1from ..lang import * 2 3T1 = TV.T1 4T2 = TV.T2 5 6Batch = S.Batch 7 8 9@linalg_structured_op 10def copy(I=TensorDef(T1), 11 O=TensorDef(U, output=True), 12 cast=TypeFnAttrDef(default=TypeFn.cast_signed)): 13 """Copies the tensor elementwise. 14 15 Numeric casting is performed on the input operand, promoting it to the same 16 data type as the accumulator/output. 17 """ 18 O[None] = cast(U, I[None]) 19 20 21@linalg_structured_op 22def elemwise_unary(I=TensorDef(T1), 23 O=TensorDef(U, output=True), 24 fun=UnaryFnAttrDef(default=UnaryFn.exp), 25 cast=TypeFnAttrDef(default=TypeFn.cast_signed)): 26 """Applies the unary function fun elementwise. 27 28 Numeric casting is performed on the input operand, promoting it to the same 29 data type as the accumulator/output. 30 """ 31 O[None] = fun(cast(U, I[None])) 32 33 34@linalg_structured_op 35def elemwise_binary(lhs=TensorDef(T1), 36 rhs=TensorDef(T2), 37 O=TensorDef(U, output=True), 38 fun=BinaryFnAttrDef(default=BinaryFn.add), 39 cast=TypeFnAttrDef(default=TypeFn.cast_signed)): 40 """Applies the binary function fun elementwise. 41 42 Numeric casting is performed on the input operand, promoting it to the same 43 data type as the accumulator/output. 44 """ 45 O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) 46 47 48@linalg_structured_op 49def matmul(A=TensorDef(T1, S.M, S.K), 50 B=TensorDef(T2, S.K, S.N), 51 C=TensorDef(U, S.M, S.N, output=True), 52 cast=TypeFnAttrDef(default=TypeFn.cast_signed)): 53 """Performs a matrix multiplication of two 2D inputs. 54 55 Numeric casting is performed on the operands to the inner multiply, promoting 56 them to the same data type as the accumulator/output. 57 """ 58 domain(D.m, D.n, D.k) 59 implements(ContractionOpInterface) 60 C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) 61 62 63@linalg_structured_op 64def matmul_unsigned(A=TensorDef(T1, S.M, S.K), 65 B=TensorDef(T2, S.K, S.N), 66 C=TensorDef(U, S.M, S.N, output=True)): 67 """Performs an unsigned matrix multiplication of two 2D inputs. 68 69 Numeric casting is performed on the operands to the inner multiply, promoting 70 them to the same data type as the accumulator/output. 71 """ 72 domain(D.m, D.n, D.k) 73 implements(ContractionOpInterface) 74 C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( 75 U, B[D.k, D.n]) 76 77 78@linalg_structured_op 79def quantized_matmul(A=TensorDef(T1, S.M, S.K), 80 B=TensorDef(T2, S.K, S.N), 81 AZp=ScalarDef(I32), 82 BZp=ScalarDef(I32), 83 C=TensorDef(U, S.M, S.N, output=True)): 84 """Performs a matrix multiplication of two 2D inputs. 85 86 Numeric casting is performed on the operands to the inner multiply, promoting 87 them to the same data type as the accumulator/output. The quantized variant 88 includes zero-point adjustments for the left and right operands of the 89 matmul. 90 """ 91 domain(D.m, D.n, D.k) 92 C[D.m, 93 D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - 94 TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(U, B[D.k, D.n]) - 95 TypeFn.cast_signed(U, BZp)) 96 97 98@linalg_structured_op 99def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), 100 rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), 101 accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)): 102 """Performs a matrix-matrix-transpose multiplication of two 4D inputs. 103 104 Differences from linalg.matmul: 105 * The right hand side is transposed, whence the 't' in 'mmt'. 106 * The input and output tensors have a 4D shape instead of a 2D shape. They 107 are interpreted as 2D matrices with one level of 2D tile subdivision, 108 whence the 2+2=4 dimensions. The inner tile dimensions are identified with 109 '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads 110 as: MxK tiles, each of shape M0xK0. 111 """ 112 domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) 113 implements(ContractionOpInterface) 114 accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( 115 TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed( 116 TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) 117 118 119@linalg_structured_op 120def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), 121 B=TensorDef(T2, Batch, S.K, S.N), 122 C=TensorDef(U, Batch, S.M, S.N, output=True)): 123 """Performs a batched matrix multiplication of two 3D inputs. 124 125 Numeric casting is performed on the operands to the inner multiply, promoting 126 them to the same data type as the accumulator/output. 127 """ 128 domain(D.b, D.m, D.n, D.k) 129 implements(ContractionOpInterface) 130 C[D.b, D.m, 131 D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( 132 U, B[D.b, D.k, D.n]) 133 134 135@linalg_structured_op 136def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), 137 B=TensorDef(T2, Batch, S.K, S.N), 138 AZp=ScalarDef(I32), 139 BZp=ScalarDef(I32), 140 C=TensorDef(U, Batch, S.M, S.N, output=True)): 141 """Performs a batched matrix multiplication of two 3D inputs. 142 143 Numeric casting is performed on the operands to the inner multiply, promoting 144 them to the same data type as the accumulator/output. The quantized variant 145 includes zero-point adjustments for the left and right operands of the 146 matmul. 147 """ 148 domain(D.b, D.m, D.n, D.k) 149 C[D.b, D.m, D.n] += (TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - 150 TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( 151 U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) 152 153 154@linalg_structured_op 155def matvec(A=TensorDef(T1, S.M, S.N), 156 y=TensorDef(T2, S.N), 157 x=TensorDef(U, S.M, output=True)): 158 """Performs a matrix-vector multiplication. 159 160 Numeric casting is performed on the operands to the inner multiply, promoting 161 them to the same data type as the accumulator/output. 162 """ 163 domain(D.m, D.n) 164 implements(ContractionOpInterface) 165 x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) 166 167 168@linalg_structured_op 169def vecmat(y=TensorDef(T1, S.M), 170 A=TensorDef(T2, S.M, S.N), 171 x=TensorDef(U, S.N, output=True)): 172 """Performs a vector-matrix multiplication. 173 174 Numeric casting is performed on the operands to the inner multiply, promoting 175 them to the same data type as the accumulator/output. 176 """ 177 domain(D.n, D.m) 178 implements(ContractionOpInterface) 179 x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) 180 181 182@linalg_structured_op 183def batch_matvec(A=TensorDef(T1, Batch, S.M, S.K), 184 B=TensorDef(T2, Batch, S.K), 185 C=TensorDef(U, Batch, S.M, output=True)): 186 """Performs a batched matrix-vector multiplication. 187 188 Numeric casting is performed on the operands to the inner multiply, promoting 189 them to the same data type as the accumulator/output. 190 """ 191 domain(D.b, D.m, D.k) 192 implements(ContractionOpInterface) 193 C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( 194 U, B[D.b, D.k]) 195 196 197@linalg_structured_op 198def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, 199 output=True)): 200 """Performs a dot product of two vectors to a scalar result. 201 202 Numeric casting is performed on the operands to the inner multiply, promoting 203 them to the same data type as the accumulator/output. 204 """ 205 implements(ContractionOpInterface) 206 C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) 207 208 209@linalg_structured_op 210def conv_1d(I=TensorDef(T1, S.OW + S.KW), 211 K=TensorDef(T2, S.KW), 212 O=TensorDef(U, S.OW, output=True)): 213 """Performs 1-D convolution with no channels. 214 215 Numeric casting is performed on the operands to the inner multiply, promoting 216 them to the same data type as the accumulator/output. 217 """ 218 implements(ConvolutionOpInterface) 219 domain(D.ow, D.kw) 220 O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed( 221 U, K[D.kw]) 222 223 224@linalg_structured_op 225def conv_2d(I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), 226 K=TensorDef(T2, S.KH, S.KW), 227 O=TensorDef(U, S.OH, S.OW, output=True)): 228 """Performs 2-D convolution with no channels. 229 230 Numeric casting is performed on the operands to the inner multiply, promoting 231 them to the same data type as the accumulator/output. 232 """ 233 implements(ConvolutionOpInterface) 234 domain(D.oh, D.ow, D.kh, D.kw) 235 O[D.oh, D.ow] += TypeFn.cast_signed( 236 U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw]) 237 238 239@linalg_structured_op 240def conv_3d(I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), 241 K=TensorDef(T2, S.KD, S.KH, S.KW), 242 O=TensorDef(U, S.OD, S.OH, S.OW, output=True)): 243 """Performs 3-D convolution with no channels. 244 245 Numeric casting is performed on the operands to the inner multiply, promoting 246 them to the same data type as the accumulator/output. 247 """ 248 implements(ConvolutionOpInterface) 249 domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) 250 O[D.od, D.oh, D.ow] += TypeFn.cast_signed( 251 U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed( 252 U, K[D.kd, D.kh, D.kw]) 253 254 255@linalg_structured_op 256def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), 257 K=TensorDef(T2, S.KW, S.C, S.F), 258 O=TensorDef(U, S.N, S.OW, S.F, output=True), 259 strides=IndexAttrDef(S.SW, default=[1]), 260 dilations=IndexAttrDef(S.DW, default=[1])): 261 """Performs 1-D convolution. 262 263 Numeric casting is performed on the operands to the inner multiply, promoting 264 them to the same data type as the accumulator/output. 265 """ 266 implements(ConvolutionOpInterface) 267 domain(D.n, D.ow, D.f, D.kw, D.c) 268 O[D.n, D.ow, D.f] += TypeFn.cast_signed( 269 U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( 270 U, K[D.kw, D.c, D.f]) 271 272 273@linalg_structured_op 274def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 275 S.OW * S.SW + S.KW * S.DW, S.C), 276 K=TensorDef(T2, S.KH, S.KW, S.C, S.F), 277 O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), 278 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 279 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 280 """Performs 2-D convolution. 281 282 Layout: 283 * Input: NHWC. 284 * Kernel: HWCF. 285 286 Numeric casting is performed on the operands to the inner multiply, promoting 287 them to the same data type as the accumulator/output. 288 """ 289 implements(ConvolutionOpInterface) 290 domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) 291 O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( 292 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, 293 D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) 294 295 296@linalg_structured_op 297def conv_2d_nhwc_fhwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 298 S.OW * S.SW + S.KW * S.DW, S.C), 299 K=TensorDef(T2, S.F, S.KH, S.KW, S.C), 300 O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), 301 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 302 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 303 """Performs 2-D convolution. 304 305 Layout: 306 * Input: NHWC. 307 * Kernel: FHWC. 308 309 Numeric casting is performed on the operands to the inner multiply, promoting 310 them to the same data type as the accumulator/output. 311 """ 312 implements(ConvolutionOpInterface) 313 domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) 314 O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( 315 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, 316 D.c]) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) 317 318 319@linalg_structured_op 320def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 321 S.OW * S.SW + S.KW * S.DW, S.C), 322 K=TensorDef(T2, S.KH, S.KW, S.C, S.F), 323 IZp=ScalarDef(I32), 324 KZp=ScalarDef(I32), 325 O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), 326 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 327 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 328 """Performs 2-D convolution with zero point offsets. 329 330 Layout: 331 * Input: NHWC. 332 * Kernel: HWCF. 333 334 Numeric casting is performed on the operands to the inner multiply, promoting 335 them to the same data type as the accumulator/output. This includes the zero 336 point offsets common to quantized operations. 337 """ 338 implements(ConvolutionOpInterface) 339 domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) 340 O[D.n, D.oh, D.ow, 341 D.f] += (TypeFn.cast_signed( 342 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) - 343 TypeFn.cast_signed(U, IZp)) * (TypeFn.cast_signed( 344 U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) 345 346 347@linalg_structured_op 348def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, 349 S.OW * S.SW + S.KW * S.DW), 350 K=TensorDef(T2, S.F, S.C, S.KH, S.KW), 351 O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), 352 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 353 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 354 """Performs 2-D convolution. 355 356 Layout: 357 * Input: NCHW. 358 * Kernel: FCHW. 359 360 Numeric casting is performed on the operands to the inner multiply, promoting 361 them to the same data type as the accumulator/output. 362 """ 363 implements(ConvolutionOpInterface) 364 domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) 365 O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( 366 U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + 367 D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) 368 369@linalg_structured_op 370def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, 371 S.OW * S.SW + S.KW * S.DW), 372 K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), 373 O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), 374 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 375 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 376 """Performs 2-D grouped convolution. 377 378 Layout: 379 * Input: NGCHW. 380 * Kernel: FGCHW. 381 382 Numeric casting is performed on the operands to the inner multiply, promoting 383 them to the same data type as the accumulator/output. 384 """ 385 implements(ConvolutionOpInterface) 386 domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) 387 O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( 388 U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + 389 D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) 390 391@linalg_structured_op 392def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, 393 S.OH * S.SH + S.KH * S.DH, 394 S.OW * S.SW + S.KW * S.DW, S.C), 395 K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), 396 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), 397 strides=IndexAttrDef(S.SD, 398 S.SH, 399 S.SW, 400 default=[1, 1, 1]), 401 dilations=IndexAttrDef(S.DD, 402 S.DH, 403 S.DW, 404 default=[1, 1, 1])): 405 """Performs 3-D convolution. 406 407 Numeric casting is performed on the operands to the inner multiply, promoting 408 them to the same data type as the accumulator/output. 409 """ 410 implements(ConvolutionOpInterface) 411 domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) 412 O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( 413 U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, 414 D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( 415 U, K[D.kd, D.kh, D.kw, D.c, D.f]) 416 417 418@linalg_structured_op 419def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, 420 S.IC), 421 K=TensorDef(T2, S.KW, S.IC), 422 O=TensorDef(U, S.N, S.OW, S.IC, output=True), 423 strides=IndexAttrDef(S.SW, default=[1]), 424 dilations=IndexAttrDef(S.DW, default=[1])): 425 """Performs depth-wise 1-D convolution. 426 427 Numeric casting is performed on the operands to the inner multiply, promoting 428 them to the same data type as the accumulator/output. Multiplier is set to 1 429 which is a special case for most depthwise convolutions. 430 """ 431 implements(ConvolutionOpInterface) 432 domain(D.n, D.ow, D.ic, D.kw) 433 O[D.n, D.ow, D.ic] += \ 434 TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ 435 TypeFn.cast_signed(U, K[D.kw, D.ic]) 436 437 438@linalg_structured_op 439def depthwise_conv_1d_nwc_wcm(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, 440 S.IC), 441 K=TensorDef(T2, S.KW, S.IC, S.CM), 442 O=TensorDef(U, S.N, S.OW, S.IC, S.CM, 443 output=True), 444 strides=IndexAttrDef(S.SW, default=[1]), 445 dilations=IndexAttrDef(S.DW, default=[1])): 446 """Performs depth-wise 1-D convolution. 447 448 Numeric casting is performed on the operands to the inner multiply, promoting 449 them to the same data type as the accumulator/output. 450 """ 451 implements(ConvolutionOpInterface) 452 domain(D.n, D.ow, D.ic, D.cm, D.kw) 453 O[D.n, D.ow, D.ic, D.cm] += \ 454 TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ 455 TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm]) 456 457 458@linalg_structured_op 459def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 460 S.OW * S.SW + S.KW * S.DW, S.IC), 461 K=TensorDef(T2, S.KH, S.KW, S.IC), 462 O=TensorDef(U, 463 S.N, 464 S.OH, 465 S.OW, 466 S.IC, 467 output=True), 468 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 469 dilations=IndexAttrDef(S.DH, 470 S.DW, 471 default=[1, 1])): 472 """Performs depth-wise 2-D convolution. 473 474 Numeric casting is performed on the operands to the inner multiply, promoting 475 them to the same data type as the accumulator/output. Multiplier is set to 1 476 which is a special case for most depthwise convolutions. 477 """ 478 implements(ConvolutionOpInterface) 479 domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) 480 O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( 481 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, 482 D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) 483 484 485@linalg_structured_op 486def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, 487 S.OW * S.SW + S.KW * S.DW), 488 K=TensorDef(T2, S.IC, S.KH, S.KW), 489 O=TensorDef(U, 490 S.N, 491 S.IC, 492 S.OH, 493 S.OW, 494 output=True), 495 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 496 dilations=IndexAttrDef(S.DH, 497 S.DW, 498 default=[1, 1])): 499 """Performs depth-wise 2-D convolution. 500 501 Numeric casting is performed on the operands to the inner multiply, promoting 502 them to the same data type as the accumulator/output. Multiplier is set to 1 503 which is a special case for most depthwise convolutions. 504 """ 505 implements(ConvolutionOpInterface) 506 domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) 507 O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( 508 U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) 509 510 511@linalg_structured_op 512def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 513 S.OW * S.SW + S.KW * S.DW, S.IC), 514 K=TensorDef(T2, S.KH, S.KW, S.IC), 515 IZp=ScalarDef(I32), 516 KZp=ScalarDef(I32), 517 O=TensorDef(U, 518 S.N, 519 S.OH, 520 S.OW, 521 S.IC, 522 output=True), 523 strides=IndexAttrDef(S.SH, 524 S.SW, 525 default=[1, 1]), 526 dilations=IndexAttrDef(S.DH, 527 S.DW, 528 default=[1, 1])): 529 """Performs depth-wise 2-D convolution. 530 531 Numeric casting is performed on the operands to the inner multiply, promoting 532 them to the same data type as the accumulator/output. 533 """ 534 implements(ConvolutionOpInterface) 535 domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) 536 O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed( 537 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - 538 TypeFn.cast_signed(U, IZp)) * 539 (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - 540 TypeFn.cast_signed(U, KZp))) 541 542 543@linalg_structured_op 544def depthwise_conv_2d_nhwc_hwcm(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 545 S.OW * S.SW + S.KW * S.DW, S.IC), 546 K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), 547 O=TensorDef(U, 548 S.N, 549 S.OH, 550 S.OW, 551 S.IC, 552 S.CM, 553 output=True), 554 strides=IndexAttrDef(S.SH, S.SW, default=[1, 555 1]), 556 dilations=IndexAttrDef(S.DH, 557 S.DW, 558 default=[1, 1])): 559 """Performs depth-wise 2-D convolution. 560 561 Numeric casting is performed on the operands to the inner multiply, promoting 562 them to the same data type as the accumulator/output. 563 """ 564 implements(ConvolutionOpInterface) 565 domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) 566 O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( 567 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, 568 D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) 569 570 571@linalg_structured_op 572def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N, 573 S.OH * S.SH + S.KH * S.DH, 574 S.OW * S.SW + S.KW * S.DW, S.IC), 575 K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), 576 IZp=ScalarDef(I32), 577 KZp=ScalarDef(I32), 578 O=TensorDef(U, 579 S.N, 580 S.OH, 581 S.OW, 582 S.IC, 583 S.CM, 584 output=True), 585 strides=IndexAttrDef(S.SH, 586 S.SW, 587 default=[1, 1]), 588 dilations=IndexAttrDef(S.DH, 589 S.DW, 590 default=[1, 1])): 591 """Performs depth-wise 2-D convolution. 592 593 Numeric casting is performed on the operands to the inner multiply, promoting 594 them to the same data type as the accumulator/output. 595 """ 596 implements(ConvolutionOpInterface) 597 domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) 598 O[D.n, D.oh, D.ow, D.ic, 599 D.cm] += ((TypeFn.cast_signed( 600 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - 601 TypeFn.cast_signed(U, IZp)) * 602 (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - 603 TypeFn.cast_signed(U, KZp))) 604 605 606@linalg_structured_op 607def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, 608 S.OH * S.SH + S.KH * S.DH, 609 S.OW * S.SW + S.KW * S.DW, S.IC), 610 K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), 611 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, 612 output=True), 613 strides=IndexAttrDef(S.SD, 614 S.SH, 615 S.SW, 616 default=[1, 1, 1]), 617 dilations=IndexAttrDef(S.DD, 618 S.DH, 619 S.DW, 620 default=[1, 1, 1])): 621 """Performs depth-wise 3-D convolution. 622 623 Numeric casting is performed on the operands to the inner multiply, promoting 624 them to the same data type as the accumulator/output. Multiplier is set to 1 625 which is a special case for most depthwise convolutions. 626 """ 627 implements(ConvolutionOpInterface) 628 domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) 629 O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed( 630 U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, 631 D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( 632 U, K[D.kd, D.kh, D.kw, D.ic]) 633 634 635@linalg_structured_op 636def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, 637 S.N, S.OD * S.SD + S.KD * S.DD, 638 S.OH * S.SH + S.KH * S.DH, 639 S.OW * S.SW + S.KW * S.DW, S.IC), 640 K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), 641 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, 642 output=True), 643 strides=IndexAttrDef(S.SD, 644 S.SH, 645 S.SW, 646 default=[1, 1, 1]), 647 dilations=IndexAttrDef(S.DD, 648 S.DH, 649 S.DW, 650 default=[1, 1, 1])): 651 """Performs depth-wise 3-D convolution. 652 653 Numeric casting is performed on the operands to the inner multiply, promoting 654 them to the same data type as the accumulator/output. 655 """ 656 implements(ConvolutionOpInterface) 657 domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic) 658 O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( 659 U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, 660 D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( 661 U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) 662 663 664@linalg_structured_op 665def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 666 S.OW * S.SW + S.KW * S.DW, S.C), 667 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 668 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 669 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 670 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 671 """Performs sum pooling. 672 673 Layout: 674 * Input: NHWC. 675 * Kernel: HW. 676 677 Numeric casting is performed on the input operand, promoting it to the same 678 data type as the accumulator/output. 679 """ 680 implements(ConvolutionOpInterface) 681 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 682 O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( 683 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) 684 685 686@linalg_structured_op 687def pooling_nchw_sum(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, 688 S.OW * S.SW + S.KW * S.DW), 689 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 690 O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), 691 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 692 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 693 """Performs sum pooling. 694 695 Layout: 696 * Input: NCHW. 697 * Kernel: HW. 698 699 Numeric casting is performed on the input operand, promoting it to the same 700 data type as the accumulator/output. 701 """ 702 implements(ConvolutionOpInterface) 703 domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) 704 O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed( 705 U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) 706 707 708@linalg_structured_op 709def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 710 S.OW * S.SW + S.KW * S.DW, S.C), 711 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 712 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 713 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 714 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 715 """Performs max pooling. 716 717 Numeric casting is performed on the input operand, promoting it to the same 718 data type as the accumulator/output. 719 """ 720 implements(ConvolutionOpInterface) 721 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 722 O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed( 723 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) 724 725 726@linalg_structured_op 727def pooling_nhwc_max_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 728 S.OW * S.SW + S.KW * S.DW, S.C), 729 K=TensorDef(T2, 730 S.KH, 731 S.KW, 732 index_dims=[D.kh, D.kw]), 733 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 734 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 735 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 736 1])): 737 """Performs unsigned max pooling. 738 739 Numeric casting is performed on the input operand, promoting it to the same 740 data type as the accumulator/output. 741 """ 742 implements(ConvolutionOpInterface) 743 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 744 O[D.n, D.oh, D.ow, 745 D.c] = ReduceFn.max_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( 746 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) 747 748 749@linalg_structured_op 750def pooling_nchw_max(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, 751 S.OW * S.SW + S.KW * S.DW), 752 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 753 O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), 754 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 755 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 756 """Performs max pooling. 757 758 Numeric casting is performed on the input operand, promoting it to the same 759 data type as the accumulator/output. 760 """ 761 implements(ConvolutionOpInterface) 762 domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) 763 O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed( 764 U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,])) 765 766 767@linalg_structured_op 768def pooling_nhwc_min(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 769 S.OW * S.SW + S.KW * S.DW, S.C), 770 K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), 771 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 772 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 773 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): 774 """Performs min pooling. 775 776 Numeric casting is performed on the input operand, promoting it to the same 777 data type as the accumulator/output. 778 """ 779 implements(ConvolutionOpInterface) 780 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 781 O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](TypeFn.cast_signed( 782 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) 783 784 785@linalg_structured_op 786def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, 787 S.OW * S.SW + S.KW * S.DW, S.C), 788 K=TensorDef(T2, 789 S.KH, 790 S.KW, 791 index_dims=[D.kh, D.kw]), 792 O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), 793 strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), 794 dilations=IndexAttrDef(S.DH, S.DW, default=[1, 795 1])): 796 """Performs unsigned min pooling. 797 798 Numeric casting is performed on the input operand, promoting it to the same 799 data type as the accumulator/output. 800 """ 801 implements(ConvolutionOpInterface) 802 domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) 803 O[D.n, D.oh, D.ow, 804 D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( 805 U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) 806 807 808@linalg_structured_op 809def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, 810 S.OH * S.SH + S.KH * S.DH, 811 S.OW * S.SW + S.KW * S.DW, S.C), 812 K=TensorDef(T2, 813 S.KD, 814 S.KH, 815 S.KW, 816 index_dims=[D.kd, D.kh, D.kw]), 817 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), 818 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 819 dilations=IndexAttrDef(S.DD, 820 S.DH, 821 S.DW, 822 default=[1, 1, 1])): 823 """Performs 3D sum pooling. 824 825 Numeric casting is performed on the input operand, promoting it to the same 826 data type as the accumulator/output. 827 """ 828 implements(ConvolutionOpInterface) 829 domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) 830 O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( 831 U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, 832 D.ow * S.SW + D.kw * S.DW, D.c]) 833 834 835@linalg_structured_op 836def pooling_ndhwc_max(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, 837 S.OH * S.SH + S.KH * S.DH, 838 S.OW * S.SW + S.KW * S.DW, S.C), 839 K=TensorDef(T2, 840 S.KD, 841 S.KH, 842 S.KW, 843 index_dims=[D.kd, D.kh, D.kw]), 844 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), 845 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 846 dilations=IndexAttrDef(S.DD, 847 S.DH, 848 S.DW, 849 default=[1, 1, 1])): 850 """Performs 3D max pooling. 851 852 Numeric casting is performed on the input operand, promoting it to the same 853 data type as the accumulator/output. 854 """ 855 implements(ConvolutionOpInterface) 856 domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) 857 O[D.n, D.od, D.oh, D.ow, 858 D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed( 859 U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, 860 D.ow * S.SW + D.kw * S.DW, D.c])) 861 862 863@linalg_structured_op 864def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, 865 S.OH * S.SH + S.KH * S.DH, 866 S.OW * S.SW + S.KW * S.DW, S.C), 867 K=TensorDef(T2, 868 S.KD, 869 S.KH, 870 S.KW, 871 index_dims=[D.kd, D.kh, D.kw]), 872 O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), 873 strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), 874 dilations=IndexAttrDef(S.DD, 875 S.DH, 876 S.DW, 877 default=[1, 1, 1])): 878 """Performs 3D min pooling. 879 880 Numeric casting is performed on the input operand, promoting it to the same 881 data type as the accumulator/output. 882 """ 883 implements(ConvolutionOpInterface) 884 domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) 885 O[D.n, D.od, D.oh, D.ow, 886 D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed( 887 U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, 888 D.ow * S.SW + D.kw * S.DW, D.c])) 889 890 891@linalg_structured_op 892def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): 893 """Fills the output tensor with the given value. 894 895 Works for arbitrary ranked output tensors since the operation performs scalar 896 accesses only and is thus rank polymorphic. Numeric casting is performed on 897 the value operand, promoting it to the same data type as the output. 898 """ 899 implements(FillOpInterface) 900 defines(Canonicalizer) 901 O[None] = TypeFn.cast_signed(U, value) 902 903 904@linalg_structured_op 905def fill_rng_2d(min=ScalarDef(F64), 906 max=ScalarDef(F64), 907 seed=ScalarDef(I32), 908 O=TensorDef(T, S.M, S.N, output=True)): 909 """Fills the output tensor with pseudo random numbers. 910 911 The operation generations pseudo random numbers using a linear congruential 912 generator. It provides no guarantees regarding the distribution of the 913 generated random numbers. Instead of generating the random numbers 914 sequentially, it instantiates one random number generator per data element 915 and runs them in parallel. The seed operand and the indices of the data 916 element seed the random number generation. The min and max operands limit 917 the range of the generated random numbers. 918 """ 919 domain(D.m, D.n) 920 multiplier = TypeFn.cast_signed(I32, const(1103515245)) 921 increment = TypeFn.cast_signed(I32, const(12345)) 922 rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment 923 rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment 924 inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) 925 offset = TypeFn.cast_signed(F64, const(2147483647)) 926 scaling = (max - min) * inv_range 927 O[D.m, D.n] = TypeFn.cast_signed( 928 T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min) 929