1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include <utility>
14
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16
17 #include "mlir/Dialect/CommonFolders.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22
23 using namespace mlir;
24
25 //===----------------------------------------------------------------------===//
26 // Common utility functions
27 //===----------------------------------------------------------------------===//
28
29 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
30 /// or splat vector bool constant.
getScalarOrSplatBoolAttr(Attribute boolAttr)31 static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
32 if (!boolAttr)
33 return llvm::None;
34
35 auto type = boolAttr.getType();
36 if (type.isInteger(1)) {
37 auto attr = boolAttr.cast<BoolAttr>();
38 return attr.getValue();
39 }
40 if (auto vecType = type.cast<VectorType>()) {
41 if (vecType.getElementType().isInteger(1))
42 if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
43 return attr.getSplatValue<bool>();
44 }
45 return llvm::None;
46 }
47
48 // Extracts an element from the given `composite` by following the given
49 // `indices`. Returns a null Attribute if error happens.
extractCompositeElement(Attribute composite,ArrayRef<unsigned> indices)50 static Attribute extractCompositeElement(Attribute composite,
51 ArrayRef<unsigned> indices) {
52 // Check that given composite is a constant.
53 if (!composite)
54 return {};
55 // Return composite itself if we reach the end of the index chain.
56 if (indices.empty())
57 return composite;
58
59 if (auto vector = composite.dyn_cast<ElementsAttr>()) {
60 assert(indices.size() == 1 && "must have exactly one index for a vector");
61 return vector.getValues<Attribute>()[indices[0]];
62 }
63
64 if (auto array = composite.dyn_cast<ArrayAttr>()) {
65 assert(!indices.empty() && "must have at least one index for an array");
66 return extractCompositeElement(array.getValue()[indices[0]],
67 indices.drop_front());
68 }
69
70 return {};
71 }
72
73 //===----------------------------------------------------------------------===//
74 // TableGen'erated canonicalizers
75 //===----------------------------------------------------------------------===//
76
77 namespace {
78 #include "SPIRVCanonicalization.inc"
79 } // namespace
80
81 //===----------------------------------------------------------------------===//
82 // spv.AccessChainOp
83 //===----------------------------------------------------------------------===//
84
85 namespace {
86
87 /// Combines chained `spirv::AccessChainOp` operations into one
88 /// `spirv::AccessChainOp` operation.
89 struct CombineChainedAccessChain
90 : public OpRewritePattern<spirv::AccessChainOp> {
91 using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
92
matchAndRewrite__anon884ef7a50211::CombineChainedAccessChain93 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
94 PatternRewriter &rewriter) const override {
95 auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
96 accessChainOp.base_ptr().getDefiningOp());
97
98 if (!parentAccessChainOp) {
99 return failure();
100 }
101
102 // Combine indices.
103 SmallVector<Value, 4> indices(parentAccessChainOp.indices());
104 indices.append(accessChainOp.indices().begin(),
105 accessChainOp.indices().end());
106
107 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
108 accessChainOp, parentAccessChainOp.base_ptr(), indices);
109
110 return success();
111 }
112 };
113 } // namespace
114
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)115 void spirv::AccessChainOp::getCanonicalizationPatterns(
116 RewritePatternSet &results, MLIRContext *context) {
117 results.add<CombineChainedAccessChain>(context);
118 }
119
120 //===----------------------------------------------------------------------===//
121 // spv.BitcastOp
122 //===----------------------------------------------------------------------===//
123
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)124 void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
125 MLIRContext *context) {
126 results.add<ConvertChainedBitcast>(context);
127 }
128
129 //===----------------------------------------------------------------------===//
130 // spv.CompositeExtractOp
131 //===----------------------------------------------------------------------===//
132
fold(ArrayRef<Attribute> operands)133 OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
134 assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
135 auto indexVector =
136 llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
137 return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
138 }));
139 return extractCompositeElement(operands[0], indexVector);
140 }
141
142 //===----------------------------------------------------------------------===//
143 // spv.Constant
144 //===----------------------------------------------------------------------===//
145
fold(ArrayRef<Attribute> operands)146 OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
147 assert(operands.empty() && "spv.Constant has no operands");
148 return value();
149 }
150
151 //===----------------------------------------------------------------------===//
152 // spv.IAdd
153 //===----------------------------------------------------------------------===//
154
fold(ArrayRef<Attribute> operands)155 OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
156 assert(operands.size() == 2 && "spv.IAdd expects two operands");
157 // x + 0 = x
158 if (matchPattern(operand2(), m_Zero()))
159 return operand1();
160
161 // According to the SPIR-V spec:
162 //
163 // The resulting value will equal the low-order N bits of the correct result
164 // R, where N is the component width and R is computed with enough precision
165 // to avoid overflow and underflow.
166 return constFoldBinaryOp<IntegerAttr>(
167 operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
168 }
169
170 //===----------------------------------------------------------------------===//
171 // spv.IMul
172 //===----------------------------------------------------------------------===//
173
fold(ArrayRef<Attribute> operands)174 OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
175 assert(operands.size() == 2 && "spv.IMul expects two operands");
176 // x * 0 == 0
177 if (matchPattern(operand2(), m_Zero()))
178 return operand2();
179 // x * 1 = x
180 if (matchPattern(operand2(), m_One()))
181 return operand1();
182
183 // According to the SPIR-V spec:
184 //
185 // The resulting value will equal the low-order N bits of the correct result
186 // R, where N is the component width and R is computed with enough precision
187 // to avoid overflow and underflow.
188 return constFoldBinaryOp<IntegerAttr>(
189 operands, [](const APInt &a, const APInt &b) { return a * b; });
190 }
191
192 //===----------------------------------------------------------------------===//
193 // spv.ISub
194 //===----------------------------------------------------------------------===//
195
fold(ArrayRef<Attribute> operands)196 OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
197 // x - x = 0
198 if (operand1() == operand2())
199 return Builder(getContext()).getIntegerAttr(getType(), 0);
200
201 // According to the SPIR-V spec:
202 //
203 // The resulting value will equal the low-order N bits of the correct result
204 // R, where N is the component width and R is computed with enough precision
205 // to avoid overflow and underflow.
206 return constFoldBinaryOp<IntegerAttr>(
207 operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
208 }
209
210 //===----------------------------------------------------------------------===//
211 // spv.LogicalAnd
212 //===----------------------------------------------------------------------===//
213
fold(ArrayRef<Attribute> operands)214 OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
215 assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
216
217 if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
218 // x && true = x
219 if (rhs.value())
220 return operand1();
221
222 // x && false = false
223 if (!rhs.value())
224 return operands.back();
225 }
226
227 return Attribute();
228 }
229
230 //===----------------------------------------------------------------------===//
231 // spv.LogicalNot
232 //===----------------------------------------------------------------------===//
233
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)234 void spirv::LogicalNotOp::getCanonicalizationPatterns(
235 RewritePatternSet &results, MLIRContext *context) {
236 results
237 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
238 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
239 context);
240 }
241
242 //===----------------------------------------------------------------------===//
243 // spv.LogicalOr
244 //===----------------------------------------------------------------------===//
245
fold(ArrayRef<Attribute> operands)246 OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
247 assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
248
249 if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
250 if (rhs.value())
251 // x || true = true
252 return operands.back();
253
254 // x || false = x
255 if (!rhs.value())
256 return operand1();
257 }
258
259 return Attribute();
260 }
261
262 //===----------------------------------------------------------------------===//
263 // spv.mlir.selection
264 //===----------------------------------------------------------------------===//
265
266 namespace {
267 // Blocks from the given `spv.mlir.selection` operation must satisfy the
268 // following layout:
269 //
270 // +-----------------------------------------------+
271 // | header block |
272 // | spv.BranchConditionalOp %cond, ^case0, ^case1 |
273 // +-----------------------------------------------+
274 // / \
275 // ...
276 //
277 //
278 // +------------------------+ +------------------------+
279 // | case #0 | | case #1 |
280 // | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
281 // | spv.Branch ^merge | | spv.Branch ^merge |
282 // +------------------------+ +------------------------+
283 //
284 //
285 // ...
286 // \ /
287 // v
288 // +-------------+
289 // | merge block |
290 // +-------------+
291 //
292 struct ConvertSelectionOpToSelect
293 : public OpRewritePattern<spirv::SelectionOp> {
294 using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
295
matchAndRewrite__anon884ef7a50711::ConvertSelectionOpToSelect296 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
297 PatternRewriter &rewriter) const override {
298 auto *op = selectionOp.getOperation();
299 auto &body = op->getRegion(0);
300 // Verifier allows an empty region for `spv.mlir.selection`.
301 if (body.empty()) {
302 return failure();
303 }
304
305 // Check that region consists of 4 blocks:
306 // header block, `true` block, `false` block and merge block.
307 if (std::distance(body.begin(), body.end()) != 4) {
308 return failure();
309 }
310
311 auto *headerBlock = selectionOp.getHeaderBlock();
312 if (!onlyContainsBranchConditionalOp(headerBlock)) {
313 return failure();
314 }
315
316 auto brConditionalOp =
317 cast<spirv::BranchConditionalOp>(headerBlock->front());
318
319 auto *trueBlock = brConditionalOp.getSuccessor(0);
320 auto *falseBlock = brConditionalOp.getSuccessor(1);
321 auto *mergeBlock = selectionOp.getMergeBlock();
322
323 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
324 return failure();
325
326 auto trueValue = getSrcValue(trueBlock);
327 auto falseValue = getSrcValue(falseBlock);
328 auto ptrValue = getDstPtr(trueBlock);
329 auto storeOpAttributes =
330 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
331
332 auto selectOp = rewriter.create<spirv::SelectOp>(
333 selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
334 trueValue, falseValue);
335 rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
336 selectOp.getResult(), storeOpAttributes);
337
338 // `spv.mlir.selection` is not needed anymore.
339 rewriter.eraseOp(op);
340 return success();
341 }
342
343 private:
344 // Checks that given blocks follow the following rules:
345 // 1. Each conditional block consists of two operations, the first operation
346 // is a `spv.Store` and the last operation is a `spv.Branch`.
347 // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
348 // 3. A control flow goes into the given merge block from the given
349 // conditional blocks.
350 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
351 Block *mergeBlock) const;
352
onlyContainsBranchConditionalOp__anon884ef7a50711::ConvertSelectionOpToSelect353 bool onlyContainsBranchConditionalOp(Block *block) const {
354 return std::next(block->begin()) == block->end() &&
355 isa<spirv::BranchConditionalOp>(block->front());
356 }
357
isSameAttrList__anon884ef7a50711::ConvertSelectionOpToSelect358 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
359 return lhs->getAttrDictionary() == rhs->getAttrDictionary();
360 }
361
362 // Returns a source value for the given block.
getSrcValue__anon884ef7a50711::ConvertSelectionOpToSelect363 Value getSrcValue(Block *block) const {
364 auto storeOp = cast<spirv::StoreOp>(block->front());
365 return storeOp.value();
366 }
367
368 // Returns a destination value for the given block.
getDstPtr__anon884ef7a50711::ConvertSelectionOpToSelect369 Value getDstPtr(Block *block) const {
370 auto storeOp = cast<spirv::StoreOp>(block->front());
371 return storeOp.ptr();
372 }
373 };
374
canCanonicalizeSelection(Block * trueBlock,Block * falseBlock,Block * mergeBlock) const375 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
376 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
377 // Each block must consists of 2 operations.
378 if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
379 (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
380 return failure();
381 }
382
383 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
384 auto trueBrBranchOp =
385 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
386 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
387 auto falseBrBranchOp =
388 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
389
390 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
391 !falseBrBranchOp) {
392 return failure();
393 }
394
395 // Checks that given type is valid for `spv.SelectOp`.
396 // According to SPIR-V spec:
397 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
398 // Starting with version 1.4, Result Type can additionally be a composite type
399 // other than a vector."
400 bool isScalarOrVector = trueBrStoreOp.value()
401 .getType()
402 .cast<spirv::SPIRVType>()
403 .isScalarOrVector();
404
405 // Check that each `spv.Store` uses the same pointer, memory access
406 // attributes and a valid type of the value.
407 if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
408 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
409 return failure();
410 }
411
412 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
413 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
414 return failure();
415 }
416
417 return success();
418 }
419 } // namespace
420
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)421 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
422 MLIRContext *context) {
423 results.add<ConvertSelectionOpToSelect>(context);
424 }
425