1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert SCF dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Context 25 //===----------------------------------------------------------------------===// 26 27 namespace mlir { 28 struct ScfToSPIRVContextImpl { 29 // Map between the spirv region control flow operation (spv.mlir.loop or 30 // spv.mlir.selection) to the VariableOp created to store the region results. 31 // The order of the VariableOp matches the order of the results. 32 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars; 33 }; 34 } // namespace mlir 35 36 /// We use ScfToSPIRVContext to store information about the lowering of the scf 37 /// region that need to be used later on. When we lower scf.for/scf.if we create 38 /// VariableOp to store the results. We need to keep track of the VariableOp 39 /// created as we need to insert stores into them when lowering Yield. Those 40 /// StoreOp cannot be created earlier as they may use a different type than 41 /// yield operands. 42 ScfToSPIRVContext::ScfToSPIRVContext() { 43 impl = std::make_unique<ScfToSPIRVContextImpl>(); 44 } 45 46 ScfToSPIRVContext::~ScfToSPIRVContext() = default; 47 48 //===----------------------------------------------------------------------===// 49 // Pattern Declarations 50 //===----------------------------------------------------------------------===// 51 52 namespace { 53 /// Common class for all vector to GPU patterns. 54 template <typename OpTy> 55 class SCFToSPIRVPattern : public OpConversionPattern<OpTy> { 56 public: 57 SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter, 58 ScfToSPIRVContextImpl *scfToSPIRVContext) 59 : OpConversionPattern<OpTy>::OpConversionPattern(context), 60 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {} 61 62 protected: 63 ScfToSPIRVContextImpl *scfToSPIRVContext; 64 // FIXME: We explicitly keep a reference of the type converter here instead of 65 // passing it to OpConversionPattern during construction. This effectively 66 // bypasses the conversion framework's automation on type conversion. This is 67 // needed right now because the conversion framework will unconditionally 68 // legalize all types used by SCF ops upon discovering them, for example, the 69 // types of loop carried values. We use SPIR-V variables for those loop 70 // carried values. Depending on the available capabilities, the SPIR-V 71 // variable can be different, for example, cooperative matrix or normal 72 // variable. We'd like to detach the conversion of the loop carried values 73 // from the SCF ops (which is mainly a region). So we need to "mark" types 74 // used by SCF ops as legal, if to use the conversion framework for type 75 // conversion. There isn't a straightforward way to do that yet, as when 76 // converting types, ops aren't taken into consideration. Therefore, we just 77 // bypass the framework's type conversion for now. 78 SPIRVTypeConverter &typeConverter; 79 }; 80 81 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. 82 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> { 83 public: 84 using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern; 85 86 LogicalResult 87 matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands, 88 ConversionPatternRewriter &rewriter) const override; 89 }; 90 91 /// Pattern to convert a scf::IfOp within kernel functions into 92 /// spirv::SelectionOp. 93 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> { 94 public: 95 using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern; 96 97 LogicalResult 98 matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands, 99 ConversionPatternRewriter &rewriter) const override; 100 }; 101 102 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> { 103 public: 104 using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern; 105 106 LogicalResult 107 matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands, 108 ConversionPatternRewriter &rewriter) const override; 109 }; 110 } // namespace 111 112 /// Helper function to replaces SCF op outputs with SPIR-V variable loads. 113 /// We create VariableOp to handle the results value of the control flow region. 114 /// spv.mlir.loop/spv.mlir.selection currently don't yield value. Right after 115 /// the loop we load the value from the allocation and use it as the SCF op 116 /// result. 117 template <typename ScfOp, typename OpTy> 118 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, 119 ConversionPatternRewriter &rewriter, 120 ScfToSPIRVContextImpl *scfToSPIRVContext, 121 ArrayRef<Type> returnTypes) { 122 123 Location loc = scfOp.getLoc(); 124 auto &allocas = scfToSPIRVContext->outputVars[newOp]; 125 // Clearing the allocas is necessary in case a dialect conversion path failed 126 // previously, and this is the second attempt of this conversion. 127 allocas.clear(); 128 SmallVector<Value, 8> resultValue; 129 for (Type convertedType : returnTypes) { 130 auto pointerType = 131 spirv::PointerType::get(convertedType, spirv::StorageClass::Function); 132 rewriter.setInsertionPoint(newOp); 133 auto alloc = rewriter.create<spirv::VariableOp>( 134 loc, pointerType, spirv::StorageClass::Function, 135 /*initializer=*/nullptr); 136 allocas.push_back(alloc); 137 rewriter.setInsertionPointAfter(newOp); 138 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc); 139 resultValue.push_back(loadResult); 140 } 141 rewriter.replaceOp(scfOp, resultValue); 142 } 143 144 //===----------------------------------------------------------------------===// 145 // scf::ForOp 146 //===----------------------------------------------------------------------===// 147 148 LogicalResult 149 ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands, 150 ConversionPatternRewriter &rewriter) const { 151 // scf::ForOp can be lowered to the structured control flow represented by 152 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop 153 // latch and the merge block the exit block. The resulting spirv::LoopOp has a 154 // single back edge from the continue to header block, and a single exit from 155 // header to merge. 156 scf::ForOpAdaptor forOperands(operands); 157 auto loc = forOp.getLoc(); 158 auto loopControl = rewriter.getI32IntegerAttr( 159 static_cast<uint32_t>(spirv::LoopControl::None)); 160 auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl); 161 loopOp.addEntryAndMergeBlock(); 162 163 OpBuilder::InsertionGuard guard(rewriter); 164 // Create the block for the header. 165 auto *header = new Block(); 166 // Insert the header. 167 loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); 168 169 // Create the new induction variable to use. 170 BlockArgument newIndVar = 171 header->addArgument(forOperands.lowerBound().getType()); 172 for (Value arg : forOperands.initArgs()) 173 header->addArgument(arg.getType()); 174 Block *body = forOp.getBody(); 175 176 // Apply signature conversion to the body of the forOp. It has a single block, 177 // with argument which is the induction variable. That has to be replaced with 178 // the new induction variable. 179 TypeConverter::SignatureConversion signatureConverter( 180 body->getNumArguments()); 181 signatureConverter.remapInput(0, newIndVar); 182 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) 183 signatureConverter.remapInput(i, header->getArgument(i)); 184 body = rewriter.applySignatureConversion(&forOp.getLoopBody(), 185 signatureConverter); 186 187 // Move the blocks from the forOp into the loopOp. This is the body of the 188 // loopOp. 189 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(), 190 std::next(loopOp.body().begin(), 2)); 191 192 SmallVector<Value, 8> args(1, forOperands.lowerBound()); 193 args.append(forOperands.initArgs().begin(), forOperands.initArgs().end()); 194 // Branch into it from the entry. 195 rewriter.setInsertionPointToEnd(&(loopOp.body().front())); 196 rewriter.create<spirv::BranchOp>(loc, header, args); 197 198 // Generate the rest of the loop header. 199 rewriter.setInsertionPointToEnd(header); 200 auto *mergeBlock = loopOp.getMergeBlock(); 201 auto cmpOp = rewriter.create<spirv::SLessThanOp>( 202 loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); 203 204 rewriter.create<spirv::BranchConditionalOp>( 205 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>()); 206 207 // Generate instructions to increment the step of the induction variable and 208 // branch to the header. 209 Block *continueBlock = loopOp.getContinueBlock(); 210 rewriter.setInsertionPointToEnd(continueBlock); 211 212 // Add the step to the induction variable and branch to the header. 213 Value updatedIndVar = rewriter.create<spirv::IAddOp>( 214 loc, newIndVar.getType(), newIndVar, forOperands.step()); 215 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar); 216 217 // Infer the return types from the init operands. Vector type may get 218 // converted to CooperativeMatrix or to Vector type, to avoid having complex 219 // extra logic to figure out the right type we just infer it from the Init 220 // operands. 221 SmallVector<Type, 8> initTypes; 222 for (auto arg : forOperands.initArgs()) 223 initTypes.push_back(arg.getType()); 224 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes); 225 return success(); 226 } 227 228 //===----------------------------------------------------------------------===// 229 // scf::IfOp 230 //===----------------------------------------------------------------------===// 231 232 LogicalResult 233 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands, 234 ConversionPatternRewriter &rewriter) const { 235 // When lowering `scf::IfOp` we explicitly create a selection header block 236 // before the control flow diverges and a merge block where control flow 237 // subsequently converges. 238 scf::IfOpAdaptor ifOperands(operands); 239 auto loc = ifOp.getLoc(); 240 241 // Create `spv.mlir.selection` operation, selection header block and merge 242 // block. 243 auto selectionControl = rewriter.getI32IntegerAttr( 244 static_cast<uint32_t>(spirv::SelectionControl::None)); 245 auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl); 246 auto *mergeBlock = 247 rewriter.createBlock(&selectionOp.body(), selectionOp.body().end()); 248 rewriter.create<spirv::MergeOp>(loc); 249 250 OpBuilder::InsertionGuard guard(rewriter); 251 auto *selectionHeaderBlock = 252 rewriter.createBlock(&selectionOp.body().front()); 253 254 // Inline `then` region before the merge block and branch to it. 255 auto &thenRegion = ifOp.thenRegion(); 256 auto *thenBlock = &thenRegion.front(); 257 rewriter.setInsertionPointToEnd(&thenRegion.back()); 258 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 259 rewriter.inlineRegionBefore(thenRegion, mergeBlock); 260 261 auto *elseBlock = mergeBlock; 262 // If `else` region is not empty, inline that region before the merge block 263 // and branch to it. 264 if (!ifOp.elseRegion().empty()) { 265 auto &elseRegion = ifOp.elseRegion(); 266 elseBlock = &elseRegion.front(); 267 rewriter.setInsertionPointToEnd(&elseRegion.back()); 268 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 269 rewriter.inlineRegionBefore(elseRegion, mergeBlock); 270 } 271 272 // Create a `spv.BranchConditional` operation for selection header block. 273 rewriter.setInsertionPointToEnd(selectionHeaderBlock); 274 rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(), 275 thenBlock, ArrayRef<Value>(), 276 elseBlock, ArrayRef<Value>()); 277 278 SmallVector<Type, 8> returnTypes; 279 for (auto result : ifOp.results()) { 280 auto convertedType = typeConverter.convertType(result.getType()); 281 returnTypes.push_back(convertedType); 282 } 283 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, 284 returnTypes); 285 return success(); 286 } 287 288 //===----------------------------------------------------------------------===// 289 // scf::YieldOp 290 //===----------------------------------------------------------------------===// 291 292 /// Yield is lowered to stores to the VariableOp created during lowering of the 293 /// parent region. For loops we also need to update the branch looping back to 294 /// the header with the loop carried values. 295 LogicalResult TerminatorOpConversion::matchAndRewrite( 296 scf::YieldOp terminatorOp, ArrayRef<Value> operands, 297 ConversionPatternRewriter &rewriter) const { 298 // If the region is return values, store each value into the associated 299 // VariableOp created during lowering of the parent region. 300 if (!operands.empty()) { 301 auto loc = terminatorOp.getLoc(); 302 auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()]; 303 assert(allocas.size() == operands.size()); 304 for (unsigned i = 0, e = operands.size(); i < e; i++) 305 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]); 306 if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) { 307 // For loops we also need to update the branch jumping back to the header. 308 auto br = 309 cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator()); 310 SmallVector<Value, 8> args(br.getBlockArguments()); 311 args.append(operands.begin(), operands.end()); 312 rewriter.setInsertionPoint(br); 313 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(), 314 args); 315 rewriter.eraseOp(br); 316 } 317 } 318 rewriter.eraseOp(terminatorOp); 319 return success(); 320 } 321 322 //===----------------------------------------------------------------------===// 323 // Hooks 324 //===----------------------------------------------------------------------===// 325 326 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context, 327 SPIRVTypeConverter &typeConverter, 328 ScfToSPIRVContext &scfToSPIRVContext, 329 OwningRewritePatternList &patterns) { 330 patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>( 331 context, typeConverter, scfToSPIRVContext.getImpl()); 332 } 333