1 //===- Generalization.cpp - linalg named ops to generic ops --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg generalization pass. It converts named 10 // Linalg ops to linalg.generic ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/Attributes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-generalization" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if 33 // the given `namedOp` does not have a region builder. 34 static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp, 35 PatternRewriter &rewriter) { 36 SmallVector<Value> inputOperands = namedOp.getInputOperands(); 37 SmallVector<Value> outputOperands = namedOp.getOutputOperands(); 38 SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps(); 39 SmallVector<StringRef> iterators = llvm::to_vector<4>( 40 namedOp.iterator_types().getAsValueRange<StringAttr>()); 41 SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes(); 42 SmallVector<Type> types(resultTypes.begin(), resultTypes.end()); 43 44 // Inline the existing region if the named operation has a region attached. 45 if (namedOp->getNumRegions() == 1) { 46 GenericOp genericOp = 47 rewriter.create<GenericOp>(namedOp.getLoc(), types, inputOperands, 48 outputOperands, indexingMaps, iterators); 49 rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), 50 genericOp.region().begin()); 51 return genericOp; 52 } 53 54 // Otherwise use the region builder to generate a new region. 55 // TODO: Remove this path once all linag operations have a region attached. 56 auto regionBuilder = namedOp.getRegionBuilder(); 57 if (!regionBuilder) { 58 LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n"); 59 return nullptr; 60 } 61 return rewriter.create<GenericOp>( 62 namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps, 63 iterators, 64 [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { 65 ImplicitLocOpBuilder b(loc, bodyBuilder); 66 regionBuilder(b, *bodyBuilder.getBlock()); 67 }); 68 } 69 70 namespace { 71 72 /// Base class for all linalg generalization patterns. A subclass must provide 73 /// the following method: 74 /// GenericOp createGenericOp(RootOp, PatternRewriter &) 75 /// for creating the generic op. 76 // TODO: remove this pattern after migrating all manually-written named ops 77 // into auto-generated ones. 78 template <typename ConcretePattern, typename RootOp> 79 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> { 80 LinalgGeneralizationPattern(MLIRContext *context, 81 LinalgTransformationFilter marker, 82 PatternBenefit benefit = 1) 83 : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {} 84 85 LogicalResult matchAndRewrite(RootOp rootOp, 86 PatternRewriter &rewriter) const override { 87 auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation()); 88 if (!linalgOp) 89 return failure(); 90 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 91 return failure(); 92 93 auto *pattern = static_cast<const ConcretePattern *>(this); 94 GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter); 95 if (!genericOp) 96 return failure(); 97 98 rewriter.replaceOp(rootOp, genericOp.getResults()); 99 marker.replaceLinalgTransformationFilter(rewriter, 100 genericOp.getOperation()); 101 return success(); 102 } 103 104 private: 105 LinalgTransformationFilter marker; 106 }; 107 108 struct GeneralizeConvOp 109 : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> { 110 using LinalgGeneralizationPattern::LinalgGeneralizationPattern; 111 112 GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const; 113 }; 114 115 /// Catch-all pattern for converting all named ops with a region builder into 116 /// linalg.generic. 117 struct LinalgNamedOpGeneralizationPattern : RewritePattern { 118 LinalgNamedOpGeneralizationPattern(MLIRContext *context, 119 LinalgTransformationFilter marker, 120 PatternBenefit benefit = 1) 121 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 122 marker(std::move(marker)) {} 123 124 LogicalResult matchAndRewrite(Operation *rootOp, 125 PatternRewriter &rewriter) const override { 126 auto linalgOp = dyn_cast<LinalgOp>(rootOp); 127 if (!linalgOp) 128 return failure(); 129 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 130 return failure(); 131 132 // No nothing to do for linalg.generic. 133 if (isa<GenericOp>(rootOp)) 134 return failure(); 135 136 GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter); 137 if (!genericOp) 138 return failure(); 139 140 rewriter.replaceOp(rootOp, genericOp.getResults()); 141 marker.replaceLinalgTransformationFilter(rewriter, 142 genericOp.getOperation()); 143 return success(); 144 } 145 146 private: 147 LinalgTransformationFilter marker; 148 }; 149 150 struct LinalgGeneralizationPass 151 : public LinalgGeneralizationBase<LinalgGeneralizationPass> { 152 void runOnFunction() override; 153 }; 154 155 } // namespace 156 157 void LinalgGeneralizationPass::runOnFunction() { 158 FuncOp func = getFunction(); 159 RewritePatternSet patterns(&getContext()); 160 populateLinalgConvGeneralizationPatterns(patterns); 161 populateLinalgNamedOpsGeneralizationPatterns(patterns); 162 (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); 163 } 164 165 GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp, 166 OpBuilder &builder) const { 167 SmallVector<AffineMap> indexingMaps = convOp.getIndexingMaps(); 168 auto iterators = 169 llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>()); 170 SmallVector<Value> inputBuffers = convOp.getInputBufferOperands(); 171 SmallVector<Value> outputBuffers = convOp.getOutputBufferOperands(); 172 return builder.create<GenericOp>( 173 convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(), inputBuffers, 174 outputBuffers, indexingMaps, iterators, 175 [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) { 176 Value mul = 177 bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]); 178 Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]); 179 bodyBuilder.create<YieldOp>(bodyLoc, add); 180 }); 181 } 182 183 void mlir::linalg::populateLinalgConvGeneralizationPatterns( 184 RewritePatternSet &patterns, LinalgTransformationFilter marker) { 185 patterns.add<GeneralizeConvOp>(patterns.getContext(), marker); 186 } 187 188 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 189 RewritePatternSet &patterns, LinalgTransformationFilter marker) { 190 patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(), 191 marker); 192 } 193 194 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() { 195 return std::make_unique<LinalgGeneralizationPass>(); 196 } 197