1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert SCF dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinOps.h" 19 20 using namespace mlir; 21 22 namespace mlir { 23 struct ScfToSPIRVContextImpl { 24 // Map between the spirv region control flow operation (spv.loop or 25 // spv.selection) to the VariableOp created to store the region results. The 26 // order of the VariableOp matches the order of the results. 27 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars; 28 }; 29 } // namespace mlir 30 31 /// We use ScfToSPIRVContext to store information about the lowering of the scf 32 /// region that need to be used later on. When we lower scf.for/scf.if we create 33 /// VariableOp to store the results. We need to keep track of the VariableOp 34 /// created as we need to insert stores into them when lowering Yield. Those 35 /// StoreOp cannot be created earlier as they may use a different type than 36 /// yield operands. 37 ScfToSPIRVContext::ScfToSPIRVContext() { 38 impl = std::make_unique<ScfToSPIRVContextImpl>(); 39 } 40 ScfToSPIRVContext::~ScfToSPIRVContext() = default; 41 42 namespace { 43 /// Common class for all vector to GPU patterns. 44 template <typename OpTy> 45 class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> { 46 public: 47 SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter, 48 ScfToSPIRVContextImpl *scfToSPIRVContext) 49 : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter), 50 scfToSPIRVContext(scfToSPIRVContext) {} 51 52 protected: 53 ScfToSPIRVContextImpl *scfToSPIRVContext; 54 }; 55 56 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. 57 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> { 58 public: 59 using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern; 60 61 LogicalResult 62 matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands, 63 ConversionPatternRewriter &rewriter) const override; 64 }; 65 66 /// Pattern to convert a scf::IfOp within kernel functions into 67 /// spirv::SelectionOp. 68 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> { 69 public: 70 using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern; 71 72 LogicalResult 73 matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands, 74 ConversionPatternRewriter &rewriter) const override; 75 }; 76 77 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> { 78 public: 79 using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern; 80 81 LogicalResult 82 matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands, 83 ConversionPatternRewriter &rewriter) const override; 84 }; 85 } // namespace 86 87 /// Helper function to replaces SCF op outputs with SPIR-V variable loads. 88 /// We create VariableOp to handle the results value of the control flow region. 89 /// spv.loop/spv.selection currently don't yield value. Right after the loop 90 /// we load the value from the allocation and use it as the SCF op result. 91 template <typename ScfOp, typename OpTy> 92 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, 93 SPIRVTypeConverter &typeConverter, 94 ConversionPatternRewriter &rewriter, 95 ScfToSPIRVContextImpl *scfToSPIRVContext, 96 ArrayRef<Type> returnTypes) { 97 98 Location loc = scfOp.getLoc(); 99 auto &allocas = scfToSPIRVContext->outputVars[newOp]; 100 // Clearing the allocas is necessary in case a dialect conversion path failed 101 // previously, and this is the second attempt of this conversion. 102 allocas.clear(); 103 SmallVector<Value, 8> resultValue; 104 for (Type convertedType : returnTypes) { 105 auto pointerType = 106 spirv::PointerType::get(convertedType, spirv::StorageClass::Function); 107 rewriter.setInsertionPoint(newOp); 108 auto alloc = rewriter.create<spirv::VariableOp>( 109 loc, pointerType, spirv::StorageClass::Function, 110 /*initializer=*/nullptr); 111 allocas.push_back(alloc); 112 rewriter.setInsertionPointAfter(newOp); 113 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc); 114 resultValue.push_back(loadResult); 115 } 116 rewriter.replaceOp(scfOp, resultValue); 117 } 118 119 //===----------------------------------------------------------------------===// 120 // scf::ForOp. 121 //===----------------------------------------------------------------------===// 122 123 LogicalResult 124 ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands, 125 ConversionPatternRewriter &rewriter) const { 126 // scf::ForOp can be lowered to the structured control flow represented by 127 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop 128 // latch and the merge block the exit block. The resulting spirv::LoopOp has a 129 // single back edge from the continue to header block, and a single exit from 130 // header to merge. 131 scf::ForOpAdaptor forOperands(operands); 132 auto loc = forOp.getLoc(); 133 auto loopControl = rewriter.getI32IntegerAttr( 134 static_cast<uint32_t>(spirv::LoopControl::None)); 135 auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl); 136 loopOp.addEntryAndMergeBlock(); 137 138 OpBuilder::InsertionGuard guard(rewriter); 139 // Create the block for the header. 140 auto *header = new Block(); 141 // Insert the header. 142 loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); 143 144 // Create the new induction variable to use. 145 BlockArgument newIndVar = 146 header->addArgument(forOperands.lowerBound().getType()); 147 for (Value arg : forOperands.initArgs()) 148 header->addArgument(arg.getType()); 149 Block *body = forOp.getBody(); 150 151 // Apply signature conversion to the body of the forOp. It has a single block, 152 // with argument which is the induction variable. That has to be replaced with 153 // the new induction variable. 154 TypeConverter::SignatureConversion signatureConverter( 155 body->getNumArguments()); 156 signatureConverter.remapInput(0, newIndVar); 157 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) 158 signatureConverter.remapInput(i, header->getArgument(i)); 159 body = rewriter.applySignatureConversion(&forOp.getLoopBody(), 160 signatureConverter); 161 162 // Move the blocks from the forOp into the loopOp. This is the body of the 163 // loopOp. 164 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(), 165 std::next(loopOp.body().begin(), 2)); 166 167 SmallVector<Value, 8> args(1, forOperands.lowerBound()); 168 args.append(forOperands.initArgs().begin(), forOperands.initArgs().end()); 169 // Branch into it from the entry. 170 rewriter.setInsertionPointToEnd(&(loopOp.body().front())); 171 rewriter.create<spirv::BranchOp>(loc, header, args); 172 173 // Generate the rest of the loop header. 174 rewriter.setInsertionPointToEnd(header); 175 auto *mergeBlock = loopOp.getMergeBlock(); 176 auto cmpOp = rewriter.create<spirv::SLessThanOp>( 177 loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); 178 179 rewriter.create<spirv::BranchConditionalOp>( 180 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>()); 181 182 // Generate instructions to increment the step of the induction variable and 183 // branch to the header. 184 Block *continueBlock = loopOp.getContinueBlock(); 185 rewriter.setInsertionPointToEnd(continueBlock); 186 187 // Add the step to the induction variable and branch to the header. 188 Value updatedIndVar = rewriter.create<spirv::IAddOp>( 189 loc, newIndVar.getType(), newIndVar, forOperands.step()); 190 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar); 191 192 // Infer the return types from the init operands. Vector type may get 193 // converted to CooperativeMatrix or to Vector type, to avoid having complex 194 // extra logic to figure out the right type we just infer it from the Init 195 // operands. 196 SmallVector<Type, 8> initTypes; 197 for (auto arg : forOperands.initArgs()) 198 initTypes.push_back(arg.getType()); 199 replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter, 200 scfToSPIRVContext, initTypes); 201 return success(); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // scf::IfOp. 206 //===----------------------------------------------------------------------===// 207 208 LogicalResult 209 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands, 210 ConversionPatternRewriter &rewriter) const { 211 // When lowering `scf::IfOp` we explicitly create a selection header block 212 // before the control flow diverges and a merge block where control flow 213 // subsequently converges. 214 scf::IfOpAdaptor ifOperands(operands); 215 auto loc = ifOp.getLoc(); 216 217 // Create `spv.selection` operation, selection header block and merge block. 218 auto selectionControl = rewriter.getI32IntegerAttr( 219 static_cast<uint32_t>(spirv::SelectionControl::None)); 220 auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl); 221 auto *mergeBlock = 222 rewriter.createBlock(&selectionOp.body(), selectionOp.body().end()); 223 rewriter.create<spirv::MergeOp>(loc); 224 225 OpBuilder::InsertionGuard guard(rewriter); 226 auto *selectionHeaderBlock = 227 rewriter.createBlock(&selectionOp.body().front()); 228 229 // Inline `then` region before the merge block and branch to it. 230 auto &thenRegion = ifOp.thenRegion(); 231 auto *thenBlock = &thenRegion.front(); 232 rewriter.setInsertionPointToEnd(&thenRegion.back()); 233 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 234 rewriter.inlineRegionBefore(thenRegion, mergeBlock); 235 236 auto *elseBlock = mergeBlock; 237 // If `else` region is not empty, inline that region before the merge block 238 // and branch to it. 239 if (!ifOp.elseRegion().empty()) { 240 auto &elseRegion = ifOp.elseRegion(); 241 elseBlock = &elseRegion.front(); 242 rewriter.setInsertionPointToEnd(&elseRegion.back()); 243 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 244 rewriter.inlineRegionBefore(elseRegion, mergeBlock); 245 } 246 247 // Create a `spv.BranchConditional` operation for selection header block. 248 rewriter.setInsertionPointToEnd(selectionHeaderBlock); 249 rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(), 250 thenBlock, ArrayRef<Value>(), 251 elseBlock, ArrayRef<Value>()); 252 253 SmallVector<Type, 8> returnTypes; 254 for (auto result : ifOp.results()) { 255 auto convertedType = typeConverter.convertType(result.getType()); 256 returnTypes.push_back(convertedType); 257 } 258 replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter, 259 scfToSPIRVContext, returnTypes); 260 return success(); 261 } 262 263 /// Yield is lowered to stores to the VariableOp created during lowering of the 264 /// parent region. For loops we also need to update the branch looping back to 265 /// the header with the loop carried values. 266 LogicalResult TerminatorOpConversion::matchAndRewrite( 267 scf::YieldOp terminatorOp, ArrayRef<Value> operands, 268 ConversionPatternRewriter &rewriter) const { 269 // If the region is return values, store each value into the associated 270 // VariableOp created during lowering of the parent region. 271 if (!operands.empty()) { 272 auto loc = terminatorOp.getLoc(); 273 auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()]; 274 assert(allocas.size() == operands.size()); 275 for (unsigned i = 0, e = operands.size(); i < e; i++) 276 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]); 277 if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) { 278 // For loops we also need to update the branch jumping back to the header. 279 auto br = 280 cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator()); 281 SmallVector<Value, 8> args(br.getBlockArguments()); 282 args.append(operands.begin(), operands.end()); 283 rewriter.setInsertionPoint(br); 284 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(), 285 args); 286 rewriter.eraseOp(br); 287 } 288 } 289 rewriter.eraseOp(terminatorOp); 290 return success(); 291 } 292 293 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context, 294 SPIRVTypeConverter &typeConverter, 295 ScfToSPIRVContext &scfToSPIRVContext, 296 OwningRewritePatternList &patterns) { 297 patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>( 298 context, typeConverter, scfToSPIRVContext.getImpl()); 299 } 300