1 //===- SCFToSPIRV.cpp - Convert SCF ops to SPIR-V dialect -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the conversion patterns from SCF ops to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/SPIRVLowering.h" 16 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 17 #include "mlir/IR/Module.h" 18 19 using namespace mlir; 20 21 namespace mlir { 22 struct ScfToSPIRVContextImpl { 23 // Map between the spirv region control flow operation (spv.loop or 24 // spv.selection) to the VariableOp created to store the region results. The 25 // order of the VariableOp matches the order of the results. 26 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars; 27 }; 28 } // namespace mlir 29 30 /// We use ScfToSPIRVContext to store information about the lowering of the scf 31 /// region that need to be used later on. When we lower scf.for/scf.if we create 32 /// VariableOp to store the results. We need to keep track of the VariableOp 33 /// created as we need to insert stores into them when lowering Yield. Those 34 /// StoreOp cannot be created earlier as they may use a different type than 35 /// yield operands. 36 ScfToSPIRVContext::ScfToSPIRVContext() { 37 impl = std::make_unique<ScfToSPIRVContextImpl>(); 38 } 39 ScfToSPIRVContext::~ScfToSPIRVContext() = default; 40 41 namespace { 42 /// Common class for all vector to GPU patterns. 43 template <typename OpTy> 44 class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> { 45 public: 46 SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter, 47 ScfToSPIRVContextImpl *scfToSPIRVContext) 48 : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter), 49 scfToSPIRVContext(scfToSPIRVContext) {} 50 51 protected: 52 ScfToSPIRVContextImpl *scfToSPIRVContext; 53 }; 54 55 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. 56 class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> { 57 public: 58 using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern; 59 60 LogicalResult 61 matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands, 62 ConversionPatternRewriter &rewriter) const override; 63 }; 64 65 /// Pattern to convert a scf::IfOp within kernel functions into 66 /// spirv::SelectionOp. 67 class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> { 68 public: 69 using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern; 70 71 LogicalResult 72 matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands, 73 ConversionPatternRewriter &rewriter) const override; 74 }; 75 76 class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> { 77 public: 78 using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern; 79 80 LogicalResult 81 matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands, 82 ConversionPatternRewriter &rewriter) const override; 83 }; 84 } // namespace 85 86 /// Helper function to replaces SCF op outputs with SPIR-V variable loads. 87 /// We create VariableOp to handle the results value of the control flow region. 88 /// spv.loop/spv.selection currently don't yield value. Right after the loop 89 /// we load the value from the allocation and use it as the SCF op result. 90 template <typename ScfOp, typename OpTy> 91 static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, 92 SPIRVTypeConverter &typeConverter, 93 ConversionPatternRewriter &rewriter, 94 ScfToSPIRVContextImpl *scfToSPIRVContext, 95 ArrayRef<Type> returnTypes) { 96 97 Location loc = scfOp.getLoc(); 98 auto &allocas = scfToSPIRVContext->outputVars[newOp]; 99 SmallVector<Value, 8> resultValue; 100 for (Type convertedType : returnTypes) { 101 auto pointerType = 102 spirv::PointerType::get(convertedType, spirv::StorageClass::Function); 103 rewriter.setInsertionPoint(newOp); 104 auto alloc = rewriter.create<spirv::VariableOp>( 105 loc, pointerType, spirv::StorageClass::Function, 106 /*initializer=*/nullptr); 107 allocas.push_back(alloc); 108 rewriter.setInsertionPointAfter(newOp); 109 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc); 110 resultValue.push_back(loadResult); 111 } 112 rewriter.replaceOp(scfOp, resultValue); 113 } 114 115 //===----------------------------------------------------------------------===// 116 // scf::ForOp. 117 //===----------------------------------------------------------------------===// 118 119 LogicalResult 120 ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands, 121 ConversionPatternRewriter &rewriter) const { 122 // scf::ForOp can be lowered to the structured control flow represented by 123 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop 124 // latch and the merge block the exit block. The resulting spirv::LoopOp has a 125 // single back edge from the continue to header block, and a single exit from 126 // header to merge. 127 scf::ForOpAdaptor forOperands(operands); 128 auto loc = forOp.getLoc(); 129 auto loopControl = rewriter.getI32IntegerAttr( 130 static_cast<uint32_t>(spirv::LoopControl::None)); 131 auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl); 132 loopOp.addEntryAndMergeBlock(); 133 134 OpBuilder::InsertionGuard guard(rewriter); 135 // Create the block for the header. 136 auto *header = new Block(); 137 // Insert the header. 138 loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); 139 140 // Create the new induction variable to use. 141 BlockArgument newIndVar = 142 header->addArgument(forOperands.lowerBound().getType()); 143 for (Value arg : forOperands.initArgs()) 144 header->addArgument(arg.getType()); 145 Block *body = forOp.getBody(); 146 147 // Apply signature conversion to the body of the forOp. It has a single block, 148 // with argument which is the induction variable. That has to be replaced with 149 // the new induction variable. 150 TypeConverter::SignatureConversion signatureConverter( 151 body->getNumArguments()); 152 signatureConverter.remapInput(0, newIndVar); 153 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) 154 signatureConverter.remapInput(i, header->getArgument(i)); 155 body = rewriter.applySignatureConversion(&forOp.getLoopBody(), 156 signatureConverter); 157 158 // Move the blocks from the forOp into the loopOp. This is the body of the 159 // loopOp. 160 rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(), 161 std::next(loopOp.body().begin(), 2)); 162 163 SmallVector<Value, 8> args(1, forOperands.lowerBound()); 164 args.append(forOperands.initArgs().begin(), forOperands.initArgs().end()); 165 // Branch into it from the entry. 166 rewriter.setInsertionPointToEnd(&(loopOp.body().front())); 167 rewriter.create<spirv::BranchOp>(loc, header, args); 168 169 // Generate the rest of the loop header. 170 rewriter.setInsertionPointToEnd(header); 171 auto *mergeBlock = loopOp.getMergeBlock(); 172 auto cmpOp = rewriter.create<spirv::SLessThanOp>( 173 loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); 174 175 rewriter.create<spirv::BranchConditionalOp>( 176 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>()); 177 178 // Generate instructions to increment the step of the induction variable and 179 // branch to the header. 180 Block *continueBlock = loopOp.getContinueBlock(); 181 rewriter.setInsertionPointToEnd(continueBlock); 182 183 // Add the step to the induction variable and branch to the header. 184 Value updatedIndVar = rewriter.create<spirv::IAddOp>( 185 loc, newIndVar.getType(), newIndVar, forOperands.step()); 186 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar); 187 188 // Infer the return types from the init operands. Vector type may get 189 // converted to CooperativeMatrix or to Vector type, to avoid having complex 190 // extra logic to figure out the right type we just infer it from the Init 191 // operands. 192 SmallVector<Type, 8> initTypes; 193 for (auto arg : forOperands.initArgs()) 194 initTypes.push_back(arg.getType()); 195 replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter, 196 scfToSPIRVContext, initTypes); 197 return success(); 198 } 199 200 //===----------------------------------------------------------------------===// 201 // scf::IfOp. 202 //===----------------------------------------------------------------------===// 203 204 LogicalResult 205 IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands, 206 ConversionPatternRewriter &rewriter) const { 207 // When lowering `scf::IfOp` we explicitly create a selection header block 208 // before the control flow diverges and a merge block where control flow 209 // subsequently converges. 210 scf::IfOpAdaptor ifOperands(operands); 211 auto loc = ifOp.getLoc(); 212 213 // Create `spv.selection` operation, selection header block and merge block. 214 auto selectionControl = rewriter.getI32IntegerAttr( 215 static_cast<uint32_t>(spirv::SelectionControl::None)); 216 auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl); 217 selectionOp.addMergeBlock(); 218 auto *mergeBlock = selectionOp.getMergeBlock(); 219 220 OpBuilder::InsertionGuard guard(rewriter); 221 auto *selectionHeaderBlock = new Block(); 222 selectionOp.body().getBlocks().push_front(selectionHeaderBlock); 223 224 // Inline `then` region before the merge block and branch to it. 225 auto &thenRegion = ifOp.thenRegion(); 226 auto *thenBlock = &thenRegion.front(); 227 rewriter.setInsertionPointToEnd(&thenRegion.back()); 228 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 229 rewriter.inlineRegionBefore(thenRegion, mergeBlock); 230 231 auto *elseBlock = mergeBlock; 232 // If `else` region is not empty, inline that region before the merge block 233 // and branch to it. 234 if (!ifOp.elseRegion().empty()) { 235 auto &elseRegion = ifOp.elseRegion(); 236 elseBlock = &elseRegion.front(); 237 rewriter.setInsertionPointToEnd(&elseRegion.back()); 238 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 239 rewriter.inlineRegionBefore(elseRegion, mergeBlock); 240 } 241 242 // Create a `spv.BranchConditional` operation for selection header block. 243 rewriter.setInsertionPointToEnd(selectionHeaderBlock); 244 rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(), 245 thenBlock, ArrayRef<Value>(), 246 elseBlock, ArrayRef<Value>()); 247 248 SmallVector<Type, 8> returnTypes; 249 for (auto result : ifOp.results()) { 250 auto convertedType = typeConverter.convertType(result.getType()); 251 returnTypes.push_back(convertedType); 252 } 253 replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter, 254 scfToSPIRVContext, returnTypes); 255 return success(); 256 } 257 258 /// Yield is lowered to stores to the VariableOp created during lowering of the 259 /// parent region. For loops we also need to update the branch looping back to 260 /// the header with the loop carried values. 261 LogicalResult TerminatorOpConversion::matchAndRewrite( 262 scf::YieldOp terminatorOp, ArrayRef<Value> operands, 263 ConversionPatternRewriter &rewriter) const { 264 // If the region is return values, store each value into the associated 265 // VariableOp created during lowering of the parent region. 266 if (!operands.empty()) { 267 auto loc = terminatorOp.getLoc(); 268 auto &allocas = scfToSPIRVContext->outputVars[terminatorOp.getParentOp()]; 269 assert(allocas.size() == operands.size()); 270 for (unsigned i = 0, e = operands.size(); i < e; i++) 271 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]); 272 if (isa<spirv::LoopOp>(terminatorOp.getParentOp())) { 273 // For loops we also need to update the branch jumping back to the header. 274 auto br = 275 cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator()); 276 SmallVector<Value, 8> args(br.getBlockArguments()); 277 args.append(operands.begin(), operands.end()); 278 rewriter.setInsertionPoint(br); 279 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(), 280 args); 281 rewriter.eraseOp(br); 282 } 283 } 284 rewriter.eraseOp(terminatorOp); 285 return success(); 286 } 287 288 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context, 289 SPIRVTypeConverter &typeConverter, 290 ScfToSPIRVContext &scfToSPIRVContext, 291 OwningRewritePatternList &patterns) { 292 patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>( 293 context, typeConverter, scfToSPIRVContext.getImpl()); 294 } 295