1 //===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file define utilities that operate on builtin types and are
10 // useful across multiple dialects that use structured ops abstractions. These
11 // abstractions consist of define custom operations that encode and transport
12 // information about their semantics (e.g. type of iterators like parallel,
13 // reduction, etc..) as attributes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
18 #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19 
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/Support/LLVM.h"
24 #include "llvm/ADT/StringRef.h"
25 
26 namespace mlir {
27 
28 class OpBuilder;
29 
30 /// Tests whether the given maps describe a row major matmul. The test is
31 /// permutation-invariant. Note that this only checks the affine maps from an
32 /// operation, so does not perform any checks on the math being performed within
33 /// the reduction.
34 bool isRowMajorMatmul(ArrayAttr indexingMaps);
35 
36 /// Tests whether the given maps describe a column major matmul. The test is
37 /// permutation-invariant. Note that this only checks the affine maps from an
38 /// operation, so does not perform any checks on the math being performed within
39 /// the reduction.
40 bool isColumnMajorMatmul(ArrayAttr indexingMaps);
41 
42 /// Tests whether the given maps describe a row major batch matmul. The test is
43 /// permutation-invariant. Note that this only checks the affine maps from an
44 /// operation, so does not perform any checks on the math being performed within
45 /// the reduction.
46 bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);
47 
48 /// Attribute name for the AffineArrayAttr which encodes the relationship
49 /// between a structured op iterators' and its operands.
getIndexingMapsAttrName()50 constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
51 
52 /// Attribute name for the StrArrayAttr which encodes the type of a structured
53 /// op's iterators.
getIteratorTypesAttrName()54 constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }
55 
56 /// Attribute name for the StrArrayAttr which encodes the distribution type for
57 /// `linalg.tiled_loop`.
getDistributionTypesAttrName()58 constexpr StringRef getDistributionTypesAttrName() {
59   return "distribution_types";
60 }
61 
62 /// Attribute name for the StringAttr which encodes an optional documentation
63 /// string of the structured op.
getDocAttrName()64 constexpr StringRef getDocAttrName() { return "doc"; }
65 
66 /// Attribute name for the StrArrayAttr which encodes the external library
67 /// function that implements the structured op.
getLibraryCallAttrName()68 constexpr StringRef getLibraryCallAttrName() { return "library_call"; }
69 
70 /// Attribute name for the StrArrayAttr which encodes the value of strides.
getStridesAttrName()71 constexpr StringRef getStridesAttrName() { return "strides"; }
72 
73 /// Attribute name for the StrArrayAttr which encodes the value of dilations.
getDilationsAttrName()74 constexpr StringRef getDilationsAttrName() { return "dilations"; }
75 
76 /// Attribute name for the StrArrayAttr which encodes the value of paddings.
getPaddingAttrName()77 constexpr StringRef getPaddingAttrName() { return "padding"; }
78 
79 /// Use to encode that a particular iterator type has parallel semantics.
getParallelIteratorTypeName()80 constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
isParallelIterator(Attribute attr)81 inline bool isParallelIterator(Attribute attr) {
82   auto strAttr = attr.dyn_cast_or_null<StringAttr>();
83   return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
84 }
85 
86 /// Use to encode that a particular iterator type has reduction semantics.
getReductionIteratorTypeName()87 constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
isReductionIterator(Attribute attr)88 inline bool isReductionIterator(Attribute attr) {
89   auto strAttr = attr.dyn_cast_or_null<StringAttr>();
90   return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
91 }
92 
93 /// Use to encode that a particular iterator type has window semantics.
getWindowIteratorTypeName()94 constexpr StringRef getWindowIteratorTypeName() { return "window"; }
isWindowIterator(Attribute attr)95 inline bool isWindowIterator(Attribute attr) {
96   auto strAttr = attr.dyn_cast_or_null<StringAttr>();
97   return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
98 }
99 
100 /// Use to encode that a particular iterator type has window semantics.
getAllIteratorTypeNames()101 inline ArrayRef<StringRef> getAllIteratorTypeNames() {
102   static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
103                                          getReductionIteratorTypeName(),
104                                          getWindowIteratorTypeName()};
105   return llvm::makeArrayRef(names);
106 }
107 
108 /// Returns the iterator of a certain type.
getNumIterators(StringRef name,ArrayAttr iteratorTypes)109 inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
110   auto names = getAllIteratorTypeNames();
111   (void)names;
112   assert(llvm::is_contained(names, name));
113   return llvm::count_if(iteratorTypes, [name](Attribute a) {
114     return a.cast<StringAttr>().getValue() == name;
115   });
116 }
117 
getNumIterators(ArrayAttr iteratorTypes)118 inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
119   unsigned res = 0;
120   for (auto n : getAllIteratorTypeNames())
121     res += getNumIterators(n, iteratorTypes);
122   return res;
123 }
124 
125 /// Typed representation for loop type strings.
126 enum class IteratorType { Parallel, Reduction };
127 
toString(IteratorType t)128 inline StringRef toString(IteratorType t) {
129   switch (t) {
130   case IteratorType::Parallel:
131     return getParallelIteratorTypeName();
132   case IteratorType::Reduction:
133     return getReductionIteratorTypeName();
134   }
135   llvm_unreachable("Unsupported IteratorType");
136 }
137 
138 /// Helper StructuredGenerator class to manipulate and rewrite ops with
139 /// `StructuredOpInterface`. This is templated for now because VectorOps do not
140 /// yet implement the StructuredOpInterface itself.
141 template <typename StructuredOpInterface>
142 class StructuredGenerator {
143 public:
144   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
145 
146   struct IteratorType {
IteratorTypeIteratorType147     IteratorType(StringRef strRef) : strRef(strRef) {}
isOfTypeIteratorType148     bool isOfType(Attribute attr) const {
149       auto sAttr = attr.dyn_cast<StringAttr>();
150       return sAttr && sAttr.getValue() == strRef;
151     }
152     StringRef strRef;
153   };
154   struct Par : public IteratorType {
ParPar155     Par() : IteratorType(getParallelIteratorTypeName()) {}
156   };
157   struct Red : public IteratorType {
RedRed158     Red() : IteratorType(getReductionIteratorTypeName()) {}
159   };
160   struct Win : public IteratorType {
WinWin161     Win() : IteratorType(getWindowIteratorTypeName()) {}
162   };
163 
StructuredGenerator(OpBuilder & builder,StructuredOpInterface op)164   StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
165       : builder(builder), ctx(op.getContext()), loc(op.getLoc()),
166         iterators(op.getIteratorTypes()), maps(op.getIndexingMapsArray()),
167         op(op) {}
168 
iters(ArrayRef<IteratorType> its)169   bool iters(ArrayRef<IteratorType> its) {
170     if (its.size() != iterators.size())
171       return false;
172     for (int i = 0, e = its.size(); i != e; ++i) {
173       if (!its[i].isOfType(iterators[i]))
174         return false;
175     }
176     return true;
177   }
178 
layout(MapList l)179   bool layout(MapList l) {
180     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
181     return maps == infer(l);
182   }
183 
184 protected:
185   OpBuilder &builder;
186   MLIRContext *ctx;
187   Location loc;
188   ArrayAttr iterators;
189   SmallVector<AffineMap, 4> maps;
190   Operation *op;
191 };
192 
193 } // namespace mlir
194 
195 #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
196