1 //===- TosaInferShapes.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Propogate shapes forward along TOSA operations to resolve dynamic shape
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
17 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
18 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 using namespace mlir;
30 using namespace mlir::tosa;
31 
32 namespace {
33 
34 void propagateShapesInRegion(Region &region);
35 
propagateShapesToTosaIf(Operation & op,DenseMap<Value,ShapedTypeComponents> & shapesStorage)36 void propagateShapesToTosaIf(
37     Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
38   IfOp ifOp = dyn_cast<IfOp>(op);
39   if (!ifOp)
40     return;
41 
42   for (auto &region : op.getRegions()) {
43     Block &frontBlock = region.front();
44     if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
45       return;
46 
47     for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
48       auto inferredTy = shapesStorage[op.getOperand(i)];
49       auto blockArg = frontBlock.getArgument(i - 1);
50       auto oldType = blockArg.getType().cast<ShapedType>();
51 
52       if (inferredTy.hasRank()) {
53         Type newType = oldType.clone(inferredTy.getDims());
54         blockArg.setType(newType);
55       }
56     }
57 
58     for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
59       ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
60           ifOp.getOperand(i + 1).getType());
61       ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
62           frontBlock.getArgument(i).getType());
63       ValueKnowledge joinedKnowledge =
64           ValueKnowledge::join(operandKnowledge, blockKnowledge);
65       if (!joinedKnowledge)
66         continue;
67       frontBlock.getArgument(i).setType(joinedKnowledge.getType());
68     }
69 
70     propagateShapesInRegion(region);
71   }
72 }
73 
propagateShapesToTosaWhile(Operation & op,DenseMap<Value,ShapedTypeComponents> & shapesStorage)74 void propagateShapesToTosaWhile(
75     Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
76   WhileOp whileOp = dyn_cast<WhileOp>(op);
77   if (!whileOp)
78     return;
79 
80   // Determine what the expected argument types are to the cond/body blocks.
81   // The expected arguments should be compatible with ever iteration of the
82   // loop body / condition for tosa.while.
83   llvm::SmallVector<Type> argTypes;
84   for (auto operand : op.getOperands()) {
85     auto operandTy = operand.getType().cast<ShapedType>();
86     auto shapedTypeComponent = shapesStorage[operand];
87     if (shapedTypeComponent.hasRank()) {
88       auto newTy = operandTy.clone(shapedTypeComponent.getDims());
89       argTypes.push_back(newTy);
90     } else {
91       argTypes.push_back(operand.getType());
92     }
93   }
94 
95   // Save out the type information so we can restore at the end.
96   llvm::DenseMap<Value, Type> originalTypeMap;
97   for (auto &block : op.getRegion(1)) {
98     for (auto arg : block.getArguments())
99       originalTypeMap[arg] = arg.getType();
100     for (auto &op : block)
101       for (auto result : op.getResults())
102         originalTypeMap[result] = result.getType();
103   }
104 
105   bool hasNewTypes = true;
106   while (hasNewTypes) {
107 
108     // Set types on the block args.
109     Region &bodyRegion = op.getRegion(1);
110     Block &block = bodyRegion.front();
111     for (int i = 0, s = argTypes.size(); i < s; i++) {
112       block.getArgument(i).setType(argTypes[i]);
113     }
114 
115     // Propagate to the end.
116     propagateShapesInRegion(bodyRegion);
117 
118     // Find all the tosa yield types and verify there is atleast one.
119     llvm::SmallVector<YieldOp> yieldOps;
120     for (auto &block : bodyRegion)
121       if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
122         yieldOps.push_back(yieldOp);
123 
124     if (yieldOps.empty())
125       return;
126 
127     // Using the new tosa.yield operand types, infer the new subtypes.
128     llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
129     for (auto ty : argTypes) {
130       yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
131     }
132 
133     for (auto yieldOp : yieldOps) {
134       for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
135         auto newKnowledge =
136             ValueKnowledge::getKnowledgeFromType(it.value().getType());
137         yieldTypeInfo[it.index()] =
138             ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
139       }
140     }
141 
142     // This should never happen.
143     if (yieldTypeInfo.size() != argTypes.size()) {
144       op.emitWarning("has a tosa.yield with the incorrect number of operands");
145       return;
146     }
147 
148     // Determine the new block args and see if any changed.
149     hasNewTypes = false;
150     for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
151       Type newType = yieldTypeInfo[i].getType();
152       hasNewTypes |= (newType != argTypes[i]);
153       argTypes[i] = newType;
154     }
155 
156     // The types inferred in the block assume the operand types specified for
157     // this iteration. We need to restore the original types to ensure that
158     // future iterations only use the already specified types, not possible
159     // types from previous iterations.
160     for (auto &block : bodyRegion) {
161       for (auto arg : block.getArguments())
162         arg.setType(originalTypeMap[arg]);
163       for (auto &op : block)
164         for (auto result : op.getResults())
165           result.setType(originalTypeMap[result]);
166     }
167   }
168 
169   // We now set the block arguments according to the most recent shape
170   // inference results. This gives us the block arg types for the next
171   // iteration.
172   for (auto &region : op.getRegions()) {
173     for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
174       region.front().getArgument(i).setType(argTypes[i]);
175     }
176 
177     propagateShapesInRegion(region);
178   }
179 }
180 
propagateShapesInRegion(Region & region)181 void propagateShapesInRegion(Region &region) {
182   DenseMap<Value, ShapedTypeComponents> shapesStorage;
183   auto setShapes = [&](Value val, Type t) {
184     if (auto st = t.dyn_cast<ShapedType>())
185       shapesStorage[val] = st;
186     else
187       shapesStorage[val] = t;
188   };
189   auto operandShape = [&](Value val) -> ShapeAdaptor {
190     // Query the WIP mapping rather than the type if set.
191     auto it = shapesStorage.find(val);
192     if (it == shapesStorage.end())
193       return nullptr;
194     return it->second;
195   };
196 
197   for (auto &block : region) {
198     for (Operation &op : block) {
199       if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
200         continue;
201 
202       propagateShapesToTosaIf(op, shapesStorage);
203       propagateShapesToTosaWhile(op, shapesStorage);
204 
205       InferShapedTypeOpInterface shapeInterface =
206           dyn_cast<InferShapedTypeOpInterface>(op);
207       if (!shapeInterface)
208         continue;
209 
210       SmallVector<ShapedTypeComponents> returnedShapes;
211 
212       ValueShapeRange range(op.getOperands(), operandShape);
213       if (shapeInterface
214               .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
215                                          op.getAttrDictionary(),
216                                          op.getRegions(), returnedShapes)
217               .succeeded()) {
218         for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
219           Value result = std::get<0>(it);
220           ShapedTypeComponents predictedShape = std::get<1>(it);
221 
222           // Check whether this use case is replaceable. We define an op as
223           // being replaceable if it is used by a ReturnOp or a TosaOp.
224           bool replaceable = true;
225           for (auto *user : result.getUsers()) {
226             if (isa<func::ReturnOp>(user))
227               continue;
228             if (user->getDialect()->getNamespace() ==
229                 TosaDialect::getDialectNamespace())
230               continue;
231 
232             replaceable = false;
233           }
234 
235           // Determine the knowledge based on the output type.
236           // TODO: should also query WIP type probably
237           Type resultTy = result.getType();
238           auto currentKnowledge =
239               ValueKnowledge::getKnowledgeFromType(resultTy);
240 
241           // Compute the knowledge based on the inferred type.
242           auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
243           inferredKnowledge.dtype =
244               resultTy.cast<ShapedType>().getElementType();
245           inferredKnowledge.hasRank = predictedShape.hasRank();
246           if (predictedShape.hasRank()) {
247             for (auto dim : predictedShape.getDims()) {
248               inferredKnowledge.sizes.push_back(dim);
249             }
250           }
251 
252           if (!replaceable)
253             continue;
254 
255           // Compute the new type based on the joined version.
256           auto newKnowledge =
257               ValueKnowledge::join(currentKnowledge, inferredKnowledge);
258           if (!newKnowledge)
259             continue;
260           setShapes(result, newKnowledge.getType());
261         }
262       }
263     }
264   }
265 
266   // Actually update types with updated shape knowledge.
267   for (auto it : shapesStorage) {
268     auto result = it.second;
269     if (result.hasRank()) {
270       Type t = it.first.getType().cast<ShapedType>().clone(result.getDims());
271       it.first.setType(t);
272     }
273   }
274 }
275 
276 /// Pass that performs shape propagation across TOSA operations. This includes
277 /// migrating to within the regions of if/while operations.
278 struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
279 public:
runOnOperation__anon805bfaf10111::TosaInferShapes280   void runOnOperation() override {
281     func::FuncOp func = getOperation();
282 
283     IRRewriter rewriter(func.getContext());
284 
285     propagateShapesInRegion(func.getBody());
286 
287     // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
288     // the FuncOp type.
289     func.walk([&](func::ReturnOp op) {
290       func::FuncOp parent = dyn_cast<func::FuncOp>(op->getParentOp());
291       if (!parent)
292         return;
293 
294       rewriter.setInsertionPoint(op);
295       FunctionType funcTy = func.getFunctionType();
296       auto resultTys = funcTy.getResults();
297 
298       bool castAdded = false;
299       SmallVector<Value> castedValues;
300       for (auto it : llvm::zip(op->getOperands(), resultTys)) {
301         auto operand = std::get<0>(it);
302         auto currentTy = operand.getType();
303         auto castTy = std::get<1>(it);
304         if (currentTy == castTy) {
305           castedValues.push_back(operand);
306           continue;
307         }
308 
309         castedValues.push_back(
310             rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand)
311                 .getResult());
312 
313         castAdded = true;
314       }
315 
316       if (castAdded) {
317         rewriter.replaceOpWithNewOp<func::ReturnOp>(op, castedValues);
318       }
319     });
320   }
321 };
322 } // namespace
323 
createTosaInferShapesPass()324 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
325   return std::make_unique<TosaInferShapes>();
326 }
327