1 //===- TosaInferShapes.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Propogate shapes forward along TOSA operations to resolve dynamic shape 10 // operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Analysis/DataFlowAnalysis.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 18 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" 19 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 20 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 21 #include "mlir/IR/BlockAndValueMapping.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/BuiltinOps.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/Support/FormatVariadic.h" 29 30 using namespace mlir; 31 using namespace mlir::tosa; 32 33 namespace { 34 35 void propagateShapesInRegion(Region ®ion); 36 37 void propagateShapesToTosaIf( 38 Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) { 39 IfOp ifOp = dyn_cast<IfOp>(op); 40 if (!ifOp) 41 return; 42 43 for (auto ®ion : op.getRegions()) { 44 Block &frontBlock = region.front(); 45 if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) 46 return; 47 48 for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { 49 auto inferredTy = shapesStorage[op.getOperand(i)]; 50 auto blockArg = frontBlock.getArgument(i - 1); 51 auto oldType = blockArg.getType().cast<ShapedType>(); 52 53 if (inferredTy.hasRank()) { 54 Type newType = oldType.clone(inferredTy.getDims()); 55 blockArg.setType(newType); 56 } 57 } 58 59 for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { 60 ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( 61 ifOp.getOperand(i + 1).getType()); 62 ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType( 63 frontBlock.getArgument(i).getType()); 64 ValueKnowledge joinedKnowledge = 65 ValueKnowledge::join(operandKnowledge, blockKnowledge); 66 if (!joinedKnowledge) 67 continue; 68 frontBlock.getArgument(i).setType(joinedKnowledge.getType()); 69 } 70 71 propagateShapesInRegion(region); 72 } 73 } 74 75 void propagateShapesToTosaWhile( 76 Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) { 77 WhileOp whileOp = dyn_cast<WhileOp>(op); 78 if (!whileOp) 79 return; 80 81 // Determine what the expected argument types are to the cond/body blocks. 82 // The expected arguments should be compatible with ever iteration of the 83 // loop body / condition for tosa.while. 84 llvm::SmallVector<Type> argTypes; 85 for (auto operand : op.getOperands()) { 86 auto operandTy = operand.getType().cast<ShapedType>(); 87 auto shapedTypeComponent = shapesStorage[operand]; 88 if (shapedTypeComponent.hasRank()) { 89 auto newTy = operandTy.clone(shapedTypeComponent.getDims()); 90 argTypes.push_back(newTy); 91 } else { 92 argTypes.push_back(operand.getType()); 93 } 94 } 95 96 // Save out the type information so we can restore at the end. 97 llvm::DenseMap<Value, Type> originalTypeMap; 98 for (auto &block : op.getRegion(1)) { 99 for (auto arg : block.getArguments()) 100 originalTypeMap[arg] = arg.getType(); 101 for (auto &op : block) 102 for (auto result : op.getResults()) 103 originalTypeMap[result] = result.getType(); 104 } 105 106 bool hasNewTypes = true; 107 while (hasNewTypes) { 108 109 // Set types on the block args. 110 Region &bodyRegion = op.getRegion(1); 111 Block &block = bodyRegion.front(); 112 for (int i = 0, s = argTypes.size(); i < s; i++) { 113 block.getArgument(i).setType(argTypes[i]); 114 } 115 116 // Propagate to the end. 117 propagateShapesInRegion(bodyRegion); 118 119 // Find all the tosa yield types and verify there is atleast one. 120 llvm::SmallVector<YieldOp> yieldOps; 121 for (auto &block : bodyRegion) 122 if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator())) 123 yieldOps.push_back(yieldOp); 124 125 if (yieldOps.empty()) 126 return; 127 128 // Using the new tosa.yield operand types, infer the new subtypes. 129 llvm::SmallVector<ValueKnowledge> yieldTypeInfo; 130 for (auto ty : argTypes) { 131 yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty)); 132 } 133 134 for (auto yieldOp : yieldOps) { 135 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 136 auto newKnowledge = 137 ValueKnowledge::getKnowledgeFromType(it.value().getType()); 138 yieldTypeInfo[it.index()] = 139 ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge); 140 } 141 } 142 143 // This should never happen. 144 if (yieldTypeInfo.size() != argTypes.size()) { 145 op.emitWarning("has a tosa.yield with the incorrect number of operands"); 146 return; 147 } 148 149 // Determine the new block args and see if any changed. 150 hasNewTypes = false; 151 for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) { 152 Type newType = yieldTypeInfo[i].getType(); 153 hasNewTypes |= (newType != argTypes[i]); 154 argTypes[i] = newType; 155 } 156 157 // The types inferred in the block assume the operand types specified for 158 // this iteration. We need to restore the original types to ensure that 159 // future iterations only use the already specified types, not possible 160 // types from previous iterations. 161 for (auto &block : bodyRegion) { 162 for (auto arg : block.getArguments()) 163 arg.setType(originalTypeMap[arg]); 164 for (auto &op : block) 165 for (auto result : op.getResults()) 166 result.setType(originalTypeMap[result]); 167 } 168 } 169 170 // We now set the block arguments according to the most recent shape 171 // inference results. This gives us the block arg types for the next 172 // iteration. 173 for (auto ®ion : op.getRegions()) { 174 for (unsigned int i = 0, s = argTypes.size(); i < s; i++) { 175 region.front().getArgument(i).setType(argTypes[i]); 176 } 177 178 propagateShapesInRegion(region); 179 } 180 } 181 182 void propagateShapesInRegion(Region ®ion) { 183 DenseMap<Value, ShapedTypeComponents> shapesStorage; 184 auto setShapes = [&](Value val, Type t) { 185 if (auto st = t.dyn_cast<ShapedType>()) 186 shapesStorage[val] = st; 187 else 188 shapesStorage[val] = t; 189 }; 190 auto operandShape = [&](Value val) -> ShapeAdaptor { 191 // Query the WIP mapping rather than the type if set. 192 auto it = shapesStorage.find(val); 193 if (it == shapesStorage.end()) 194 return nullptr; 195 return it->second; 196 }; 197 198 for (auto &block : region) { 199 for (Operation &op : block) { 200 if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) 201 continue; 202 203 propagateShapesToTosaIf(op, shapesStorage); 204 propagateShapesToTosaWhile(op, shapesStorage); 205 206 InferShapedTypeOpInterface shapeInterface = 207 dyn_cast<InferShapedTypeOpInterface>(op); 208 if (!shapeInterface) 209 continue; 210 211 SmallVector<ShapedTypeComponents> returnedShapes; 212 213 ValueShapeRange range(op.getOperands(), operandShape); 214 if (shapeInterface 215 .inferReturnTypeComponents(op.getContext(), op.getLoc(), range, 216 op.getAttrDictionary(), 217 op.getRegions(), returnedShapes) 218 .succeeded()) { 219 for (auto it : llvm::zip(op.getResults(), returnedShapes)) { 220 Value result = std::get<0>(it); 221 ShapedTypeComponents predictedShape = std::get<1>(it); 222 223 // Check whether this use case is replaceable. We define an op as 224 // being replaceable if it is used by a ReturnOp or a TosaOp. 225 bool replaceable = true; 226 for (auto *user : result.getUsers()) { 227 if (isa<ReturnOp>(user)) 228 continue; 229 if (user->getDialect()->getNamespace() == 230 TosaDialect::getDialectNamespace()) 231 continue; 232 233 replaceable = false; 234 } 235 236 // Determine the knowledge based on the output type. 237 // TODO: should also query WIP type probably 238 Type resultTy = result.getType(); 239 auto currentKnowledge = 240 ValueKnowledge::getKnowledgeFromType(resultTy); 241 242 // Compute the knowledge based on the inferred type. 243 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); 244 inferredKnowledge.dtype = 245 resultTy.cast<ShapedType>().getElementType(); 246 inferredKnowledge.hasRank = predictedShape.hasRank(); 247 if (predictedShape.hasRank()) { 248 for (auto dim : predictedShape.getDims()) { 249 inferredKnowledge.sizes.push_back(dim); 250 } 251 } 252 253 if (!replaceable) 254 continue; 255 256 // Compute the new type based on the joined version. 257 auto newKnowledge = 258 ValueKnowledge::join(currentKnowledge, inferredKnowledge); 259 if (!newKnowledge) 260 continue; 261 setShapes(result, newKnowledge.getType()); 262 } 263 } 264 } 265 } 266 267 // Actually update types with updated shape knowledge. 268 for (auto it : shapesStorage) { 269 auto result = it.second; 270 if (result.hasRank()) { 271 Type t = it.first.getType().cast<ShapedType>().clone(result.getDims()); 272 it.first.setType(t); 273 } 274 } 275 } 276 277 /// Pass that performs shape propagation across TOSA operations. This includes 278 /// migrating to within the regions of if/while operations. 279 struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> { 280 public: 281 void runOnFunction() override { 282 FuncOp func = getOperation(); 283 284 IRRewriter rewriter(func.getContext()); 285 286 propagateShapesInRegion(func.body()); 287 288 // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with 289 // the FuncOp type. 290 func.walk([&](ReturnOp op) { 291 FuncOp parent = dyn_cast<FuncOp>(op->getParentOp()); 292 if (!parent) 293 return; 294 295 rewriter.setInsertionPoint(op); 296 FunctionType funcTy = func.getType(); 297 auto resultTys = funcTy.getResults(); 298 299 bool castAdded = false; 300 SmallVector<Value> castedValues; 301 for (auto it : llvm::zip(op->getOperands(), resultTys)) { 302 auto operand = std::get<0>(it); 303 auto currentTy = operand.getType(); 304 auto castTy = std::get<1>(it); 305 if (currentTy == castTy) { 306 castedValues.push_back(operand); 307 continue; 308 } 309 310 castedValues.push_back( 311 rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand) 312 .getResult()); 313 314 castAdded = true; 315 } 316 317 if (castAdded) { 318 rewriter.replaceOpWithNewOp<ReturnOp>(op, castedValues); 319 } 320 }); 321 } 322 }; 323 } // namespace 324 325 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() { 326 return std::make_unique<TosaInferShapes>(); 327 } 328