1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Vector/VectorOps.h" 16 #include "mlir/IR/Function.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Pass/Pass.h" 19 20 using namespace ::mlir; 21 using namespace ::mlir::linalg; 22 23 static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp, 24 OpBuilder &b) { 25 auto indexingMaps = llvm::to_vector<4>( 26 linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>()); 27 auto inputIndexingMaps = 28 llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs()); 29 30 mlir::edsc::ScopedContext scope(b, loc); 31 return emitLoopRanges(scope.getBuilderRef(), loc, 32 concatAffineMaps(inputIndexingMaps), 33 getShape(b, linalgOp)); 34 } 35 36 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) { 37 if (val.getType().isIndex()) 38 return val; 39 return b.create<IndexCastOp>(loc, val, b.getIndexType()); 40 } 41 42 static LogicalResult 43 allocateBuffersForResults(Location loc, LinalgOp linalgOp, 44 linalg::GenericOpAdaptor &adaptor, 45 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 46 // Lazily compute loopRanges. 47 SmallVector<Range, 4> loopRanges; 48 49 // Allocate a buffer for every tensor result. 50 for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) { 51 size_t resultIndex = en.index(); 52 Type resultType = en.value(); 53 54 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 55 if (tensorType == nullptr) { 56 linalgOp.emitOpError() 57 << "tensor to buffer conversion expects ranked tensor results"; 58 return failure(); 59 } 60 auto tensorShape = tensorType.getShape(); 61 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 62 63 // Allocate buffers for init tensors that are assumed to fold onto the first 64 // results. 65 // TODO: update this assumption because the reality is more complex 66 // under linalg on tensor based transformations. 67 bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors(); 68 if (foldedInitTensor) { 69 // Dealing with an init tensor requires distinguishing between 1-use 70 // and many-use cases which would create aliasing and WAR hazards. 71 Value initTensor = linalgOp.getInitTensor(resultIndex); 72 Value initBuffer = adaptor.init_tensors()[resultIndex]; 73 if (initTensor.hasOneUse()) { 74 resultBuffers.push_back(initBuffer); 75 continue; 76 } 77 SmallVector<Value, 4> dynOperands; 78 for (auto dim : llvm::enumerate(tensorShape)) { 79 if (dim.value() == TensorType::kDynamicSize) { 80 dynOperands.push_back(b.create<DimOp>(loc, initTensor, dim.index())); 81 } 82 } 83 auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands); 84 b.create<linalg::CopyOp>(loc, initBuffer, alloc); 85 resultBuffers.push_back(alloc); 86 continue; 87 } 88 89 // Allocate buffers for statically-shaped results. 90 if (memrefType.hasStaticShape()) { 91 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType)); 92 continue; 93 } 94 95 // Perform a naive shape inference for the dynamically-shaped results. 96 // Extract the required element out of the vector. 97 SmallVector<Value, 4> dynOperands; 98 auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex); 99 for (auto shapeElement : llvm::enumerate(tensorType.getShape())) { 100 if (loopRanges.empty()) 101 loopRanges = computeLoopRanges(loc, linalgOp, b); 102 103 if (shapeElement.value() != ShapedType::kDynamicSize) 104 continue; 105 106 AffineExpr expr = resultIndexingMap.getResult(shapeElement.index()); 107 switch (expr.getKind()) { 108 case AffineExprKind::DimId: { 109 int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition(); 110 Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b); 111 dynOperands.push_back(size); 112 break; 113 } 114 default: 115 return failure(); 116 } 117 } 118 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands)); 119 } 120 return success(); 121 } 122 123 // Specialization for `linalg::GenericOp`. 124 /// A pattern to convert Generic Linalg operations which work on tensors to 125 /// use buffers. BufferPlacement pass should be later used to move 126 /// Alloc operations to the correct positions and insert the missing Dealloc 127 /// operations in the correct places. 128 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 129 linalg::GenericOp genericOp, 130 ValueRange inputs, ValueRange outputs) { 131 // Generate a new linalg operation that works on buffers. 132 auto newGenericOp = rewriter.create<linalg::GenericOp>( 133 genericOp.getLoc(), 134 /*resultTensorTypes=*/llvm::None, 135 /*inputs=*/inputs, 136 /*outputBuffers=*/outputs, 137 /*initTensors=*/llvm::None, genericOp.indexing_maps(), 138 genericOp.iterator_types(), genericOp.docAttr(), 139 genericOp.library_callAttr(), genericOp.symbol_sourceAttr()); 140 141 // Create a new block in the region of the new Generic Op. 142 Block *oldBlock = genericOp.getBody(); 143 Region &newRegion = newGenericOp.region(); 144 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 145 oldBlock->getArgumentTypes()); 146 147 // Add the result arguments to the new block. 148 for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors())) 149 newBlock->addArgument(v.getType().cast<MemRefType>().getElementType()); 150 151 // Clone the body of the old block to the new block. 152 BlockAndValueMapping mapping; 153 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 154 155 OpBuilder::InsertionGuard guard(rewriter); 156 rewriter.setInsertionPointToEnd(newBlock); 157 for (auto &op : oldBlock->getOperations()) { 158 Operation *clonedOp = rewriter.clone(op, mapping); 159 mapping.map(op.getResults(), clonedOp->getResults()); 160 } 161 162 // Replace the results of the old op with the new output buffers. 163 rewriter.replaceOp(genericOp, outputs); 164 } 165 166 // TODO: Specialization for `linalg::IndexedGenericOp`. 167 168 // Specialization for all other `linalg::LinalgOp`. 169 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 170 linalg::LinalgOp linalgOp, 171 ValueRange inputs, ValueRange outputs) { 172 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 173 assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation())); 174 SmallVector<Value, 8> newOperands = inputs; 175 newOperands.append(outputs.begin(), outputs.end()); 176 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 177 newOperands.append(otherOperands.begin(), otherOperands.end()); 178 LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(), 179 /*resultTypes=*/ArrayRef<Type>{}, 180 newOperands)); 181 // Need to mutate the operands_segment_sizes in the resulting op. 182 res.setNumOutputBuffers(outputs.size()); 183 res.setNumInitTensors(0); 184 // Replace the results of the old op with the new output buffers. 185 rewriter.replaceOp(linalgOp, outputs); 186 } 187 188 LogicalResult mlir::linalg::LinalgOpConverter::matchAndRewrite( 189 Operation *op, ArrayRef<Value> operands, 190 ConversionPatternRewriter &rewriter) const { 191 LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op); 192 if (!linalgOp) 193 return failure(); 194 195 // We abuse the GenericOpAdaptor here. 196 // TODO: Manually create an Adaptor that captures inputs, output_buffers and 197 // init_tensors for all linalg::LinalgOp interface ops. 198 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 199 200 // All inputs need to be turned into buffers first. Until then, bail out. 201 if (llvm::any_of(adaptor.inputs(), 202 [](Value in) { return !in.getType().isa<MemRefType>(); })) 203 return failure(); 204 205 // All init_tensors need to be turned into buffers first. Until then, bail 206 // out. 207 if (llvm::any_of(adaptor.init_tensors(), 208 [](Value in) { return !in.getType().isa<MemRefType>(); })) 209 return failure(); 210 211 Location loc = linalgOp.getLoc(); 212 SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(), 213 adaptor.output_buffers().end()); 214 215 if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, newOutputBuffers, 216 rewriter))) { 217 linalgOp.emitOpError() << "Failed to allocate buffers for tensor results."; 218 return failure(); 219 } 220 221 // Delegate to the linalg generic pattern. 222 if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) { 223 finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), 224 newOutputBuffers); 225 return success(); 226 } 227 228 finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), 229 newOutputBuffers); 230 return success(); 231 } 232 233 LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite( 234 ConstantOp op, ArrayRef<Value> operands, 235 ConversionPatternRewriter &rewriter) const { 236 RankedTensorType rankedTensorType = op.getType().dyn_cast<RankedTensorType>(); 237 if (!rankedTensorType) 238 return failure(); 239 if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) { 240 return s == 0 || ShapedType::isDynamic(s); 241 })) 242 return failure(); 243 244 int64_t nElements = 1; 245 for (int64_t s : rankedTensorType.getShape()) 246 nElements *= s; 247 Type elementType = rankedTensorType.getElementType(); 248 MemRefType memrefType = 249 converter.convertType(op.getType()).cast<MemRefType>(); 250 VectorType flatVectorType = VectorType::get({nElements}, elementType); 251 MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); 252 MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); 253 254 Location loc = op.getLoc(); 255 auto attr = op.getValue().cast<DenseElementsAttr>(); 256 Value alloc = 257 rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{}); 258 Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType, 259 attr.reshape(flatVectorType)); 260 rewriter.create<StoreOp>(loc, cstVec, alloc); 261 262 Value memref = 263 rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc); 264 if (rankedTensorType.getRank() > 1) { 265 // Introduce a linalg.reshape to flatten the memref. 266 AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( 267 /*numDims=*/rankedTensorType.getRank(), op.getContext()); 268 memref = rewriter.create<linalg::ReshapeOp>( 269 loc, memrefType, memref, 270 rewriter.getAffineMapArrayAttr(collapseAllDims)); 271 } 272 rewriter.replaceOp(op, memref); 273 274 return success(); 275 } 276 277 LogicalResult mlir::linalg::TensorCastOpConverter::matchAndRewrite( 278 TensorCastOp op, ArrayRef<Value> operands, 279 ConversionPatternRewriter &rewriter) const { 280 if (op.getType().hasRank()) 281 return failure(); 282 Type t = UnrankedMemRefType::get(op.getType().getElementType(), 283 /*memorySpace=*/0); 284 rewriter.replaceOpWithNewOp<MemRefCastOp>(op, t, operands.front()); 285 return success(); 286 } 287 288 namespace { 289 290 /// Converts Linalg operations that work on tensor-type operands or results to 291 /// work on buffers. 292 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 293 void runOnOperation() override { 294 MLIRContext &context = getContext(); 295 ConversionTarget target(context); 296 BufferizeTypeConverter converter; 297 298 // Mark all Standard operations legal. 299 target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>(); 300 target.addLegalOp<ModuleOp>(); 301 target.addLegalOp<ModuleTerminatorOp>(); 302 303 // Mark all Linalg operations illegal as long as they work on tensors. 304 auto isLegalOperation = [&](Operation *op) { 305 return converter.isLegal(op); 306 }; 307 target.addDynamicallyLegalDialect<linalg::LinalgDialect>( 308 Optional<ConversionTarget::DynamicLegalityCallbackFn>( 309 isLegalOperation)); 310 311 // Mark operations that consume or return tensors illegal. 312 auto isLegal = [&](Operation *op) { 313 if (llvm::any_of(op->getOperandTypes(), 314 [&](Type t) { return !converter.isLegal(t); })) 315 return false; 316 if (llvm::any_of(op->getResultTypes(), 317 [&](Type t) { return !converter.isLegal(t); })) 318 return false; 319 return true; 320 }; 321 target.addDynamicallyLegalOp< 322 // clang-format off 323 CallOp, 324 ConstantOp, 325 ConstantIntOp, 326 ConstantIndexOp, 327 ConstantFloatOp, 328 ReturnOp, 329 TensorCastOp 330 // clang-format on 331 >(isLegal); 332 333 // Mark the function operation illegal as long as an argument is tensor. 334 // TODO: if the FuncOp is a FuncOp that only has a declaration (e.g. to an 335 // externally defined symbol like an external library calls), only convert 336 // if some special attribute is set. This will allow more control of interop 337 // across ABI boundaries. 338 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) { 339 return converter.isSignatureLegal(funcOp.getType()) && 340 llvm::none_of(funcOp.getType().getResults(), 341 [&](Type type) { return type.isa<MemRefType>(); }) && 342 converter.isLegal(&funcOp.getBody()); 343 }); 344 345 converter.setResultConversionKind<RankedTensorType, MemRefType>( 346 BufferizeTypeConverter::AppendToArgumentsList); 347 348 OwningRewritePatternList patterns; 349 populateLinalgBufferizePatterns(&context, converter, patterns); 350 populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp, 351 linalg::CopyOp>( 352 &context, converter, patterns); 353 if (failed(applyFullConversion(this->getOperation(), target, patterns))) 354 this->signalPassFailure(); 355 } 356 }; 357 } // end anonymous namespace 358 359 std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgBufferizePass() { 360 return std::make_unique<LinalgBufferizePass>(); 361 } 362 void mlir::linalg::populateLinalgBufferizePatterns( 363 MLIRContext *context, BufferizeTypeConverter &converter, 364 OwningRewritePatternList &patterns) { 365 patterns.insert< 366 // clang-format off 367 LinalgOpConverter, 368 TensorCastOpConverter, 369 TensorConstantOpConverter 370 // clang-format on 371 >(context, converter); 372 } 373