1from ..lang import *
2
3T1 = TV.T1
4T2 = TV.T2
5
6Batch = S.Batch
7
8
9@linalg_structured_op
10def copy(I=TensorDef(T1),
11         O=TensorDef(U, output=True),
12         cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
13  """Copies the tensor elementwise.
14
15  Numeric casting is performed on the input operand, promoting it to the same
16  data type as the accumulator/output.
17  """
18  O[None] = cast(U, I[None])
19
20
21@linalg_structured_op
22def elemwise_unary(I=TensorDef(T1),
23                   O=TensorDef(U, output=True),
24                   fun=UnaryFnAttrDef(default=UnaryFn.exp),
25                   cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
26  """Applies the unary function fun elementwise.
27
28  Numeric casting is performed on the input operand, promoting it to the same
29  data type as the accumulator/output.
30  """
31  O[None] = fun(cast(U, I[None]))
32
33
34@linalg_structured_op
35def elemwise_binary(lhs=TensorDef(T1),
36                    rhs=TensorDef(T2),
37                    O=TensorDef(U, output=True),
38                    fun=BinaryFnAttrDef(default=BinaryFn.add),
39                    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
40  """Applies the binary function fun elementwise.
41
42  Numeric casting is performed on the input operand, promoting it to the same
43  data type as the accumulator/output.
44  """
45  O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
46
47
48@linalg_structured_op
49def matmul(A=TensorDef(T1, S.M, S.K),
50           B=TensorDef(T2, S.K, S.N),
51           C=TensorDef(U, S.M, S.N, output=True),
52           cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
53  """Performs a matrix multiplication of two 2D inputs.
54
55  Numeric casting is performed on the operands to the inner multiply, promoting
56  them to the same data type as the accumulator/output.
57  """
58  domain(D.m, D.n, D.k)
59  implements(ContractionOpInterface)
60  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
61
62
63@linalg_structured_op
64def matmul_unsigned(A=TensorDef(T1, S.M, S.K),
65                    B=TensorDef(T2, S.K, S.N),
66                    C=TensorDef(U, S.M, S.N, output=True)):
67  """Performs an unsigned matrix multiplication of two 2D inputs.
68
69  Numeric casting is performed on the operands to the inner multiply, promoting
70  them to the same data type as the accumulator/output.
71  """
72  domain(D.m, D.n, D.k)
73  implements(ContractionOpInterface)
74  C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
75      U, B[D.k, D.n])
76
77
78@linalg_structured_op
79def quantized_matmul(A=TensorDef(T1, S.M, S.K),
80                     B=TensorDef(T2, S.K, S.N),
81                     AZp=ScalarDef(I32),
82                     BZp=ScalarDef(I32),
83                     C=TensorDef(U, S.M, S.N, output=True)):
84  """Performs a matrix multiplication of two 2D inputs.
85
86  Numeric casting is performed on the operands to the inner multiply, promoting
87  them to the same data type as the accumulator/output. The quantized variant
88  includes zero-point adjustments for the left and right operands of the
89  matmul.
90  """
91  domain(D.m, D.n, D.k)
92  C[D.m,
93    D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) -
94             TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(U, B[D.k, D.n]) -
95                                            TypeFn.cast_signed(U, BZp))
96
97
98@linalg_structured_op
99def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
100          rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
101          accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)):
102  """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
103
104    Differences from linalg.matmul:
105    * The right hand side is transposed, whence the 't' in 'mmt'.
106    * The input and output tensors have a 4D shape instead of a 2D shape. They
107      are interpreted as 2D matrices with one level of 2D tile subdivision,
108      whence the 2+2=4 dimensions. The inner tile dimensions are identified with
109      '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
110      as: MxK tiles, each of shape M0xK0.
111  """
112  domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
113  implements(ContractionOpInterface)
114  accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
115      TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed(
116          TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
117
118
119@linalg_structured_op
120def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
121                 B=TensorDef(T2, Batch, S.K, S.N),
122                 C=TensorDef(U, Batch, S.M, S.N, output=True)):
123  """Performs a batched matrix multiplication of two 3D inputs.
124
125  Numeric casting is performed on the operands to the inner multiply, promoting
126  them to the same data type as the accumulator/output.
127  """
128  domain(D.b, D.m, D.n, D.k)
129  implements(ContractionOpInterface)
130  C[D.b, D.m,
131    D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
132        U, B[D.b, D.k, D.n])
133
134
135@linalg_structured_op
136def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
137                           B=TensorDef(T2, Batch, S.K, S.N),
138                           AZp=ScalarDef(I32),
139                           BZp=ScalarDef(I32),
140                           C=TensorDef(U, Batch, S.M, S.N, output=True)):
141  """Performs a batched matrix multiplication of two 3D inputs.
142
143  Numeric casting is performed on the operands to the inner multiply, promoting
144  them to the same data type as the accumulator/output. The quantized variant
145  includes zero-point adjustments for the left and right operands of the
146  matmul.
147  """
148  domain(D.b, D.m, D.n, D.k)
149  C[D.b, D.m, D.n] += (TypeFn.cast_signed(U, A[D.b, D.m, D.k]) -
150                       TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
151                           U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
152
153
154@linalg_structured_op
155def matvec(A=TensorDef(T1, S.M, S.N),
156           y=TensorDef(T2, S.N),
157           x=TensorDef(U, S.M, output=True)):
158  """Performs a matrix-vector multiplication.
159
160  Numeric casting is performed on the operands to the inner multiply, promoting
161  them to the same data type as the accumulator/output.
162  """
163  domain(D.m, D.n)
164  implements(ContractionOpInterface)
165  x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
166
167
168@linalg_structured_op
169def vecmat(y=TensorDef(T1, S.M),
170           A=TensorDef(T2, S.M, S.N),
171           x=TensorDef(U, S.N, output=True)):
172  """Performs a vector-matrix multiplication.
173
174  Numeric casting is performed on the operands to the inner multiply, promoting
175  them to the same data type as the accumulator/output.
176  """
177  domain(D.n, D.m)
178  implements(ContractionOpInterface)
179  x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
180
181
182@linalg_structured_op
183def batch_matvec(A=TensorDef(T1, Batch, S.M, S.K),
184                 B=TensorDef(T2, Batch, S.K),
185                 C=TensorDef(U, Batch, S.M, output=True)):
186  """Performs a batched matrix-vector multiplication.
187
188  Numeric casting is performed on the operands to the inner multiply, promoting
189  them to the same data type as the accumulator/output.
190  """
191  domain(D.b, D.m, D.k)
192  implements(ContractionOpInterface)
193  C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
194      U, B[D.b, D.k])
195
196
197@linalg_structured_op
198def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
199                                                                output=True)):
200  """Performs a dot product of two vectors to a scalar result.
201
202  Numeric casting is performed on the operands to the inner multiply, promoting
203  them to the same data type as the accumulator/output.
204  """
205  implements(ContractionOpInterface)
206  C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
207
208
209@linalg_structured_op
210def conv_1d(I=TensorDef(T1, S.OW + S.KW),
211            K=TensorDef(T2, S.KW),
212            O=TensorDef(U, S.OW, output=True)):
213  """Performs 1-D convolution with no channels.
214
215  Numeric casting is performed on the operands to the inner multiply, promoting
216  them to the same data type as the accumulator/output.
217  """
218  implements(ConvolutionOpInterface)
219  domain(D.ow, D.kw)
220  O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(
221      U, K[D.kw])
222
223
224@linalg_structured_op
225def conv_2d(I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
226            K=TensorDef(T2, S.KH, S.KW),
227            O=TensorDef(U, S.OH, S.OW, output=True)):
228  """Performs 2-D convolution with no channels.
229
230  Numeric casting is performed on the operands to the inner multiply, promoting
231  them to the same data type as the accumulator/output.
232  """
233  implements(ConvolutionOpInterface)
234  domain(D.oh, D.ow, D.kh, D.kw)
235  O[D.oh, D.ow] += TypeFn.cast_signed(
236      U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw])
237
238
239@linalg_structured_op
240def conv_3d(I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
241            K=TensorDef(T2, S.KD, S.KH, S.KW),
242            O=TensorDef(U, S.OD, S.OH, S.OW, output=True)):
243  """Performs 3-D convolution with no channels.
244
245  Numeric casting is performed on the operands to the inner multiply, promoting
246  them to the same data type as the accumulator/output.
247  """
248  implements(ConvolutionOpInterface)
249  domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
250  O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
251      U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(
252          U, K[D.kd, D.kh, D.kw])
253
254
255@linalg_structured_op
256def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
257                    K=TensorDef(T2, S.KW, S.C, S.F),
258                    O=TensorDef(U, S.N, S.OW, S.F, output=True),
259                    strides=IndexAttrDef(S.SW, default=[1]),
260                    dilations=IndexAttrDef(S.DW, default=[1])):
261  """Performs 1-D convolution.
262
263  Numeric casting is performed on the operands to the inner multiply, promoting
264  them to the same data type as the accumulator/output.
265  """
266  implements(ConvolutionOpInterface)
267  domain(D.n, D.ow, D.f, D.kw, D.c)
268  O[D.n, D.ow, D.f] += TypeFn.cast_signed(
269      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
270          U, K[D.kw, D.c, D.f])
271
272
273@linalg_structured_op
274def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
275                                  S.OW * S.SW + S.KW * S.DW, S.C),
276                      K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
277                      O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
278                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
279                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
280  """Performs 2-D convolution.
281
282  Layout:
283    * Input: NHWC.
284    * Kernel: HWCF.
285
286  Numeric casting is performed on the operands to the inner multiply, promoting
287  them to the same data type as the accumulator/output.
288  """
289  implements(ConvolutionOpInterface)
290  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
291  O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
292      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
293           D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
294
295
296@linalg_structured_op
297def conv_2d_nhwc_fhwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
298                                  S.OW * S.SW + S.KW * S.DW, S.C),
299                      K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
300                      O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
301                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
302                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
303  """Performs 2-D convolution.
304
305  Layout:
306    * Input: NHWC.
307    * Kernel: FHWC.
308
309  Numeric casting is performed on the operands to the inner multiply, promoting
310  them to the same data type as the accumulator/output.
311  """
312  implements(ConvolutionOpInterface)
313  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
314  O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
315      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
316           D.c]) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c])
317
318
319@linalg_structured_op
320def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
321                                    S.OW * S.SW + S.KW * S.DW, S.C),
322                        K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
323                        IZp=ScalarDef(I32),
324                        KZp=ScalarDef(I32),
325                        O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
326                        strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
327                        dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
328  """Performs 2-D convolution with zero point offsets.
329
330  Layout:
331    * Input: NHWC.
332    * Kernel: HWCF.
333
334  Numeric casting is performed on the operands to the inner multiply, promoting
335  them to the same data type as the accumulator/output. This includes the zero
336  point offsets common to quantized operations.
337  """
338  implements(ConvolutionOpInterface)
339  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
340  O[D.n, D.oh, D.ow,
341    D.f] += (TypeFn.cast_signed(
342        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
343             TypeFn.cast_signed(U, IZp)) * (TypeFn.cast_signed(
344                 U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
345
346
347@linalg_structured_op
348def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
349                                  S.OW * S.SW + S.KW * S.DW),
350                      K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
351                      O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
352                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
353                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
354  """Performs 2-D convolution.
355
356  Layout:
357    * Input: NCHW.
358    * Kernel: FCHW.
359
360  Numeric casting is performed on the operands to the inner multiply, promoting
361  them to the same data type as the accumulator/output.
362  """
363  implements(ConvolutionOpInterface)
364  domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
365  O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
366      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
367           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
368
369@linalg_structured_op
370def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH,
371                                  S.OW * S.SW + S.KW * S.DW),
372                      K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
373                      O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
374                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
375                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
376  """Performs 2-D grouped convolution.
377
378  Layout:
379    * Input: NGCHW.
380    * Kernel: FGCHW.
381
382  Numeric casting is performed on the operands to the inner multiply, promoting
383  them to the same data type as the accumulator/output.
384  """
385  implements(ConvolutionOpInterface)
386  domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
387  O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
388      U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
389          D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
390
391@linalg_structured_op
392def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
393                                    S.OH * S.SH + S.KH * S.DH,
394                                    S.OW * S.SW + S.KW * S.DW, S.C),
395                        K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
396                        O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
397                        strides=IndexAttrDef(S.SD,
398                                             S.SH,
399                                             S.SW,
400                                             default=[1, 1, 1]),
401                        dilations=IndexAttrDef(S.DD,
402                                               S.DH,
403                                               S.DW,
404                                               default=[1, 1, 1])):
405  """Performs 3-D convolution.
406
407  Numeric casting is performed on the operands to the inner multiply, promoting
408  them to the same data type as the accumulator/output.
409  """
410  implements(ConvolutionOpInterface)
411  domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
412  O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
413      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
414           D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
415               U, K[D.kd, D.kh, D.kw, D.c, D.f])
416
417
418@linalg_structured_op
419def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
420                                         S.IC),
421                             K=TensorDef(T2, S.KW, S.IC),
422                             O=TensorDef(U, S.N, S.OW, S.IC, output=True),
423                             strides=IndexAttrDef(S.SW, default=[1]),
424                             dilations=IndexAttrDef(S.DW, default=[1])):
425  """Performs depth-wise 1-D convolution.
426
427  Numeric casting is performed on the operands to the inner multiply, promoting
428  them to the same data type as the accumulator/output. Multiplier is set to 1
429  which is a special case for most depthwise convolutions.
430  """
431  implements(ConvolutionOpInterface)
432  domain(D.n, D.ow, D.ic, D.kw)
433  O[D.n, D.ow, D.ic] += \
434      TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
435      TypeFn.cast_signed(U, K[D.kw, D.ic])
436
437
438@linalg_structured_op
439def depthwise_conv_1d_nwc_wcm(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
440                                          S.IC),
441                              K=TensorDef(T2, S.KW, S.IC, S.CM),
442                              O=TensorDef(U, S.N, S.OW, S.IC, S.CM,
443                                          output=True),
444                              strides=IndexAttrDef(S.SW, default=[1]),
445                              dilations=IndexAttrDef(S.DW, default=[1])):
446  """Performs depth-wise 1-D convolution.
447
448  Numeric casting is performed on the operands to the inner multiply, promoting
449  them to the same data type as the accumulator/output.
450  """
451  implements(ConvolutionOpInterface)
452  domain(D.n, D.ow, D.ic, D.cm, D.kw)
453  O[D.n, D.ow, D.ic, D.cm] += \
454      TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
455      TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm])
456
457
458@linalg_structured_op
459def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
460                                           S.OW * S.SW + S.KW * S.DW, S.IC),
461                               K=TensorDef(T2, S.KH, S.KW, S.IC),
462                               O=TensorDef(U,
463                                           S.N,
464                                           S.OH,
465                                           S.OW,
466                                           S.IC,
467                                           output=True),
468                               strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
469                               dilations=IndexAttrDef(S.DH,
470                                                      S.DW,
471                                                      default=[1, 1])):
472  """Performs depth-wise 2-D convolution.
473
474  Numeric casting is performed on the operands to the inner multiply, promoting
475  them to the same data type as the accumulator/output. Multiplier is set to 1
476  which is a special case for most depthwise convolutions.
477  """
478  implements(ConvolutionOpInterface)
479  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
480  O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
481      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
482           D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
483
484
485@linalg_structured_op
486def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH,
487                                           S.OW * S.SW + S.KW * S.DW),
488                               K=TensorDef(T2, S.IC, S.KH, S.KW),
489                               O=TensorDef(U,
490                                           S.N,
491                                           S.IC,
492                                           S.OH,
493                                           S.OW,
494                                           output=True),
495                               strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
496                               dilations=IndexAttrDef(S.DH,
497                                                      S.DW,
498                                                      default=[1, 1])):
499  """Performs depth-wise 2-D convolution.
500
501  Numeric casting is performed on the operands to the inner multiply, promoting
502  them to the same data type as the accumulator/output. Multiplier is set to 1
503  which is a special case for most depthwise convolutions.
504  """
505  implements(ConvolutionOpInterface)
506  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
507  O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
508      U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
509
510
511@linalg_structured_op
512def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
513                                             S.OW * S.SW + S.KW * S.DW, S.IC),
514                                 K=TensorDef(T2, S.KH, S.KW, S.IC),
515                                 IZp=ScalarDef(I32),
516                                 KZp=ScalarDef(I32),
517                                 O=TensorDef(U,
518                                             S.N,
519                                             S.OH,
520                                             S.OW,
521                                             S.IC,
522                                             output=True),
523                                 strides=IndexAttrDef(S.SH,
524                                                      S.SW,
525                                                      default=[1, 1]),
526                                 dilations=IndexAttrDef(S.DH,
527                                                        S.DW,
528                                                        default=[1, 1])):
529  """Performs depth-wise 2-D convolution.
530
531  Numeric casting is performed on the operands to the inner multiply, promoting
532  them to the same data type as the accumulator/output.
533  """
534  implements(ConvolutionOpInterface)
535  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
536  O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed(
537      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
538                                TypeFn.cast_signed(U, IZp)) *
539                               (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) -
540                                TypeFn.cast_signed(U, KZp)))
541
542
543@linalg_structured_op
544def depthwise_conv_2d_nhwc_hwcm(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
545                                            S.OW * S.SW + S.KW * S.DW, S.IC),
546                                K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
547                                O=TensorDef(U,
548                                            S.N,
549                                            S.OH,
550                                            S.OW,
551                                            S.IC,
552                                            S.CM,
553                                            output=True),
554                                strides=IndexAttrDef(S.SH, S.SW, default=[1,
555                                                                          1]),
556                                dilations=IndexAttrDef(S.DH,
557                                                       S.DW,
558                                                       default=[1, 1])):
559  """Performs depth-wise 2-D convolution.
560
561  Numeric casting is performed on the operands to the inner multiply, promoting
562  them to the same data type as the accumulator/output.
563  """
564  implements(ConvolutionOpInterface)
565  domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
566  O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
567      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
568           D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
569
570
571@linalg_structured_op
572def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N,
573                                              S.OH * S.SH + S.KH * S.DH,
574                                              S.OW * S.SW + S.KW * S.DW, S.IC),
575                                  K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
576                                  IZp=ScalarDef(I32),
577                                  KZp=ScalarDef(I32),
578                                  O=TensorDef(U,
579                                              S.N,
580                                              S.OH,
581                                              S.OW,
582                                              S.IC,
583                                              S.CM,
584                                              output=True),
585                                  strides=IndexAttrDef(S.SH,
586                                                       S.SW,
587                                                       default=[1, 1]),
588                                  dilations=IndexAttrDef(S.DH,
589                                                         S.DW,
590                                                         default=[1, 1])):
591  """Performs depth-wise 2-D convolution.
592
593  Numeric casting is performed on the operands to the inner multiply, promoting
594  them to the same data type as the accumulator/output.
595  """
596  implements(ConvolutionOpInterface)
597  domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
598  O[D.n, D.oh, D.ow, D.ic,
599    D.cm] += ((TypeFn.cast_signed(
600        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
601               TypeFn.cast_signed(U, IZp)) *
602              (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) -
603               TypeFn.cast_signed(U, KZp)))
604
605
606@linalg_structured_op
607def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
608                                             S.OH * S.SH + S.KH * S.DH,
609                                             S.OW * S.SW + S.KW * S.DW, S.IC),
610                                 K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
611                                 O=TensorDef(U, S.N, S.OD, S.OH, S.OW,
612                                             output=True),
613                                 strides=IndexAttrDef(S.SD,
614                                                      S.SH,
615                                                      S.SW,
616                                                      default=[1, 1, 1]),
617                                 dilations=IndexAttrDef(S.DD,
618                                                        S.DH,
619                                                        S.DW,
620                                                        default=[1, 1, 1])):
621  """Performs depth-wise 3-D convolution.
622
623  Numeric casting is performed on the operands to the inner multiply, promoting
624  them to the same data type as the accumulator/output. Multiplier is set to 1
625  which is a special case for most depthwise convolutions.
626  """
627  implements(ConvolutionOpInterface)
628  domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
629  O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
630      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
631           D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed(
632               U, K[D.kd, D.kh, D.kw, D.ic])
633
634
635@linalg_structured_op
636def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1,
637                                              S.N, S.OD * S.SD + S.KD * S.DD,
638                                              S.OH * S.SH + S.KH * S.DH,
639                                              S.OW * S.SW + S.KW * S.DW, S.IC),
640                                  K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
641                                  O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM,
642                                              output=True),
643                                  strides=IndexAttrDef(S.SD,
644                                                       S.SH,
645                                                       S.SW,
646                                                       default=[1, 1, 1]),
647                                  dilations=IndexAttrDef(S.DD,
648                                                         S.DH,
649                                                         S.DW,
650                                                         default=[1, 1, 1])):
651  """Performs depth-wise 3-D convolution.
652
653  Numeric casting is performed on the operands to the inner multiply, promoting
654  them to the same data type as the accumulator/output.
655  """
656  implements(ConvolutionOpInterface)
657  domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
658  O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
659      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
660           D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed(
661               U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
662
663
664@linalg_structured_op
665def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
666                                 S.OW * S.SW + S.KW * S.DW, S.C),
667                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
668                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
669                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
670                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
671  """Performs sum pooling.
672
673  Layout:
674    * Input: NHWC.
675    * Kernel: HW.
676
677  Numeric casting is performed on the input operand, promoting it to the same
678  data type as the accumulator/output.
679  """
680  implements(ConvolutionOpInterface)
681  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
682  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
683      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
684
685
686@linalg_structured_op
687def pooling_nchw_sum(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
688                                 S.OW * S.SW + S.KW * S.DW),
689                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
690                     O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
691                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
692                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
693  """Performs sum pooling.
694
695  Layout:
696    * Input: NCHW.
697    * Kernel: HW.
698
699  Numeric casting is performed on the input operand, promoting it to the same
700  data type as the accumulator/output.
701  """
702  implements(ConvolutionOpInterface)
703  domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
704  O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed(
705      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW])
706
707
708@linalg_structured_op
709def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
710                                 S.OW * S.SW + S.KW * S.DW, S.C),
711                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
712                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
713                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
714                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
715  """Performs max pooling.
716
717  Numeric casting is performed on the input operand, promoting it to the same
718  data type as the accumulator/output.
719  """
720  implements(ConvolutionOpInterface)
721  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
722  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
723      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
724
725
726@linalg_structured_op
727def pooling_nhwc_max_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
728                                          S.OW * S.SW + S.KW * S.DW, S.C),
729                              K=TensorDef(T2,
730                                          S.KH,
731                                          S.KW,
732                                          index_dims=[D.kh, D.kw]),
733                              O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
734                              strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
735                              dilations=IndexAttrDef(S.DH, S.DW, default=[1,
736                                                                          1])):
737  """Performs unsigned max pooling.
738
739  Numeric casting is performed on the input operand, promoting it to the same
740  data type as the accumulator/output.
741  """
742  implements(ConvolutionOpInterface)
743  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
744  O[D.n, D.oh, D.ow,
745    D.c] = ReduceFn.max_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
746        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
747
748
749@linalg_structured_op
750def pooling_nchw_max(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
751                                 S.OW * S.SW + S.KW * S.DW),
752                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
753                     O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
754                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
755                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
756  """Performs max pooling.
757
758  Numeric casting is performed on the input operand, promoting it to the same
759  data type as the accumulator/output.
760  """
761  implements(ConvolutionOpInterface)
762  domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
763  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
764      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,]))
765
766
767@linalg_structured_op
768def pooling_nhwc_min(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
769                                 S.OW * S.SW + S.KW * S.DW, S.C),
770                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
771                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
772                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
773                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
774  """Performs min pooling.
775
776  Numeric casting is performed on the input operand, promoting it to the same
777  data type as the accumulator/output.
778  """
779  implements(ConvolutionOpInterface)
780  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
781  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](TypeFn.cast_signed(
782      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
783
784
785@linalg_structured_op
786def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
787                                          S.OW * S.SW + S.KW * S.DW, S.C),
788                              K=TensorDef(T2,
789                                          S.KH,
790                                          S.KW,
791                                          index_dims=[D.kh, D.kw]),
792                              O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
793                              strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
794                              dilations=IndexAttrDef(S.DH, S.DW, default=[1,
795                                                                          1])):
796  """Performs unsigned min pooling.
797
798  Numeric casting is performed on the input operand, promoting it to the same
799  data type as the accumulator/output.
800  """
801  implements(ConvolutionOpInterface)
802  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
803  O[D.n, D.oh, D.ow,
804    D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
805        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
806
807
808@linalg_structured_op
809def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
810                                  S.OH * S.SH + S.KH * S.DH,
811                                  S.OW * S.SW + S.KW * S.DW, S.C),
812                      K=TensorDef(T2,
813                                  S.KD,
814                                  S.KH,
815                                  S.KW,
816                                  index_dims=[D.kd, D.kh, D.kw]),
817                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
818                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
819                      dilations=IndexAttrDef(S.DD,
820                                             S.DH,
821                                             S.DW,
822                                             default=[1, 1, 1])):
823  """Performs 3D sum pooling.
824
825  Numeric casting is performed on the input operand, promoting it to the same
826  data type as the accumulator/output.
827  """
828  implements(ConvolutionOpInterface)
829  domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
830  O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
831      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
832           D.ow * S.SW + D.kw * S.DW, D.c])
833
834
835@linalg_structured_op
836def pooling_ndhwc_max(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
837                                  S.OH * S.SH + S.KH * S.DH,
838                                  S.OW * S.SW + S.KW * S.DW, S.C),
839                      K=TensorDef(T2,
840                                  S.KD,
841                                  S.KH,
842                                  S.KW,
843                                  index_dims=[D.kd, D.kh, D.kw]),
844                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
845                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
846                      dilations=IndexAttrDef(S.DD,
847                                             S.DH,
848                                             S.DW,
849                                             default=[1, 1, 1])):
850  """Performs 3D max pooling.
851
852  Numeric casting is performed on the input operand, promoting it to the same
853  data type as the accumulator/output.
854  """
855  implements(ConvolutionOpInterface)
856  domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
857  O[D.n, D.od, D.oh, D.ow,
858    D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
859        U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
860             D.ow * S.SW + D.kw * S.DW, D.c]))
861
862
863@linalg_structured_op
864def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
865                                  S.OH * S.SH + S.KH * S.DH,
866                                  S.OW * S.SW + S.KW * S.DW, S.C),
867                      K=TensorDef(T2,
868                                  S.KD,
869                                  S.KH,
870                                  S.KW,
871                                  index_dims=[D.kd, D.kh, D.kw]),
872                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
873                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
874                      dilations=IndexAttrDef(S.DD,
875                                             S.DH,
876                                             S.DW,
877                                             default=[1, 1, 1])):
878  """Performs 3D min pooling.
879
880  Numeric casting is performed on the input operand, promoting it to the same
881  data type as the accumulator/output.
882  """
883  implements(ConvolutionOpInterface)
884  domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
885  O[D.n, D.od, D.oh, D.ow,
886    D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
887        U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
888             D.ow * S.SW + D.kw * S.DW, D.c]))
889
890
891@linalg_structured_op
892def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
893  """Fills the output tensor with the given value.
894
895  Works for arbitrary ranked output tensors since the operation performs scalar
896  accesses only and is thus rank polymorphic. Numeric casting is performed on
897  the value operand, promoting it to the same data type as the output.
898  """
899  implements(FillOpInterface)
900  defines(Canonicalizer)
901  O[None] = TypeFn.cast_signed(U, value)
902
903
904@linalg_structured_op
905def fill_rng_2d(min=ScalarDef(F64),
906                max=ScalarDef(F64),
907                seed=ScalarDef(I32),
908                O=TensorDef(T, S.M, S.N, output=True)):
909  """Fills the output tensor with pseudo random numbers.
910
911  The operation generations pseudo random numbers using a linear congruential
912  generator. It provides no guarantees regarding the distribution of the
913  generated random numbers. Instead of generating the random numbers
914  sequentially, it instantiates one random number generator per data element
915  and runs them in parallel. The seed operand and the indices of the data
916  element seed the random number generation. The min and max operands limit
917  the range of the generated random numbers.
918  """
919  domain(D.m, D.n)
920  multiplier = TypeFn.cast_signed(I32, const(1103515245))
921  increment = TypeFn.cast_signed(I32, const(12345))
922  rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
923  rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
924  inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
925  offset = TypeFn.cast_signed(F64, const(2147483647))
926  scaling = (max - min) * inv_range
927  O[D.m, D.n] = TypeFn.cast_signed(
928      T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
929