1 //===- SimplifyAffineStructures.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to simplify affine structures.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/AffineStructures.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/Passes.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Transforms/Utils.h"
21 
22 #define DEBUG_TYPE "simplify-affine-structure"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 /// Simplifies affine maps and sets appearing in the operations of the Function.
29 /// This part is mainly to test the simplifyAffineExpr method. In addition,
30 /// all memrefs with non-trivial layout maps are converted to ones with trivial
31 /// identity layout ones.
32 struct SimplifyAffineStructures
33     : public FunctionPass<SimplifyAffineStructures> {
34   void runOnFunction() override;
35 
36   /// Utility to simplify an affine attribute and update its entry in the parent
37   /// operation if necessary.
38   template <typename AttributeT>
39   void simplifyAndUpdateAttribute(Operation *op, Identifier name,
40                                   AttributeT attr) {
41     auto &simplified = simplifiedAttributes[attr];
42     if (simplified == attr)
43       return;
44 
45     // This is a newly encountered attribute.
46     if (!simplified) {
47       // Try to simplify the value of the attribute.
48       auto value = attr.getValue();
49       auto simplifiedValue = simplify(value);
50       if (simplifiedValue == value) {
51         simplified = attr;
52         return;
53       }
54       simplified = AttributeT::get(simplifiedValue);
55     }
56 
57     // Simplification was successful, so update the attribute.
58     op->setAttr(name, simplified);
59   }
60 
61   /// Performs basic integer set simplifications. Checks if it's empty, and
62   /// replaces it with the canonical empty set if it is.
63   IntegerSet simplify(IntegerSet set) {
64     FlatAffineConstraints fac(set);
65     if (fac.isEmpty())
66       return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
67                                      &getContext());
68     return set;
69   }
70 
71   /// Performs basic affine map simplifications.
72   AffineMap simplify(AffineMap map) {
73     MutableAffineMap mMap(map);
74     mMap.simplify();
75     return mMap.getAffineMap();
76   }
77 
78   DenseMap<Attribute, Attribute> simplifiedAttributes;
79 };
80 
81 } // end anonymous namespace
82 
83 std::unique_ptr<OpPassBase<FuncOp>> mlir::createSimplifyAffineStructuresPass() {
84   return std::make_unique<SimplifyAffineStructures>();
85 }
86 
87 void SimplifyAffineStructures::runOnFunction() {
88   auto func = getFunction();
89   simplifiedAttributes.clear();
90   func.walk([&](Operation *opInst) {
91     for (auto attr : opInst->getAttrs()) {
92       if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
93         simplifyAndUpdateAttribute(opInst, attr.first, mapAttr);
94       else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>())
95         simplifyAndUpdateAttribute(opInst, attr.first, setAttr);
96     }
97   });
98 
99   // Turn memrefs' non-identity layouts maps into ones with identity. Collect
100   // alloc ops first and then process since normalizeMemRef replaces/erases ops
101   // during memref rewriting.
102   SmallVector<AllocOp, 4> allocOps;
103   func.walk([&](AllocOp op) { allocOps.push_back(op); });
104   for (auto allocOp : allocOps) {
105     normalizeMemRef(allocOp);
106   }
107 }
108 
109 static PassRegistration<SimplifyAffineStructures>
110     pass("simplify-affine-structures",
111          "Simplify affine expressions in maps/sets and normalize memrefs");
112