1 //===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" 10 #include "mlir/Dialect/Linalg/IR/Linalg.h" 11 #include "mlir/Dialect/Linalg/Utils/Utils.h" 12 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 14 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 15 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 16 #include "mlir/IR/AffineExpr.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 19 using namespace mlir; 20 21 //===----------------------------------------------------------------------===// 22 // Utilities 23 //===----------------------------------------------------------------------===// 24 25 /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V 26 /// location invocation ID. This function will create necessary operations with 27 /// `builder` at the proper region containing `op`. 28 static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType, 29 Location loc, OpBuilder *builder) { 30 assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions"); 31 Value invocation = spirv::getBuiltinVariableValue( 32 op, spirv::BuiltIn::LocalInvocationId, integerType, *builder); 33 Type xType = invocation.getType().cast<ShapedType>().getElementType(); 34 return builder->create<spirv::CompositeExtractOp>( 35 loc, xType, invocation, builder->getI32ArrayAttr({dim})); 36 } 37 38 //===----------------------------------------------------------------------===// 39 // Reduction (single workgroup) 40 //===----------------------------------------------------------------------===// 41 42 namespace { 43 44 /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition 45 /// that the linalg.generic op is performing reduction with a workload size that 46 /// can fit in one workgroup. 47 struct SingleWorkgroupReduction final 48 : public OpConversionPattern<linalg::GenericOp> { 49 using OpConversionPattern::OpConversionPattern; 50 51 /// Matches the given linalg.generic op as performing reduction and returns 52 /// the binary op kind if successful. 53 static Optional<linalg::RegionMatcher::BinaryOpKind> 54 matchAsPerformingReduction(linalg::GenericOp genericOp); 55 56 LogicalResult 57 matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor, 58 ConversionPatternRewriter &rewriter) const override; 59 }; 60 61 } // namespace 62 63 Optional<linalg::RegionMatcher::BinaryOpKind> 64 SingleWorkgroupReduction::matchAsPerformingReduction( 65 linalg::GenericOp genericOp) { 66 Operation *op = genericOp.getOperation(); 67 68 // Make sure the linalg.generic is working on memrefs. 69 if (!genericOp.hasBufferSemantics()) 70 return llvm::None; 71 72 // Make sure this is reduction with one input and one output. 73 if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1) 74 return llvm::None; 75 76 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); 77 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); 78 79 // Make sure the original input has one dimension. 80 if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1) 81 return llvm::None; 82 // Make sure the original output has one element. 83 if (!originalOutputType.hasStaticShape() || 84 originalOutputType.getNumElements() != 1) 85 return llvm::None; 86 87 if (!genericOp.hasSingleReductionLoop()) 88 return llvm::None; 89 90 auto indexingMaps = genericOp.getIndexingMapsArray(); 91 if (indexingMaps.size() != 2) 92 return llvm::None; 93 94 // TODO: create utility functions for these checks in Linalg 95 // and use them. 96 auto inputMap = indexingMaps[0]; 97 auto outputMap = indexingMaps[1]; 98 // The indexing map for the input should be `(i) -> (i)`. 99 if (inputMap != AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext()))) 100 return llvm::None; 101 // The indexing map for the input should be `(i) -> (0)`. 102 if (outputMap != 103 AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext()))) 104 return llvm::None; 105 106 return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp); 107 } 108 109 LogicalResult SingleWorkgroupReduction::matchAndRewrite( 110 linalg::GenericOp genericOp, OpAdaptor adaptor, 111 ConversionPatternRewriter &rewriter) const { 112 Operation *op = genericOp.getOperation(); 113 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); 114 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); 115 116 auto binaryOpKind = matchAsPerformingReduction(genericOp); 117 if (!binaryOpKind) 118 return failure(); 119 120 // Query the shader interface for local workgroup size to make sure the 121 // invocation configuration fits with the input memref's shape. 122 DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp); 123 if (!localSize) 124 return failure(); 125 126 if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0)) 127 return failure(); 128 if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1), 129 [](const APInt &size) { return !size.isOneValue(); })) 130 return failure(); 131 132 // TODO: Query the target environment to make sure the current 133 // workload fits in a local workgroup. 134 135 Value convertedInput = adaptor.getOperands()[0]; 136 Value convertedOutput = adaptor.getOperands()[1]; 137 Location loc = genericOp.getLoc(); 138 139 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); 140 auto indexType = typeConverter->getIndexType(); 141 142 // Get the invocation ID. 143 Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc, 144 &rewriter); 145 146 // TODO: Load to Workgroup storage class first. 147 148 // Get the input element accessed by this invocation. 149 Value inputElementPtr = spirv::getElementPtr( 150 *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); 151 Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr); 152 153 // Perform the group reduction operation. 154 Value groupOperation; 155 #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \ 156 case linalg::RegionMatcher::BinaryOpKind::opKind: { \ 157 groupOperation = rewriter.create<spirv::spvOp>( \ 158 loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \ 159 spirv::GroupOperation::Reduce, inputElement, \ 160 /*cluster_size=*/nullptr); \ 161 } break 162 switch (*binaryOpKind) { 163 CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp); 164 } 165 #undef CREATE_GROUP_NON_UNIFORM_BIN_OP 166 167 // Get the output element accessed by this reduction. 168 Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter); 169 SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero); 170 Value outputElementPtr = 171 spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput, 172 zeroIndices, loc, rewriter); 173 174 // Write out the final reduction result. This should be only conducted by one 175 // invocation. We use spv.GroupNonUniformElect to find the invocation with the 176 // lowest ID. 177 // 178 // ``` 179 // if (spv.GroupNonUniformElect) { output = ... } 180 // ``` 181 182 Value condition = rewriter.create<spirv::GroupNonUniformElectOp>( 183 loc, spirv::Scope::Subgroup); 184 185 auto createAtomicOp = [&](OpBuilder &builder) { 186 #define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \ 187 case linalg::RegionMatcher::BinaryOpKind::opKind: { \ 188 builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \ 189 spirv::MemorySemantics::AcquireRelease, \ 190 groupOperation); \ 191 } break 192 switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); } 193 #undef CREATE_ATOMIC_BIN_OP 194 }; 195 196 spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter); 197 198 rewriter.eraseOp(genericOp); 199 return success(); 200 } 201 202 //===----------------------------------------------------------------------===// 203 // Pattern population 204 //===----------------------------------------------------------------------===// 205 206 void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 207 RewritePatternSet &patterns) { 208 patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext()); 209 } 210