1 //===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h" 10 #include "mlir/Dialect/Linalg/IR/Linalg.h" 11 #include "mlir/Dialect/Linalg/Utils/Utils.h" 12 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 14 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 17 #include "mlir/IR/AffineExpr.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Utilities 24 //===----------------------------------------------------------------------===// 25 26 /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V 27 /// location invocation ID. This function will create necessary operations with 28 /// `builder` at the proper region containing `op`. 29 static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType, 30 Location loc, OpBuilder *builder) { 31 assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions"); 32 Value invocation = spirv::getBuiltinVariableValue( 33 op, spirv::BuiltIn::LocalInvocationId, integerType, *builder); 34 Type xType = invocation.getType().cast<ShapedType>().getElementType(); 35 return builder->create<spirv::CompositeExtractOp>( 36 loc, xType, invocation, builder->getI32ArrayAttr({dim})); 37 } 38 39 //===----------------------------------------------------------------------===// 40 // Reduction (single workgroup) 41 //===----------------------------------------------------------------------===// 42 43 namespace { 44 45 /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition 46 /// that the linalg.generic op is performing reduction with a workload size that 47 /// can fit in one workgroup. 48 struct SingleWorkgroupReduction final 49 : public OpConversionPattern<linalg::GenericOp> { 50 using OpConversionPattern::OpConversionPattern; 51 52 /// Matches the given linalg.generic op as performing reduction and returns 53 /// the binary op kind if successful. 54 static Optional<linalg::RegionMatcher::BinaryOpKind> 55 matchAsPerformingReduction(linalg::GenericOp genericOp); 56 57 LogicalResult 58 matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor, 59 ConversionPatternRewriter &rewriter) const override; 60 }; 61 62 } // namespace 63 64 Optional<linalg::RegionMatcher::BinaryOpKind> 65 SingleWorkgroupReduction::matchAsPerformingReduction( 66 linalg::GenericOp genericOp) { 67 Operation *op = genericOp.getOperation(); 68 69 // Make sure the linalg.generic is working on memrefs. 70 if (!genericOp.hasBufferSemantics()) 71 return llvm::None; 72 73 // Make sure this is reduction with one input and one output. 74 if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1) 75 return llvm::None; 76 77 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); 78 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); 79 80 // Make sure the original input has one dimension. 81 if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1) 82 return llvm::None; 83 // Make sure the original output has one element. 84 if (!originalOutputType.hasStaticShape() || 85 originalOutputType.getNumElements() != 1) 86 return llvm::None; 87 88 if (!genericOp.hasSingleReductionLoop()) 89 return llvm::None; 90 91 if (genericOp.indexing_maps().getValue().size() != 2) 92 return llvm::None; 93 94 // TODO: create utility functions for these checks in Linalg 95 // and use them. 96 auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>(); 97 auto outputMap = 98 genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>(); 99 // The indexing map for the input should be `(i) -> (i)`. 100 if (inputMap.getValue() != 101 AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext()))) 102 return llvm::None; 103 // The indexing map for the input should be `(i) -> (0)`. 104 if (outputMap.getValue() != 105 AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext()))) 106 return llvm::None; 107 108 return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp); 109 } 110 111 LogicalResult SingleWorkgroupReduction::matchAndRewrite( 112 linalg::GenericOp genericOp, OpAdaptor adaptor, 113 ConversionPatternRewriter &rewriter) const { 114 Operation *op = genericOp.getOperation(); 115 auto originalInputType = op->getOperand(0).getType().cast<MemRefType>(); 116 auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>(); 117 118 auto binaryOpKind = matchAsPerformingReduction(genericOp); 119 if (!binaryOpKind) 120 return failure(); 121 122 // Query the shader interface for local workgroup size to make sure the 123 // invocation configuration fits with the input memref's shape. 124 DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp); 125 if (!localSize) 126 return failure(); 127 128 if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0)) 129 return failure(); 130 if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1), 131 [](const APInt &size) { return !size.isOneValue(); })) 132 return failure(); 133 134 // TODO: Query the target environment to make sure the current 135 // workload fits in a local workgroup. 136 137 Value convertedInput = adaptor.getOperands()[0]; 138 Value convertedOutput = adaptor.getOperands()[1]; 139 Location loc = genericOp.getLoc(); 140 141 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); 142 auto indexType = typeConverter->getIndexType(); 143 144 // Get the invocation ID. 145 Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc, 146 &rewriter); 147 148 // TODO: Load to Workgroup storage class first. 149 150 151 // Get the input element accessed by this invocation. 152 Value inputElementPtr = spirv::getElementPtr( 153 *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter); 154 Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr); 155 156 // Perform the group reduction operation. 157 Value groupOperation; 158 #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \ 159 case linalg::RegionMatcher::BinaryOpKind::opKind: { \ 160 groupOperation = rewriter.create<spirv::spvOp>( \ 161 loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \ 162 spirv::GroupOperation::Reduce, inputElement, \ 163 /*cluster_size=*/nullptr); \ 164 } break 165 switch (*binaryOpKind) { 166 CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp); 167 } 168 #undef CREATE_GROUP_NON_UNIFORM_BIN_OP 169 170 // Get the output element accessed by this reduction. 171 Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter); 172 SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero); 173 Value outputElementPtr = 174 spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput, 175 zeroIndices, loc, rewriter); 176 177 // Write out the final reduction result. This should be only conducted by one 178 // invocation. We use spv.GroupNonUniformElect to find the invocation with the 179 // lowest ID. 180 // 181 // ``` 182 // if (spv.GroupNonUniformElect) { output = ... } 183 // ``` 184 185 Value condition = rewriter.create<spirv::GroupNonUniformElectOp>( 186 loc, spirv::Scope::Subgroup); 187 188 auto createAtomicOp = [&](OpBuilder &builder) { 189 #define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \ 190 case linalg::RegionMatcher::BinaryOpKind::opKind: { \ 191 builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \ 192 spirv::MemorySemantics::AcquireRelease, \ 193 groupOperation); \ 194 } break 195 switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); } 196 #undef CREATE_ATOMIC_BIN_OP 197 }; 198 199 spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter); 200 201 rewriter.eraseOp(genericOp); 202 return success(); 203 } 204 205 //===----------------------------------------------------------------------===// 206 // Pattern population 207 //===----------------------------------------------------------------------===// 208 209 void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 210 RewritePatternSet &patterns) { 211 patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext()); 212 } 213