1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/Function.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Pass/Pass.h" 20 21 using namespace ::mlir; 22 using namespace ::mlir::linalg; 23 24 static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp, 25 OpBuilder &b) { 26 auto indexingMaps = llvm::to_vector<4>( 27 linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>()); 28 auto inputIndexingMaps = 29 llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs()); 30 31 mlir::edsc::ScopedContext scope(b, loc); 32 return emitLoopRanges(scope.getBuilderRef(), loc, 33 concatAffineMaps(inputIndexingMaps), 34 getShape(b, linalgOp)); 35 } 36 37 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) { 38 if (val.getType().isIndex()) 39 return val; 40 return b.create<IndexCastOp>(loc, val, b.getIndexType()); 41 } 42 43 static LogicalResult 44 allocateBuffersForResults(Location loc, LinalgOp linalgOp, 45 linalg::GenericOpAdaptor &adaptor, 46 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 47 // Lazily compute loopRanges. 48 SmallVector<Range, 4> loopRanges; 49 50 // Allocate a buffer for every tensor result. 51 for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) { 52 size_t resultIndex = en.index(); 53 Type resultType = en.value(); 54 55 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 56 if (tensorType == nullptr) { 57 linalgOp.emitOpError() 58 << "tensor to buffer conversion expects ranked tensor results"; 59 return failure(); 60 } 61 auto tensorShape = tensorType.getShape(); 62 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 63 64 // Allocate buffers for init tensors that are assumed to fold onto the first 65 // results. 66 // TODO: update this assumption because the reality is more complex 67 // under linalg on tensor based transformations. 68 bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors(); 69 if (foldedInitTensor) { 70 Value initTensor = linalgOp.getInitTensor(resultIndex); 71 Value initBuffer = adaptor.init_tensors()[resultIndex]; 72 SmallVector<Value, 4> dynOperands; 73 for (auto dim : llvm::enumerate(tensorShape)) { 74 if (dim.value() == TensorType::kDynamicSize) { 75 dynOperands.push_back(b.create<DimOp>(loc, initTensor, dim.index())); 76 } 77 } 78 auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands); 79 b.create<linalg::CopyOp>(loc, initBuffer, alloc); 80 resultBuffers.push_back(alloc); 81 continue; 82 } 83 84 // Allocate buffers for statically-shaped results. 85 if (memrefType.hasStaticShape()) { 86 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType)); 87 continue; 88 } 89 90 // Perform a naive shape inference for the dynamically-shaped results. 91 // Extract the required element out of the vector. 92 SmallVector<Value, 4> dynOperands; 93 auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex); 94 for (auto shapeElement : llvm::enumerate(tensorType.getShape())) { 95 if (loopRanges.empty()) 96 loopRanges = computeLoopRanges(loc, linalgOp, b); 97 98 if (shapeElement.value() != ShapedType::kDynamicSize) 99 continue; 100 101 AffineExpr expr = resultIndexingMap.getResult(shapeElement.index()); 102 switch (expr.getKind()) { 103 case AffineExprKind::DimId: { 104 int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition(); 105 Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b); 106 dynOperands.push_back(size); 107 break; 108 } 109 default: 110 return failure(); 111 } 112 } 113 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands)); 114 } 115 return success(); 116 } 117 118 // Specialization for `linalg::GenericOp`. 119 /// A pattern to convert Generic Linalg operations which work on tensors to 120 /// use buffers. BufferPlacement pass should be later used to move 121 /// Alloc operations to the correct positions and insert the missing Dealloc 122 /// operations in the correct places. 123 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 124 linalg::GenericOp genericOp, 125 ValueRange inputs, ValueRange outputs) { 126 // Generate a new linalg operation that works on buffers. 127 auto newGenericOp = rewriter.create<linalg::GenericOp>( 128 genericOp.getLoc(), 129 /*resultTensorTypes=*/llvm::None, 130 /*inputs=*/inputs, 131 /*outputBuffers=*/outputs, 132 /*initTensors=*/llvm::None, genericOp.indexing_maps(), 133 genericOp.iterator_types(), genericOp.docAttr(), 134 genericOp.library_callAttr(), genericOp.symbol_sourceAttr()); 135 136 // Create a new block in the region of the new Generic Op. 137 Block *oldBlock = genericOp.getBody(); 138 Region &newRegion = newGenericOp.region(); 139 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 140 oldBlock->getArgumentTypes()); 141 142 // Add the result arguments to the new block. 143 for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors())) 144 newBlock->addArgument(v.getType().cast<MemRefType>().getElementType()); 145 146 // Clone the body of the old block to the new block. 147 BlockAndValueMapping mapping; 148 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 149 150 OpBuilder::InsertionGuard guard(rewriter); 151 rewriter.setInsertionPointToEnd(newBlock); 152 for (auto &op : oldBlock->getOperations()) { 153 Operation *clonedOp = rewriter.clone(op, mapping); 154 mapping.map(op.getResults(), clonedOp->getResults()); 155 } 156 157 // Replace the results of the old op with the new output buffers. 158 rewriter.replaceOp(genericOp, outputs); 159 } 160 161 // TODO: Specialization for `linalg::IndexedGenericOp`. 162 163 // Specialization for all other `linalg::LinalgOp`. 164 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 165 linalg::LinalgOp linalgOp, 166 ValueRange inputs, ValueRange outputs) { 167 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 168 assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation())); 169 SmallVector<Value, 8> newOperands = inputs; 170 newOperands.append(outputs.begin(), outputs.end()); 171 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 172 newOperands.append(otherOperands.begin(), otherOperands.end()); 173 LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(), 174 /*resultTypes=*/ArrayRef<Type>{}, 175 newOperands)); 176 // Need to mutate the operands_segment_sizes in the resulting op. 177 res.setNumOutputBuffers(outputs.size()); 178 res.setNumInitTensors(0); 179 // Replace the results of the old op with the new output buffers. 180 rewriter.replaceOp(linalgOp, outputs); 181 } 182 183 //===----------------------------------------------------------------------===// 184 // Bufferization patterns. 185 //===----------------------------------------------------------------------===// 186 187 namespace { 188 /// Generic conversion pattern that matches any LinalgOp. This avoids template 189 /// instantiating one pattern for each LinalgOp. 190 class BufferizeAnyLinalgOp : public ConversionPattern { 191 public: 192 BufferizeAnyLinalgOp(TypeConverter &typeConverter) 193 : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} 194 195 LogicalResult 196 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 197 ConversionPatternRewriter &rewriter) const final { 198 199 LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op); 200 if (!linalgOp) 201 return failure(); 202 203 // We abuse the GenericOpAdaptor here. 204 // TODO: Manually create an Adaptor that captures inputs, output_buffers and 205 // init_tensors for all linalg::LinalgOp interface ops. 206 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 207 208 Location loc = linalgOp.getLoc(); 209 SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(), 210 adaptor.output_buffers().end()); 211 212 if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, 213 newOutputBuffers, rewriter))) { 214 linalgOp.emitOpError() 215 << "Failed to allocate buffers for tensor results."; 216 return failure(); 217 } 218 219 // Delegate to the linalg generic pattern. 220 if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) { 221 finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), 222 newOutputBuffers); 223 return success(); 224 } 225 226 finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), 227 newOutputBuffers); 228 return success(); 229 } 230 }; 231 } // namespace 232 233 namespace { 234 /// TensorConstantOp conversion inserts a linearized 1-D vector constant that is 235 /// stored in memory. A linalg.reshape is introduced to convert to the desired 236 /// n-D buffer form. 237 class TensorConstantOpConverter : public OpConversionPattern<ConstantOp> { 238 public: 239 using OpConversionPattern::OpConversionPattern; 240 241 LogicalResult 242 matchAndRewrite(ConstantOp op, ArrayRef<Value> operands, 243 ConversionPatternRewriter &rewriter) const final { 244 245 RankedTensorType rankedTensorType = 246 op.getType().dyn_cast<RankedTensorType>(); 247 if (!rankedTensorType) 248 return failure(); 249 if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) { 250 return s == 0 || ShapedType::isDynamic(s); 251 })) 252 return failure(); 253 254 int64_t nElements = 1; 255 for (int64_t s : rankedTensorType.getShape()) 256 nElements *= s; 257 Type elementType = rankedTensorType.getElementType(); 258 MemRefType memrefType = 259 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 260 VectorType flatVectorType = VectorType::get({nElements}, elementType); 261 MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); 262 MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); 263 264 Location loc = op.getLoc(); 265 auto attr = op.getValue().cast<DenseElementsAttr>(); 266 Value alloc = 267 rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{}); 268 Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType, 269 attr.reshape(flatVectorType)); 270 rewriter.create<StoreOp>(loc, cstVec, alloc); 271 272 Value memref = 273 rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc); 274 if (rankedTensorType.getRank() > 1) { 275 // Introduce a linalg.reshape to flatten the memref. 276 AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( 277 /*numDims=*/rankedTensorType.getRank(), op.getContext()); 278 memref = rewriter.create<linalg::ReshapeOp>( 279 loc, memrefType, memref, 280 rewriter.getAffineMapArrayAttr(collapseAllDims)); 281 } 282 rewriter.replaceOp(op, memref); 283 284 return success(); 285 } 286 }; 287 } // namespace 288 289 namespace { 290 291 /// Converts Linalg operations that work on tensor-type operands or results to 292 /// work on buffers. 293 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 294 void runOnOperation() override { 295 MLIRContext &context = getContext(); 296 ConversionTarget target(context); 297 BufferizeTypeConverter converter; 298 299 // Mark all Standard operations legal. 300 // TODO: Remove after TensorConstantOpConverter moves to std-bufferize. 301 target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>(); 302 303 // Mark all Linalg operations illegal as long as they work on tensors. 304 auto isLegalOperation = [&](Operation *op) { 305 return converter.isLegal(op); 306 }; 307 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 308 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 309 310 OwningRewritePatternList patterns; 311 populateLinalgBufferizePatterns(&context, converter, patterns); 312 if (failed(applyPartialConversion(getOperation(), target, 313 std::move(patterns)))) 314 signalPassFailure(); 315 } 316 }; 317 } // end anonymous namespace 318 319 std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgBufferizePass() { 320 return std::make_unique<LinalgBufferizePass>(); 321 } 322 void mlir::linalg::populateLinalgBufferizePatterns( 323 MLIRContext *context, BufferizeTypeConverter &converter, 324 OwningRewritePatternList &patterns) { 325 326 patterns.insert<BufferizeAnyLinalgOp>(converter); 327 patterns.insert<TensorConstantOpConverter>(converter, context); 328 } 329