| 1 | /******************************************************************************* |
| 2 | * Copyright 2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #include <assert.h> |
| 18 | |
| 19 | #include "c_types_map.hpp" |
| 20 | #include "memory_desc_wrapper.hpp" |
| 21 | #include "mkldnn_debug.h" |
| 22 | #include "nstl.hpp" |
| 23 | #include "type_helpers.hpp" |
| 24 | #include "utils.hpp" |
| 25 | |
| 26 | #include "jit_uni_reorder.hpp" |
| 27 | |
| 28 | using namespace mkldnn::impl::types; |
| 29 | using namespace mkldnn::impl::status; |
| 30 | |
| 31 | namespace mkldnn { |
| 32 | namespace impl { |
| 33 | namespace cpu { |
| 34 | |
| 35 | namespace tr { |
| 36 | |
| 37 | /** ad-hoc structure to describe blocked memory layout */ |
| 38 | struct layout_desc_t { |
| 39 | data_type_t dt; |
| 40 | int ndims; |
| 41 | dims_t id; |
| 42 | dims_t dims; |
| 43 | strides_t strides; |
| 44 | }; |
| 45 | |
| 46 | status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, |
| 47 | layout_desc_t &ld) { |
| 48 | const auto md = memory_desc_wrapper(md_); |
| 49 | |
| 50 | bool ok = true |
| 51 | && md.is_blocking_desc() |
| 52 | && md.extra().flags == 0; |
| 53 | if (!ok) return invalid_arguments; |
| 54 | |
| 55 | const auto &bd = md.blocking_desc(); |
| 56 | |
| 57 | ld.ndims = 0; |
| 58 | ld.dt = md.data_type(); |
| 59 | |
| 60 | auto P = [&ld](int id, int dim, ptrdiff_t stride) { |
| 61 | assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); |
| 62 | ld.id[ld.ndims] = id; |
| 63 | ld.dims[ld.ndims] = dim; |
| 64 | ld.strides[ld.ndims] = stride; |
| 65 | ++ld.ndims; |
| 66 | }; |
| 67 | |
| 68 | dims_t blocks; |
| 69 | md.compute_blocks(blocks); |
| 70 | |
| 71 | for (int d = 0; d < md.ndims(); ++d) { |
| 72 | const int ld_ndims_start = ld.ndims; |
| 73 | if (blocks[d] != 1) { |
| 74 | stride_t stride = 1; |
| 75 | for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { |
| 76 | if (bd.inner_idxs[iblk] == d) |
| 77 | P(d, bd.inner_blks[iblk], stride); |
| 78 | stride *= bd.inner_blks[iblk]; |
| 79 | } |
| 80 | } |
| 81 | P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); |
| 82 | |
| 83 | // TODO: NOW: revisit, do we need a reverse? |
| 84 | // TODO: NOW: consider using strides instead of block sizes in md |
| 85 | // reverse the order of dims |
| 86 | for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) { |
| 87 | const int idx0 = ld_ndims_start + ld_d; |
| 88 | const int idx1 = ld.ndims - 1 - ld_d; |
| 89 | nstl::swap(ld.dims[idx0], ld.dims[idx1]); |
| 90 | nstl::swap(ld.strides[idx0], ld.strides[idx1]); |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | return success; |
| 95 | } |
| 96 | |
| 97 | status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, |
| 98 | const primitive_attr_t *attr) { |
| 99 | auto im_d = memory_desc_wrapper(imd); |
| 100 | auto om_d = memory_desc_wrapper(omd); |
| 101 | |
| 102 | bool ok = true |
| 103 | && im_d.is_blocking_desc() |
| 104 | && om_d.is_blocking_desc() |
| 105 | && !im_d.has_zero_dim() |
| 106 | && !om_d.has_zero_dim(); |
| 107 | if (!ok) |
| 108 | return unimplemented; |
| 109 | |
| 110 | dims_t iblocks, oblocks; |
| 111 | im_d.compute_blocks(iblocks); |
| 112 | om_d.compute_blocks(oblocks); |
| 113 | |
| 114 | /* padding_dim consistency check */ |
| 115 | for (int d = 0; d < im_d.ndims(); ++d) { |
| 116 | const auto pdim = im_d.padded_dims()[d]; |
| 117 | bool ok = true |
| 118 | && pdim == om_d.padded_dims()[d] |
| 119 | && pdim % iblocks[d] == 0 |
| 120 | && pdim % oblocks[d] == 0; |
| 121 | if (!ok) return unimplemented; |
| 122 | } |
| 123 | |
| 124 | layout_desc_t ild, old; |
| 125 | status_t status = cvt_mem_desc_to_layout_desc(imd, ild); |
| 126 | if (status != success) return status; |
| 127 | status = cvt_mem_desc_to_layout_desc(omd, old); |
| 128 | if (status != success) return status; |
| 129 | |
| 130 | p.itype = ild.dt; |
| 131 | p.otype = old.dt; |
| 132 | |
| 133 | p.scale_type = attr->output_scales_.has_default_values() |
| 134 | ? scale_type_t::NONE |
| 135 | : (attr->output_scales_.mask_ == 0 |
| 136 | ? scale_type_t::COMMON |
| 137 | : scale_type_t::MANY); |
| 138 | |
| 139 | ptrdiff_t ss[max_ndims] = {0}; |
| 140 | if (p.scale_type == scale_type_t::MANY) { |
| 141 | ptrdiff_t last_ss = 1; |
| 142 | for (int d = old.ndims - 1; d >=0; --d) { |
| 143 | assert((d == 0 || old.id[d - 1] <= old.id[d]) |
| 144 | && "logical dimensions should be in ascending order" ); |
| 145 | if (attr->output_scales_.mask_ & (1 << old.id[d])) { |
| 146 | ss[d] = last_ss; |
| 147 | last_ss *= old.dims[d]; |
| 148 | } |
| 149 | } |
| 150 | } |
| 151 | |
| 152 | int ndims = 0; |
| 153 | |
| 154 | int i_pos = 0; /* state for input -- current dimension */ |
| 155 | int o_pos = 0; /* state for output -- current dimension */ |
| 156 | |
| 157 | while (i_pos < ild.ndims && o_pos < old.ndims) { |
| 158 | assert(ild.id[i_pos] == old.id[o_pos]); |
| 159 | if (ild.id[i_pos] != old.id[o_pos]) |
| 160 | return runtime_error; |
| 161 | |
| 162 | assert(ndims < max_ndims); |
| 163 | if (ndims == max_ndims) |
| 164 | return runtime_error; |
| 165 | |
| 166 | if (ild.dims[i_pos] == old.dims[o_pos]) { |
| 167 | p.nodes[ndims].n = ild.dims[i_pos]; |
| 168 | p.nodes[ndims].is = ild.strides[i_pos]; |
| 169 | p.nodes[ndims].os = old.strides[o_pos]; |
| 170 | p.nodes[ndims].ss = ss[o_pos]; |
| 171 | ++ndims; |
| 172 | ++i_pos; |
| 173 | ++o_pos; |
| 174 | } else if (ild.dims[i_pos] < old.dims[o_pos]) { |
| 175 | assert(old.dims[o_pos] % ild.dims[i_pos] == 0); |
| 176 | int factor = old.dims[o_pos] / ild.dims[i_pos]; |
| 177 | p.nodes[ndims].n = ild.dims[i_pos]; |
| 178 | p.nodes[ndims].is = ild.strides[i_pos]; |
| 179 | p.nodes[ndims].os = old.strides[o_pos] * factor; |
| 180 | p.nodes[ndims].ss = ss[o_pos] * factor; |
| 181 | ++ndims; |
| 182 | ++i_pos; |
| 183 | old.dims[o_pos] = factor; |
| 184 | } else if (ild.dims[i_pos] > old.dims[o_pos]) { |
| 185 | assert(ild.dims[i_pos] % old.dims[o_pos] == 0); |
| 186 | int factor = ild.dims[i_pos] / old.dims[o_pos]; |
| 187 | p.nodes[ndims].n = old.dims[o_pos]; |
| 188 | p.nodes[ndims].is = ild.strides[i_pos] * factor; |
| 189 | p.nodes[ndims].os = old.strides[o_pos]; |
| 190 | p.nodes[ndims].ss = ss[o_pos]; |
| 191 | ++ndims; |
| 192 | ++o_pos; |
| 193 | ild.dims[i_pos] = factor; |
| 194 | } |
| 195 | } |
| 196 | p.ndims = ndims; |
| 197 | |
| 198 | dims_t zero_pos = {0}; |
| 199 | p.ioff = memory_desc_wrapper(imd).off_v(zero_pos); |
| 200 | p.ooff = memory_desc_wrapper(omd).off_v(zero_pos); |
| 201 | |
| 202 | const int sum_idx = attr->post_ops_.find(primitive_kind::sum); |
| 203 | p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; |
| 204 | |
| 205 | return success; |
| 206 | } |
| 207 | |
| 208 | void prb_normalize(prb_t &p) { |
| 209 | for (int d = 0; d < p.ndims; ++d) { |
| 210 | int min_pos = d; |
| 211 | for (int j = d + 1; j < p.ndims; ++j) { |
| 212 | bool new_min = false |
| 213 | || p.nodes[j].os < p.nodes[min_pos].os |
| 214 | || (true |
| 215 | && p.nodes[j].os == p.nodes[min_pos].os |
| 216 | && p.nodes[j].n < p.nodes[min_pos].n); |
| 217 | if (new_min) min_pos = j; |
| 218 | } |
| 219 | if (min_pos != d) |
| 220 | nstl::swap(p.nodes[d], p.nodes[min_pos]); |
| 221 | } |
| 222 | } |
| 223 | |
| 224 | void prb_simplify(prb_t &p) { |
| 225 | #if defined(__GNUC__) && __GNUC__ >= 4 |
| 226 | /* GCC produces bogus array subscript is above array bounds warning for |
| 227 | * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */ |
| 228 | #pragma GCC diagnostic push |
| 229 | #pragma GCC diagnostic ignored "-Warray-bounds" |
| 230 | #endif |
| 231 | for (int d = 0; d < p.ndims - 1; ++d) { |
| 232 | auto &this_node = p.nodes[d + 0]; |
| 233 | auto &next_node = p.nodes[d + 1]; |
| 234 | const bool fold = false |
| 235 | || next_node.n == (size_t)1 // trivial case, just drop next node |
| 236 | || (true // or real folding if possible |
| 237 | && next_node.is == (ptrdiff_t)this_node.n * this_node.is |
| 238 | && next_node.os == (ptrdiff_t)this_node.n * this_node.os |
| 239 | && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss); |
| 240 | if (fold) { |
| 241 | this_node.n *= next_node.n; |
| 242 | for (int j = d + 2; j < p.ndims; ++j) |
| 243 | p.nodes[j - 1] = p.nodes[j]; |
| 244 | --p.ndims; |
| 245 | --d; // make another try |
| 246 | } |
| 247 | } |
| 248 | #if defined(__GNUC__) && __GNUC__ >= 4 |
| 249 | #pragma GCC diagnostic pop |
| 250 | #endif |
| 251 | } |
| 252 | |
| 253 | void prb_node_split(prb_t &p, int dim, size_t n1) { |
| 254 | assert(dim < p.ndims); |
| 255 | assert(p.ndims < max_ndims); |
| 256 | assert(p.nodes[dim].n % n1 == 0); |
| 257 | |
| 258 | p.ndims += 1; |
| 259 | |
| 260 | for (int d = p.ndims; d > dim + 1; --d) |
| 261 | p.nodes[d] = p.nodes[d - 1]; |
| 262 | |
| 263 | p.nodes[dim + 1].n = p.nodes[dim].n / n1; |
| 264 | p.nodes[dim + 1].is = p.nodes[dim].is * n1; |
| 265 | p.nodes[dim + 1].os = p.nodes[dim].os * n1; |
| 266 | p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; |
| 267 | |
| 268 | p.nodes[dim].n = n1; |
| 269 | } |
| 270 | |
| 271 | void prb_node_swap(prb_t &p, int d0, int d1) { |
| 272 | assert(d0 < p.ndims); |
| 273 | assert(d1 < p.ndims); |
| 274 | assert(p.ndims < max_ndims); |
| 275 | |
| 276 | if (d0 == d1) return; |
| 277 | |
| 278 | nstl::swap(p.nodes[d0], p.nodes[d1]); |
| 279 | } |
| 280 | |
| 281 | void prb_node_move(prb_t &p, int d0, int d1) { |
| 282 | assert(d0 < p.ndims); |
| 283 | assert(d1 < p.ndims); |
| 284 | assert(p.ndims < max_ndims); |
| 285 | |
| 286 | if (d0 == d1) return; |
| 287 | |
| 288 | node_t node = p.nodes[d0]; |
| 289 | |
| 290 | if (d0 < d1) |
| 291 | for (int d = d0; d < d1; ++d) |
| 292 | p.nodes[d] = p.nodes[d + 1]; |
| 293 | else |
| 294 | for (int d = d0; d > d1; --d) |
| 295 | p.nodes[d] = p.nodes[d - 1]; |
| 296 | |
| 297 | p.nodes[d1] = node; |
| 298 | } |
| 299 | |
| 300 | void prb_dump(const prb_t &p) { |
| 301 | printf("@@@ type:%s:%s ndims:%d " , mkldnn_dt2str(p.itype), |
| 302 | mkldnn_dt2str(p.otype), p.ndims); |
| 303 | for (int d = 0; d < p.ndims; ++d) |
| 304 | printf("[%zu:%td:%td:%td]" , |
| 305 | p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss); |
| 306 | printf(" off:%zu:%zu\n" , p.ioff, p.ooff); |
| 307 | } |
| 308 | |
| 309 | } |
| 310 | |
| 311 | } |
| 312 | } |
| 313 | } |
| 314 | |