1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 19 using namespace mlir; 20 using namespace mlir::linalg; 21 22 /// Helper function to extract the operand types that are passed to the 23 /// generated CallOp. MemRefTypes have their layout canonicalized since the 24 /// information is not used in signature generation. 25 /// Note that static size information is not modified. 26 static SmallVector<Type, 4> extractOperandTypes(Operation *op) { 27 SmallVector<Type, 4> result; 28 result.reserve(op->getNumOperands()); 29 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) { 30 auto *ctx = op->getContext(); 31 auto numLoops = indexedGenericOp.getNumLoops(); 32 result.reserve(op->getNumOperands() + numLoops); 33 result.assign(numLoops, IndexType::get(ctx)); 34 } 35 for (auto type : op->getOperandTypes()) { 36 // The underlying descriptor type (e.g. LLVM) does not have layout 37 // information. Canonicalizing the type at the level of std when going into 38 // a library call avoids needing to introduce DialectCastOp. 39 if (auto memrefType = type.dyn_cast<MemRefType>()) 40 result.push_back(eraseStridedLayout(memrefType)); 41 else 42 result.push_back(type); 43 } 44 return result; 45 } 46 47 // Get a SymbolRefAttr containing the library function name for the LinalgOp. 48 // If the library function does not exist, insert a declaration. 49 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, 50 PatternRewriter &rewriter) { 51 auto linalgOp = cast<LinalgOp>(op); 52 auto fnName = linalgOp.getLibraryCallName(); 53 if (fnName.empty()) { 54 op->emitWarning("No library call defined for: ") << *op; 55 return {}; 56 } 57 58 // fnName is a dynamic std::string, unique it via a SymbolRefAttr. 59 FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); 60 auto module = op->getParentOfType<ModuleOp>(); 61 if (module.lookupSymbol(fnName)) { 62 return fnNameAttr; 63 } 64 65 SmallVector<Type, 4> inputTypes(extractOperandTypes(op)); 66 assert(op->getNumResults() == 0 && 67 "Library call for linalg operation can be generated only for ops that " 68 "have void return types"); 69 auto libFnType = rewriter.getFunctionType(inputTypes, {}); 70 71 OpBuilder::InsertionGuard guard(rewriter); 72 // Insert before module terminator. 73 rewriter.setInsertionPoint(module.getBody(), 74 std::prev(module.getBody()->end())); 75 FuncOp funcOp = 76 rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType); 77 // Insert a function attribute that will trigger the emission of the 78 // corresponding `_mlir_ciface_xxx` interface so that external libraries see 79 // a normalized ABI. This interface is added during std to llvm conversion. 80 funcOp->setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext())); 81 funcOp.setPrivate(); 82 return fnNameAttr; 83 } 84 85 static SmallVector<Value, 4> 86 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, 87 ValueRange operands) { 88 SmallVector<Value, 4> res; 89 res.reserve(operands.size()); 90 for (auto op : operands) { 91 auto memrefType = op.getType().dyn_cast<MemRefType>(); 92 if (!memrefType) { 93 res.push_back(op); 94 continue; 95 } 96 Value cast = 97 b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op); 98 res.push_back(cast); 99 } 100 return res; 101 } 102 103 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( 104 LinalgOp op, PatternRewriter &rewriter) const { 105 // Only LinalgOp for which there is no specialized pattern go through this. 106 if (isa<CopyOp>(op) || isa<IndexedGenericOp>(op)) 107 return failure(); 108 109 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 110 if (!libraryCallName) 111 return failure(); 112 113 rewriter.replaceOpWithNewOp<mlir::CallOp>( 114 op, libraryCallName.getValue(), TypeRange(), 115 createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), 116 op->getOperands())); 117 return success(); 118 } 119 120 LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite( 121 CopyOp op, PatternRewriter &rewriter) const { 122 auto inputPerm = op.inputPermutation(); 123 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 124 return failure(); 125 auto outputPerm = op.outputPermutation(); 126 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 127 return failure(); 128 129 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 130 if (!libraryCallName) 131 return failure(); 132 133 rewriter.replaceOpWithNewOp<mlir::CallOp>( 134 op, libraryCallName.getValue(), TypeRange(), 135 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), 136 op.getOperands())); 137 return success(); 138 } 139 140 LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite( 141 CopyOp op, PatternRewriter &rewriter) const { 142 Value in = op.input(), out = op.output(); 143 144 // If either inputPerm or outputPerm are non-identities, insert transposes. 145 auto inputPerm = op.inputPermutation(); 146 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 147 in = rewriter.create<memref::TransposeOp>(op.getLoc(), in, 148 AffineMapAttr::get(*inputPerm)); 149 auto outputPerm = op.outputPermutation(); 150 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 151 out = rewriter.create<memref::TransposeOp>(op.getLoc(), out, 152 AffineMapAttr::get(*outputPerm)); 153 154 // If nothing was transposed, fail and let the conversion kick in. 155 if (in == op.input() && out == op.output()) 156 return failure(); 157 158 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 159 if (!libraryCallName) 160 return failure(); 161 162 rewriter.replaceOpWithNewOp<mlir::CallOp>( 163 op, libraryCallName.getValue(), TypeRange(), 164 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out})); 165 return success(); 166 } 167 168 LogicalResult 169 mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite( 170 IndexedGenericOp op, PatternRewriter &rewriter) const { 171 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); 172 if (!libraryCallName) 173 return failure(); 174 175 // TODO: Use induction variables values instead of zeros, when 176 // IndexedGenericOp is tiled. 177 auto zero = rewriter.create<mlir::ConstantOp>( 178 op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 179 auto indexedGenericOp = cast<IndexedGenericOp>(op); 180 auto numLoops = indexedGenericOp.getNumLoops(); 181 SmallVector<Value, 4> operands; 182 operands.reserve(numLoops + op.getNumOperands()); 183 for (unsigned i = 0; i < numLoops; ++i) 184 operands.push_back(zero); 185 for (auto operand : op.getOperands()) 186 operands.push_back(operand); 187 rewriter.replaceOpWithNewOp<mlir::CallOp>( 188 op, libraryCallName.getValue(), TypeRange(), 189 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands)); 190 return success(); 191 } 192 193 /// Populate the given list with patterns that convert from Linalg to Standard. 194 void mlir::linalg::populateLinalgToStandardConversionPatterns( 195 RewritePatternSet &patterns) { 196 // TODO: ConvOp conversion needs to export a descriptor with relevant 197 // attribute values such as kernel striding and dilation. 198 // clang-format off 199 patterns.add< 200 CopyOpToLibraryCallRewrite, 201 CopyTransposeRewrite, 202 IndexedGenericOpToLibraryCallRewrite, 203 LinalgOpToLibraryCallRewrite>(patterns.getContext()); 204 // clang-format on 205 } 206 207 namespace { 208 struct ConvertLinalgToStandardPass 209 : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> { 210 void runOnOperation() override; 211 }; 212 } // namespace 213 214 void ConvertLinalgToStandardPass::runOnOperation() { 215 auto module = getOperation(); 216 ConversionTarget target(getContext()); 217 target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect, 218 StandardOpsDialect>(); 219 target.addLegalOp<ModuleOp, FuncOp, ReturnOp>(); 220 target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>(); 221 RewritePatternSet patterns(&getContext()); 222 populateLinalgToStandardConversionPatterns(patterns); 223 if (failed(applyFullConversion(module, target, std::move(patterns)))) 224 signalPassFailure(); 225 } 226 227 std::unique_ptr<OperationPass<ModuleOp>> 228 mlir::createConvertLinalgToStandardPass() { 229 return std::make_unique<ConvertLinalgToStandardPass>(); 230 } 231