| 1 | /******************************************************************************* |
| 2 | * Copyright 2017-2018 Intel Corporation |
| 3 | * Copyright 2018 YANDEX LLC |
| 4 | * |
| 5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | * you may not use this file except in compliance with the License. |
| 7 | * You may obtain a copy of the License at |
| 8 | * |
| 9 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | * |
| 11 | * Unless required by applicable law or agreed to in writing, software |
| 12 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | * See the License for the specific language governing permissions and |
| 15 | * limitations under the License. |
| 16 | *******************************************************************************/ |
| 17 | |
| 18 | #ifndef JIT_UNI_POOL_KERNEL_F32_HPP |
| 19 | #define JIT_UNI_POOL_KERNEL_F32_HPP |
| 20 | |
| 21 | #include <cfloat> |
| 22 | |
| 23 | #include "c_types_map.hpp" |
| 24 | #include "pooling_pd.hpp" |
| 25 | #include "type_helpers.hpp" |
| 26 | |
| 27 | #include "jit_generator.hpp" |
| 28 | #include "jit_primitive_conf.hpp" |
| 29 | |
| 30 | namespace mkldnn { |
| 31 | namespace impl { |
| 32 | namespace cpu { |
| 33 | |
| 34 | using namespace Xbyak; |
| 35 | |
| 36 | template <cpu_isa_t isa> |
| 37 | struct jit_uni_pool_kernel_f32: public jit_generator { |
| 38 | jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp) |
| 39 | { |
| 40 | this->generate(); |
| 41 | jit_ker = (decltype(jit_ker))this->getCode(); |
| 42 | } |
| 43 | |
| 44 | jit_pool_conf_t jpp; |
| 45 | |
| 46 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32) |
| 47 | |
| 48 | void operator()(jit_pool_call_s *arg) { jit_ker(arg); } |
| 49 | static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd); |
| 50 | |
| 51 | private: |
| 52 | using Vmm = typename utils::conditional3<isa == sse42, Xmm, isa == avx, |
| 53 | Ymm, Zmm>::type; |
| 54 | Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); } |
| 55 | Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); } |
| 56 | Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); } |
| 57 | |
| 58 | const AddressFrame &vmmword = (isa == sse42) ? xword : |
| 59 | (isa == avx) ? yword : zword; |
| 60 | |
| 61 | Xmm vmm_mask = Xmm(0); |
| 62 | Xmm xmm_ker_area_h = Xmm(2); |
| 63 | Xmm xmm_one = Xmm(2); |
| 64 | Xmm xmm_tmp = Xmm(3); |
| 65 | |
| 66 | Vmm vmm_ker_area_h = Vmm(2); |
| 67 | Vmm vmm_one = Vmm(2); |
| 68 | Vmm vmm_tmp = Vmm(3); |
| 69 | |
| 70 | Vmm vmm_k_offset = Vmm(1); |
| 71 | |
| 72 | Opmask k_index_mask = Opmask(6); |
| 73 | Opmask k_store_mask = Opmask(7); |
| 74 | |
| 75 | // Here be some (tame) dragons. This kernel does not follow the regular |
| 76 | // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu |
| 77 | // instruction which has its destination hardcoded in rdi. Therefore: |
| 78 | // - all registers are hardcoded |
| 79 | // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI |
| 80 | // |
| 81 | // While this is only required by the backward pass, the quirk above |
| 82 | // is applied to the forward pass as well to keep things simpler. |
| 83 | |
| 84 | using reg64_t = const Xbyak::Reg64; |
| 85 | reg64_t reg_param = rdi; // Always mimic the Unix ABI |
| 86 | reg64_t reg_input = r8; |
| 87 | reg64_t aux_reg_input = r9; |
| 88 | reg64_t reg_index = r10; |
| 89 | reg64_t reg_output = r12; |
| 90 | reg64_t reg_kd_pad_shift = r13; |
| 91 | reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu |
| 92 | |
| 93 | reg64_t kj = r14; |
| 94 | reg64_t oi_iter = r15; |
| 95 | reg64_t reg_kh = rax; |
| 96 | reg64_t reg_k_shift = rbx; |
| 97 | reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above |
| 98 | reg64_t reg_ker_area_h = rdx; |
| 99 | |
| 100 | reg64_t zero_size = r15; |
| 101 | reg64_t ki = r12; |
| 102 | reg64_t aux_reg_input_d = r8; |
| 103 | |
| 104 | Xbyak::Reg32 reg_shuf_mask = esi; |
| 105 | |
| 106 | int prev_kw; |
| 107 | void (*jit_ker)(jit_pool_call_s *); |
| 108 | |
| 109 | void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r); |
| 110 | void avg_step(int ur_w, int pad_l, int pad_r); |
| 111 | void max_step_fwd(int ur_w, int pad_l, int pad_r); |
| 112 | void max_step_bwd(int ur_w, int pad_l, int pad_r); |
| 113 | |
| 114 | void maybe_zero_diff_src(); |
| 115 | |
| 116 | void step(int ur_w, int pad_l, int pad_r) { |
| 117 | if (jpp.alg == alg_kind::pooling_max) { |
| 118 | if(jpp.is_backward) |
| 119 | max_step_bwd(ur_w, pad_l, pad_r); |
| 120 | else |
| 121 | max_step_fwd(ur_w, pad_l, pad_r); |
| 122 | } |
| 123 | else |
| 124 | avg_step(ur_w, pad_l, pad_r); |
| 125 | } |
| 126 | |
| 127 | void step_high_half(int ur_w, int pad_l, int pad_r) { |
| 128 | add(reg_input, sizeof(float) * 4); |
| 129 | add(reg_output, sizeof(float) * 4); |
| 130 | if (jpp.alg == alg_kind::pooling_max && |
| 131 | (jpp.is_training || jpp.is_backward)) |
| 132 | add(reg_index, types::data_type_size(jpp.ind_dt) * 4); |
| 133 | |
| 134 | step(ur_w, pad_l, pad_r); |
| 135 | } |
| 136 | |
| 137 | void generate(); |
| 138 | |
| 139 | void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { |
| 140 | assert(y0.getIdx() != x1.getIdx()); |
| 141 | vextractf128(xtmp, y0, 0); |
| 142 | vpaddd(xtmp, xtmp, x1); |
| 143 | vinsertf128(y0, y0, xtmp, 0); |
| 144 | vextractf128(xtmp, y0, 1); |
| 145 | vpaddd(xtmp, xtmp, x1); |
| 146 | vinsertf128(y0, y0, xtmp, 1); |
| 147 | } |
| 148 | |
| 149 | void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) { |
| 150 | assert(false /*function should not be used*/); |
| 151 | paddd(x0, x1); |
| 152 | } |
| 153 | |
| 154 | void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { |
| 155 | Xmm x0(y0.getIdx()); |
| 156 | pshufd(xmm_tmp, x1, 1); |
| 157 | pmovzxbd(x0, x1); |
| 158 | pmovzxbd(xmm_tmp, xmm_tmp); |
| 159 | vinsertf128(y0, y0, xmm_tmp, 1); |
| 160 | } |
| 161 | |
| 162 | void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) { |
| 163 | assert(false /*function should not be used*/); |
| 164 | pmovzxbd(x0, x1); |
| 165 | } |
| 166 | |
| 167 | void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) { |
| 168 | assert(y0.getIdx() != y1.getIdx()); |
| 169 | assert(y0.getIdx() != y2.getIdx()); |
| 170 | Xmm x0(y0.getIdx()); |
| 171 | Xmm x2(y2.getIdx()); |
| 172 | vextractf128(x0, y1, 1); |
| 173 | vextractf128(xtmp, y2, 1); |
| 174 | pcmpeqd(xtmp, x0); |
| 175 | vextractf128(x0, y1, 0); |
| 176 | pcmpeqd(x0, x2); |
| 177 | vinsertf128(y0, y0, xtmp, 1); |
| 178 | } |
| 179 | |
| 180 | void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) { |
| 181 | assert(false /*function should not be used*/); |
| 182 | pcmpeqd(x0, x1); |
| 183 | } |
| 184 | }; |
| 185 | |
| 186 | } |
| 187 | } |
| 188 | } |
| 189 | |
| 190 | #endif |
| 191 | |
| 192 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
| 193 | |