1 //===- SCFToSPIRV.cpp - Convert SCF ops to SPIR-V dialect -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the conversion patterns from SCF ops to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/SPIRVLowering.h" 16 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 17 #include "mlir/IR/BuiltinOps.h" 18 19 using namespace mlir; 20 21 namespace mlir { 22 struct ScfToSPIRVContextImpl { 23 // Map between the spirv region control flow operation (spv.loop or 24 // spv.selection) to the VariableOp created to store the region results. The 25 // order of the VariableOp matches the order of the results. 26 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars; 27 }; 28 } // namespace mlir 29 30 /// We use ScfToSPIRVContext to store information about the lowering of the scf 31 /// region that need to be used later on. When we lower scf.for/scf.if we create 32 /// VariableOp to store the results. We need to keep track of the VariableOp 33 /// created as we need to insert stores into them when lowering Yield. Those 34 /// StoreOp cannot be created earlier as they may use a different type than 35 /// yield operands. 36 ScfToSPIRVContext::ScfToSPIRVContext() { 37 impl = std::make_unique<ScfToSPIRVContextImpl>(); 38 } 39 ScfToSPIRVContext::~ScfToSPIRVContext() = default; 40 41 namespace { 42 /// Common class for all vector to GPU patterns. 43 template <typename OpTy> 44 class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> { 45 public: 46 SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter, 47 ScfToSPIRVContextImpl *scfToSPIRVContext) 48 : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter), 49 scfToSPIRVContext(scfToSPIRVContext) {} 50 51 protected: 52 ScfToSPIRVContextImpl *scfToSPIRVContext; 53 }; 54 55 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. 56 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> { 57 public: 58 using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern; 59 60 LogicalResult 61 matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands, 62 ConversionPatternRewriter &rewriter) const override; 63 }; 64 65 /// Pattern to convert a scf::IfOp within kernel functions into 66 /// spirv::SelectionOp. 67 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> { 68 public: 69 using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern; 70 71 LogicalResult 72 matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands, 73 ConversionPatternRewriter &rewriter) const override; 74 }; 75 76 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> { 77 public: 78 using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern; 79 80 LogicalResult 81 matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands, 82 ConversionPatternRewriter &rewriter) const override; 83 }; 84 } // namespace 85 86 /// Helper function to replaces SCF op outputs with SPIR-V variable loads. 87 /// We create VariableOp to handle the results value of the control flow region. 88 /// spv.loop/spv.selection currently don't yield value. Right after the loop 89 /// we load the value from the allocation and use it as the SCF op result. 90 template <typename ScfOp, typename OpTy> 91 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, 92 SPIRVTypeConverter &typeConverter, 93 ConversionPatternRewriter &rewriter, 94 ScfToSPIRVContextImpl *scfToSPIRVContext, 95 ArrayRef<Type> returnTypes) { 96 97 Location loc = scfOp.getLoc(); 98 auto &allocas = scfToSPIRVContext->outputVars[newOp]; 99 // Clearing the allocas is necessary in case a dialect conversion path failed 100 // previously, and this is the second attempt of this conversion. 101 allocas.clear(); 102 SmallVector<Value, 8> resultValue; 103 for (Type convertedType : returnTypes) { 104 auto pointerType = 105 spirv::PointerType::get(convertedType, spirv::StorageClass::Function); 106 rewriter.setInsertionPoint(newOp); 107 auto alloc = rewriter.create<spirv::VariableOp>( 108 loc, pointerType, spirv::StorageClass::Function, 109 /*initializer=*/nullptr); 110 allocas.push_back(alloc); 111 rewriter.setInsertionPointAfter(newOp); 112 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc); 113 resultValue.push_back(loadResult); 114 } 115 rewriter.replaceOp(scfOp, resultValue); 116 } 117 118 //===----------------------------------------------------------------------===// 119 // scf::ForOp. 120 //===----------------------------------------------------------------------===// 121 122 LogicalResult 123 ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands, 124 ConversionPatternRewriter &rewriter) const { 125 // scf::ForOp can be lowered to the structured control flow represented by 126 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop 127 // latch and the merge block the exit block. The resulting spirv::LoopOp has a 128 // single back edge from the continue to header block, and a single exit from 129 // header to merge. 130 scf::ForOpAdaptor forOperands(operands); 131 auto loc = forOp.getLoc(); 132 auto loopControl = rewriter.getI32IntegerAttr( 133 static_cast<uint32_t>(spirv::LoopControl::None)); 134 auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl); 135 loopOp.addEntryAndMergeBlock(); 136 137 OpBuilder::InsertionGuard guard(rewriter); 138 // Create the block for the header. 139 auto *header = new Block(); 140 // Insert the header. 141 loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); 142 143 // Create the new induction variable to use. 144 BlockArgument newIndVar = 145 header->addArgument(forOperands.lowerBound().getType()); 146 for (Value arg : forOperands.initArgs()) 147 header->addArgument(arg.getType()); 148 Block *body = forOp.getBody(); 149 150 // Apply signature conversion to the body of the forOp. It has a single block, 151 // with argument which is the induction variable. That has to be replaced with 152 // the new induction variable. 153 TypeConverter::SignatureConversion signatureConverter( 154 body->getNumArguments()); 155 signatureConverter.remapInput(0, newIndVar); 156 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) 157 signatureConverter.remapInput(i, header->getArgument(i)); 158 body = rewriter.applySignatureConversion(&forOp.getLoopBody(), 159 signatureConverter); 160 161 // Move the blocks from the forOp into the loopOp. This is the body of the 162 // loopOp. 163 rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(), 164 std::next(loopOp.body().begin(), 2)); 165 166 SmallVector<Value, 8> args(1, forOperands.lowerBound()); 167 args.append(forOperands.initArgs().begin(), forOperands.initArgs().end()); 168 // Branch into it from the entry. 169 rewriter.setInsertionPointToEnd(&(loopOp.body().front())); 170 rewriter.create<spirv::BranchOp>(loc, header, args); 171 172 // Generate the rest of the loop header. 173 rewriter.setInsertionPointToEnd(header); 174 auto *mergeBlock = loopOp.getMergeBlock(); 175 auto cmpOp = rewriter.create<spirv::SLessThanOp>( 176 loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); 177 178 rewriter.create<spirv::BranchConditionalOp>( 179 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>()); 180 181 // Generate instructions to increment the step of the induction variable and 182 // branch to the header. 183 Block *continueBlock = loopOp.getContinueBlock(); 184 rewriter.setInsertionPointToEnd(continueBlock); 185 186 // Add the step to the induction variable and branch to the header. 187 Value updatedIndVar = rewriter.create<spirv::IAddOp>( 188 loc, newIndVar.getType(), newIndVar, forOperands.step()); 189 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar); 190 191 // Infer the return types from the init operands. Vector type may get 192 // converted to CooperativeMatrix or to Vector type, to avoid having complex 193 // extra logic to figure out the right type we just infer it from the Init 194 // operands. 195 SmallVector<Type, 8> initTypes; 196 for (auto arg : forOperands.initArgs()) 197 initTypes.push_back(arg.getType()); 198 replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter, 199 scfToSPIRVContext, initTypes); 200 return success(); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // scf::IfOp. 205 //===----------------------------------------------------------------------===// 206 207 LogicalResult 208 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands, 209 ConversionPatternRewriter &rewriter) const { 210 // When lowering `scf::IfOp` we explicitly create a selection header block 211 // before the control flow diverges and a merge block where control flow 212 // subsequently converges. 213 scf::IfOpAdaptor ifOperands(operands); 214 auto loc = ifOp.getLoc(); 215 216 // Create `spv.selection` operation, selection header block and merge block. 217 auto selectionControl = rewriter.getI32IntegerAttr( 218 static_cast<uint32_t>(spirv::SelectionControl::None)); 219 auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl); 220 auto *mergeBlock = 221 rewriter.createBlock(&selectionOp.body(), selectionOp.body().end()); 222 rewriter.create<spirv::MergeOp>(loc); 223 224 OpBuilder::InsertionGuard guard(rewriter); 225 auto *selectionHeaderBlock = 226 rewriter.createBlock(&selectionOp.body().front()); 227 228 // Inline `then` region before the merge block and branch to it. 229 auto &thenRegion = ifOp.thenRegion(); 230 auto *thenBlock = &thenRegion.front(); 231 rewriter.setInsertionPointToEnd(&thenRegion.back()); 232 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 233 rewriter.inlineRegionBefore(thenRegion, mergeBlock); 234 235 auto *elseBlock = mergeBlock; 236 // If `else` region is not empty, inline that region before the merge block 237 // and branch to it. 238 if (!ifOp.elseRegion().empty()) { 239 auto &elseRegion = ifOp.elseRegion(); 240 elseBlock = &elseRegion.front(); 241 rewriter.setInsertionPointToEnd(&elseRegion.back()); 242 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 243 rewriter.inlineRegionBefore(elseRegion, mergeBlock); 244 } 245 246 // Create a `spv.BranchConditional` operation for selection header block. 247 rewriter.setInsertionPointToEnd(selectionHeaderBlock); 248 rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(), 249 thenBlock, ArrayRef<Value>(), 250 elseBlock, ArrayRef<Value>()); 251 252 SmallVector<Type, 8> returnTypes; 253 for (auto result : ifOp.results()) { 254 auto convertedType = typeConverter.convertType(result.getType()); 255 returnTypes.push_back(convertedType); 256 } 257 replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter, 258 scfToSPIRVContext, returnTypes); 259 return success(); 260 } 261 262 /// Yield is lowered to stores to the VariableOp created during lowering of the 263 /// parent region. For loops we also need to update the branch looping back to 264 /// the header with the loop carried values. 265 LogicalResult TerminatorOpConversion::matchAndRewrite( 266 scf::YieldOp terminatorOp, ArrayRef<Value> operands, 267 ConversionPatternRewriter &rewriter) const { 268 // If the region is return values, store each value into the associated 269 // VariableOp created during lowering of the parent region. 270 if (!operands.empty()) { 271 auto loc = terminatorOp.getLoc(); 272 auto &allocas = scfToSPIRVContext->outputVars[terminatorOp.getParentOp()]; 273 assert(allocas.size() == operands.size()); 274 for (unsigned i = 0, e = operands.size(); i < e; i++) 275 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]); 276 if (isa<spirv::LoopOp>(terminatorOp.getParentOp())) { 277 // For loops we also need to update the branch jumping back to the header. 278 auto br = 279 cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator()); 280 SmallVector<Value, 8> args(br.getBlockArguments()); 281 args.append(operands.begin(), operands.end()); 282 rewriter.setInsertionPoint(br); 283 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(), 284 args); 285 rewriter.eraseOp(br); 286 } 287 } 288 rewriter.eraseOp(terminatorOp); 289 return success(); 290 } 291 292 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context, 293 SPIRVTypeConverter &typeConverter, 294 ScfToSPIRVContext &scfToSPIRVContext, 295 OwningRewritePatternList &patterns) { 296 patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>( 297 context, typeConverter, scfToSPIRVContext.getImpl()); 298 } 299