1 //===- LinalgToSPIRV.cpp - Linalg to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
10 #include "mlir/Dialect/Linalg/IR/Linalg.h"
11 #include "mlir/Dialect/Linalg/Utils/Utils.h"
12 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Utilities
24 //===----------------------------------------------------------------------===//
25 
26 /// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
27 /// location invocation ID. This function will create necessary operations with
28 /// `builder` at the proper region containing `op`.
29 static Value getLocalInvocationDimSize(Operation *op, int dim, Type integerType,
30                                        Location loc, OpBuilder *builder) {
31   assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
32   Value invocation = spirv::getBuiltinVariableValue(
33       op, spirv::BuiltIn::LocalInvocationId, integerType, *builder);
34   Type xType = invocation.getType().cast<ShapedType>().getElementType();
35   return builder->create<spirv::CompositeExtractOp>(
36       loc, xType, invocation, builder->getI32ArrayAttr({dim}));
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Reduction (single workgroup)
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 
45 /// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
46 /// that the linalg.generic op is performing reduction with a workload size that
47 /// can fit in one workgroup.
48 struct SingleWorkgroupReduction final
49     : public OpConversionPattern<linalg::GenericOp> {
50   using OpConversionPattern::OpConversionPattern;
51 
52   /// Matches the given linalg.generic op as performing reduction and returns
53   /// the binary op kind if successful.
54   static Optional<linalg::RegionMatcher::BinaryOpKind>
55   matchAsPerformingReduction(linalg::GenericOp genericOp);
56 
57   LogicalResult
58   matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor,
59                   ConversionPatternRewriter &rewriter) const override;
60 };
61 
62 } // namespace
63 
64 Optional<linalg::RegionMatcher::BinaryOpKind>
65 SingleWorkgroupReduction::matchAsPerformingReduction(
66     linalg::GenericOp genericOp) {
67   Operation *op = genericOp.getOperation();
68 
69   // Make sure the linalg.generic is working on memrefs.
70   if (!genericOp.hasBufferSemantics())
71     return llvm::None;
72 
73   // Make sure this is reduction with one input and one output.
74   if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1)
75     return llvm::None;
76 
77   auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
78   auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
79 
80   // Make sure the original input has one dimension.
81   if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
82     return llvm::None;
83   // Make sure the original output has one element.
84   if (!originalOutputType.hasStaticShape() ||
85       originalOutputType.getNumElements() != 1)
86     return llvm::None;
87 
88   if (!genericOp.hasSingleReductionLoop())
89     return llvm::None;
90 
91   if (genericOp.indexing_maps().getValue().size() != 2)
92     return llvm::None;
93 
94   // TODO: create utility functions for these checks in Linalg
95   // and use them.
96   auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>();
97   auto outputMap =
98       genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>();
99   // The indexing map for the input should be `(i) -> (i)`.
100   if (inputMap.getValue() !=
101       AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext())))
102     return llvm::None;
103   // The indexing map for the input should be `(i) -> (0)`.
104   if (outputMap.getValue() !=
105       AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext())))
106     return llvm::None;
107 
108   return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
109 }
110 
111 LogicalResult SingleWorkgroupReduction::matchAndRewrite(
112     linalg::GenericOp genericOp, OpAdaptor adaptor,
113     ConversionPatternRewriter &rewriter) const {
114   Operation *op = genericOp.getOperation();
115   auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
116   auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
117 
118   auto binaryOpKind = matchAsPerformingReduction(genericOp);
119   if (!binaryOpKind)
120     return failure();
121 
122   // Query the shader interface for local workgroup size to make sure the
123   // invocation configuration fits with the input memref's shape.
124   DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
125   if (!localSize)
126     return failure();
127 
128   if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
129     return failure();
130   if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1),
131                    [](const APInt &size) { return !size.isOneValue(); }))
132     return failure();
133 
134   // TODO: Query the target environment to make sure the current
135   // workload fits in a local workgroup.
136 
137   Value convertedInput = adaptor.getOperands()[0];
138   Value convertedOutput = adaptor.getOperands()[1];
139   Location loc = genericOp.getLoc();
140 
141   auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
142   auto indexType = typeConverter->getIndexType();
143 
144   // Get the invocation ID.
145   Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, indexType, loc,
146                                       &rewriter);
147 
148   // TODO: Load to Workgroup storage class first.
149 
150 
151   // Get the input element accessed by this invocation.
152   Value inputElementPtr = spirv::getElementPtr(
153       *typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
154   Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
155 
156   // Perform the group reduction operation.
157   Value groupOperation;
158 #define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp)                         \
159   case linalg::RegionMatcher::BinaryOpKind::opKind: {                          \
160     groupOperation = rewriter.create<spirv::spvOp>(                            \
161         loc, originalInputType.getElementType(), spirv::Scope::Subgroup,       \
162         spirv::GroupOperation::Reduce, inputElement,                           \
163         /*cluster_size=*/nullptr);                                             \
164   } break
165   switch (*binaryOpKind) {
166     CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
167   }
168 #undef CREATE_GROUP_NON_UNIFORM_BIN_OP
169 
170   // Get the output element accessed by this reduction.
171   Value zero = spirv::ConstantOp::getZero(indexType, loc, rewriter);
172   SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
173   Value outputElementPtr =
174       spirv::getElementPtr(*typeConverter, originalOutputType, convertedOutput,
175                            zeroIndices, loc, rewriter);
176 
177   // Write out the final reduction result. This should be only conducted by one
178   // invocation. We use spv.GroupNonUniformElect to find the invocation with the
179   // lowest ID.
180   //
181   // ```
182   // if (spv.GroupNonUniformElect) { output = ... }
183   // ```
184 
185   Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
186       loc, spirv::Scope::Subgroup);
187 
188   auto createAtomicOp = [&](OpBuilder &builder) {
189 #define CREATE_ATOMIC_BIN_OP(opKind, spvOp)                                    \
190   case linalg::RegionMatcher::BinaryOpKind::opKind: {                          \
191     builder.create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device,  \
192                                  spirv::MemorySemantics::AcquireRelease,       \
193                                  groupOperation);                              \
194   } break
195     switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
196 #undef CREATE_ATOMIC_BIN_OP
197   };
198 
199   spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, rewriter);
200 
201   rewriter.eraseOp(genericOp);
202   return success();
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // Pattern population
207 //===----------------------------------------------------------------------===//
208 
209 void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
210                                          RewritePatternSet &patterns) {
211   patterns.add<SingleWorkgroupReduction>(typeConverter, patterns.getContext());
212 }
213