1 //===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
10 #include "mlir/Dialect/Linalg/IR/Linalg.h"
11 #include "mlir/Dialect/Linalg/Utils/Utils.h"
12 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
15 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
16 #include "mlir/IR/AffineExpr.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // Utilities
23 //===----------------------------------------------------------------------===//
24 
25 /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
26 /// location invocation ID. This function will create necessary operations with
27 /// `builder` at the proper region containing `op`.
getLocalInvocationDimSize(Operation * op,int dim,Type integerType,Location loc,OpBuilder * builder)28 static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType,
29                                        Location loc, OpBuilder *builder) {
30   assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
31   Value invocation = spirv::getBuiltinVariableValue(
32       op, spirv::BuiltIn::LocalInvocationId, integerType, *builder);
33   Type xType = invocation.getType().cast<ShapedType>().getElementType();
34   return builder->create<spirv::CompositeExtractOp>(
35       loc, xType, invocation, builder->getI32ArrayAttr({dim}));
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Reduction (single workgroup)
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 
44 /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
45 /// that the linalg.generic op is performing reduction with a workload size that
46 /// can fit in one workgroup.
47 struct SingleWorkgroupReduction final
48     : public OpConversionPattern<linalg::GenericOp> {
49   using OpConversionPattern::OpConversionPattern;
50 
51   /// Matches the given linalg.generic op as performing reduction and returns
52   /// the binary op kind if successful.
53   static Optional<linalg::RegionMatcher::BinaryOpKind>
54   matchAsPerformingReduction(linalg::GenericOp genericOp);
55 
56   LogicalResult
57   matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor,
58                   ConversionPatternRewriter &rewriter) const override;
59 };
60 
61 } // namespace
62 
63 Optional<linalg::RegionMatcher::BinaryOpKind>
matchAsPerformingReduction(linalg::GenericOp genericOp)64 SingleWorkgroupReduction::matchAsPerformingReduction(
65     linalg::GenericOp genericOp) {
66   Operation *op = genericOp.getOperation();
67 
68   // Make sure the linalg.generic is working on memrefs.
69   if (!genericOp.hasBufferSemantics())
70     return llvm::None;
71 
72   // Make sure this is reduction with one input and one output.
73   if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1)
74     return llvm::None;
75 
76   auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
77   auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
78 
79   // Make sure the original input has one dimension.
80   if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
81     return llvm::None;
82   // Make sure the original output has one element.
83   if (!originalOutputType.hasStaticShape() ||
84       originalOutputType.getNumElements() != 1)
85     return llvm::None;
86 
87   if (!genericOp.hasSingleReductionLoop())
88     return llvm::None;
89 
90   auto indexingMaps = genericOp.getIndexingMapsArray();
91   if (indexingMaps.size() != 2)
92     return llvm::None;
93 
94   // TODO: create utility functions for these checks in Linalg
95   // and use them.
96   auto inputMap = indexingMaps[0];
97   auto outputMap = indexingMaps[1];
98   // The indexing map for the input should be `(i) -> (i)`.
99   if (inputMap != AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext())))
100     return llvm::None;
101   // The indexing map for the input should be `(i) -> (0)`.
102   if (outputMap !=
103       AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext())))
104     return llvm::None;
105 
106   return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
107 }
108 
matchAndRewrite(linalg::GenericOp genericOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const109 LogicalResult SingleWorkgroupReduction::matchAndRewrite(
110     linalg::GenericOp genericOp, OpAdaptor adaptor,
111     ConversionPatternRewriter &rewriter) const {
112   Operation *op = genericOp.getOperation();
113   auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
114   auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
115 
116   auto binaryOpKind = matchAsPerformingReduction(genericOp);
117   if (!binaryOpKind)
118     return failure();
119 
120   // Query the shader interface for local workgroup size to make sure the
121   // invocation configuration fits with the input memref's shape.
122   DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
123   if (!localSize)
124     return failure();
125 
126   if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
127     return failure();
128   if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1),
129                    [](const APInt &size) { return !size.isOneValue(); }))
130     return failure();
131 
132   // TODO: Query the target environment to make sure the current
133   // workload fits in a local workgroup.
134 
135   Value convertedInput = adaptor.getOperands()[0];
136   Value convertedOutput = adaptor.getOperands()[1];
137   Location loc = genericOp.getLoc();
138 
139   auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
140   auto indexType = typeConverter->getIndexType();
141 
142   // Get the invocation ID.
143   Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc,
144                                       &rewriter);
145 
146   // TODO: Load to Workgroup storage class first.
147 
148   // Get the input element accessed by this invocation.
149   Value inputElementPtr = spirv::getElementPtr(
150       *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
151   Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
152 
153   // Perform the group reduction operation.
154   Value groupOperation;
155 #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp)                         \
156   case linalg::RegionMatcher::BinaryOpKind::opKind: {                          \
157     groupOperation = rewriter.create<spirv::spvOp>(                            \
158         loc, originalInputType.getElementType(), spirv::Scope::Subgroup,       \
159         spirv::GroupOperation::Reduce, inputElement,                           \
160         /*cluster_size=*/nullptr);                                             \
161   } break
162   switch (*binaryOpKind) {
163     CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
164   }
165 #undef CREATE_GROUP_NON_UNIFORM_BIN_OP
166 
167   // Get the output element accessed by this reduction.
168   Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter);
169   SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
170   Value outputElementPtr =
171       spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,
172                            zeroIndices, loc, rewriter);
173 
174   // Write out the final reduction result. This should be only conducted by one
175   // invocation. We use spv.GroupNonUniformElect to find the invocation with the
176   // lowest ID.
177   //
178   // ```
179   // if (spv.GroupNonUniformElect) { output = ... }
180   // ```
181 
182   Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
183       loc, spirv::Scope::Subgroup);
184 
185   auto createAtomicOp = [&](OpBuilder &builder) {
186 #define CREATE_ATOMIC_BIN_OP(opKind, spvOp)                                    \
187   case linalg::RegionMatcher::BinaryOpKind::opKind: {                          \
188     builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device,  \
189                                  spirv::MemorySemantics::AcquireRelease,       \
190                                  groupOperation);                              \
191   } break
192     switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
193 #undef CREATE_ATOMIC_BIN_OP
194   };
195 
196   spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter);
197 
198   rewriter.eraseOp(genericOp);
199   return success();
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // Pattern population
204 //===----------------------------------------------------------------------===//
205 
populateLinalgToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)206 void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
207                                          RewritePatternSet &patterns) {
208   patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext());
209 }
210