1 //===- Generalization.cpp - linalg named ops to generic ops --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg generalization pass. It converts named 10 // Linalg ops to linalg.generic ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/Attributes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-generalization" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if 33 // the given `namedOp` does not have a region builder. 34 static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp, 35 PatternRewriter &rewriter) { 36 SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps(); 37 SmallVector<StringRef> iterators = llvm::to_vector<4>( 38 namedOp.iterator_types().getAsValueRange<StringAttr>()); 39 SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes(); 40 SmallVector<Type> types(resultTypes.begin(), resultTypes.end()); 41 42 // Inline the existing region if the named operation has a region attached. 43 if (namedOp->getNumRegions() == 1) { 44 GenericOp genericOp = rewriter.create<GenericOp>( 45 namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(), 46 indexingMaps, iterators); 47 rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), 48 genericOp.region().begin()); 49 return genericOp; 50 } 51 52 // Otherwise use the region builder to generate a new region. 53 // TODO: Remove this path once all linag operations have a region attached. 54 auto regionBuilder = namedOp.getRegionBuilder(); 55 if (!regionBuilder) { 56 LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n"); 57 return nullptr; 58 } 59 return rewriter.create<GenericOp>( 60 namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(), 61 indexingMaps, iterators, 62 [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { 63 ImplicitLocOpBuilder b(loc, bodyBuilder); 64 regionBuilder(b, *bodyBuilder.getBlock(), 65 /*captures=*/{}); 66 }); 67 } 68 69 namespace { 70 71 /// Base class for all linalg generalization patterns. A subclass must provide 72 /// the following method: 73 /// GenericOp createGenericOp(RootOp, PatternRewriter &) 74 /// for creating the generic op. 75 // TODO: remove this pattern after migrating all manually-written named ops 76 // into auto-generated ones. 77 template <typename ConcretePattern, typename RootOp> 78 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> { 79 LinalgGeneralizationPattern(MLIRContext *context, 80 LinalgTransformationFilter marker, 81 PatternBenefit benefit = 1) 82 : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {} 83 84 LogicalResult matchAndRewrite(RootOp rootOp, 85 PatternRewriter &rewriter) const override { 86 auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation()); 87 if (!linalgOp) 88 return failure(); 89 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 90 return failure(); 91 92 auto *pattern = static_cast<const ConcretePattern *>(this); 93 GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter); 94 if (!genericOp) 95 return failure(); 96 97 rewriter.replaceOp(rootOp, genericOp.getResults()); 98 marker.replaceLinalgTransformationFilter(rewriter, 99 genericOp.getOperation()); 100 return success(); 101 } 102 103 private: 104 LinalgTransformationFilter marker; 105 }; 106 107 struct GeneralizeConvOp 108 : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> { 109 using LinalgGeneralizationPattern::LinalgGeneralizationPattern; 110 111 GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const; 112 }; 113 114 /// Catch-all pattern for converting all named ops with a region builder into 115 /// linalg.generic. 116 struct LinalgNamedOpGeneralizationPattern : RewritePattern { 117 LinalgNamedOpGeneralizationPattern(MLIRContext *context, 118 LinalgTransformationFilter marker, 119 PatternBenefit benefit = 1) 120 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 121 marker(std::move(marker)) {} 122 123 LogicalResult matchAndRewrite(Operation *rootOp, 124 PatternRewriter &rewriter) const override { 125 auto linalgOp = dyn_cast<LinalgOp>(rootOp); 126 if (!linalgOp) 127 return failure(); 128 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 129 return failure(); 130 131 // No nothing to do for linalg.generic and linalg.indexed_generic. 132 if (isa<GenericOp, IndexedGenericOp>(rootOp)) 133 return failure(); 134 135 GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter); 136 if (!genericOp) 137 return failure(); 138 139 rewriter.replaceOp(rootOp, genericOp.getResults()); 140 marker.replaceLinalgTransformationFilter(rewriter, 141 genericOp.getOperation()); 142 return success(); 143 } 144 145 private: 146 LinalgTransformationFilter marker; 147 }; 148 149 struct LinalgGeneralizationPass 150 : public LinalgGeneralizationBase<LinalgGeneralizationPass> { 151 void runOnFunction() override; 152 }; 153 154 } // namespace 155 156 void LinalgGeneralizationPass::runOnFunction() { 157 FuncOp func = getFunction(); 158 RewritePatternSet patterns(&getContext()); 159 populateLinalgConvGeneralizationPatterns(patterns); 160 populateLinalgNamedOpsGeneralizationPatterns(patterns); 161 (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); 162 } 163 164 GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp, 165 OpBuilder &builder) const { 166 SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps(); 167 auto iterators = 168 llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>()); 169 return builder.create<GenericOp>( 170 convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(), 171 convOp.getInputBuffers(), convOp.getOutputBuffers(), indexingMaps, 172 iterators, 173 [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) { 174 Value mul = 175 bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]); 176 Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]); 177 bodyBuilder.create<YieldOp>(bodyLoc, add); 178 }); 179 } 180 181 void mlir::linalg::populateLinalgConvGeneralizationPatterns( 182 RewritePatternSet &patterns, LinalgTransformationFilter marker) { 183 patterns.add<GeneralizeConvOp>(patterns.getContext(), marker); 184 } 185 186 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 187 RewritePatternSet &patterns, LinalgTransformationFilter marker) { 188 patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(), 189 marker); 190 } 191 192 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() { 193 return std::make_unique<LinalgGeneralizationPass>(); 194 } 195