1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/Function.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Pass/Pass.h" 20 21 using namespace ::mlir; 22 using namespace ::mlir::linalg; 23 24 static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp, 25 OpBuilder &b) { 26 auto indexingMaps = llvm::to_vector<4>( 27 linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>()); 28 auto inputIndexingMaps = 29 llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs()); 30 31 mlir::edsc::ScopedContext scope(b, loc); 32 return emitLoopRanges(scope.getBuilderRef(), loc, 33 concatAffineMaps(inputIndexingMaps), 34 getShape(b, linalgOp)); 35 } 36 37 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) { 38 if (val.getType().isIndex()) 39 return val; 40 return b.create<IndexCastOp>(loc, val, b.getIndexType()); 41 } 42 43 static LogicalResult 44 allocateBuffersForResults(Location loc, LinalgOp linalgOp, 45 linalg::GenericOpAdaptor &adaptor, 46 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 47 // Lazily compute loopRanges. 48 SmallVector<Range, 4> loopRanges; 49 50 // Allocate a buffer for every tensor result. 51 for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) { 52 size_t resultIndex = en.index(); 53 Type resultType = en.value(); 54 55 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 56 if (tensorType == nullptr) { 57 linalgOp.emitOpError() 58 << "tensor to buffer conversion expects ranked tensor results"; 59 return failure(); 60 } 61 auto tensorShape = tensorType.getShape(); 62 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 63 64 // Allocate buffers for init tensors that are assumed to fold onto the first 65 // results. 66 // TODO: update this assumption because the reality is more complex 67 // under linalg on tensor based transformations. 68 bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors(); 69 if (foldedInitTensor) { 70 Value initTensor = linalgOp.getInitTensor(resultIndex); 71 Value initBuffer = adaptor.init_tensors()[resultIndex]; 72 SmallVector<Value, 4> dynOperands; 73 for (auto dim : llvm::enumerate(tensorShape)) { 74 if (dim.value() == TensorType::kDynamicSize) { 75 dynOperands.push_back(b.create<DimOp>(loc, initTensor, dim.index())); 76 } 77 } 78 auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands); 79 b.create<linalg::CopyOp>(loc, initBuffer, alloc); 80 resultBuffers.push_back(alloc); 81 continue; 82 } 83 84 // Allocate buffers for statically-shaped results. 85 if (memrefType.hasStaticShape()) { 86 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType)); 87 continue; 88 } 89 90 // Perform a naive shape inference for the dynamically-shaped results. 91 // Extract the required element out of the vector. 92 SmallVector<Value, 4> dynOperands; 93 auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex); 94 for (auto shapeElement : llvm::enumerate(tensorType.getShape())) { 95 if (loopRanges.empty()) 96 loopRanges = computeLoopRanges(loc, linalgOp, b); 97 98 if (shapeElement.value() != ShapedType::kDynamicSize) 99 continue; 100 101 AffineExpr expr = resultIndexingMap.getResult(shapeElement.index()); 102 switch (expr.getKind()) { 103 case AffineExprKind::DimId: { 104 int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition(); 105 Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b); 106 dynOperands.push_back(size); 107 break; 108 } 109 default: 110 return failure(); 111 } 112 } 113 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands)); 114 } 115 return success(); 116 } 117 118 // Specialization for `linalg::GenericOp`. 119 /// A pattern to convert Generic Linalg operations which work on tensors to 120 /// use buffers. BufferPlacement pass should be later used to move 121 /// Alloc operations to the correct positions and insert the missing Dealloc 122 /// operations in the correct places. 123 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 124 linalg::GenericOp genericOp, 125 ValueRange inputs, ValueRange outputs) { 126 // Generate a new linalg operation that works on buffers. 127 auto newGenericOp = rewriter.create<linalg::GenericOp>( 128 genericOp.getLoc(), 129 /*resultTensorTypes=*/llvm::None, 130 /*inputs=*/inputs, 131 /*outputBuffers=*/outputs, 132 /*initTensors=*/llvm::None, genericOp.indexing_maps(), 133 genericOp.iterator_types(), genericOp.docAttr(), 134 genericOp.library_callAttr(), genericOp.symbol_sourceAttr()); 135 136 // Create a new block in the region of the new Generic Op. 137 Block *oldBlock = genericOp.getBody(); 138 Region &newRegion = newGenericOp.region(); 139 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 140 oldBlock->getArgumentTypes()); 141 142 // Add the result arguments to the new block. 143 for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors())) 144 newBlock->addArgument(v.getType().cast<MemRefType>().getElementType()); 145 146 // Clone the body of the old block to the new block. 147 BlockAndValueMapping mapping; 148 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 149 150 OpBuilder::InsertionGuard guard(rewriter); 151 rewriter.setInsertionPointToEnd(newBlock); 152 for (auto &op : oldBlock->getOperations()) { 153 Operation *clonedOp = rewriter.clone(op, mapping); 154 mapping.map(op.getResults(), clonedOp->getResults()); 155 } 156 157 // Replace the results of the old op with the new output buffers. 158 rewriter.replaceOp(genericOp, outputs); 159 } 160 161 // TODO: Specialization for `linalg::IndexedGenericOp`. 162 163 // Specialization for all other `linalg::LinalgOp`. 164 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 165 linalg::LinalgOp linalgOp, 166 ValueRange inputs, ValueRange outputs) { 167 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 168 assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation())); 169 SmallVector<Value, 8> newOperands = inputs; 170 newOperands.append(outputs.begin(), outputs.end()); 171 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 172 newOperands.append(otherOperands.begin(), otherOperands.end()); 173 LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(), 174 /*resultTypes=*/ArrayRef<Type>{}, 175 newOperands)); 176 // Need to mutate the operands_segment_sizes in the resulting op. 177 res.setNumOutputBuffers(outputs.size()); 178 res.setNumInitTensors(0); 179 // Replace the results of the old op with the new output buffers. 180 rewriter.replaceOp(linalgOp, outputs); 181 } 182 183 //===----------------------------------------------------------------------===// 184 // Bufferization patterns. 185 //===----------------------------------------------------------------------===// 186 187 namespace { 188 /// Generic conversion pattern that matches any LinalgOp. This avoids template 189 /// instantiating one pattern for each LinalgOp. 190 class BufferizeAnyLinalgOp : public ConversionPattern { 191 public: 192 BufferizeAnyLinalgOp(TypeConverter &typeConverter) 193 : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} 194 195 LogicalResult 196 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 197 ConversionPatternRewriter &rewriter) const final { 198 199 LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op); 200 if (!linalgOp) 201 return failure(); 202 203 // We abuse the GenericOpAdaptor here. 204 // TODO: Manually create an Adaptor that captures inputs, output_buffers and 205 // init_tensors for all linalg::LinalgOp interface ops. 206 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 207 208 Location loc = linalgOp.getLoc(); 209 SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(), 210 adaptor.output_buffers().end()); 211 212 if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, 213 newOutputBuffers, rewriter))) { 214 linalgOp.emitOpError() 215 << "Failed to allocate buffers for tensor results."; 216 return failure(); 217 } 218 219 // Delegate to the linalg generic pattern. 220 if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) { 221 finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), 222 newOutputBuffers); 223 return success(); 224 } 225 226 finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), 227 newOutputBuffers); 228 return success(); 229 } 230 }; 231 232 // Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 233 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 234 return llvm::to_vector<4>( 235 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 236 return a.cast<IntegerAttr>().getInt(); 237 })); 238 } 239 240 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy 241 /// pattern. 242 /// ``` 243 /// %a = alloc(sizes) 244 /// %sv = subview %source [offsets][sizes][strides] 245 /// linalg_copy(%sv, %a) 246 /// ``` 247 /// 248 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 249 /// std::CopyOp. 250 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> { 251 public: 252 using OpConversionPattern<SubTensorOp>::OpConversionPattern; 253 254 LogicalResult 255 matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands, 256 ConversionPatternRewriter &rewriter) const final { 257 SubTensorOpAdaptor adaptor(operands, 258 op.getOperation()->getAttrDictionary()); 259 Value sourceMemref = adaptor.source(); 260 assert(sourceMemref.getType().isa<MemRefType>()); 261 262 MemRefType subviewMemRefType = 263 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 264 // op.sizes() capture exactly the dynamic alloc operands matching the 265 // subviewMemRefType thanks to subview/subtensor canonicalization and 266 // verification. 267 Value alloc = 268 rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes()); 269 Value subView = rewriter.create<SubViewOp>( 270 op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()), 271 extractFromI64ArrayAttr(op.static_sizes()), 272 extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(), 273 op.strides()); 274 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 275 rewriter.replaceOp(op, alloc); 276 return success(); 277 } 278 }; 279 280 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> 281 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern. 282 /// tensor_to_memref and tensor_load are inserted automatically by the 283 /// conversion infra: 284 /// ``` 285 /// %sv = subview %dest [offsets][sizes][strides] 286 /// linalg_copy(%source, %sv) 287 /// // replace with %dest 288 /// ``` 289 /// 290 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 291 /// std::CopyOp. 292 class SubTensorInsertOpConverter 293 : public OpConversionPattern<SubTensorInsertOp> { 294 public: 295 using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern; 296 297 LogicalResult 298 matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands, 299 ConversionPatternRewriter &rewriter) const final { 300 SubTensorInsertOpAdaptor adaptor(operands, 301 op.getOperation()->getAttrDictionary()); 302 Value sourceMemRef = adaptor.source(); 303 assert(sourceMemRef.getType().isa<MemRefType>()); 304 305 Value destMemRef = adaptor.dest(); 306 assert(destMemRef.getType().isa<MemRefType>()); 307 308 // Take a subview to copy the small memref. 309 Value subview = rewriter.create<SubViewOp>( 310 op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()), 311 extractFromI64ArrayAttr(op.static_sizes()), 312 extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(), 313 adaptor.sizes(), adaptor.strides()); 314 // Copy the small memref. 315 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 316 rewriter.replaceOp(op, destMemRef); 317 return success(); 318 } 319 }; 320 321 /// TensorConstantOp conversion inserts a linearized 1-D vector constant that is 322 /// stored in memory. A linalg.reshape is introduced to convert to the desired 323 /// n-D buffer form. 324 class TensorConstantOpConverter : public OpConversionPattern<ConstantOp> { 325 public: 326 using OpConversionPattern::OpConversionPattern; 327 328 LogicalResult 329 matchAndRewrite(ConstantOp op, ArrayRef<Value> operands, 330 ConversionPatternRewriter &rewriter) const final { 331 332 RankedTensorType rankedTensorType = 333 op.getType().dyn_cast<RankedTensorType>(); 334 if (!rankedTensorType) 335 return failure(); 336 if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) { 337 return s == 0 || ShapedType::isDynamic(s); 338 })) 339 return failure(); 340 341 int64_t nElements = 1; 342 for (int64_t s : rankedTensorType.getShape()) 343 nElements *= s; 344 Type elementType = rankedTensorType.getElementType(); 345 MemRefType memrefType = 346 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 347 VectorType flatVectorType = VectorType::get({nElements}, elementType); 348 MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); 349 MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); 350 351 Location loc = op.getLoc(); 352 auto attr = op.getValue().cast<DenseElementsAttr>(); 353 Value alloc = 354 rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{}); 355 Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType, 356 attr.reshape(flatVectorType)); 357 rewriter.create<StoreOp>(loc, cstVec, alloc); 358 359 Value memref = 360 rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc); 361 if (rankedTensorType.getRank() > 1) { 362 // Introduce a linalg.reshape to flatten the memref. 363 AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( 364 /*numDims=*/rankedTensorType.getRank(), op.getContext()); 365 memref = rewriter.create<linalg::ReshapeOp>( 366 loc, memrefType, memref, 367 rewriter.getAffineMapArrayAttr(collapseAllDims)); 368 } 369 rewriter.replaceOp(op, memref); 370 371 return success(); 372 } 373 }; 374 } // namespace 375 376 namespace { 377 /// Converts Linalg operations that work on tensor-type operands or results to 378 /// work on buffers. 379 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 380 void runOnOperation() override { 381 MLIRContext &context = getContext(); 382 ConversionTarget target(context); 383 BufferizeTypeConverter typeConverter; 384 385 // Mark all Standard operations legal. 386 target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>(); 387 target.addIllegalOp<SubTensorOp, SubTensorInsertOp>(); 388 389 // Mark all Linalg operations illegal as long as they work on tensors. 390 auto isLegalOperation = [&](Operation *op) { 391 return typeConverter.isLegal(op); 392 }; 393 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 394 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 395 396 OwningRewritePatternList patterns; 397 populateLinalgBufferizePatterns(&context, typeConverter, patterns); 398 if (failed(applyPartialConversion(getOperation(), target, 399 std::move(patterns)))) 400 signalPassFailure(); 401 } 402 }; 403 } // end anonymous namespace 404 405 std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgBufferizePass() { 406 return std::make_unique<LinalgBufferizePass>(); 407 } 408 409 void mlir::linalg::populateLinalgBufferizePatterns( 410 MLIRContext *context, BufferizeTypeConverter &typeConverter, 411 OwningRewritePatternList &patterns) { 412 patterns.insert<BufferizeAnyLinalgOp>(typeConverter); 413 // TODO: Drop this once tensor constants work in standard. 414 patterns.insert< 415 // clang-format off 416 SubTensorOpConverter, 417 SubTensorInsertOpConverter, 418 TensorConstantOpConverter 419 // clang-format on 420 >(typeConverter, context); 421 } 422