| 1 | /******************************************************************************* |
| 2 | * Copyright 2016-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #include <assert.h> |
| 18 | #include <stddef.h> |
| 19 | #include <stdint.h> |
| 20 | |
| 21 | #include "mkldnn.h" |
| 22 | |
| 23 | #include "c_types_map.hpp" |
| 24 | #include "engine.hpp" |
| 25 | #include "type_helpers.hpp" |
| 26 | #include "utils.hpp" |
| 27 | |
| 28 | using namespace mkldnn::impl; |
| 29 | using namespace mkldnn::impl::utils; |
| 30 | using namespace mkldnn::impl::status; |
| 31 | using namespace mkldnn::impl::data_type; |
| 32 | |
| 33 | namespace { |
| 34 | bool memory_desc_sanity_check(int ndims,const dims_t dims, |
| 35 | data_type_t data_type, format_kind_t format_kind) { |
| 36 | if (ndims == 0) return true; |
| 37 | |
| 38 | bool ok = true |
| 39 | && dims != nullptr |
| 40 | && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS |
| 41 | && one_of(data_type, f32, s32, s8, u8) |
| 42 | && format_kind != format_kind::undef; |
| 43 | if (!ok) return false; |
| 44 | for (int d = 0; d < ndims; ++d) |
| 45 | if (dims[d] < 0) return false; |
| 46 | |
| 47 | return true; |
| 48 | } |
| 49 | |
| 50 | bool memory_desc_sanity_check(const memory_desc_t *md) { |
| 51 | if (md == nullptr) return false; |
| 52 | return memory_desc_sanity_check(md->ndims, md->dims, md->data_type, |
| 53 | format_kind::any); |
| 54 | } |
| 55 | } |
| 56 | |
| 57 | status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims, |
| 58 | const dims_t dims, data_type_t data_type, format_tag_t tag) { |
| 59 | if (any_null(memory_desc)) return invalid_arguments; |
| 60 | if (ndims == 0 || tag == format_tag::undef) { |
| 61 | *memory_desc = types::zero_md(); |
| 62 | return success; |
| 63 | } |
| 64 | |
| 65 | format_kind_t format_kind = types::format_tag_to_kind(tag); |
| 66 | |
| 67 | /* memory_desc != 0 */ |
| 68 | bool args_ok = !any_null(memory_desc) |
| 69 | && memory_desc_sanity_check(ndims, dims, data_type, format_kind); |
| 70 | if (!args_ok) return invalid_arguments; |
| 71 | |
| 72 | auto md = memory_desc_t(); |
| 73 | md.ndims = ndims; |
| 74 | array_copy(md.dims, dims, ndims); |
| 75 | md.data_type = data_type; |
| 76 | array_copy(md.padded_dims, dims, ndims); |
| 77 | md.format_kind = format_kind; |
| 78 | |
| 79 | status_t status = success; |
| 80 | if (tag == format_tag::undef) { |
| 81 | status = invalid_arguments; |
| 82 | } else if (tag == format_tag::any) { |
| 83 | // nop |
| 84 | } else if (format_kind == format_kind::blocked) { |
| 85 | status = memory_desc_wrapper::compute_blocking(md, tag); |
| 86 | } else { |
| 87 | assert(!"unreachable" ); |
| 88 | status = invalid_arguments; |
| 89 | } |
| 90 | |
| 91 | if (status == success) |
| 92 | *memory_desc = md; |
| 93 | |
| 94 | return status; |
| 95 | } |
| 96 | |
| 97 | status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc, |
| 98 | int ndims, const dims_t dims, data_type_t data_type, |
| 99 | const dims_t strides) { |
| 100 | if (any_null(memory_desc)) return invalid_arguments; |
| 101 | if (ndims == 0) { |
| 102 | *memory_desc = types::zero_md(); |
| 103 | return success; |
| 104 | } |
| 105 | |
| 106 | /* memory_desc != 0 */ |
| 107 | bool args_ok = !any_null(memory_desc) |
| 108 | && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any); |
| 109 | if (!args_ok) return invalid_arguments; |
| 110 | |
| 111 | auto md = memory_desc_t(); |
| 112 | md.ndims = ndims; |
| 113 | array_copy(md.dims, dims, ndims); |
| 114 | md.data_type = data_type; |
| 115 | array_copy(md.padded_dims, dims, ndims); |
| 116 | md.format_kind = format_kind::blocked; |
| 117 | |
| 118 | dims_t default_strides = {0}; |
| 119 | if (strides == nullptr) { |
| 120 | default_strides[md.ndims - 1] = 1; |
| 121 | for (int d = md.ndims - 2; d >= 0; --d) |
| 122 | default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1]; |
| 123 | strides = default_strides; |
| 124 | } else { |
| 125 | /* TODO: add sanity check for the provided strides */ |
| 126 | } |
| 127 | |
| 128 | array_copy(md.format_desc.blocking.strides, strides, md.ndims); |
| 129 | |
| 130 | *memory_desc = md; |
| 131 | |
| 132 | return status::success; |
| 133 | } |
| 134 | |
| 135 | status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md, |
| 136 | const memory_desc_t *parent_md, const dims_t dims, |
| 137 | const dims_t offsets) { |
| 138 | if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md)) |
| 139 | return invalid_arguments; |
| 140 | |
| 141 | const memory_desc_wrapper src_d(parent_md); |
| 142 | |
| 143 | for (int d = 0; d < src_d.ndims(); ++d) { |
| 144 | if (dims[d] < 0 || offsets[d] < 0 |
| 145 | || (offsets[d] + dims[d] > src_d.dims()[d])) |
| 146 | return invalid_arguments; |
| 147 | } |
| 148 | |
| 149 | if (src_d.format_kind() != format_kind::blocked) |
| 150 | return unimplemented; |
| 151 | |
| 152 | dims_t blocks; |
| 153 | src_d.compute_blocks(blocks); |
| 154 | |
| 155 | memory_desc_t dst_d = *parent_md; |
| 156 | auto &dst_d_blk = dst_d.format_desc.blocking; |
| 157 | |
| 158 | /* TODO: put this into memory_desc_wrapper */ |
| 159 | for (int d = 0; d < src_d.ndims(); ++d) { |
| 160 | /* very limited functionality for now */ |
| 161 | const bool ok = true |
| 162 | && offsets[d] % blocks[d] == 0 /* [r1] */ |
| 163 | && src_d.padded_offsets()[d] == 0 |
| 164 | && (false |
| 165 | || dims[d] % blocks[d] == 0 |
| 166 | || dims[d] < blocks[d]); |
| 167 | if (!ok) |
| 168 | return unimplemented; |
| 169 | |
| 170 | const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d]; |
| 171 | |
| 172 | dst_d.dims[d] = dims[d]; |
| 173 | dst_d.padded_dims[d] = is_right_border |
| 174 | ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d]; |
| 175 | dst_d.padded_offsets[d] = src_d.padded_offsets()[d]; |
| 176 | dst_d.offset0 += /* [r1] */ |
| 177 | offsets[d] / blocks[d] * dst_d_blk.strides[d]; |
| 178 | } |
| 179 | |
| 180 | *md = dst_d; |
| 181 | |
| 182 | return success; |
| 183 | } |
| 184 | |
| 185 | int mkldnn_memory_desc_equal(const memory_desc_t *lhs, |
| 186 | const memory_desc_t *rhs) { |
| 187 | if (lhs == rhs) return 1; |
| 188 | if (any_null(lhs, rhs)) return 0; |
| 189 | return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs); |
| 190 | } |
| 191 | |
| 192 | size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) { |
| 193 | if (md == nullptr) return 0; |
| 194 | return memory_desc_wrapper(*md).size(); |
| 195 | } |
| 196 | |
| 197 | status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md, |
| 198 | engine_t *engine, void *handle) { |
| 199 | if (any_null(memory, engine)) return invalid_arguments; |
| 200 | memory_desc_t z_md = types::zero_md(); |
| 201 | return engine->memory_create(memory, md ? md : &z_md, handle); |
| 202 | } |
| 203 | |
| 204 | status_t mkldnn_memory_get_memory_desc(const memory_t *memory, |
| 205 | const memory_desc_t **md) { |
| 206 | if (any_null(memory, md)) return invalid_arguments; |
| 207 | *md = memory->md(); |
| 208 | return success; |
| 209 | } |
| 210 | |
| 211 | status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) { |
| 212 | if (any_null(memory, engine)) return invalid_arguments; |
| 213 | *engine = memory->engine(); |
| 214 | return success; |
| 215 | } |
| 216 | |
| 217 | status_t mkldnn_memory_get_data_handle(const memory_t *memory, |
| 218 | void **handle) { |
| 219 | if (any_null(handle)) |
| 220 | return invalid_arguments; |
| 221 | if (memory == nullptr) { |
| 222 | *handle = nullptr; |
| 223 | return success; |
| 224 | } |
| 225 | return memory->get_data_handle(handle); |
| 226 | } |
| 227 | |
| 228 | status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) { |
| 229 | if (any_null(memory)) return invalid_arguments; |
| 230 | return memory->set_data_handle(handle); |
| 231 | } |
| 232 | |
| 233 | status_t mkldnn_memory_destroy(memory_t *memory) { |
| 234 | delete memory; |
| 235 | return success; |
| 236 | } |
| 237 | |
| 238 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
| 239 | |