1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/IR/SCF.h" 19 20 using namespace mlir; 21 using namespace mlir::linalg; 22 23 /// Helper function to extract the operand types that are passed to the 24 /// generated CallOp. MemRefTypes have their layout canonicalized since the 25 /// information is not used in signature generation. 26 /// Note that static size information is not modified. 27 static SmallVector<Type, 4> extractOperandTypes(Operation *op) { 28 SmallVector<Type, 4> result; 29 result.reserve(op->getNumOperands()); 30 for (auto type : op->getOperandTypes()) { 31 // The underlying descriptor type (e.g. LLVM) does not have layout 32 // information. Canonicalizing the type at the level of std when going into 33 // a library call avoids needing to introduce DialectCastOp. 34 if (auto memrefType = type.dyn_cast<MemRefType>()) 35 result.push_back(eraseStridedLayout(memrefType)); 36 else 37 result.push_back(type); 38 } 39 return result; 40 } 41 42 // Get a SymbolRefAttr containing the library function name for the LinalgOp. 43 // If the library function does not exist, insert a declaration. 44 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, 45 PatternRewriter &rewriter) { 46 auto linalgOp = cast<LinalgOp>(op); 47 auto fnName = linalgOp.getLibraryCallName(); 48 if (fnName.empty()) { 49 op->emitWarning("No library call defined for: ") << *op; 50 return {}; 51 } 52 53 // fnName is a dynamic std::string, unique it via a SymbolRefAttr. 54 FlatSymbolRefAttr fnNameAttr = 55 SymbolRefAttr::get(rewriter.getContext(), fnName); 56 auto module = op->getParentOfType<ModuleOp>(); 57 if (module.lookupSymbol(fnNameAttr.getAttr())) 58 return fnNameAttr; 59 60 SmallVector<Type, 4> inputTypes(extractOperandTypes(op)); 61 assert(op->getNumResults() == 0 && 62 "Library call for linalg operation can be generated only for ops that " 63 "have void return types"); 64 auto libFnType = rewriter.getFunctionType(inputTypes, {}); 65 66 OpBuilder::InsertionGuard guard(rewriter); 67 // Insert before module terminator. 68 rewriter.setInsertionPoint(module.getBody(), 69 std::prev(module.getBody()->end())); 70 func::FuncOp funcOp = rewriter.create<func::FuncOp>( 71 op->getLoc(), fnNameAttr.getValue(), libFnType); 72 // Insert a function attribute that will trigger the emission of the 73 // corresponding `_mlir_ciface_xxx` interface so that external libraries see 74 // a normalized ABI. This interface is added during std to llvm conversion. 75 funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), 76 UnitAttr::get(op->getContext())); 77 funcOp.setPrivate(); 78 return fnNameAttr; 79 } 80 81 static SmallVector<Value, 4> 82 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, 83 ValueRange operands) { 84 SmallVector<Value, 4> res; 85 res.reserve(operands.size()); 86 for (auto op : operands) { 87 auto memrefType = op.getType().dyn_cast<MemRefType>(); 88 if (!memrefType) { 89 res.push_back(op); 90 continue; 91 } 92 Value cast = 93 b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op); 94 res.push_back(cast); 95 } 96 return res; 97 } 98 99 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( 100 LinalgOp op, PatternRewriter &rewriter) const { 101 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 102 if (!libraryCallName) 103 return failure(); 104 105 // TODO: Add support for more complex library call signatures that include 106 // indices or captured values. 107 rewriter.replaceOpWithNewOp<func::CallOp>( 108 op, libraryCallName.getValue(), TypeRange(), 109 createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), 110 op->getOperands())); 111 return success(); 112 } 113 114 /// Populate the given list with patterns that convert from Linalg to Standard. 115 void mlir::linalg::populateLinalgToStandardConversionPatterns( 116 RewritePatternSet &patterns) { 117 // TODO: ConvOp conversion needs to export a descriptor with relevant 118 // attribute values such as kernel striding and dilation. 119 patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext()); 120 } 121 122 namespace { 123 struct ConvertLinalgToStandardPass 124 : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> { 125 void runOnOperation() override; 126 }; 127 } // namespace 128 129 void ConvertLinalgToStandardPass::runOnOperation() { 130 auto module = getOperation(); 131 ConversionTarget target(getContext()); 132 target.addLegalDialect<AffineDialect, arith::ArithmeticDialect, 133 func::FuncDialect, memref::MemRefDialect, 134 scf::SCFDialect>(); 135 target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(); 136 RewritePatternSet patterns(&getContext()); 137 populateLinalgToStandardConversionPatterns(patterns); 138 if (failed(applyFullConversion(module, target, std::move(patterns)))) 139 signalPassFailure(); 140 } 141 142 std::unique_ptr<OperationPass<ModuleOp>> 143 mlir::createConvertLinalgToStandardPass() { 144 return std::make_unique<ConvertLinalgToStandardPass>(); 145 } 146