1 //===- InlineScalarOperands.cpp - Pass to inline scalar operands =============// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to inline scalar operands into a generic 10 // operation. A scalar operand is an operand whose indexing map has a constant 11 // rhs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 27 namespace { 28 struct InlineScalarOperands : public OpRewritePattern<GenericOp> { 29 using OpRewritePattern<GenericOp>::OpRewritePattern; 30 LogicalResult matchAndRewrite(GenericOp genericOp, 31 PatternRewriter &rewriter) const override { 32 if (!genericOp.hasTensorSemantics()) 33 return failure(); 34 35 SmallVector<size_t> scalarOperands; 36 SmallVector<AffineMap> newIndexingMaps; 37 SmallVector<Value> newOperands; 38 for (OpOperand *opOperand : genericOp.getInputOperands()) { 39 AffineMap map = genericOp.getTiedIndexingMap(opOperand); 40 if (genericOp.isInputTensor(opOperand) && map.isConstant()) { 41 scalarOperands.emplace_back(opOperand->getOperandNumber()); 42 } else { 43 newIndexingMaps.emplace_back(map); 44 newOperands.emplace_back(opOperand->get()); 45 } 46 } 47 48 if (scalarOperands.empty()) 49 return failure(); 50 51 for (OpOperand *opOperand : genericOp.getOutputOperands()) 52 newIndexingMaps.emplace_back(genericOp.getTiedIndexingMap(opOperand)); 53 54 Location loc = genericOp->getLoc(); 55 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 56 auto newOp = rewriter.create<GenericOp>( 57 loc, genericOp->getResultTypes(), newOperands, outputOperands, 58 newIndexingMaps, 59 llvm::to_vector<4>( 60 genericOp.iterator_types().template getAsValueRange<StringAttr>())); 61 rewriter.cloneRegionBefore(genericOp.region(), newOp.region(), 62 newOp.region().begin()); 63 64 Block *body = newOp.getBody(); 65 PatternRewriter::InsertionGuard guard(rewriter); 66 rewriter.setInsertionPointToStart(body); 67 68 for (auto idx : llvm::reverse(scalarOperands)) { 69 OpOperand *opOperand = genericOp.getInputOperand(idx); 70 AffineMap map = genericOp.getTiedIndexingMap(opOperand); 71 SmallVector<int64_t> indices = map.getConstantResults(); 72 SmallVector<Value> indicesValues; 73 for (auto idx : indices) 74 indicesValues.emplace_back( 75 rewriter.create<arith::ConstantIndexOp>(loc, idx)); 76 Value extractedValue = rewriter.create<tensor::ExtractOp>( 77 loc, opOperand->get(), indicesValues); 78 body->getArgument(idx).replaceAllUsesWith(extractedValue); 79 body->eraseArgument(idx); 80 } 81 82 rewriter.replaceOp(genericOp, newOp->getResults()); 83 return success(); 84 } 85 }; 86 } // namespace 87 88 /// Patterns that are used to inline constant operands into linalg generic 89 /// ops. 90 void mlir::linalg::populateInlineConstantOperandsPatterns( 91 RewritePatternSet &patterns) { 92 auto *context = patterns.getContext(); 93 patterns.add<InlineScalarOperands>(context); 94 } 95 96 namespace { 97 /// Pass that removes unit-extent dims within generic ops. 98 struct LinalgInlineScalarOperandsPass 99 : public LinalgInlineScalarOperandsBase<LinalgInlineScalarOperandsPass> { 100 void runOnOperation() override { 101 func::FuncOp funcOp = getOperation(); 102 MLIRContext *context = funcOp.getContext(); 103 RewritePatternSet patterns(context); 104 105 populateInlineConstantOperandsPatterns(patterns); 106 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 107 } 108 }; 109 } // namespace 110 111 std::unique_ptr<OperationPass<func::FuncOp>> 112 mlir::createLinalgInlineScalarOperandsPass() { 113 return std::make_unique<LinalgInlineScalarOperandsPass>(); 114 } 115