1 //===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 13 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 16 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 21 #include "mlir/Dialect/Linalg/Passes.h" 22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Attributes.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/IR/MLIRContext.h" 28 #include "mlir/IR/Module.h" 29 #include "mlir/IR/Operation.h" 30 #include "mlir/IR/PatternMatch.h" 31 #include "mlir/IR/StandardTypes.h" 32 #include "mlir/IR/Types.h" 33 #include "mlir/Support/LogicalResult.h" 34 #include "mlir/Transforms/DialectConversion.h" 35 #include "mlir/Transforms/Passes.h" 36 #include "llvm/ADT/SetVector.h" 37 #include "llvm/IR/DerivedTypes.h" 38 #include "llvm/IR/Module.h" 39 #include "llvm/IR/Type.h" 40 #include "llvm/Support/Allocator.h" 41 #include "llvm/Support/ErrorHandling.h" 42 43 using namespace mlir; 44 using namespace mlir::edsc; 45 using namespace mlir::edsc::intrinsics; 46 using namespace mlir::LLVM; 47 using namespace mlir::linalg; 48 49 using llvm_add = ValueBuilder<LLVM::AddOp>; 50 using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>; 51 using llvm_constant = ValueBuilder<LLVM::ConstantOp>; 52 using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>; 53 using llvm_gep = ValueBuilder<LLVM::GEPOp>; 54 using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>; 55 using llvm_call = OperationBuilder<LLVM::CallOp>; 56 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>; 57 using llvm_load = ValueBuilder<LLVM::LoadOp>; 58 using llvm_store = OperationBuilder<LLVM::StoreOp>; 59 using llvm_select = ValueBuilder<LLVM::SelectOp>; 60 using llvm_mul = ValueBuilder<LLVM::MulOp>; 61 using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>; 62 using llvm_sub = ValueBuilder<LLVM::SubOp>; 63 using llvm_undef = ValueBuilder<LLVM::UndefOp>; 64 using llvm_urem = ValueBuilder<LLVM::URemOp>; 65 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>; 66 using llvm_return = OperationBuilder<LLVM::ReturnOp>; 67 68 template <typename T> 69 static LLVMType getPtrToElementType(T containerType, 70 LLVMTypeConverter &lowering) { 71 return lowering.convertType(containerType.getElementType()) 72 .template cast<LLVMType>() 73 .getPointerTo(); 74 } 75 76 /// Convert the given range descriptor type to the LLVMIR dialect. 77 /// Range descriptor contains the range bounds and the step as 64-bit integers. 78 /// 79 /// struct { 80 /// int64_t min; 81 /// int64_t max; 82 /// int64_t step; 83 /// }; 84 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { 85 auto *context = t.getContext(); 86 auto int64Ty = converter.convertType(IntegerType::get(64, context)) 87 .cast<LLVM::LLVMType>(); 88 return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); 89 } 90 91 namespace { 92 /// EDSC-compatible wrapper for MemRefDescriptor. 93 class BaseViewConversionHelper { 94 public: 95 BaseViewConversionHelper(Type type) 96 : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} 97 98 BaseViewConversionHelper(Value v) : d(v) {} 99 100 /// Wrappers around MemRefDescriptor that use EDSC builder and location. 101 Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } 102 void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); } 103 Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); } 104 void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); } 105 Value offset() { return d.offset(rewriter(), loc()); } 106 void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } 107 Value size(unsigned i) { return d.size(rewriter(), loc(), i); } 108 void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } 109 void setConstantSize(unsigned i, int64_t v) { 110 d.setConstantSize(rewriter(), loc(), i, v); 111 } 112 Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } 113 void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } 114 void setConstantStride(unsigned i, int64_t v) { 115 d.setConstantStride(rewriter(), loc(), i, v); 116 } 117 118 operator Value() { return d; } 119 120 private: 121 OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); } 122 Location loc() { return ScopedContext::getLocation(); } 123 124 MemRefDescriptor d; 125 }; 126 127 // RangeOp creates a new range descriptor. 128 class RangeOpConversion : public ConvertToLLVMPattern { 129 public: 130 explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 131 : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {} 132 133 LogicalResult 134 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 135 ConversionPatternRewriter &rewriter) const override { 136 auto rangeOp = cast<RangeOp>(op); 137 auto rangeDescriptorTy = 138 convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter); 139 140 edsc::ScopedContext context(rewriter, op->getLoc()); 141 142 // Fill in an aggregate value of the descriptor. 143 RangeOpOperandAdaptor adaptor(operands); 144 Value desc = llvm_undef(rangeDescriptorTy); 145 desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); 146 desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); 147 desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); 148 rewriter.replaceOp(op, desc); 149 return success(); 150 } 151 }; 152 153 // ReshapeOp creates a new view descriptor of the proper rank. 154 // For now, the only conversion supported is for target MemRef with static sizes 155 // and strides. 156 class ReshapeOpConversion : public ConvertToLLVMPattern { 157 public: 158 explicit ReshapeOpConversion(MLIRContext *context, 159 LLVMTypeConverter &lowering_) 160 : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context, 161 lowering_) {} 162 163 LogicalResult 164 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 165 ConversionPatternRewriter &rewriter) const override { 166 auto reshapeOp = cast<ReshapeOp>(op); 167 MemRefType dstType = reshapeOp.getResultType(); 168 169 if (!dstType.hasStaticShape()) 170 return failure(); 171 172 int64_t offset; 173 SmallVector<int64_t, 4> strides; 174 auto res = getStridesAndOffset(dstType, strides, offset); 175 if (failed(res) || llvm::any_of(strides, [](int64_t val) { 176 return ShapedType::isDynamicStrideOrOffset(val); 177 })) 178 return failure(); 179 180 edsc::ScopedContext context(rewriter, op->getLoc()); 181 ReshapeOpOperandAdaptor adaptor(operands); 182 BaseViewConversionHelper baseDesc(adaptor.src()); 183 BaseViewConversionHelper desc(typeConverter.convertType(dstType)); 184 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 185 desc.setAlignedPtr(baseDesc.alignedPtr()); 186 desc.setOffset(baseDesc.offset()); 187 for (auto en : llvm::enumerate(dstType.getShape())) 188 desc.setConstantSize(en.index(), en.value()); 189 for (auto en : llvm::enumerate(strides)) 190 desc.setConstantStride(en.index(), en.value()); 191 rewriter.replaceOp(op, {desc}); 192 return success(); 193 } 194 }; 195 196 /// Conversion pattern that transforms a linalg.slice op into: 197 /// 1. An "undef" value for the ViewDescriptor. 198 /// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size 199 /// and stride corresponding to the region of memory within the bounds of 200 /// the parent view. 201 /// The linalg.slice op is replaced by the alloca'ed pointer. 202 class SliceOpConversion : public ConvertToLLVMPattern { 203 public: 204 explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 205 : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {} 206 207 LogicalResult 208 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 209 ConversionPatternRewriter &rewriter) const override { 210 edsc::ScopedContext context(rewriter, op->getLoc()); 211 SliceOpOperandAdaptor adaptor(operands); 212 BaseViewConversionHelper baseDesc(adaptor.view()); 213 214 auto sliceOp = cast<SliceOp>(op); 215 auto memRefType = sliceOp.getBaseViewType(); 216 auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64)) 217 .cast<LLVM::LLVMType>(); 218 219 BaseViewConversionHelper desc( 220 typeConverter.convertType(sliceOp.getShapedType())); 221 222 // TODO(ntv): extract sizes and emit asserts. 223 SmallVector<Value, 4> strides(memRefType.getRank()); 224 for (int i = 0, e = memRefType.getRank(); i < e; ++i) 225 strides[i] = baseDesc.stride(i); 226 227 auto pos = [&rewriter](ArrayRef<int64_t> values) { 228 return rewriter.getI64ArrayAttr(values); 229 }; 230 231 // Compute base offset. 232 Value baseOffset = baseDesc.offset(); 233 for (int i = 0, e = memRefType.getRank(); i < e; ++i) { 234 Value indexing = adaptor.indexings()[i]; 235 Value min = indexing; 236 if (sliceOp.indexing(i).getType().isa<RangeType>()) 237 min = llvm_extractvalue(int64Ty, indexing, pos(0)); 238 baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i])); 239 } 240 241 // Insert the base and aligned pointers. 242 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 243 desc.setAlignedPtr(baseDesc.alignedPtr()); 244 245 // Insert base offset. 246 desc.setOffset(baseOffset); 247 248 // Corner case, no sizes or strides: early return the descriptor. 249 if (sliceOp.getShapedType().getRank() == 0) 250 return rewriter.replaceOp(op, {desc}), success(); 251 252 Value zero = llvm_constant( 253 int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 254 // Compute and insert view sizes (max - min along the range) and strides. 255 // Skip the non-range operands as they will be projected away from the view. 256 int numNewDims = 0; 257 for (auto en : llvm::enumerate(sliceOp.indexings())) { 258 Value indexing = en.value(); 259 if (indexing.getType().isa<RangeType>()) { 260 int rank = en.index(); 261 Value rangeDescriptor = adaptor.indexings()[rank]; 262 Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0)); 263 Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1)); 264 Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2)); 265 Value baseSize = baseDesc.size(rank); 266 267 // Bound upper by base view upper bound. 268 max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, 269 baseSize); 270 Value size = llvm_sub(max, min); 271 // Bound lower by zero. 272 size = 273 llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); 274 Value stride = llvm_mul(strides[rank], step); 275 desc.setSize(numNewDims, size); 276 desc.setStride(numNewDims, stride); 277 ++numNewDims; 278 } 279 } 280 281 rewriter.replaceOp(op, {desc}); 282 return success(); 283 } 284 }; 285 286 /// Conversion pattern that transforms a linalg.transpose op into: 287 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 288 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 289 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 290 /// and stride. Size and stride are permutations of the original values. 291 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 292 /// The linalg.transpose op is replaced by the alloca'ed pointer. 293 class TransposeOpConversion : public ConvertToLLVMPattern { 294 public: 295 explicit TransposeOpConversion(MLIRContext *context, 296 LLVMTypeConverter &lowering_) 297 : ConvertToLLVMPattern(TransposeOp::getOperationName(), context, 298 lowering_) {} 299 300 LogicalResult 301 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 302 ConversionPatternRewriter &rewriter) const override { 303 // Initialize the common boilerplate and alloca at the top of the FuncOp. 304 edsc::ScopedContext context(rewriter, op->getLoc()); 305 TransposeOpOperandAdaptor adaptor(operands); 306 BaseViewConversionHelper baseDesc(adaptor.view()); 307 308 auto transposeOp = cast<TransposeOp>(op); 309 // No permutation, early exit. 310 if (transposeOp.permutation().isIdentity()) 311 return rewriter.replaceOp(op, {baseDesc}), success(); 312 313 BaseViewConversionHelper desc( 314 typeConverter.convertType(transposeOp.getShapedType())); 315 316 // Copy the base and aligned pointers from the old descriptor to the new 317 // one. 318 desc.setAllocatedPtr(baseDesc.allocatedPtr()); 319 desc.setAlignedPtr(baseDesc.alignedPtr()); 320 321 // Copy the offset pointer from the old descriptor to the new one. 322 desc.setOffset(baseDesc.offset()); 323 324 // Iterate over the dimensions and apply size/stride permutation. 325 for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { 326 int sourcePos = en.index(); 327 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 328 desc.setSize(targetPos, baseDesc.size(sourcePos)); 329 desc.setStride(targetPos, baseDesc.stride(sourcePos)); 330 } 331 332 rewriter.replaceOp(op, {desc}); 333 return success(); 334 } 335 }; 336 337 // YieldOp produces and LLVM::ReturnOp. 338 class YieldOpConversion : public ConvertToLLVMPattern { 339 public: 340 explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 341 : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {} 342 343 LogicalResult 344 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 345 ConversionPatternRewriter &rewriter) const override { 346 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands); 347 return success(); 348 } 349 }; 350 } // namespace 351 352 /// Populate the given list with patterns that convert from Linalg to LLVM. 353 void mlir::populateLinalgToLLVMConversionPatterns( 354 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 355 MLIRContext *ctx) { 356 patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion, 357 TransposeOpConversion, YieldOpConversion>(ctx, converter); 358 359 // Populate the type conversions for the linalg types. 360 converter.addConversion( 361 [&](RangeType type) { return convertRangeType(type, converter); }); 362 } 363 364 namespace { 365 struct ConvertLinalgToLLVMPass 366 : public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> { 367 void runOnOperation() override; 368 }; 369 } // namespace 370 371 void ConvertLinalgToLLVMPass::runOnOperation() { 372 auto module = getOperation(); 373 374 // Convert to the LLVM IR dialect using the converter defined above. 375 OwningRewritePatternList patterns; 376 LLVMTypeConverter converter(&getContext()); 377 populateAffineToStdConversionPatterns(patterns, &getContext()); 378 populateLoopToStdConversionPatterns(patterns, &getContext()); 379 populateStdToLLVMConversionPatterns(converter, patterns); 380 populateVectorToSCFConversionPatterns(patterns, &getContext()); 381 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 382 populateVectorToLLVMConversionPatterns(converter, patterns); 383 populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); 384 385 LLVMConversionTarget target(getContext()); 386 target.addDynamicallyLegalOp<FuncOp>( 387 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 388 target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); 389 if (failed(applyFullConversion(module, target, patterns, &converter))) 390 signalPassFailure(); 391 } 392 393 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() { 394 return std::make_unique<ConvertLinalgToLLVMPass>(); 395 } 396