| 1 | /******************************************************************************* |
| 2 | * Copyright 2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #include "mkldnn.h" |
| 18 | |
| 19 | #include "c_types_map.hpp" |
| 20 | #include "type_helpers.hpp" |
| 21 | #include "utils.hpp" |
| 22 | #include "cpu/gemm/os_blas.hpp" |
| 23 | |
| 24 | using namespace mkldnn::impl; |
| 25 | using namespace mkldnn::impl::status; |
| 26 | using namespace mkldnn::impl::types; |
| 27 | using namespace mkldnn::impl::utils; |
| 28 | |
| 29 | namespace { |
| 30 | memory_desc_t copy_maybe_null(const memory_desc_t *md) { |
| 31 | return md ? *md : zero_md(); |
| 32 | } |
| 33 | |
| 34 | rnn_desc_t zero_rnn_desc() { |
| 35 | auto rd = rnn_desc_t(); |
| 36 | rd.src_layer_desc = zero_md(); |
| 37 | rd.src_iter_desc = zero_md(); |
| 38 | rd.weights_layer_desc = zero_md(); |
| 39 | rd.weights_iter_desc = zero_md(); |
| 40 | rd.bias_desc = zero_md(); |
| 41 | rd.dst_layer_desc = zero_md(); |
| 42 | rd.dst_iter_desc = zero_md(); |
| 43 | rd.diff_src_layer_desc = zero_md(); |
| 44 | rd.diff_src_iter_desc = zero_md(); |
| 45 | rd.diff_weights_layer_desc = zero_md(); |
| 46 | rd.diff_weights_iter_desc = zero_md(); |
| 47 | rd.diff_bias_desc = zero_md(); |
| 48 | rd.diff_dst_layer_desc = zero_md(); |
| 49 | rd.diff_dst_iter_desc = zero_md(); |
| 50 | return rd; |
| 51 | } |
| 52 | } |
| 53 | |
| 54 | /* Public C Api */ |
| 55 | |
| 56 | status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc, |
| 57 | mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f, |
| 58 | unsigned int flags, float alpha, float clipping) { |
| 59 | using namespace mkldnn::impl::alg_kind; |
| 60 | |
| 61 | bool args_ok = true |
| 62 | && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru, |
| 63 | gru_linear_before_reset) |
| 64 | && IMPLICATION(cell_kind == vanilla_rnn, |
| 65 | one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic)); |
| 66 | if (!args_ok) |
| 67 | return invalid_arguments; |
| 68 | |
| 69 | auto rcd = mkldnn_rnn_cell_desc_t(); |
| 70 | |
| 71 | rcd.cell_kind = cell_kind; |
| 72 | rcd.activation_kind = act_f; |
| 73 | rcd.flags = flags; |
| 74 | rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0; |
| 75 | rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0; |
| 76 | |
| 77 | *rnn_cell_desc = rcd; |
| 78 | |
| 79 | return success; |
| 80 | } |
| 81 | |
| 82 | int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) { |
| 83 | switch (rnn_cell_desc->cell_kind) { |
| 84 | case mkldnn::impl::alg_kind::vanilla_rnn: return 1; |
| 85 | case mkldnn::impl::alg_kind::vanilla_gru: return 3; |
| 86 | case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3; |
| 87 | case mkldnn::impl::alg_kind::vanilla_lstm: return 4; |
| 88 | default: assert(!"unknown cell kind" ); return 0; |
| 89 | } |
| 90 | return 0; |
| 91 | } |
| 92 | |
| 93 | int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) { |
| 94 | switch (rnn_cell_desc->cell_kind) { |
| 95 | case mkldnn::impl::alg_kind::vanilla_rnn: return 1; |
| 96 | case mkldnn::impl::alg_kind::vanilla_gru: return 1; |
| 97 | case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1; |
| 98 | case mkldnn::impl::alg_kind::vanilla_lstm: return 2; |
| 99 | default: assert(!"unknown cell kind" ); return 0; |
| 100 | } |
| 101 | return 0; |
| 102 | } |
| 103 | |
| 104 | status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc, |
| 105 | prop_kind_t prop_kind, const memory_desc_t *src_layer_desc, |
| 106 | const memory_desc_t *src_iter_desc, |
| 107 | const memory_desc_t *weights_layer_desc, |
| 108 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
| 109 | const memory_desc_t *dst_layer_desc, |
| 110 | const memory_desc_t *dst_iter_desc) { |
| 111 | using namespace data_type; |
| 112 | data_type_t src_layer_dt = src_layer_desc->data_type; |
| 113 | data_type_t dst_layer_dt = dst_layer_desc->data_type; |
| 114 | data_type_t weights_iter_dt = weights_iter_desc->data_type; |
| 115 | data_type_t weights_layer_dt = weights_layer_desc->data_type; |
| 116 | |
| 117 | bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt, |
| 118 | weights_layer_dt) |
| 119 | && IMPLICATION(!is_zero_md(src_iter_desc), |
| 120 | src_iter_desc->data_type == f32) |
| 121 | && IMPLICATION(!is_zero_md(dst_iter_desc), |
| 122 | dst_iter_desc->data_type == f32) |
| 123 | && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); |
| 124 | |
| 125 | #if USE_MKL_PACKED_GEMM |
| 126 | bool is_u8u8u8 = src_layer_dt == u8 |
| 127 | && IMPLICATION(!is_zero_md(src_iter_desc), |
| 128 | src_iter_desc->data_type == u8) |
| 129 | && IMPLICATION(!is_zero_md(dst_iter_desc), |
| 130 | dst_iter_desc->data_type == u8) |
| 131 | && one_of(dst_layer_dt, u8, f32) |
| 132 | && everyone_is(s8, weights_iter_dt, weights_layer_dt) |
| 133 | && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); |
| 134 | |
| 135 | bool is_f32u8f32 = src_layer_dt == u8 |
| 136 | && IMPLICATION(!is_zero_md(src_iter_desc), |
| 137 | src_iter_desc->data_type == f32) |
| 138 | && IMPLICATION(!is_zero_md(dst_iter_desc), |
| 139 | dst_iter_desc->data_type == f32) |
| 140 | && one_of(dst_layer_dt, u8, f32) |
| 141 | && everyone_is(s8, weights_iter_dt, weights_layer_dt) |
| 142 | && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); |
| 143 | |
| 144 | bool is_inference = prop_kind == prop_kind::forward_inference; |
| 145 | bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm; |
| 146 | |
| 147 | return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference)) |
| 148 | ? success |
| 149 | : unimplemented; |
| 150 | #else |
| 151 | return is_f32 ? success : unimplemented; |
| 152 | #endif |
| 153 | } |
| 154 | |
| 155 | status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc, |
| 156 | rnn_direction_t direction, int L, int D, int T, int N, int S, int G, |
| 157 | int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc, |
| 158 | const memory_desc_t *src_iter_desc, |
| 159 | const memory_desc_t *weights_layer_desc, |
| 160 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
| 161 | const memory_desc_t *dst_layer_desc, |
| 162 | const memory_desc_t *dst_iter_desc) { |
| 163 | bool args_ok; |
| 164 | |
| 165 | // * algorithm specific |
| 166 | args_ok = true |
| 167 | && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru, |
| 168 | DIC == SIC); |
| 169 | if (!args_ok) return invalid_arguments; |
| 170 | int = |
| 171 | rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset; |
| 172 | |
| 173 | // * on num layers |
| 174 | args_ok = true |
| 175 | && L == weights_layer_desc->dims[0] |
| 176 | && L == weights_iter_desc->dims[0] |
| 177 | && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0]) |
| 178 | && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]) |
| 179 | && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]); |
| 180 | if (!args_ok) return invalid_arguments; |
| 181 | |
| 182 | // * on num directions |
| 183 | args_ok = true |
| 184 | && D == weights_layer_desc->dims[1] |
| 185 | && D == weights_iter_desc->dims[1] |
| 186 | && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1]) |
| 187 | && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1]) |
| 188 | && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]); |
| 189 | if (!args_ok) return invalid_arguments; |
| 190 | |
| 191 | // * on num iterations |
| 192 | args_ok = true |
| 193 | && T == src_layer_desc->dims[0] |
| 194 | && T == dst_layer_desc->dims[0]; |
| 195 | if (!args_ok) return invalid_arguments; |
| 196 | |
| 197 | // * on mb |
| 198 | args_ok = true |
| 199 | && N == src_layer_desc->dims[1] |
| 200 | && N == dst_layer_desc->dims[1] |
| 201 | && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3]) |
| 202 | && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]); |
| 203 | if (!args_ok) return invalid_arguments; |
| 204 | |
| 205 | // * on num gates |
| 206 | args_ok = true |
| 207 | && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc) |
| 208 | && G == weights_layer_desc->dims[3] |
| 209 | && G == weights_iter_desc->dims[3] |
| 210 | && IMPLICATION(!is_zero_md(bias_desc), |
| 211 | G + extra_bias == bias_desc->dims[2]); |
| 212 | if (!args_ok) return invalid_arguments; |
| 213 | |
| 214 | // * on num states |
| 215 | args_ok = true |
| 216 | && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc) |
| 217 | && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2]) |
| 218 | && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]); |
| 219 | if (!args_ok) return invalid_arguments; |
| 220 | |
| 221 | // * on slc |
| 222 | args_ok = true |
| 223 | && SLC == weights_layer_desc->dims[2] |
| 224 | && SLC == src_layer_desc->dims[2]; |
| 225 | if (!args_ok) return invalid_arguments; |
| 226 | |
| 227 | // * on sic |
| 228 | args_ok = true |
| 229 | && SIC == weights_iter_desc->dims[2] |
| 230 | && IMPLICATION(!is_zero_md(src_iter_desc), |
| 231 | SIC == src_iter_desc->dims[4]); |
| 232 | if (!args_ok) return invalid_arguments; |
| 233 | |
| 234 | // * on dlc |
| 235 | int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1; |
| 236 | args_ok = true |
| 237 | && DLC == dlc_multiplier * DIC |
| 238 | && DLC == dst_layer_desc->dims[2]; |
| 239 | if (!args_ok) return invalid_arguments; |
| 240 | |
| 241 | // * on dic |
| 242 | args_ok = true |
| 243 | && DIC == weights_layer_desc->dims[4] |
| 244 | && DIC == weights_iter_desc->dims[4] |
| 245 | && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3]) |
| 246 | && IMPLICATION(!is_zero_md(dst_iter_desc), |
| 247 | DIC == dst_iter_desc->dims[4]); |
| 248 | if (!args_ok) return invalid_arguments; |
| 249 | |
| 250 | // * unrolling/fusion conditions |
| 251 | args_ok = true |
| 252 | && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC) |
| 253 | && IMPLICATION(T > 1, SIC == DIC); |
| 254 | if (!args_ok) return invalid_arguments; |
| 255 | |
| 256 | return success; |
| 257 | } |
| 258 | |
| 259 | status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, |
| 260 | prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, |
| 261 | const rnn_direction_t direction, const memory_desc_t *src_layer_desc, |
| 262 | const memory_desc_t *src_iter_desc, |
| 263 | const memory_desc_t *weights_layer_desc, |
| 264 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
| 265 | const memory_desc_t *dst_layer_desc, |
| 266 | const memory_desc_t *dst_iter_desc) { |
| 267 | bool args_ok = true && rnn_cell_desc != nullptr |
| 268 | && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, |
| 269 | dst_layer_desc); |
| 270 | if (!args_ok) return invalid_arguments; |
| 271 | |
| 272 | //check dimensions consistency |
| 273 | int L = weights_layer_desc->dims[0]; |
| 274 | int T = src_layer_desc->dims[0]; |
| 275 | int N = src_layer_desc->dims[1]; |
| 276 | const int D = one_of(direction, mkldnn_unidirectional_left2right, |
| 277 | mkldnn_unidirectional_right2left) ? |
| 278 | 1 : |
| 279 | 2; |
| 280 | int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); |
| 281 | int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); |
| 282 | int SLC = src_layer_desc->dims[2]; |
| 283 | int SIC = weights_iter_desc->dims[2]; |
| 284 | int DLC = dst_layer_desc->dims[2]; |
| 285 | int DIC = weights_layer_desc->dims[4]; |
| 286 | |
| 287 | CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, |
| 288 | G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, |
| 289 | weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, |
| 290 | dst_iter_desc)); |
| 291 | |
| 292 | CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind, |
| 293 | src_layer_desc, src_iter_desc, weights_layer_desc, |
| 294 | weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc)); |
| 295 | |
| 296 | // Create the descriptor |
| 297 | mkldnn_rnn_desc_t rd = zero_rnn_desc(); |
| 298 | |
| 299 | rd.primitive_kind = primitive_kind::rnn; |
| 300 | rd.prop_kind = prop_kind; |
| 301 | rd.cell_desc = *rnn_cell_desc; |
| 302 | rd.direction = direction; |
| 303 | rd.src_layer_desc = copy_maybe_null(src_layer_desc); |
| 304 | rd.src_iter_desc = copy_maybe_null(src_iter_desc); |
| 305 | rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); |
| 306 | rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); |
| 307 | rd.bias_desc = copy_maybe_null(bias_desc); |
| 308 | rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); |
| 309 | rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); |
| 310 | |
| 311 | *rnn_desc = rd; |
| 312 | |
| 313 | return success; |
| 314 | } |
| 315 | |
| 316 | status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, |
| 317 | prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, |
| 318 | const rnn_direction_t direction, const memory_desc_t *src_layer_desc, |
| 319 | const memory_desc_t *src_iter_desc, |
| 320 | const memory_desc_t *weights_layer_desc, |
| 321 | const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, |
| 322 | const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, |
| 323 | const memory_desc_t *diff_src_layer_desc, |
| 324 | const memory_desc_t *diff_src_iter_desc, |
| 325 | const memory_desc_t *diff_weights_layer_desc, |
| 326 | const memory_desc_t *diff_weights_iter_desc, |
| 327 | const memory_desc_t *diff_bias_desc, |
| 328 | const memory_desc_t *diff_dst_layer_desc, |
| 329 | const memory_desc_t *diff_dst_iter_desc) { |
| 330 | bool args_ok = true |
| 331 | && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, |
| 332 | dst_layer_desc, diff_src_layer_desc, |
| 333 | diff_weights_layer_desc, diff_weights_iter_desc, |
| 334 | diff_dst_layer_desc); |
| 335 | if (!args_ok) |
| 336 | return invalid_arguments; |
| 337 | |
| 338 | auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) { |
| 339 | return is_zero_md(a_md) == is_zero_md(b_md); |
| 340 | }; |
| 341 | |
| 342 | args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) |
| 343 | && xnor_md(dst_iter_desc, diff_dst_iter_desc) |
| 344 | && xnor_md(src_iter_desc, diff_src_iter_desc); |
| 345 | if (!args_ok) |
| 346 | return invalid_arguments; |
| 347 | |
| 348 | //check dimensions consistency |
| 349 | int L = weights_layer_desc->dims[0]; |
| 350 | int T = src_layer_desc->dims[0]; |
| 351 | int N = src_layer_desc->dims[1]; |
| 352 | const int D = one_of(direction, mkldnn_unidirectional_left2right, |
| 353 | mkldnn_unidirectional_right2left) ? |
| 354 | 1 : |
| 355 | 2; |
| 356 | int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); |
| 357 | int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); |
| 358 | int SLC = src_layer_desc->dims[2]; |
| 359 | int SIC = weights_iter_desc->dims[2]; |
| 360 | int DLC = dst_layer_desc->dims[2]; |
| 361 | int DIC = weights_layer_desc->dims[4]; |
| 362 | |
| 363 | status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, |
| 364 | G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, |
| 365 | weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, |
| 366 | dst_iter_desc); |
| 367 | if (st != success) return st; |
| 368 | |
| 369 | st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, |
| 370 | G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc, |
| 371 | diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc, |
| 372 | diff_dst_layer_desc, diff_dst_iter_desc); |
| 373 | if (st != success) return st; |
| 374 | |
| 375 | mkldnn_rnn_desc_t rd = zero_rnn_desc(); |
| 376 | |
| 377 | rd.primitive_kind = primitive_kind::rnn; |
| 378 | rd.prop_kind = prop_kind; |
| 379 | rd.cell_desc = *rnn_cell_desc; |
| 380 | rd.direction = direction; |
| 381 | |
| 382 | rd.src_layer_desc = copy_maybe_null(src_layer_desc); |
| 383 | rd.src_iter_desc = copy_maybe_null(src_iter_desc); |
| 384 | rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); |
| 385 | rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); |
| 386 | rd.bias_desc = copy_maybe_null(bias_desc); |
| 387 | rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); |
| 388 | rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); |
| 389 | rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc); |
| 390 | rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc); |
| 391 | rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc); |
| 392 | rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc); |
| 393 | rd.diff_bias_desc = copy_maybe_null(diff_bias_desc); |
| 394 | rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc); |
| 395 | rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc); |
| 396 | |
| 397 | *rnn_desc = rd; |
| 398 | |
| 399 | return success; |
| 400 | } |
| 401 | |