1 //===- Generalization.cpp - linalg named ops to generic ops --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg generalization pass. It converts named 10 // Linalg ops to linalg.generic ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/EDSC/Builders.h" 19 #include "mlir/IR/AffineMap.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-generalization" 28 29 using namespace mlir; 30 31 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if 32 // the given `namedOp` does not have a region builder. 33 static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp, 34 OpBuilder &builder) { 35 auto regionBuilder = namedOp.getRegionBuilder(); 36 if (!regionBuilder) { 37 LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n"); 38 return nullptr; 39 } 40 41 SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps(); 42 auto iterators = llvm::to_vector<4>( 43 namedOp.iterator_types().getAsValueRange<StringAttr>()); 44 auto resultTypes = namedOp.getOutputTensorTypes(); 45 SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end()); 46 47 return builder.create<linalg::GenericOp>( 48 namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputBuffers(), 49 namedOp.getInitTensors(), indexingMaps, iterators, 50 [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { 51 edsc::ScopedContext scope(bodyBuilder, loc); 52 regionBuilder(*bodyBuilder.getBlock()); 53 }); 54 } 55 56 namespace { 57 58 /// Base class for all linalg generalization patterns. A subclass must provide 59 /// the following method: 60 /// linalg::GenericOp createGenericOp(RootOp, PatternRewriter &) 61 /// for creating the generic op. 62 // TODO: remove this pattern after migrating all manually-written named ops 63 // into auto-generated ones. 64 template <typename ConcretePattern, typename RootOp> 65 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> { 66 LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker, 67 PatternBenefit benefit = 1) 68 : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {} 69 70 LogicalResult matchAndRewrite(RootOp rootOp, 71 PatternRewriter &rewriter) const override { 72 auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation()); 73 if (!linalgOp) 74 return failure(); 75 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 76 return failure(); 77 78 auto *pattern = static_cast<const ConcretePattern *>(this); 79 linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter); 80 if (!genericOp) 81 return failure(); 82 83 rewriter.replaceOp(rootOp, genericOp.getResults()); 84 marker.replaceLinalgMarker(rewriter, genericOp.getOperation()); 85 return success(); 86 } 87 88 private: 89 linalg::LinalgMarker marker; 90 }; 91 92 struct GeneralizeConvOp 93 : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> { 94 using LinalgGeneralizationPattern::LinalgGeneralizationPattern; 95 96 linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const; 97 }; 98 99 /// Catch-all pattern for converting all named ops with a region builder into 100 /// linalg.generic. 101 struct LinalgNamedOpGeneralizationPattern : RewritePattern { 102 LinalgNamedOpGeneralizationPattern(MLIRContext *context, 103 linalg::LinalgMarker marker, 104 PatternBenefit benefit = 1) 105 : RewritePattern(benefit, MatchAnyOpTypeTag()), 106 marker(std::move(marker)) {} 107 108 LogicalResult matchAndRewrite(Operation *rootOp, 109 PatternRewriter &rewriter) const override { 110 auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp); 111 if (!linalgOp) 112 return failure(); 113 if (failed(marker.checkAndNotify(rewriter, linalgOp))) 114 return failure(); 115 116 // No nothing to do for linalg.generic and linalg.indexed_generic. 117 if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp)) 118 return failure(); 119 120 linalg::GenericOp genericOp = 121 createGenericOpFromNamedOp(linalgOp, rewriter); 122 if (!genericOp) 123 return failure(); 124 125 rewriter.replaceOp(rootOp, genericOp.getResults()); 126 marker.replaceLinalgMarker(rewriter, genericOp.getOperation()); 127 return success(); 128 } 129 130 private: 131 linalg::LinalgMarker marker; 132 }; 133 134 struct LinalgGeneralizationPass 135 : public LinalgGeneralizationBase<LinalgGeneralizationPass> { 136 void runOnFunction() override; 137 }; 138 139 } // namespace 140 141 void LinalgGeneralizationPass::runOnFunction() { 142 FuncOp func = getFunction(); 143 OwningRewritePatternList patterns; 144 linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns); 145 linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns); 146 applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); 147 } 148 149 linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp, 150 OpBuilder &builder) const { 151 SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps(); 152 auto iterators = 153 llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>()); 154 return builder.create<linalg::GenericOp>( 155 convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(), 156 convOp.getInputBuffers(), convOp.getOutputBuffers(), 157 /*initTensors=*/ValueRange(), indexingMaps, iterators, 158 [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) { 159 Value mul = 160 bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]); 161 Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]); 162 bodyBuilder.create<linalg::YieldOp>(bodyLoc, add); 163 }); 164 } 165 166 void mlir::linalg::populateLinalgConvGeneralizationPatterns( 167 MLIRContext *context, OwningRewritePatternList &patterns, 168 linalg::LinalgMarker marker) { 169 patterns.insert<GeneralizeConvOp>(context, marker); 170 } 171 172 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 173 MLIRContext *context, OwningRewritePatternList &patterns, 174 linalg::LinalgMarker marker) { 175 patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker); 176 } 177 178 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() { 179 return std::make_unique<LinalgGeneralizationPass>(); 180 } 181