1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 18 #include "mlir/Dialect/Math/IR/Math.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 24 #include "mlir/Dialect/Utils/StaticValueUtils.h" 25 #include "mlir/IR/AffineExprVisitor.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/OpImplementation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/Interfaces/InferTypeOpInterface.h" 30 #include "mlir/Parser/Parser.h" 31 32 #include "llvm/ADT/DenseMap.h" 33 #include "llvm/ADT/SetVector.h" 34 #include "llvm/ADT/SmallSet.h" 35 #include "llvm/ADT/StringSet.h" 36 #include "llvm/ADT/TypeSwitch.h" 37 #include "llvm/Support/FormatVariadic.h" 38 #include "llvm/Support/MathExtras.h" 39 #include "llvm/Support/raw_ostream.h" 40 41 using namespace mlir; 42 using namespace mlir::linalg; 43 44 //===----------------------------------------------------------------------===// 45 // Support for named Linalg ops defined in ods-gen. 46 //===----------------------------------------------------------------------===// 47 48 using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &, 49 ArrayRef<NamedAttribute>)>; 50 51 /// Fills the region of a structured operation using the provided 52 /// `regionBuilder`. The method is used by both named structured ops created by 53 /// ods-gen and by manually defined C++ ops. It is called by both builders and 54 /// parsers and creates a block with arguments corresponding to the elemental 55 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be 56 /// ShapedType. 57 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, 58 TypeRange inputTypes, TypeRange outputTypes, 59 ArrayRef<NamedAttribute> attrs, 60 RegionBuilderFn regionBuilder) { 61 assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); })); 62 63 // TODO: atm all operands go through getElementTypeOrSelf, 64 // reconsider when we have evidence we need to. 65 SmallVector<Type, 8> argTypes; 66 SmallVector<Location, 8> argLocs; 67 for (auto containers : {inputTypes, outputTypes}) { 68 for (auto t : containers) { 69 argTypes.push_back(getElementTypeOrSelf(t)); 70 71 // TODO: Pass in a proper location here. 72 argLocs.push_back(opBuilder.getUnknownLoc()); 73 } 74 } 75 76 // RAII. 77 OpBuilder::InsertionGuard guard(opBuilder); 78 Block *body = 79 opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); 80 81 opBuilder.setInsertionPointToStart(body); 82 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); 83 regionBuilder(b, *body, attrs); 84 85 // indexing_maps is an auto-generated method. 86 87 // iterator_types is an auto-generated method. 88 } 89 90 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`. 91 /// The result types are derived automatically if `resultTensorTypes` is none. 92 /// The body of the operation is filled using `regionBuilder`. All ods-gen 93 /// created structured operations use the method to implement their builders. 94 static void buildStructuredOp(OpBuilder &b, OperationState &state, 95 llvm::Optional<TypeRange> resultTensorTypes, 96 ValueRange inputs, ValueRange outputs, 97 ArrayRef<NamedAttribute> attributes, 98 RegionBuilderFn regionBuilder) { 99 // Derive the result types if needed. 100 SmallVector<Type> derivedResultTypes = 101 resultTensorTypes.getValueOr(TypeRange()); 102 if (!resultTensorTypes.hasValue()) 103 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes), 104 [](Type type) { return type.isa<RankedTensorType>(); }); 105 106 state.addOperands(inputs); 107 state.addOperands(outputs); 108 state.addTypes(derivedResultTypes); 109 state.addAttributes(attributes); 110 state.addAttribute( 111 "operand_segment_sizes", 112 b.getI32VectorAttr({static_cast<int32_t>(inputs.size()), 113 static_cast<int32_t>(outputs.size())})); 114 115 // Create and fill the region of the structured operation. 116 Region ®ion = *state.addRegion(); 117 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs), 118 state.attributes.getAttrs(), regionBuilder); 119 } 120 121 /// Common parsing used for both named structured ops created by ods-gen and by 122 /// manually defined C++ ops. Does not handle regions. 123 static ParseResult 124 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, 125 SmallVectorImpl<Type> &inputTypes, 126 SmallVectorImpl<Type> &outputTypes) { 127 SMLoc inputsOperandsLoc, outputsOperandsLoc; 128 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands, 129 outputsOperands; 130 131 parser.parseOptionalAttrDict(result.attributes); 132 133 if (succeeded(parser.parseOptionalKeyword("ins"))) { 134 if (parser.parseLParen()) 135 return failure(); 136 137 inputsOperandsLoc = parser.getCurrentLocation(); 138 if (parser.parseOperandList(inputsOperands) || 139 parser.parseColonTypeList(inputTypes) || parser.parseRParen()) 140 return failure(); 141 } 142 143 if (succeeded(parser.parseOptionalKeyword("outs"))) { 144 outputsOperandsLoc = parser.getCurrentLocation(); 145 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || 146 parser.parseColonTypeList(outputTypes) || parser.parseRParen()) 147 return failure(); 148 } 149 150 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, 151 result.operands) || 152 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, 153 result.operands)) 154 return failure(); 155 156 result.addAttribute("operand_segment_sizes", 157 parser.getBuilder().getI32VectorAttr( 158 {static_cast<int32_t>(inputsOperands.size()), 159 static_cast<int32_t>(outputsOperands.size())})); 160 return success(); 161 } 162 163 static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, 164 ValueRange outputs) { 165 if (!inputs.empty()) 166 p << " ins(" << inputs << " : " << inputs.getTypes() << ")"; 167 if (!outputs.empty()) 168 p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Specific parsing and printing for named structured ops created by ods-gen. 173 //===----------------------------------------------------------------------===// 174 175 static ParseResult parseNamedStructuredOpRegion( 176 OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, 177 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs, 178 RegionBuilderFn regionBuilder) { 179 if (numRegionArgs != inputTypes.size() + outputTypes.size()) { 180 return parser.emitError( 181 parser.getCurrentLocation(), 182 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " 183 "region expects {0} args, got {1}", 184 numRegionArgs, inputTypes.size() + outputTypes.size())); 185 } 186 187 OpBuilder opBuilder(parser.getContext()); 188 fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs, 189 regionBuilder); 190 return success(); 191 } 192 193 static ParseResult 194 parseNamedStructuredOpResults(OpAsmParser &parser, 195 SmallVectorImpl<Type> &resultTypes) { 196 if (parser.parseOptionalArrowTypeList(resultTypes)) 197 return failure(); 198 return success(); 199 } 200 201 static ParseResult parseNamedStructuredOp(OpAsmParser &parser, 202 OperationState &result, 203 unsigned numRegionArgs, 204 RegionBuilderFn regionBuilder) { 205 // TODO: Enable when ods-gen supports captures. 206 SmallVector<Type, 1> inputTypes, outputTypes; 207 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) 208 return failure(); 209 210 // TODO: consider merging results parsing into region parsing. 211 // Need to wait for declarative assembly resolution to decide. 212 SmallVector<Type, 1> outputTensorsTypes; 213 if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) 214 return failure(); 215 result.addTypes(outputTensorsTypes); 216 217 std::unique_ptr<Region> region = std::make_unique<Region>(); 218 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes, 219 outputTypes, result.attributes.getAttrs(), 220 regionBuilder)) 221 return failure(); 222 result.addRegion(std::move(region)); 223 224 return success(); 225 } 226 227 static void printNamedStructuredOpResults(OpAsmPrinter &p, 228 TypeRange resultTypes) { 229 if (resultTypes.empty()) 230 return; 231 p.printOptionalArrowTypeList(resultTypes); 232 } 233 234 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, 235 ValueRange inputs, ValueRange outputs) { 236 p.printOptionalAttrDict( 237 op->getAttrs(), 238 /*elidedAttrs=*/{"operand_segment_sizes", 239 // See generated code in mlir-linalg-yaml-gen.cpp 240 "linalg.memoized_indexing_maps"}); 241 242 // Printing is shared with generic ops, except for the region and 243 // attributes. 244 printCommonStructuredOpParts(p, inputs, outputs); 245 246 // Results printing. 247 printNamedStructuredOpResults(p, op->getResultTypes()); 248 249 // Region is elided. 250 } 251 252 /// This is a common class used for patterns of the form 253 /// ``` 254 /// someop(memrefcast(%src)) -> someop(%src) 255 /// ``` 256 /// It folds the source of the memref.cast into the root operation directly. 257 static LogicalResult foldMemRefCast(Operation *op) { 258 bool folded = false; 259 for (OpOperand &operand : op->getOpOperands()) { 260 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 261 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 262 operand.set(castOp.getOperand()); 263 folded = true; 264 } 265 } 266 return success(folded); 267 } 268 269 /// Helper function to find if there is atleast one dimension in an AffineMap 270 /// testMap that is contained in `testMapLocation` of `maps` but not in any 271 /// other locations 272 static bool hasaUniqueDim(ArrayRef<AffineMap> maps, unsigned testMapLocation) { 273 AffineMap testMap = maps[testMapLocation]; 274 llvm::SmallDenseSet<unsigned> dimsToCheck; 275 for (auto result : testMap.getResults()) { 276 auto expr = result.dyn_cast<AffineDimExpr>(); 277 if (expr != nullptr) 278 dimsToCheck.insert(expr.getPosition()); 279 } 280 for (const auto &it : llvm::enumerate(maps)) { 281 if (it.index() == testMapLocation) 282 continue; 283 auto map = it.value(); 284 for (auto result : map.getResults()) { 285 auto expr = result.dyn_cast<AffineDimExpr>(); 286 if (expr != nullptr) { 287 dimsToCheck.erase(expr.getPosition()); 288 } 289 if (dimsToCheck.empty()) 290 return false; 291 } 292 } 293 return true; 294 } 295 296 //===----------------------------------------------------------------------===// 297 // Region builder helper. 298 // TODO: Move this to a utility library. 299 // The public methods on this class are referenced directly from generated code. 300 // Helper build the unary, binary, and type conversion functions defined by the 301 // DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class. 302 // 303 // Implementations of the math functions must be polymorphic over numeric types, 304 // internally performing necessary casts. If the function application makes no 305 // sense, then the only recourse is to assert and return nullptr. This can be 306 // extended later if it becomes possible to fail construction of the region. The 307 // invariant should be enforced at a higher level. 308 // 309 // TODO: These helpers are currently type polymorphic over the class of integer 310 // and floating point types, but they will not internally cast within bit 311 // widths of a class (mixed precision such as i8->i32) or across classes 312 // (i.e. mixed float and integer). Many such combinations are ambiguous or need 313 // to be handled with care and work is being considered to extend the op 314 // language to make such cases explicit. In the mean-time, violating this will 315 // fail verification, which is deemed acceptable. 316 //===----------------------------------------------------------------------===// 317 318 namespace { 319 320 class RegionBuilderHelper { 321 public: 322 RegionBuilderHelper(MLIRContext *context, Block &block) 323 : context(context), block(block) {} 324 325 // Build the unary functions defined by OpDSL. 326 Value buildUnaryFn(UnaryFn unaryFn, Value arg) { 327 if (!isFloatingPoint(arg)) 328 llvm_unreachable("unsupported non numeric type"); 329 OpBuilder builder = getBuilder(); 330 switch (unaryFn) { 331 case UnaryFn::exp: 332 return builder.create<math::ExpOp>(arg.getLoc(), arg); 333 case UnaryFn::log: 334 return builder.create<math::LogOp>(arg.getLoc(), arg); 335 case UnaryFn::abs: 336 return builder.create<math::AbsOp>(arg.getLoc(), arg); 337 case UnaryFn::ceil: 338 return builder.create<math::CeilOp>(arg.getLoc(), arg); 339 case UnaryFn::floor: 340 return builder.create<math::FloorOp>(arg.getLoc(), arg); 341 case UnaryFn::negf: 342 return builder.create<arith::NegFOp>(arg.getLoc(), arg); 343 } 344 llvm_unreachable("unsupported unary function"); 345 } 346 347 // Build the binary functions defined by OpDSL. 348 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) { 349 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); 350 bool allInteger = isInteger(arg0) && isInteger(arg1); 351 if (!allFloatingPoint && !allInteger) 352 llvm_unreachable("unsupported non numeric type"); 353 OpBuilder builder = getBuilder(); 354 switch (binaryFn) { 355 case BinaryFn::add: 356 if (allFloatingPoint) 357 return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1); 358 return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1); 359 case BinaryFn::sub: 360 if (allFloatingPoint) 361 return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1); 362 return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1); 363 case BinaryFn::mul: 364 if (allFloatingPoint) 365 return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1); 366 return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1); 367 case BinaryFn::max_signed: 368 if (allFloatingPoint) 369 return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1); 370 return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1); 371 case BinaryFn::min_signed: 372 if (allFloatingPoint) 373 return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1); 374 return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1); 375 case BinaryFn::max_unsigned: 376 if (allFloatingPoint) 377 return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1); 378 return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1); 379 case BinaryFn::min_unsigned: 380 if (allFloatingPoint) 381 return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1); 382 return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1); 383 } 384 llvm_unreachable("unsupported binary function"); 385 } 386 387 // Build the type functions defined by OpDSL. 388 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { 389 switch (typeFn) { 390 case TypeFn::cast_signed: 391 return cast(toType, operand, false); 392 case TypeFn::cast_unsigned: 393 return cast(toType, operand, true); 394 } 395 llvm_unreachable("unsupported type conversion function"); 396 } 397 398 void yieldOutputs(ValueRange values) { 399 OpBuilder builder = getBuilder(); 400 Location loc = builder.getUnknownLoc(); 401 builder.create<YieldOp>(loc, values); 402 } 403 404 Value constant(const std::string &value) { 405 OpBuilder builder = getBuilder(); 406 Location loc = builder.getUnknownLoc(); 407 Attribute valueAttr = parseAttribute(value, builder.getContext()); 408 return builder.create<arith::ConstantOp>(loc, valueAttr.getType(), 409 valueAttr); 410 } 411 412 Value index(int64_t dim) { 413 OpBuilder builder = getBuilder(); 414 return builder.create<IndexOp>(builder.getUnknownLoc(), dim); 415 } 416 417 Type getIntegerType(unsigned width) { 418 return IntegerType::get(context, width); 419 } 420 421 Type getFloat32Type() { return Float32Type::get(context); } 422 Type getFloat64Type() { return Float64Type::get(context); } 423 424 private: 425 // Generates operations to cast the given operand to a specified type. 426 // If the cast cannot be performed, a warning will be issued and the 427 // operand returned as-is (which will presumably yield a verification 428 // issue downstream). 429 Value cast(Type toType, Value operand, bool isUnsignedCast) { 430 OpBuilder builder = getBuilder(); 431 auto loc = operand.getLoc(); 432 433 if (operand.getType() == toType) 434 return operand; 435 if (auto toIntType = toType.dyn_cast<IntegerType>()) { 436 // If operand is floating point, cast directly to the int type. 437 if (operand.getType().isa<FloatType>()) { 438 if (isUnsignedCast) 439 return builder.create<arith::FPToUIOp>(loc, toType, operand); 440 return builder.create<arith::FPToSIOp>(loc, toType, operand); 441 } 442 // Cast index operands directly to the int type. 443 if (operand.getType().isIndex()) 444 return builder.create<arith::IndexCastOp>(loc, toType, operand); 445 if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) { 446 // Either extend or truncate. 447 if (toIntType.getWidth() > fromIntType.getWidth()) { 448 if (isUnsignedCast) 449 return builder.create<arith::ExtUIOp>(loc, toType, operand); 450 return builder.create<arith::ExtSIOp>(loc, toType, operand); 451 } 452 if (toIntType.getWidth() < fromIntType.getWidth()) 453 return builder.create<arith::TruncIOp>(loc, toType, operand); 454 } 455 } else if (auto toFloatType = toType.dyn_cast<FloatType>()) { 456 // If operand is integer, cast directly to the float type. 457 // Note that it is unclear how to cast from BF16<->FP16. 458 if (operand.getType().isa<IntegerType>()) { 459 if (isUnsignedCast) 460 return builder.create<arith::UIToFPOp>(loc, toFloatType, operand); 461 return builder.create<arith::SIToFPOp>(loc, toFloatType, operand); 462 } 463 if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) { 464 if (toFloatType.getWidth() > fromFloatType.getWidth()) 465 return builder.create<arith::ExtFOp>(loc, toFloatType, operand); 466 if (toFloatType.getWidth() < fromFloatType.getWidth()) 467 return builder.create<arith::TruncFOp>(loc, toFloatType, operand); 468 } 469 } 470 471 emitWarning(operand.getLoc()) << "could not cast operand of type " 472 << operand.getType() << " to " << toType; 473 return operand; 474 } 475 476 bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); } 477 bool isInteger(Value value) { return value.getType().isa<IntegerType>(); } 478 479 OpBuilder getBuilder() { 480 OpBuilder builder(context); 481 builder.setInsertionPointToEnd(&block); 482 return builder; 483 } 484 485 MLIRContext *context; 486 Block █ 487 }; 488 489 } // namespace 490 491 //===----------------------------------------------------------------------===// 492 // FillOp 493 //===----------------------------------------------------------------------===// 494 495 namespace { 496 497 /// Fold linalg.fill -> tensor.expand/collapse_shape chain. 498 /// 499 /// For such op chains, we can create new linalg.fill ops with the result 500 /// type of the tensor.expand/collapse_shape op. 501 template <typename TensorReshapeOp> 502 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> { 503 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 504 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 505 PatternRewriter &rewriter) const override { 506 auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>(); 507 if (!oldFill) 508 return failure(); 509 510 Location loc = oldFill.getLoc(); 511 auto newInit = rewriter.create<TensorReshapeOp>( 512 loc, reshapeOp.getResultType(), oldFill.output(), 513 reshapeOp.reassociation()); 514 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()}, 515 ValueRange{newInit}); 516 517 return success(); 518 } 519 }; 520 521 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the 522 /// filling value are the same. 523 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> { 524 using OpRewritePattern::OpRewritePattern; 525 526 LogicalResult matchAndRewrite(tensor::PadOp padOp, 527 PatternRewriter &rewriter) const override { 528 auto fillOp = padOp.source().getDefiningOp<linalg::FillOp>(); 529 if (!fillOp) 530 return failure(); 531 532 // We can only fold if the padding value is the same as the original 533 // filling value. 534 Value padValue = padOp.getConstantPaddingValue(); 535 if (!padValue || fillOp.value() != padValue) 536 return failure(); 537 538 ReifiedRankedShapedTypeDims reifiedShape; 539 ReifyRankedShapedTypeOpInterface interface = 540 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation()); 541 if (failed(interface.reifyResultShapes(rewriter, reifiedShape))) 542 return rewriter.notifyMatchFailure( 543 padOp, "failed to reify tensor.pad op result shape"); 544 545 auto oldResultType = padOp.getResultType(); 546 SmallVector<int64_t, 4> staticShape(oldResultType.getRank(), 547 ShapedType::kDynamicSize); 548 auto newInitOp = rewriter.create<InitTensorOp>( 549 padOp.getLoc(), reifiedShape.front(), staticShape, 550 oldResultType.getElementType()); 551 auto newFillOp = rewriter.create<FillOp>( 552 fillOp.getLoc(), ValueRange{padValue}, ValueRange{newInitOp}); 553 rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType, 554 newFillOp.result()); 555 556 return success(); 557 } 558 }; 559 560 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into 561 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the 562 /// filling value are the same. 563 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> { 564 using OpRewritePattern::OpRewritePattern; 565 566 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, 567 PatternRewriter &rewriter) const override { 568 auto srcPadOp = insertOp.source().getDefiningOp<tensor::PadOp>(); 569 if (!srcPadOp) 570 return failure(); 571 572 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank()) 573 return failure(); 574 575 // Walk back the tensor.insert_slice chain and find the first destination 576 // value at the start of the chain. 577 Value firstDest = insertOp.dest(); 578 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) { 579 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank()) 580 return failure(); 581 582 // Make sure the range of values accessed are disjoint. Without this, we 583 // cannot fold tensor.pad away. 584 bool disjoint = false; 585 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) { 586 // If the dimension has dynamic offset/size, we cannot guarantee 587 // disjoint. So just skip it. 588 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) || 589 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) || 590 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i)) 591 continue; 592 593 // Get the range start and end, inclusively for both. 594 int64_t prevStart = prevOp.getStaticOffset(i); 595 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) * 596 prevOp.getStaticStride(i); 597 int64_t nextStart = insertOp.getStaticOffset(i); 598 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) * 599 insertOp.getStaticStride(i); 600 if (prevEnd < nextStart || nextEnd < prevStart) { 601 disjoint = true; 602 break; 603 } 604 } 605 606 if (!disjoint) 607 break; 608 firstDest = prevOp.dest(); 609 } 610 611 // Check whether the first destination is a fill op. For overlapped cases, 612 // this also cannot be true. 613 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>(); 614 if (!dstFillOp) 615 return failure(); 616 617 // We can only fold if the padding value is the same as the original 618 // filling value. 619 Value padValue = srcPadOp.getConstantPaddingValue(); 620 if (!padValue || dstFillOp.value() != padValue) 621 return failure(); 622 623 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad(); 624 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets(); 625 626 Location loc = insertOp.getLoc(); 627 MLIRContext *context = getContext(); 628 629 AffineExpr sym0, sym1; 630 bindSymbols(context, sym0, sym1); 631 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context); 632 633 // Calculate the new offsets for the insert. It should be the old offsets 634 // plus low padding sizes. 635 SmallVector<OpFoldResult, 4> newOffsets; 636 for (const auto &p : llvm::zip(lowPads, oldOffsets)) { 637 Value padValue = getValueOrCreateConstantIndexOp( 638 rewriter, srcPadOp.getLoc(), std::get<0>(p)); 639 Value offsetValue = getValueOrCreateConstantIndexOp( 640 rewriter, insertOp.getLoc(), std::get<1>(p)); 641 newOffsets.push_back( 642 applyMapToValues(rewriter, loc, addMap, {offsetValue, padValue})[0]); 643 } 644 645 SmallVector<OpFoldResult, 4> newSizes; 646 for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) { 647 newSizes.push_back( 648 rewriter.create<tensor::DimOp>(loc, srcPadOp.source(), i).result()); 649 } 650 651 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 652 insertOp, srcPadOp.source(), insertOp.dest(), newOffsets, newSizes, 653 insertOp.getMixedStrides()); 654 return success(); 655 } 656 }; 657 658 } // namespace 659 660 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, 661 MLIRContext *context) { 662 results 663 .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>, 664 FoldFillWithTensorReshape<tensor::ExpandShapeOp>, 665 FoldInsertPadIntoFill>(context); 666 } 667 668 //===----------------------------------------------------------------------===// 669 // GenericOps 670 //===----------------------------------------------------------------------===// 671 void GenericOp::build( 672 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, 673 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 674 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, 675 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 676 ArrayRef<NamedAttribute> attributes) { 677 build(builder, result, resultTensorTypes, inputs, outputs, 678 builder.getAffineMapArrayAttr(indexingMaps), 679 builder.getStrArrayAttr(iteratorTypes), 680 doc.empty() ? StringAttr() : builder.getStringAttr(doc), 681 libraryCall.empty() ? StringAttr() 682 : builder.getStringAttr(libraryCall)); 683 result.addAttributes(attributes); 684 if (!bodyBuild) 685 return; 686 687 SmallVector<Type, 4> blockArgTypes; 688 SmallVector<Location, 4> blockArgLocs; 689 for (ValueRange container : {inputs, outputs}) { 690 for (Value v : container) { 691 blockArgTypes.push_back(getElementTypeOrSelf(v)); 692 blockArgLocs.push_back(v.getLoc()); 693 } 694 } 695 696 OpBuilder::InsertionGuard guard(builder); 697 auto ®ion = *result.regions.front(); 698 Block *bodyBlock = 699 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); 700 bodyBuild(builder, result.location, bodyBlock->getArguments()); 701 } 702 703 void GenericOp::build( 704 OpBuilder &builder, OperationState &result, ValueRange inputs, 705 ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 706 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, 707 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 708 ArrayRef<NamedAttribute> attributes) { 709 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, 710 iteratorTypes, doc, libraryCall, bodyBuild, attributes); 711 } 712 713 void GenericOp::build( 714 OpBuilder &builder, OperationState &result, ValueRange inputs, 715 ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 716 ArrayRef<StringRef> iteratorTypes, 717 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 718 ArrayRef<NamedAttribute> attributes) { 719 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, 720 /*doc=*/"", 721 /*libraryCall=*/"", bodyBuild, attributes); 722 } 723 724 void GenericOp::build( 725 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, 726 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 727 ArrayRef<StringRef> iteratorTypes, 728 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 729 ArrayRef<NamedAttribute> attributes) { 730 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, 731 iteratorTypes, 732 /*doc=*/"", 733 /*libraryCall=*/"", bodyBuild, attributes); 734 } 735 736 void GenericOp::print(OpAsmPrinter &p) { 737 p << " "; 738 739 // Print extra attributes. 740 auto genericAttrNames = linalgTraitAttrNames(); 741 742 llvm::StringSet<> genericAttrNamesSet; 743 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); 744 SmallVector<NamedAttribute, 8> genericAttrs; 745 for (auto attr : (*this)->getAttrs()) 746 if (genericAttrNamesSet.count(attr.getName().strref()) > 0) 747 genericAttrs.push_back(attr); 748 if (!genericAttrs.empty()) { 749 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs); 750 p << genericDictAttr; 751 } 752 753 // Printing is shared with named ops, except for the region and attributes 754 printCommonStructuredOpParts(p, inputs(), outputs()); 755 756 genericAttrNames.push_back("operand_segment_sizes"); 757 genericAttrNamesSet.insert(genericAttrNames.back()); 758 759 bool hasExtraAttrs = false; 760 for (NamedAttribute n : (*this)->getAttrs()) { 761 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref()))) 762 break; 763 } 764 if (hasExtraAttrs) { 765 p << " attrs = "; 766 p.printOptionalAttrDict((*this)->getAttrs(), 767 /*elidedAttrs=*/genericAttrNames); 768 } 769 770 // Print region. 771 if (!region().empty()) { 772 p << ' '; 773 p.printRegion(region()); 774 } 775 776 // Print results. 777 printNamedStructuredOpResults(p, result_tensors().getTypes()); 778 } 779 780 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) { 781 DictionaryAttr dictAttr; 782 // Parse the core linalg traits that must check into a dictAttr. 783 // The name is unimportant as we will overwrite result.attributes. 784 // The core linalg traits must contain the information necessary to pass the 785 // verifier. 786 if (parser.parseAttribute(dictAttr, "_", result.attributes)) 787 return failure(); 788 result.attributes.assign(dictAttr.getValue().begin(), 789 dictAttr.getValue().end()); 790 791 // Parsing is shared with named ops, except for the region. 792 SmallVector<Type, 1> inputTypes, outputTypes; 793 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) 794 return failure(); 795 796 // Optional attributes may be added. 797 if (succeeded(parser.parseOptionalKeyword("attrs"))) 798 if (failed(parser.parseEqual()) || 799 failed(parser.parseOptionalAttrDict(result.attributes))) 800 return failure(); 801 802 SmallVector<OpAsmParser::UnresolvedOperand, 8> regionOperands; 803 std::unique_ptr<Region> region = std::make_unique<Region>(); 804 SmallVector<Type, 8> operandTypes, regionTypes; 805 if (parser.parseRegion(*region, regionOperands, regionTypes)) 806 return failure(); 807 result.addRegion(std::move(region)); 808 809 // Generic ops may specify that a subset of its outputs are tensors. Such 810 // outputs are specified in the result type. 811 // TODO: may need to move output parsing before region parsing. 812 // Need to wait for declarative assembly resolution to decide. 813 SmallVector<Type, 1> outputTensorsTypes; 814 if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) 815 return failure(); 816 result.addTypes(outputTensorsTypes); 817 818 return success(); 819 } 820 821 static void getGenericEffectsImpl( 822 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 823 &effects, 824 ValueRange results, ValueRange inputBuffers, ValueRange outputs) { 825 for (Value value : inputBuffers) { 826 effects.emplace_back(MemoryEffects::Read::get(), value, 827 SideEffects::DefaultResource::get()); 828 } 829 for (Value value : outputs) { 830 effects.emplace_back(MemoryEffects::Read::get(), value, 831 SideEffects::DefaultResource::get()); 832 effects.emplace_back(MemoryEffects::Write::get(), value, 833 SideEffects::DefaultResource::get()); 834 } 835 } 836 837 void GenericOp::getEffects( 838 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 839 &effects) { 840 SmallVector<Value> inputBuffers = getInputBufferOperands(); 841 SmallVector<Value> outputBuffers = getOutputBufferOperands(); 842 getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, 843 outputBuffers); 844 } 845 846 LogicalResult GenericOp::verify() { return success(); } 847 848 namespace { 849 // Deduplicate redundant args of a linalg generic op. 850 // An arg is redundant if it has the same Value and indexing map as another. 851 struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> { 852 using OpRewritePattern<GenericOp>::OpRewritePattern; 853 854 LogicalResult matchAndRewrite(GenericOp genericOp, 855 PatternRewriter &rewriter) const override { 856 // Associate each input to an equivalent "canonical" input that has the same 857 // Value and indexing map. 858 // 859 // In the non-duplicate case, input `i` will have canonical input `i`. But 860 // in the case of duplicated inputs, the canonical input could be some other 861 // input `< i`. That is, a later input will have some earlier input as its 862 // canonical input. 863 llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput; 864 // For later remapping tasks like deduplicating payload block arguments, 865 // having a simple "inputIndex -> canonicalInputIndex" integer mapping is 866 // convenient. 867 SmallVector<unsigned> canonicalInputIndices; 868 for (OpOperand *opOperand : genericOp.getInputOperands()) { 869 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 870 // STL-like maps have a convenient behavior for our use case here. In the 871 // case of duplicate keys, the insertion is rejected, and the returned 872 // iterator gives access to the value already in the map. 873 auto pair = canonicalInput.insert( 874 {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); 875 canonicalInputIndices.push_back(pair.first->second); 876 } 877 878 // If there are no duplicate args, then bail out. 879 if (canonicalInput.size() == genericOp.getNumInputs()) 880 return failure(); 881 882 // The operands for the newly canonicalized op. 883 SmallVector<Value> newInputOperands; 884 for (OpOperand *opOperand : genericOp.getInputOperands()) 885 if (canonicalInputIndices[opOperand->getOperandNumber()] == 886 opOperand->getOperandNumber()) 887 newInputOperands.push_back(opOperand->get()); 888 889 // Repair the indexing maps by filtering out the ones that have been 890 // eliminated. 891 SmallVector<AffineMap> newIndexingMaps; 892 for (OpOperand *opOperand : genericOp.getInputOperands()) 893 if (canonicalInputIndices[opOperand->getOperandNumber()] == 894 opOperand->getOperandNumber()) 895 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); 896 for (OpOperand *opOperand : genericOp.getOutputOperands()) 897 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); 898 899 // Clone the old op with new operands. 900 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 901 auto newOp = rewriter.create<GenericOp>( 902 genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, 903 outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), 904 genericOp.iterator_types(), genericOp.docAttr(), 905 genericOp.library_callAttr()); 906 907 // Copy over unknown attributes. They might be load bearing for some flow. 908 ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames(); 909 for (NamedAttribute kv : genericOp->getAttrs()) { 910 if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { 911 newOp->setAttr(kv.getName(), kv.getValue()); 912 } 913 } 914 915 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 916 newOp.region().begin()); 917 918 // Repair the payload entry block by RAUW'ing redundant arguments and 919 // erasing them. 920 Block &payload = newOp.region().front(); 921 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 922 for (OpOperand *opOperand : llvm::reverse(inputOperands)) { 923 // Iterate in reverse, so that we erase later args first, preventing the 924 // argument list from shifting unexpectedly and invalidating all our 925 // indices. 926 unsigned operandNumber = opOperand->getOperandNumber(); 927 if (canonicalInputIndices[operandNumber] == operandNumber) 928 continue; 929 payload.getArgument(operandNumber) 930 .replaceAllUsesWith( 931 payload.getArgument(canonicalInputIndices[operandNumber])); 932 payload.eraseArgument(operandNumber); 933 } 934 935 rewriter.replaceOp(genericOp, newOp->getResults()); 936 return success(); 937 } 938 }; 939 940 /// Remove generic operations (on tensors) that are just copying 941 /// the values from inputs to the results. Requirements are 942 /// 1) All iterator types are parallel 943 /// 2) The body contains just a yield operation with the yielded values being 944 /// the arguments corresponding to the operands. 945 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> { 946 using OpRewritePattern<GenericOp>::OpRewritePattern; 947 948 LogicalResult matchAndRewrite(GenericOp genericOp, 949 PatternRewriter &rewriter) const override { 950 // Check all indexing maps are identity. 951 if (llvm::any_of(genericOp.getIndexingMaps(), 952 [](AffineMap map) { return !map.isIdentity(); })) 953 return failure(); 954 955 // Check that the body of the linalg operation is just a linalg.yield 956 // operation. 957 Block &body = genericOp.region().front(); 958 if (!llvm::hasSingleElement(body)) 959 return failure(); 960 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 961 if (!yieldOp) 962 return failure(); 963 964 // In the buffer case, we need to check exact buffer equality. 965 if (genericOp.hasBufferSemantics()) { 966 if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 && 967 genericOp.getInputOperand(0)->get() == 968 genericOp.getOutputOperand(0)->get()) { 969 rewriter.eraseOp(genericOp); 970 return success(); 971 } 972 return failure(); 973 } 974 975 // Get the argument number of the returned values. That is the operand 976 // number to use for replacing uses of this operation. 977 SmallVector<Value> returnedArgs; 978 for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) { 979 auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>(); 980 if (!yieldArg || yieldArg.getOwner() != &body) 981 return failure(); 982 unsigned argumentNumber = yieldArg.getArgNumber(); 983 Value returnedArg = genericOp->getOperand(argumentNumber); 984 Type resultType = genericOp->getResult(yieldVal.index()).getType(); 985 // The input can have a different type than the result, e.g. a dynamic 986 // input dimension can be turned into a static output dimension. 987 Type returnType = returnedArg.getType(); 988 if (returnType != resultType) { 989 // Distinguish between sparse conversion or dense tensor casting. 990 // TODO: unify the two ops? 991 if (sparse_tensor::getSparseTensorEncoding(returnType) || 992 sparse_tensor::getSparseTensorEncoding(resultType)) 993 returnedArg = rewriter.create<sparse_tensor::ConvertOp>( 994 genericOp.getLoc(), resultType, returnedArg); 995 else { 996 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(), 997 resultType)) 998 return failure(); 999 returnedArg = rewriter.create<tensor::CastOp>( 1000 genericOp.getLoc(), resultType, returnedArg); 1001 } 1002 } 1003 returnedArgs.push_back(returnedArg); 1004 } 1005 1006 if (returnedArgs.size() != genericOp->getNumResults()) 1007 return failure(); 1008 rewriter.replaceOp(genericOp, returnedArgs); 1009 return success(); 1010 } 1011 }; 1012 1013 /// Drop dead args of a linalg generic op. 1014 /// An arg is dead if it has zero uses in the op region. 1015 struct DeadArgsGenericOpInputs : public OpRewritePattern<GenericOp> { 1016 using OpRewritePattern<GenericOp>::OpRewritePattern; 1017 LogicalResult matchAndRewrite(GenericOp genericOp, 1018 PatternRewriter &rewriter) const override { 1019 SmallVector<AffineMap> oldIndexingMaps = genericOp.getIndexingMaps(); 1020 // Maps must be projected permutations. 1021 if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { 1022 return !map.isProjectedPermutation(); 1023 })) 1024 return failure(); 1025 Block &payload = genericOp.region().front(); 1026 SmallVector<Value> newInputOperands; 1027 SmallVector<AffineMap> newIndexingMaps; 1028 bool deadArgFound = false; 1029 int inputSize = genericOp.getInputOperands().size(); 1030 for (int i = inputSize - 1; i >= 0; i--) { 1031 OpOperand *opOperand = genericOp.getInputOperand(i); 1032 // Iterate in reverse, so that we erase later args first, preventing the 1033 // argument list from shifting unexpectedly and invalidating all our 1034 // indices. 1035 if (payload.getArgument(i).use_empty() && 1036 !hasaUniqueDim(oldIndexingMaps, i)) { 1037 payload.eraseArgument(i); 1038 deadArgFound = true; 1039 // remove this indexing map out of consideration for hasaUniqueDim check 1040 oldIndexingMaps.erase(oldIndexingMaps.begin() + i); 1041 } else { 1042 newInputOperands.insert(newInputOperands.begin(), opOperand->get()); 1043 newIndexingMaps.insert(newIndexingMaps.begin(), 1044 genericOp.getTiedIndexingMap(opOperand)); 1045 } 1046 } 1047 // Bail out if there are no dead args. 1048 if (!deadArgFound) 1049 return failure(); 1050 for (OpOperand *opOperand : genericOp.getOutputOperands()) 1051 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); 1052 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 1053 1054 auto newOp = rewriter.create<GenericOp>( 1055 genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, 1056 outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), 1057 genericOp.iterator_types(), genericOp.docAttr(), 1058 genericOp.library_callAttr()); 1059 // Copy over unknown attributes. They might be load bearing for some flow. 1060 ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames(); 1061 for (NamedAttribute kv : genericOp->getAttrs()) { 1062 if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { 1063 newOp->setAttr(kv.getName(), kv.getValue()); 1064 } 1065 } 1066 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 1067 newOp.region().begin()); 1068 rewriter.replaceOp(genericOp, newOp->getResults()); 1069 return success(); 1070 } 1071 }; 1072 } // namespace 1073 1074 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, 1075 MLIRContext *context) { 1076 results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp, 1077 DeadArgsGenericOpInputs>(context); 1078 } 1079 1080 LogicalResult GenericOp::fold(ArrayRef<Attribute>, 1081 SmallVectorImpl<OpFoldResult> &) { 1082 return foldMemRefCast(*this); 1083 } 1084 1085 //===----------------------------------------------------------------------===// 1086 // InitTensorOp 1087 //===----------------------------------------------------------------------===// 1088 1089 void InitTensorOp::build(OpBuilder &b, OperationState &result, 1090 ArrayRef<OpFoldResult> sizes, Type elementType, 1091 ArrayRef<NamedAttribute> attrs) { 1092 SmallVector<Value, 4> dynamicSizes; 1093 SmallVector<int64_t, 4> staticSizes; 1094 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1095 ShapedType::kDynamicSize); 1096 auto resultType = RankedTensorType ::get(staticSizes, elementType); 1097 build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); 1098 result.addAttributes(attrs); 1099 } 1100 1101 LogicalResult InitTensorOp::verify() { 1102 RankedTensorType resultType = getType(); 1103 SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range( 1104 static_sizes().cast<ArrayAttr>(), 1105 [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); })); 1106 1107 if (failed(verifyListOfOperandsOrIntegers( 1108 *this, "sizes", resultType.getRank(), static_sizes(), sizes(), 1109 ShapedType::isDynamic))) 1110 return failure(); 1111 1112 if (static_sizes().size() != static_cast<unsigned>(resultType.getRank())) 1113 return emitError("expected ") << resultType.getRank() << " sizes values"; 1114 1115 Type expectedType = InitTensorOp::inferResultType( 1116 staticSizes, resultType.getElementType(), resultType.getEncoding()); 1117 if (resultType != expectedType) { 1118 return emitError("specified type ") 1119 << resultType << " does not match the inferred type " 1120 << expectedType; 1121 } 1122 return success(); 1123 } 1124 1125 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes, 1126 Type elementType, Attribute encoding) { 1127 return RankedTensorType::get(staticSizes, elementType, encoding); 1128 } 1129 1130 SmallVector<OpFoldResult> InitTensorOp::getMixedSizes() { 1131 SmallVector<OpFoldResult> mixedSizes; 1132 mixedSizes.reserve(getType().getRank()); 1133 unsigned dynamicValIndex = 0; 1134 for (Attribute attr : static_sizes()) { 1135 auto intAttr = attr.cast<IntegerAttr>(); 1136 if (!ShapedType::isDynamic(intAttr.getInt())) { 1137 mixedSizes.push_back(intAttr); 1138 continue; 1139 } 1140 mixedSizes.push_back(sizes()[dynamicValIndex++]); 1141 } 1142 return mixedSizes; 1143 } 1144 1145 namespace { 1146 /// Change the type of the result of a `linalg.init_tensor` by making the result 1147 /// type statically sized along dimension that in the original operation where 1148 /// defined as dynamic, but the size was defined using a `constant` op. For 1149 /// example 1150 /// 1151 /// %c5 = arith.constant 5: index 1152 /// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32> 1153 /// 1154 /// to 1155 /// 1156 /// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32> 1157 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> { 1158 using OpRewritePattern<InitTensorOp>::OpRewritePattern; 1159 1160 LogicalResult matchAndRewrite(InitTensorOp op, 1161 PatternRewriter &rewriter) const override { 1162 SmallVector<Value, 4> dynamicSizes; 1163 SmallVector<int64_t, 4> staticSizes; 1164 for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { 1165 // If the size is already static, nothing to do. 1166 if (!op.isDynamicSize(i)) { 1167 staticSizes.push_back(op.getStaticSize(i)); 1168 continue; 1169 } 1170 1171 // If the size is dynamic but defined using a `constant` op, get the 1172 // constant value to find the static size to use. 1173 unsigned operandNum = op.getIndexOfDynamicSize(i); 1174 Value sizeOperand = op.getOperand(operandNum); 1175 if (auto constantIndexOp = 1176 sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) { 1177 staticSizes.push_back(constantIndexOp.value()); 1178 continue; 1179 } 1180 1181 // Fallback case. Keep the size dynamic. 1182 dynamicSizes.push_back(sizeOperand); 1183 staticSizes.push_back(ShapedType::kDynamicSize); 1184 } 1185 RankedTensorType newType = 1186 RankedTensorType::get(staticSizes, op.getType().getElementType()); 1187 if (newType == op.getType()) 1188 return failure(); 1189 auto newOp = 1190 rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes, 1191 rewriter.getI64ArrayAttr(staticSizes)); 1192 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 1193 return success(); 1194 } 1195 }; 1196 } // namespace 1197 1198 namespace { 1199 /// Since `init_tensor` operation creates a tensor needed only for its shape, a 1200 /// slice of this is also needed only for its shape. The result can be 1201 /// replaced by a new init_tensor operation of the same size as the extract 1202 /// slice op. 1203 struct FoldInitTensorWithExtractSliceOp 1204 : public OpRewritePattern<tensor::ExtractSliceOp> { 1205 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 1206 1207 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 1208 PatternRewriter &rewriter) const override { 1209 if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>()) 1210 return failure(); 1211 // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved 1212 // as well as its result type. 1213 rewriter.replaceOpWithNewOp<linalg::InitTensorOp>( 1214 sliceOp, sliceOp.sizes(), 1215 sliceOp.result().getType().cast<RankedTensorType>().getShape(), 1216 sliceOp.getSourceType().getElementType()); 1217 return success(); 1218 } 1219 }; 1220 1221 template <typename TensorReshapeOp> 1222 struct FoldInitTensorWithTensorReshapeOp 1223 : public OpRewritePattern<TensorReshapeOp> { 1224 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1225 1226 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1227 PatternRewriter &rewriter) const override { 1228 if (!reshapeOp.src().template getDefiningOp<InitTensorOp>()) 1229 return failure(); 1230 Location loc = reshapeOp.getLoc(); 1231 ReifiedRankedShapedTypeDims resultShapes; 1232 ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = 1233 cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation()); 1234 if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, 1235 resultShapes)) || 1236 !llvm::hasSingleElement(resultShapes)) 1237 return failure(); 1238 Value initTensor = rewriter.create<InitTensorOp>( 1239 loc, getAsOpFoldResult(resultShapes[0]), 1240 reshapeOp.getResultType().getElementType()); 1241 if (initTensor.getType() != reshapeOp.getResultType()) { 1242 rewriter.replaceOpWithNewOp<tensor::CastOp>( 1243 reshapeOp, reshapeOp.getResultType(), initTensor); 1244 } else { 1245 rewriter.replaceOp(reshapeOp, initTensor); 1246 } 1247 return success(); 1248 } 1249 }; 1250 1251 struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> { 1252 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 1253 1254 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 1255 PatternRewriter &rewriter) const override { 1256 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 1257 auto initTensorOp = dimOp.source().getDefiningOp<linalg::InitTensorOp>(); 1258 if (!initTensorOp || !maybeConstantIndex) 1259 return failure(); 1260 if (!initTensorOp.isDynamicSize(*maybeConstantIndex)) 1261 return failure(); 1262 rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex)); 1263 return success(); 1264 } 1265 }; 1266 1267 /// Canonicalize 1268 /// 1269 /// ```mlir 1270 /// %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32> 1271 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32> 1272 /// ``` 1273 /// 1274 /// into 1275 /// 1276 /// ```mlir 1277 /// %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32> 1278 /// ``` 1279 /// 1280 /// This assumes the input program is correct in terms of its shape. So it 1281 /// is safe to assume that `%d0` is in fact 4. If that was not the case, the 1282 /// input program is wrong to begin with, so its undefined behavior anyway (i.e. 1283 /// this optimization can still triggering without violating program semantics). 1284 struct FoldInitTensorWithTensorCastOp 1285 : public OpRewritePattern<tensor::CastOp> { 1286 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1287 1288 LogicalResult matchAndRewrite(tensor::CastOp castOp, 1289 PatternRewriter &rewriter) const override { 1290 if (!canFoldIntoProducerOp(castOp)) 1291 return failure(); 1292 auto producer = castOp.source().getDefiningOp<InitTensorOp>(); 1293 if (!producer) 1294 return failure(); 1295 1296 auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>(); 1297 ArrayRef<int64_t> resultShape = resultType.getShape(); 1298 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes(); 1299 SmallVector<OpFoldResult> newMixedSizes; 1300 newMixedSizes.reserve(currMixedSizes.size()); 1301 assert(resultShape.size() == currMixedSizes.size() && 1302 "mismatch in result shape and sizes of init_tensor op"); 1303 for (auto it : llvm::zip(resultShape, currMixedSizes)) { 1304 int64_t newDim = std::get<0>(it); 1305 OpFoldResult currDim = std::get<1>(it); 1306 // Case 1: The init tensor dim is static. Check that the tensor cast 1307 // result dim matches. 1308 if (auto attr = currDim.dyn_cast<Attribute>()) { 1309 if (ShapedType::isDynamic(newDim) || 1310 newDim != attr.cast<IntegerAttr>().getInt()) { 1311 // Something is off, the cast result shape cannot be more dynamic than 1312 // the init tensor result shape (enforced by `canFoldIntoProducer`). 1313 // Abort for now. 1314 return rewriter.notifyMatchFailure( 1315 producer, "mismatch in static value of shape of init " 1316 "tensor result and cast result"); 1317 } 1318 newMixedSizes.push_back(attr); 1319 continue; 1320 } 1321 1322 // Case 2 : The tensor cast shape is static, but init tensor result shape 1323 // is dynamic. 1324 if (!ShapedType::isDynamic(newDim)) { 1325 newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); 1326 continue; 1327 } 1328 1329 // Case 3 : The tensor cast shape is dynamic and init tensor result shape 1330 // is dynamic. Use the dynamic value from the init tensor op. 1331 newMixedSizes.push_back(currDim); 1332 } 1333 1334 rewriter.replaceOpWithNewOp<InitTensorOp>(castOp, newMixedSizes, 1335 resultType.getElementType()); 1336 return success(); 1337 } 1338 }; 1339 1340 } // namespace 1341 1342 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 1343 MLIRContext *context) { 1344 results.add<FoldInitTensorWithTensorCastOp, FoldInitTensorWithDimOp, 1345 FoldInitTensorWithExtractSliceOp, 1346 FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>, 1347 FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>, 1348 ReplaceStaticShapeDims>(context); 1349 } 1350 1351 LogicalResult InitTensorOp::reifyResultShapes( 1352 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 1353 auto shapes = llvm::to_vector<4>(llvm::map_range( 1354 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 1355 if (isDynamicSize(dim)) 1356 return getDynamicSize(dim); 1357 return builder.create<arith::ConstantIndexOp>(getLoc(), 1358 getStaticSize(dim)); 1359 })); 1360 reifiedReturnShapes.emplace_back(std::move(shapes)); 1361 return success(); 1362 } 1363 1364 //===----------------------------------------------------------------------===// 1365 // YieldOp 1366 //===----------------------------------------------------------------------===// 1367 1368 void linalg::YieldOp::print(OpAsmPrinter &p) { 1369 if (getNumOperands() > 0) 1370 p << ' ' << getOperands(); 1371 p.printOptionalAttrDict((*this)->getAttrs()); 1372 if (getNumOperands() > 0) 1373 p << " : " << getOperandTypes(); 1374 } 1375 1376 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) { 1377 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo; 1378 SmallVector<Type, 2> types; 1379 SMLoc loc = parser.getCurrentLocation(); 1380 return failure(parser.parseOperandList(opInfo) || 1381 parser.parseOptionalAttrDict(result.attributes) || 1382 (!opInfo.empty() && parser.parseColonTypeList(types)) || 1383 parser.resolveOperands(opInfo, types, loc, result.operands)); 1384 } 1385 1386 // Check the operand number and types must match the element types of the 1387 // LinalgOp interface's shaped operands. 1388 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { 1389 if (op.getNumOperands() != linalgOp.getNumOutputs()) 1390 return op.emitOpError("expected number of yield values (") 1391 << linalgOp.getNumOutputs() 1392 << ") to match the number of operands of the enclosing " 1393 << "LinalgOp (" << op.getNumOperands() << ")"; 1394 1395 for (OpOperand &opOperand : op->getOpOperands()) { 1396 OpOperand *outputOperand = 1397 linalgOp.getOutputOperand(opOperand.getOperandNumber()); 1398 Type elementType = getElementTypeOrSelf(outputOperand->get().getType()); 1399 if (opOperand.get().getType() != elementType) 1400 return op.emitOpError("type of yield operand ") 1401 << (opOperand.getOperandNumber() + 1) << " (" 1402 << opOperand.get().getType() << ") doesn't match " 1403 << "the element type of the enclosing linalg.generic op (" 1404 << elementType << ")"; 1405 } 1406 return success(); 1407 } 1408 1409 LogicalResult linalg::YieldOp::verify() { 1410 auto *parentOp = (*this)->getParentOp(); 1411 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) 1412 return emitOpError("expected single non-empty parent region"); 1413 1414 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp)) 1415 return verifyYield(*this, linalgOp); 1416 1417 return emitOpError("expected parent op with LinalgOp interface"); 1418 } 1419 1420 //===----------------------------------------------------------------------===// 1421 // IndexOp 1422 //===----------------------------------------------------------------------===// 1423 1424 LogicalResult IndexOp::verify() { 1425 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp()); 1426 if (!linalgOp) 1427 return emitOpError("expected parent op with LinalgOp interface"); 1428 if (linalgOp.getNumLoops() <= dim()) 1429 return emitOpError("expected dim (") 1430 << dim() << ") to be lower than the number of loops (" 1431 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; 1432 return success(); 1433 } 1434 1435 /////// Operations corresponding to library calls defined with Tablegen //////// 1436 1437 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" 1438 1439 #define GET_OP_CLASSES 1440 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" 1441 1442 #define GET_OP_CLASSES 1443 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 1444 1445 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. 1446 /// Assumes `op` is a LinalgOp. 1447 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, 1448 SmallVectorImpl<unsigned> &res) { 1449 if (!cast<LinalgOp>(op).iterator_types()) 1450 return; 1451 1452 unsigned dim = 0; 1453 for (auto tn : 1454 cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) { 1455 if (tn == iteratorTypeName) 1456 res.push_back(dim); 1457 ++dim; 1458 } 1459 } 1460 1461 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap, 1462 unsigned rank, 1463 MLIRContext *context) { 1464 if (maybeMap) 1465 return maybeMap.getValue(); 1466 if (rank == 0) 1467 return AffineMap::get(context); 1468 return AffineMap::getMultiDimIdentityMap(rank, context); 1469 } 1470 1471 SmallVector<AffineExpr, 4> 1472 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, 1473 MLIRContext *context) { 1474 SmallVector<AffineExpr, 4> res; 1475 res.reserve(num); 1476 for (unsigned i = 0; i < num; ++i) 1477 res.push_back(getAffineDimExpr(startIdx++, context)); 1478 return res; 1479 } 1480 1481 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a, 1482 ArrayRef<AffineExpr> b) { 1483 auto rangeA = llvm::make_range(a.begin(), a.end()); 1484 auto rangeB = llvm::make_range(b.begin(), b.end()); 1485 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB); 1486 return llvm::to_vector<4>(concatRanges); 1487 } 1488 1489 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { 1490 if (auto memref = t.dyn_cast<MemRefType>()) { 1491 ss << "view"; 1492 for (auto size : memref.getShape()) 1493 if (size < 0) 1494 ss << "sx"; 1495 else 1496 ss << size << "x"; 1497 appendMangledType(ss, memref.getElementType()); 1498 } else if (auto vec = t.dyn_cast<VectorType>()) { 1499 ss << "vector"; 1500 llvm::interleave( 1501 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); 1502 appendMangledType(ss, vec.getElementType()); 1503 } else if (t.isSignlessIntOrIndexOrFloat()) { 1504 ss << t; 1505 } else { 1506 llvm_unreachable("Invalid type for linalg library name mangling"); 1507 } 1508 } 1509 1510 std::string mlir::linalg::generateLibraryCallName(Operation *op) { 1511 assert(isa<LinalgOp>(op)); 1512 std::string name(op->getName().getStringRef().str()); 1513 name.reserve(128); 1514 std::replace(name.begin(), name.end(), '.', '_'); 1515 llvm::raw_string_ostream ss(name); 1516 ss << "_"; 1517 auto types = op->getOperandTypes(); 1518 llvm::interleave( 1519 types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, 1520 [&]() { ss << "_"; }); 1521 return ss.str(); 1522 } 1523 1524 //===----------------------------------------------------------------------===// 1525 // Canonicalizers and Folders. 1526 //===----------------------------------------------------------------------===// 1527 1528 namespace { 1529 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> { 1530 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 1531 1532 LogicalResult matchAndRewrite(LinalgOp op, 1533 PatternRewriter &rewriter) const override { 1534 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 1535 // Linalg "inputs" may be either tensor or memref type. 1536 // tensor<0xelt_type> is a convention that may not always mean 1537 // "0 iterations". Only erase in cases we see memref<...x0x...>. 1538 auto mt = opOperand->get().getType().dyn_cast<MemRefType>(); 1539 if (!mt) 1540 continue; 1541 if (llvm::is_contained(op.getShape(opOperand), 0)) { 1542 rewriter.eraseOp(op); 1543 return success(); 1544 } 1545 } 1546 return failure(); 1547 } 1548 }; 1549 1550 struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> { 1551 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 1552 1553 LogicalResult matchAndRewrite(LinalgOp op, 1554 PatternRewriter &rewriter) const override { 1555 // If no operand comes from a tensor::CastOp and can be folded then fail. 1556 bool hasTensorCastOperand = 1557 llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { 1558 if (opOperand->get().isa<BlockArgument>()) 1559 return false; 1560 auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>(); 1561 return castOp && canFoldIntoConsumerOp(castOp); 1562 }); 1563 if (!hasTensorCastOperand) 1564 return failure(); 1565 1566 SmallVector<Type, 4> newResultTypes; 1567 newResultTypes.reserve(op->getNumResults()); 1568 SmallVector<Value, 4> newOperands; 1569 newOperands.reserve(op->getNumOperands()); 1570 // Inputs may fold. 1571 for (OpOperand *opOperand : op.getInputOperands()) { 1572 auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>(); 1573 newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) 1574 ? tensorCastOp.source() 1575 : opOperand->get()); 1576 } 1577 // Init tensors may fold, in which case the resultType must also change. 1578 for (OpOperand *opOperand : op.getOutputOperands()) { 1579 auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>(); 1580 bool fold = canFoldIntoConsumerOp(tensorCastOp); 1581 newOperands.push_back(fold ? tensorCastOp.getOperand() 1582 : opOperand->get()); 1583 newResultTypes.push_back(newOperands.back().getType()); 1584 } 1585 // Clone op. 1586 Operation *newOp = 1587 op.clone(rewriter, op->getLoc(), newResultTypes, newOperands); 1588 SmallVector<Value, 4> replacements; 1589 replacements.reserve(newOp->getNumResults()); 1590 for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { 1591 Value oldResult = std::get<0>(result); 1592 Value newResult = std::get<1>(result); 1593 if (newResult.getType() != oldResult.getType()) { 1594 replacements.push_back(rewriter.create<tensor::CastOp>( 1595 op->getLoc(), oldResult.getType(), newResult)); 1596 } else { 1597 replacements.push_back(newResult); 1598 } 1599 } 1600 rewriter.replaceOp(op, replacements); 1601 1602 return success(); 1603 } 1604 }; 1605 1606 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has 1607 /// result that is more static than the linalg op. 1608 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> { 1609 using OpRewritePattern<tensor::CastOp>::OpRewritePattern; 1610 1611 LogicalResult matchAndRewrite(tensor::CastOp castOp, 1612 PatternRewriter &rewriter) const override { 1613 if (!tensor::canFoldIntoProducerOp(castOp)) 1614 return failure(); 1615 auto linalgOp = castOp.source().getDefiningOp<LinalgOp>(); 1616 if (!linalgOp) 1617 return failure(); 1618 1619 OpBuilder::InsertionGuard guard(rewriter); 1620 rewriter.setInsertionPoint(linalgOp); 1621 1622 Location loc = linalgOp.getLoc(); 1623 OpResult resultValue = castOp.source().cast<OpResult>(); 1624 unsigned resultNumber = resultValue.getResultNumber(); 1625 auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>(); 1626 // Replace the `outs` for the result with a `tensor.cast`. This cast is now 1627 // going from a more dynamic shape to a less dynamic shape. If the producer 1628 // for this cast, i.e. producer of the out operand, is also an operation 1629 // that folds with tensor.cast consumer (like this pattern), the cast will 1630 // continue to propagate as far up the stack as it can go. 1631 OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); 1632 Value newOperand = 1633 rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get()); 1634 SmallVector<Value> newOperands = linalgOp.getInputOperands(); 1635 SmallVector<Value> outputOperands = linalgOp.getOutputOperands(); 1636 outputOperands[resultNumber] = newOperand; 1637 newOperands.append(outputOperands.begin(), outputOperands.end()); 1638 1639 SmallVector<Type> resultTypes(linalgOp->result_type_begin(), 1640 linalgOp->result_type_end()); 1641 resultTypes[resultNumber] = resultType; 1642 Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands); 1643 1644 // Create a tensor.cast operation back to the original type. 1645 Value castBack = rewriter.create<tensor::CastOp>( 1646 loc, resultValue.getType(), newOp->getResult(resultNumber)); 1647 1648 SmallVector<Value> results(newOp->result_begin(), newOp->result_end()); 1649 results[resultNumber] = castBack; 1650 rewriter.replaceOp(linalgOp, results); 1651 rewriter.replaceOp(castOp, newOp->getResult(resultNumber)); 1652 return success(); 1653 } 1654 }; 1655 1656 /// For each of the operand in `operands` this function maps the static sizes of 1657 /// dimensions to their affine dim expressions. 1658 static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands, 1659 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) { 1660 for (OpOperand *opOperand : operands) { 1661 if (linalgOp.isScalar(opOperand)) 1662 continue; 1663 Value src = opOperand->get(); 1664 auto sourceType = src.getType().cast<RankedTensorType>(); 1665 auto sourceMap = linalgOp.getTiedIndexingMap(opOperand); 1666 1667 // Get the `sourceShape` of the `sourceType`. If the operand is a result of 1668 // `tensor.cast` operation and source of the cast operation has a static 1669 // shape, then assign it to the `sourceShape`. 1670 auto *parentOp = src.getDefiningOp(); 1671 ArrayRef<int64_t> sourceShape = sourceType.getShape(); 1672 if (parentOp) { 1673 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) { 1674 Value castSource = castOp.source(); 1675 auto castSourceType = castSource.getType().cast<RankedTensorType>(); 1676 if (castSourceType.hasStaticShape()) 1677 sourceShape = castSourceType.getShape(); 1678 } 1679 } 1680 1681 // If the source shape's dimension has a static shape, map the affine dim 1682 // expression to the known static size. 1683 for (unsigned i = 0; i < sourceShape.size(); i++) { 1684 if (sourceType.isDynamicDim(i)) 1685 continue; 1686 if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>()) 1687 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]); 1688 } 1689 } 1690 } 1691 1692 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes 1693 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and 1694 /// their result types is stored in `resultTypes`. If `opOperand` requires no 1695 /// change then `changeNeeded` is false and same operand is added in the 1696 /// `newOperands` list. 1697 static void createNewOperandWithStaticSizes( 1698 Location loc, PatternRewriter &rewriter, OpOperand *opOperand, 1699 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp, 1700 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes, 1701 bool &changeNeeded) { 1702 Value src = opOperand->get(); 1703 newOperands.push_back(src); 1704 if (linalgOp.isScalar(opOperand)) 1705 return; 1706 auto sourceType = src.getType().cast<RankedTensorType>(); 1707 Type resultType = sourceType; 1708 if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) { 1709 resultTypes.push_back(resultType); 1710 return; 1711 } 1712 ArrayRef<int64_t> sourceShape = sourceType.getShape(); 1713 AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand); 1714 SmallVector<int64_t> newShape; 1715 // If operand is updated with new shape, `newOperandNeeded` will be 1716 // true. 1717 bool newOperandNeeded = false; 1718 for (unsigned i = 0; i < sourceShape.size(); i++) { 1719 int64_t dimShape = sourceShape[i]; 1720 AffineExpr dimExpr = sourceMap.getResult(i); 1721 if (affineExprToSize.find(dimExpr) == affineExprToSize.end() || 1722 !sourceType.isDynamicDim(i)) { 1723 newShape.push_back(dimShape); 1724 continue; 1725 } 1726 // Dimension has a dynamic shape and corresponding affine dim 1727 // expression is present in the map. So assign the size for the 1728 // given affine dim expression to the dimension. 1729 newShape.push_back(affineExprToSize[dimExpr]); 1730 newOperandNeeded = true; 1731 } 1732 resultType = RankedTensorType::get(newShape, sourceType.getElementType()); 1733 if (newOperandNeeded) { 1734 changeNeeded = true; 1735 // Get the new operand value given its size and element type by 1736 // casting it. 1737 Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src); 1738 unsigned index = opOperand->getOperandNumber(); 1739 newOperands[index] = newOperand; 1740 } 1741 if (linalgOp.isOutputTensor(opOperand)) 1742 resultTypes.push_back(resultType); 1743 } 1744 1745 /// Static shapes for the operands can be inferred if any one of the operands 1746 /// have a static shape. This can be done by referring to the affine dim 1747 /// expressions for the operand. 1748 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> { 1749 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 1750 1751 LogicalResult matchAndRewrite(LinalgOp linalgOp, 1752 PatternRewriter &rewriter) const override { 1753 if (!linalgOp.hasTensorSemantics()) 1754 return failure(); 1755 1756 // Maps must be projected permutations. 1757 if (llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap map) { 1758 return !map.isProjectedPermutation(); 1759 })) 1760 return failure(); 1761 1762 // Maps affine dim expressions to the static size of that dimension. 1763 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize; 1764 Location loc = linalgOp.getLoc(); 1765 1766 // For each of the affine dim expression, check if the size is known. If 1767 // known add that in the map. 1768 populateMap(linalgOp, linalgOp.getInputAndOutputOperands(), 1769 affineExprToSize); 1770 1771 SmallVector<Value> newOperands; 1772 SmallVector<Type> resultTypes; 1773 1774 // `changeNeeded` is `false` if the operands of `linalgOp` require no 1775 // change in their types. 1776 bool changeNeeded = false; 1777 newOperands.reserve(linalgOp.getNumInputsAndOutputs()); 1778 resultTypes.reserve(linalgOp.getNumOutputs()); 1779 1780 // Iterate over all the operands and update the static sizes. 1781 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 1782 createNewOperandWithStaticSizes(loc, rewriter, opOperand, 1783 affineExprToSize, linalgOp, newOperands, 1784 resultTypes, changeNeeded); 1785 } 1786 1787 // If the generic op has all the required static information, no 1788 // canonicalization needed. 1789 if (!changeNeeded) 1790 return failure(); 1791 1792 // Clone op. 1793 Operation *newOp = 1794 linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands); 1795 SmallVector<Value> replacements; 1796 replacements.reserve(newOp->getNumResults()); 1797 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) { 1798 Value newResult = std::get<1>(it); 1799 Value oldResult = std::get<0>(it); 1800 Type newType = newResult.getType(); 1801 Type oldType = oldResult.getType(); 1802 replacements.push_back( 1803 (newType != oldType) 1804 ? rewriter.create<tensor::CastOp>(loc, oldType, newResult) 1805 : newResult); 1806 } 1807 rewriter.replaceOp(linalgOp, replacements); 1808 return success(); 1809 } 1810 }; 1811 1812 } // namespace 1813 1814 // All named ops canonicalizers and folders are auto-generated in the 1815 // .cpp.inc. 1816 1817 //===----------------------------------------------------------------------===// 1818 // LinalgDialect 1819 //===----------------------------------------------------------------------===// 1820 1821 void LinalgDialect::getCanonicalizationPatterns( 1822 RewritePatternSet &results) const { 1823 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, 1824 FoldTensorCastProducerOp, InferStaticShapeOfOperands>( 1825 getContext()); 1826 } 1827 1828 Operation *LinalgDialect::materializeConstant(OpBuilder &builder, 1829 Attribute value, Type type, 1830 Location loc) { 1831 return builder.create<arith::ConstantOp>(loc, type, value); 1832 } 1833