1 //===- TosaInferShapes.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Propogate shapes forward along TOSA operations to resolve dynamic shape 10 // operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 17 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" 18 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 using namespace mlir; 30 using namespace mlir::tosa; 31 32 namespace { 33 34 void propagateShapesInRegion(Region ®ion); 35 36 void propagateShapesToTosaIf( 37 Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) { 38 IfOp ifOp = dyn_cast<IfOp>(op); 39 if (!ifOp) 40 return; 41 42 for (auto ®ion : op.getRegions()) { 43 Block &frontBlock = region.front(); 44 if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands()) 45 return; 46 47 for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { 48 auto inferredTy = shapesStorage[op.getOperand(i)]; 49 auto blockArg = frontBlock.getArgument(i - 1); 50 auto oldType = blockArg.getType().cast<ShapedType>(); 51 52 if (inferredTy.hasRank()) { 53 Type newType = oldType.clone(inferredTy.getDims()); 54 blockArg.setType(newType); 55 } 56 } 57 58 for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) { 59 ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType( 60 ifOp.getOperand(i + 1).getType()); 61 ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType( 62 frontBlock.getArgument(i).getType()); 63 ValueKnowledge joinedKnowledge = 64 ValueKnowledge::join(operandKnowledge, blockKnowledge); 65 if (!joinedKnowledge) 66 continue; 67 frontBlock.getArgument(i).setType(joinedKnowledge.getType()); 68 } 69 70 propagateShapesInRegion(region); 71 } 72 } 73 74 void propagateShapesToTosaWhile( 75 Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) { 76 WhileOp whileOp = dyn_cast<WhileOp>(op); 77 if (!whileOp) 78 return; 79 80 // Determine what the expected argument types are to the cond/body blocks. 81 // The expected arguments should be compatible with ever iteration of the 82 // loop body / condition for tosa.while. 83 llvm::SmallVector<Type> argTypes; 84 for (auto operand : op.getOperands()) { 85 auto operandTy = operand.getType().cast<ShapedType>(); 86 auto shapedTypeComponent = shapesStorage[operand]; 87 if (shapedTypeComponent.hasRank()) { 88 auto newTy = operandTy.clone(shapedTypeComponent.getDims()); 89 argTypes.push_back(newTy); 90 } else { 91 argTypes.push_back(operand.getType()); 92 } 93 } 94 95 // Save out the type information so we can restore at the end. 96 llvm::DenseMap<Value, Type> originalTypeMap; 97 for (auto &block : op.getRegion(1)) { 98 for (auto arg : block.getArguments()) 99 originalTypeMap[arg] = arg.getType(); 100 for (auto &op : block) 101 for (auto result : op.getResults()) 102 originalTypeMap[result] = result.getType(); 103 } 104 105 bool hasNewTypes = true; 106 while (hasNewTypes) { 107 108 // Set types on the block args. 109 Region &bodyRegion = op.getRegion(1); 110 Block &block = bodyRegion.front(); 111 for (int i = 0, s = argTypes.size(); i < s; i++) { 112 block.getArgument(i).setType(argTypes[i]); 113 } 114 115 // Propagate to the end. 116 propagateShapesInRegion(bodyRegion); 117 118 // Find all the tosa yield types and verify there is atleast one. 119 llvm::SmallVector<YieldOp> yieldOps; 120 for (auto &block : bodyRegion) 121 if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator())) 122 yieldOps.push_back(yieldOp); 123 124 if (yieldOps.empty()) 125 return; 126 127 // Using the new tosa.yield operand types, infer the new subtypes. 128 llvm::SmallVector<ValueKnowledge> yieldTypeInfo; 129 for (auto ty : argTypes) { 130 yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty)); 131 } 132 133 for (auto yieldOp : yieldOps) { 134 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { 135 auto newKnowledge = 136 ValueKnowledge::getKnowledgeFromType(it.value().getType()); 137 yieldTypeInfo[it.index()] = 138 ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge); 139 } 140 } 141 142 // This should never happen. 143 if (yieldTypeInfo.size() != argTypes.size()) { 144 op.emitWarning("has a tosa.yield with the incorrect number of operands"); 145 return; 146 } 147 148 // Determine the new block args and see if any changed. 149 hasNewTypes = false; 150 for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) { 151 Type newType = yieldTypeInfo[i].getType(); 152 hasNewTypes |= (newType != argTypes[i]); 153 argTypes[i] = newType; 154 } 155 156 // The types inferred in the block assume the operand types specified for 157 // this iteration. We need to restore the original types to ensure that 158 // future iterations only use the already specified types, not possible 159 // types from previous iterations. 160 for (auto &block : bodyRegion) { 161 for (auto arg : block.getArguments()) 162 arg.setType(originalTypeMap[arg]); 163 for (auto &op : block) 164 for (auto result : op.getResults()) 165 result.setType(originalTypeMap[result]); 166 } 167 } 168 169 // We now set the block arguments according to the most recent shape 170 // inference results. This gives us the block arg types for the next 171 // iteration. 172 for (auto ®ion : op.getRegions()) { 173 for (unsigned int i = 0, s = argTypes.size(); i < s; i++) { 174 region.front().getArgument(i).setType(argTypes[i]); 175 } 176 177 propagateShapesInRegion(region); 178 } 179 } 180 181 void propagateShapesInRegion(Region ®ion) { 182 DenseMap<Value, ShapedTypeComponents> shapesStorage; 183 auto setShapes = [&](Value val, Type t) { 184 if (auto st = t.dyn_cast<ShapedType>()) 185 shapesStorage[val] = st; 186 else 187 shapesStorage[val] = t; 188 }; 189 auto operandShape = [&](Value val) -> ShapeAdaptor { 190 // Query the WIP mapping rather than the type if set. 191 auto it = shapesStorage.find(val); 192 if (it == shapesStorage.end()) 193 return nullptr; 194 return it->second; 195 }; 196 197 for (auto &block : region) { 198 for (Operation &op : block) { 199 if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) 200 continue; 201 202 propagateShapesToTosaIf(op, shapesStorage); 203 propagateShapesToTosaWhile(op, shapesStorage); 204 205 InferShapedTypeOpInterface shapeInterface = 206 dyn_cast<InferShapedTypeOpInterface>(op); 207 if (!shapeInterface) 208 continue; 209 210 SmallVector<ShapedTypeComponents> returnedShapes; 211 212 ValueShapeRange range(op.getOperands(), operandShape); 213 if (shapeInterface 214 .inferReturnTypeComponents(op.getContext(), op.getLoc(), range, 215 op.getAttrDictionary(), 216 op.getRegions(), returnedShapes) 217 .succeeded()) { 218 for (auto it : llvm::zip(op.getResults(), returnedShapes)) { 219 Value result = std::get<0>(it); 220 ShapedTypeComponents predictedShape = std::get<1>(it); 221 222 // Check whether this use case is replaceable. We define an op as 223 // being replaceable if it is used by a ReturnOp or a TosaOp. 224 bool replaceable = true; 225 for (auto *user : result.getUsers()) { 226 if (isa<func::ReturnOp>(user)) 227 continue; 228 if (user->getDialect()->getNamespace() == 229 TosaDialect::getDialectNamespace()) 230 continue; 231 232 replaceable = false; 233 } 234 235 // Determine the knowledge based on the output type. 236 // TODO: should also query WIP type probably 237 Type resultTy = result.getType(); 238 auto currentKnowledge = 239 ValueKnowledge::getKnowledgeFromType(resultTy); 240 241 // Compute the knowledge based on the inferred type. 242 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); 243 inferredKnowledge.dtype = 244 resultTy.cast<ShapedType>().getElementType(); 245 inferredKnowledge.hasRank = predictedShape.hasRank(); 246 if (predictedShape.hasRank()) { 247 for (auto dim : predictedShape.getDims()) { 248 inferredKnowledge.sizes.push_back(dim); 249 } 250 } 251 252 if (!replaceable) 253 continue; 254 255 // Compute the new type based on the joined version. 256 auto newKnowledge = 257 ValueKnowledge::join(currentKnowledge, inferredKnowledge); 258 if (!newKnowledge) 259 continue; 260 setShapes(result, newKnowledge.getType()); 261 } 262 } 263 } 264 } 265 266 // Actually update types with updated shape knowledge. 267 for (auto it : shapesStorage) { 268 auto result = it.second; 269 if (result.hasRank()) { 270 Type t = it.first.getType().cast<ShapedType>().clone(result.getDims()); 271 it.first.setType(t); 272 } 273 } 274 } 275 276 /// Pass that performs shape propagation across TOSA operations. This includes 277 /// migrating to within the regions of if/while operations. 278 struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> { 279 public: 280 void runOnOperation() override { 281 func::FuncOp func = getOperation(); 282 283 IRRewriter rewriter(func.getContext()); 284 285 propagateShapesInRegion(func.getBody()); 286 287 // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with 288 // the FuncOp type. 289 func.walk([&](func::ReturnOp op) { 290 func::FuncOp parent = dyn_cast<func::FuncOp>(op->getParentOp()); 291 if (!parent) 292 return; 293 294 rewriter.setInsertionPoint(op); 295 FunctionType funcTy = func.getFunctionType(); 296 auto resultTys = funcTy.getResults(); 297 298 bool castAdded = false; 299 SmallVector<Value> castedValues; 300 for (auto it : llvm::zip(op->getOperands(), resultTys)) { 301 auto operand = std::get<0>(it); 302 auto currentTy = operand.getType(); 303 auto castTy = std::get<1>(it); 304 if (currentTy == castTy) { 305 castedValues.push_back(operand); 306 continue; 307 } 308 309 castedValues.push_back( 310 rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand) 311 .getResult()); 312 313 castAdded = true; 314 } 315 316 if (castAdded) { 317 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, castedValues); 318 } 319 }); 320 } 321 }; 322 } // namespace 323 324 std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() { 325 return std::make_unique<TosaInferShapes>(); 326 } 327