1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 18 using namespace mlir; 19 using namespace mlir::linalg; 20 21 /// Helper function to extract the operand types that are passed to the 22 /// generated CallOp. MemRefTypes have their layout canonicalized since the 23 /// information is not used in signature generation. 24 /// Note that static size information is not modified. 25 static SmallVector<Type, 4> extractOperandTypes(Operation *op) { 26 SmallVector<Type, 4> result; 27 result.reserve(op->getNumOperands()); 28 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) { 29 auto *ctx = op->getContext(); 30 auto numLoops = indexedGenericOp.getNumLoops(); 31 result.reserve(op->getNumOperands() + numLoops); 32 result.assign(numLoops, IndexType::get(ctx)); 33 } 34 for (auto type : op->getOperandTypes()) { 35 // The underlying descriptor type (e.g. LLVM) does not have layout 36 // information. Canonicalizing the type at the level of std when going into 37 // a library call avoids needing to introduce DialectCastOp. 38 if (auto memrefType = type.dyn_cast<MemRefType>()) 39 result.push_back(eraseStridedLayout(memrefType)); 40 else 41 result.push_back(type); 42 } 43 return result; 44 } 45 46 // Get a SymbolRefAttr containing the library function name for the LinalgOp. 47 // If the library function does not exist, insert a declaration. 48 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, 49 PatternRewriter &rewriter) { 50 auto linalgOp = cast<LinalgOp>(op); 51 auto fnName = linalgOp.getLibraryCallName(); 52 if (fnName.empty()) { 53 op->emitWarning("No library call defined for: ") << *op; 54 return {}; 55 } 56 57 // fnName is a dynamic std::string, unique it via a SymbolRefAttr. 58 FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); 59 auto module = op->getParentOfType<ModuleOp>(); 60 if (module.lookupSymbol(fnName)) { 61 return fnNameAttr; 62 } 63 64 SmallVector<Type, 4> inputTypes(extractOperandTypes(op)); 65 assert(op->getNumResults() == 0 && 66 "Library call for linalg operation can be generated only for ops that " 67 "have void return types"); 68 auto libFnType = rewriter.getFunctionType(inputTypes, {}); 69 70 OpBuilder::InsertionGuard guard(rewriter); 71 // Insert before module terminator. 72 rewriter.setInsertionPoint(module.getBody(), 73 std::prev(module.getBody()->end())); 74 FuncOp funcOp = 75 rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType); 76 // Insert a function attribute that will trigger the emission of the 77 // corresponding `_mlir_ciface_xxx` interface so that external libraries see 78 // a normalized ABI. This interface is added during std to llvm conversion. 79 funcOp->setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext())); 80 funcOp.setPrivate(); 81 return fnNameAttr; 82 } 83 84 static SmallVector<Value, 4> 85 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, 86 ValueRange operands) { 87 SmallVector<Value, 4> res; 88 res.reserve(operands.size()); 89 for (auto op : operands) { 90 auto memrefType = op.getType().dyn_cast<MemRefType>(); 91 if (!memrefType) { 92 res.push_back(op); 93 continue; 94 } 95 Value cast = 96 b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op); 97 res.push_back(cast); 98 } 99 return res; 100 } 101 102 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( 103 Operation *op, PatternRewriter &rewriter) const { 104 // Only LinalgOp for which there is no specialized pattern go through this. 105 if (!isa<LinalgOp>(op) || isa<CopyOp>(op) || isa<IndexedGenericOp>(op)) 106 return failure(); 107 108 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 109 if (!libraryCallName) 110 return failure(); 111 112 rewriter.replaceOpWithNewOp<mlir::CallOp>( 113 op, libraryCallName.getValue(), TypeRange(), 114 createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), 115 op->getOperands())); 116 return success(); 117 } 118 119 LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite( 120 CopyOp op, PatternRewriter &rewriter) const { 121 auto inputPerm = op.inputPermutation(); 122 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 123 return failure(); 124 auto outputPerm = op.outputPermutation(); 125 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 126 return failure(); 127 128 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 129 if (!libraryCallName) 130 return failure(); 131 132 rewriter.replaceOpWithNewOp<mlir::CallOp>( 133 op, libraryCallName.getValue(), TypeRange(), 134 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), 135 op.getOperands())); 136 return success(); 137 } 138 139 LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite( 140 CopyOp op, PatternRewriter &rewriter) const { 141 Value in = op.input(), out = op.output(); 142 143 // If either inputPerm or outputPerm are non-identities, insert transposes. 144 auto inputPerm = op.inputPermutation(); 145 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 146 in = rewriter.create<TransposeOp>(op.getLoc(), in, 147 AffineMapAttr::get(*inputPerm)); 148 auto outputPerm = op.outputPermutation(); 149 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 150 out = rewriter.create<TransposeOp>(op.getLoc(), out, 151 AffineMapAttr::get(*outputPerm)); 152 153 // If nothing was transposed, fail and let the conversion kick in. 154 if (in == op.input() && out == op.output()) 155 return failure(); 156 157 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 158 if (!libraryCallName) 159 return failure(); 160 161 rewriter.replaceOpWithNewOp<mlir::CallOp>( 162 op, libraryCallName.getValue(), TypeRange(), 163 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out})); 164 return success(); 165 } 166 167 LogicalResult 168 mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite( 169 IndexedGenericOp op, PatternRewriter &rewriter) const { 170 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 171 if (!libraryCallName) 172 return failure(); 173 174 // TODO: Use induction variables values instead of zeros, when 175 // IndexedGenericOp is tiled. 176 auto zero = rewriter.create<mlir::ConstantOp>( 177 op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 178 auto indexedGenericOp = cast<IndexedGenericOp>(op); 179 auto numLoops = indexedGenericOp.getNumLoops(); 180 SmallVector<Value, 4> operands; 181 operands.reserve(numLoops + op.getNumOperands()); 182 for (unsigned i = 0; i < numLoops; ++i) 183 operands.push_back(zero); 184 for (auto operand : op.getOperands()) 185 operands.push_back(operand); 186 rewriter.replaceOpWithNewOp<mlir::CallOp>( 187 op, libraryCallName.getValue(), TypeRange(), 188 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands)); 189 return success(); 190 } 191 192 /// Populate the given list with patterns that convert from Linalg to Standard. 193 void mlir::linalg::populateLinalgToStandardConversionPatterns( 194 OwningRewritePatternList &patterns, MLIRContext *ctx) { 195 // TODO: ConvOp conversion needs to export a descriptor with relevant 196 // attribute values such as kernel striding and dilation. 197 // clang-format off 198 patterns.insert< 199 CopyOpToLibraryCallRewrite, 200 CopyTransposeRewrite, 201 IndexedGenericOpToLibraryCallRewrite>(ctx); 202 patterns.insert<LinalgOpToLibraryCallRewrite>(); 203 // clang-format on 204 } 205 206 namespace { 207 struct ConvertLinalgToStandardPass 208 : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> { 209 void runOnOperation() override; 210 }; 211 } // namespace 212 213 void ConvertLinalgToStandardPass::runOnOperation() { 214 auto module = getOperation(); 215 ConversionTarget target(getContext()); 216 target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>(); 217 target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>(); 218 target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>(); 219 OwningRewritePatternList patterns; 220 populateLinalgToStandardConversionPatterns(patterns, &getContext()); 221 if (failed(applyFullConversion(module, target, std::move(patterns)))) 222 signalPassFailure(); 223 } 224 225 std::unique_ptr<OperationPass<ModuleOp>> 226 mlir::createConvertLinalgToStandardPass() { 227 return std::make_unique<ConvertLinalgToStandardPass>(); 228 } 229