1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 17 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 18 #include "mlir/Dialect/Utils/StaticValueUtils.h" 19 #include "mlir/IR/AffineExprVisitor.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/OpImplementation.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Interfaces/InferTypeOpInterface.h" 24 #include "mlir/Parser.h" 25 26 #include "llvm/ADT/DenseMap.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallSet.h" 29 #include "llvm/ADT/StringSet.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/FormatVariadic.h" 32 #include "llvm/Support/MathExtras.h" 33 #include "llvm/Support/raw_ostream.h" 34 35 using namespace mlir; 36 using namespace mlir::linalg; 37 38 #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc" 39 40 /// Forward declarations. 41 42 /// Generic entry point to create the block for the region of a LinalgOp. 43 /// This is used by both named structured ops created by ods-gen and by manually 44 /// defined C++ ops. 45 /// This is used by both builders and parsers. 46 /// This function creates the block in the region with arguments corresponding 47 /// to the elemental types of `inputTypes` and `outputTypes`. The latter are 48 /// asserted to be of ShapedType. 49 template <typename NamedStructuredOpType> 50 static void fillStructuredOpRegion( 51 OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, 52 TypeRange outputTypes, 53 llvm::function_ref<void(unsigned, unsigned)> errorHandler = nullptr); 54 55 /// Generic entry point to create both the region and the block of a LinalgOp. 56 template <typename NamedStructuredOpType> 57 static void 58 createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, 59 TypeRange inputTypes, TypeRange outputTypes); 60 61 /// Common parsing and printing used for both named structured ops created by 62 /// ods-gen and by manually defined C++ ops. Does not handle regions. 63 static ParseResult 64 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, 65 SmallVectorImpl<Type> &inputTypes, 66 SmallVectorImpl<Type> &outputTypes); 67 template <typename NamedStructuredOpType> 68 static void printCommonStructuredOpParts(OpAsmPrinter &p, 69 NamedStructuredOpType op); 70 71 /// Specific parsing and printing for named structured ops created by ods-gen. 72 template <typename NamedStructuredOpType> 73 static ParseResult 74 parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, 75 TypeRange inputTypes, TypeRange outputTypes); 76 77 static ParseResult 78 parseNamedStructuredOpResults(OpAsmParser &parser, 79 SmallVectorImpl<Type> &resultTypes); 80 81 template <typename NamedStructuredOpType> 82 static ParseResult parseNamedStructuredOp(OpAsmParser &parser, 83 OperationState &result); 84 85 static void printNamedStructuredOpResults(OpAsmPrinter &p, 86 TypeRange resultTypes); 87 88 template <typename NamedStructuredOpType> 89 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); 90 91 /// This is a common class used for patterns of the form 92 /// ``` 93 /// someop(memrefcast(%src)) -> someop(%src) 94 /// ``` 95 /// It folds the source of the memref.cast into the root operation directly. 96 static LogicalResult foldMemRefCast(Operation *op) { 97 bool folded = false; 98 for (OpOperand &operand : op->getOpOperands()) { 99 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 100 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 101 operand.set(castOp.getOperand()); 102 folded = true; 103 } 104 } 105 return success(folded); 106 } 107 108 /// This is a specialization of `foldMemRefCast` used for patterns of the form 109 /// ``` 110 /// tiled_loop(memrefcast(%src)) -> tiled_loop(%src) 111 /// ``` 112 /// It folds the source of the memref.cast into the root operation directly. 113 static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) { 114 bool folded = false; 115 Location loc = op->getLoc(); 116 117 Block *body = op.getBody(); 118 OpBuilder b = OpBuilder::atBlockBegin(body); 119 120 // Update `input` and `output` operands and block arguments if necessary. 121 // Operands list: [lbs, ubs, steps, inputs, outputs]. 122 // Block args list: [ivs, inputs, outputs]. 123 for (size_t operandIndex = op.getNumControlOperands(), 124 bbArgIndex = op.getNumLoops(), e = op.getNumOperands(); 125 operandIndex < e; ++operandIndex, ++bbArgIndex) { 126 OpOperand &operand = op->getOpOperand(operandIndex); 127 128 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 129 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 130 operand.set(castOp.getOperand()); 131 BlockArgument newBbArg = 132 body->insertArgument(bbArgIndex, castOp.getOperand().getType()); 133 BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1); 134 135 // Insert memref.cast back to the original type. 136 oldBbArg.replaceAllUsesWith( 137 b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg)); 138 body->eraseArgument(oldBbArg.getArgNumber()); 139 140 folded = true; 141 } 142 } 143 return success(folded); 144 } 145 146 //===----------------------------------------------------------------------===// 147 // Region builder helper. 148 // TODO: Move this to a utility library. 149 // The public methods on this class are referenced directly from generated code 150 // and bind by name to math and type conversion functions in the DSL as: 151 // `arithfn__{fnName}` 152 // `typefn__{fnName}` 153 // Examples: 154 // `arithfn__add` 155 // `arithfn__mul` 156 // `typefn__cast` 157 // The naming convention is intentional in order to match snake-cased DSL names. 158 // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. 159 // 160 // Implementations of the math functions must be polymorphic over numeric types, 161 // internally performing necessary casts. If the function application makes no 162 // sense, then the only recourse is to assert and return nullptr. This can be 163 // extended later if it becomes possible to fail construction of the region. The 164 // invariant should be enforced at a higher level. 165 // 166 // TODO: These helpers are currently type polymorphic over the class of integer 167 // and floating point types, but they will not internally cast within bit 168 // widths of a class (mixed precision such as i8->i32) or across classes 169 // (i.e. mixed float and integer). Many such combinations are ambiguous or need 170 // to be handled with care and work is being considered to extend the op 171 // language to make such cases explicit. In the mean-time, violating this will 172 // fail verification, which is deemed acceptable. 173 //===----------------------------------------------------------------------===// 174 175 namespace { 176 177 class RegionBuilderHelper { 178 public: 179 RegionBuilderHelper(MLIRContext *context, Block &block) 180 : context(context), block(block) {} 181 182 // Generates operations to cast the given operand to a specified type. 183 // If the cast cannot be performed, a warning will be issued and the 184 // operand returned as-is (which will presumably yield a verification 185 // issue downstream). 186 Value cast(Type toType, Value operand, bool isUnsignedCast) { 187 OpBuilder builder = getBuilder(); 188 auto loc = operand.getLoc(); 189 190 if (operand.getType() == toType) 191 return operand; 192 if (auto toIntType = toType.dyn_cast<IntegerType>()) { 193 // If operand is floating point, cast directly to the int type. 194 if (operand.getType().isa<FloatType>()) { 195 if (isUnsignedCast) 196 return builder.create<arith::FPToUIOp>(loc, toType, operand); 197 return builder.create<arith::FPToSIOp>(loc, toType, operand); 198 } 199 // Cast index operands directly to the int type. 200 if (operand.getType().isIndex()) 201 return builder.create<arith::IndexCastOp>(loc, toType, operand); 202 if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) { 203 // Either extend or truncate. 204 if (toIntType.getWidth() > fromIntType.getWidth()) { 205 if (isUnsignedCast) 206 return builder.create<arith::ExtUIOp>(loc, toType, operand); 207 return builder.create<arith::ExtSIOp>(loc, toType, operand); 208 } 209 if (toIntType.getWidth() < fromIntType.getWidth()) 210 return builder.create<arith::TruncIOp>(loc, toType, operand); 211 } 212 } else if (auto toFloatType = toType.dyn_cast<FloatType>()) { 213 // If operand is integer, cast directly to the float type. 214 // Note that it is unclear how to cast from BF16<->FP16. 215 if (operand.getType().isa<IntegerType>()) { 216 if (isUnsignedCast) 217 return builder.create<arith::UIToFPOp>(loc, toFloatType, operand); 218 return builder.create<arith::SIToFPOp>(loc, toFloatType, operand); 219 } 220 if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) { 221 if (toFloatType.getWidth() > fromFloatType.getWidth()) 222 return builder.create<arith::ExtFOp>(loc, toFloatType, operand); 223 if (toFloatType.getWidth() < fromFloatType.getWidth()) 224 return builder.create<arith::TruncFOp>(loc, toFloatType, operand); 225 } 226 } 227 228 emitWarning(operand.getLoc()) << "could not cast operand of type " 229 << operand.getType() << " to " << toType; 230 return operand; 231 } 232 233 // NOLINTNEXTLINE(*-identifier-naming): externally called. 234 Value typefn__cast(Type toType, Value operand) { 235 return cast(toType, operand, false); 236 } 237 238 // NOLINTNEXTLINE(*-identifier-naming): externally called. 239 Value typefn__cast_unsigned(Type toType, Value operand) { 240 return cast(toType, operand, true); 241 } 242 243 // NOLINTNEXTLINE(*-identifier-naming): externally called. 244 Value arithfn__add(Value lhs, Value rhs) { 245 OpBuilder builder = getBuilder(); 246 if (isFloatingPoint(lhs)) 247 return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs); 248 if (isInteger(lhs)) 249 return builder.create<arith::AddIOp>(lhs.getLoc(), lhs, rhs); 250 llvm_unreachable("unsupported non numeric type"); 251 } 252 253 // NOLINTNEXTLINE(*-identifier-naming): externally called. 254 Value arithfn__exp(Value x) { 255 OpBuilder builder = getBuilder(); 256 if (isFloatingPoint(x)) 257 return builder.create<math::ExpOp>(x.getLoc(), x); 258 llvm_unreachable("unsupported non numeric type"); 259 } 260 261 // NOLINTNEXTLINE(*-identifier-naming): externally called. 262 Value arithfn__log(Value x) { 263 OpBuilder builder = getBuilder(); 264 if (isFloatingPoint(x)) 265 return builder.create<math::LogOp>(x.getLoc(), x); 266 llvm_unreachable("unsupported non numeric type"); 267 } 268 269 // NOLINTNEXTLINE(*-identifier-naming): externally called. 270 Value arithfn__sub(Value lhs, Value rhs) { 271 OpBuilder builder = getBuilder(); 272 if (isFloatingPoint(lhs)) 273 return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs); 274 if (isInteger(lhs)) 275 return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs); 276 llvm_unreachable("unsupported non numeric type"); 277 } 278 279 // NOLINTNEXTLINE(*-identifier-naming): externally called. 280 Value arithfn__mul(Value lhs, Value rhs) { 281 OpBuilder builder = getBuilder(); 282 if (isFloatingPoint(lhs)) 283 return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs); 284 if (isInteger(lhs)) 285 return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs); 286 llvm_unreachable("unsupported non numeric type"); 287 } 288 289 // NOLINTNEXTLINE(*-identifier-naming): externally called. 290 Value arithfn__max(Value lhs, Value rhs) { 291 OpBuilder builder = getBuilder(); 292 if (isFloatingPoint(lhs)) 293 return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs); 294 if (isInteger(lhs)) 295 return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs); 296 llvm_unreachable("unsupported non numeric type"); 297 } 298 299 // NOLINTNEXTLINE(*-identifier-naming): externally called. 300 Value arithfn__max_unsigned(Value lhs, Value rhs) { 301 OpBuilder builder = getBuilder(); 302 if (isFloatingPoint(lhs)) 303 return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs); 304 if (isInteger(lhs)) 305 return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs); 306 llvm_unreachable("unsupported non numeric type"); 307 } 308 309 // NOLINTNEXTLINE(*-identifier-naming): externally called. 310 Value arithfn__min(Value lhs, Value rhs) { 311 OpBuilder builder = getBuilder(); 312 if (isFloatingPoint(lhs)) 313 return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs); 314 if (isInteger(lhs)) 315 return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs); 316 llvm_unreachable("unsupported non numeric type"); 317 } 318 319 // NOLINTNEXTLINE(*-identifier-naming): externally called. 320 Value arithfn__min_unsigned(Value lhs, Value rhs) { 321 OpBuilder builder = getBuilder(); 322 if (isFloatingPoint(lhs)) 323 return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs); 324 if (isInteger(lhs)) 325 return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs); 326 llvm_unreachable("unsupported non numeric type"); 327 } 328 329 void yieldOutputs(ValueRange values) { 330 assert(!values.empty() && "linalg ops must yield outputs"); 331 if (values.empty()) 332 return; 333 Value first = values.front(); 334 OpBuilder builder = getBuilder(); 335 builder.create<YieldOp>(first.getLoc(), values); 336 } 337 338 Value constant(const std::string &value) { 339 OpBuilder builder = getBuilder(); 340 Location loc = builder.getUnknownLoc(); 341 Attribute valueAttr = parseAttribute(value, builder.getContext()); 342 return builder.create<arith::ConstantOp>(loc, valueAttr.getType(), 343 valueAttr); 344 } 345 346 Value index(int64_t dim) { 347 OpBuilder builder = getBuilder(); 348 return builder.create<IndexOp>(builder.getUnknownLoc(), dim); 349 } 350 351 Type getIntegerType(unsigned width) { 352 return IntegerType::get(context, width); 353 } 354 355 Type getFloat32Type() { return Float32Type::get(context); } 356 357 Type getFloat64Type() { return Float64Type::get(context); } 358 359 private: 360 MLIRContext *context; 361 Block █ 362 363 bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); } 364 bool isInteger(Value value) { return value.getType().isa<IntegerType>(); } 365 366 OpBuilder getBuilder() { 367 OpBuilder builder(context); 368 builder.setInsertionPointToEnd(&block); 369 return builder; 370 } 371 }; 372 373 } // namespace 374 375 //===----------------------------------------------------------------------===// 376 // CopyOp 377 //===----------------------------------------------------------------------===// 378 void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { 379 assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args"); 380 b.create<linalg::YieldOp>(block.getArgument(0)); 381 } 382 383 void CopyOp::build(OpBuilder &builder, OperationState &result, Value input, 384 Value output, AffineMap inputPermutation, 385 AffineMap outputPermutation, 386 ArrayRef<NamedAttribute> namedAttrs) { 387 result.addOperands({input, output}); 388 result.addAttributes(namedAttrs); 389 if (inputPermutation) 390 result.addAttribute("inputPermutation", 391 AffineMapAttr::get(inputPermutation)); 392 if (outputPermutation) 393 result.addAttribute("outputPermutation", 394 AffineMapAttr::get(outputPermutation)); 395 result.addRegion(); 396 fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(), 397 TypeRange{input.getType()}, 398 TypeRange{output.getType()}); 399 } 400 401 ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType, 402 Type outputType) { 403 OpBuilder opBuilder(parser.getContext()); 404 fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType}, 405 TypeRange{outputType}); 406 return success(); 407 } 408 409 /// CopyOp region is elided when printing. 410 void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} 411 412 static LogicalResult verify(CopyOp op) { 413 OpOperand *output = op.getOutputOperand(0); 414 OpOperand *input = op.getInputOperand(0); 415 if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get())) 416 return op.emitOpError("expects views of the same type"); 417 if (op.getRank(input) != op.getRank(output)) 418 return op.emitOpError("expects views of the same rank"); 419 auto rank = op.getNumParallelLoops(); 420 auto inputPermutationMap = op.inputPermutation(); 421 if (inputPermutationMap) { 422 if (inputPermutationMap->getNumInputs() != rank) 423 return op.emitOpError("expects optional input_permutation map of rank ") 424 << rank; 425 if (!inputPermutationMap->isPermutation()) 426 return op.emitOpError( 427 "expects optional input_permutation map to be a permutation"); 428 } 429 auto outputPermutationMap = op.outputPermutation(); 430 if (outputPermutationMap) { 431 if (outputPermutationMap->getNumInputs() != rank) 432 return op.emitOpError("expects optional output_permutation map of rank ") 433 << rank; 434 if (!outputPermutationMap->isPermutation()) 435 return op.emitOpError( 436 "expects optional output_permutation map to be a permutation"); 437 } 438 if (rank == 0 && inputPermutationMap) 439 return op.emitOpError("expected no input permutation when rank == 0"); 440 if (rank == 0 && outputPermutationMap) 441 return op.emitOpError("expected no output permutation when rank == 0"); 442 return success(); 443 } 444 445 void CopyOp::getEffects( 446 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 447 &effects) { 448 effects.emplace_back(MemoryEffects::Read::get(), input(), 449 SideEffects::DefaultResource::get()); 450 effects.emplace_back(MemoryEffects::Write::get(), output(), 451 SideEffects::DefaultResource::get()); 452 } 453 454 namespace { 455 /// Remove copy operations that copy data inplace. Requirements are: 456 /// 1) The input and output values are identical. 457 /// 2) The input and output permutation maps are identical. 458 struct EraseIdentityCopyOp : public OpRewritePattern<CopyOp> { 459 using OpRewritePattern<CopyOp>::OpRewritePattern; 460 461 LogicalResult matchAndRewrite(CopyOp copyOp, 462 PatternRewriter &rewriter) const override { 463 assert(copyOp.hasBufferSemantics()); 464 if (copyOp.input() == copyOp.output() && 465 copyOp.inputPermutation() == copyOp.outputPermutation()) { 466 rewriter.eraseOp(copyOp); 467 return success(); 468 } 469 return failure(); 470 } 471 }; 472 } // namespace 473 474 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 475 MLIRContext *context) { 476 results.add<EraseIdentityCopyOp>(context); 477 } 478 479 //===----------------------------------------------------------------------===// 480 // FillOp 481 //===----------------------------------------------------------------------===// 482 void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { 483 assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args"); 484 b.create<linalg::YieldOp>(block.getArgument(0)); 485 } 486 487 void FillOp::build(OpBuilder &builder, OperationState &result, Value value, 488 Value output) { 489 build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value, 490 output); 491 fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), 492 TypeRange{value.getType()}, 493 TypeRange{output.getType()}, {}); 494 } 495 496 ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType, 497 Type outputType) { 498 OpBuilder opBuilder(parser.getContext()); 499 fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType}, 500 TypeRange{outputType}); 501 return success(); 502 } 503 504 /// FillOp region is elided when printing. 505 void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} 506 507 static LogicalResult verify(FillOp op) { 508 OpOperand *output = op.getOutputOperand(0); 509 Type fillType = op.value().getType(); 510 if (getElementTypeOrSelf(output->get()) != fillType) 511 return op.emitOpError("expects fill type to match view elemental type"); 512 return success(); 513 } 514 515 void FillOp::getEffects( 516 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 517 &effects) { 518 if (output().getType().isa<MemRefType>()) 519 effects.emplace_back(MemoryEffects::Write::get(), output(), 520 SideEffects::DefaultResource::get()); 521 } 522 523 namespace { 524 525 /// Fold linalg.fill -> tensor.expand/collapse_shape chain. 526 /// 527 /// For such op chains, we can create new linalg.fill ops with the result 528 /// type of the tensor.expand/collapse_shape op. 529 template <typename TensorReshapeOp> 530 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> { 531 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 532 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 533 PatternRewriter &rewriter) const override { 534 auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>(); 535 if (!oldFill) 536 return failure(); 537 538 Location loc = oldFill.getLoc(); 539 auto newInit = rewriter.create<TensorReshapeOp>( 540 loc, reshapeOp.getResultType(), oldFill.output(), 541 reshapeOp.reassociation()); 542 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, oldFill.value(), newInit); 543 544 return success(); 545 } 546 }; 547 548 } // namespace 549 550 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, 551 MLIRContext *context) { 552 results.add<FoldFillWithTensorReshape<tensor::CollapseShapeOp>, 553 FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context); 554 } 555 556 //===----------------------------------------------------------------------===// 557 // GenericOps 558 //===----------------------------------------------------------------------===// 559 void GenericOp::build( 560 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, 561 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 562 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, 563 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 564 ArrayRef<NamedAttribute> attributes) { 565 build(builder, result, resultTensorTypes, inputs, outputs, 566 builder.getAffineMapArrayAttr(indexingMaps), 567 builder.getStrArrayAttr(iteratorTypes), 568 doc.empty() ? StringAttr() : builder.getStringAttr(doc), 569 libraryCall.empty() ? StringAttr() 570 : builder.getStringAttr(libraryCall)); 571 result.addAttributes(attributes); 572 if (!bodyBuild) 573 return; 574 575 SmallVector<Type, 4> blockArgTypes; 576 for (ValueRange container : {inputs, outputs}) 577 for (Value v : container) 578 blockArgTypes.push_back(getElementTypeOrSelf(v)); 579 580 OpBuilder::InsertionGuard guard(builder); 581 auto ®ion = *result.regions.front(); 582 Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); 583 bodyBuild(builder, result.location, bodyBlock->getArguments()); 584 } 585 586 void GenericOp::build( 587 OpBuilder &builder, OperationState &result, ValueRange inputs, 588 ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 589 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, 590 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 591 ArrayRef<NamedAttribute> attributes) { 592 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, 593 iteratorTypes, doc, libraryCall, bodyBuild, attributes); 594 } 595 596 void GenericOp::build( 597 OpBuilder &builder, OperationState &result, ValueRange inputs, 598 ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 599 ArrayRef<StringRef> iteratorTypes, 600 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 601 ArrayRef<NamedAttribute> attributes) { 602 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, 603 /*doc=*/"", 604 /*libraryCall=*/"", bodyBuild, attributes); 605 } 606 607 void GenericOp::build( 608 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, 609 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 610 ArrayRef<StringRef> iteratorTypes, 611 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, 612 ArrayRef<NamedAttribute> attributes) { 613 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, 614 iteratorTypes, 615 /*doc=*/"", 616 /*libraryCall=*/"", bodyBuild, attributes); 617 } 618 619 static void print(OpAsmPrinter &p, GenericOp op) { 620 p << " "; 621 622 // Print extra attributes. 623 auto genericAttrNames = op.linalgTraitAttrNames(); 624 625 llvm::StringSet<> genericAttrNamesSet; 626 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); 627 SmallVector<NamedAttribute, 8> genericAttrs; 628 for (auto attr : op->getAttrs()) 629 if (genericAttrNamesSet.count(attr.getName().strref()) > 0) 630 genericAttrs.push_back(attr); 631 if (!genericAttrs.empty()) { 632 auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); 633 p << genericDictAttr; 634 } 635 636 // Printing is shared with named ops, except for the region and attributes 637 printCommonStructuredOpParts(p, op); 638 639 genericAttrNames.push_back("operand_segment_sizes"); 640 genericAttrNamesSet.insert(genericAttrNames.back()); 641 642 bool hasExtraAttrs = false; 643 for (NamedAttribute n : op->getAttrs()) { 644 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref()))) 645 break; 646 } 647 if (hasExtraAttrs) { 648 p << " attrs = "; 649 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames); 650 } 651 652 // Print region. 653 if (!op.region().empty()) { 654 p << ' '; 655 p.printRegion(op.region()); 656 } 657 658 // Print results. 659 printNamedStructuredOpResults(p, op.result_tensors().getTypes()); 660 } 661 662 static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) { 663 DictionaryAttr dictAttr; 664 // Parse the core linalg traits that must check into a dictAttr. 665 // The name is unimportant as we will overwrite result.attributes. 666 // The core linalg traits must contain the information necessary to pass the 667 // verifier. 668 if (parser.parseAttribute(dictAttr, "_", result.attributes)) 669 return failure(); 670 result.attributes.assign(dictAttr.getValue().begin(), 671 dictAttr.getValue().end()); 672 673 // Parsing is shared with named ops, except for the region. 674 SmallVector<Type, 1> inputTypes, outputTypes; 675 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) 676 return failure(); 677 678 // Optional attributes may be added. 679 if (succeeded(parser.parseOptionalKeyword("attrs"))) 680 if (failed(parser.parseEqual()) || 681 failed(parser.parseOptionalAttrDict(result.attributes))) 682 return failure(); 683 684 SmallVector<OpAsmParser::OperandType, 8> regionOperands; 685 std::unique_ptr<Region> region = std::make_unique<Region>(); 686 SmallVector<Type, 8> operandTypes, regionTypes; 687 if (parser.parseRegion(*region, regionOperands, regionTypes)) 688 return failure(); 689 result.addRegion(std::move(region)); 690 691 // Generic ops may specify that a subset of its outputs are tensors. Such 692 // outputs are specified in the result type. 693 // TODO: may need to move output parsing before region parsing. 694 // Need to wait for declarative assembly resolution to decide. 695 SmallVector<Type, 1> outputTensorsTypes; 696 if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) 697 return failure(); 698 result.addTypes(outputTensorsTypes); 699 700 return success(); 701 } 702 703 static void getGenericEffectsImpl( 704 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 705 &effects, 706 ValueRange results, ValueRange inputBuffers, ValueRange outputs) { 707 for (Value value : results) { 708 effects.emplace_back(MemoryEffects::Allocate::get(), value, 709 SideEffects::DefaultResource::get()); 710 } 711 for (Value value : inputBuffers) { 712 effects.emplace_back(MemoryEffects::Read::get(), value, 713 SideEffects::DefaultResource::get()); 714 } 715 for (Value value : outputs) { 716 effects.emplace_back(MemoryEffects::Read::get(), value, 717 SideEffects::DefaultResource::get()); 718 effects.emplace_back(MemoryEffects::Write::get(), value, 719 SideEffects::DefaultResource::get()); 720 } 721 } 722 723 void GenericOp::getEffects( 724 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 725 &effects) { 726 SmallVector<Value> inputBuffers = getInputBufferOperands(); 727 SmallVector<Value> outputBuffers = getOutputBufferOperands(); 728 getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, 729 outputBuffers); 730 } 731 732 template <typename GenericOpType> 733 static LogicalResult verifyGenericOp(GenericOpType op) { 734 return success(); 735 } 736 737 static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } 738 739 namespace { 740 // Deduplicate redundant args of a linalg generic op. 741 // An arg is redundant if it has the same Value and indexing map as another. 742 struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> { 743 using OpRewritePattern<GenericOp>::OpRewritePattern; 744 745 LogicalResult matchAndRewrite(GenericOp genericOp, 746 PatternRewriter &rewriter) const override { 747 // Associate each input to an equivalent "canonical" input that has the same 748 // Value and indexing map. 749 // 750 // In the non-duplicate case, input `i` will have canonical input `i`. But 751 // in the case of duplicated inputs, the canonical input could be some other 752 // input `< i`. That is, a later input will have some earlier input as its 753 // canonical input. 754 llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput; 755 // For later remapping tasks like deduplicating payload block arguments, 756 // having a simple "inputIndex -> canonicalInputIndex" integer mapping is 757 // convenient. 758 SmallVector<unsigned> canonicalInputIndices; 759 for (OpOperand *opOperand : genericOp.getInputOperands()) { 760 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); 761 // STL-like maps have a convenient behavior for our use case here. In the 762 // case of duplicate keys, the insertion is rejected, and the returned 763 // iterator gives access to the value already in the map. 764 auto pair = canonicalInput.insert( 765 {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); 766 canonicalInputIndices.push_back(pair.first->second); 767 } 768 769 // If there are no duplicate args, then bail out. 770 if (canonicalInput.size() == genericOp.getNumInputs()) 771 return failure(); 772 773 // The operands for the newly canonicalized op. 774 SmallVector<Value> newInputOperands; 775 for (OpOperand *opOperand : genericOp.getInputOperands()) 776 if (canonicalInputIndices[opOperand->getOperandNumber()] == 777 opOperand->getOperandNumber()) 778 newInputOperands.push_back(opOperand->get()); 779 780 // Repair the indexing maps by filtering out the ones that have been 781 // eliminated. 782 SmallVector<AffineMap> newIndexingMaps; 783 for (OpOperand *opOperand : genericOp.getInputOperands()) 784 if (canonicalInputIndices[opOperand->getOperandNumber()] == 785 opOperand->getOperandNumber()) 786 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); 787 for (OpOperand *opOperand : genericOp.getOutputOperands()) 788 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); 789 790 // Clone the old op with new operands. 791 SmallVector<Value> outputOperands = genericOp.getOutputOperands(); 792 auto newOp = rewriter.create<GenericOp>( 793 genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, 794 outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), 795 genericOp.iterator_types(), genericOp.docAttr(), 796 genericOp.library_callAttr()); 797 798 // Copy over unknown attributes. They might be load bearing for some flow. 799 ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames(); 800 for (NamedAttribute kv : genericOp->getAttrs()) { 801 if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { 802 newOp->setAttr(kv.getName(), kv.getValue()); 803 } 804 } 805 806 rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), 807 newOp.region().begin()); 808 809 // Repair the payload entry block by RAUW'ing redundant arguments and 810 // erasing them. 811 Block &payload = newOp.region().front(); 812 SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands(); 813 for (OpOperand *opOperand : llvm::reverse(inputOperands)) { 814 // Iterate in reverse, so that we erase later args first, preventing the 815 // argument list from shifting unexpectedly and invalidating all our 816 // indices. 817 unsigned operandNumber = opOperand->getOperandNumber(); 818 if (canonicalInputIndices[operandNumber] == operandNumber) 819 continue; 820 payload.getArgument(operandNumber) 821 .replaceAllUsesWith( 822 payload.getArgument(canonicalInputIndices[operandNumber])); 823 payload.eraseArgument(operandNumber); 824 } 825 826 rewriter.replaceOp(genericOp, newOp->getResults()); 827 return success(); 828 } 829 }; 830 831 /// Remove generic operations (on tensors) that are just copying 832 /// the values from inputs to the results. Requirements are 833 /// 1) All iterator types are parallel 834 /// 2) The body contains just a yield operation with the yielded values being 835 /// the arguments corresponding to the operands. 836 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> { 837 using OpRewritePattern<GenericOp>::OpRewritePattern; 838 839 LogicalResult matchAndRewrite(GenericOp genericOp, 840 PatternRewriter &rewriter) const override { 841 if (!genericOp.hasTensorSemantics()) 842 return failure(); 843 // Check all indexing maps are identity. 844 if (llvm::any_of(genericOp.getIndexingMaps(), 845 [](AffineMap map) { return !map.isIdentity(); })) 846 return failure(); 847 848 // Check that the body of the linalg operation is just a linalg.yield 849 // operation. 850 Block &body = genericOp.region().front(); 851 if (!llvm::hasSingleElement(body)) 852 return failure(); 853 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 854 if (!yieldOp) 855 return failure(); 856 857 // Get the argument number of the returned values. That is the operand 858 // number to use for replacing uses of this operation. 859 SmallVector<Value> returnedArgs; 860 for (Value yieldVal : yieldOp.values()) { 861 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 862 if (!yieldArg || yieldArg.getOwner() != &body) 863 return failure(); 864 unsigned argumentNumber = yieldArg.getArgNumber(); 865 returnedArgs.push_back(genericOp->getOperand(argumentNumber)); 866 } 867 if (returnedArgs.size() != genericOp->getNumResults()) 868 return failure(); 869 rewriter.replaceOp(genericOp, returnedArgs); 870 return success(); 871 } 872 }; 873 } // namespace 874 875 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, 876 MLIRContext *context) { 877 results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context); 878 } 879 880 //===----------------------------------------------------------------------===// 881 // InitTensorOp 882 //===----------------------------------------------------------------------===// 883 884 void InitTensorOp::build(OpBuilder &b, OperationState &result, 885 ArrayRef<OpFoldResult> sizes, Type elementType, 886 ArrayRef<NamedAttribute> attrs) { 887 SmallVector<Value, 4> dynamicSizes; 888 SmallVector<int64_t, 4> staticSizes; 889 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 890 ShapedType::kDynamicSize); 891 auto resultType = RankedTensorType ::get(staticSizes, elementType); 892 build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); 893 result.addAttributes(attrs); 894 } 895 896 static LogicalResult verify(InitTensorOp op) { 897 RankedTensorType resultType = op.getType(); 898 SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range( 899 op.static_sizes().cast<ArrayAttr>(), 900 [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); })); 901 902 if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(), 903 op.static_sizes(), op.sizes(), 904 ShapedType::isDynamic))) 905 return failure(); 906 907 if (op.static_sizes().size() != static_cast<unsigned>(resultType.getRank())) 908 return op->emitError("expected ") 909 << resultType.getRank() << " sizes values"; 910 911 Type expectedType = InitTensorOp::inferResultType( 912 staticSizes, resultType.getElementType(), resultType.getEncoding()); 913 if (resultType != expectedType) { 914 return op.emitError("specified type ") 915 << resultType << " does not match the inferred type " 916 << expectedType; 917 } 918 return success(); 919 } 920 921 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes, 922 Type elementType, Attribute encoding) { 923 return RankedTensorType::get(staticSizes, elementType, encoding); 924 } 925 926 namespace { 927 /// Change the type of the result of a `linalg.init_tensor` by making the result 928 /// type statically sized along dimension that in the original operation where 929 /// defined as dynamic, but the size was defined using a `constant` op. For 930 /// example 931 /// 932 /// %c5 = arith.constant 5: index 933 /// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32> 934 /// 935 /// to 936 /// 937 /// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32> 938 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> { 939 using OpRewritePattern<InitTensorOp>::OpRewritePattern; 940 941 LogicalResult matchAndRewrite(InitTensorOp op, 942 PatternRewriter &rewriter) const override { 943 SmallVector<Value, 4> dynamicSizes; 944 SmallVector<int64_t, 4> staticSizes; 945 for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { 946 // If the size is already static, nothing to do. 947 if (!op.isDynamicSize(i)) { 948 staticSizes.push_back(op.getStaticSize(i)); 949 continue; 950 } 951 952 // If the size is dynamic but defined using a `constant` op, get the 953 // constant value to find the static size to use. 954 unsigned operandNum = op.getIndexOfDynamicSize(i); 955 Value sizeOperand = op.getOperand(operandNum); 956 if (auto constantIndexOp = 957 sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) { 958 staticSizes.push_back(constantIndexOp.value()); 959 continue; 960 } 961 962 // Fallback case. Keep the size dynamic. 963 dynamicSizes.push_back(sizeOperand); 964 staticSizes.push_back(ShapedType::kDynamicSize); 965 } 966 RankedTensorType newType = 967 RankedTensorType::get(staticSizes, op.getType().getElementType()); 968 if (newType == op.getType()) 969 return failure(); 970 auto newOp = 971 rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes, 972 rewriter.getI64ArrayAttr(staticSizes)); 973 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 974 return success(); 975 } 976 }; 977 } // namespace 978 979 namespace { 980 /// Since `init_tensor` operation creates a tensor needed only for its shape, a 981 /// slice of this is also needed only for its shape. The result can be 982 /// replaced by a new init_tensor operation of the same size as the extract 983 /// slice op. 984 struct FoldInitTensorWithExtractSliceOp 985 : public OpRewritePattern<tensor::ExtractSliceOp> { 986 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 987 988 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 989 PatternRewriter &rewriter) const override { 990 if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>()) 991 return failure(); 992 // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved 993 // as well as its result type. 994 rewriter.replaceOpWithNewOp<linalg::InitTensorOp>( 995 sliceOp, sliceOp.sizes(), 996 sliceOp.result().getType().cast<RankedTensorType>().getShape(), 997 sliceOp.getSourceType().getElementType()); 998 return success(); 999 } 1000 }; 1001 1002 template <typename TensorReshapeOp> 1003 struct FoldInitTensorWithTensorReshapeOp 1004 : public OpRewritePattern<TensorReshapeOp> { 1005 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1006 1007 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1008 PatternRewriter &rewriter) const override { 1009 if (!reshapeOp.src().template getDefiningOp<InitTensorOp>()) 1010 return failure(); 1011 Location loc = reshapeOp.getLoc(); 1012 ReifiedRankedShapedTypeDims resultShapes; 1013 ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = 1014 cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation()); 1015 if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, 1016 resultShapes)) || 1017 !llvm::hasSingleElement(resultShapes)) 1018 return failure(); 1019 Value initTensor = rewriter.create<InitTensorOp>( 1020 loc, getAsOpFoldResult(resultShapes[0]), 1021 reshapeOp.getResultType().getElementType()); 1022 if (initTensor.getType() != reshapeOp.getResultType()) { 1023 rewriter.replaceOpWithNewOp<tensor::CastOp>( 1024 reshapeOp, reshapeOp.getResultType(), initTensor); 1025 } else { 1026 rewriter.replaceOp(reshapeOp, initTensor); 1027 } 1028 return success(); 1029 } 1030 }; 1031 1032 struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> { 1033 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 1034 1035 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 1036 PatternRewriter &rewriter) const override { 1037 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 1038 auto initTensorOp = dimOp.source().getDefiningOp<linalg::InitTensorOp>(); 1039 if (!initTensorOp || !maybeConstantIndex) 1040 return failure(); 1041 if (!initTensorOp.isDynamicSize(*maybeConstantIndex)) 1042 return failure(); 1043 rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex)); 1044 return success(); 1045 } 1046 }; 1047 } // namespace 1048 1049 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 1050 MLIRContext *context) { 1051 results.add<FoldInitTensorWithDimOp, FoldInitTensorWithExtractSliceOp, 1052 FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>, 1053 FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>, 1054 ReplaceStaticShapeDims>(context); 1055 } 1056 1057 LogicalResult InitTensorOp::reifyResultShapes( 1058 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 1059 auto shapes = llvm::to_vector<4>(llvm::map_range( 1060 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 1061 if (isDynamicSize(dim)) 1062 return getDynamicSize(dim); 1063 return builder.create<arith::ConstantIndexOp>(getLoc(), 1064 getStaticSize(dim)); 1065 })); 1066 reifiedReturnShapes.emplace_back(std::move(shapes)); 1067 return success(); 1068 } 1069 1070 //===----------------------------------------------------------------------===// 1071 // PadTensorOp 1072 //===----------------------------------------------------------------------===// 1073 1074 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it 1075 // supports optional types. 1076 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, 1077 Type typeToInfer, Type typeToInferFrom) {} 1078 1079 ParseResult parseInferType(OpAsmParser &parser, 1080 Optional<OpAsmParser::OperandType> optOperand, 1081 Type &typeToInfer, Type typeToInferFrom) { 1082 if (optOperand) 1083 typeToInfer = typeToInferFrom; 1084 return success(); 1085 } 1086 1087 static LogicalResult verify(PadTensorOp op) { 1088 auto sourceType = op.source().getType().cast<RankedTensorType>(); 1089 auto resultType = op.result().getType().cast<RankedTensorType>(); 1090 auto expectedType = PadTensorOp::inferResultType( 1091 sourceType, extractFromI64ArrayAttr(op.static_low()), 1092 extractFromI64ArrayAttr(op.static_high())); 1093 for (int i = 0, e = sourceType.getRank(); i < e; ++i) { 1094 if (resultType.getDimSize(i) == expectedType.getDimSize(i)) 1095 continue; 1096 if (expectedType.isDynamicDim(i)) 1097 continue; 1098 return op.emitError("specified type ") 1099 << resultType << " does not match the inferred type " 1100 << expectedType; 1101 } 1102 1103 auto ®ion = op.region(); 1104 unsigned rank = resultType.getRank(); 1105 Block &block = region.front(); 1106 if (block.getNumArguments() != rank) 1107 return op.emitError("expected the block to have ") << rank << " arguments"; 1108 1109 // Note: the number and type of yield values are checked in the YieldOp. 1110 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) { 1111 if (!en.value().isIndex()) 1112 return op.emitOpError("expected block argument ") 1113 << (en.index() + 1) << " to be an index"; 1114 } 1115 1116 return success(); 1117 } 1118 1119 RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType, 1120 ArrayRef<int64_t> staticLow, 1121 ArrayRef<int64_t> staticHigh, 1122 ArrayRef<int64_t> resultShape) { 1123 unsigned rank = sourceType.getRank(); 1124 assert(staticLow.size() == rank && "unexpected staticLow size mismatch"); 1125 assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch"); 1126 assert((resultShape.empty() || resultShape.size() == rank) && 1127 "unexpected resultShape size mismatch"); 1128 1129 SmallVector<int64_t, 4> inferredShape; 1130 for (auto i : llvm::seq<unsigned>(0, rank)) { 1131 if (sourceType.isDynamicDim(i) || 1132 staticLow[i] == ShapedType::kDynamicSize || 1133 staticHigh[i] == ShapedType::kDynamicSize) { 1134 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize 1135 : resultShape[i]); 1136 } else { 1137 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; 1138 assert((resultShape.empty() || size == resultShape[i] || 1139 resultShape[i] == ShapedType::kDynamicSize) && 1140 "mismatch between inferred shape and result shape"); 1141 inferredShape.push_back(size); 1142 } 1143 } 1144 1145 return RankedTensorType::get(inferredShape, sourceType.getElementType()); 1146 } 1147 1148 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, 1149 ArrayRef<int64_t> staticLow, 1150 ArrayRef<int64_t> staticHigh, ValueRange low, 1151 ValueRange high, bool nofold, 1152 ArrayRef<NamedAttribute> attrs) { 1153 auto sourceType = source.getType().cast<RankedTensorType>(); 1154 auto resultType = inferResultType(sourceType, staticLow, staticHigh); 1155 build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), 1156 b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); 1157 result.addAttributes(attrs); 1158 } 1159 1160 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, 1161 ValueRange low, ValueRange high, bool nofold, 1162 ArrayRef<NamedAttribute> attrs) { 1163 auto sourceType = source.getType().cast<RankedTensorType>(); 1164 unsigned rank = sourceType.getRank(); 1165 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize); 1166 build(b, result, source, staticVector, staticVector, low, high, nofold, 1167 attrs); 1168 } 1169 1170 void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType, 1171 Value source, ArrayRef<OpFoldResult> low, 1172 ArrayRef<OpFoldResult> high, bool nofold, 1173 ArrayRef<NamedAttribute> attrs) { 1174 assert(resultType.isa<RankedTensorType>()); 1175 auto sourceType = source.getType().cast<RankedTensorType>(); 1176 SmallVector<Value, 4> dynamicLow, dynamicHigh; 1177 SmallVector<int64_t, 4> staticLow, staticHigh; 1178 // staticLow and staticHigh have full information of the padding config. 1179 // This will grow staticLow and staticHigh with 1 value. If the config is 1180 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 1181 // value as well. 1182 dispatchIndexOpFoldResults(low, dynamicLow, staticLow, 1183 ShapedType::kDynamicSize); 1184 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh, 1185 ShapedType::kDynamicSize); 1186 if (!resultType) { 1187 resultType = 1188 PadTensorOp::inferResultType(sourceType, staticLow, staticHigh); 1189 } 1190 build(b, result, resultType, source, dynamicLow, dynamicHigh, 1191 b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh), 1192 nofold ? b.getUnitAttr() : UnitAttr()); 1193 result.addAttributes(attrs); 1194 } 1195 1196 PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad, 1197 ArrayRef<OpFoldResult> low, 1198 ArrayRef<OpFoldResult> high, 1199 bool nofold, Location loc, 1200 OpBuilder &builder) { 1201 auto padTensorOp = 1202 builder.create<linalg::PadTensorOp>(loc, type, source, low, high, nofold); 1203 int rank = padTensorOp.getResultType().getRank(); 1204 SmallVector<Type, 4> blockArgTypes; 1205 blockArgTypes.assign(rank, builder.getIndexType()); 1206 auto ®ion = padTensorOp.region(); 1207 // `builder.createBlock` changes the insertion point within the block. Create 1208 // a guard to reset the insertion point of the builder after it is destroyed. 1209 OpBuilder::InsertionGuard guard(builder); 1210 builder.createBlock(®ion, region.end(), blockArgTypes); 1211 builder.create<linalg::YieldOp>(loc, pad); 1212 return padTensorOp; 1213 } 1214 1215 PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad, 1216 bool nofold, Location loc, 1217 OpBuilder &b) { 1218 SmallVector<OpFoldResult, 4> low, high; 1219 auto rankedTensorType = type.cast<RankedTensorType>(); 1220 assert(rankedTensorType.hasStaticShape()); 1221 for (const auto &en : enumerate(rankedTensorType.getShape())) { 1222 AffineExpr d0; 1223 bindDims(b.getContext(), d0); 1224 auto dimOp = b.createOrFold<tensor::DimOp>(loc, source, en.index()); 1225 Value paddingWidth = 1226 makeComposedAffineApply(b, loc, en.value() - d0, {dimOp}); 1227 high.push_back(paddingWidth); 1228 low.push_back(b.createOrFold<arith::ConstantIndexOp>(loc, 0)); 1229 } 1230 return PadTensorOp::createPadScalarOp(type, source, pad, low, high, nofold, 1231 loc, b); 1232 } 1233 1234 LogicalResult PadTensorOp::reifyResultShapes( 1235 OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 1236 Location loc = getLoc(); 1237 auto lowPad = getMixedLowPad(); 1238 auto highPad = getMixedHighPad(); 1239 SmallVector<Value> shapes; 1240 for (auto dim : llvm::seq<int64_t>(0, getSourceType().getRank())) { 1241 // Shape along each dimension is source dim + low pad + high pad. 1242 SmallVector<Value> mapOperands; 1243 mapOperands.push_back(b.createOrFold<tensor::DimOp>(loc, source(), dim)); 1244 AffineExpr expr = b.getAffineDimExpr(0); 1245 unsigned numSymbols = 0; 1246 auto addOpFoldResult = [&](OpFoldResult valueOrAttr) { 1247 if (Value v = valueOrAttr.dyn_cast<Value>()) { 1248 expr = expr + b.getAffineSymbolExpr(numSymbols++); 1249 mapOperands.push_back(v); 1250 return; 1251 } 1252 int64_t staticValue = 1253 valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt(); 1254 expr = expr + staticValue; 1255 }; 1256 addOpFoldResult(lowPad[dim]); 1257 addOpFoldResult(highPad[dim]); 1258 shapes.push_back(applyMapToValues( 1259 b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); 1260 } 1261 reifiedReturnShapes.emplace_back(std::move(shapes)); 1262 return success(); 1263 } 1264 1265 //===----------------------------------------------------------------------===// 1266 // Methods related to PadTensor tiling. 1267 //===----------------------------------------------------------------------===// 1268 1269 SmallVector<Value> PadTensorOp::getDestinationOperands(OpBuilder &b) { 1270 ReifiedRankedShapedTypeDims reifiedShapes; 1271 (void)reifyResultShapes(b, reifiedShapes); 1272 SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]); 1273 Value initTensor = b.create<InitTensorOp>(getLoc(), mixedSizes, 1274 getResultType().getElementType()); 1275 return {initTensor}; 1276 } 1277 1278 SmallVector<StringRef> PadTensorOp::getLoopIteratorTypes() { 1279 SmallVector<StringRef> iteratorTypes(getResultType().getRank(), 1280 getParallelIteratorTypeName()); 1281 return iteratorTypes; 1282 } 1283 1284 SmallVector<Range> PadTensorOp::getIterationDomain(OpBuilder &b) { 1285 ReifiedRankedShapedTypeDims reifiedShapes; 1286 (void)reifyResultShapes(b, reifiedShapes); 1287 Value zero = b.create<arith::ConstantIndexOp>(getLoc(), 0); 1288 Value one = b.create<arith::ConstantIndexOp>(getLoc(), 1); 1289 // Initialize all the ranges to {zero, one, one}. All the `ub`s are 1290 // overwritten. 1291 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); 1292 for (const auto &ub : enumerate(reifiedShapes[0])) 1293 loopRanges[ub.index()].size = ub.value(); 1294 return loopRanges; 1295 } 1296 1297 SmallVector<Operation *> PadTensorOp::getTiledImplementation( 1298 OpBuilder &b, ValueRange dest, ArrayRef<OpFoldResult> offsets, 1299 ArrayRef<OpFoldResult> sizes, bool /*tileDestOperands*/) { 1300 // Only constant padding value supported. 1301 Value padValue = getConstantPaddingValue(); 1302 if (!padValue) 1303 return {}; 1304 1305 // Helper variables and functions for various arithmetic operations. These are 1306 // used extensively for computing new offset/length and padding values. 1307 Location loc = getLoc(); 1308 AffineExpr dim0, dim1; 1309 bindDims(b.getContext(), dim0, dim1); 1310 // Add two integers. 1311 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 1312 auto add = [&](Value v1, Value v2) { 1313 return b.createOrFold<AffineApplyOp>(loc, addMap, ValueRange{v1, v2}); 1314 }; 1315 // Subtract two integers. 1316 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 1317 auto sub = [&](Value v1, Value v2) { 1318 return b.createOrFold<AffineApplyOp>(loc, subMap, ValueRange{v1, v2}); 1319 }; 1320 // Take the minimum of two integers. 1321 auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); 1322 auto min = [&](Value v1, Value v2) { 1323 return b.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2}); 1324 }; 1325 // Take the maximum of two integers. 1326 auto max = [&](Value v1, Value v2) { 1327 return b.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2}); 1328 }; 1329 // Zero index-typed integer. 1330 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 1331 1332 // Helper function for filling static/dynamic low/high padding indices vectors 1333 // of PadTensorOp. 1334 auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices, 1335 SmallVector<int64_t> &staticIndices) { 1336 if (auto constInt = getConstantIntValue(val)) { 1337 staticIndices.push_back(*constInt); 1338 } else { 1339 staticIndices.push_back(ShapedType::kDynamicSize); 1340 dynIndices.push_back(val); 1341 } 1342 }; 1343 1344 // Compute new offsets, lengths, low padding, high padding. 1345 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 1346 SmallVector<Value> newLows, newHighs; 1347 SmallVector<int64_t> staticNewLows, staticNewHighs; 1348 // Set to true if the original data source is not read at all. 1349 bool hasZeroLen = false; 1350 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 1351 // is true if the original data source turns out to be unused at runtime. 1352 Value dynHasZeroLenCond; 1353 1354 int64_t rank = getSourceType().getRank(); 1355 for (unsigned dim = 0; dim < rank; ++dim) { 1356 auto low = getValueOrCreateConstantIndexOp(b, loc, getMixedLowPad()[dim]); 1357 bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0); 1358 auto high = getValueOrCreateConstantIndexOp(b, loc, getMixedHighPad()[dim]); 1359 bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0); 1360 auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]); 1361 auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]); 1362 auto srcSize = b.createOrFold<tensor::DimOp>(loc, source(), dim); 1363 1364 // The new amount of low padding is `low - offset`. Except for the case 1365 // where none of the low padding is read. In that case, the new amount of 1366 // low padding is zero. 1367 // 1368 // Optimization: If low = 0, then newLow = 0. 1369 Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 1370 appendIndex(newLow, newLows, staticNewLows); 1371 1372 // Start reading the data from position `offset - low`. Since the original 1373 // read may have started in the low padding zone, this value could be 1374 // negative. Therefore, start reading from: 1375 // 1376 // max(offset - low, 0) 1377 // 1378 // The original read could also have started in the high padding zone. 1379 // In that case, set the offset to the end of source tensor. The new 1380 // ExtractSliceOp length will be zero in that case. (Effectively reading no 1381 // data from the source.) 1382 // 1383 // Optimization: If low = 0, then the formula can be simplified. 1384 Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize) 1385 : min(offset, srcSize); 1386 newOffsets.push_back(getAsOpFoldResult(newOffset)); 1387 1388 // The original ExtractSliceOp was reading until position `offset + length`. 1389 // Therefore, the corresponding position within the source tensor is: 1390 // 1391 // offset + length - low 1392 // 1393 // In case the original ExtractSliceOp stopped reading within the low 1394 // padding zone, this value can be negative. In that case, the end position 1395 // of the read should be zero. (Similar to newOffset.) 1396 // 1397 // The original read could also have stopped in the high padding zone. 1398 // In that case, set the end positition of the read should be the end of the 1399 // source tensor. (Similar to newOffset.) 1400 // 1401 // endLoc = min(max(offset - low + length, 0), srcSize) 1402 // 1403 // The new ExtractSliceOp length is `endLoc - newOffset`. 1404 // 1405 // Optimization: If low = 0, then the formula can be simplified. 1406 Value endLoc = hasLowPad 1407 ? min(max(add(sub(offset, low), length), zero), srcSize) 1408 : min(add(offset, length), srcSize); 1409 Value newLength = sub(endLoc, newOffset); 1410 newLengths.push_back(getAsOpFoldResult(newLength)); 1411 1412 // Check if newLength is zero. In that case, no SubTensorOp should be 1413 // executed. 1414 if (auto newLengthInt = getConstantIntValue(newLength)) { 1415 hasZeroLen |= *newLengthInt == 0; 1416 } else { 1417 Value check = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1418 newLength, zero); 1419 dynHasZeroLenCond = 1420 dynHasZeroLenCond 1421 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) 1422 : check; 1423 } 1424 1425 // The amount of high padding is simply the number of elements remaining, 1426 // so that the result has the same length as the original ExtractSliceOp. 1427 // As an optimization, if the original high padding is zero, then the new 1428 // high padding must also be zero. 1429 Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero; 1430 appendIndex(newHigh, newHighs, staticNewHighs); 1431 1432 // Only unit stride supported. 1433 newStrides.push_back(b.getIndexAttr(1)); 1434 } 1435 1436 // The shape of the result can be obtained from the sizes passed in. 1437 SmallVector<Value> dynDims; 1438 SmallVector<int64_t> shape; 1439 dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize); 1440 RankedTensorType resultType = 1441 RankedTensorType::get(shape, getResultType().getElementType()); 1442 1443 // Insert cast to ensure that types match. (May be folded away.) 1444 auto castResult = [&](Value val) -> Operation * { 1445 auto castOp = b.create<tensor::CastOp>(loc, resultType, val); 1446 return castOp; 1447 }; 1448 1449 // In cases where the original data source is unused: Emit a GenerateOp and 1450 // do not generate a SliceOp. (The result shape of the SliceOp would 1451 // have a dimension of size 0, the semantics of which is unclear.) 1452 auto createGenerateOp = [&]() { 1453 // Create GenerateOp. 1454 auto generateOp = b.create<tensor::GenerateOp>( 1455 loc, resultType, dynDims, 1456 [&](OpBuilder &builder, Location gLoc, ValueRange indices) { 1457 builder.create<tensor::YieldOp>(gLoc, padValue); 1458 }); 1459 return castResult(generateOp); 1460 }; 1461 1462 // Emit a SliceOp and a PadTensorOp. Should not be used in cases where 1463 // the result shape of the new SliceOp has a zero dimension. 1464 auto createPadTensorOfSubTensor = [&]() { 1465 // Create pad_tensor(subtensor(x)). 1466 auto newSliceOp = b.create<tensor::ExtractSliceOp>( 1467 loc, source(), newOffsets, newLengths, newStrides); 1468 auto newPadTensorOp = b.create<PadTensorOp>( 1469 loc, newSliceOp, staticNewLows, staticNewHighs, newLows, newHighs); 1470 1471 // Copy region to new PadTensorOp. 1472 BlockAndValueMapping bvm; 1473 region().cloneInto(&newPadTensorOp.getRegion(), bvm); 1474 1475 // Cast result and return. 1476 return castResult(newPadTensorOp); 1477 }; 1478 1479 // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known 1480 // that the original data source x is not used. 1481 if (hasZeroLen) { 1482 return {createGenerateOp()}; 1483 } 1484 1485 // If there are dynamic dimensions: Generate an scf.if check to avoid creating 1486 // SliceOps with result dimensions of size 0 at runtime. 1487 if (dynHasZeroLenCond) { 1488 auto result = b.create<scf::IfOp>( 1489 loc, resultType, dynHasZeroLenCond, 1490 /*thenBuilder=*/ 1491 [&](OpBuilder &b, Location loc) { 1492 b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0)); 1493 }, 1494 /*elseBuilder=*/ 1495 [&](OpBuilder &b, Location loc) { 1496 b.create<scf::YieldOp>(loc, 1497 createPadTensorOfSubTensor()->getResult(0)); 1498 }); 1499 return {result}; 1500 } 1501 return {createPadTensorOfSubTensor()}; 1502 } 1503 1504 namespace { 1505 // Folds linalg.pad_tensor when padding is static zeros and the attribute 1506 // doesn't request otherwise. 1507 struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> { 1508 using OpRewritePattern<PadTensorOp>::OpRewritePattern; 1509 1510 LogicalResult matchAndRewrite(PadTensorOp padTensorOp, 1511 PatternRewriter &rewriter) const override { 1512 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad()) 1513 return failure(); 1514 if (padTensorOp.nofold()) 1515 return failure(); 1516 rewriter.replaceOpWithNewOp<tensor::CastOp>( 1517 padTensorOp, padTensorOp.result().getType(), padTensorOp.source()); 1518 return success(); 1519 } 1520 }; 1521 1522 // Fold CastOp into PadTensorOp when adding static information. 1523 struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> { 1524 using OpRewritePattern<PadTensorOp>::OpRewritePattern; 1525 1526 LogicalResult matchAndRewrite(PadTensorOp padTensorOp, 1527 PatternRewriter &rewriter) const override { 1528 auto castOp = padTensorOp.source().getDefiningOp<tensor::CastOp>(); 1529 if (!tensor::canFoldIntoConsumerOp(castOp)) 1530 return failure(); 1531 1532 auto newResultType = PadTensorOp::inferResultType( 1533 castOp.source().getType().cast<RankedTensorType>(), 1534 extractFromI64ArrayAttr(padTensorOp.static_low()), 1535 extractFromI64ArrayAttr(padTensorOp.static_high()), 1536 padTensorOp.getResultType().getShape()); 1537 1538 if (newResultType == padTensorOp.getResultType()) { 1539 rewriter.updateRootInPlace(padTensorOp, [&]() { 1540 padTensorOp.sourceMutable().assign(castOp.source()); 1541 }); 1542 } else { 1543 auto newOp = rewriter.create<PadTensorOp>( 1544 padTensorOp->getLoc(), newResultType, padTensorOp.source(), 1545 padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(), 1546 padTensorOp.static_high(), padTensorOp.nofold()); 1547 BlockAndValueMapping mapper; 1548 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); 1549 1550 rewriter.replaceOpWithNewOp<tensor::CastOp>( 1551 padTensorOp, padTensorOp.getResultType(), newOp); 1552 } 1553 return success(); 1554 } 1555 }; 1556 1557 // Fold CastOp using the result of PadTensorOp back into the latter if it adds 1558 // static information. 1559 struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> { 1560 using OpRewritePattern<PadTensorOp>::OpRewritePattern; 1561 1562 LogicalResult matchAndRewrite(PadTensorOp padTensorOp, 1563 PatternRewriter &rewriter) const override { 1564 if (!padTensorOp.result().hasOneUse()) 1565 return failure(); 1566 auto tensorCastOp = 1567 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin()); 1568 if (!tensorCastOp) 1569 return failure(); 1570 if (!tensor::preservesStaticInformation(padTensorOp.result().getType(), 1571 tensorCastOp.dest().getType())) 1572 return failure(); 1573 1574 auto replacementOp = rewriter.create<PadTensorOp>( 1575 padTensorOp.getLoc(), tensorCastOp.dest().getType(), 1576 padTensorOp.source(), padTensorOp.low(), padTensorOp.high(), 1577 padTensorOp.static_low(), padTensorOp.static_high(), 1578 padTensorOp.nofold()); 1579 replacementOp.region().takeBody(padTensorOp.region()); 1580 1581 rewriter.replaceOp(padTensorOp, replacementOp.result()); 1582 rewriter.replaceOp(tensorCastOp, replacementOp.result()); 1583 return success(); 1584 } 1585 }; 1586 } // namespace 1587 1588 void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 1589 MLIRContext *context) { 1590 results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context); 1591 results.add<FoldTargetTensorCast>(context); 1592 } 1593 1594 /// Return the padding value of the PadTensorOp if it constant. In this context, 1595 /// "constant" means an actual constant or "defined outside of the block". 1596 /// 1597 /// Values are considered constant in three cases: 1598 /// - A ConstantLike value. 1599 /// - A basic block argument from a different block. 1600 /// - A value defined outside of the block. 1601 /// 1602 /// If the padding value is not constant, an empty Value is returned. 1603 Value PadTensorOp::getConstantPaddingValue() { 1604 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator()); 1605 if (!yieldOp || yieldOp.values().size() != 1) 1606 return {}; 1607 Value padValue = yieldOp.values().front(); 1608 // Check if yield value is a constant. 1609 if (matchPattern(padValue, m_Constant())) 1610 return padValue; 1611 // Check if yield value is defined inside the PadTensorOp block. 1612 if (padValue.getParentBlock() == &getRegion().front()) 1613 return {}; 1614 // Else: Yield value defined outside of the PadTensorOp block. 1615 return padValue; 1616 } 1617 1618 OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) { 1619 if (getResultType().hasStaticShape() && getResultType() == getSourceType() && 1620 !nofold()) 1621 return source(); 1622 return {}; 1623 } 1624 1625 //===----------------------------------------------------------------------===// 1626 // YieldOp 1627 //===----------------------------------------------------------------------===// 1628 1629 static void print(OpAsmPrinter &p, linalg::YieldOp op) { 1630 if (op.getNumOperands() > 0) 1631 p << ' ' << op.getOperands(); 1632 p.printOptionalAttrDict(op->getAttrs()); 1633 if (op.getNumOperands() > 0) 1634 p << " : " << op.getOperandTypes(); 1635 } 1636 1637 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { 1638 SmallVector<OpAsmParser::OperandType, 2> opInfo; 1639 SmallVector<Type, 2> types; 1640 llvm::SMLoc loc = parser.getCurrentLocation(); 1641 return failure(parser.parseOperandList(opInfo) || 1642 parser.parseOptionalAttrDict(result.attributes) || 1643 (!opInfo.empty() && parser.parseColonTypeList(types)) || 1644 parser.resolveOperands(opInfo, types, loc, result.operands)); 1645 } 1646 1647 // Check the operand number and types must match the element types of the 1648 // LinalgOp interface's shaped operands. 1649 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { 1650 if (op.getNumOperands() != linalgOp.getNumOutputs()) 1651 return op.emitOpError("expected number of yield values (") 1652 << linalgOp.getNumOutputs() 1653 << ") to match the number of operands of the enclosing " 1654 << "LinalgOp (" << op.getNumOperands() << ")"; 1655 1656 for (OpOperand &opOperand : op->getOpOperands()) { 1657 OpOperand *outputOperand = 1658 linalgOp.getOutputOperand(opOperand.getOperandNumber()); 1659 Type elementType = getElementTypeOrSelf(outputOperand->get().getType()); 1660 if (opOperand.get().getType() != elementType) 1661 return op.emitOpError("type of yield operand ") 1662 << (opOperand.getOperandNumber() + 1) << " (" 1663 << opOperand.get().getType() << ") doesn't match " 1664 << "the element type of the enclosing linalg.generic op (" 1665 << elementType << ")"; 1666 } 1667 return success(); 1668 } 1669 1670 static LogicalResult verify(linalg::YieldOp op) { 1671 auto *parentOp = op->getParentOp(); 1672 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) 1673 return op.emitOpError("expected single non-empty parent region"); 1674 1675 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp)) 1676 return verifyYield(op, cast<LinalgOp>(parentOp)); 1677 1678 if (auto padTensorOp = dyn_cast<linalg::PadTensorOp>(parentOp)) { 1679 if (op.getNumOperands() != 1) 1680 return op.emitOpError("expected single yield operand (got ") 1681 << op->getNumOperands() << ")"; 1682 if (op.getOperand(0).getType() != 1683 padTensorOp.getType().cast<ShapedType>().getElementType()) 1684 return op.emitOpError("expected yield type to match shape element type"); 1685 return success(); 1686 } 1687 1688 if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) { 1689 // Check if output args with tensor types match results types. 1690 SmallVector<Value, 2> tensorOuts; 1691 llvm::copy_if( 1692 tiledLoopOp.outputs(), std::back_inserter(tensorOuts), 1693 [&](Value out) { return out.getType().isa<RankedTensorType>(); }); 1694 if (tensorOuts.size() != op.values().size()) 1695 return op.emitOpError("expected number of tensor output args = ") 1696 << tensorOuts.size() << " to match the number of yield operands = " 1697 << op.values().size(); 1698 1699 TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts)); 1700 for (auto &item : 1701 llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) { 1702 Type outType, resultType; 1703 unsigned index = item.index(); 1704 std::tie(outType, resultType) = item.value(); 1705 if (outType != resultType) 1706 return op.emitOpError("expected yield operand ") 1707 << index << " with type = " << resultType 1708 << " to match output arg type = " << outType; 1709 } 1710 return success(); 1711 } 1712 return op.emitOpError("expected parent op with LinalgOp interface"); 1713 } 1714 1715 //===----------------------------------------------------------------------===// 1716 // TiledLoopOp 1717 //===----------------------------------------------------------------------===// 1718 1719 void TiledLoopOp::build(OpBuilder &builder, OperationState &result, 1720 ValueRange lowerBounds, ValueRange upperBounds, 1721 ValueRange steps, ValueRange inputs, ValueRange outputs, 1722 ArrayAttr iteratorTypes, 1723 function_ref<void(OpBuilder &, Location, ValueRange, 1724 ValueRange, ValueRange)> 1725 bodyBuilderFn) { 1726 build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs, 1727 iteratorTypes, llvm::None, bodyBuilderFn); 1728 } 1729 1730 void TiledLoopOp::build(OpBuilder &builder, OperationState &result, 1731 ValueRange lowerBounds, ValueRange upperBounds, 1732 ValueRange steps, ValueRange inputs, ValueRange outputs, 1733 ArrayAttr iteratorTypes, 1734 Optional<ArrayAttr> distributionTypes, 1735 function_ref<void(OpBuilder &, Location, ValueRange, 1736 ValueRange, ValueRange)> 1737 bodyBuilderFn) { 1738 result.addOperands(lowerBounds); 1739 result.addOperands(upperBounds); 1740 result.addOperands(steps); 1741 result.addOperands(inputs); 1742 result.addOperands(outputs); 1743 result.addAttribute( 1744 TiledLoopOp::getOperandSegmentSizeAttr(), 1745 builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()), 1746 static_cast<int32_t>(upperBounds.size()), 1747 static_cast<int32_t>(steps.size()), 1748 static_cast<int32_t>(inputs.size()), 1749 static_cast<int32_t>(outputs.size())})); 1750 result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); 1751 1752 if (distributionTypes.hasValue()) 1753 result.addAttribute(getDistributionTypesAttrName(), 1754 distributionTypes.getValue()); 1755 1756 // Add output types for `RankedTensorType` output arguments. 1757 for (Value output : outputs) { 1758 Type outputType = output.getType(); 1759 if (outputType.isa<RankedTensorType>()) 1760 result.addTypes(outputType); 1761 } 1762 1763 OpBuilder::InsertionGuard guard(builder); 1764 unsigned numIVs = steps.size(); 1765 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType()); 1766 for (Type type : TypeRange(inputs)) 1767 argTypes.push_back(type); 1768 for (Type type : TypeRange(outputs)) 1769 argTypes.push_back(type); 1770 Region *bodyRegion = result.addRegion(); 1771 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes); 1772 1773 if (bodyBuilderFn) { 1774 builder.setInsertionPointToStart(bodyBlock); 1775 bodyBuilderFn(builder, result.location, 1776 bodyBlock->getArguments().take_front(numIVs), 1777 bodyBlock->getArguments().slice(numIVs, inputs.size()), 1778 bodyBlock->getArguments().take_back(outputs.size())); 1779 TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); 1780 } 1781 } 1782 1783 static void print(OpAsmPrinter &p, TiledLoopOp op) { 1784 p << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to (" 1785 << op.upperBound() << ") step (" << op.step() << ")"; 1786 1787 if (!op.inputs().empty()) { 1788 p << " ins ("; 1789 llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p, 1790 [&](auto it) { 1791 p << std::get<0>(it) << " = " << std::get<1>(it) 1792 << ": " << std::get<1>(it).getType(); 1793 }); 1794 p << ")"; 1795 } 1796 if (!op.outputs().empty()) { 1797 p << " outs ("; 1798 llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p, 1799 [&](auto it) { 1800 p << std::get<0>(it) << " = " << std::get<1>(it) 1801 << ": " << std::get<1>(it).getType(); 1802 }); 1803 p << ")"; 1804 } 1805 1806 if (llvm::any_of(op.iterator_types(), [](Attribute attr) { 1807 return attr.cast<StringAttr>().getValue() != 1808 getParallelIteratorTypeName(); 1809 })) 1810 p << " iterators" << op.iterator_types(); 1811 1812 if (op.distribution_types().hasValue()) 1813 p << " distribution" << op.distribution_types().getValue(); 1814 1815 p << ' '; 1816 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 1817 p.printOptionalAttrDict( 1818 op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(), 1819 getIteratorTypesAttrName(), 1820 getDistributionTypesAttrName()}); 1821 } 1822 1823 static ParseResult parseTiledLoopOp(OpAsmParser &parser, 1824 OperationState &result) { 1825 auto &builder = parser.getBuilder(); 1826 // Parse an opening `(` followed by induction variables followed by `)` 1827 SmallVector<OpAsmParser::OperandType, 4> ivs; 1828 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 1829 OpAsmParser::Delimiter::Paren)) 1830 return failure(); 1831 1832 // Parse loop bounds. 1833 SmallVector<OpAsmParser::OperandType, 4> lower; 1834 if (parser.parseEqual() || 1835 parser.parseOperandList(lower, ivs.size(), 1836 OpAsmParser::Delimiter::Paren) || 1837 parser.resolveOperands(lower, builder.getIndexType(), result.operands)) 1838 return failure(); 1839 1840 SmallVector<OpAsmParser::OperandType, 4> upper; 1841 if (parser.parseKeyword("to") || 1842 parser.parseOperandList(upper, ivs.size(), 1843 OpAsmParser::Delimiter::Paren) || 1844 parser.resolveOperands(upper, builder.getIndexType(), result.operands)) 1845 return failure(); 1846 1847 // Parse step values. 1848 SmallVector<OpAsmParser::OperandType, 4> steps; 1849 if (parser.parseKeyword("step") || 1850 parser.parseOperandList(steps, ivs.size(), 1851 OpAsmParser::Delimiter::Paren) || 1852 parser.resolveOperands(steps, builder.getIndexType(), result.operands)) 1853 return failure(); 1854 1855 // Parse input tensors. 1856 SmallVector<OpAsmParser::OperandType, 4> inputs, inputRegionArgs; 1857 SmallVector<Type, 4> inputTypes; 1858 if (succeeded(parser.parseOptionalKeyword("ins"))) { 1859 llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation(); 1860 1861 if (parser.parseAssignmentListWithTypes(inputRegionArgs, inputs, 1862 inputTypes)) 1863 return failure(); 1864 1865 if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc, 1866 result.operands)) 1867 return failure(); 1868 } 1869 1870 // Parse output tensors. 1871 SmallVector<OpAsmParser::OperandType, 4> outputs, outputRegionArgs; 1872 SmallVector<Type, 4> outputTypes; 1873 if (succeeded(parser.parseOptionalKeyword("outs"))) { 1874 llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation(); 1875 1876 if (parser.parseAssignmentListWithTypes(outputRegionArgs, outputs, 1877 outputTypes)) 1878 return failure(); 1879 1880 if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc, 1881 result.operands)) 1882 return failure(); 1883 for (Type outputType : outputTypes) 1884 if (outputType.isa<RankedTensorType>()) 1885 result.addTypes(outputType); 1886 } 1887 1888 // Parse attributes. 1889 SmallVector<Attribute, 4> iterTypes, distributionTypes; 1890 auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) { 1891 if (succeeded(parser.parseOptionalKeyword(keyword))) { 1892 StringAttr attr; 1893 1894 if (parser.parseLSquare() || parser.parseAttribute(attr)) 1895 return failure(); 1896 attrs->push_back(attr); 1897 for (int i = 1, e = ivs.size(); i < e; ++i) { 1898 if (parser.parseComma() || parser.parseAttribute(attr)) 1899 return failure(); 1900 attrs->push_back(attr); 1901 } 1902 if (parser.parseRSquare()) 1903 return failure(); 1904 } 1905 return success(); 1906 }; 1907 if (failed(parseAttr("iterators", &iterTypes)) || 1908 failed(parseAttr("distribution", &distributionTypes))) 1909 return failure(); 1910 1911 // Set all loop iterator types to "parallel" if they are not printed in IR. 1912 if (iterTypes.empty()) { 1913 auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName()); 1914 iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter); 1915 } 1916 result.addAttribute(getIteratorTypesAttrName(), 1917 builder.getArrayAttr(iterTypes)); 1918 if (!distributionTypes.empty()) 1919 result.addAttribute(getDistributionTypesAttrName(), 1920 builder.getArrayAttr(distributionTypes)); 1921 result.addAttribute( 1922 TiledLoopOp::getOperandSegmentSizeAttr(), 1923 builder.getI32VectorAttr({static_cast<int32_t>(lower.size()), 1924 static_cast<int32_t>(upper.size()), 1925 static_cast<int32_t>(steps.size()), 1926 static_cast<int32_t>(inputs.size()), 1927 static_cast<int32_t>(outputs.size())})); 1928 1929 // Parse the body. 1930 Region *body = result.addRegion(); 1931 1932 SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType()); 1933 regionTypes.append(inputTypes); 1934 regionTypes.append(outputTypes); 1935 1936 SmallVector<OpAsmParser::OperandType, 4> regionArgs(ivs); 1937 regionArgs.append(inputRegionArgs); 1938 regionArgs.append(outputRegionArgs); 1939 1940 if (parser.parseRegion(*body, regionArgs, regionTypes)) 1941 return failure(); 1942 1943 // Parse optional attributes. 1944 parser.parseOptionalAttrDict(result.attributes); 1945 1946 return success(); 1947 } 1948 1949 Region &TiledLoopOp::getLoopBody() { return region(); } 1950 1951 LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef<Operation *> ops) { 1952 for (auto *op : ops) 1953 op->moveBefore(*this); 1954 return success(); 1955 } 1956 1957 bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) { 1958 return !region().isAncestor(value.getParentRegion()); 1959 } 1960 1961 static LogicalResult verify(TiledLoopOp op) { 1962 // Check if iterator types are provided for every loop dimension. 1963 if (op.iterator_types().size() != op.getNumLoops()) 1964 return op.emitOpError("expected iterator types array attribute size = ") 1965 << op.iterator_types().size() 1966 << " to match the number of loops = " << op.getNumLoops(); 1967 1968 // Check if types of input arguments match region args types. 1969 for (auto &item : 1970 llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) { 1971 Value input, inputRegionArg; 1972 unsigned index = item.index(); 1973 std::tie(input, inputRegionArg) = item.value(); 1974 if (input.getType() != inputRegionArg.getType()) 1975 return op.emitOpError("expected input arg ") 1976 << index << " with type = " << input.getType() 1977 << " to match region arg " << index + op.getNumLoops() 1978 << " type = " << inputRegionArg.getType(); 1979 } 1980 1981 // Check if types of input arguments match region args types. 1982 for (auto &item : 1983 llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) { 1984 Value output, outputRegionArg; 1985 unsigned index = item.index(); 1986 std::tie(output, outputRegionArg) = item.value(); 1987 if (output.getType() != outputRegionArg.getType()) 1988 return op.emitOpError("expected output arg ") 1989 << index << " with type = " << output.getType() 1990 << " to match region arg " 1991 << index + op.getNumLoops() + op.inputs().size() 1992 << " type = " << outputRegionArg.getType(); 1993 } 1994 return success(); 1995 } 1996 1997 namespace { 1998 1999 static constexpr int64_t kNoMatch = -1; 2000 2001 // Folds away TiledLoopOp inputs if they have no uses within the body. 2002 // 2003 // Example: 2004 // 2005 // %0 = linalg.tiled_loop ... ins (%in_ = %in: tensor<...>, 2006 // %in_buf_ = %in_buf: memref<...>) {...} 2007 // Becomes 2008 // 2009 // linalg.tiled_loop ... ins (%in_buf_ = %in_buf: memref<...>) {...} 2010 struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> { 2011 using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern; 2012 2013 LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, 2014 PatternRewriter &rewriter) const final { 2015 SmallVector<Value, 2> newInputs, regionInputTensorArgs; 2016 // Store ids of the corresponding old and new input operands. 2017 SmallVector<int64_t, 2> oldInputIdToNew(tiledLoop.inputs().size(), 2018 kNoMatch); 2019 for (const auto &en : llvm::enumerate( 2020 llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) { 2021 Value in, bbArg; 2022 size_t index = en.index(); 2023 std::tie(in, bbArg) = en.value(); 2024 if (!bbArg.use_empty()) { 2025 oldInputIdToNew[index] = newInputs.size(); 2026 newInputs.push_back(in); 2027 } 2028 } 2029 if (newInputs.size() == tiledLoop.inputs().size()) 2030 return failure(); 2031 Location loc = tiledLoop.getLoc(); 2032 auto newTiledLoop = rewriter.create<TiledLoopOp>( 2033 loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), 2034 newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(), 2035 tiledLoop.distribution_types()); 2036 2037 // Clone the region. 2038 BlockAndValueMapping bvm; 2039 bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); 2040 bvm.map(tiledLoop.getRegionOutputArgs(), 2041 newTiledLoop.getRegionOutputArgs()); 2042 for (const auto &en : llvm::enumerate(oldInputIdToNew)) 2043 if (en.value() != kNoMatch) 2044 bvm.map(tiledLoop.getRegionInputArgs()[en.index()], 2045 newTiledLoop.getRegionInputArgs()[en.value()]); 2046 OpBuilder innerBuilder = 2047 OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); 2048 for (auto &op : *tiledLoop.getBody()) 2049 innerBuilder.clone(op, bvm); 2050 rewriter.replaceOp(tiledLoop, newTiledLoop.getResults()); 2051 2052 return success(); 2053 } 2054 }; 2055 2056 } // namespace 2057 2058 /// A simple, conservative analysis to determine if the loop is shape 2059 /// conserving. I.e., the type of the arg-th yielded value is the same as the 2060 /// type of the corresponding basic block argument of the loop. 2061 /// Note: This function handles only simple cases. Expand as needed. 2062 static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) { 2063 auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator()); 2064 if (yieldOp.values().empty()) 2065 // Tiled loop either has no outputs or is a "memref-based version". In 2066 // either case, the loop is shape conserving. 2067 return true; 2068 assert(arg < static_cast<int64_t>(yieldOp.values().size()) && 2069 "arg is out of bounds"); 2070 Value value = yieldOp.values()[arg]; 2071 while (value) { 2072 if (value == loopOp.getRegionOutputArgs()[arg]) 2073 return true; 2074 OpResult opResult = value.dyn_cast<OpResult>(); 2075 if (!opResult) 2076 return false; 2077 2078 using tensor::InsertSliceOp; 2079 value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner()) 2080 .template Case<InsertSliceOp>( 2081 [&](InsertSliceOp op) { return op.dest(); }) 2082 .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) { 2083 return isShapePreserving(loopOp, opResult.getResultNumber()) 2084 ? loopOp.outputs()[opResult.getResultNumber()] 2085 : Value(); 2086 }) 2087 .Default([&](auto op) { return Value(); }); 2088 } 2089 return false; 2090 } 2091 2092 namespace { 2093 2094 /// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block 2095 /// to dim(y) where `y` is the initial input/output value of the argument. 2096 /// 2097 /// E.g.: 2098 /// %y = ... : tensor<...> 2099 /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) { 2100 /// tensor.dim %x, %c0 : tensor<...> 2101 /// } 2102 /// 2103 /// is folded to: 2104 /// %y = ... : tensor<...> 2105 /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) { 2106 /// tensor.dim %y, %c0 : tensor<...> 2107 /// } 2108 /// 2109 /// Note: Dim ops are folded only if it can be proven that the runtime type of 2110 /// the yielded value (in case of outputs) does not change with loop iterations. 2111 template <typename OpTy> 2112 struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> { 2113 using OpRewritePattern<OpTy>::OpRewritePattern; 2114 2115 LogicalResult matchAndRewrite(OpTy dimOp, 2116 PatternRewriter &rewriter) const final { 2117 auto src = dimOp.source().template dyn_cast<BlockArgument>(); 2118 if (!src) 2119 return failure(); 2120 auto loopOp = 2121 dyn_cast<TiledLoopOp>(src.getOwner()->getParent()->getParentOp()); 2122 if (!loopOp) 2123 return failure(); 2124 unsigned numLoops = loopOp.getNumLoops(); 2125 unsigned numInputArgs = loopOp.getRegionInputArgs().size(); 2126 if (src.getArgNumber() >= numInputArgs + numLoops && 2127 !isShapePreserving(loopOp, 2128 src.getArgNumber() - numInputArgs - numLoops)) 2129 return failure(); 2130 2131 auto inputArgs = loopOp.getRegionInputArgs(); 2132 auto it1 = llvm::find(inputArgs, src); 2133 if (it1 != inputArgs.end()) { 2134 rewriter.updateRootInPlace(dimOp, [&] { 2135 dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]); 2136 }); 2137 return success(); 2138 } 2139 2140 auto outputArgs = loopOp.getRegionOutputArgs(); 2141 auto it2 = llvm::find(outputArgs, src); 2142 if (it2 != outputArgs.end()) { 2143 rewriter.updateRootInPlace(dimOp, [&] { 2144 dimOp.sourceMutable().assign( 2145 loopOp.outputs()[it2 - outputArgs.begin()]); 2146 }); 2147 return success(); 2148 } 2149 2150 return failure(); 2151 } 2152 }; 2153 2154 /// Fold dim(r) where `r` is the result of a TiledLoopOp to dim(y) where `y` 2155 /// is the initial output value of the loop. 2156 /// 2157 /// E.g.: 2158 /// %y = ... : tensor<...> 2159 /// %r = linalg.tiled_loop ... outs(%i = %y : tensor<...>) { 2160 /// ... 2161 /// } 2162 /// %0 = tensor.dim %r, %c0 : tensor<...> 2163 /// 2164 /// is folded to: 2165 /// %y = ... : tensor<...> 2166 /// linalg.tiled_loop ... outs(%i = %y : tensor<...>) { 2167 /// ... 2168 /// } 2169 /// %0 = tensor.dim %y, %c0 : tensor<...> 2170 /// 2171 /// Note: Dim ops are folded only if it can be proven that the runtime type of 2172 /// the yielded value (in case of outputs) does not change with loop iterations. 2173 template <typename OpTy> 2174 struct DimOfTiledLoopResultFolder : public OpRewritePattern<OpTy> { 2175 using OpRewritePattern<OpTy>::OpRewritePattern; 2176 2177 LogicalResult matchAndRewrite(OpTy dimOp, 2178 PatternRewriter &rewriter) const final { 2179 auto loopOp = dimOp.source().template getDefiningOp<TiledLoopOp>(); 2180 if (!loopOp) 2181 return failure(); 2182 auto opResult = dimOp.source().template cast<OpResult>(); 2183 unsigned resultNumber = opResult.getResultNumber(); 2184 if (!isShapePreserving(loopOp, resultNumber)) 2185 return failure(); 2186 rewriter.updateRootInPlace(dimOp, [&]() { 2187 dimOp.sourceMutable().assign(loopOp.outputs()[resultNumber]); 2188 }); 2189 return success(); 2190 } 2191 }; 2192 2193 // Folds away TiledLoopOp output tensors when the following conditions are met: 2194 // * result of `linalg.tiled_loop` has no uses 2195 // * output tensor is the argument of `linalg.yield` 2196 // 2197 // Example: 2198 // 2199 // %0 = linalg.tiled_loop ... outs (%o_ = %out: tensor<...>, 2200 // %obuf_ = %out_buf: memref<...>) { 2201 // ... 2202 // linalg.yield %o_ : tensor ... 2203 // } 2204 // 2205 // Becomes 2206 // 2207 // linalg.tiled_loop ... outs (%obuf_ = %out_buf: memref<...>) { 2208 // ... 2209 // linalg.yield 2210 // } 2211 struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> { 2212 using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern; 2213 2214 LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, 2215 PatternRewriter &rewriter) const final { 2216 if (tiledLoop.getNumResults() == 0) 2217 return failure(); 2218 2219 Block *block = tiledLoop.getBody(); 2220 auto yieldOp = cast<linalg::YieldOp>(block->getTerminator()); 2221 2222 // Match the pattern and collect output buffers that will replace the output 2223 // tensors and also the ops that will be ignored when cloning the body. 2224 SmallVector<Value, 2> newOutputOperands, newYieldArgs; 2225 int resultId = 0; 2226 // Store ids of the corresponding old and new output operands. 2227 SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(), 2228 kNoMatch); 2229 // Store ids of the corresponding old and new results. 2230 SmallVector<int64_t, 2> oldResultIdToNew(tiledLoop.getNumResults(), 2231 kNoMatch); 2232 SmallVector<Value, 2> resultReplacement(tiledLoop.getNumResults()); 2233 for (const auto &en : llvm::enumerate( 2234 llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) { 2235 size_t index = en.index(); 2236 Value out = std::get<0>(en.value()); 2237 Value outRegionArg = std::get<1>(en.value()); 2238 2239 if (!out.getType().isa<RankedTensorType>()) { 2240 oldOutputIdToNew[index] = newOutputOperands.size(); 2241 newOutputOperands.push_back(out); 2242 continue; 2243 } 2244 Value result = tiledLoop.getResult(resultId); 2245 Value yieldArg = yieldOp.getOperand(resultId); 2246 if (yieldArg != outRegionArg || !result.use_empty()) { 2247 oldOutputIdToNew[index] = newOutputOperands.size(); 2248 oldResultIdToNew[resultId] = newYieldArgs.size(); 2249 resultReplacement[resultId] = out; 2250 newOutputOperands.push_back(out); 2251 newYieldArgs.push_back(yieldArg); 2252 } 2253 ++resultId; 2254 } 2255 if (newOutputOperands.size() == tiledLoop.outputs().size()) 2256 return failure(); 2257 2258 Location loc = tiledLoop.getLoc(); 2259 auto newTiledLoop = rewriter.create<TiledLoopOp>( 2260 loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), 2261 tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(), 2262 tiledLoop.distribution_types()); 2263 2264 // Clone the region. 2265 BlockAndValueMapping bvm; 2266 bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); 2267 bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs()); 2268 for (const auto &en : llvm::enumerate(oldOutputIdToNew)) { 2269 if (en.value() != kNoMatch) 2270 bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], 2271 newTiledLoop.getRegionOutputArgs()[en.value()]); 2272 else 2273 bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], 2274 tiledLoop.outputs()[en.index()]); 2275 } 2276 OpBuilder innerBuilder = 2277 OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); 2278 for (auto &op : tiledLoop.getBody()->without_terminator()) 2279 innerBuilder.clone(op, bvm); 2280 innerBuilder.create<linalg::YieldOp>( 2281 loc, llvm::to_vector<2>(llvm::map_range( 2282 newYieldArgs, [&](Value arg) { return bvm.lookup(arg); }))); 2283 2284 for (const auto &en : llvm::enumerate(oldResultIdToNew)) 2285 if (en.value() != kNoMatch) 2286 resultReplacement[en.index()] = newTiledLoop.getResult(en.value()); 2287 rewriter.replaceOp(tiledLoop, resultReplacement); 2288 2289 return success(); 2290 } 2291 }; 2292 } // namespace 2293 2294 void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 2295 MLIRContext *context) { 2296 results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder, 2297 DimOfTiledLoopInsOutsFolder<tensor::DimOp>, 2298 DimOfTiledLoopInsOutsFolder<memref::DimOp>, 2299 DimOfTiledLoopResultFolder<tensor::DimOp>, 2300 DimOfTiledLoopResultFolder<memref::DimOp>>(context); 2301 } 2302 2303 LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>, 2304 SmallVectorImpl<OpFoldResult> &) { 2305 return foldMemRefCastInTiledLoopOp(*this); 2306 } 2307 2308 //===----------------------------------------------------------------------===// 2309 // IndexOp 2310 //===----------------------------------------------------------------------===// 2311 2312 static LogicalResult verify(IndexOp op) { 2313 auto linalgOp = dyn_cast<LinalgOp>(op->getParentOp()); 2314 if (!linalgOp) 2315 return op.emitOpError("expected parent op with LinalgOp interface"); 2316 if (linalgOp.getNumLoops() <= op.dim()) 2317 return op.emitOpError("expected dim (") 2318 << op.dim() << ") to be lower than the number of loops (" 2319 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; 2320 return success(); 2321 } 2322 2323 /////// Operations corresponding to library calls defined with Tablegen //////// 2324 2325 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" 2326 2327 #define GET_OP_CLASSES 2328 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" 2329 2330 #define GET_OP_CLASSES 2331 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 2332 2333 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. 2334 /// Assumes `op` is a LinalgOp. 2335 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, 2336 SmallVectorImpl<unsigned> &res) { 2337 if (!cast<LinalgOp>(op).iterator_types()) 2338 return; 2339 2340 unsigned dim = 0; 2341 for (auto tn : 2342 cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) { 2343 if (tn == iteratorTypeName) 2344 res.push_back(dim); 2345 ++dim; 2346 } 2347 } 2348 2349 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap, 2350 unsigned rank, 2351 MLIRContext *context) { 2352 if (maybeMap) 2353 return maybeMap.getValue(); 2354 if (rank == 0) 2355 return AffineMap::get(context); 2356 return AffineMap::getMultiDimIdentityMap(rank, context); 2357 } 2358 2359 SmallVector<AffineExpr, 4> 2360 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, 2361 MLIRContext *context) { 2362 SmallVector<AffineExpr, 4> res; 2363 res.reserve(num); 2364 for (unsigned i = 0; i < num; ++i) 2365 res.push_back(getAffineDimExpr(startIdx++, context)); 2366 return res; 2367 } 2368 2369 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a, 2370 ArrayRef<AffineExpr> b) { 2371 auto rangeA = llvm::make_range(a.begin(), a.end()); 2372 auto rangeB = llvm::make_range(b.begin(), b.end()); 2373 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB); 2374 return llvm::to_vector<4>(concatRanges); 2375 } 2376 2377 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { 2378 if (auto memref = t.dyn_cast<MemRefType>()) { 2379 ss << "view"; 2380 for (auto size : memref.getShape()) 2381 if (size < 0) 2382 ss << "sx"; 2383 else 2384 ss << size << "x"; 2385 appendMangledType(ss, memref.getElementType()); 2386 } else if (auto vec = t.dyn_cast<VectorType>()) { 2387 ss << "vector"; 2388 llvm::interleave( 2389 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); 2390 appendMangledType(ss, vec.getElementType()); 2391 } else if (t.isSignlessIntOrIndexOrFloat()) { 2392 ss << t; 2393 } else { 2394 llvm_unreachable("Invalid type for linalg library name mangling"); 2395 } 2396 } 2397 2398 std::string mlir::linalg::generateLibraryCallName(Operation *op) { 2399 assert(isa<LinalgOp>(op)); 2400 std::string name(op->getName().getStringRef().str()); 2401 name.reserve(128); 2402 std::replace(name.begin(), name.end(), '.', '_'); 2403 llvm::raw_string_ostream ss(name); 2404 ss << "_"; 2405 auto types = op->getOperandTypes(); 2406 llvm::interleave( 2407 types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, 2408 [&]() { ss << "_"; }); 2409 return ss.str(); 2410 } 2411 2412 //===----------------------------------------------------------------------===// 2413 // Support for named Linalg ops defined in ods-gen. 2414 //===----------------------------------------------------------------------===// 2415 2416 /// Generic entry point to create the block for the region of a LinalgOp. 2417 /// This is used by both named structured ops created by ods-gen and by manually 2418 /// defined C++ ops. 2419 /// This is used by both builders and parsers. 2420 /// This function creates the block in the region with arguments corresponding 2421 /// to the elemental types of `inputTypes` and `outputTypes`, which are asserted 2422 /// to be ShapedType. 2423 template <typename NamedStructuredOpType> 2424 static void fillStructuredOpRegion( 2425 OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, 2426 TypeRange outputTypes, 2427 llvm::function_ref<void(unsigned, unsigned)> errorHandler) { 2428 assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); })); 2429 2430 // TODO: atm all operands go through getElementTypeOrSelf, 2431 // reconsider when we have evidence we need to. 2432 SmallVector<Type, 8> argTypes; 2433 for (auto containers : {inputTypes, outputTypes}) 2434 for (auto t : containers) 2435 argTypes.push_back(getElementTypeOrSelf(t)); 2436 2437 // RAII. 2438 OpBuilder::InsertionGuard guard(opBuilder); 2439 Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); 2440 unsigned actual = body->getNumArguments(); 2441 unsigned expected = NamedStructuredOpType::getNumRegionArgs(); 2442 if (expected != actual) { 2443 if (errorHandler) 2444 errorHandler(expected, actual); 2445 return; 2446 } 2447 2448 opBuilder.setInsertionPointToStart(body); 2449 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); 2450 NamedStructuredOpType::regionBuilder(b, *body); 2451 2452 // indexing_maps is an auto-generated method. 2453 2454 // iterator_types is an auto-generated method. 2455 } 2456 2457 /// Generic entry point to create both the region and the block of a LinalgOp. 2458 template <typename NamedStructuredOpType> 2459 void createAndFillStructuredOpRegion(OpBuilder &opBuilder, 2460 OperationState &result, 2461 TypeRange inputTypes, 2462 TypeRange outputTypes) { 2463 Region ®ion = *result.addRegion(); 2464 fillStructuredOpRegion<NamedStructuredOpType>( 2465 opBuilder, region, inputTypes, outputTypes, 2466 [&](unsigned expected, unsigned actual) { 2467 assert(expected != actual && "incorrect number of arguments"); 2468 }); 2469 } 2470 2471 /// Common parsing used for both named structured ops created by ods-gen and by 2472 /// manually defined C++ ops. Does not handle regions. 2473 static ParseResult 2474 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, 2475 SmallVectorImpl<Type> &inputTypes, 2476 SmallVectorImpl<Type> &outputTypes) { 2477 llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc; 2478 SmallVector<OpAsmParser::OperandType, 4> inputsOperands, outputsOperands; 2479 2480 parser.parseOptionalAttrDict(result.attributes); 2481 2482 if (succeeded(parser.parseOptionalKeyword("ins"))) { 2483 if (parser.parseLParen()) 2484 return failure(); 2485 2486 inputsOperandsLoc = parser.getCurrentLocation(); 2487 if (parser.parseOperandList(inputsOperands) || 2488 parser.parseColonTypeList(inputTypes) || parser.parseRParen()) 2489 return failure(); 2490 } 2491 2492 if (succeeded(parser.parseOptionalKeyword("outs"))) { 2493 outputsOperandsLoc = parser.getCurrentLocation(); 2494 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || 2495 parser.parseColonTypeList(outputTypes) || parser.parseRParen()) 2496 return failure(); 2497 } 2498 2499 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, 2500 result.operands) || 2501 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, 2502 result.operands)) 2503 return failure(); 2504 2505 result.addAttribute("operand_segment_sizes", 2506 parser.getBuilder().getI32VectorAttr( 2507 {static_cast<int32_t>(inputsOperands.size()), 2508 static_cast<int32_t>(outputsOperands.size())})); 2509 return success(); 2510 } 2511 2512 template <typename NamedStructuredOpType> 2513 static void printCommonStructuredOpParts(OpAsmPrinter &p, 2514 NamedStructuredOpType op) { 2515 if (!op.inputs().empty()) 2516 p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; 2517 if (!op.outputs().empty()) 2518 p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; 2519 } 2520 2521 //===----------------------------------------------------------------------===// 2522 // Specific parsing and printing for named structured ops created by ods-gen. 2523 //===----------------------------------------------------------------------===// 2524 2525 template <typename NamedStructuredOpType> 2526 static ParseResult 2527 parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, 2528 TypeRange inputTypes, TypeRange outputTypes) { 2529 ParseResult res = success(); 2530 OpBuilder opBuilder(parser.getContext()); 2531 // Resolve `captures` into `capturedValues` at parse time so we can build the 2532 // region with captures. 2533 SmallVector<Value> capturedValues; 2534 fillStructuredOpRegion<NamedStructuredOpType>( 2535 opBuilder, region, inputTypes, outputTypes, 2536 [&](unsigned expected, unsigned actual) { 2537 res = parser.emitError( 2538 parser.getCurrentLocation(), 2539 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " 2540 "region expects {0} args, got {1}", 2541 expected, actual)); 2542 region.front().dump(); 2543 }); 2544 return res; 2545 } 2546 2547 static ParseResult 2548 parseNamedStructuredOpResults(OpAsmParser &parser, 2549 SmallVectorImpl<Type> &resultTypes) { 2550 if (parser.parseOptionalArrowTypeList(resultTypes)) 2551 return failure(); 2552 return success(); 2553 } 2554 2555 template <typename NamedStructuredOpType> 2556 static ParseResult parseNamedStructuredOp(OpAsmParser &parser, 2557 OperationState &result) { 2558 // TODO: Enable when ods-gen supports captures. 2559 SmallVector<Type, 1> inputTypes, outputTypes; 2560 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) 2561 return failure(); 2562 2563 // TODO: consider merging results parsing into region parsing. 2564 // Need to wait for declarative assembly resolution to decide. 2565 SmallVector<Type, 1> outputTensorsTypes; 2566 if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) 2567 return failure(); 2568 result.addTypes(outputTensorsTypes); 2569 2570 std::unique_ptr<Region> region = std::make_unique<Region>(); 2571 if (parseNamedStructuredOpRegion<NamedStructuredOpType>( 2572 parser, *region, inputTypes, outputTypes)) 2573 return failure(); 2574 result.addRegion(std::move(region)); 2575 2576 return success(); 2577 } 2578 2579 static void printNamedStructuredOpResults(OpAsmPrinter &p, 2580 TypeRange resultTypes) { 2581 if (resultTypes.empty()) 2582 return; 2583 p.printOptionalArrowTypeList(resultTypes); 2584 } 2585 2586 template <typename NamedStructuredOpType> 2587 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { 2588 p.printOptionalAttrDict( 2589 op->getAttrs(), 2590 /*elidedAttrs=*/{"operand_segment_sizes", 2591 // See generated code in mlir-linalg-yaml-gen.cpp 2592 "linalg.memoized_indexing_maps"}); 2593 2594 // Printing is shared with generic ops, except for the region and 2595 // attributes. 2596 printCommonStructuredOpParts(p, op); 2597 2598 // Results printing. 2599 printNamedStructuredOpResults(p, op.result_tensors().getTypes()); 2600 2601 // Region is elided. 2602 } 2603 2604 template <typename NamedStructuredOpType> 2605 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { 2606 return verifyGenericOp<NamedStructuredOpType>(op); 2607 } 2608 2609 //===----------------------------------------------------------------------===// 2610 // Canonicalizers and Folders. 2611 //===----------------------------------------------------------------------===// 2612 2613 namespace { 2614 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> { 2615 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 2616 2617 LogicalResult matchAndRewrite(LinalgOp op, 2618 PatternRewriter &rewriter) const override { 2619 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 2620 // Linalg "inputs" may be either tensor or memref type. 2621 // tensor<0xelt_type> is a convention that may not always mean 2622 // "0 iterations". Only erase in cases we see memref<...x0x...>. 2623 auto mt = opOperand->get().getType().dyn_cast<MemRefType>(); 2624 if (!mt) 2625 continue; 2626 if (llvm::is_contained(op.getShape(opOperand), 0)) { 2627 rewriter.eraseOp(op); 2628 return success(); 2629 } 2630 } 2631 return failure(); 2632 } 2633 }; 2634 2635 struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> { 2636 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 2637 2638 LogicalResult matchAndRewrite(LinalgOp op, 2639 PatternRewriter &rewriter) const override { 2640 // If no operand comes from a tensor::CastOp and can be folded then fail. 2641 bool hasTensorCastOperand = 2642 llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { 2643 if (opOperand->get().isa<BlockArgument>()) 2644 return false; 2645 auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>(); 2646 return castOp && canFoldIntoConsumerOp(castOp); 2647 }); 2648 if (!hasTensorCastOperand) 2649 return failure(); 2650 2651 SmallVector<Type, 4> newResultTypes; 2652 newResultTypes.reserve(op->getNumResults()); 2653 SmallVector<Value, 4> newOperands; 2654 newOperands.reserve(op->getNumOperands()); 2655 // Inputs may fold. 2656 for (OpOperand *opOperand : op.getInputOperands()) { 2657 auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>(); 2658 newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) 2659 ? tensorCastOp.source() 2660 : opOperand->get()); 2661 } 2662 // Init tensors may fold, in which case the resultType must also change. 2663 for (OpOperand *opOperand : op.getOutputOperands()) { 2664 auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>(); 2665 bool fold = canFoldIntoConsumerOp(tensorCastOp); 2666 newOperands.push_back(fold ? tensorCastOp.getOperand() 2667 : opOperand->get()); 2668 newResultTypes.push_back(newOperands.back().getType()); 2669 } 2670 // Clone op. 2671 Operation *newOp = 2672 op.clone(rewriter, op->getLoc(), newResultTypes, newOperands); 2673 SmallVector<Value, 4> replacements; 2674 replacements.reserve(newOp->getNumResults()); 2675 for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { 2676 Value oldResult = std::get<0>(result); 2677 Value newResult = std::get<1>(result); 2678 if (newResult.getType() != oldResult.getType()) { 2679 replacements.push_back(rewriter.create<tensor::CastOp>( 2680 op->getLoc(), oldResult.getType(), newResult)); 2681 } else { 2682 replacements.push_back(newResult); 2683 } 2684 } 2685 rewriter.replaceOp(op, replacements); 2686 2687 return success(); 2688 } 2689 }; 2690 2691 } // namespace 2692 2693 #define LINALGOP_FOLDERS(XXX) \ 2694 LogicalResult XXX::fold(ArrayRef<Attribute>, \ 2695 SmallVectorImpl<OpFoldResult> &) { \ 2696 return foldMemRefCast(*this); \ 2697 } 2698 2699 LINALGOP_FOLDERS(CopyOp) 2700 LINALGOP_FOLDERS(FillOp) 2701 LINALGOP_FOLDERS(GenericOp) 2702 2703 // All named ops canonicalizers and folders are auto-generated in the 2704 // .cpp.inc. 2705 2706 //===----------------------------------------------------------------------===// 2707 // LinalgDialect 2708 //===----------------------------------------------------------------------===// 2709 2710 void LinalgDialect::getCanonicalizationPatterns( 2711 RewritePatternSet &results) const { 2712 results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext()); 2713 } 2714 2715 Operation *LinalgDialect::materializeConstant(OpBuilder &builder, 2716 Attribute value, Type type, 2717 Location loc) { 2718 return builder.create<arith::ConstantOp>(loc, type, value); 2719 } 2720