1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/Function.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Pass/Pass.h" 20 21 using namespace ::mlir; 22 using namespace ::mlir::linalg; 23 24 static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp, 25 OpBuilder &b) { 26 auto indexingMaps = llvm::to_vector<4>( 27 linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>()); 28 auto inputIndexingMaps = 29 llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs()); 30 31 mlir::edsc::ScopedContext scope(b, loc); 32 return emitLoopRanges(scope.getBuilderRef(), loc, 33 concatAffineMaps(inputIndexingMaps), 34 getShape(b, linalgOp)); 35 } 36 37 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) { 38 if (val.getType().isIndex()) 39 return val; 40 return b.create<IndexCastOp>(loc, val, b.getIndexType()); 41 } 42 43 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 44 auto memrefType = memref.getType().cast<MemRefType>(); 45 SmallVector<Value, 4> dynOperands; 46 for (auto dim : llvm::enumerate(memrefType.getShape())) { 47 if (dim.value() == TensorType::kDynamicSize) { 48 dynOperands.push_back(b.create<DimOp>(loc, memref, dim.index())); 49 } 50 } 51 auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands); 52 b.create<linalg::CopyOp>(loc, memref, alloc); 53 return alloc; 54 } 55 56 static LogicalResult 57 allocateBuffersForResults(Location loc, LinalgOp linalgOp, 58 linalg::GenericOpAdaptor &adaptor, 59 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 60 // Lazily compute loopRanges. 61 SmallVector<Range, 4> loopRanges; 62 63 // Allocate a buffer for every tensor result. 64 for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) { 65 size_t resultIndex = en.index(); 66 Type resultType = en.value(); 67 68 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 69 if (tensorType == nullptr) { 70 linalgOp.emitOpError() 71 << "tensor to buffer conversion expects ranked tensor results"; 72 return failure(); 73 } 74 auto tensorShape = tensorType.getShape(); 75 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 76 77 // Allocate buffers for init tensors that are assumed to fold onto the first 78 // results. 79 // TODO: update this assumption because the reality is more complex 80 // under linalg on tensor based transformations. 81 bool hasInitTensor = resultIndex < linalgOp.getNumInitTensors(); 82 if (hasInitTensor) { 83 resultBuffers.push_back( 84 cloneMemref(loc, adaptor.init_tensors()[resultIndex], b)); 85 continue; 86 } 87 88 // Allocate buffers for statically-shaped results. 89 if (memrefType.hasStaticShape()) { 90 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType)); 91 continue; 92 } 93 94 // Perform a naive shape inference for the dynamically-shaped results. 95 // Extract the required element out of the vector. 96 SmallVector<Value, 4> dynOperands; 97 auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex); 98 for (auto shapeElement : llvm::enumerate(tensorType.getShape())) { 99 if (loopRanges.empty()) 100 loopRanges = computeLoopRanges(loc, linalgOp, b); 101 102 if (shapeElement.value() != ShapedType::kDynamicSize) 103 continue; 104 105 AffineExpr expr = resultIndexingMap.getResult(shapeElement.index()); 106 switch (expr.getKind()) { 107 case AffineExprKind::DimId: { 108 int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition(); 109 Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b); 110 dynOperands.push_back(size); 111 break; 112 } 113 default: 114 return failure(); 115 } 116 } 117 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands)); 118 } 119 return success(); 120 } 121 122 // Specialization for `linalg::GenericOp`. 123 /// A pattern to convert Generic Linalg operations which work on tensors to 124 /// use buffers. BufferPlacement pass should be later used to move 125 /// Alloc operations to the correct positions and insert the missing Dealloc 126 /// operations in the correct places. 127 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 128 linalg::GenericOp genericOp, 129 ValueRange inputs, ValueRange outputs) { 130 // Generate a new linalg operation that works on buffers. 131 auto newGenericOp = rewriter.create<linalg::GenericOp>( 132 genericOp.getLoc(), 133 /*resultTensorTypes=*/llvm::None, 134 /*inputs=*/inputs, 135 /*outputBuffers=*/outputs, 136 /*initTensors=*/llvm::None, genericOp.indexing_maps(), 137 genericOp.iterator_types(), genericOp.docAttr(), 138 genericOp.library_callAttr(), genericOp.sparseAttr(), 139 genericOp.symbol_sourceAttr()); 140 141 // Create a new block in the region of the new Generic Op. 142 Block *oldBlock = genericOp.getBody(); 143 Region &newRegion = newGenericOp.region(); 144 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 145 oldBlock->getArgumentTypes()); 146 147 // Add the result arguments to the new block. 148 for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors())) 149 newBlock->addArgument(v.getType().cast<MemRefType>().getElementType()); 150 151 // Clone the body of the old block to the new block. 152 BlockAndValueMapping mapping; 153 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 154 155 OpBuilder::InsertionGuard guard(rewriter); 156 rewriter.setInsertionPointToEnd(newBlock); 157 for (auto &op : oldBlock->getOperations()) { 158 Operation *clonedOp = rewriter.clone(op, mapping); 159 mapping.map(op.getResults(), clonedOp->getResults()); 160 } 161 162 // Replace the results of the old op with the new output buffers. 163 rewriter.replaceOp(genericOp, outputs); 164 } 165 166 // TODO: Specialization for `linalg::IndexedGenericOp`. 167 168 // Specialization for all other `linalg::LinalgOp`. 169 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 170 linalg::LinalgOp linalgOp, 171 ValueRange inputs, ValueRange outputs) { 172 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 173 assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation())); 174 SmallVector<Value, 8> newOperands = inputs; 175 newOperands.append(outputs.begin(), outputs.end()); 176 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 177 newOperands.append(otherOperands.begin(), otherOperands.end()); 178 LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(), 179 /*resultTypes=*/ArrayRef<Type>{}, 180 newOperands)); 181 // Need to mutate the operands_segment_sizes in the resulting op. 182 res.setNumOutputBuffers(outputs.size()); 183 res.setNumInitTensors(0); 184 // Replace the results of the old op with the new output buffers. 185 rewriter.replaceOp(linalgOp, outputs); 186 } 187 188 //===----------------------------------------------------------------------===// 189 // Bufferization patterns. 190 //===----------------------------------------------------------------------===// 191 192 namespace { 193 /// Generic conversion pattern that matches any LinalgOp. This avoids template 194 /// instantiating one pattern for each LinalgOp. 195 class BufferizeAnyLinalgOp : public ConversionPattern { 196 public: 197 BufferizeAnyLinalgOp(TypeConverter &typeConverter) 198 : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} 199 200 LogicalResult 201 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 202 ConversionPatternRewriter &rewriter) const final { 203 204 LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op); 205 if (!linalgOp) 206 return failure(); 207 208 // We abuse the GenericOpAdaptor here. 209 // TODO: Manually create an Adaptor that captures inputs, output_buffers and 210 // init_tensors for all linalg::LinalgOp interface ops. 211 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 212 213 Location loc = linalgOp.getLoc(); 214 SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(), 215 adaptor.output_buffers().end()); 216 217 if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, 218 newOutputBuffers, rewriter))) { 219 linalgOp.emitOpError() 220 << "Failed to allocate buffers for tensor results."; 221 return failure(); 222 } 223 224 // Delegate to the linalg generic pattern. 225 if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) { 226 finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), 227 newOutputBuffers); 228 return success(); 229 } 230 231 finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), 232 newOutputBuffers); 233 return success(); 234 } 235 }; 236 237 // Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 238 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 239 return llvm::to_vector<4>( 240 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 241 return a.cast<IntegerAttr>().getInt(); 242 })); 243 } 244 245 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy 246 /// pattern. 247 /// ``` 248 /// %a = alloc(sizes) 249 /// %sv = subview %source [offsets][sizes][strides] 250 /// linalg_copy(%sv, %a) 251 /// ``` 252 /// 253 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 254 /// std::CopyOp. 255 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> { 256 public: 257 using OpConversionPattern<SubTensorOp>::OpConversionPattern; 258 259 LogicalResult 260 matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands, 261 ConversionPatternRewriter &rewriter) const final { 262 SubTensorOpAdaptor adaptor(operands, 263 op.getOperation()->getAttrDictionary()); 264 Value sourceMemref = adaptor.source(); 265 assert(sourceMemref.getType().isa<MemRefType>()); 266 267 MemRefType subviewMemRefType = 268 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 269 // op.sizes() capture exactly the dynamic alloc operands matching the 270 // subviewMemRefType thanks to subview/subtensor canonicalization and 271 // verification. 272 Value alloc = 273 rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes()); 274 Value subView = rewriter.create<SubViewOp>( 275 op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()), 276 extractFromI64ArrayAttr(op.static_sizes()), 277 extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(), 278 op.strides()); 279 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 280 rewriter.replaceOp(op, alloc); 281 return success(); 282 } 283 }; 284 285 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> 286 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern. 287 /// tensor_to_memref and tensor_load are inserted automatically by the 288 /// conversion infra: 289 /// ``` 290 /// %sv = subview %dest [offsets][sizes][strides] 291 /// linalg_copy(%source, %sv) 292 /// // replace with %dest 293 /// ``` 294 /// 295 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 296 /// std::CopyOp. 297 class SubTensorInsertOpConverter 298 : public OpConversionPattern<SubTensorInsertOp> { 299 public: 300 using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern; 301 302 LogicalResult 303 matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands, 304 ConversionPatternRewriter &rewriter) const final { 305 SubTensorInsertOpAdaptor adaptor(operands, 306 op.getOperation()->getAttrDictionary()); 307 Value sourceMemRef = adaptor.source(); 308 assert(sourceMemRef.getType().isa<MemRefType>()); 309 310 // For now, be conservative and copy the converted input memref. 311 // In general, the converted input memref here could be aliased or could 312 // point into constant memory, so mutating it would lead to miscompilations. 313 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 314 assert(destMemRef.getType().isa<MemRefType>()); 315 316 // Take a subview to copy the small memref. 317 Value subview = rewriter.create<SubViewOp>( 318 op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()), 319 extractFromI64ArrayAttr(op.static_sizes()), 320 extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(), 321 adaptor.sizes(), adaptor.strides()); 322 // Copy the small memref. 323 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 324 rewriter.replaceOp(op, destMemRef); 325 return success(); 326 } 327 }; 328 } // namespace 329 330 namespace { 331 /// Converts Linalg operations that work on tensor-type operands or results to 332 /// work on buffers. 333 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 334 void runOnOperation() override { 335 MLIRContext &context = getContext(); 336 ConversionTarget target(context); 337 BufferizeTypeConverter typeConverter; 338 339 // Mark all Standard operations legal. 340 target.addLegalDialect<StandardOpsDialect>(); 341 target.addIllegalOp<SubTensorOp, SubTensorInsertOp>(); 342 343 // Mark all Linalg operations illegal as long as they work on tensors. 344 auto isLegalOperation = [&](Operation *op) { 345 return typeConverter.isLegal(op); 346 }; 347 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 348 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 349 350 OwningRewritePatternList patterns; 351 populateLinalgBufferizePatterns(&context, typeConverter, patterns); 352 if (failed(applyPartialConversion(getOperation(), target, 353 std::move(patterns)))) 354 signalPassFailure(); 355 } 356 }; 357 } // end anonymous namespace 358 359 std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgBufferizePass() { 360 return std::make_unique<LinalgBufferizePass>(); 361 } 362 363 void mlir::linalg::populateLinalgBufferizePatterns( 364 MLIRContext *context, BufferizeTypeConverter &typeConverter, 365 OwningRewritePatternList &patterns) { 366 patterns.insert<BufferizeAnyLinalgOp>(typeConverter); 367 // TODO: Drop this once tensor constants work in standard. 368 patterns.insert< 369 // clang-format off 370 SubTensorOpConverter, 371 SubTensorInsertOpConverter 372 // clang-format on 373 >(typeConverter, context); 374 } 375