1 //===- TosaInferShapes.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Propogate shapes forward along TOSA operations to resolve dynamic shape
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Analysis/DataFlowAnalysis.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
18 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
19 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
20 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
21 #include "mlir/IR/BlockAndValueMapping.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "llvm/Support/FormatVariadic.h"
29 
30 using namespace mlir;
31 using namespace mlir::tosa;
32 
33 namespace {
34 
35 void propagateShapesInRegion(Region &region);
36 
37 void propagateShapesToTosaIf(
38     Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
39   IfOp ifOp = dyn_cast<IfOp>(op);
40   if (!ifOp)
41     return;
42 
43   for (auto &region : op.getRegions()) {
44     Block &frontBlock = region.front();
45     if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
46       return;
47 
48     for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
49       auto inferredTy = shapesStorage[op.getOperand(i)];
50       auto blockArg = frontBlock.getArgument(i - 1);
51       auto oldType = blockArg.getType().cast<ShapedType>();
52 
53       if (inferredTy.hasRank()) {
54         Type newType = oldType.clone(inferredTy.getDims());
55         blockArg.setType(newType);
56       }
57     }
58 
59     for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
60       ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
61           ifOp.getOperand(i + 1).getType());
62       ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
63           frontBlock.getArgument(i).getType());
64       ValueKnowledge joinedKnowledge =
65           ValueKnowledge::join(operandKnowledge, blockKnowledge);
66       if (!joinedKnowledge)
67         continue;
68       frontBlock.getArgument(i).setType(joinedKnowledge.getType());
69     }
70 
71     propagateShapesInRegion(region);
72   }
73 }
74 
75 void propagateShapesToTosaWhile(
76     Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
77   WhileOp whileOp = dyn_cast<WhileOp>(op);
78   if (!whileOp)
79     return;
80 
81   // Determine what the expected argument types are to the cond/body blocks.
82   // The expected arguments should be compatible with ever iteration of the
83   // loop body / condition for tosa.while.
84   llvm::SmallVector<Type> argTypes;
85   for (auto operand : op.getOperands()) {
86     auto operandTy = operand.getType().cast<ShapedType>();
87     auto shapedTypeComponent = shapesStorage[operand];
88     if (shapedTypeComponent.hasRank()) {
89       auto newTy = operandTy.clone(shapedTypeComponent.getDims());
90       argTypes.push_back(newTy);
91     } else {
92       argTypes.push_back(operand.getType());
93     }
94   }
95 
96   // Save out the type information so we can restore at the end.
97   llvm::DenseMap<Value, Type> originalTypeMap;
98   for (auto &block : op.getRegion(1)) {
99     for (auto arg : block.getArguments())
100       originalTypeMap[arg] = arg.getType();
101     for (auto &op : block)
102       for (auto result : op.getResults())
103         originalTypeMap[result] = result.getType();
104   }
105 
106   bool hasNewTypes = true;
107   while (hasNewTypes) {
108 
109     // Set types on the block args.
110     Region &bodyRegion = op.getRegion(1);
111     Block &block = bodyRegion.front();
112     for (int i = 0, s = argTypes.size(); i < s; i++) {
113       block.getArgument(i).setType(argTypes[i]);
114     }
115 
116     // Propagate to the end.
117     propagateShapesInRegion(bodyRegion);
118 
119     // Find all the tosa yield types and verify there is atleast one.
120     llvm::SmallVector<YieldOp> yieldOps;
121     for (auto &block : bodyRegion)
122       if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
123         yieldOps.push_back(yieldOp);
124 
125     if (yieldOps.empty())
126       return;
127 
128     // Using the new tosa.yield operand types, infer the new subtypes.
129     llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
130     for (auto ty : argTypes) {
131       yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
132     }
133 
134     for (auto yieldOp : yieldOps) {
135       for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
136         auto newKnowledge =
137             ValueKnowledge::getKnowledgeFromType(it.value().getType());
138         yieldTypeInfo[it.index()] =
139             ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
140       }
141     }
142 
143     // This should never happen.
144     if (yieldTypeInfo.size() != argTypes.size()) {
145       op.emitWarning("has a tosa.yield with the incorrect number of operands");
146       return;
147     }
148 
149     // Determine the new block args and see if any changed.
150     hasNewTypes = false;
151     for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
152       Type newType = yieldTypeInfo[i].getType();
153       hasNewTypes |= (newType != argTypes[i]);
154       argTypes[i] = newType;
155     }
156 
157     // The types inferred in the block assume the operand types specified for
158     // this iteration. We need to restore the original types to ensure that
159     // future iterations only use the already specified types, not possible
160     // types from previous iterations.
161     for (auto &block : bodyRegion) {
162       for (auto arg : block.getArguments())
163         arg.setType(originalTypeMap[arg]);
164       for (auto &op : block)
165         for (auto result : op.getResults())
166           result.setType(originalTypeMap[result]);
167     }
168   }
169 
170   // We now set the block arguments according to the most recent shape
171   // inference results. This gives us the block arg types for the next
172   // iteration.
173   for (auto &region : op.getRegions()) {
174     for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
175       region.front().getArgument(i).setType(argTypes[i]);
176     }
177 
178     propagateShapesInRegion(region);
179   }
180 }
181 
182 void propagateShapesInRegion(Region &region) {
183   DenseMap<Value, ShapedTypeComponents> shapesStorage;
184   auto setShapes = [&](Value val, Type t) {
185     if (auto st = t.dyn_cast<ShapedType>())
186       shapesStorage[val] = st;
187     else
188       shapesStorage[val] = t;
189   };
190   auto operandShape = [&](Value val) -> ShapeAdaptor {
191     // Query the WIP mapping rather than the type if set.
192     auto it = shapesStorage.find(val);
193     if (it == shapesStorage.end())
194       return nullptr;
195     return it->second;
196   };
197 
198   for (auto &block : region) {
199     for (Operation &op : block) {
200       if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
201         continue;
202 
203       propagateShapesToTosaIf(op, shapesStorage);
204       propagateShapesToTosaWhile(op, shapesStorage);
205 
206       InferShapedTypeOpInterface shapeInterface =
207           dyn_cast<InferShapedTypeOpInterface>(op);
208       if (!shapeInterface)
209         continue;
210 
211       SmallVector<ShapedTypeComponents> returnedShapes;
212 
213       ValueShapeRange range(op.getOperands(), operandShape);
214       if (shapeInterface
215               .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
216                                          op.getAttrDictionary(),
217                                          op.getRegions(), returnedShapes)
218               .succeeded()) {
219         for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
220           Value result = std::get<0>(it);
221           ShapedTypeComponents predictedShape = std::get<1>(it);
222 
223           // Check whether this use case is replaceable. We define an op as
224           // being replaceable if it is used by a ReturnOp or a TosaOp.
225           bool replaceable = true;
226           for (auto *user : result.getUsers()) {
227             if (isa<ReturnOp>(user))
228               continue;
229             if (user->getDialect()->getNamespace() ==
230                 TosaDialect::getDialectNamespace())
231               continue;
232 
233             replaceable = false;
234           }
235 
236           // Determine the knowledge based on the output type.
237           // TODO: should also query WIP type probably
238           Type resultTy = result.getType();
239           auto currentKnowledge =
240               ValueKnowledge::getKnowledgeFromType(resultTy);
241 
242           // Compute the knowledge based on the inferred type.
243           auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
244           inferredKnowledge.dtype =
245               resultTy.cast<ShapedType>().getElementType();
246           inferredKnowledge.hasRank = predictedShape.hasRank();
247           if (predictedShape.hasRank()) {
248             for (auto dim : predictedShape.getDims()) {
249               inferredKnowledge.sizes.push_back(dim);
250             }
251           }
252 
253           if (!replaceable)
254             continue;
255 
256           // Compute the new type based on the joined version.
257           auto newKnowledge =
258               ValueKnowledge::join(currentKnowledge, inferredKnowledge);
259           if (!newKnowledge)
260             continue;
261           setShapes(result, newKnowledge.getType());
262         }
263       }
264     }
265   }
266 
267   // Actually update types with updated shape knowledge.
268   for (auto it : shapesStorage) {
269     auto result = it.second;
270     if (result.hasRank()) {
271       Type t = it.first.getType().cast<ShapedType>().clone(result.getDims());
272       it.first.setType(t);
273     }
274   }
275 }
276 
277 /// Pass that performs shape propagation across TOSA operations. This includes
278 /// migrating to within the regions of if/while operations.
279 struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
280 public:
281   void runOnFunction() override {
282     FuncOp func = getOperation();
283 
284     IRRewriter rewriter(func.getContext());
285 
286     propagateShapesInRegion(func.body());
287 
288     // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
289     // the FuncOp type.
290     func.walk([&](ReturnOp op) {
291       FuncOp parent = dyn_cast<FuncOp>(op->getParentOp());
292       if (!parent)
293         return;
294 
295       rewriter.setInsertionPoint(op);
296       FunctionType funcTy = func.getType();
297       auto resultTys = funcTy.getResults();
298 
299       bool castAdded = false;
300       SmallVector<Value> castedValues;
301       for (auto it : llvm::zip(op->getOperands(), resultTys)) {
302         auto operand = std::get<0>(it);
303         auto currentTy = operand.getType();
304         auto castTy = std::get<1>(it);
305         if (currentTy == castTy) {
306           castedValues.push_back(operand);
307           continue;
308         }
309 
310         castedValues.push_back(
311             rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand)
312                 .getResult());
313 
314         castAdded = true;
315       }
316 
317       if (castAdded) {
318         rewriter.replaceOpWithNewOp<ReturnOp>(op, castedValues);
319       }
320     });
321   }
322 };
323 } // namespace
324 
325 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
326   return std::make_unique<TosaInferShapes>();
327 }
328