1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <utility>
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 
17 #include "mlir/Dialect/CommonFolders.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Common utility functions
27 //===----------------------------------------------------------------------===//
28 
29 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
30 /// or splat vector bool constant.
getScalarOrSplatBoolAttr(Attribute boolAttr)31 static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
32   if (!boolAttr)
33     return llvm::None;
34 
35   auto type = boolAttr.getType();
36   if (type.isInteger(1)) {
37     auto attr = boolAttr.cast<BoolAttr>();
38     return attr.getValue();
39   }
40   if (auto vecType = type.cast<VectorType>()) {
41     if (vecType.getElementType().isInteger(1))
42       if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
43         return attr.getSplatValue<bool>();
44   }
45   return llvm::None;
46 }
47 
48 // Extracts an element from the given `composite` by following the given
49 // `indices`. Returns a null Attribute if error happens.
extractCompositeElement(Attribute composite,ArrayRef<unsigned> indices)50 static Attribute extractCompositeElement(Attribute composite,
51                                          ArrayRef<unsigned> indices) {
52   // Check that given composite is a constant.
53   if (!composite)
54     return {};
55   // Return composite itself if we reach the end of the index chain.
56   if (indices.empty())
57     return composite;
58 
59   if (auto vector = composite.dyn_cast<ElementsAttr>()) {
60     assert(indices.size() == 1 && "must have exactly one index for a vector");
61     return vector.getValues<Attribute>()[indices[0]];
62   }
63 
64   if (auto array = composite.dyn_cast<ArrayAttr>()) {
65     assert(!indices.empty() && "must have at least one index for an array");
66     return extractCompositeElement(array.getValue()[indices[0]],
67                                    indices.drop_front());
68   }
69 
70   return {};
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // TableGen'erated canonicalizers
75 //===----------------------------------------------------------------------===//
76 
77 namespace {
78 #include "SPIRVCanonicalization.inc"
79 } // namespace
80 
81 //===----------------------------------------------------------------------===//
82 // spv.AccessChainOp
83 //===----------------------------------------------------------------------===//
84 
85 namespace {
86 
87 /// Combines chained `spirv::AccessChainOp` operations into one
88 /// `spirv::AccessChainOp` operation.
89 struct CombineChainedAccessChain
90     : public OpRewritePattern<spirv::AccessChainOp> {
91   using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
92 
matchAndRewrite__anon884ef7a50211::CombineChainedAccessChain93   LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
94                                 PatternRewriter &rewriter) const override {
95     auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
96         accessChainOp.base_ptr().getDefiningOp());
97 
98     if (!parentAccessChainOp) {
99       return failure();
100     }
101 
102     // Combine indices.
103     SmallVector<Value, 4> indices(parentAccessChainOp.indices());
104     indices.append(accessChainOp.indices().begin(),
105                    accessChainOp.indices().end());
106 
107     rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
108         accessChainOp, parentAccessChainOp.base_ptr(), indices);
109 
110     return success();
111   }
112 };
113 } // namespace
114 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)115 void spirv::AccessChainOp::getCanonicalizationPatterns(
116     RewritePatternSet &results, MLIRContext *context) {
117   results.add<CombineChainedAccessChain>(context);
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // spv.BitcastOp
122 //===----------------------------------------------------------------------===//
123 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)124 void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
125                                                    MLIRContext *context) {
126   results.add<ConvertChainedBitcast>(context);
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // spv.CompositeExtractOp
131 //===----------------------------------------------------------------------===//
132 
fold(ArrayRef<Attribute> operands)133 OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
134   assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
135   auto indexVector =
136       llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
137         return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
138       }));
139   return extractCompositeElement(operands[0], indexVector);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // spv.Constant
144 //===----------------------------------------------------------------------===//
145 
fold(ArrayRef<Attribute> operands)146 OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
147   assert(operands.empty() && "spv.Constant has no operands");
148   return value();
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // spv.IAdd
153 //===----------------------------------------------------------------------===//
154 
fold(ArrayRef<Attribute> operands)155 OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
156   assert(operands.size() == 2 && "spv.IAdd expects two operands");
157   // x + 0 = x
158   if (matchPattern(operand2(), m_Zero()))
159     return operand1();
160 
161   // According to the SPIR-V spec:
162   //
163   // The resulting value will equal the low-order N bits of the correct result
164   // R, where N is the component width and R is computed with enough precision
165   // to avoid overflow and underflow.
166   return constFoldBinaryOp<IntegerAttr>(
167       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // spv.IMul
172 //===----------------------------------------------------------------------===//
173 
fold(ArrayRef<Attribute> operands)174 OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
175   assert(operands.size() == 2 && "spv.IMul expects two operands");
176   // x * 0 == 0
177   if (matchPattern(operand2(), m_Zero()))
178     return operand2();
179   // x * 1 = x
180   if (matchPattern(operand2(), m_One()))
181     return operand1();
182 
183   // According to the SPIR-V spec:
184   //
185   // The resulting value will equal the low-order N bits of the correct result
186   // R, where N is the component width and R is computed with enough precision
187   // to avoid overflow and underflow.
188   return constFoldBinaryOp<IntegerAttr>(
189       operands, [](const APInt &a, const APInt &b) { return a * b; });
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // spv.ISub
194 //===----------------------------------------------------------------------===//
195 
fold(ArrayRef<Attribute> operands)196 OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
197   // x - x = 0
198   if (operand1() == operand2())
199     return Builder(getContext()).getIntegerAttr(getType(), 0);
200 
201   // According to the SPIR-V spec:
202   //
203   // The resulting value will equal the low-order N bits of the correct result
204   // R, where N is the component width and R is computed with enough precision
205   // to avoid overflow and underflow.
206   return constFoldBinaryOp<IntegerAttr>(
207       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // spv.LogicalAnd
212 //===----------------------------------------------------------------------===//
213 
fold(ArrayRef<Attribute> operands)214 OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
215   assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
216 
217   if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
218     // x && true = x
219     if (rhs.value())
220       return operand1();
221 
222     // x && false = false
223     if (!rhs.value())
224       return operands.back();
225   }
226 
227   return Attribute();
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // spv.LogicalNot
232 //===----------------------------------------------------------------------===//
233 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)234 void spirv::LogicalNotOp::getCanonicalizationPatterns(
235     RewritePatternSet &results, MLIRContext *context) {
236   results
237       .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
238            ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
239           context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // spv.LogicalOr
244 //===----------------------------------------------------------------------===//
245 
fold(ArrayRef<Attribute> operands)246 OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
247   assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
248 
249   if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
250     if (rhs.value())
251       // x || true = true
252       return operands.back();
253 
254     // x || false = x
255     if (!rhs.value())
256       return operand1();
257   }
258 
259   return Attribute();
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // spv.mlir.selection
264 //===----------------------------------------------------------------------===//
265 
266 namespace {
267 // Blocks from the given `spv.mlir.selection` operation must satisfy the
268 // following layout:
269 //
270 //       +-----------------------------------------------+
271 //       | header block                                  |
272 //       | spv.BranchConditionalOp %cond, ^case0, ^case1 |
273 //       +-----------------------------------------------+
274 //                            /   \
275 //                             ...
276 //
277 //
278 //   +------------------------+    +------------------------+
279 //   | case #0                |    | case #1                |
280 //   | spv.Store %ptr %value0 |    | spv.Store %ptr %value1 |
281 //   | spv.Branch ^merge      |    | spv.Branch ^merge      |
282 //   +------------------------+    +------------------------+
283 //
284 //
285 //                             ...
286 //                            \   /
287 //                              v
288 //                       +-------------+
289 //                       | merge block |
290 //                       +-------------+
291 //
292 struct ConvertSelectionOpToSelect
293     : public OpRewritePattern<spirv::SelectionOp> {
294   using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
295 
matchAndRewrite__anon884ef7a50711::ConvertSelectionOpToSelect296   LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
297                                 PatternRewriter &rewriter) const override {
298     auto *op = selectionOp.getOperation();
299     auto &body = op->getRegion(0);
300     // Verifier allows an empty region for `spv.mlir.selection`.
301     if (body.empty()) {
302       return failure();
303     }
304 
305     // Check that region consists of 4 blocks:
306     // header block, `true` block, `false` block and merge block.
307     if (std::distance(body.begin(), body.end()) != 4) {
308       return failure();
309     }
310 
311     auto *headerBlock = selectionOp.getHeaderBlock();
312     if (!onlyContainsBranchConditionalOp(headerBlock)) {
313       return failure();
314     }
315 
316     auto brConditionalOp =
317         cast<spirv::BranchConditionalOp>(headerBlock->front());
318 
319     auto *trueBlock = brConditionalOp.getSuccessor(0);
320     auto *falseBlock = brConditionalOp.getSuccessor(1);
321     auto *mergeBlock = selectionOp.getMergeBlock();
322 
323     if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
324       return failure();
325 
326     auto trueValue = getSrcValue(trueBlock);
327     auto falseValue = getSrcValue(falseBlock);
328     auto ptrValue = getDstPtr(trueBlock);
329     auto storeOpAttributes =
330         cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
331 
332     auto selectOp = rewriter.create<spirv::SelectOp>(
333         selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
334         trueValue, falseValue);
335     rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
336                                     selectOp.getResult(), storeOpAttributes);
337 
338     // `spv.mlir.selection` is not needed anymore.
339     rewriter.eraseOp(op);
340     return success();
341   }
342 
343 private:
344   // Checks that given blocks follow the following rules:
345   // 1. Each conditional block consists of two operations, the first operation
346   //    is a `spv.Store` and the last operation is a `spv.Branch`.
347   // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
348   // 3. A control flow goes into the given merge block from the given
349   //    conditional blocks.
350   LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
351                                          Block *mergeBlock) const;
352 
onlyContainsBranchConditionalOp__anon884ef7a50711::ConvertSelectionOpToSelect353   bool onlyContainsBranchConditionalOp(Block *block) const {
354     return std::next(block->begin()) == block->end() &&
355            isa<spirv::BranchConditionalOp>(block->front());
356   }
357 
isSameAttrList__anon884ef7a50711::ConvertSelectionOpToSelect358   bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
359     return lhs->getAttrDictionary() == rhs->getAttrDictionary();
360   }
361 
362   // Returns a source value for the given block.
getSrcValue__anon884ef7a50711::ConvertSelectionOpToSelect363   Value getSrcValue(Block *block) const {
364     auto storeOp = cast<spirv::StoreOp>(block->front());
365     return storeOp.value();
366   }
367 
368   // Returns a destination value for the given block.
getDstPtr__anon884ef7a50711::ConvertSelectionOpToSelect369   Value getDstPtr(Block *block) const {
370     auto storeOp = cast<spirv::StoreOp>(block->front());
371     return storeOp.ptr();
372   }
373 };
374 
canCanonicalizeSelection(Block * trueBlock,Block * falseBlock,Block * mergeBlock) const375 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
376     Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
377   // Each block must consists of 2 operations.
378   if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
379       (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
380     return failure();
381   }
382 
383   auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
384   auto trueBrBranchOp =
385       dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
386   auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
387   auto falseBrBranchOp =
388       dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
389 
390   if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
391       !falseBrBranchOp) {
392     return failure();
393   }
394 
395   // Checks that given type is valid for `spv.SelectOp`.
396   // According to SPIR-V spec:
397   // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
398   // Starting with version 1.4, Result Type can additionally be a composite type
399   // other than a vector."
400   bool isScalarOrVector = trueBrStoreOp.value()
401                               .getType()
402                               .cast<spirv::SPIRVType>()
403                               .isScalarOrVector();
404 
405   // Check that each `spv.Store` uses the same pointer, memory access
406   // attributes and a valid type of the value.
407   if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
408       !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
409     return failure();
410   }
411 
412   if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
413       (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
414     return failure();
415   }
416 
417   return success();
418 }
419 } // namespace
420 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)421 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
422                                                      MLIRContext *context) {
423   results.add<ConvertSelectionOpToSelect>(context);
424 }
425