1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert SCF dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Context 25 //===----------------------------------------------------------------------===// 26 27 namespace mlir { 28 struct ScfToSPIRVContextImpl { 29 // Map between the spirv region control flow operation (spv.mlir.loop or 30 // spv.mlir.selection) to the VariableOp created to store the region results. 31 // The order of the VariableOp matches the order of the results. 32 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars; 33 }; 34 } // namespace mlir 35 36 /// We use ScfToSPIRVContext to store information about the lowering of the scf 37 /// region that need to be used later on. When we lower scf.for/scf.if we create 38 /// VariableOp to store the results. We need to keep track of the VariableOp 39 /// created as we need to insert stores into them when lowering Yield. Those 40 /// StoreOp cannot be created earlier as they may use a different type than 41 /// yield operands. 42 ScfToSPIRVContext::ScfToSPIRVContext() { 43 impl = std::make_unique<ScfToSPIRVContextImpl>(); 44 } 45 46 ScfToSPIRVContext::~ScfToSPIRVContext() = default; 47 48 //===----------------------------------------------------------------------===// 49 // Pattern Declarations 50 //===----------------------------------------------------------------------===// 51 52 namespace { 53 /// Common class for all vector to GPU patterns. 54 template <typename OpTy> 55 class SCFToSPIRVPattern : public OpConversionPattern<OpTy> { 56 public: 57 SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter, 58 ScfToSPIRVContextImpl *scfToSPIRVContext) 59 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context), 60 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {} 61 62 protected: 63 ScfToSPIRVContextImpl *scfToSPIRVContext; 64 // FIXME: We explicitly keep a reference of the type converter here instead of 65 // passing it to OpConversionPattern during construction. This effectively 66 // bypasses the conversion framework's automation on type conversion. This is 67 // needed right now because the conversion framework will unconditionally 68 // legalize all types used by SCF ops upon discovering them, for example, the 69 // types of loop carried values. We use SPIR-V variables for those loop 70 // carried values. Depending on the available capabilities, the SPIR-V 71 // variable can be different, for example, cooperative matrix or normal 72 // variable. We'd like to detach the conversion of the loop carried values 73 // from the SCF ops (which is mainly a region). So we need to "mark" types 74 // used by SCF ops as legal, if to use the conversion framework for type 75 // conversion. There isn't a straightforward way to do that yet, as when 76 // converting types, ops aren't taken into consideration. Therefore, we just 77 // bypass the framework's type conversion for now. 78 SPIRVTypeConverter &typeConverter; 79 }; 80 81 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. 82 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> { 83 public: 84 using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern; 85 86 LogicalResult 87 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, 88 ConversionPatternRewriter &rewriter) const override; 89 }; 90 91 /// Pattern to convert a scf::IfOp within kernel functions into 92 /// spirv::SelectionOp. 93 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> { 94 public: 95 using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern; 96 97 LogicalResult 98 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, 99 ConversionPatternRewriter &rewriter) const override; 100 }; 101 102 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> { 103 public: 104 using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern; 105 106 LogicalResult 107 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override; 109 }; 110 111 class WhileOpConversion final : public SCFToSPIRVPattern<scf::WhileOp> { 112 public: 113 using SCFToSPIRVPattern<scf::WhileOp>::SCFToSPIRVPattern; 114 115 LogicalResult 116 matchAndRewrite(scf::WhileOp forOp, OpAdaptor adaptor, 117 ConversionPatternRewriter &rewriter) const override; 118 }; 119 } // namespace 120 121 /// Helper function to replaces SCF op outputs with SPIR-V variable loads. 122 /// We create VariableOp to handle the results value of the control flow region. 123 /// spv.mlir.loop/spv.mlir.selection currently don't yield value. Right after 124 /// the loop we load the value from the allocation and use it as the SCF op 125 /// result. 126 template <typename ScfOp, typename OpTy> 127 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, 128 ConversionPatternRewriter &rewriter, 129 ScfToSPIRVContextImpl *scfToSPIRVContext, 130 ArrayRef<Type> returnTypes) { 131 132 Location loc = scfOp.getLoc(); 133 auto &allocas = scfToSPIRVContext->outputVars[newOp]; 134 // Clearing the allocas is necessary in case a dialect conversion path failed 135 // previously, and this is the second attempt of this conversion. 136 allocas.clear(); 137 SmallVector<Value, 8> resultValue; 138 for (Type convertedType : returnTypes) { 139 auto pointerType = 140 spirv::PointerType::get(convertedType, spirv::StorageClass::Function); 141 rewriter.setInsertionPoint(newOp); 142 auto alloc = rewriter.create<spirv::VariableOp>( 143 loc, pointerType, spirv::StorageClass::Function, 144 /*initializer=*/nullptr); 145 allocas.push_back(alloc); 146 rewriter.setInsertionPointAfter(newOp); 147 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc); 148 resultValue.push_back(loadResult); 149 } 150 rewriter.replaceOp(scfOp, resultValue); 151 } 152 153 static Region::iterator getBlockIt(Region ®ion, unsigned index) { 154 return std::next(region.begin(), index); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // scf::ForOp 159 //===----------------------------------------------------------------------===// 160 161 LogicalResult 162 ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, 163 ConversionPatternRewriter &rewriter) const { 164 // scf::ForOp can be lowered to the structured control flow represented by 165 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop 166 // latch and the merge block the exit block. The resulting spirv::LoopOp has a 167 // single back edge from the continue to header block, and a single exit from 168 // header to merge. 169 auto loc = forOp.getLoc(); 170 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None); 171 loopOp.addEntryAndMergeBlock(); 172 173 OpBuilder::InsertionGuard guard(rewriter); 174 // Create the block for the header. 175 auto *header = new Block(); 176 // Insert the header. 177 loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header); 178 179 // Create the new induction variable to use. 180 Value adapLowerBound = adaptor.getLowerBound(); 181 BlockArgument newIndVar = 182 header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc()); 183 for (Value arg : adaptor.getInitArgs()) 184 header->addArgument(arg.getType(), arg.getLoc()); 185 Block *body = forOp.getBody(); 186 187 // Apply signature conversion to the body of the forOp. It has a single block, 188 // with argument which is the induction variable. That has to be replaced with 189 // the new induction variable. 190 TypeConverter::SignatureConversion signatureConverter( 191 body->getNumArguments()); 192 signatureConverter.remapInput(0, newIndVar); 193 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) 194 signatureConverter.remapInput(i, header->getArgument(i)); 195 body = rewriter.applySignatureConversion(&forOp.getLoopBody(), 196 signatureConverter); 197 198 // Move the blocks from the forOp into the loopOp. This is the body of the 199 // loopOp. 200 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(), 201 getBlockIt(loopOp.body(), 2)); 202 203 SmallVector<Value, 8> args(1, adaptor.getLowerBound()); 204 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); 205 // Branch into it from the entry. 206 rewriter.setInsertionPointToEnd(&(loopOp.body().front())); 207 rewriter.create<spirv::BranchOp>(loc, header, args); 208 209 // Generate the rest of the loop header. 210 rewriter.setInsertionPointToEnd(header); 211 auto *mergeBlock = loopOp.getMergeBlock(); 212 auto cmpOp = rewriter.create<spirv::SLessThanOp>( 213 loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound()); 214 215 rewriter.create<spirv::BranchConditionalOp>( 216 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>()); 217 218 // Generate instructions to increment the step of the induction variable and 219 // branch to the header. 220 Block *continueBlock = loopOp.getContinueBlock(); 221 rewriter.setInsertionPointToEnd(continueBlock); 222 223 // Add the step to the induction variable and branch to the header. 224 Value updatedIndVar = rewriter.create<spirv::IAddOp>( 225 loc, newIndVar.getType(), newIndVar, adaptor.getStep()); 226 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar); 227 228 // Infer the return types from the init operands. Vector type may get 229 // converted to CooperativeMatrix or to Vector type, to avoid having complex 230 // extra logic to figure out the right type we just infer it from the Init 231 // operands. 232 SmallVector<Type, 8> initTypes; 233 for (auto arg : adaptor.getInitArgs()) 234 initTypes.push_back(arg.getType()); 235 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes); 236 return success(); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // scf::IfOp 241 //===----------------------------------------------------------------------===// 242 243 LogicalResult 244 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, 245 ConversionPatternRewriter &rewriter) const { 246 // When lowering `scf::IfOp` we explicitly create a selection header block 247 // before the control flow diverges and a merge block where control flow 248 // subsequently converges. 249 auto loc = ifOp.getLoc(); 250 251 // Create `spv.selection` operation, selection header block and merge block. 252 auto selectionOp = 253 rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); 254 auto *mergeBlock = 255 rewriter.createBlock(&selectionOp.body(), selectionOp.body().end()); 256 rewriter.create<spirv::MergeOp>(loc); 257 258 OpBuilder::InsertionGuard guard(rewriter); 259 auto *selectionHeaderBlock = 260 rewriter.createBlock(&selectionOp.body().front()); 261 262 // Inline `then` region before the merge block and branch to it. 263 auto &thenRegion = ifOp.getThenRegion(); 264 auto *thenBlock = &thenRegion.front(); 265 rewriter.setInsertionPointToEnd(&thenRegion.back()); 266 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 267 rewriter.inlineRegionBefore(thenRegion, mergeBlock); 268 269 auto *elseBlock = mergeBlock; 270 // If `else` region is not empty, inline that region before the merge block 271 // and branch to it. 272 if (!ifOp.getElseRegion().empty()) { 273 auto &elseRegion = ifOp.getElseRegion(); 274 elseBlock = &elseRegion.front(); 275 rewriter.setInsertionPointToEnd(&elseRegion.back()); 276 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 277 rewriter.inlineRegionBefore(elseRegion, mergeBlock); 278 } 279 280 // Create a `spv.BranchConditional` operation for selection header block. 281 rewriter.setInsertionPointToEnd(selectionHeaderBlock); 282 rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(), 283 thenBlock, ArrayRef<Value>(), 284 elseBlock, ArrayRef<Value>()); 285 286 SmallVector<Type, 8> returnTypes; 287 for (auto result : ifOp.getResults()) { 288 auto convertedType = typeConverter.convertType(result.getType()); 289 returnTypes.push_back(convertedType); 290 } 291 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, 292 returnTypes); 293 return success(); 294 } 295 296 //===----------------------------------------------------------------------===// 297 // scf::YieldOp 298 //===----------------------------------------------------------------------===// 299 300 /// Yield is lowered to stores to the VariableOp created during lowering of the 301 /// parent region. For loops we also need to update the branch looping back to 302 /// the header with the loop carried values. 303 LogicalResult TerminatorOpConversion::matchAndRewrite( 304 scf::YieldOp terminatorOp, OpAdaptor adaptor, 305 ConversionPatternRewriter &rewriter) const { 306 ValueRange operands = adaptor.getOperands(); 307 308 // If the region is return values, store each value into the associated 309 // VariableOp created during lowering of the parent region. 310 if (!operands.empty()) { 311 auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()]; 312 if (allocas.size() != operands.size()) 313 return failure(); 314 315 auto loc = terminatorOp.getLoc(); 316 for (unsigned i = 0, e = operands.size(); i < e; i++) 317 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]); 318 if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) { 319 // For loops we also need to update the branch jumping back to the header. 320 auto br = 321 cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator()); 322 SmallVector<Value, 8> args(br.getBlockArguments()); 323 args.append(operands.begin(), operands.end()); 324 rewriter.setInsertionPoint(br); 325 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(), 326 args); 327 rewriter.eraseOp(br); 328 } 329 } 330 rewriter.eraseOp(terminatorOp); 331 return success(); 332 } 333 334 //===----------------------------------------------------------------------===// 335 // scf::WhileOp 336 //===----------------------------------------------------------------------===// 337 338 LogicalResult 339 WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, 340 ConversionPatternRewriter &rewriter) const { 341 auto loc = whileOp.getLoc(); 342 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None); 343 loopOp.addEntryAndMergeBlock(); 344 345 OpBuilder::InsertionGuard guard(rewriter); 346 347 Region &beforeRegion = whileOp.getBefore(); 348 Region &afterRegion = whileOp.getAfter(); 349 350 Block &entryBlock = *loopOp.getEntryBlock(); 351 Block &beforeBlock = beforeRegion.front(); 352 Block &afterBlock = afterRegion.front(); 353 Block &mergeBlock = *loopOp.getMergeBlock(); 354 355 auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator()); 356 SmallVector<Value> condArgs; 357 if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs))) 358 return failure(); 359 360 Value conditionVal = rewriter.getRemappedValue(cond.getCondition()); 361 if (!conditionVal) 362 return failure(); 363 364 auto yield = cast<scf::YieldOp>(afterBlock.getTerminator()); 365 SmallVector<Value> yieldArgs; 366 if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs))) 367 return failure(); 368 369 // Move the while before block as the initial loop header block. 370 rewriter.inlineRegionBefore(beforeRegion, loopOp.body(), 371 getBlockIt(loopOp.body(), 1)); 372 373 // Move the while after block as the initial loop body block. 374 rewriter.inlineRegionBefore(afterRegion, loopOp.body(), 375 getBlockIt(loopOp.body(), 2)); 376 377 // Jump from the loop entry block to the loop header block. 378 rewriter.setInsertionPointToEnd(&entryBlock); 379 rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits()); 380 381 auto condLoc = cond.getLoc(); 382 383 SmallVector<Value> resultValues(condArgs.size()); 384 385 // For other SCF ops, the scf.yield op yields the value for the whole SCF op. 386 // So we use the scf.yield op as the anchor to create/load/store SPIR-V local 387 // variables. But for the scf.while op, the scf.yield op yields a value for 388 // the before region, which may not matching the whole op's result. Instead, 389 // the scf.condition op returns values matching the whole op's results. So we 390 // need to create/load/store variables according to that. 391 for (const auto &it : llvm::enumerate(condArgs)) { 392 auto res = it.value(); 393 auto i = it.index(); 394 auto pointerType = 395 spirv::PointerType::get(res.getType(), spirv::StorageClass::Function); 396 397 // Create local variables before the scf.while op. 398 rewriter.setInsertionPoint(loopOp); 399 auto alloc = rewriter.create<spirv::VariableOp>( 400 condLoc, pointerType, spirv::StorageClass::Function, 401 /*initializer=*/nullptr); 402 403 // Load the final result values after the scf.while op. 404 rewriter.setInsertionPointAfter(loopOp); 405 auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc); 406 resultValues[i] = loadResult; 407 408 // Store the current iteration's result value. 409 rewriter.setInsertionPointToEnd(&beforeBlock); 410 rewriter.create<spirv::StoreOp>(condLoc, alloc, res); 411 } 412 413 rewriter.setInsertionPointToEnd(&beforeBlock); 414 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>( 415 cond, conditionVal, &afterBlock, condArgs, &mergeBlock, llvm::None); 416 417 // Convert the scf.yield op to a branch back to the header block. 418 rewriter.setInsertionPointToEnd(&afterBlock); 419 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock, yieldArgs); 420 421 rewriter.replaceOp(whileOp, resultValues); 422 return success(); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // Hooks 427 //===----------------------------------------------------------------------===// 428 429 void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 430 ScfToSPIRVContext &scfToSPIRVContext, 431 RewritePatternSet &patterns) { 432 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion, 433 WhileOpConversion>(patterns.getContext(), typeConverter, 434 scfToSPIRVContext.getImpl()); 435 } 436