1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 11 #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" 20 #include "mlir/EDSC/Builders.h" 21 #include "mlir/EDSC/Intrinsics.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/Attributes.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/MLIRContext.h" 27 #include "mlir/IR/Module.h" 28 #include "mlir/IR/Operation.h" 29 #include "mlir/IR/PatternMatch.h" 30 #include "mlir/IR/StandardTypes.h" 31 #include "mlir/IR/Types.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Pass/PassManager.h" 34 #include "mlir/Support/LogicalResult.h" 35 #include "mlir/Transforms/DialectConversion.h" 36 #include "mlir/Transforms/Passes.h" 37 38 #include "llvm/ADT/SetVector.h" 39 #include "llvm/IR/DerivedTypes.h" 40 #include "llvm/IR/Module.h" 41 #include "llvm/IR/Type.h" 42 #include "llvm/Support/Allocator.h" 43 #include "llvm/Support/ErrorHandling.h" 44 45 using namespace mlir; 46 using namespace mlir::edsc; 47 using namespace mlir::edsc::intrinsics; 48 using namespace mlir::LLVM; 49 using namespace mlir::linalg; 50 using namespace mlir::linalg::intrinsics; 51 52 using add = ValueBuilder<mlir::LLVM::AddOp>; 53 using addi = ValueBuilder<mlir::AddIOp>; 54 using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>; 55 using cmpi = ValueBuilder<mlir::CmpIOp>; 56 using constant = ValueBuilder<mlir::LLVM::ConstantOp>; 57 using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>; 58 using gep = ValueBuilder<mlir::LLVM::GEPOp>; 59 using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>; 60 using llvm_call = OperationBuilder<mlir::LLVM::CallOp>; 61 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>; 62 using llvm_load = ValueBuilder<LLVM::LoadOp>; 63 using llvm_store = OperationBuilder<LLVM::StoreOp>; 64 using llvm_select = ValueBuilder<LLVM::SelectOp>; 65 using mul = ValueBuilder<mlir::LLVM::MulOp>; 66 using ptrtoint = ValueBuilder<mlir::LLVM::PtrToIntOp>; 67 using sub = ValueBuilder<mlir::LLVM::SubOp>; 68 using llvm_undef = ValueBuilder<mlir::LLVM::UndefOp>; 69 using urem = ValueBuilder<mlir::LLVM::URemOp>; 70 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>; 71 using llvm_return = OperationBuilder<LLVM::ReturnOp>; 72 73 namespace { 74 75 template <typename T> 76 static LLVMType getPtrToElementType(T containerType, 77 LLVMTypeConverter &lowering) { 78 return lowering.convertType(containerType.getElementType()) 79 .template cast<LLVMType>() 80 .getPointerTo(); 81 } 82 83 // Convert the given type to the LLVM IR Dialect type. The following 84 // conversions are supported: 85 // - an Index type is converted into an LLVM integer type with pointer 86 // bitwidth (analogous to intptr_t in C); 87 // - an Integer type is converted into an LLVM integer type of the same width; 88 // - an F32 type is converted into an LLVM float type 89 // - a Buffer, Range or View is converted into an LLVM structure type 90 // containing the respective dynamic values. 91 static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { 92 auto *context = t.getContext(); 93 auto int64Ty = lowering.convertType(IntegerType::get(64, context)) 94 .cast<LLVM::LLVMType>(); 95 96 // Range descriptor contains the range bounds and the step as 64-bit integers. 97 // 98 // struct { 99 // int64_t min; 100 // int64_t max; 101 // int64_t step; 102 // }; 103 if (t.isa<RangeType>()) 104 return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); 105 106 return Type(); 107 } 108 109 /// EDSC-compatible wrapper for MemRefDescriptor. 110 class BaseViewConversionHelper { 111 public: 112 BaseViewConversionHelper(Type type) 113 : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} 114 115 BaseViewConversionHelper(Value v) : d(v) {} 116 117 /// Wrappers around MemRefDescriptor that use EDSC builder and location. 118 Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } 119 void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } 120 Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } 121 void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } 122 Value offset() { return d.offset(rewriter(), loc()); } 123 void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } 124 Value size(unsigned i) { return d.size(rewriter(), loc(), i); } 125 void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } 126 void setConstantSize(unsigned i, int64_t v) { 127 d.setConstantSize(rewriter(), loc(), i, v); 128 } 129 Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } 130 void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } 131 void setConstantStride(unsigned i, int64_t v) { 132 d.setConstantStride(rewriter(), loc(), i, v); 133 } 134 135 operator Value() { return d; } 136 137 private: 138 OpBuilder &rewriter() { return ScopedContext::getBuilder(); } 139 Location loc() { return ScopedContext::getLocation(); } 140 141 MemRefDescriptor d; 142 }; 143 144 // RangeOp creates a new range descriptor. 145 class RangeOpConversion : public LLVMOpLowering { 146 public: 147 explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 148 : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} 149 150 PatternMatchResult 151 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 152 ConversionPatternRewriter &rewriter) const override { 153 auto rangeOp = cast<RangeOp>(op); 154 auto rangeDescriptorTy = 155 convertLinalgType(rangeOp.getResult().getType(), lowering); 156 157 edsc::ScopedContext context(rewriter, op->getLoc()); 158 159 // Fill in an aggregate value of the descriptor. 160 RangeOpOperandAdaptor adaptor(operands); 161 Value desc = llvm_undef(rangeDescriptorTy); 162 desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); 163 desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); 164 desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); 165 rewriter.replaceOp(op, desc); 166 return matchSuccess(); 167 } 168 }; 169 170 // ReshapeOp creates a new view descriptor of the proper rank. 171 // For now, the only conversion supported is for target MemRef with static sizes 172 // and strides. 173 class ReshapeOpConversion : public LLVMOpLowering { 174 public: 175 explicit ReshapeOpConversion(MLIRContext *context, 176 LLVMTypeConverter &lowering_) 177 : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {} 178 179 PatternMatchResult 180 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 181 ConversionPatternRewriter &rewriter) const override { 182 auto reshapeOp = cast<ReshapeOp>(op); 183 MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>(); 184 185 if (!dstType.hasStaticShape()) 186 return matchFailure(); 187 188 int64_t offset; 189 SmallVector<int64_t, 4> strides; 190 auto res = getStridesAndOffset(dstType, strides, offset); 191 if (failed(res) || llvm::any_of(strides, [](int64_t val) { 192 return ShapedType::isDynamicStrideOrOffset(val); 193 })) 194 return matchFailure(); 195 196 edsc::ScopedContext context(rewriter, op->getLoc()); 197 ReshapeOpOperandAdaptor adaptor(operands); 198 BaseViewConversionHelper baseDesc(adaptor.view()); 199 BaseViewConversionHelper desc(lowering.convertType(dstType)); 200 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 201 desc.setAlignedPtr(baseDesc.alignedPtr()); 202 desc.setOffset(baseDesc.offset()); 203 for (auto en : llvm::enumerate(dstType.getShape())) 204 desc.setConstantSize(en.index(), en.value()); 205 for (auto en : llvm::enumerate(strides)) 206 desc.setConstantStride(en.index(), en.value()); 207 rewriter.replaceOp(op, {desc}); 208 return matchSuccess(); 209 } 210 }; 211 212 /// Conversion pattern that transforms a linalg.slice op into: 213 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 214 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 215 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 216 /// and stride corresponding to the region of memory within the bounds of 217 /// the parent view. 218 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 219 /// The linalg.slice op is replaced by the alloca'ed pointer. 220 class SliceOpConversion : public LLVMOpLowering { 221 public: 222 explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 223 : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} 224 225 PatternMatchResult 226 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 227 ConversionPatternRewriter &rewriter) const override { 228 edsc::ScopedContext context(rewriter, op->getLoc()); 229 SliceOpOperandAdaptor adaptor(operands); 230 BaseViewConversionHelper baseDesc(adaptor.view()); 231 232 auto sliceOp = cast<SliceOp>(op); 233 auto memRefType = sliceOp.getBaseViewType(); 234 auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)) 235 .cast<LLVM::LLVMType>(); 236 237 BaseViewConversionHelper desc( 238 lowering.convertType(sliceOp.getShapedType())); 239 240 // TODO(ntv): extract sizes and emit asserts. 241 SmallVector<Value, 4> strides(memRefType.getRank()); 242 for (int i = 0, e = memRefType.getRank(); i < e; ++i) 243 strides[i] = baseDesc.stride(i); 244 245 auto pos = [&rewriter](ArrayRef<int64_t> values) { 246 return rewriter.getI64ArrayAttr(values); 247 }; 248 249 // Compute base offset. 250 Value baseOffset = baseDesc.offset(); 251 for (int i = 0, e = memRefType.getRank(); i < e; ++i) { 252 Value indexing = adaptor.indexings()[i]; 253 Value min = indexing; 254 if (sliceOp.indexing(i).getType().isa<RangeType>()) 255 min = extractvalue(int64Ty, indexing, pos(0)); 256 baseOffset = add(baseOffset, mul(min, strides[i])); 257 } 258 259 // Insert the base and aligned pointers. 260 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 261 desc.setAlignedPtr(baseDesc.alignedPtr()); 262 263 // Insert base offset. 264 desc.setOffset(baseOffset); 265 266 // Corner case, no sizes or strides: early return the descriptor. 267 if (sliceOp.getShapedType().getRank() == 0) 268 return rewriter.replaceOp(op, {desc}), matchSuccess(); 269 270 Value zero = 271 constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 272 // Compute and insert view sizes (max - min along the range) and strides. 273 // Skip the non-range operands as they will be projected away from the view. 274 int numNewDims = 0; 275 for (auto en : llvm::enumerate(sliceOp.indexings())) { 276 Value indexing = en.value(); 277 if (indexing.getType().isa<RangeType>()) { 278 int rank = en.index(); 279 Value rangeDescriptor = adaptor.indexings()[rank]; 280 Value min = extractvalue(int64Ty, rangeDescriptor, pos(0)); 281 Value max = extractvalue(int64Ty, rangeDescriptor, pos(1)); 282 Value step = extractvalue(int64Ty, rangeDescriptor, pos(2)); 283 Value baseSize = baseDesc.size(rank); 284 285 // Bound upper by base view upper bound. 286 max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, 287 baseSize); 288 Value size = sub(max, min); 289 // Bound lower by zero. 290 size = 291 llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); 292 Value stride = mul(strides[rank], step); 293 desc.setSize(numNewDims, size); 294 desc.setStride(numNewDims, stride); 295 ++numNewDims; 296 } 297 } 298 299 rewriter.replaceOp(op, {desc}); 300 return matchSuccess(); 301 } 302 }; 303 304 /// Conversion pattern that transforms a linalg.transpose op into: 305 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 306 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 307 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 308 /// and stride. Size and stride are permutations of the original values. 309 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 310 /// The linalg.transpose op is replaced by the alloca'ed pointer. 311 class TransposeOpConversion : public LLVMOpLowering { 312 public: 313 explicit TransposeOpConversion(MLIRContext *context, 314 LLVMTypeConverter &lowering_) 315 : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} 316 317 PatternMatchResult 318 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 319 ConversionPatternRewriter &rewriter) const override { 320 // Initialize the common boilerplate and alloca at the top of the FuncOp. 321 edsc::ScopedContext context(rewriter, op->getLoc()); 322 TransposeOpOperandAdaptor adaptor(operands); 323 BaseViewConversionHelper baseDesc(adaptor.view()); 324 325 auto transposeOp = cast<TransposeOp>(op); 326 // No permutation, early exit. 327 if (transposeOp.permutation().isIdentity()) 328 return rewriter.replaceOp(op, {baseDesc}), matchSuccess(); 329 330 BaseViewConversionHelper desc( 331 lowering.convertType(transposeOp.getShapedType())); 332 333 // Copy the base and aligned pointers from the old descriptor to the new 334 // one. 335 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 336 desc.setAlignedPtr(baseDesc.alignedPtr()); 337 338 // Copy the offset pointer from the old descriptor to the new one. 339 desc.setOffset(baseDesc.offset()); 340 341 // Iterate over the dimensions and apply size/stride permutation. 342 for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { 343 int sourcePos = en.index(); 344 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 345 desc.setSize(targetPos, baseDesc.size(sourcePos)); 346 desc.setStride(targetPos, baseDesc.stride(sourcePos)); 347 } 348 349 rewriter.replaceOp(op, {desc}); 350 return matchSuccess(); 351 } 352 }; 353 354 // YieldOp produces and LLVM::ReturnOp. 355 class YieldOpConversion : public LLVMOpLowering { 356 public: 357 explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 358 : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {} 359 360 PatternMatchResult 361 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 362 ConversionPatternRewriter &rewriter) const override { 363 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); 364 return matchSuccess(); 365 } 366 }; 367 368 template <typename LinalgOp> 369 static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) { 370 return SmallVector<Type, 4>{op->getOperandTypes()}; 371 } 372 373 template <> 374 SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) { 375 auto ctx = op->getContext(); 376 auto indexedGenericOp = cast<IndexedGenericOp>(op); 377 auto numLoops = indexedGenericOp.getNumLoops(); 378 379 SmallVector<Type, 4> result; 380 result.reserve(numLoops + op->getNumOperands()); 381 for (unsigned i = 0; i < numLoops; ++i) { 382 result.push_back(IndexType::get(ctx)); 383 } 384 for (auto type : op->getOperandTypes()) { 385 result.push_back(type); 386 } 387 return result; 388 } 389 390 // Get a SymbolRefAttr containing the library function name for the LinalgOp. 391 // If the library function does not exist, insert a declaration. 392 template <typename LinalgOp> 393 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, 394 PatternRewriter &rewriter) { 395 auto linalgOp = cast<LinalgOp>(op); 396 auto fnName = linalgOp.getLibraryCallName(); 397 if (fnName.empty()) { 398 op->emitWarning("No library call defined for: ") << *op; 399 return {}; 400 } 401 402 // fnName is a dynamic std::String, unique it via a SymbolRefAttr. 403 FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); 404 auto module = op->getParentOfType<ModuleOp>(); 405 if (module.lookupSymbol(fnName)) { 406 return fnNameAttr; 407 } 408 409 SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op)); 410 assert(op->getNumResults() == 0 && 411 "Library call for linalg operation can be generated only for ops that " 412 "have void return types"); 413 auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext()); 414 415 OpBuilder::InsertionGuard guard(rewriter); 416 // Insert before module terminator. 417 rewriter.setInsertionPoint(module.getBody(), 418 std::prev(module.getBody()->end())); 419 rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType, 420 ArrayRef<NamedAttribute>{}); 421 return fnNameAttr; 422 } 423 424 } // namespace 425 426 Type LinalgTypeConverter::convertType(Type t) { 427 if (auto result = LLVMTypeConverter::convertType(t)) 428 return result; 429 return convertLinalgType(t, *this); 430 } 431 432 namespace { 433 434 // LinalgOpConversion<LinalgOp> creates a new call to the 435 // `LinalgOp::getLibraryCallName()` function. 436 // The implementation of the function can be either in the same module or in an 437 // externally linked library. 438 template <typename LinalgOp> 439 class LinalgOpConversion : public OpRewritePattern<LinalgOp> { 440 public: 441 using OpRewritePattern<LinalgOp>::OpRewritePattern; 442 443 PatternMatchResult matchAndRewrite(LinalgOp op, 444 PatternRewriter &rewriter) const override { 445 auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter); 446 if (!libraryCallName) 447 return this->matchFailure(); 448 449 rewriter.replaceOpWithNewOp<mlir::CallOp>( 450 op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands()); 451 return this->matchSuccess(); 452 } 453 }; 454 455 /// Conversion pattern specialization for CopyOp. This kicks in when both input 456 /// and output permutations are left unspecified or are the identity. 457 template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> { 458 public: 459 using OpRewritePattern<CopyOp>::OpRewritePattern; 460 461 PatternMatchResult matchAndRewrite(CopyOp op, 462 PatternRewriter &rewriter) const override { 463 auto inputPerm = op.inputPermutation(); 464 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 465 return matchFailure(); 466 auto outputPerm = op.outputPermutation(); 467 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 468 return matchFailure(); 469 470 auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter); 471 if (!libraryCallName) 472 return matchFailure(); 473 474 rewriter.replaceOpWithNewOp<mlir::CallOp>( 475 op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands()); 476 return matchSuccess(); 477 } 478 }; 479 480 /// Conversion pattern specialization for IndexedGenericOp. 481 template <> 482 class LinalgOpConversion<IndexedGenericOp> 483 : public OpRewritePattern<IndexedGenericOp> { 484 public: 485 using OpRewritePattern<IndexedGenericOp>::OpRewritePattern; 486 487 PatternMatchResult matchAndRewrite(IndexedGenericOp op, 488 PatternRewriter &rewriter) const override { 489 auto libraryCallName = 490 getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter); 491 if (!libraryCallName) 492 return this->matchFailure(); 493 494 // TODO(pifon, ntv): Use induction variables values instead of zeros, when 495 // IndexedGenericOp is tiled. 496 auto zero = rewriter.create<mlir::ConstantOp>( 497 op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 498 auto indexedGenericOp = cast<IndexedGenericOp>(op); 499 auto numLoops = indexedGenericOp.getNumLoops(); 500 SmallVector<Value, 4> operands; 501 operands.reserve(numLoops + op.getNumOperands()); 502 for (unsigned i = 0; i < numLoops; ++i) { 503 operands.push_back(zero); 504 } 505 for (auto operand : op.getOperands()) { 506 operands.push_back(operand); 507 } 508 rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(), 509 ArrayRef<Type>{}, operands); 510 return this->matchSuccess(); 511 } 512 }; 513 514 /// A non-conversion rewrite pattern kicks in to convert CopyOp with 515 /// permutations into a sequence of TransposeOp and permutation-free CopyOp. 516 /// This interplays together with TransposeOpConversion and 517 /// LinalgConversion<CopyOp> to create a path to the LLVM dialect. 518 class CopyTransposeConversion : public OpRewritePattern<CopyOp> { 519 public: 520 using OpRewritePattern<CopyOp>::OpRewritePattern; 521 522 PatternMatchResult matchAndRewrite(CopyOp op, 523 PatternRewriter &rewriter) const override { 524 Value in = op.input(), out = op.output(); 525 526 // If either inputPerm or outputPerm are non-identities, insert transposes. 527 auto inputPerm = op.inputPermutation(); 528 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 529 in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in, 530 AffineMapAttr::get(*inputPerm)); 531 auto outputPerm = op.outputPermutation(); 532 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 533 out = rewriter.create<linalg::TransposeOp>( 534 op.getLoc(), out, AffineMapAttr::get(*outputPerm)); 535 536 // If nothing was transposed, fail and let the conversion kick in. 537 if (in == op.input() && out == op.output()) 538 return matchFailure(); 539 540 rewriter.replaceOpWithNewOp<CopyOp>(op, in, out); 541 return matchSuccess(); 542 } 543 }; 544 545 /// Populate the given list with patterns that convert from Linalg to Standard. 546 static void 547 populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns, 548 MLIRContext *ctx) { 549 // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant 550 // attribute values such as kernel striding and dilation. 551 patterns.insert<CopyTransposeConversion, LinalgOpConversion<ConvOp>, 552 LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>, 553 LinalgOpConversion<FillOp>, LinalgOpConversion<GenericOp>, 554 LinalgOpConversion<IndexedGenericOp>, 555 LinalgOpConversion<MatmulOp>, LinalgOpConversion<MatvecOp>>( 556 ctx); 557 } 558 559 } // namespace 560 561 /// Populate the given list with patterns that convert from Linalg to LLVM. 562 void mlir::populateLinalgToLLVMConversionPatterns( 563 LinalgTypeConverter &converter, OwningRewritePatternList &patterns, 564 MLIRContext *ctx) { 565 patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion, 566 TransposeOpConversion, YieldOpConversion>(ctx, converter); 567 } 568 569 namespace { 570 struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> { 571 void runOnModule() override; 572 }; 573 } // namespace 574 575 void ConvertLinalgToLLVMPass::runOnModule() { 576 auto module = getModule(); 577 578 // Convert to the LLVM IR dialect using the converter defined above. 579 OwningRewritePatternList patterns; 580 LinalgTypeConverter converter(&getContext()); 581 populateAffineToStdConversionPatterns(patterns, &getContext()); 582 populateLoopToStdConversionPatterns(patterns, &getContext()); 583 populateStdToLLVMConversionPatterns(converter, patterns); 584 populateVectorToLLVMConversionPatterns(converter, patterns); 585 populateLinalgToStandardConversionPatterns(patterns, &getContext()); 586 populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); 587 588 ConversionTarget target(getContext()); 589 target.addLegalDialect<LLVM::LLVMDialect>(); 590 target.addDynamicallyLegalOp<FuncOp>( 591 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 592 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 593 if (failed(applyFullConversion(module, target, patterns, &converter))) 594 signalPassFailure(); 595 } 596 597 std::unique_ptr<OpPassBase<ModuleOp>> 598 mlir::linalg::createConvertLinalgToLLVMPass() { 599 return std::make_unique<ConvertLinalgToLLVMPass>(); 600 } 601 602 static PassRegistration<ConvertLinalgToLLVMPass> pass( 603 "convert-linalg-to-llvm", 604 "Convert the operations from the linalg dialect into the LLVM dialect"); 605