| 1 | /******************************************************************************* |
| 2 | * Copyright 2017-2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | |
| 17 | #ifndef CPU_JIT_TRANSPOSE_SRC_HPP |
| 18 | #define CPU_JIT_TRANSPOSE_SRC_HPP |
| 19 | |
| 20 | #include "cpu_barrier.hpp" |
| 21 | #include "jit_primitive_conf.hpp" |
| 22 | |
| 23 | namespace mkldnn { |
| 24 | namespace impl { |
| 25 | namespace cpu { |
| 26 | |
| 27 | struct jit_trans_src_t { |
| 28 | struct ctx_t { |
| 29 | const void *src; |
| 30 | const void *tr_src; |
| 31 | const void *src_prf; |
| 32 | const void *tr_src_prf; |
| 33 | |
| 34 | /* 1st conv 4fma: backward by weights */ |
| 35 | int nthr_oc_b; /* number of threads process given src image */ |
| 36 | int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ |
| 37 | simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ |
| 38 | }; |
| 39 | |
| 40 | jit_trans_src_t(const jit_conv_conf_t *conf) |
| 41 | : conf_(conf), ker_(nullptr) {} |
| 42 | virtual ~jit_trans_src_t() {} |
| 43 | |
| 44 | void operator()(const ctx_t *ctx) |
| 45 | { assert(ker_); ker_(ctx); } |
| 46 | |
| 47 | const jit_conv_conf_t *conf_; |
| 48 | void (*ker_)(const ctx_t *); |
| 49 | }; |
| 50 | |
| 51 | struct jit_src_transpose_s { |
| 52 | size_t size; |
| 53 | const void *src; |
| 54 | const void *tr_src; |
| 55 | const void *src_prf; |
| 56 | const void *tr_src_prf; |
| 57 | }; |
| 58 | |
| 59 | struct jit_trans_dst_t { |
| 60 | struct ctx_t { |
| 61 | const void *src; |
| 62 | const void *tr_src; |
| 63 | const void *src_prf; |
| 64 | const void *tr_src_prf; |
| 65 | |
| 66 | /* 1st conv 4fma: backward by weights */ |
| 67 | int nthr_oc_b; /* number of threads process given src image */ |
| 68 | int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ |
| 69 | simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ |
| 70 | }; |
| 71 | |
| 72 | jit_trans_dst_t(const jit_conv_conf_t *conf) |
| 73 | : conf_(conf), ker_(nullptr) {} |
| 74 | virtual ~jit_trans_dst_t() {} |
| 75 | |
| 76 | void operator()(const ctx_t *ctx) |
| 77 | { assert(ker_); ker_(ctx); } |
| 78 | |
| 79 | const jit_conv_conf_t *conf_; |
| 80 | void (*ker_)(const ctx_t *); |
| 81 | }; |
| 82 | |
| 83 | struct jit_transpose4x16_src_t { |
| 84 | int src_pf0_distance; |
| 85 | int tr_src_pf0_distance; |
| 86 | bool src_pf1; |
| 87 | bool tr_src_pf1; |
| 88 | }; |
| 89 | |
| 90 | struct jit_transpose4x16_src : public jit_generator { |
| 91 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src) |
| 92 | |
| 93 | jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams, |
| 94 | jit_transpose4x16_src_t *tparams_) |
| 95 | : params(aparams), tparams(tparams_) |
| 96 | { |
| 97 | this->generate(); |
| 98 | jit_ker = (decltype(jit_ker))this->getCode(); |
| 99 | } |
| 100 | |
| 101 | const jit_1x1_conv_conf_t *params; |
| 102 | const jit_transpose4x16_src_t *tparams; |
| 103 | void (*jit_ker)(jit_src_transpose_s *); |
| 104 | |
| 105 | void operator()(jit_src_transpose_s *arg) { jit_ker(arg); } |
| 106 | |
| 107 | static const int transpose_size = 4; |
| 108 | private: |
| 109 | static const int typesize = sizeof(float); |
| 110 | |
| 111 | int src_stride, tr_src_stride; |
| 112 | |
| 113 | Xbyak::Reg64 imm_addr64 = rbx; |
| 114 | |
| 115 | Xbyak::Opmask kF0 = k1; |
| 116 | Xbyak::Opmask kCC = k2; |
| 117 | Xbyak::Opmask k33 = k3; |
| 118 | Xbyak::Opmask kFFFF = k4; |
| 119 | |
| 120 | Xbyak::Zmm vidx01 = zmm31; |
| 121 | Xbyak::Zmm vidx10 = zmm30; |
| 122 | Xbyak::Zmm vidx1 = zmm29; |
| 123 | Xbyak::Zmm vidxP = zmm28; |
| 124 | |
| 125 | Xbyak::Reg64 reg_src = r8; |
| 126 | Xbyak::Reg64 reg_tr_src = r9; |
| 127 | Xbyak::Reg64 reg_src_prf = r10; |
| 128 | Xbyak::Reg64 reg_tr_src_prf = r11; |
| 129 | Xbyak::Reg64 reg_loop = r12; |
| 130 | Xbyak::Reg64 reg_tr_src_tmp = r13; |
| 131 | Xbyak::Reg32 regw_tmp = r14d; |
| 132 | |
| 133 | void transpose_block(int ur, int nrows); |
| 134 | void transpose(int nrows); |
| 135 | void generate(); |
| 136 | }; |
| 137 | |
| 138 | jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf); |
| 139 | jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf); |
| 140 | |
| 141 | } |
| 142 | } |
| 143 | } |
| 144 | |
| 145 | #endif |
| 146 | |