1 //===- InlineScalarOperands.cpp - Pass to inline scalar operands =============// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to inline scalar operands into a generic 10 // operation. A scalar operand is an operand whose indexing map has a constant 11 // rhs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 26 namespace { 27 struct InlineScalarOperands : public OpRewritePattern<GenericOp> { 28 using OpRewritePattern<GenericOp>::OpRewritePattern; 29 LogicalResult matchAndRewrite(GenericOp genericOp, 30 PatternRewriter &rewriter) const override { 31 if (!genericOp.hasTensorSemantics()) 32 return failure(); 33 34 SmallVector<size_t> scalarOperands; 35 SmallVector<AffineMap> newIndexingMaps; 36 SmallVector<Value> newOperands; 37 for (OpOperand *opOperand : genericOp.getInputOperands()) { 38 AffineMap map = genericOp.getTiedIndexingMap(opOperand); 39 if (genericOp.isInputTensor(opOperand) && map.isConstant()) { 40 scalarOperands.emplace_back(opOperand->getOperandNumber()); 41 } else { 42 newIndexingMaps.emplace_back(map); 43 newOperands.emplace_back(opOperand->get()); 44 } 45 } 46 47 if (scalarOperands.empty()) 48 return failure(); 49 50 for (OpOperand *opOperand : genericOp.getOutputOperands()) 51 newIndexingMaps.emplace_back(genericOp.getTiedIndexingMap(opOperand)); 52 53 Location loc = genericOp->getLoc(); 54 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 55 auto newOp = rewriter.create<GenericOp>( 56 loc, genericOp->getResultTypes(), newOperands, outputOperands, 57 newIndexingMaps, 58 llvm::to_vector<4>( 59 genericOp.iterator_types().template getAsValueRange<StringAttr>())); 60 rewriter.cloneRegionBefore(genericOp.region(), newOp.region(), 61 newOp.region().begin()); 62 63 Block *body = newOp.getBody(); 64 PatternRewriter::InsertionGuard guard(rewriter); 65 rewriter.setInsertionPointToStart(body); 66 67 for (auto idx : llvm::reverse(scalarOperands)) { 68 OpOperand *opOperand = genericOp.getInputOperand(idx); 69 AffineMap map = genericOp.getTiedIndexingMap(opOperand); 70 SmallVector<int64_t> indices = map.getConstantResults(); 71 SmallVector<Value> indicesValues; 72 for (auto idx : indices) 73 indicesValues.emplace_back(rewriter.create<ConstantIndexOp>(loc, idx)); 74 Value extractedValue = rewriter.create<tensor::ExtractOp>( 75 loc, opOperand->get(), indicesValues); 76 body->getArgument(idx).replaceAllUsesWith(extractedValue); 77 body->eraseArgument(idx); 78 } 79 80 rewriter.replaceOp(genericOp, newOp->getResults()); 81 return success(); 82 } 83 }; 84 } // namespace 85 86 /// Patterns that are used to inline constant operands into linalg generic 87 /// ops. 88 void mlir::linalg::populateInlineConstantOperandsPatterns( 89 RewritePatternSet &patterns) { 90 auto *context = patterns.getContext(); 91 patterns.add<InlineScalarOperands>(context); 92 } 93 94 namespace { 95 /// Pass that removes unit-extent dims within generic ops. 96 struct LinalgInlineScalarOperandsPass 97 : public LinalgInlineScalarOperandsBase<LinalgInlineScalarOperandsPass> { 98 void runOnFunction() override { 99 FuncOp funcOp = getFunction(); 100 MLIRContext *context = funcOp.getContext(); 101 RewritePatternSet patterns(context); 102 103 populateInlineConstantOperandsPatterns(patterns); 104 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 105 } 106 }; 107 } // namespace 108 109 std::unique_ptr<OperationPass<FuncOp>> 110 mlir::createLinalgInlineScalarOperandsPass() { 111 return std::make_unique<LinalgInlineScalarOperandsPass>(); 112 } 113