1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::tensor;
22 
23 namespace mlir {
24 namespace tensor {
25 namespace {
26 
27 struct CastOpInterface
28     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
29                                                     tensor::CastOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::CastOpInterface30   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31                               const AnalysisState &state) const {
32     return false;
33   }
34 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::CastOpInterface35   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
36                                const AnalysisState &state) const {
37     return false;
38   }
39 
getAliasingOpResultmlir::tensor::__anonb90e36390111::CastOpInterface40   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
41                                             const AnalysisState &state) const {
42     return {op->getResult(0)};
43   }
44 
bufferRelationmlir::tensor::__anonb90e36390111::CastOpInterface45   BufferRelation bufferRelation(Operation *op, OpResult opResult,
46                                 const AnalysisState &state) const {
47     return BufferRelation::Equivalent;
48   }
49 
bufferizemlir::tensor::__anonb90e36390111::CastOpInterface50   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51                           const BufferizationOptions &options) const {
52     auto castOp = cast<tensor::CastOp>(op);
53 
54     // The result buffer still has the old (pre-cast) type.
55     FailureOr<Value> resultBuffer =
56         getBuffer(rewriter, castOp.getSource(), options);
57     if (failed(resultBuffer))
58       return failure();
59     auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
60     TensorType resultTensorType =
61         castOp.getResult().getType().cast<TensorType>();
62     MemRefLayoutAttrInterface layout;
63 
64     if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65       if (resultTensorType.isa<RankedTensorType>())
66         layout = rankedMemRefType.getLayout();
67 
68     // Compute the new memref type.
69     Type resultMemRefType =
70         getMemRefType(castOp.getResult(), options, layout,
71                       sourceMemRefType.getMemorySpaceAsInt());
72 
73     // Replace the op with a memref.cast.
74     assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
75                                              resultMemRefType) &&
76            "CallOp::bufferize: cast incompatible");
77     replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
78                                                  *resultBuffer);
79 
80     return success();
81   }
82 };
83 
84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
85 struct CollapseShapeOpInterface
86     : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
87                                                     tensor::CollapseShapeOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::CollapseShapeOpInterface88   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
89                               const AnalysisState &state) const {
90     return false;
91   }
92 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::CollapseShapeOpInterface93   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
94                                const AnalysisState &state) const {
95     return false;
96   }
97 
getAliasingOpResultmlir::tensor::__anonb90e36390111::CollapseShapeOpInterface98   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
99                                             const AnalysisState &state) const {
100     if (&opOperand == &op->getOpOperand(0) /*src*/)
101       return {op->getOpResult(0)};
102     return {};
103   }
104 
bufferRelationmlir::tensor::__anonb90e36390111::CollapseShapeOpInterface105   BufferRelation bufferRelation(Operation *op, OpResult opResult,
106                                 const AnalysisState &state) const {
107     return BufferRelation::Equivalent;
108   }
109 
bufferizemlir::tensor::__anonb90e36390111::CollapseShapeOpInterface110   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111                           const BufferizationOptions &options) const {
112     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113     RankedTensorType tensorResultType = collapseShapeOp.getResultType();
114     FailureOr<Value> maybeBuffer =
115         getBuffer(rewriter, collapseShapeOp.getSrc(), options);
116     if (failed(maybeBuffer))
117       return failure();
118     Value buffer = *maybeBuffer;
119     auto bufferType = buffer.getType().cast<MemRefType>();
120 
121     if (tensorResultType.getRank() == 0) {
122       // 0-d collapses must go through a different op builder.
123       MemRefType resultType;
124 
125       if (bufferType.getLayout().isIdentity()) {
126         // Standard layout: result type has no offset.
127         MemRefLayoutAttrInterface layout;
128         resultType = MemRefType::get({}, tensorResultType.getElementType(),
129                                      layout, bufferType.getMemorySpace());
130       } else {
131         // Source memref has a layout map: result type has the same offset as
132         // the source type.
133         SmallVector<int64_t> strides;
134         int64_t offset;
135         if (failed(getStridesAndOffset(bufferType, strides, offset)))
136           return failure();
137         AffineMap resultLayout =
138             makeStridedLinearLayoutMap({}, offset, op->getContext());
139         resultType =
140             MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
141                             bufferType.getMemorySpaceAsInt());
142       }
143 
144       replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
145           rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
146       return success();
147     }
148 
149     // If the dims are not collapsible (due to an incompatible source layout
150     // map), force an out-of-place bufferization, i.e., a buffer copy. This
151     // newly allocated buffer will have no layout map and thus be collapsible.
152     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
153         bufferType, collapseShapeOp.getReassociationIndices());
154     if (!canBeCollapsed) {
155       // TODO: Create alloc_tensor ops during TensorCopyInsertion.
156       AnalysisState analysisState(options);
157       FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
158           rewriter, op->getLoc(), collapseShapeOp.getSrc(),
159           analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
160       if (failed(tensorAlloc))
161         return failure();
162       auto memrefType =
163           MemRefType::get(collapseShapeOp.getSrcType().getShape(),
164                           collapseShapeOp.getSrcType().getElementType(),
165                           AffineMap(), bufferType.getMemorySpaceAsInt());
166       buffer = rewriter.create<bufferization::ToMemrefOp>(
167           op->getLoc(), memrefType, *tensorAlloc);
168     }
169 
170     // Result type is inferred by the builder.
171     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
172         rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
173     return success();
174   }
175 };
176 
177 /// Bufferization of tensor.dim. Replace with memref.dim.
178 struct DimOpInterface
179     : public BufferizableOpInterface::ExternalModel<DimOpInterface,
180                                                     tensor::DimOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::DimOpInterface181   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
182                               const AnalysisState &state) const {
183     return true;
184   }
185 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::DimOpInterface186   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
187                                const AnalysisState &state) const {
188     return false;
189   }
190 
getAliasingOpResultmlir::tensor::__anonb90e36390111::DimOpInterface191   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
192                                             const AnalysisState &state) const {
193     return {};
194   }
195 
bufferizemlir::tensor::__anonb90e36390111::DimOpInterface196   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
197                           const BufferizationOptions &options) const {
198     auto dimOp = cast<tensor::DimOp>(op);
199     FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
200     if (failed(v))
201       return failure();
202     replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
203                                                 dimOp.getIndex());
204     return success();
205   }
206 };
207 
208 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
209 struct ExpandShapeOpInterface
210     : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
211                                                     tensor::ExpandShapeOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ExpandShapeOpInterface212   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
213                               const AnalysisState &state) const {
214     return false;
215   }
216 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ExpandShapeOpInterface217   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
218                                const AnalysisState &state) const {
219     return false;
220   }
221 
getAliasingOpResultmlir::tensor::__anonb90e36390111::ExpandShapeOpInterface222   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
223                                             const AnalysisState &state) const {
224     if (&opOperand == &op->getOpOperand(0) /*src*/)
225       return {op->getOpResult(0)};
226     return {};
227   }
228 
bufferRelationmlir::tensor::__anonb90e36390111::ExpandShapeOpInterface229   BufferRelation bufferRelation(Operation *op, OpResult opResult,
230                                 const AnalysisState &state) const {
231     return BufferRelation::Equivalent;
232   }
233 
bufferizemlir::tensor::__anonb90e36390111::ExpandShapeOpInterface234   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
235                           const BufferizationOptions &options) const {
236     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
237     auto tensorResultType = expandShapeOp.getResultType();
238     FailureOr<Value> buffer =
239         getBuffer(rewriter, expandShapeOp.getSrc(), options);
240     if (failed(buffer))
241       return failure();
242 
243     // Memref result type is inferred by the builder based on reassociation
244     // indices and result shape.
245     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
246         rewriter, op, tensorResultType.getShape(), *buffer,
247         expandShapeOp.getReassociationIndices());
248     return success();
249   }
250 };
251 
252 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
253 struct ExtractSliceOpInterface
254     : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
255                                                     tensor::ExtractSliceOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ExtractSliceOpInterface256   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
257                               const AnalysisState &state) const {
258     return false;
259   }
260 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ExtractSliceOpInterface261   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
262                                const AnalysisState &state) const {
263     return false;
264   }
265 
getAliasingOpResultmlir::tensor::__anonb90e36390111::ExtractSliceOpInterface266   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
267                                             const AnalysisState &state) const {
268     if (&opOperand == &op->getOpOperand(0) /*source*/)
269       return {op->getOpResult(0)};
270     return {};
271   }
272 
bufferRelationmlir::tensor::__anonb90e36390111::ExtractSliceOpInterface273   BufferRelation bufferRelation(Operation *op, OpResult opResult,
274                                 const AnalysisState &state) const {
275     return BufferRelation::None;
276   }
277 
bufferizemlir::tensor::__anonb90e36390111::ExtractSliceOpInterface278   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
279                           const BufferizationOptions &options) const {
280     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
281     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
282     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
283     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
284     Location loc = extractSliceOp.getLoc();
285 
286     // Get source buffer.
287     FailureOr<Value> srcMemref =
288         getBuffer(rewriter, extractSliceOp.getSource(), options);
289     if (failed(srcMemref))
290       return failure();
291     auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
292 
293     // Take a subview of the source buffer.
294     auto subviewMemRefType =
295         memref::SubViewOp::inferRankReducedResultType(
296             extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets,
297             mixedSizes, mixedStrides)
298             .cast<MemRefType>();
299     Value subView = rewriter.create<memref::SubViewOp>(
300         loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
301         mixedStrides);
302 
303     replaceOpWithBufferizedValues(rewriter, op, subView);
304     return success();
305   }
306 };
307 
308 /// Bufferization of tensor.extract. Replace with memref.load.
309 struct ExtractOpInterface
310     : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
311                                                     tensor::ExtractOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ExtractOpInterface312   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
313                               const AnalysisState &state) const {
314     return true;
315   }
316 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ExtractOpInterface317   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
318                                const AnalysisState &state) const {
319     return false;
320   }
321 
getAliasingOpResultmlir::tensor::__anonb90e36390111::ExtractOpInterface322   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
323                                             const AnalysisState &state) const {
324     return {};
325   }
326 
bufferizemlir::tensor::__anonb90e36390111::ExtractOpInterface327   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
328                           const BufferizationOptions &options) const {
329     auto extractOp = cast<tensor::ExtractOp>(op);
330     FailureOr<Value> srcMemref =
331         getBuffer(rewriter, extractOp.getTensor(), options);
332     if (failed(srcMemref))
333       return failure();
334     replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
335                                                  extractOp.getIndices());
336     return success();
337   }
338 };
339 
340 // Implements backtracking to traverse indices of the output buffer while
341 // iterating over op.elements().
createStores(RewriterBase & rewriter,Location loc,int dim,Value buffer,ArrayRef<int64_t> shape,ArrayRef<Value> constants,OperandRange::iterator & elementIt,SmallVectorImpl<Value> & indices)342 static void createStores(RewriterBase &rewriter, Location loc, int dim,
343                          Value buffer, ArrayRef<int64_t> shape,
344                          ArrayRef<Value> constants,
345                          OperandRange::iterator &elementIt,
346                          SmallVectorImpl<Value> &indices) {
347   if (dim == static_cast<int>(shape.size()) - 1) {
348     for (int i = 0; i < shape.back(); ++i) {
349       indices.back() = constants[i];
350       rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
351       ++elementIt;
352     }
353     return;
354   }
355   for (int i = 0; i < shape[dim]; ++i) {
356     indices[dim] = constants[i];
357     createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
358                  indices);
359   }
360 }
361 
362 /// Bufferization of tensor.from_elements.
363 struct FromElementsOpInterface
364     : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
365                                                     tensor::FromElementsOp> {
366 
bufferizesToAllocationmlir::tensor::__anonb90e36390111::FromElementsOpInterface367   bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
368     return true;
369   }
370 
bufferizemlir::tensor::__anonb90e36390111::FromElementsOpInterface371   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
372                           const BufferizationOptions &options) const {
373     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
374     // Should the buffer be deallocated?
375     bool dealloc = shouldDeallocateOpResult(
376         fromElementsOp.getResult().cast<OpResult>(), options);
377 
378     // TODO: Implement memory space for this op.
379     if (options.defaultMemorySpace != static_cast<unsigned>(0))
380       return op->emitError("memory space not implemented yet");
381 
382     // Allocate a buffer for the result.
383     Location loc = op->getLoc();
384     auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
385     auto shape = tensorType.getShape();
386     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
387     FailureOr<Value> tensorAlloc =
388         allocateTensorForShapedValue(rewriter, loc, fromElementsOp.getResult(),
389                                      /*escape=*/!dealloc, options,
390                                      /*copy=*/false);
391     if (failed(tensorAlloc))
392       return failure();
393     auto memrefType =
394         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
395     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
396         op->getLoc(), memrefType, *tensorAlloc);
397 
398     // Case: tensor<0xelem_type>.
399     if (fromElementsOp.getElements().empty()) {
400       replaceOpWithBufferizedValues(rewriter, op, buffer);
401       return success();
402     }
403 
404     // Case: tensor<elem_type>.
405     if (shape.empty()) {
406       rewriter.create<memref::StoreOp>(
407           loc, fromElementsOp.getElements().front(), buffer);
408       replaceOpWithBufferizedValues(rewriter, op, buffer);
409       return success();
410     }
411 
412     // Create constants for the range of possible indices [0, max{shape_i}).
413     auto maxDim = *std::max_element(shape.begin(), shape.end());
414     SmallVector<Value, 2> constants;
415     constants.reserve(maxDim);
416     for (int i = 0; i < maxDim; ++i)
417       constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
418 
419     // Traverse all `elements` and create `memref.store` ops.
420     auto elementIt = fromElementsOp.getElements().begin();
421     SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
422     createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
423                  indices);
424 
425     replaceOpWithBufferizedValues(rewriter, op, buffer);
426 
427     return success();
428   }
429 };
430 
431 /// Bufferization of tensor.generate.
432 struct GenerateOpInterface
433     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
434                                                     tensor::GenerateOp> {
435 
bufferizesToAllocationmlir::tensor::__anonb90e36390111::GenerateOpInterface436   bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
437     return true;
438   }
439 
bufferizemlir::tensor::__anonb90e36390111::GenerateOpInterface440   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
441                           const BufferizationOptions &options) const {
442     auto generateOp = cast<tensor::GenerateOp>(op);
443     // Should the buffer be deallocated?
444     bool dealloc = shouldDeallocateOpResult(
445         generateOp.getResult().cast<OpResult>(), options);
446 
447     // TODO: Implement memory space for this op.
448     if (options.defaultMemorySpace != static_cast<unsigned>(0))
449       return op->emitError("memory space not implemented yet");
450 
451     auto tensorType = generateOp.getType().cast<RankedTensorType>();
452     // Allocate memory.
453     Location loc = op->getLoc();
454     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
455     FailureOr<Value> tensorAlloc =
456         allocateTensorForShapedValue(rewriter, loc, generateOp.getResult(),
457                                      /*escape=*/!dealloc, options,
458                                      /*copy=*/false);
459     if (failed(tensorAlloc))
460       return failure();
461     auto memrefType =
462         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
463     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
464         op->getLoc(), memrefType, *tensorAlloc);
465 
466     // Collect loop bounds.
467     int64_t rank = memrefType.getRank();
468     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
469     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
470     SmallVector<Value, 4> lowerBounds(rank, zero);
471     SmallVector<Value, 4> steps(rank, one);
472     SmallVector<Value, 4> upperBounds;
473     int nextDynamicIndex = 0;
474     for (int i = 0; i < rank; i++) {
475       Value upperBound =
476           memrefType.isDynamicDim(i)
477               ? generateOp.getDynamicExtents()[nextDynamicIndex++]
478               : rewriter.create<arith::ConstantIndexOp>(
479                     loc, memrefType.getDimSize(i));
480       upperBounds.push_back(upperBound);
481     }
482 
483     // Generate tensor elements with a parallel loop that stores into
484     // each element of the resulting memref. We use mergeBlockBefore to "move"
485     // this op's body into the scf.parallel's body.
486     auto parallel =
487         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
488     Block *parallelBody = parallel.getBody();
489     rewriter.mergeBlockBefore(&generateOp.getBody().front(),
490                               parallelBody->getTerminator(),
491                               parallelBody->getArguments());
492     // Replace the inlined yield op with a store op. The scf.parallel's builder
493     // already populated an scf.yield at the end, so we don't need to worry
494     // about creating that.
495     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
496     rewriter.setInsertionPointAfter(elementYield);
497     rewriter.replaceOpWithNewOp<memref::StoreOp>(
498         elementYield, elementYield->getOperands()[0], buffer,
499         parallelBody->getArguments());
500 
501     replaceOpWithBufferizedValues(rewriter, op, buffer);
502 
503     return success();
504   }
505 };
506 
507 /// Bufferization of tensor.insert. Replace with memref.store.
508 struct InsertOpInterface
509     : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
510                                                     tensor::InsertOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::InsertOpInterface511   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
512                               const AnalysisState &state) const {
513     return true;
514   }
515 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::InsertOpInterface516   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
517                                const AnalysisState &state) const {
518     return true;
519   }
520 
getAliasingOpResultmlir::tensor::__anonb90e36390111::InsertOpInterface521   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
522                                             const AnalysisState &state) const {
523     assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
524            "expected dest OpOperand");
525     return {op->getOpResult(0)};
526   }
527 
528   SmallVector<OpOperand *>
getAliasingOpOperandmlir::tensor::__anonb90e36390111::InsertOpInterface529   getAliasingOpOperand(Operation *op, OpResult opResult,
530                        const AnalysisState &state) const {
531     return {&op->getOpOperand(1) /*dest*/};
532   }
533 
bufferizemlir::tensor::__anonb90e36390111::InsertOpInterface534   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
535                           const BufferizationOptions &options) const {
536     auto insertOp = cast<tensor::InsertOp>(op);
537     FailureOr<Value> destMemref =
538         getBuffer(rewriter, insertOp.getDest(), options);
539     if (failed(destMemref))
540       return failure();
541     rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
542                                      *destMemref, insertOp.getIndices());
543     replaceOpWithBufferizedValues(rewriter, op, *destMemref);
544     return success();
545   }
546 
bufferRelationmlir::tensor::__anonb90e36390111::InsertOpInterface547   BufferRelation bufferRelation(Operation *op, OpResult opResult,
548                                 const AnalysisState &state) const {
549     return BufferRelation::Equivalent;
550   }
551 };
552 
553 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
554 /// equivalent operand / result and same offset/sizes/strides specification).
555 template <typename OpTy>
areEquivalentExtractSliceOps(const AnalysisState & state,ExtractSliceOp extractSliceOp,OpTy insertSliceOp)556 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
557                                          ExtractSliceOp extractSliceOp,
558                                          OpTy insertSliceOp) {
559   if (!extractSliceOp || !insertSliceOp)
560     return false;
561   if (extractSliceOp != insertSliceOp &&
562       !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
563                                            insertSliceOp.getDest()))
564     return false;
565   if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
566                                   isEqualConstantIntOrValue))
567     return false;
568   return true;
569 }
570 
571 /// Return true if `value` is originating from an ExtractSliceOp that matches
572 /// the given InsertSliceOp.
573 template <typename OpTy>
hasMatchingExtractSliceOp(const AnalysisState & state,Value value,OpTy insertSliceOp)574 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
575                                       OpTy insertSliceOp) {
576   auto condition = [&](Value val) {
577     if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
578       if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp))
579         return true;
580     return false;
581   };
582 
583   return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
584                       condition);
585 }
586 
587 template <typename OpTy>
isNotConflictingInsertSliceLikeOp(Operation * op,OpOperand * uRead,OpOperand * uConflictingWrite,const AnalysisState & state)588 static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
589                                               OpOperand *uConflictingWrite,
590                                               const AnalysisState &state) {
591   Operation *readingOp = uRead->getOwner();
592   Operation *conflictingWritingOp = uConflictingWrite->getOwner();
593 
594   // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
595   // uRead is an InsertSliceOp...
596   if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
597     // As an example, consider the following IR.
598     //
599     // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
600     // %1 = linalg.fill %cst, %0 {inplace= [true] }
601     // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
602     //     {inplace= [true] }
603 
604     // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
605     if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
606         hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
607                                   insertSliceOp))
608       // Case 1: The main insight is that InsertSliceOp reads only part of
609       // the destination tensor. The overwritten area is not read. If
610       // uConflictingWrite writes into exactly the memory location that is
611       // being read by uRead, this is not a conflict.
612       //
613       // In the above example:
614       // uRead             = OpOperand 1 (%t) of tensor.insert_slice
615       // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
616       //
617       // The read of %t does not conflict with the write of the FillOp
618       // (same aliases!) because the area that the FillOp operates on is
619       // exactly the one that is *not* read via %t.
620       return true;
621 
622     if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
623         uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
624         hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
625       // Case 2: The read of the source tensor and the write to the dest
626       // tensor via an InsertSliceOp is not a conflict if the read is
627       // reading exactly that part of an equivalent tensor that the
628       // InsertSliceOp is writing.
629       //
630       // In the above example:
631       // uRead             = OpOperand 0 (%1) of tensor.insert_slice
632       // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
633       return true;
634   }
635 
636   // If uConflictingWrite is an InsertSliceOp...
637   if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
638     // As an example, consider the following IR.
639     //
640     // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
641     // %1 = linalg.fill %cst, %0 {inplace= [true] }
642     // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
643     //     {inplace= [true] }
644     // %3 = vector.transfer_read %1, %cst
645     //
646     // In the above example:
647     // uRead             = OpOperand 0 (%1) of vector.transfer_read
648     // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
649     // lastWrite         = %1
650     //
651     // This is not a conflict because the InsertSliceOp overwrites the
652     // memory segment of %1 with the exact same data. (Effectively, there
653     // is no memory write here.)
654     if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
655         state.areEquivalentBufferizedValues(uRead->get(),
656                                             insertSliceOp.getSource()) &&
657         hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
658                                   insertSliceOp))
659       return true;
660 
661   return false;
662 }
663 
664 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
665 /// certain circumstances, this op can also be a no-op.
666 struct InsertSliceOpInterface
667     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
668                                                     tensor::InsertSliceOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::InsertSliceOpInterface669   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
670                               const AnalysisState &state) const {
671     return true;
672   }
673 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::InsertSliceOpInterface674   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
675                                const AnalysisState &state) const {
676     return &opOperand == &op->getOpOperand(1) /*dest*/;
677   }
678 
getAliasingOpResultmlir::tensor::__anonb90e36390111::InsertSliceOpInterface679   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
680                                             const AnalysisState &state) const {
681     if (&opOperand == &op->getOpOperand(1) /*dest*/)
682       return {op->getResult(0)};
683     return {};
684   }
685 
bufferRelationmlir::tensor::__anonb90e36390111::InsertSliceOpInterface686   BufferRelation bufferRelation(Operation *op, OpResult opResult,
687                                 const AnalysisState &state) const {
688     return BufferRelation::Equivalent;
689   }
690 
isNotConflictingmlir::tensor::__anonb90e36390111::InsertSliceOpInterface691   bool isNotConflicting(Operation *op, OpOperand *uRead,
692                         OpOperand *uConflictingWrite,
693                         const AnalysisState &state) const {
694     return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
695         op, uRead, uConflictingWrite, state);
696   }
697 
bufferizemlir::tensor::__anonb90e36390111::InsertSliceOpInterface698   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
699                           const BufferizationOptions &options) const {
700     // insert_slice ops arise from tiling and bufferizing them out-of-place is
701     // generally a deal breaker. When used with loops, this ends up cloning the
702     // whole tensor on every single iteration and is a symptom of a
703     // catastrophically bad scheduling decision.
704     // TODO: be very loud about it or even consider failing the pass.
705     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
706     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
707     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
708     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
709     Location loc = insertSliceOp.getLoc();
710 
711     // Get destination buffer.
712     FailureOr<Value> dstMemref =
713         getBuffer(rewriter, insertSliceOp.getDest(), options);
714     if (failed(dstMemref))
715       return failure();
716 
717     // Take a subview of the destination buffer.
718     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
719     auto subviewMemRefType =
720         memref::SubViewOp::inferRankReducedResultType(
721             insertSliceOp.getSourceType().getShape(), dstMemrefType,
722             mixedOffsets, mixedSizes, mixedStrides)
723             .cast<MemRefType>();
724     Value subView = rewriter.create<memref::SubViewOp>(
725         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
726         mixedStrides);
727 
728     // Copy tensor. If this tensor.insert_slice has a matching
729     // tensor.extract_slice, the copy operation will eventually fold away.
730     FailureOr<Value> srcMemref =
731         getBuffer(rewriter, insertSliceOp.getSource(), options);
732     if (failed(srcMemref))
733       return failure();
734     if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
735       return failure();
736 
737     replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
738     return success();
739   }
740 };
741 
742 /// Bufferization of tensor.rank. Replace with memref.rank.
743 struct RankOpInterface
744     : public BufferizableOpInterface::ExternalModel<RankOpInterface,
745                                                     tensor::RankOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::RankOpInterface746   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
747                               const AnalysisState &state) const {
748     return true;
749   }
750 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::RankOpInterface751   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
752                                const AnalysisState &state) const {
753     return false;
754   }
755 
getAliasingOpResultmlir::tensor::__anonb90e36390111::RankOpInterface756   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
757                                             const AnalysisState &state) const {
758     return {};
759   }
760 
bufferizemlir::tensor::__anonb90e36390111::RankOpInterface761   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
762                           const BufferizationOptions &options) const {
763     auto rankOp = cast<tensor::RankOp>(op);
764     FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
765     if (failed(v))
766       return failure();
767     replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
768                                                  *v);
769     return success();
770   }
771 };
772 
773 /// Bufferization of tensor.reshape. Replace with memref.reshape.
774 struct ReshapeOpInterface
775     : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
776                                                     tensor::ReshapeOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ReshapeOpInterface777   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
778                               const AnalysisState &state) const {
779     if (&opOperand == &op->getOpOperand(1) /* shape */)
780       return true;
781     return false;
782   }
783 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ReshapeOpInterface784   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
785                                const AnalysisState &state) const {
786     return false;
787   }
788 
getAliasingOpResultmlir::tensor::__anonb90e36390111::ReshapeOpInterface789   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
790                                             const AnalysisState &state) const {
791     return {op->getOpResult(0)};
792   }
793 
bufferRelationmlir::tensor::__anonb90e36390111::ReshapeOpInterface794   BufferRelation bufferRelation(Operation *op, OpResult opResult,
795                                 const AnalysisState &state) const {
796     return BufferRelation::Equivalent;
797   }
798 
bufferizemlir::tensor::__anonb90e36390111::ReshapeOpInterface799   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
800                           const BufferizationOptions &options) const {
801     auto reshapeOp = cast<tensor::ReshapeOp>(op);
802     FailureOr<Value> srcBuffer =
803         getBuffer(rewriter, reshapeOp.getSource(), options);
804     FailureOr<Value> shapeBuffer =
805         getBuffer(rewriter, reshapeOp.getShape(), options);
806     if (failed(srcBuffer) || failed(shapeBuffer))
807       return failure();
808     auto resultMemRefType = getMemRefType(
809         reshapeOp.getResult(), options, /*layout=*/{},
810         srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
811     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
812         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
813     return success();
814   }
815 };
816 
817 /// Analysis of ParallelInsertSliceOp.
818 struct ParallelInsertSliceOpInterface
819     : public BufferizableOpInterface::ExternalModel<
820           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
getAliasingOpResultmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface821   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
822                                             const AnalysisState &state) const {
823     if (&opOperand != &op->getOpOperand(1) /*dest*/)
824       return {};
825 
826     // ParallelInsertSliceOp itself has no results, query its tied op results.
827     auto insertOp = cast<ParallelInsertSliceOp>(op);
828     return {insertOp.getTiedOpResult()};
829   }
830 
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface831   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
832                               const AnalysisState &state) const {
833     return true;
834   }
835 
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface836   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
837                                const AnalysisState &state) const {
838     return &opOperand == &op->getOpOperand(1) /*dest*/;
839   }
840 
bufferRelationmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface841   BufferRelation bufferRelation(Operation *op, OpResult opResult,
842                                 const AnalysisState &state) const {
843     return BufferRelation::Equivalent;
844   }
845 
resolveConflictsmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface846   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
847                                  const AnalysisState &state) const {
848     // This interface method is overridden because we want to set a custom
849     // insertion point for tensor copies. They should be inserted right before
850     // the ForeachThreadOp. E.g.:
851     //
852     // %r0, %r1 = foreach_thead ... {
853     //   ...
854     //   perform_concurrently {
855     //     parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
856     //     parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
857     //   }
858     // }
859     //
860     // After TensorCopyInsertion:
861     //
862     // %copy = bufferization.alloc_tensor() copy(%d)
863     // %r0, %r1 = foreach_thead ... {
864     //   ...
865     //   perform_concurrently {
866     //     parallel_insert_slice %a into %b ...
867     //     parallel_insert_slice %c into %copy ...
868     //   }
869     // }
870 
871     OpBuilder::InsertionGuard g(rewriter);
872     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
873     ParallelCombiningOpInterface parallelCombiningParent =
874         parallelInsertSliceOp.getParallelCombiningParent();
875     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
876 
877     // Nothing to do if the destination tensor is inplace.
878     assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
879            "source is always in-place");
880     if (state.isInPlace(op->getOpOperand(1) /*dest*/))
881       return success();
882 
883     // Find corresponding OpResult.
884     OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
885 
886     // Insert tensor allocation right before the ForeachThreadOp.
887     rewriter.setInsertionPoint(parallelIteratingOp);
888     bool isYielded = state.isTensorYielded(opResult);
889     FailureOr<Value> alloc = allocateTensorForShapedValue(
890         rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
891         /*escape=*/isYielded, state.getOptions());
892     if (failed(alloc))
893       return failure();
894 
895     // Update destination operand.
896     rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
897       parallelInsertSliceOp.getDestMutable().assign(*alloc);
898     });
899 
900     return success();
901   }
902 
bufferizemlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface903   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
904                           const BufferizationOptions &options) const {
905     OpBuilder::InsertionGuard g(rewriter);
906     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
907     ParallelCombiningOpInterface parallelCombiningParent =
908         parallelInsertSliceOp.getParallelCombiningParent();
909     Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
910 
911     // Get destination buffer.
912     FailureOr<Value> destBuffer =
913         getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
914     if (failed(destBuffer))
915       return failure();
916 
917     // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
918     rewriter.setInsertionPoint(parallelCombiningParent);
919     FailureOr<Value> srcBuffer =
920         getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
921     if (failed(srcBuffer))
922       return failure();
923 
924     // Take a subview of the destination buffer.
925     auto destBufferType = destBuffer->getType().cast<MemRefType>();
926     auto subviewMemRefType =
927         memref::SubViewOp::inferRankReducedResultType(
928             parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
929             parallelInsertSliceOp.getMixedOffsets(),
930             parallelInsertSliceOp.getMixedSizes(),
931             parallelInsertSliceOp.getMixedStrides())
932             .cast<MemRefType>();
933     Value subview = rewriter.create<memref::SubViewOp>(
934         parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
935         parallelInsertSliceOp.getMixedOffsets(),
936         parallelInsertSliceOp.getMixedSizes(),
937         parallelInsertSliceOp.getMixedStrides());
938 
939     // This memcpy will fold away if everything bufferizes in-place.
940     if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
941                                     *srcBuffer, subview)))
942       return failure();
943 
944     // Replace all uses of parallelIteratingOp (just the corresponding result).
945     rewriter.setInsertionPointAfter(parallelIteratingOp);
946     Value toTensorOp =
947         rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
948     // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
949     SmallVector<OpOperand *> resultUses = llvm::to_vector(
950         llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
951                         [](OpOperand &use) { return &use; }));
952     for (OpOperand *use : resultUses) {
953       rewriter.updateRootInPlace(use->getOwner(),
954                                  [&]() { use->set(toTensorOp); });
955     }
956     rewriter.eraseOp(op);
957     return success();
958   }
959 
isNotConflictingmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface960   bool isNotConflicting(Operation *op, OpOperand *uRead,
961                         OpOperand *uConflictingWrite,
962                         const AnalysisState &state) const {
963     return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
964         op, uRead, uConflictingWrite, state);
965   }
966 };
967 
968 } // namespace
969 } // namespace tensor
970 } // namespace mlir
971 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)972 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
973     DialectRegistry &registry) {
974   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
975     CastOp::attachInterface<CastOpInterface>(*ctx);
976     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
977     DimOp::attachInterface<DimOpInterface>(*ctx);
978     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
979     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
980     ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
981     FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
982     GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
983     InsertOp::attachInterface<InsertOpInterface>(*ctx);
984     InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
985     ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
986         *ctx);
987     RankOp::attachInterface<RankOpInterface>(*ctx);
988     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
989 
990     // Load additional dialects of which ops may get created.
991     ctx->loadDialect<arith::ArithmeticDialect, scf::SCFDialect>();
992   });
993 }
994