1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18
19 using namespace mlir;
20 using namespace mlir::bufferization;
21 using namespace mlir::tensor;
22
23 namespace mlir {
24 namespace tensor {
25 namespace {
26
27 struct CastOpInterface
28 : public BufferizableOpInterface::ExternalModel<CastOpInterface,
29 tensor::CastOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::CastOpInterface30 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
31 const AnalysisState &state) const {
32 return false;
33 }
34
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::CastOpInterface35 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
36 const AnalysisState &state) const {
37 return false;
38 }
39
getAliasingOpResultmlir::tensor::__anonb90e36390111::CastOpInterface40 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
41 const AnalysisState &state) const {
42 return {op->getResult(0)};
43 }
44
bufferRelationmlir::tensor::__anonb90e36390111::CastOpInterface45 BufferRelation bufferRelation(Operation *op, OpResult opResult,
46 const AnalysisState &state) const {
47 return BufferRelation::Equivalent;
48 }
49
bufferizemlir::tensor::__anonb90e36390111::CastOpInterface50 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51 const BufferizationOptions &options) const {
52 auto castOp = cast<tensor::CastOp>(op);
53
54 // The result buffer still has the old (pre-cast) type.
55 FailureOr<Value> resultBuffer =
56 getBuffer(rewriter, castOp.getSource(), options);
57 if (failed(resultBuffer))
58 return failure();
59 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
60 TensorType resultTensorType =
61 castOp.getResult().getType().cast<TensorType>();
62 MemRefLayoutAttrInterface layout;
63
64 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
65 if (resultTensorType.isa<RankedTensorType>())
66 layout = rankedMemRefType.getLayout();
67
68 // Compute the new memref type.
69 Type resultMemRefType =
70 getMemRefType(castOp.getResult(), options, layout,
71 sourceMemRefType.getMemorySpaceAsInt());
72
73 // Replace the op with a memref.cast.
74 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
75 resultMemRefType) &&
76 "CallOp::bufferize: cast incompatible");
77 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
78 *resultBuffer);
79
80 return success();
81 }
82 };
83
84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
85 struct CollapseShapeOpInterface
86 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
87 tensor::CollapseShapeOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::CollapseShapeOpInterface88 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
89 const AnalysisState &state) const {
90 return false;
91 }
92
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::CollapseShapeOpInterface93 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
94 const AnalysisState &state) const {
95 return false;
96 }
97
getAliasingOpResultmlir::tensor::__anonb90e36390111::CollapseShapeOpInterface98 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
99 const AnalysisState &state) const {
100 if (&opOperand == &op->getOpOperand(0) /*src*/)
101 return {op->getOpResult(0)};
102 return {};
103 }
104
bufferRelationmlir::tensor::__anonb90e36390111::CollapseShapeOpInterface105 BufferRelation bufferRelation(Operation *op, OpResult opResult,
106 const AnalysisState &state) const {
107 return BufferRelation::Equivalent;
108 }
109
bufferizemlir::tensor::__anonb90e36390111::CollapseShapeOpInterface110 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111 const BufferizationOptions &options) const {
112 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
114 FailureOr<Value> maybeBuffer =
115 getBuffer(rewriter, collapseShapeOp.getSrc(), options);
116 if (failed(maybeBuffer))
117 return failure();
118 Value buffer = *maybeBuffer;
119 auto bufferType = buffer.getType().cast<MemRefType>();
120
121 if (tensorResultType.getRank() == 0) {
122 // 0-d collapses must go through a different op builder.
123 MemRefType resultType;
124
125 if (bufferType.getLayout().isIdentity()) {
126 // Standard layout: result type has no offset.
127 MemRefLayoutAttrInterface layout;
128 resultType = MemRefType::get({}, tensorResultType.getElementType(),
129 layout, bufferType.getMemorySpace());
130 } else {
131 // Source memref has a layout map: result type has the same offset as
132 // the source type.
133 SmallVector<int64_t> strides;
134 int64_t offset;
135 if (failed(getStridesAndOffset(bufferType, strides, offset)))
136 return failure();
137 AffineMap resultLayout =
138 makeStridedLinearLayoutMap({}, offset, op->getContext());
139 resultType =
140 MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
141 bufferType.getMemorySpaceAsInt());
142 }
143
144 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
145 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
146 return success();
147 }
148
149 // If the dims are not collapsible (due to an incompatible source layout
150 // map), force an out-of-place bufferization, i.e., a buffer copy. This
151 // newly allocated buffer will have no layout map and thus be collapsible.
152 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
153 bufferType, collapseShapeOp.getReassociationIndices());
154 if (!canBeCollapsed) {
155 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
156 AnalysisState analysisState(options);
157 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
158 rewriter, op->getLoc(), collapseShapeOp.getSrc(),
159 analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
160 if (failed(tensorAlloc))
161 return failure();
162 auto memrefType =
163 MemRefType::get(collapseShapeOp.getSrcType().getShape(),
164 collapseShapeOp.getSrcType().getElementType(),
165 AffineMap(), bufferType.getMemorySpaceAsInt());
166 buffer = rewriter.create<bufferization::ToMemrefOp>(
167 op->getLoc(), memrefType, *tensorAlloc);
168 }
169
170 // Result type is inferred by the builder.
171 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
172 rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
173 return success();
174 }
175 };
176
177 /// Bufferization of tensor.dim. Replace with memref.dim.
178 struct DimOpInterface
179 : public BufferizableOpInterface::ExternalModel<DimOpInterface,
180 tensor::DimOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::DimOpInterface181 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
182 const AnalysisState &state) const {
183 return true;
184 }
185
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::DimOpInterface186 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
187 const AnalysisState &state) const {
188 return false;
189 }
190
getAliasingOpResultmlir::tensor::__anonb90e36390111::DimOpInterface191 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
192 const AnalysisState &state) const {
193 return {};
194 }
195
bufferizemlir::tensor::__anonb90e36390111::DimOpInterface196 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
197 const BufferizationOptions &options) const {
198 auto dimOp = cast<tensor::DimOp>(op);
199 FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
200 if (failed(v))
201 return failure();
202 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
203 dimOp.getIndex());
204 return success();
205 }
206 };
207
208 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
209 struct ExpandShapeOpInterface
210 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
211 tensor::ExpandShapeOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ExpandShapeOpInterface212 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
213 const AnalysisState &state) const {
214 return false;
215 }
216
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ExpandShapeOpInterface217 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
218 const AnalysisState &state) const {
219 return false;
220 }
221
getAliasingOpResultmlir::tensor::__anonb90e36390111::ExpandShapeOpInterface222 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
223 const AnalysisState &state) const {
224 if (&opOperand == &op->getOpOperand(0) /*src*/)
225 return {op->getOpResult(0)};
226 return {};
227 }
228
bufferRelationmlir::tensor::__anonb90e36390111::ExpandShapeOpInterface229 BufferRelation bufferRelation(Operation *op, OpResult opResult,
230 const AnalysisState &state) const {
231 return BufferRelation::Equivalent;
232 }
233
bufferizemlir::tensor::__anonb90e36390111::ExpandShapeOpInterface234 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
235 const BufferizationOptions &options) const {
236 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
237 auto tensorResultType = expandShapeOp.getResultType();
238 FailureOr<Value> buffer =
239 getBuffer(rewriter, expandShapeOp.getSrc(), options);
240 if (failed(buffer))
241 return failure();
242
243 // Memref result type is inferred by the builder based on reassociation
244 // indices and result shape.
245 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
246 rewriter, op, tensorResultType.getShape(), *buffer,
247 expandShapeOp.getReassociationIndices());
248 return success();
249 }
250 };
251
252 /// Bufferization of tensor.extract_slice. Replace with memref.subview.
253 struct ExtractSliceOpInterface
254 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
255 tensor::ExtractSliceOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ExtractSliceOpInterface256 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
257 const AnalysisState &state) const {
258 return false;
259 }
260
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ExtractSliceOpInterface261 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
262 const AnalysisState &state) const {
263 return false;
264 }
265
getAliasingOpResultmlir::tensor::__anonb90e36390111::ExtractSliceOpInterface266 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
267 const AnalysisState &state) const {
268 if (&opOperand == &op->getOpOperand(0) /*source*/)
269 return {op->getOpResult(0)};
270 return {};
271 }
272
bufferRelationmlir::tensor::__anonb90e36390111::ExtractSliceOpInterface273 BufferRelation bufferRelation(Operation *op, OpResult opResult,
274 const AnalysisState &state) const {
275 return BufferRelation::None;
276 }
277
bufferizemlir::tensor::__anonb90e36390111::ExtractSliceOpInterface278 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
279 const BufferizationOptions &options) const {
280 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
281 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
282 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
283 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
284 Location loc = extractSliceOp.getLoc();
285
286 // Get source buffer.
287 FailureOr<Value> srcMemref =
288 getBuffer(rewriter, extractSliceOp.getSource(), options);
289 if (failed(srcMemref))
290 return failure();
291 auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
292
293 // Take a subview of the source buffer.
294 auto subviewMemRefType =
295 memref::SubViewOp::inferRankReducedResultType(
296 extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets,
297 mixedSizes, mixedStrides)
298 .cast<MemRefType>();
299 Value subView = rewriter.create<memref::SubViewOp>(
300 loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
301 mixedStrides);
302
303 replaceOpWithBufferizedValues(rewriter, op, subView);
304 return success();
305 }
306 };
307
308 /// Bufferization of tensor.extract. Replace with memref.load.
309 struct ExtractOpInterface
310 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
311 tensor::ExtractOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ExtractOpInterface312 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
313 const AnalysisState &state) const {
314 return true;
315 }
316
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ExtractOpInterface317 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
318 const AnalysisState &state) const {
319 return false;
320 }
321
getAliasingOpResultmlir::tensor::__anonb90e36390111::ExtractOpInterface322 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
323 const AnalysisState &state) const {
324 return {};
325 }
326
bufferizemlir::tensor::__anonb90e36390111::ExtractOpInterface327 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
328 const BufferizationOptions &options) const {
329 auto extractOp = cast<tensor::ExtractOp>(op);
330 FailureOr<Value> srcMemref =
331 getBuffer(rewriter, extractOp.getTensor(), options);
332 if (failed(srcMemref))
333 return failure();
334 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
335 extractOp.getIndices());
336 return success();
337 }
338 };
339
340 // Implements backtracking to traverse indices of the output buffer while
341 // iterating over op.elements().
createStores(RewriterBase & rewriter,Location loc,int dim,Value buffer,ArrayRef<int64_t> shape,ArrayRef<Value> constants,OperandRange::iterator & elementIt,SmallVectorImpl<Value> & indices)342 static void createStores(RewriterBase &rewriter, Location loc, int dim,
343 Value buffer, ArrayRef<int64_t> shape,
344 ArrayRef<Value> constants,
345 OperandRange::iterator &elementIt,
346 SmallVectorImpl<Value> &indices) {
347 if (dim == static_cast<int>(shape.size()) - 1) {
348 for (int i = 0; i < shape.back(); ++i) {
349 indices.back() = constants[i];
350 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
351 ++elementIt;
352 }
353 return;
354 }
355 for (int i = 0; i < shape[dim]; ++i) {
356 indices[dim] = constants[i];
357 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
358 indices);
359 }
360 }
361
362 /// Bufferization of tensor.from_elements.
363 struct FromElementsOpInterface
364 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
365 tensor::FromElementsOp> {
366
bufferizesToAllocationmlir::tensor::__anonb90e36390111::FromElementsOpInterface367 bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
368 return true;
369 }
370
bufferizemlir::tensor::__anonb90e36390111::FromElementsOpInterface371 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
372 const BufferizationOptions &options) const {
373 auto fromElementsOp = cast<tensor::FromElementsOp>(op);
374 // Should the buffer be deallocated?
375 bool dealloc = shouldDeallocateOpResult(
376 fromElementsOp.getResult().cast<OpResult>(), options);
377
378 // TODO: Implement memory space for this op.
379 if (options.defaultMemorySpace != static_cast<unsigned>(0))
380 return op->emitError("memory space not implemented yet");
381
382 // Allocate a buffer for the result.
383 Location loc = op->getLoc();
384 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
385 auto shape = tensorType.getShape();
386 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
387 FailureOr<Value> tensorAlloc =
388 allocateTensorForShapedValue(rewriter, loc, fromElementsOp.getResult(),
389 /*escape=*/!dealloc, options,
390 /*copy=*/false);
391 if (failed(tensorAlloc))
392 return failure();
393 auto memrefType =
394 MemRefType::get(tensorType.getShape(), tensorType.getElementType());
395 Value buffer = rewriter.create<bufferization::ToMemrefOp>(
396 op->getLoc(), memrefType, *tensorAlloc);
397
398 // Case: tensor<0xelem_type>.
399 if (fromElementsOp.getElements().empty()) {
400 replaceOpWithBufferizedValues(rewriter, op, buffer);
401 return success();
402 }
403
404 // Case: tensor<elem_type>.
405 if (shape.empty()) {
406 rewriter.create<memref::StoreOp>(
407 loc, fromElementsOp.getElements().front(), buffer);
408 replaceOpWithBufferizedValues(rewriter, op, buffer);
409 return success();
410 }
411
412 // Create constants for the range of possible indices [0, max{shape_i}).
413 auto maxDim = *std::max_element(shape.begin(), shape.end());
414 SmallVector<Value, 2> constants;
415 constants.reserve(maxDim);
416 for (int i = 0; i < maxDim; ++i)
417 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
418
419 // Traverse all `elements` and create `memref.store` ops.
420 auto elementIt = fromElementsOp.getElements().begin();
421 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
422 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
423 indices);
424
425 replaceOpWithBufferizedValues(rewriter, op, buffer);
426
427 return success();
428 }
429 };
430
431 /// Bufferization of tensor.generate.
432 struct GenerateOpInterface
433 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
434 tensor::GenerateOp> {
435
bufferizesToAllocationmlir::tensor::__anonb90e36390111::GenerateOpInterface436 bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
437 return true;
438 }
439
bufferizemlir::tensor::__anonb90e36390111::GenerateOpInterface440 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
441 const BufferizationOptions &options) const {
442 auto generateOp = cast<tensor::GenerateOp>(op);
443 // Should the buffer be deallocated?
444 bool dealloc = shouldDeallocateOpResult(
445 generateOp.getResult().cast<OpResult>(), options);
446
447 // TODO: Implement memory space for this op.
448 if (options.defaultMemorySpace != static_cast<unsigned>(0))
449 return op->emitError("memory space not implemented yet");
450
451 auto tensorType = generateOp.getType().cast<RankedTensorType>();
452 // Allocate memory.
453 Location loc = op->getLoc();
454 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
455 FailureOr<Value> tensorAlloc =
456 allocateTensorForShapedValue(rewriter, loc, generateOp.getResult(),
457 /*escape=*/!dealloc, options,
458 /*copy=*/false);
459 if (failed(tensorAlloc))
460 return failure();
461 auto memrefType =
462 MemRefType::get(tensorType.getShape(), tensorType.getElementType());
463 Value buffer = rewriter.create<bufferization::ToMemrefOp>(
464 op->getLoc(), memrefType, *tensorAlloc);
465
466 // Collect loop bounds.
467 int64_t rank = memrefType.getRank();
468 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
469 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
470 SmallVector<Value, 4> lowerBounds(rank, zero);
471 SmallVector<Value, 4> steps(rank, one);
472 SmallVector<Value, 4> upperBounds;
473 int nextDynamicIndex = 0;
474 for (int i = 0; i < rank; i++) {
475 Value upperBound =
476 memrefType.isDynamicDim(i)
477 ? generateOp.getDynamicExtents()[nextDynamicIndex++]
478 : rewriter.create<arith::ConstantIndexOp>(
479 loc, memrefType.getDimSize(i));
480 upperBounds.push_back(upperBound);
481 }
482
483 // Generate tensor elements with a parallel loop that stores into
484 // each element of the resulting memref. We use mergeBlockBefore to "move"
485 // this op's body into the scf.parallel's body.
486 auto parallel =
487 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
488 Block *parallelBody = parallel.getBody();
489 rewriter.mergeBlockBefore(&generateOp.getBody().front(),
490 parallelBody->getTerminator(),
491 parallelBody->getArguments());
492 // Replace the inlined yield op with a store op. The scf.parallel's builder
493 // already populated an scf.yield at the end, so we don't need to worry
494 // about creating that.
495 Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
496 rewriter.setInsertionPointAfter(elementYield);
497 rewriter.replaceOpWithNewOp<memref::StoreOp>(
498 elementYield, elementYield->getOperands()[0], buffer,
499 parallelBody->getArguments());
500
501 replaceOpWithBufferizedValues(rewriter, op, buffer);
502
503 return success();
504 }
505 };
506
507 /// Bufferization of tensor.insert. Replace with memref.store.
508 struct InsertOpInterface
509 : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
510 tensor::InsertOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::InsertOpInterface511 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
512 const AnalysisState &state) const {
513 return true;
514 }
515
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::InsertOpInterface516 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
517 const AnalysisState &state) const {
518 return true;
519 }
520
getAliasingOpResultmlir::tensor::__anonb90e36390111::InsertOpInterface521 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
522 const AnalysisState &state) const {
523 assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
524 "expected dest OpOperand");
525 return {op->getOpResult(0)};
526 }
527
528 SmallVector<OpOperand *>
getAliasingOpOperandmlir::tensor::__anonb90e36390111::InsertOpInterface529 getAliasingOpOperand(Operation *op, OpResult opResult,
530 const AnalysisState &state) const {
531 return {&op->getOpOperand(1) /*dest*/};
532 }
533
bufferizemlir::tensor::__anonb90e36390111::InsertOpInterface534 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
535 const BufferizationOptions &options) const {
536 auto insertOp = cast<tensor::InsertOp>(op);
537 FailureOr<Value> destMemref =
538 getBuffer(rewriter, insertOp.getDest(), options);
539 if (failed(destMemref))
540 return failure();
541 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
542 *destMemref, insertOp.getIndices());
543 replaceOpWithBufferizedValues(rewriter, op, *destMemref);
544 return success();
545 }
546
bufferRelationmlir::tensor::__anonb90e36390111::InsertOpInterface547 BufferRelation bufferRelation(Operation *op, OpResult opResult,
548 const AnalysisState &state) const {
549 return BufferRelation::Equivalent;
550 }
551 };
552
553 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
554 /// equivalent operand / result and same offset/sizes/strides specification).
555 template <typename OpTy>
areEquivalentExtractSliceOps(const AnalysisState & state,ExtractSliceOp extractSliceOp,OpTy insertSliceOp)556 static bool areEquivalentExtractSliceOps(const AnalysisState &state,
557 ExtractSliceOp extractSliceOp,
558 OpTy insertSliceOp) {
559 if (!extractSliceOp || !insertSliceOp)
560 return false;
561 if (extractSliceOp != insertSliceOp &&
562 !state.areEquivalentBufferizedValues(extractSliceOp.getSource(),
563 insertSliceOp.getDest()))
564 return false;
565 if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp,
566 isEqualConstantIntOrValue))
567 return false;
568 return true;
569 }
570
571 /// Return true if `value` is originating from an ExtractSliceOp that matches
572 /// the given InsertSliceOp.
573 template <typename OpTy>
hasMatchingExtractSliceOp(const AnalysisState & state,Value value,OpTy insertSliceOp)574 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
575 OpTy insertSliceOp) {
576 auto condition = [&](Value val) {
577 if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
578 if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp))
579 return true;
580 return false;
581 };
582
583 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
584 condition);
585 }
586
587 template <typename OpTy>
isNotConflictingInsertSliceLikeOp(Operation * op,OpOperand * uRead,OpOperand * uConflictingWrite,const AnalysisState & state)588 static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
589 OpOperand *uConflictingWrite,
590 const AnalysisState &state) {
591 Operation *readingOp = uRead->getOwner();
592 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
593
594 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
595 // uRead is an InsertSliceOp...
596 if (auto insertSliceOp = dyn_cast<OpTy>(readingOp)) {
597 // As an example, consider the following IR.
598 //
599 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
600 // %1 = linalg.fill %cst, %0 {inplace= [true] }
601 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
602 // {inplace= [true] }
603
604 // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
605 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
606 hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
607 insertSliceOp))
608 // Case 1: The main insight is that InsertSliceOp reads only part of
609 // the destination tensor. The overwritten area is not read. If
610 // uConflictingWrite writes into exactly the memory location that is
611 // being read by uRead, this is not a conflict.
612 //
613 // In the above example:
614 // uRead = OpOperand 1 (%t) of tensor.insert_slice
615 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
616 //
617 // The read of %t does not conflict with the write of the FillOp
618 // (same aliases!) because the area that the FillOp operates on is
619 // exactly the one that is *not* read via %t.
620 return true;
621
622 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
623 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
624 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
625 // Case 2: The read of the source tensor and the write to the dest
626 // tensor via an InsertSliceOp is not a conflict if the read is
627 // reading exactly that part of an equivalent tensor that the
628 // InsertSliceOp is writing.
629 //
630 // In the above example:
631 // uRead = OpOperand 0 (%1) of tensor.insert_slice
632 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
633 return true;
634 }
635
636 // If uConflictingWrite is an InsertSliceOp...
637 if (auto insertSliceOp = dyn_cast<OpTy>(conflictingWritingOp))
638 // As an example, consider the following IR.
639 //
640 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
641 // %1 = linalg.fill %cst, %0 {inplace= [true] }
642 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
643 // {inplace= [true] }
644 // %3 = vector.transfer_read %1, %cst
645 //
646 // In the above example:
647 // uRead = OpOperand 0 (%1) of vector.transfer_read
648 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
649 // lastWrite = %1
650 //
651 // This is not a conflict because the InsertSliceOp overwrites the
652 // memory segment of %1 with the exact same data. (Effectively, there
653 // is no memory write here.)
654 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
655 state.areEquivalentBufferizedValues(uRead->get(),
656 insertSliceOp.getSource()) &&
657 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
658 insertSliceOp))
659 return true;
660
661 return false;
662 }
663
664 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
665 /// certain circumstances, this op can also be a no-op.
666 struct InsertSliceOpInterface
667 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
668 tensor::InsertSliceOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::InsertSliceOpInterface669 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
670 const AnalysisState &state) const {
671 return true;
672 }
673
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::InsertSliceOpInterface674 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
675 const AnalysisState &state) const {
676 return &opOperand == &op->getOpOperand(1) /*dest*/;
677 }
678
getAliasingOpResultmlir::tensor::__anonb90e36390111::InsertSliceOpInterface679 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
680 const AnalysisState &state) const {
681 if (&opOperand == &op->getOpOperand(1) /*dest*/)
682 return {op->getResult(0)};
683 return {};
684 }
685
bufferRelationmlir::tensor::__anonb90e36390111::InsertSliceOpInterface686 BufferRelation bufferRelation(Operation *op, OpResult opResult,
687 const AnalysisState &state) const {
688 return BufferRelation::Equivalent;
689 }
690
isNotConflictingmlir::tensor::__anonb90e36390111::InsertSliceOpInterface691 bool isNotConflicting(Operation *op, OpOperand *uRead,
692 OpOperand *uConflictingWrite,
693 const AnalysisState &state) const {
694 return isNotConflictingInsertSliceLikeOp<tensor::InsertSliceOp>(
695 op, uRead, uConflictingWrite, state);
696 }
697
bufferizemlir::tensor::__anonb90e36390111::InsertSliceOpInterface698 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
699 const BufferizationOptions &options) const {
700 // insert_slice ops arise from tiling and bufferizing them out-of-place is
701 // generally a deal breaker. When used with loops, this ends up cloning the
702 // whole tensor on every single iteration and is a symptom of a
703 // catastrophically bad scheduling decision.
704 // TODO: be very loud about it or even consider failing the pass.
705 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
706 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
707 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
708 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
709 Location loc = insertSliceOp.getLoc();
710
711 // Get destination buffer.
712 FailureOr<Value> dstMemref =
713 getBuffer(rewriter, insertSliceOp.getDest(), options);
714 if (failed(dstMemref))
715 return failure();
716
717 // Take a subview of the destination buffer.
718 auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
719 auto subviewMemRefType =
720 memref::SubViewOp::inferRankReducedResultType(
721 insertSliceOp.getSourceType().getShape(), dstMemrefType,
722 mixedOffsets, mixedSizes, mixedStrides)
723 .cast<MemRefType>();
724 Value subView = rewriter.create<memref::SubViewOp>(
725 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
726 mixedStrides);
727
728 // Copy tensor. If this tensor.insert_slice has a matching
729 // tensor.extract_slice, the copy operation will eventually fold away.
730 FailureOr<Value> srcMemref =
731 getBuffer(rewriter, insertSliceOp.getSource(), options);
732 if (failed(srcMemref))
733 return failure();
734 if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
735 return failure();
736
737 replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
738 return success();
739 }
740 };
741
742 /// Bufferization of tensor.rank. Replace with memref.rank.
743 struct RankOpInterface
744 : public BufferizableOpInterface::ExternalModel<RankOpInterface,
745 tensor::RankOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::RankOpInterface746 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
747 const AnalysisState &state) const {
748 return true;
749 }
750
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::RankOpInterface751 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
752 const AnalysisState &state) const {
753 return false;
754 }
755
getAliasingOpResultmlir::tensor::__anonb90e36390111::RankOpInterface756 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
757 const AnalysisState &state) const {
758 return {};
759 }
760
bufferizemlir::tensor::__anonb90e36390111::RankOpInterface761 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
762 const BufferizationOptions &options) const {
763 auto rankOp = cast<tensor::RankOp>(op);
764 FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
765 if (failed(v))
766 return failure();
767 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
768 *v);
769 return success();
770 }
771 };
772
773 /// Bufferization of tensor.reshape. Replace with memref.reshape.
774 struct ReshapeOpInterface
775 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
776 tensor::ReshapeOp> {
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ReshapeOpInterface777 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
778 const AnalysisState &state) const {
779 if (&opOperand == &op->getOpOperand(1) /* shape */)
780 return true;
781 return false;
782 }
783
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ReshapeOpInterface784 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
785 const AnalysisState &state) const {
786 return false;
787 }
788
getAliasingOpResultmlir::tensor::__anonb90e36390111::ReshapeOpInterface789 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
790 const AnalysisState &state) const {
791 return {op->getOpResult(0)};
792 }
793
bufferRelationmlir::tensor::__anonb90e36390111::ReshapeOpInterface794 BufferRelation bufferRelation(Operation *op, OpResult opResult,
795 const AnalysisState &state) const {
796 return BufferRelation::Equivalent;
797 }
798
bufferizemlir::tensor::__anonb90e36390111::ReshapeOpInterface799 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
800 const BufferizationOptions &options) const {
801 auto reshapeOp = cast<tensor::ReshapeOp>(op);
802 FailureOr<Value> srcBuffer =
803 getBuffer(rewriter, reshapeOp.getSource(), options);
804 FailureOr<Value> shapeBuffer =
805 getBuffer(rewriter, reshapeOp.getShape(), options);
806 if (failed(srcBuffer) || failed(shapeBuffer))
807 return failure();
808 auto resultMemRefType = getMemRefType(
809 reshapeOp.getResult(), options, /*layout=*/{},
810 srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
811 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
812 rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
813 return success();
814 }
815 };
816
817 /// Analysis of ParallelInsertSliceOp.
818 struct ParallelInsertSliceOpInterface
819 : public BufferizableOpInterface::ExternalModel<
820 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
getAliasingOpResultmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface821 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
822 const AnalysisState &state) const {
823 if (&opOperand != &op->getOpOperand(1) /*dest*/)
824 return {};
825
826 // ParallelInsertSliceOp itself has no results, query its tied op results.
827 auto insertOp = cast<ParallelInsertSliceOp>(op);
828 return {insertOp.getTiedOpResult()};
829 }
830
bufferizesToMemoryReadmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface831 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
832 const AnalysisState &state) const {
833 return true;
834 }
835
bufferizesToMemoryWritemlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface836 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
837 const AnalysisState &state) const {
838 return &opOperand == &op->getOpOperand(1) /*dest*/;
839 }
840
bufferRelationmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface841 BufferRelation bufferRelation(Operation *op, OpResult opResult,
842 const AnalysisState &state) const {
843 return BufferRelation::Equivalent;
844 }
845
resolveConflictsmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface846 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
847 const AnalysisState &state) const {
848 // This interface method is overridden because we want to set a custom
849 // insertion point for tensor copies. They should be inserted right before
850 // the ForeachThreadOp. E.g.:
851 //
852 // %r0, %r1 = foreach_thead ... {
853 // ...
854 // perform_concurrently {
855 // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
856 // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
857 // }
858 // }
859 //
860 // After TensorCopyInsertion:
861 //
862 // %copy = bufferization.alloc_tensor() copy(%d)
863 // %r0, %r1 = foreach_thead ... {
864 // ...
865 // perform_concurrently {
866 // parallel_insert_slice %a into %b ...
867 // parallel_insert_slice %c into %copy ...
868 // }
869 // }
870
871 OpBuilder::InsertionGuard g(rewriter);
872 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
873 ParallelCombiningOpInterface parallelCombiningParent =
874 parallelInsertSliceOp.getParallelCombiningParent();
875 Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
876
877 // Nothing to do if the destination tensor is inplace.
878 assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
879 "source is always in-place");
880 if (state.isInPlace(op->getOpOperand(1) /*dest*/))
881 return success();
882
883 // Find corresponding OpResult.
884 OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
885
886 // Insert tensor allocation right before the ForeachThreadOp.
887 rewriter.setInsertionPoint(parallelIteratingOp);
888 bool isYielded = state.isTensorYielded(opResult);
889 FailureOr<Value> alloc = allocateTensorForShapedValue(
890 rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
891 /*escape=*/isYielded, state.getOptions());
892 if (failed(alloc))
893 return failure();
894
895 // Update destination operand.
896 rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
897 parallelInsertSliceOp.getDestMutable().assign(*alloc);
898 });
899
900 return success();
901 }
902
bufferizemlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface903 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
904 const BufferizationOptions &options) const {
905 OpBuilder::InsertionGuard g(rewriter);
906 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
907 ParallelCombiningOpInterface parallelCombiningParent =
908 parallelInsertSliceOp.getParallelCombiningParent();
909 Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
910
911 // Get destination buffer.
912 FailureOr<Value> destBuffer =
913 getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
914 if (failed(destBuffer))
915 return failure();
916
917 // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
918 rewriter.setInsertionPoint(parallelCombiningParent);
919 FailureOr<Value> srcBuffer =
920 getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
921 if (failed(srcBuffer))
922 return failure();
923
924 // Take a subview of the destination buffer.
925 auto destBufferType = destBuffer->getType().cast<MemRefType>();
926 auto subviewMemRefType =
927 memref::SubViewOp::inferRankReducedResultType(
928 parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
929 parallelInsertSliceOp.getMixedOffsets(),
930 parallelInsertSliceOp.getMixedSizes(),
931 parallelInsertSliceOp.getMixedStrides())
932 .cast<MemRefType>();
933 Value subview = rewriter.create<memref::SubViewOp>(
934 parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
935 parallelInsertSliceOp.getMixedOffsets(),
936 parallelInsertSliceOp.getMixedSizes(),
937 parallelInsertSliceOp.getMixedStrides());
938
939 // This memcpy will fold away if everything bufferizes in-place.
940 if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
941 *srcBuffer, subview)))
942 return failure();
943
944 // Replace all uses of parallelIteratingOp (just the corresponding result).
945 rewriter.setInsertionPointAfter(parallelIteratingOp);
946 Value toTensorOp =
947 rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
948 // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
949 SmallVector<OpOperand *> resultUses = llvm::to_vector(
950 llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
951 [](OpOperand &use) { return &use; }));
952 for (OpOperand *use : resultUses) {
953 rewriter.updateRootInPlace(use->getOwner(),
954 [&]() { use->set(toTensorOp); });
955 }
956 rewriter.eraseOp(op);
957 return success();
958 }
959
isNotConflictingmlir::tensor::__anonb90e36390111::ParallelInsertSliceOpInterface960 bool isNotConflicting(Operation *op, OpOperand *uRead,
961 OpOperand *uConflictingWrite,
962 const AnalysisState &state) const {
963 return isNotConflictingInsertSliceLikeOp<tensor::ParallelInsertSliceOp>(
964 op, uRead, uConflictingWrite, state);
965 }
966 };
967
968 } // namespace
969 } // namespace tensor
970 } // namespace mlir
971
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)972 void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
973 DialectRegistry ®istry) {
974 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
975 CastOp::attachInterface<CastOpInterface>(*ctx);
976 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
977 DimOp::attachInterface<DimOpInterface>(*ctx);
978 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
979 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
980 ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
981 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
982 GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
983 InsertOp::attachInterface<InsertOpInterface>(*ctx);
984 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
985 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
986 *ctx);
987 RankOp::attachInterface<RankOpInterface>(*ctx);
988 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
989
990 // Load additional dialects of which ops may get created.
991 ctx->loadDialect<arith::ArithmeticDialect, scf::SCFDialect>();
992 });
993 }
994