1 //===- Generalization.cpp - linalg named ops to generic ops --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg generalization pass. It converts named 10 // Linalg ops to linalg.generic ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/IR/AffineMap.h" 19 #include "mlir/IR/Attributes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/ImplicitLocOpBuilder.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-generalization" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) { 33 LinalgOp namedOp = dyn_cast<LinalgOp>(op); 34 // Check if the operation is a LinalgOp but not a GenericOp. 35 if (!namedOp || isa<GenericOp>(op)) 36 return failure(); 37 // Check if the operation has a region builder. 38 if (!namedOp.getRegionBuilder()) 39 return failure(); 40 return success(); 41 } 42 43 GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter, 44 LinalgOp namedOp) { 45 SmallVector<Value> inputOperands = namedOp.getInputOperands(); 46 SmallVector<Value> outputOperands = namedOp.getOutputOperands(); 47 SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps(); 48 SmallVector<StringRef> iterators = llvm::to_vector<4>( 49 namedOp.iterator_types().getAsValueRange<StringAttr>()); 50 SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes(); 51 SmallVector<Type> types(resultTypes.begin(), resultTypes.end()); 52 53 // Inline the existing region if the named operation has a region attached. 54 if (namedOp->getNumRegions() == 1) { 55 GenericOp genericOp = 56 rewriter.create<GenericOp>(namedOp.getLoc(), types, inputOperands, 57 outputOperands, indexingMaps, iterators); 58 rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), 59 genericOp.region().begin()); 60 return genericOp; 61 } 62 63 // Otherwise use the region builder to generate a new region. 64 // TODO: Remove this path once all linag operations have a region attached. 65 auto regionBuilder = namedOp.getRegionBuilder(); 66 assert(regionBuilder && "expect the operation to have region builder"); 67 return rewriter.create<GenericOp>( 68 namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps, 69 iterators, 70 [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { 71 ImplicitLocOpBuilder b(loc, bodyBuilder); 72 regionBuilder(b, *bodyBuilder.getBlock()); 73 }); 74 } 75 76 namespace { 77 78 struct LinalgGeneralizationPass 79 : public LinalgGeneralizationBase<LinalgGeneralizationPass> { 80 void runOnFunction() override; 81 }; 82 83 } // namespace 84 85 void LinalgGeneralizationPass::runOnFunction() { 86 FuncOp func = getFunction(); 87 RewritePatternSet patterns(&getContext()); 88 populateLinalgNamedOpsGeneralizationPatterns(patterns); 89 (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); 90 } 91 92 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( 93 RewritePatternSet &patterns, LinalgTransformationFilter marker) { 94 patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker); 95 } 96 97 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() { 98 return std::make_unique<LinalgGeneralizationPass>(); 99 } 100