1 //===- TosaInferShapes.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Propogate shapes forward along TOSA operations to resolve dynamic shape
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
18 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/FormatVariadic.h"
28
29 using namespace mlir;
30 using namespace mlir::tosa;
31
32 namespace {
33
34 void propagateShapesInRegion(Region ®ion);
35
propagateShapesToTosaIf(Operation & op,DenseMap<Value,ShapedTypeComponents> & shapesStorage)36 void propagateShapesToTosaIf(
37 Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
38 IfOp ifOp = dyn_cast<IfOp>(op);
39 if (!ifOp)
40 return;
41
42 for (auto ®ion : op.getRegions()) {
43 Block &frontBlock = region.front();
44 if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
45 return;
46
47 for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
48 auto inferredTy = shapesStorage[op.getOperand(i)];
49 auto blockArg = frontBlock.getArgument(i - 1);
50 auto oldType = blockArg.getType().cast<ShapedType>();
51
52 if (inferredTy.hasRank()) {
53 Type newType = oldType.clone(inferredTy.getDims());
54 blockArg.setType(newType);
55 }
56 }
57
58 for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
59 ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
60 ifOp.getOperand(i + 1).getType());
61 ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
62 frontBlock.getArgument(i).getType());
63 ValueKnowledge joinedKnowledge =
64 ValueKnowledge::join(operandKnowledge, blockKnowledge);
65 if (!joinedKnowledge)
66 continue;
67 frontBlock.getArgument(i).setType(joinedKnowledge.getType());
68 }
69
70 propagateShapesInRegion(region);
71 }
72 }
73
propagateShapesToTosaWhile(Operation & op,DenseMap<Value,ShapedTypeComponents> & shapesStorage)74 void propagateShapesToTosaWhile(
75 Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
76 WhileOp whileOp = dyn_cast<WhileOp>(op);
77 if (!whileOp)
78 return;
79
80 // Determine what the expected argument types are to the cond/body blocks.
81 // The expected arguments should be compatible with ever iteration of the
82 // loop body / condition for tosa.while.
83 llvm::SmallVector<Type> argTypes;
84 for (auto operand : op.getOperands()) {
85 auto operandTy = operand.getType().cast<ShapedType>();
86 auto shapedTypeComponent = shapesStorage[operand];
87 if (shapedTypeComponent.hasRank()) {
88 auto newTy = operandTy.clone(shapedTypeComponent.getDims());
89 argTypes.push_back(newTy);
90 } else {
91 argTypes.push_back(operand.getType());
92 }
93 }
94
95 // Save out the type information so we can restore at the end.
96 llvm::DenseMap<Value, Type> originalTypeMap;
97 for (auto &block : op.getRegion(1)) {
98 for (auto arg : block.getArguments())
99 originalTypeMap[arg] = arg.getType();
100 for (auto &op : block)
101 for (auto result : op.getResults())
102 originalTypeMap[result] = result.getType();
103 }
104
105 bool hasNewTypes = true;
106 while (hasNewTypes) {
107
108 // Set types on the block args.
109 Region &bodyRegion = op.getRegion(1);
110 Block &block = bodyRegion.front();
111 for (int i = 0, s = argTypes.size(); i < s; i++) {
112 block.getArgument(i).setType(argTypes[i]);
113 }
114
115 // Propagate to the end.
116 propagateShapesInRegion(bodyRegion);
117
118 // Find all the tosa yield types and verify there is atleast one.
119 llvm::SmallVector<YieldOp> yieldOps;
120 for (auto &block : bodyRegion)
121 if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
122 yieldOps.push_back(yieldOp);
123
124 if (yieldOps.empty())
125 return;
126
127 // Using the new tosa.yield operand types, infer the new subtypes.
128 llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
129 for (auto ty : argTypes) {
130 yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
131 }
132
133 for (auto yieldOp : yieldOps) {
134 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
135 auto newKnowledge =
136 ValueKnowledge::getKnowledgeFromType(it.value().getType());
137 yieldTypeInfo[it.index()] =
138 ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
139 }
140 }
141
142 // This should never happen.
143 if (yieldTypeInfo.size() != argTypes.size()) {
144 op.emitWarning("has a tosa.yield with the incorrect number of operands");
145 return;
146 }
147
148 // Determine the new block args and see if any changed.
149 hasNewTypes = false;
150 for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
151 Type newType = yieldTypeInfo[i].getType();
152 hasNewTypes |= (newType != argTypes[i]);
153 argTypes[i] = newType;
154 }
155
156 // The types inferred in the block assume the operand types specified for
157 // this iteration. We need to restore the original types to ensure that
158 // future iterations only use the already specified types, not possible
159 // types from previous iterations.
160 for (auto &block : bodyRegion) {
161 for (auto arg : block.getArguments())
162 arg.setType(originalTypeMap[arg]);
163 for (auto &op : block)
164 for (auto result : op.getResults())
165 result.setType(originalTypeMap[result]);
166 }
167 }
168
169 // We now set the block arguments according to the most recent shape
170 // inference results. This gives us the block arg types for the next
171 // iteration.
172 for (auto ®ion : op.getRegions()) {
173 for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
174 region.front().getArgument(i).setType(argTypes[i]);
175 }
176
177 propagateShapesInRegion(region);
178 }
179 }
180
propagateShapesInRegion(Region & region)181 void propagateShapesInRegion(Region ®ion) {
182 DenseMap<Value, ShapedTypeComponents> shapesStorage;
183 auto setShapes = [&](Value val, Type t) {
184 if (auto st = t.dyn_cast<ShapedType>())
185 shapesStorage[val] = st;
186 else
187 shapesStorage[val] = t;
188 };
189 auto operandShape = [&](Value val) -> ShapeAdaptor {
190 // Query the WIP mapping rather than the type if set.
191 auto it = shapesStorage.find(val);
192 if (it == shapesStorage.end())
193 return nullptr;
194 return it->second;
195 };
196
197 for (auto &block : region) {
198 for (Operation &op : block) {
199 if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
200 continue;
201
202 propagateShapesToTosaIf(op, shapesStorage);
203 propagateShapesToTosaWhile(op, shapesStorage);
204
205 InferShapedTypeOpInterface shapeInterface =
206 dyn_cast<InferShapedTypeOpInterface>(op);
207 if (!shapeInterface)
208 continue;
209
210 SmallVector<ShapedTypeComponents> returnedShapes;
211
212 ValueShapeRange range(op.getOperands(), operandShape);
213 if (shapeInterface
214 .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
215 op.getAttrDictionary(),
216 op.getRegions(), returnedShapes)
217 .succeeded()) {
218 for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
219 Value result = std::get<0>(it);
220 ShapedTypeComponents predictedShape = std::get<1>(it);
221
222 // Check whether this use case is replaceable. We define an op as
223 // being replaceable if it is used by a ReturnOp or a TosaOp.
224 bool replaceable = true;
225 for (auto *user : result.getUsers()) {
226 if (isa<func::ReturnOp>(user))
227 continue;
228 if (user->getDialect()->getNamespace() ==
229 TosaDialect::getDialectNamespace())
230 continue;
231
232 replaceable = false;
233 }
234
235 // Determine the knowledge based on the output type.
236 // TODO: should also query WIP type probably
237 Type resultTy = result.getType();
238 auto currentKnowledge =
239 ValueKnowledge::getKnowledgeFromType(resultTy);
240
241 // Compute the knowledge based on the inferred type.
242 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
243 inferredKnowledge.dtype =
244 resultTy.cast<ShapedType>().getElementType();
245 inferredKnowledge.hasRank = predictedShape.hasRank();
246 if (predictedShape.hasRank()) {
247 for (auto dim : predictedShape.getDims()) {
248 inferredKnowledge.sizes.push_back(dim);
249 }
250 }
251
252 if (!replaceable)
253 continue;
254
255 // Compute the new type based on the joined version.
256 auto newKnowledge =
257 ValueKnowledge::join(currentKnowledge, inferredKnowledge);
258 if (!newKnowledge)
259 continue;
260 setShapes(result, newKnowledge.getType());
261 }
262 }
263 }
264 }
265
266 // Actually update types with updated shape knowledge.
267 for (auto it : shapesStorage) {
268 auto result = it.second;
269 if (result.hasRank()) {
270 Type t = it.first.getType().cast<ShapedType>().clone(result.getDims());
271 it.first.setType(t);
272 }
273 }
274 }
275
276 /// Pass that performs shape propagation across TOSA operations. This includes
277 /// migrating to within the regions of if/while operations.
278 struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
279 public:
runOnOperation__anon805bfaf10111::TosaInferShapes280 void runOnOperation() override {
281 func::FuncOp func = getOperation();
282
283 IRRewriter rewriter(func.getContext());
284
285 propagateShapesInRegion(func.getBody());
286
287 // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
288 // the FuncOp type.
289 func.walk([&](func::ReturnOp op) {
290 func::FuncOp parent = dyn_cast<func::FuncOp>(op->getParentOp());
291 if (!parent)
292 return;
293
294 rewriter.setInsertionPoint(op);
295 FunctionType funcTy = func.getFunctionType();
296 auto resultTys = funcTy.getResults();
297
298 bool castAdded = false;
299 SmallVector<Value> castedValues;
300 for (auto it : llvm::zip(op->getOperands(), resultTys)) {
301 auto operand = std::get<0>(it);
302 auto currentTy = operand.getType();
303 auto castTy = std::get<1>(it);
304 if (currentTy == castTy) {
305 castedValues.push_back(operand);
306 continue;
307 }
308
309 castedValues.push_back(
310 rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand)
311 .getResult());
312
313 castAdded = true;
314 }
315
316 if (castAdded) {
317 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, castedValues);
318 }
319 });
320 }
321 };
322 } // namespace
323
createTosaInferShapesPass()324 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
325 return std::make_unique<TosaInferShapes>();
326 }
327