1 //===- Generalization.cpp - linalg named ops to generic ops --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg generalization pass. It converts named 10 // Linalg ops to linalg.generic ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/Attributes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-generalization" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) { 33 LinalgOp namedOp = dyn_cast<LinalgOp>(op); 34 // Check if the operation is a LinalgOp but not a GenericOp. 35 if (!namedOp || isa<GenericOp>(op)) 36 return failure(); 37 // Check if the operation has a region builder. 38 if (!namedOp.getRegionBuilder()) 39 return failure(); 40 return success(); 41 } 42 43 GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter, 44 LinalgOp namedOp) { 45 SmallVector<Value> inputOperands = namedOp.getInputOperands(); 46 SmallVector<Value> outputOperands = namedOp.getOutputOperands(); 47 SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps(); 48 SmallVector<StringRef> iterators = llvm::to_vector<4>( 49 namedOp.iterator_types().getAsValueRange<StringAttr>()); 50 SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes(); 51 SmallVector<Type> types(resultTypes.begin(), resultTypes.end()); 52 53 // Inline the existing region if the named operation has a region attached. 54 if (namedOp->getNumRegions() == 1) { 55 GenericOp genericOp = 56 rewriter.create<GenericOp>(namedOp.getLoc(), types, inputOperands, 57 outputOperands, indexingMaps, iterators); 58 rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), 59 genericOp.region().begin()); 60 return genericOp; 61 } 62 63 // Otherwise use the region builder to generate a new region. 64 // TODO: Remove this path once all linag operations have a region attached. 65 auto regionBuilder = namedOp.getRegionBuilder(); 66 assert(regionBuilder && "expect the operation to have region builder"); 67 return rewriter.create<GenericOp>( 68 namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps, 69 iterators, 70 [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { 71 ImplicitLocOpBuilder b(loc, bodyBuilder); 72 regionBuilder(b, *bodyBuilder.getBlock()); 73 }); 74 } 75 76 namespace { 77 78 /// Base class for all linalg generalization patterns. A subclass must provide 79 /// the following method: 80 /// GenericOp createGenericOp(RootOp, PatternRewriter &) 81 /// for creating the generic op. 82 // TODO: remove this pattern after migrating all manually-written named ops 83 // into auto-generated ones. 84 template <typename ConcretePattern, typename RootOp> 85 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> { 86 LinalgGeneralizationPattern(MLIRContext *context, 87 LinalgTransformationFilter marker, 88 PatternBenefit benefit = 1) 89 : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {} 90 91 LogicalResult matchAndRewrite(RootOp rootOp, 92 PatternRewriter &rewriter) const override { 93 auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation()); 94 if (!linalgOp) 95 return failure(); 96 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 97 return failure(); 98 99 auto *pattern = static_cast<const ConcretePattern *>(this); 100 GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter); 101 if (!genericOp) 102 return failure(); 103 104 rewriter.replaceOp(rootOp, genericOp.getResults()); 105 marker.replaceLinalgTransformationFilter(rewriter, 106 genericOp.getOperation()); 107 return success(); 108 } 109 110 private: 111 LinalgTransformationFilter marker; 112 }; 113 114 struct GeneralizeConvOp 115 : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> { 116 using LinalgGeneralizationPattern::LinalgGeneralizationPattern; 117 118 GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const; 119 }; 120 121 struct LinalgGeneralizationPass 122 : public LinalgGeneralizationBase<LinalgGeneralizationPass> { 123 void runOnFunction() override; 124 }; 125 126 } // namespace 127 128 void LinalgGeneralizationPass::runOnFunction() { 129 FuncOp func = getFunction(); 130 RewritePatternSet patterns(&getContext()); 131 populateLinalgConvGeneralizationPatterns(patterns); 132 populateLinalgNamedOpsGeneralizationPatterns(patterns); 133 (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); 134 } 135 136 GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp, 137 OpBuilder &builder) const { 138 SmallVector<AffineMap> indexingMaps = convOp.getIndexingMaps(); 139 auto iterators = 140 llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>()); 141 SmallVector<Value> inputBuffers = convOp.getInputBufferOperands(); 142 SmallVector<Value> outputBuffers = convOp.getOutputBufferOperands(); 143 return builder.create<GenericOp>( 144 convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(), inputBuffers, 145 outputBuffers, indexingMaps, iterators, 146 [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) { 147 Value mul = 148 bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]); 149 Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]); 150 bodyBuilder.create<YieldOp>(bodyLoc, add); 151 }); 152 } 153 154 void mlir::linalg::populateLinalgConvGeneralizationPatterns( 155 RewritePatternSet &patterns, LinalgTransformationFilter marker) { 156 patterns.add<GeneralizeConvOp>(patterns.getContext(), marker); 157 } 158 159 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 160 RewritePatternSet &patterns, LinalgTransformationFilter marker) { 161 patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker); 162 } 163 164 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() { 165 return std::make_unique<LinalgGeneralizationPass>(); 166 } 167