| 1 | /******************************************************************************* |
| 2 | * Copyright 2016-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #include <assert.h> |
| 18 | |
| 19 | #include <initializer_list> |
| 20 | |
| 21 | #include "c_types_map.hpp" |
| 22 | #include "memory_desc_wrapper.hpp" |
| 23 | #include "type_helpers.hpp" |
| 24 | #include "utils.hpp" |
| 25 | |
| 26 | namespace mkldnn { |
| 27 | namespace impl { |
| 28 | |
| 29 | status_t fill_blocked(memory_desc_t &md, |
| 30 | std::initializer_list<int> perm, |
| 31 | std::initializer_list<int> inner_blks, |
| 32 | std::initializer_list<int> inner_idxs) { |
| 33 | const bool ok = true |
| 34 | && perm.size() == (size_t)md.ndims |
| 35 | && inner_blks.size() == inner_idxs.size(); |
| 36 | if (!ok) return status::invalid_arguments; |
| 37 | |
| 38 | md.offset0 = 0; |
| 39 | |
| 40 | blocking_desc_t &blk = md.format_desc.blocking; |
| 41 | |
| 42 | dim_t block_size = 1; |
| 43 | dims_t blocks = {0}; |
| 44 | utils::array_set(blocks, 1, md.ndims); |
| 45 | |
| 46 | blk.inner_nblks = (int)inner_blks.size(); |
| 47 | |
| 48 | int iblk = 0; |
| 49 | for (const auto &b: inner_idxs) |
| 50 | blk.inner_idxs[iblk++] = b; |
| 51 | |
| 52 | iblk = 0; |
| 53 | for (const auto &b: inner_blks) { |
| 54 | int dim = blk.inner_idxs[iblk]; |
| 55 | block_size *= b; |
| 56 | blocks[dim] *= b; |
| 57 | blk.inner_blks[iblk++] = b; |
| 58 | } |
| 59 | |
| 60 | utils::array_set(md.padded_offsets, 0, md.ndims); |
| 61 | for (int d = 0; d < md.ndims; ++d) |
| 62 | md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); |
| 63 | |
| 64 | dim_t stride = block_size; |
| 65 | // if only we use C++14, the initializer_list would have rbegin()/rend()... |
| 66 | for (int d = 0; d < md.ndims; ++d) |
| 67 | stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d]; |
| 68 | |
| 69 | for (const auto &d: perm) { |
| 70 | if (md.padded_dims[d] == 0) { |
| 71 | blk.strides[d] = 1; |
| 72 | continue; |
| 73 | } |
| 74 | stride /= md.padded_dims[d] / blocks[d]; |
| 75 | blk.strides[d] = stride; |
| 76 | } |
| 77 | |
| 78 | assert(stride == block_size); |
| 79 | |
| 80 | return status::success; |
| 81 | } |
| 82 | |
| 83 | status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, |
| 84 | format_tag_t tag) |
| 85 | { |
| 86 | using namespace format_tag; |
| 87 | |
| 88 | if (memory_desc.ndims == 0) return status::invalid_arguments; |
| 89 | |
| 90 | # define C(tag, ... /* perm, inner_blks, inner_idxs */) \ |
| 91 | case tag: return fill_blocked(memory_desc, __VA_ARGS__) |
| 92 | |
| 93 | switch (tag) { |
| 94 | C(a, {0}, {}, {}); |
| 95 | C(ab, {0, 1}, {}, {}); |
| 96 | C(abc, {0, 1, 2}, {}, {}); |
| 97 | C(abcd, {0, 1, 2, 3}, {}, {}); |
| 98 | C(abcde, {0, 1, 2, 3, 4}, {}, {}); |
| 99 | C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {}); |
| 100 | C(abdec, {0, 1, 3, 4, 2}, {}, {}); |
| 101 | C(acb, {0, 2, 1}, {}, {}); |
| 102 | C(acbde, {0, 2, 1, 3, 4}, {}, {}); |
| 103 | C(acdb, {0, 2, 3, 1}, {}, {}); |
| 104 | C(acdeb, {0, 2, 3, 4, 1}, {}, {}); |
| 105 | C(ba, {1, 0}, {}, {}); |
| 106 | C(bac, {1, 0, 2}, {}, {}); |
| 107 | C(bacd, {1, 0, 2, 3}, {}, {}); |
| 108 | C(bcda, {1, 2, 3, 0}, {}, {}); |
| 109 | C(cba, {2, 1, 0}, {}, {}); |
| 110 | C(cdba, {2, 3, 1, 0}, {}, {}); |
| 111 | C(cdeba, {2, 3, 4, 1, 0}, {}, {}); |
| 112 | C(decab, {3, 4, 2, 0, 1}, {}, {}); |
| 113 | |
| 114 | C(Abc4a, {0, 1, 2}, {4}, {0}); |
| 115 | C(aBc4b, {0, 1, 2}, {4}, {1}); |
| 116 | C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1}); |
| 117 | C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0}); |
| 118 | C(Abcd4a, {0, 1, 2, 3}, {4}, {0}); |
| 119 | C(aBcd4b, {0, 1, 2, 3}, {4}, {1}); |
| 120 | C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0}); |
| 121 | C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2}); |
| 122 | C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); |
| 123 | C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0}); |
| 124 | C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1}); |
| 125 | C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0}); |
| 126 | C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); |
| 127 | C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1}); |
| 128 | C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1}); |
| 129 | C(aBdc4b, {0, 1, 3, 2}, {4}, {1}); |
| 130 | C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1}); |
| 131 | C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1}); |
| 132 | C(Acb4a, {0, 2, 1}, {4}, {0}); |
| 133 | C(Acdb4a, {0, 2, 3, 1}, {4}, {0}); |
| 134 | C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0}); |
| 135 | |
| 136 | C(Abc16a, {0, 1, 2}, {16}, {0}); |
| 137 | C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1}); |
| 138 | C(aBc16b, {0, 1, 2}, {16}, {1}); |
| 139 | C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0}); |
| 140 | C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0}); |
| 141 | C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1}); |
| 142 | C(aBc8b, {0, 1, 2}, {8}, {1}); |
| 143 | C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1}); |
| 144 | C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0}); |
| 145 | C(Abcd16a, {0, 1, 2, 3}, {16}, {0}); |
| 146 | C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1}); |
| 147 | C(aBcd16b, {0, 1, 2, 3}, {16}, {1}); |
| 148 | C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0}); |
| 149 | C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2}); |
| 150 | C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1}); |
| 151 | C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1}); |
| 152 | C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); |
| 153 | C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); |
| 154 | C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); |
| 155 | C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1}); |
| 156 | C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1}); |
| 157 | C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0}); |
| 158 | C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2}); |
| 159 | C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2}); |
| 160 | C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1}); |
| 161 | C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0}); |
| 162 | C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1}); |
| 163 | C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1}); |
| 164 | C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0}); |
| 165 | C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2}); |
| 166 | C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1}); |
| 167 | C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2}); |
| 168 | C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2}); |
| 169 | C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2}); |
| 170 | C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0}); |
| 171 | C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1}); |
| 172 | C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1}); |
| 173 | C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1}); |
| 174 | C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1}); |
| 175 | C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0}); |
| 176 | C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2}); |
| 177 | C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2}); |
| 178 | C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1}); |
| 179 | C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1}); |
| 180 | C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2}); |
| 181 | C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1}); |
| 182 | C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2}); |
| 183 | C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2}); |
| 184 | C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1}); |
| 185 | C(aBdc16b, {0, 1, 3, 2}, {16}, {1}); |
| 186 | C(aBdc8b, {0, 1, 3, 2}, {8}, {1}); |
| 187 | C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1}); |
| 188 | C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1}); |
| 189 | C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1}); |
| 190 | C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1}); |
| 191 | C(Acb16a, {0, 2, 1}, {16}, {0}); |
| 192 | C(Acb8a, {0, 2, 1}, {8}, {0}); |
| 193 | C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2}); |
| 194 | C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2}); |
| 195 | C(Acdb16a, {0, 2, 3, 1}, {16}, {0}); |
| 196 | C(Acdb8a, {0, 2, 3, 1}, {8}, {0}); |
| 197 | C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0}); |
| 198 | C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0}); |
| 199 | C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1}); |
| 200 | C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1}); |
| 201 | default: break; |
| 202 | } |
| 203 | |
| 204 | #undef C |
| 205 | |
| 206 | return status::invalid_arguments; |
| 207 | } |
| 208 | |
| 209 | } |
| 210 | } |
| 211 | |
| 212 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
| 213 | |