1 //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 17 using namespace mlir; 18 using namespace mlir::linalg; 19 20 /// Helper function to extract the operand types that are passed to the 21 /// generated CallOp. MemRefTypes have their layout canonicalized since the 22 /// information is not used in signature generation. 23 /// Note that static size information is not modified. 24 template <typename LinalgOp> 25 static SmallVector<Type, 4> extractOperandTypes(Operation *op) { 26 SmallVector<Type, 4> result; 27 result.reserve(op->getNumOperands()); 28 for (auto type : op->getOperandTypes()) { 29 // The underlying descriptor type (e.g. LLVM) does not have layout 30 // information. Canonicalizing the type at the level of std when going into 31 // a library call avoids needing to introduce DialectCastOp. 32 if (auto memrefType = type.dyn_cast<MemRefType>()) 33 result.push_back(eraseStridedLayout(memrefType)); 34 else 35 result.push_back(type); 36 } 37 return result; 38 } 39 40 template <> 41 SmallVector<Type, 4> extractOperandTypes<IndexedGenericOp>(Operation *op) { 42 auto *ctx = op->getContext(); 43 auto indexedGenericOp = cast<IndexedGenericOp>(op); 44 auto numLoops = indexedGenericOp.getNumLoops(); 45 46 SmallVector<Type, 4> result(numLoops, IndexType::get(ctx)); 47 auto canonicalizedOperands = extractOperandTypes<LinalgOp>(op); 48 result.append(canonicalizedOperands.begin(), canonicalizedOperands.end()); 49 return result; 50 } 51 52 // Get a SymbolRefAttr containing the library function name for the LinalgOp. 53 // If the library function does not exist, insert a declaration. 54 template <typename LinalgOp> 55 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, 56 PatternRewriter &rewriter) { 57 auto linalgOp = cast<LinalgOp>(op); 58 auto fnName = linalgOp.getLibraryCallName(); 59 if (fnName.empty()) { 60 op->emitWarning("No library call defined for: ") << *op; 61 return {}; 62 } 63 64 // fnName is a dynamic std::string, unique it via a SymbolRefAttr. 65 FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); 66 auto module = op->getParentOfType<ModuleOp>(); 67 if (module.lookupSymbol(fnName)) { 68 return fnNameAttr; 69 } 70 71 SmallVector<Type, 4> inputTypes(extractOperandTypes<LinalgOp>(op)); 72 assert(op->getNumResults() == 0 && 73 "Library call for linalg operation can be generated only for ops that " 74 "have void return types"); 75 auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext()); 76 77 OpBuilder::InsertionGuard guard(rewriter); 78 // Insert before module terminator. 79 rewriter.setInsertionPoint(module.getBody(), 80 std::prev(module.getBody()->end())); 81 FuncOp funcOp = 82 rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType); 83 // Insert a function attribute that will trigger the emission of the 84 // corresponding `_mlir_ciface_xxx` interface so that external libraries see 85 // a normalized ABI. This interface is added during std to llvm conversion. 86 funcOp.setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext())); 87 return fnNameAttr; 88 } 89 90 namespace { 91 92 SmallVector<Value, 4> 93 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, 94 ValueRange operands) { 95 SmallVector<Value, 4> res; 96 res.reserve(operands.size()); 97 for (auto op : operands) { 98 auto memrefType = op.getType().dyn_cast<MemRefType>(); 99 if (!memrefType) { 100 res.push_back(op); 101 continue; 102 } 103 Value cast = 104 b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op); 105 res.push_back(cast); 106 } 107 return res; 108 } 109 110 // LinalgOpConversion<LinalgOp> creates a new call to the type-canonicalized 111 // `LinalgOp::getLibraryCallName()` function. 112 // The implementation of the function can be either in the same module or in an 113 // externally linked library. 114 template <typename LinalgOp> 115 class LinalgOpConversion : public OpRewritePattern<LinalgOp> { 116 public: 117 using OpRewritePattern<LinalgOp>::OpRewritePattern; 118 119 LogicalResult matchAndRewrite(LinalgOp op, 120 PatternRewriter &rewriter) const override { 121 auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter); 122 if (!libraryCallName) 123 return failure(); 124 125 rewriter.replaceOpWithNewOp<mlir::CallOp>( 126 op, libraryCallName.getValue(), ArrayRef<Type>{}, 127 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), 128 op.getOperands())); 129 return success(); 130 } 131 }; 132 133 /// Conversion pattern specialization for CopyOp. This kicks in when both input 134 /// and output permutations are left unspecified or are the identity. 135 template <> 136 class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> { 137 public: 138 using OpRewritePattern<CopyOp>::OpRewritePattern; 139 140 LogicalResult matchAndRewrite(CopyOp op, 141 PatternRewriter &rewriter) const override { 142 auto inputPerm = op.inputPermutation(); 143 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 144 return failure(); 145 auto outputPerm = op.outputPermutation(); 146 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 147 return failure(); 148 149 auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter); 150 if (!libraryCallName) 151 return failure(); 152 153 rewriter.replaceOpWithNewOp<mlir::CallOp>( 154 op, libraryCallName.getValue(), ArrayRef<Type>{}, 155 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), 156 op.getOperands())); 157 return success(); 158 } 159 }; 160 161 /// Conversion pattern specialization for IndexedGenericOp. 162 template <> 163 class LinalgOpConversion<IndexedGenericOp> 164 : public OpRewritePattern<IndexedGenericOp> { 165 public: 166 using OpRewritePattern<IndexedGenericOp>::OpRewritePattern; 167 168 LogicalResult matchAndRewrite(IndexedGenericOp op, 169 PatternRewriter &rewriter) const override { 170 auto libraryCallName = 171 getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter); 172 if (!libraryCallName) 173 return failure(); 174 175 // TODO: Use induction variables values instead of zeros, when 176 // IndexedGenericOp is tiled. 177 auto zero = rewriter.create<mlir::ConstantOp>( 178 op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 179 auto indexedGenericOp = cast<IndexedGenericOp>(op); 180 auto numLoops = indexedGenericOp.getNumLoops(); 181 SmallVector<Value, 4> operands; 182 operands.reserve(numLoops + op.getNumOperands()); 183 for (unsigned i = 0; i < numLoops; ++i) 184 operands.push_back(zero); 185 for (auto operand : op.getOperands()) 186 operands.push_back(operand); 187 rewriter.replaceOpWithNewOp<mlir::CallOp>( 188 op, libraryCallName.getValue(), ArrayRef<Type>{}, 189 createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands)); 190 return success(); 191 } 192 }; 193 194 /// A non-conversion rewrite pattern kicks in to convert CopyOp with 195 /// permutations into a sequence of TransposeOp and permutation-free CopyOp. 196 /// This interplays together with TransposeOpConversion and 197 /// LinalgConversion<CopyOp> to create a path to the LLVM dialect. 198 class CopyTransposeConversion : public OpRewritePattern<CopyOp> { 199 public: 200 using OpRewritePattern<CopyOp>::OpRewritePattern; 201 202 LogicalResult matchAndRewrite(CopyOp op, 203 PatternRewriter &rewriter) const override { 204 Value in = op.input(), out = op.output(); 205 206 // If either inputPerm or outputPerm are non-identities, insert transposes. 207 auto inputPerm = op.inputPermutation(); 208 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 209 in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in, 210 AffineMapAttr::get(*inputPerm)); 211 auto outputPerm = op.outputPermutation(); 212 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 213 out = rewriter.create<linalg::TransposeOp>( 214 op.getLoc(), out, AffineMapAttr::get(*outputPerm)); 215 216 // If nothing was transposed, fail and let the conversion kick in. 217 if (in == op.input() && out == op.output()) 218 return failure(); 219 220 rewriter.replaceOpWithNewOp<CopyOp>(op, in, out); 221 return success(); 222 } 223 }; 224 } // namespace 225 226 /// Populate the given list with patterns that convert from Linalg to Standard. 227 void mlir::populateLinalgToStandardConversionPatterns( 228 OwningRewritePatternList &patterns, MLIRContext *ctx) { 229 // TODO: ConvOp conversion needs to export a descriptor with relevant 230 // attribute values such as kernel striding and dilation. 231 // clang-format off 232 patterns.insert< 233 CopyTransposeConversion, 234 LinalgOpConversion<ConvOp>, 235 LinalgOpConversion<PoolingMaxOp>, 236 LinalgOpConversion<PoolingMinOp>, 237 LinalgOpConversion<PoolingSumOp>, 238 LinalgOpConversion<CopyOp>, 239 LinalgOpConversion<FillOp>, 240 LinalgOpConversion<GenericOp>, 241 LinalgOpConversion<IndexedGenericOp>>(ctx); 242 // TODO: collect all auto-generated named ops with a tblgen directive. 243 patterns.insert< 244 LinalgOpConversion<DotOp>, 245 LinalgOpConversion<BatchMatmulOp>, 246 LinalgOpConversion<MatvecOp>, 247 LinalgOpConversion<VecmatOp>, 248 LinalgOpConversion<MatmulOp>, 249 LinalgOpConversion<ConvWOp>, 250 LinalgOpConversion<ConvNWCOp>, 251 LinalgOpConversion<ConvNCWOp>, 252 LinalgOpConversion<ConvHWOp>, 253 LinalgOpConversion<ConvNHWCOp>, 254 LinalgOpConversion<ConvNCHWOp>, 255 LinalgOpConversion<ConvDHWOp>, 256 LinalgOpConversion<ConvNDHWCOp>, 257 LinalgOpConversion<ConvNCDHWOp>>(ctx); 258 // clang-format on 259 } 260 261 namespace { 262 struct ConvertLinalgToStandardPass 263 : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> { 264 void runOnOperation() override; 265 }; 266 } // namespace 267 268 void ConvertLinalgToStandardPass::runOnOperation() { 269 auto module = getOperation(); 270 ConversionTarget target(getContext()); 271 target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>(); 272 target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>(); 273 target.addLegalOp<linalg::TransposeOp, linalg::ReshapeOp, linalg::RangeOp>(); 274 OwningRewritePatternList patterns; 275 populateLinalgToStandardConversionPatterns(patterns, &getContext()); 276 if (failed(applyFullConversion(module, target, patterns))) 277 signalPassFailure(); 278 } 279 280 std::unique_ptr<OperationPass<ModuleOp>> 281 mlir::createConvertLinalgToStandardPass() { 282 return std::make_unique<ConvertLinalgToStandardPass>(); 283 } 284