1 //===- InlineScalarOperands.cpp - Pass to inline scalar operands =============// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to inline scalar operands into a generic 10 // operation. A scalar operand is an operand whose indexing map has a constant 11 // rhs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 26 namespace { 27 struct InlineScalarOperands : public OpRewritePattern<GenericOp> { 28 using OpRewritePattern<GenericOp>::OpRewritePattern; 29 LogicalResult matchAndRewrite(GenericOp genericOp, 30 PatternRewriter &rewriter) const override { 31 if (!genericOp.hasTensorSemantics()) 32 return failure(); 33 34 SmallVector<size_t> scalarOperands; 35 SmallVector<AffineMap> newIndexingMaps; 36 SmallVector<Value> newOperands; 37 for (auto it : llvm::enumerate(llvm::zip(genericOp.getInputIndexingMaps(), 38 genericOp.getInputTensors()))) { 39 AffineMap map = std::get<0>(it.value()); 40 if (map.isConstant()) { 41 scalarOperands.emplace_back(it.index()); 42 } else { 43 newIndexingMaps.emplace_back(map); 44 newOperands.emplace_back(std::get<1>(it.value())); 45 } 46 } 47 48 if (scalarOperands.empty()) 49 return failure(); 50 51 newIndexingMaps.append(genericOp.getOutputIndexingMaps()); 52 53 Location loc = genericOp->getLoc(); 54 auto newOp = rewriter.create<GenericOp>( 55 loc, genericOp->getResultTypes(), newOperands, 56 genericOp.getOutputTensors(), newIndexingMaps, 57 llvm::to_vector<4>( 58 genericOp.iterator_types().template getAsValueRange<StringAttr>())); 59 rewriter.cloneRegionBefore(genericOp.region(), newOp.region(), 60 newOp.region().begin()); 61 62 Block *body = newOp.getBody(); 63 PatternRewriter::InsertionGuard guard(rewriter); 64 rewriter.setInsertionPointToStart(body); 65 66 for (auto idx : llvm::reverse(scalarOperands)) { 67 Value operand = genericOp.getInput(idx); 68 AffineMap map = genericOp.getInputIndexingMap(idx); 69 SmallVector<int64_t> indices = map.getConstantResults(); 70 SmallVector<Value> indicesValues; 71 for (auto idx : indices) 72 indicesValues.emplace_back(rewriter.create<ConstantIndexOp>(loc, idx)); 73 operand = rewriter.create<tensor::ExtractOp>(loc, operand, indicesValues); 74 body->getArgument(idx).replaceAllUsesWith(operand); 75 body->eraseArgument(idx); 76 } 77 78 rewriter.replaceOp(genericOp, newOp->getResults()); 79 return success(); 80 } 81 }; 82 } // namespace 83 84 /// Patterns that are used to inline constant operands into linalg generic 85 /// ops. 86 void mlir::linalg::populateInlineConstantOperandsPatterns( 87 RewritePatternSet &patterns) { 88 auto *context = patterns.getContext(); 89 patterns.add<InlineScalarOperands>(context); 90 } 91 92 namespace { 93 /// Pass that removes unit-extent dims within generic ops. 94 struct LinalgInlineScalarOperandsPass 95 : public LinalgInlineScalarOperandsBase<LinalgInlineScalarOperandsPass> { 96 void runOnFunction() override { 97 FuncOp funcOp = getFunction(); 98 MLIRContext *context = funcOp.getContext(); 99 RewritePatternSet patterns(context); 100 101 populateInlineConstantOperandsPatterns(patterns); 102 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 103 } 104 }; 105 } // namespace 106 107 std::unique_ptr<OperationPass<FuncOp>> 108 mlir::createLinalgInlineScalarOperandsPass() { 109 return std::make_unique<LinalgInlineScalarOperandsPass>(); 110 } 111