1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the Linalg operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/IR/AffineExprVisitor.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Interfaces/InferTypeOpInterface.h" 26 #include "mlir/Parser.h" 27 28 #include "llvm/ADT/DenseMap.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/ADT/SmallSet.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/FormatVariadic.h" 33 #include "llvm/Support/MathExtras.h" 34 #include "llvm/Support/raw_ostream.h" 35 36 using namespace mlir; 37 using namespace mlir::linalg; 38 39 /// Forward declarations. 40 41 /// Generic entry point to create the block for the region of a LinalgOp. 42 /// This is used by both named structured ops created by ods-gen and by manually 43 /// defined C++ ops. 44 /// This is used by both builders and parsers. 45 /// This function creates the block in the region with arguments corresponding 46 /// to the elemental types of `inputTypes` and `outputTypes`. The latter are 47 /// asserted to be of ShapedType. 48 template <typename NamedStructuredOpType> 49 static void fillStructuredOpRegion( 50 OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, 51 TypeRange outputTypes, 52 std::function<void(unsigned, unsigned)> errorHandler = nullptr); 53 54 /// Generic entry point to create both the region and the block of a LinalgOp. 55 template <typename NamedStructuredOpType> 56 static void 57 createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, 58 TypeRange inputTypes, TypeRange outputTypes); 59 60 /// Common parsing and printing used for both named structured ops created by 61 /// ods-gen and by manually defined C++ ops. Does not handle regions. 62 static ParseResult 63 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, 64 SmallVectorImpl<Type> &inputTypes, 65 SmallVectorImpl<Type> &outputTypes); 66 template <typename NamedStructuredOpType> 67 static void printCommonStructuredOpParts(OpAsmPrinter &p, 68 NamedStructuredOpType op); 69 70 /// Specific parsing and printing for named structured ops created by ods-gen. 71 template <typename NamedStructuredOpType> 72 static ParseResult 73 parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, 74 TypeRange inputTypes, TypeRange outputTypes); 75 76 static ParseResult 77 parseNamedStructuredOpResults(OpAsmParser &parser, 78 SmallVectorImpl<Type> &resultTypes); 79 80 template <typename NamedStructuredOpType> 81 static ParseResult parseNamedStructuredOp(OpAsmParser &parser, 82 OperationState &result); 83 84 static void printNamedStructuredOpResults(OpAsmPrinter &p, 85 TypeRange resultTypes); 86 87 template <typename NamedStructuredOpType> 88 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); 89 90 /// Helper function to convert a Value into an OpFoldResult, if the Value is 91 /// known to be a constant index value. 92 static SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) { 93 return llvm::to_vector<4>( 94 llvm::map_range(values, [](Value v) -> OpFoldResult { 95 APInt intValue; 96 if (v.getType().isa<IndexType>() && 97 matchPattern(v, m_ConstantInt(&intValue))) { 98 return IntegerAttr::get(v.getType(), intValue.getSExtValue()); 99 } 100 return v; 101 })); 102 } 103 104 /// Helper function to convert a vector of `OpFoldResult`s into a vector of 105 /// `Value`s. 106 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, 107 ArrayRef<OpFoldResult> valueOrAttrVec) { 108 return llvm::to_vector<4>( 109 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 110 if (auto attr = value.dyn_cast<Attribute>()) 111 return b.create<ConstantIndexOp>(loc, 112 attr.cast<IntegerAttr>().getInt()); 113 return value.get<Value>(); 114 })); 115 } 116 117 /// This is a common class used for patterns of the form 118 /// ``` 119 /// someop(memrefcast(%src)) -> someop(%src) 120 /// ``` 121 /// It folds the source of the memref.cast into the root operation directly. 122 static LogicalResult foldMemRefCast(Operation *op) { 123 bool folded = false; 124 for (OpOperand &operand : op->getOpOperands()) { 125 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 126 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 127 operand.set(castOp.getOperand()); 128 folded = true; 129 } 130 } 131 return success(folded); 132 } 133 134 /// This is a specialization of `foldMemRefCast` used for patterns of the form 135 /// ``` 136 /// tiled_loop(memrefcast(%src)) -> tiled_loop(%src) 137 /// ``` 138 /// It folds the source of the memref.cast into the root operation directly. 139 static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) { 140 bool folded = false; 141 Location loc = op->getLoc(); 142 143 Block *body = op.getBody(); 144 OpBuilder b = OpBuilder::atBlockBegin(body); 145 146 // Update `input` and `output` operands and block arguments if necessary. 147 // Operands list: [lbs, ubs, steps, inputs, outputs]. 148 // Block args list: [ivs, inputs, outputs]. 149 for (size_t operandIndex = op.getNumControlOperands(), 150 bbArgIndex = op.getNumLoops(), e = op.getNumOperands(); 151 operandIndex < e; ++operandIndex, ++bbArgIndex) { 152 OpOperand &operand = op->getOpOperand(operandIndex); 153 154 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 155 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 156 operand.set(castOp.getOperand()); 157 BlockArgument newBbArg = 158 body->insertArgument(bbArgIndex, castOp.getOperand().getType()); 159 BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1); 160 161 // Insert memref.cast back to the original type. 162 oldBbArg.replaceAllUsesWith( 163 b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg)); 164 body->eraseArgument(oldBbArg.getArgNumber()); 165 166 folded = true; 167 } 168 } 169 return success(folded); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // Region builder helper. 174 // TODO: Move this to a utility library. 175 // The public methods on this class are referenced directly from generated code 176 // and bind by name to math functions in the DSL as: 177 // `applyfn__{fnName}` 178 // Examples: 179 // `applyfn__add` 180 // `applyfn__mul` 181 // The naming convention is intentional in order to match snake-cased DSL names. 182 // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. 183 // 184 // Implementations of the math functions must be polymorphic over numeric types, 185 // internally performing necessary casts. If the function application makes no 186 // sense, then the only recourse is to assert and return nullptr. This can be 187 // extended later if it becomes possible to fail construction of the region. The 188 // invariant should be enforced at a higher level. 189 // 190 // TODO: These helpers are currently type polymorphic over the class of integer 191 // and floating point types, but they will not internally cast within bit 192 // widths of a class (mixed precision such as i8->i32) or across classes 193 // (i.e. mixed float and integer). Many such combinations are ambiguous or need 194 // to be handled with care and work is being considered to extend the op 195 // language to make such cases explicit. In the mean-time, violating this will 196 // fail verification, which is deemed acceptable. 197 //===----------------------------------------------------------------------===// 198 199 namespace { 200 201 class RegionBuilderHelper { 202 public: 203 RegionBuilderHelper(MLIRContext *context, Block &block) 204 : context(context), block(block) {} 205 206 // Generates operations to cast the given operand to a specified type. 207 // If the cast cannot be performed, a warning will be issued and the 208 // operand returned as-is (which will presumably yield a verification 209 // issue downstream). 210 Value cast(Type toType, Value operand) { 211 OpBuilder builder = getBuilder(); 212 auto loc = operand.getLoc(); 213 214 if (operand.getType() == toType) 215 return operand; 216 if (auto toIntType = toType.dyn_cast<IntegerType>()) { 217 // If operand is floating point, cast directly to the int type. 218 if (operand.getType().isa<FloatType>()) 219 return builder.create<FPToSIOp>(loc, toType, operand); 220 // Cast index operands directly to the int type. 221 if (operand.getType().isIndex()) 222 return builder.create<IndexCastOp>(loc, toType, operand); 223 if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) { 224 // Either sign extend or truncate. 225 if (toIntType.getWidth() > fromIntType.getWidth()) 226 return builder.create<SignExtendIOp>(loc, toType, operand); 227 if (toIntType.getWidth() < fromIntType.getWidth()) 228 return builder.create<TruncateIOp>(loc, toType, operand); 229 } 230 } else if (auto toFloatType = toType.dyn_cast<FloatType>()) { 231 // If operand is integer, cast directly to the float type. 232 // Note that it is unclear how to cast from BF16<->FP16. 233 if (operand.getType().isa<IntegerType>()) 234 return builder.create<SIToFPOp>(loc, toFloatType, operand); 235 if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) { 236 if (toFloatType.getWidth() > fromFloatType.getWidth()) 237 return builder.create<FPExtOp>(loc, toFloatType, operand); 238 if (toFloatType.getWidth() < fromFloatType.getWidth()) 239 return builder.create<FPTruncOp>(loc, toFloatType, operand); 240 } 241 } 242 243 emitWarning(operand.getLoc()) << "could not cast operand of type " 244 << operand.getType() << " to " << toType; 245 return operand; 246 } 247 248 Value applyfn__add(Value lhs, Value rhs) { 249 OpBuilder builder = getBuilder(); 250 if (isFloatingPoint(lhs)) 251 return builder.create<AddFOp>(lhs.getLoc(), lhs, rhs); 252 if (isInteger(lhs)) 253 return builder.create<AddIOp>(lhs.getLoc(), lhs, rhs); 254 llvm_unreachable("unsupported non numeric type"); 255 } 256 257 Value applyfn__sub(Value lhs, Value rhs) { 258 OpBuilder builder = getBuilder(); 259 if (isFloatingPoint(lhs)) 260 return builder.create<SubFOp>(lhs.getLoc(), lhs, rhs); 261 if (isInteger(lhs)) 262 return builder.create<SubIOp>(lhs.getLoc(), lhs, rhs); 263 llvm_unreachable("unsupported non numeric type"); 264 } 265 266 Value applyfn__mul(Value lhs, Value rhs) { 267 OpBuilder builder = getBuilder(); 268 if (isFloatingPoint(lhs)) 269 return builder.create<MulFOp>(lhs.getLoc(), lhs, rhs); 270 if (isInteger(lhs)) 271 return builder.create<MulIOp>(lhs.getLoc(), lhs, rhs); 272 llvm_unreachable("unsupported non numeric type"); 273 } 274 275 void yieldOutputs(ValueRange values) { 276 assert(!values.empty() && "linalg ops must yield outputs"); 277 if (values.empty()) 278 return; 279 Value first = values.front(); 280 OpBuilder builder = getBuilder(); 281 builder.create<YieldOp>(first.getLoc(), values); 282 } 283 284 Value constant(std::string value) { 285 OpBuilder builder = getBuilder(); 286 Location loc = builder.getUnknownLoc(); 287 Attribute valueAttr = parseAttribute(value, builder.getContext()); 288 return builder.create<ConstantOp>(loc, valueAttr.getType(), valueAttr); 289 } 290 291 Value index(int64_t dim) { 292 OpBuilder builder = getBuilder(); 293 return builder.create<IndexOp>(builder.getUnknownLoc(), dim); 294 } 295 296 Type getIntegerType(unsigned width) { 297 return IntegerType::get(context, width); 298 } 299 300 Type getFloat32Type() { return Float32Type::get(context); } 301 302 Type getFloat64Type() { return Float64Type::get(context); } 303 304 private: 305 MLIRContext *context; 306 Block █ 307 308 bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); } 309 bool isInteger(Value value) { return value.getType().isa<IntegerType>(); } 310 311 OpBuilder getBuilder() { 312 OpBuilder builder(context); 313 builder.setInsertionPointToEnd(&block); 314 return builder; 315 } 316 }; 317 318 } // namespace 319 320 //===----------------------------------------------------------------------===// 321 // CopyOp 322 //===----------------------------------------------------------------------===// 323 void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { 324 assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args"); 325 b.create<linalg::YieldOp>(block.getArgument(0)); 326 } 327 328 void CopyOp::build(OpBuilder &builder, OperationState &result, Value input, 329 Value output, AffineMap inputPermutation, 330 AffineMap outputPermutation, 331 ArrayRef<NamedAttribute> namedAttrs) { 332 result.addOperands({input, output}); 333 result.addAttributes(namedAttrs); 334 if (inputPermutation) 335 result.addAttribute("inputPermutation", 336 AffineMapAttr::get(inputPermutation)); 337 if (outputPermutation) 338 result.addAttribute("outputPermutation", 339 AffineMapAttr::get(outputPermutation)); 340 result.addRegion(); 341 fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(), 342 TypeRange{input.getType()}, 343 TypeRange{output.getType()}); 344 } 345 346 ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType, 347 Type outputType) { 348 OpBuilder opBuilder(parser.getBuilder().getContext()); 349 fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType}, 350 TypeRange{outputType}); 351 return success(); 352 } 353 354 /// CopyOp region is elided when printing. 355 void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} 356 357 static LogicalResult verify(CopyOp op) { 358 OpOperand *output = op.getOutputOperand(0); 359 OpOperand *input = op.getInputOperand(0); 360 if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get())) 361 return op.emitOpError("expects views of the same type"); 362 if (op.getRank(input) != op.getRank(output)) 363 return op.emitOpError("expects views of the same rank"); 364 auto rank = op.getNumParallelLoops(); 365 auto inputPermutationMap = op.inputPermutation(); 366 if (inputPermutationMap) { 367 if (inputPermutationMap->getNumInputs() != rank) 368 return op.emitOpError("expects optional input_permutation map of rank ") 369 << rank; 370 if (!inputPermutationMap->isPermutation()) 371 return op.emitOpError( 372 "expects optional input_permutation map to be a permutation"); 373 } 374 auto outputPermutationMap = op.outputPermutation(); 375 if (outputPermutationMap) { 376 if (outputPermutationMap->getNumInputs() != rank) 377 return op.emitOpError("expects optional output_permutation map of rank ") 378 << rank; 379 if (!outputPermutationMap->isPermutation()) 380 return op.emitOpError( 381 "expects optional output_permutation map to be a permutation"); 382 } 383 if (rank == 0 && inputPermutationMap) 384 return op.emitOpError("expected no input permutation when rank == 0"); 385 if (rank == 0 && outputPermutationMap) 386 return op.emitOpError("expected no output permutation when rank == 0"); 387 return success(); 388 } 389 390 void CopyOp::getEffects( 391 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 392 &effects) { 393 effects.emplace_back(MemoryEffects::Read::get(), input(), 394 SideEffects::DefaultResource::get()); 395 effects.emplace_back(MemoryEffects::Write::get(), output(), 396 SideEffects::DefaultResource::get()); 397 } 398 399 //===----------------------------------------------------------------------===// 400 // FillOp 401 //===----------------------------------------------------------------------===// 402 void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { 403 assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args"); 404 b.create<linalg::YieldOp>(block.getArgument(0)); 405 } 406 407 void FillOp::build(OpBuilder &builder, OperationState &result, Value value, 408 Value output) { 409 build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value, 410 output); 411 fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), 412 TypeRange{value.getType()}, 413 TypeRange{output.getType()}, {}); 414 } 415 416 ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType, 417 Type outputType) { 418 OpBuilder opBuilder(parser.getBuilder().getContext()); 419 fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType}, 420 TypeRange{outputType}); 421 return success(); 422 } 423 424 /// FillOp region is elided when printing. 425 void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} 426 427 static LogicalResult verify(FillOp op) { 428 OpOperand *output = op.getOutputOperand(0); 429 Type fillType = op.value().getType(); 430 if (getElementTypeOrSelf(output->get()) != fillType) 431 return op.emitOpError("expects fill type to match view elemental type"); 432 if (!op.getNumResults() && !output->get().getType().isa<MemRefType>()) { 433 return op.emitOpError( 434 "expected fill op with no result value to use memref type"); 435 } 436 return success(); 437 } 438 439 void FillOp::getEffects( 440 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 441 &effects) { 442 if (output().getType().isa<MemRefType>()) 443 effects.emplace_back(MemoryEffects::Write::get(), output(), 444 SideEffects::DefaultResource::get()); 445 } 446 447 //===----------------------------------------------------------------------===// 448 // GenericOps 449 //===----------------------------------------------------------------------===// 450 void GenericOp::build( 451 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, 452 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 453 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, 454 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { 455 build(builder, result, resultTensorTypes, inputs, outputs, 456 builder.getAffineMapArrayAttr(indexingMaps), 457 builder.getStrArrayAttr(iteratorTypes), 458 doc.empty() ? StringAttr() : builder.getStringAttr(doc), 459 libraryCall.empty() ? StringAttr() 460 : builder.getStringAttr(libraryCall)); 461 if (!bodyBuild) 462 return; 463 464 SmallVector<Type, 4> blockArgTypes; 465 for (ValueRange container : {inputs, outputs}) 466 for (Value v : container) 467 blockArgTypes.push_back(getElementTypeOrSelf(v)); 468 469 OpBuilder::InsertionGuard guard(builder); 470 auto ®ion = *result.regions.front(); 471 Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); 472 bodyBuild(builder, result.location, bodyBlock->getArguments()); 473 } 474 475 void GenericOp::build( 476 OpBuilder &builder, OperationState &result, ValueRange inputs, 477 ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 478 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, 479 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { 480 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, 481 iteratorTypes, doc, libraryCall, bodyBuild); 482 } 483 484 void GenericOp::build( 485 OpBuilder &builder, OperationState &result, ValueRange inputs, 486 ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 487 ArrayRef<StringRef> iteratorTypes, 488 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { 489 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, 490 /*doc=*/"", 491 /*libraryCall=*/"", bodyBuild); 492 } 493 494 void GenericOp::build( 495 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, 496 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, 497 ArrayRef<StringRef> iteratorTypes, 498 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { 499 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, 500 iteratorTypes, 501 /*doc=*/"", 502 /*libraryCall=*/"", bodyBuild); 503 } 504 505 static void print(OpAsmPrinter &p, GenericOp op) { 506 p << op.getOperationName() << " "; 507 508 // Print extra attributes. 509 auto genericAttrNames = op.linalgTraitAttrNames(); 510 511 llvm::StringSet<> genericAttrNamesSet; 512 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); 513 SmallVector<NamedAttribute, 8> genericAttrs; 514 for (auto attr : op->getAttrs()) 515 if (genericAttrNamesSet.count(attr.first.strref()) > 0) 516 genericAttrs.push_back(attr); 517 if (!genericAttrs.empty()) { 518 auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); 519 p << genericDictAttr; 520 } 521 522 // Printing is shared with named ops, except for the region and attributes 523 printCommonStructuredOpParts(p, op); 524 525 genericAttrNames.push_back("operand_segment_sizes"); 526 genericAttrNamesSet.insert(genericAttrNames.back()); 527 528 bool hasExtraAttrs = false; 529 for (NamedAttribute n : op->getAttrs()) { 530 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref()))) 531 break; 532 } 533 if (hasExtraAttrs) { 534 p << " attrs = "; 535 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames); 536 } 537 538 // Print region. 539 if (!op.region().empty()) 540 p.printRegion(op.region()); 541 542 // Print results. 543 printNamedStructuredOpResults(p, op.result_tensors().getTypes()); 544 } 545 546 static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) { 547 DictionaryAttr dictAttr; 548 // Parse the core linalg traits that must check into a dictAttr. 549 // The name is unimportant as we will overwrite result.attributes. 550 // The core linalg traits must contain the information necessary to pass the 551 // verifier. 552 if (parser.parseAttribute(dictAttr, "_", result.attributes)) 553 return failure(); 554 result.attributes.assign(dictAttr.getValue().begin(), 555 dictAttr.getValue().end()); 556 557 // Parsing is shared with named ops, except for the region. 558 SmallVector<Type, 1> inputTypes, outputTypes; 559 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) 560 return failure(); 561 562 // Optional attributes may be added. 563 if (succeeded(parser.parseOptionalKeyword("attrs"))) 564 if (failed(parser.parseEqual()) || 565 failed(parser.parseOptionalAttrDict(result.attributes))) 566 return failure(); 567 568 SmallVector<OpAsmParser::OperandType, 8> regionOperands; 569 std::unique_ptr<Region> region = std::make_unique<Region>(); 570 SmallVector<Type, 8> operandTypes, regionTypes; 571 if (parser.parseRegion(*region, regionOperands, regionTypes)) 572 return failure(); 573 result.addRegion(std::move(region)); 574 575 // Generic ops may specify that a subset of its outputs are tensors. Such 576 // outputs are specified in the result type. 577 // TODO: may need to move output parsing before region parsing. 578 // Need to wait for declarative assembly resolution to decide. 579 SmallVector<Type, 1> outputTensorsTypes; 580 if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) 581 return failure(); 582 result.addTypes(outputTensorsTypes); 583 584 return success(); 585 } 586 587 static void getGenericEffectsImpl( 588 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 589 &effects, 590 ValueRange results, ValueRange inputBuffers, ValueRange outputs) { 591 for (Value value : results) { 592 effects.emplace_back(MemoryEffects::Allocate::get(), value, 593 SideEffects::DefaultResource::get()); 594 } 595 for (Value value : inputBuffers) { 596 effects.emplace_back(MemoryEffects::Read::get(), value, 597 SideEffects::DefaultResource::get()); 598 } 599 for (Value value : outputs) { 600 effects.emplace_back(MemoryEffects::Read::get(), value, 601 SideEffects::DefaultResource::get()); 602 effects.emplace_back(MemoryEffects::Write::get(), value, 603 SideEffects::DefaultResource::get()); 604 } 605 } 606 607 void GenericOp::getEffects( 608 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 609 &effects) { 610 SmallVector<Value> inputBuffers = getInputBufferOperands(); 611 SmallVector<Value> outputBuffers = getOutputBufferOperands(); 612 getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, 613 outputBuffers); 614 } 615 616 template <typename GenericOpType> 617 static LogicalResult verifyGenericOp(GenericOpType op) { 618 return success(); 619 } 620 621 static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } 622 623 //===----------------------------------------------------------------------===// 624 // InitTensorOp 625 //===----------------------------------------------------------------------===// 626 void InitTensorOp::build(OpBuilder &b, OperationState &result, 627 ArrayRef<OpFoldResult> sizes, Type elementType, 628 ArrayRef<NamedAttribute> attrs) { 629 unsigned rank = sizes.size(); 630 SmallVector<Value, 4> dynamicSizes; 631 SmallVector<int64_t, 4> staticSizes; 632 for (unsigned i = 0; i < rank; ++i) { 633 dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes, 634 ShapedType::kDynamicSize); 635 } 636 auto resultType = RankedTensorType ::get(staticSizes, elementType); 637 build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); 638 result.addAttributes(attrs); 639 } 640 641 static LogicalResult verify(InitTensorOp op) { 642 RankedTensorType resultType = op.getType(); 643 SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range( 644 op.static_sizes().cast<ArrayAttr>(), 645 [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); })); 646 647 if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(), 648 op.static_sizes(), op.sizes(), 649 ShapedType::isDynamic))) 650 return failure(); 651 652 if (op.static_sizes().size() != static_cast<unsigned>(resultType.getRank())) 653 return op->emitError("expected ") 654 << resultType.getRank() << " sizes values"; 655 656 Type expectedType = 657 InitTensorOp::inferResultType(staticSizes, resultType.getElementType()); 658 if (resultType != expectedType) { 659 return op.emitError("specified type ") 660 << resultType << " does not match the inferred type " 661 << expectedType; 662 } 663 return success(); 664 } 665 666 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes, 667 Type elementType) { 668 return RankedTensorType::get(staticSizes, elementType); 669 } 670 671 namespace { 672 /// Change the type of the result of a `linalg.init_tensor` by making the result 673 /// type statically sized along dimension that in the original operation where 674 /// defined as dynamic, but the size was defined using a `constant` op. For 675 /// example 676 /// 677 /// %c5 = constant 5: index 678 /// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32> 679 /// 680 /// to 681 /// 682 /// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32> 683 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> { 684 using OpRewritePattern<InitTensorOp>::OpRewritePattern; 685 686 LogicalResult matchAndRewrite(InitTensorOp op, 687 PatternRewriter &rewriter) const override { 688 SmallVector<Value, 4> dynamicSizes; 689 SmallVector<int64_t, 4> staticSizes; 690 for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { 691 // If the size is already static, nothing to do. 692 if (!op.isDynamicSize(i)) { 693 staticSizes.push_back(op.getStaticSize(i)); 694 continue; 695 } 696 697 // If the size is dynamic but defined using a `constant` op, get the 698 // constant value to find the static size to use. 699 unsigned operandNum = op.getIndexOfDynamicSize(i); 700 Value sizeOperand = op.getOperand(operandNum); 701 if (auto constantIndexOp = sizeOperand.getDefiningOp<ConstantIndexOp>()) { 702 staticSizes.push_back(constantIndexOp.getValue()); 703 continue; 704 } 705 706 // Fallback case. Keep the size dynamic. 707 dynamicSizes.push_back(sizeOperand); 708 staticSizes.push_back(ShapedType::kDynamicSize); 709 } 710 RankedTensorType newType = 711 RankedTensorType::get(staticSizes, op.getType().getElementType()); 712 if (newType == op.getType()) 713 return failure(); 714 auto newOp = 715 rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes, 716 rewriter.getI64ArrayAttr(staticSizes)); 717 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 718 return success(); 719 } 720 }; 721 } // namespace 722 723 namespace { 724 /// Since `init_tensor` operation creates a tensor needed only for its shape, a 725 /// slice of this is also needed only for its shape. The result can be 726 /// replaced by a new init_tensor operation of the same size as the extract 727 /// slice op. 728 struct FoldInitTensorWithExtractSliceOp 729 : public OpRewritePattern<tensor::ExtractSliceOp> { 730 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 731 732 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 733 PatternRewriter &rewriter) const override { 734 if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>()) 735 return failure(); 736 rewriter.replaceOpWithNewOp<linalg::InitTensorOp>( 737 sliceOp, sliceOp.sizes(), 738 llvm::to_vector<4>(llvm::map_range( 739 sliceOp.static_sizes(), 740 [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })), 741 sliceOp.getSourceType().getElementType()); 742 return success(); 743 } 744 }; 745 746 template <typename TensorReshapeOp> 747 struct FoldInitTensorWithTensorReshapeOp 748 : public OpRewritePattern<TensorReshapeOp> { 749 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 750 751 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 752 PatternRewriter &rewriter) const override { 753 if (!reshapeOp.src().template getDefiningOp<InitTensorOp>()) 754 return failure(); 755 Location loc = reshapeOp.getLoc(); 756 SmallVector<SmallVector<Value>, 4> resultShapes; 757 if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter, 758 resultShapes)) || 759 !llvm::hasSingleElement(resultShapes)) 760 return failure(); 761 Value initTensor = rewriter.create<InitTensorOp>( 762 loc, getAsOpFoldResult(resultShapes[0]), 763 reshapeOp.getResultType().getElementType()); 764 if (initTensor.getType() != reshapeOp.getResultType()) { 765 rewriter.replaceOpWithNewOp<tensor::CastOp>( 766 reshapeOp, reshapeOp.getResultType(), initTensor); 767 } else { 768 rewriter.replaceOp(reshapeOp, initTensor); 769 } 770 return success(); 771 } 772 }; 773 } // namespace 774 775 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 776 MLIRContext *context) { 777 results.add<FoldInitTensorWithExtractSliceOp, 778 FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>, 779 FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>, 780 ReplaceStaticShapeDims>(context); 781 } 782 783 LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim( 784 OpBuilder &builder, 785 SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) { 786 auto shapes = llvm::to_vector<4>(llvm::map_range( 787 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 788 if (isDynamicSize(dim)) 789 return getDynamicSize(dim); 790 return builder.create<ConstantIndexOp>(getLoc(), getStaticSize(dim)); 791 })); 792 reifiedReturnShapes.emplace_back(std::move(shapes)); 793 return success(); 794 } 795 796 //===----------------------------------------------------------------------===// 797 // PadTensorOp 798 //===----------------------------------------------------------------------===// 799 800 static LogicalResult verify(PadTensorOp op) { 801 auto sourceType = op.source().getType().cast<RankedTensorType>(); 802 auto resultType = op.result().getType().cast<RankedTensorType>(); 803 auto expectedType = PadTensorOp::inferResultType( 804 sourceType, extractFromI64ArrayAttr(op.static_low()), 805 extractFromI64ArrayAttr(op.static_high())); 806 for (int i = 0, e = sourceType.getRank(); i < e; ++i) { 807 if (resultType.getDimSize(i) == expectedType.getDimSize(i)) 808 continue; 809 if (expectedType.isDynamicDim(i)) 810 continue; 811 return op.emitError("specified type ") 812 << resultType << " does not match the inferred type " 813 << expectedType; 814 } 815 816 auto ®ion = op.region(); 817 unsigned rank = resultType.getRank(); 818 Block &block = region.front(); 819 if (block.getNumArguments() != rank) 820 return op.emitError("expected the block to have ") << rank << " arguments"; 821 822 // Note: the number and type of yield values are checked in the YieldOp. 823 for (auto en : llvm::enumerate(block.getArgumentTypes())) { 824 if (!en.value().isIndex()) 825 return op.emitOpError("expected block argument ") 826 << (en.index() + 1) << " to be an index"; 827 } 828 829 return success(); 830 } 831 832 RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType, 833 ArrayRef<int64_t> staticLow, 834 ArrayRef<int64_t> staticHigh) { 835 unsigned rank = sourceType.getRank(); 836 assert(staticLow.size() == rank && "unexpected staticLow size mismatch"); 837 assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch"); 838 839 SmallVector<int64_t, 4> resultShape; 840 for (auto i : llvm::seq<unsigned>(0, rank)) { 841 if (sourceType.isDynamicDim(i) || 842 staticLow[i] == ShapedType::kDynamicSize || 843 staticHigh[i] == ShapedType::kDynamicSize) { 844 resultShape.push_back(ShapedType::kDynamicSize); 845 } else { 846 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; 847 resultShape.push_back(size); 848 } 849 } 850 851 return RankedTensorType::get(resultShape, sourceType.getElementType()); 852 } 853 854 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, 855 ArrayRef<int64_t> staticLow, 856 ArrayRef<int64_t> staticHigh, ValueRange low, 857 ValueRange high, ArrayRef<NamedAttribute> attrs) { 858 auto sourceType = source.getType().cast<RankedTensorType>(); 859 auto resultType = inferResultType(sourceType, staticLow, staticHigh); 860 build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), 861 b.getI64ArrayAttr(staticHigh)); 862 result.addAttributes(attrs); 863 } 864 865 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source, 866 ValueRange low, ValueRange high, 867 ArrayRef<NamedAttribute> attrs) { 868 auto sourceType = source.getType().cast<RankedTensorType>(); 869 unsigned rank = sourceType.getRank(); 870 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize); 871 build(b, result, source, staticVector, staticVector, low, high, attrs); 872 } 873 874 void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType, 875 Value source, ArrayRef<OpFoldResult> low, 876 ArrayRef<OpFoldResult> high, 877 ArrayRef<NamedAttribute> attrs) { 878 assert(resultType.isa<RankedTensorType>()); 879 auto sourceType = source.getType().cast<RankedTensorType>(); 880 unsigned rank = sourceType.getRank(); 881 SmallVector<Value, 4> dynamicLow, dynamicHigh; 882 SmallVector<int64_t, 4> staticLow, staticHigh; 883 for (unsigned i = 0; i < rank; ++i) { 884 // staticLow and staticHigh have full information of the padding config. 885 // This will grow staticLow and staticHigh with 1 value. If the config is 886 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 887 // value as well. 888 dispatchIndexOpFoldResult(low[i], dynamicLow, staticLow, 889 ShapedType::kDynamicSize); 890 dispatchIndexOpFoldResult(high[i], dynamicHigh, staticHigh, 891 ShapedType::kDynamicSize); 892 } 893 if (!resultType) { 894 resultType = 895 PadTensorOp::inferResultType(sourceType, staticLow, staticHigh); 896 } 897 build(b, result, resultType, source, dynamicLow, dynamicHigh, 898 b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh)); 899 } 900 901 PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad, 902 ArrayRef<OpFoldResult> low, 903 ArrayRef<OpFoldResult> high, 904 Location loc, OpBuilder &builder) { 905 auto padTensorOp = 906 builder.create<linalg::PadTensorOp>(loc, type, source, low, high); 907 int rank = padTensorOp.getResultType().getRank(); 908 SmallVector<Type, 4> blockArgTypes; 909 blockArgTypes.assign(rank, builder.getIndexType()); 910 auto ®ion = padTensorOp.region(); 911 // `builder.createBlock` changes the insertion point within the block. Create 912 // a guard to reset the insertion point of the builder after it is destroyed. 913 OpBuilder::InsertionGuard guard(builder); 914 builder.createBlock(®ion, region.end(), blockArgTypes); 915 builder.create<linalg::YieldOp>(loc, pad); 916 return padTensorOp; 917 } 918 919 PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad, 920 Location loc, OpBuilder &builder) { 921 SmallVector<OpFoldResult, 4> low, high; 922 auto rankedTensorType = type.cast<RankedTensorType>(); 923 assert(rankedTensorType.hasStaticShape()); 924 int rank = rankedTensorType.getRank(); 925 for (int i = 0; i < rank; ++i) { 926 auto dimOp = builder.createOrFold<memref::DimOp>(loc, source, i); 927 auto resultDimSize = builder.createOrFold<ConstantIndexOp>( 928 loc, rankedTensorType.getDimSize(i)); 929 auto highValue = builder.createOrFold<SubIOp>(loc, resultDimSize, dimOp); 930 high.push_back(highValue); 931 low.push_back(builder.createOrFold<ConstantIndexOp>(loc, 0)); 932 } 933 return PadTensorOp::createPadScalarOp(type, source, pad, low, high, loc, 934 builder); 935 } 936 937 LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim( 938 OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) { 939 Location loc = getLoc(); 940 auto lowPad = getMixedLowPad(); 941 auto highPad = getMixedHighPad(); 942 SmallVector<Value> shapes; 943 for (auto dim : llvm::seq<int64_t>(0, getSourceType().getRank())) { 944 // Shape along each dimension is source dim + low pad + high pad. 945 SmallVector<Value> mapOperands; 946 mapOperands.push_back(b.createOrFold<memref::DimOp>(loc, source(), dim)); 947 AffineExpr expr = b.getAffineDimExpr(0); 948 unsigned numSymbols = 0; 949 auto addOpFoldResult = [&](OpFoldResult valueOrAttr) { 950 if (Value v = valueOrAttr.dyn_cast<Value>()) { 951 expr = expr + b.getAffineSymbolExpr(numSymbols++); 952 mapOperands.push_back(v); 953 return; 954 } 955 int64_t staticValue = 956 valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt(); 957 expr = expr + staticValue; 958 }; 959 addOpFoldResult(lowPad[dim]); 960 addOpFoldResult(highPad[dim]); 961 shapes.push_back(applyMapToValues( 962 b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]); 963 } 964 reifiedReturnShapes.emplace_back(std::move(shapes)); 965 return success(); 966 } 967 968 namespace { 969 // Folds linalg.pad_tensor when padding is static zeros. 970 struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> { 971 using OpRewritePattern<PadTensorOp>::OpRewritePattern; 972 973 LogicalResult matchAndRewrite(PadTensorOp padTensorOp, 974 PatternRewriter &rewriter) const override { 975 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad()) 976 return failure(); 977 rewriter.replaceOpWithNewOp<tensor::CastOp>( 978 padTensorOp, padTensorOp.result().getType(), padTensorOp.source()); 979 return success(); 980 } 981 }; 982 983 } // namespace 984 985 void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 986 MLIRContext *context) { 987 results.add<FoldStaticZeroPadding>(context); 988 } 989 990 /// Return the padding value of the PadTensorOp if it constant. In this context, 991 /// "constant" means an actual constant or "defined outside of the block". 992 /// 993 /// Values are considered constant in three cases: 994 /// - A ConstantLike value. 995 /// - A basic block argument from a different block. 996 /// - A value defined outside of the block. 997 /// 998 /// If the padding value is not constant, an empty Value is returned. 999 Value PadTensorOp::getConstantPaddingValue() { 1000 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator()); 1001 if (!yieldOp || yieldOp.values().size() != 1) 1002 return {}; 1003 Value padValue = yieldOp.values().front(); 1004 // Check if yield value is a constant. 1005 if (matchPattern(padValue, m_Constant())) 1006 return padValue; 1007 // Check if yield value is defined inside the PadTensorOp block. 1008 if (padValue.getParentBlock() == &getRegion().front()) 1009 return {}; 1010 // Else: Yield value defined outside of the PadTensorOp block. 1011 return padValue; 1012 } 1013 1014 OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) { 1015 if (getResultType().hasStaticShape() && getResultType() == getSourceType()) 1016 return source(); 1017 return {}; 1018 } 1019 1020 //===----------------------------------------------------------------------===// 1021 // ReshapeOp 1022 //===----------------------------------------------------------------------===// 1023 1024 Optional<SmallVector<ReassociationIndices>> 1025 mlir::linalg::getReassociationIndicesForReshape(ShapedType sourceType, 1026 ShapedType targetType) { 1027 // Make the sourceType greater rank than the targetType. If they are same 1028 // rank, then its an unsupported reshape op. 1029 if (sourceType.getRank() == targetType.getRank()) 1030 return llvm::None; 1031 if (sourceType.getRank() < targetType.getRank()) 1032 std::swap(sourceType, targetType); 1033 1034 ArrayRef<int64_t> sourceShape = sourceType.getShape(); 1035 ArrayRef<int64_t> targetShape = targetType.getShape(); 1036 unsigned sourceDim = 0; 1037 SmallVector<ReassociationIndices> reassociationMap; 1038 reassociationMap.reserve(targetType.getRank()); 1039 1040 ReassociationIndices currIndices; 1041 int64_t prodOfCollapsedDims = 1; 1042 while (sourceDim < sourceShape.size()) { 1043 unsigned targetDim = reassociationMap.size(); 1044 1045 // If all the dimensions of the targetShape are exhausted, then the 1046 // remaining dims in the source shape must be all 1s. So for such cases, set 1047 // 1 as the target shape. The actual reassociation indices will be handled 1048 // later. 1049 int64_t currTargetShape = 1050 (targetDim < targetType.getRank() ? targetShape[targetDim] : 1); 1051 while (sourceShape[sourceDim] != ShapedType::kDynamicSize && 1052 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && 1053 sourceDim < sourceShape.size()) { 1054 prodOfCollapsedDims *= sourceShape[sourceDim]; 1055 currIndices.push_back(sourceDim++); 1056 } 1057 1058 // If the current expanded dimension is dynamic, then the collapsed 1059 // dimensions should also be dynamic and product of all previous unprocessed 1060 // dimensions of the expanded shape should be 1. 1061 if (sourceShape[sourceDim] == ShapedType::kDynamicSize && 1062 (currTargetShape != ShapedType::kDynamicSize || 1063 prodOfCollapsedDims != 1)) 1064 return llvm::None; 1065 1066 // If the collapsed dim is dynamic, the current expanded dim should also 1067 // be dynamic. 1068 if (currTargetShape == ShapedType::kDynamicSize && 1069 sourceShape[sourceDim] != ShapedType::kDynamicSize) 1070 return llvm::None; 1071 1072 // For static shapes, if the product of dimensions of the expanded shape 1073 // should match the collapsed dimension shape. 1074 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) 1075 return llvm::None; 1076 1077 currIndices.push_back(sourceDim++); 1078 // If the reassociation is empty but the currIndices is not, this by 1079 // definition is folding unit-dimensions with the result being scalar type. 1080 // So only append the `currIndices` if reassociation map is not empty. 1081 if (targetDim == targetShape.size()) { 1082 if (!reassociationMap.empty() && !currIndices.empty()) 1083 reassociationMap.back().append(currIndices.begin(), currIndices.end()); 1084 // Break out of the loops. We should be done here. 1085 break; 1086 } 1087 reassociationMap.emplace_back(ReassociationIndices{}); 1088 std::swap(reassociationMap.back(), currIndices); 1089 prodOfCollapsedDims = 1; 1090 } 1091 // All the dimensions in the two shapes must have been processed. 1092 if (reassociationMap.size() != targetShape.size() || 1093 sourceDim != sourceShape.size()) 1094 return llvm::None; 1095 return reassociationMap; 1096 } 1097 1098 template <typename ReshapeLikeOp> 1099 static void print(OpAsmPrinter &p, ReshapeLikeOp op) { 1100 p << op.getOperationName() << ' ' << op.src() << " ["; 1101 1102 llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) { 1103 p << '['; 1104 auto arrayAttr = attr.template cast<ArrayAttr>(); 1105 llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) { 1106 p << attr.cast<IntegerAttr>().getInt(); 1107 }); 1108 p << ']'; 1109 }); 1110 1111 p << "] "; 1112 p.printOptionalAttrDict(op->getAttrs(), 1113 /*elidedAttrs=*/{op.getReassociationAttrName()}); 1114 p << ": " << op.src().getType() << " into " << op.getType(); 1115 } 1116 1117 static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) { 1118 print<linalg::ExpandShapeOp>(p, op); 1119 } 1120 1121 static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) { 1122 print<linalg::CollapseShapeOp>(p, op); 1123 } 1124 1125 static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) { 1126 print<linalg::TensorExpandShapeOp>(p, op); 1127 } 1128 1129 static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) { 1130 print<linalg::TensorCollapseShapeOp>(p, op); 1131 } 1132 1133 static constexpr StringRef getReassociationAttrName() { 1134 return "reassociation"; 1135 } 1136 1137 static ParseResult parseReshapeLikeOp(OpAsmParser &parser, 1138 OperationState &result) { 1139 // Parse the operand. 1140 OpAsmParser::OperandType src; 1141 if (parser.parseOperand(src)) 1142 return failure(); 1143 1144 // Parse reassociation indices. 1145 Builder &b = parser.getBuilder(); 1146 SmallVector<Attribute, 4> reassociation; 1147 if (parser.parseLSquare()) 1148 return failure(); 1149 1150 while (true) { 1151 if (succeeded(parser.parseOptionalRSquare())) 1152 break; 1153 if (parser.parseLSquare()) 1154 return failure(); 1155 SmallVector<int64_t> indices; 1156 while (true) { 1157 int64_t index; 1158 if (parser.parseInteger(index)) 1159 return failure(); 1160 indices.push_back(index); 1161 1162 if (succeeded(parser.parseOptionalComma())) 1163 continue; 1164 if (failed(parser.parseRSquare())) 1165 return failure(); 1166 break; 1167 } 1168 reassociation.push_back(b.getI64ArrayAttr(indices)); 1169 if (succeeded(parser.parseOptionalComma())) 1170 continue; 1171 if (failed(parser.parseRSquare())) 1172 return failure(); 1173 break; 1174 } 1175 1176 result.addAttribute(getReassociationAttrName(), 1177 b.getArrayAttr(reassociation)); 1178 1179 // Parse optional attributes. 1180 parser.parseOptionalAttrDict(result.attributes); 1181 1182 // Parse types. 1183 Type srcType; 1184 Type resultType; 1185 if (parser.parseColon() || parser.parseType(srcType) || 1186 parser.resolveOperand(src, srcType, result.operands) || 1187 parser.parseKeyword("into") || parser.parseType(resultType)) 1188 return failure(); 1189 result.addTypes(resultType); 1190 return success(); 1191 } 1192 1193 /// Collapse reassociation maps that are used in pair of reshape ops where one 1194 /// is a producer and other is the consumer. Only valid to use this method when 1195 /// both the producer and consumer are collapsing dimensions or both are 1196 /// expanding dimensions. 1197 /// 1198 /// For example, 1199 /// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, 1200 /// affine_map<(d0, d1, d2, d3, d4) -> (d2)>, 1201 /// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] 1202 /// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>, 1203 /// affine_map<(d0, d1, d2) -> (d2)>] 1204 /// 1205 /// is folded into 1206 /// 1207 /// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, 1208 /// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] 1209 static Optional<SmallVector<ReassociationIndices>> 1210 collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer, 1211 ArrayRef<AffineMap> mapsConsumer, 1212 MLIRContext *context) { 1213 // Make the producer the larger sized vector. If they are of same size, the 1214 // resulting reshape is not a supported reshape op. 1215 if (mapsProducer.size() == mapsConsumer.size()) 1216 return llvm::None; 1217 if (mapsProducer.size() < mapsConsumer.size()) 1218 std::swap(mapsProducer, mapsConsumer); 1219 1220 // Handle the corner case of the result being a rank 0 shaped type. Return an 1221 // empty reassociation. 1222 if (mapsConsumer.empty()) 1223 return SmallVector<ReassociationIndices>{}; 1224 if (mapsProducer.size() != mapsConsumer[0].getNumDims()) 1225 return llvm::None; 1226 1227 unsigned currDim = 0; 1228 SmallVector<ReassociationIndices> reassociationMaps; 1229 for (AffineMap rhs : mapsConsumer) { 1230 ReassociationIndices reassociations; 1231 for (AffineExpr rhsExpr : rhs.getResults()) { 1232 AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>(); 1233 for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); 1234 i < e; ++i) 1235 reassociations.push_back(currDim++); 1236 } 1237 reassociationMaps.push_back(std::move(reassociations)); 1238 } 1239 return reassociationMaps; 1240 } 1241 1242 namespace { 1243 /// Pattern to collapse producer/consumer reshape ops that are both collapsing 1244 /// dimensions or are both expanding dimensions. 1245 template <typename ReshapeOpTy> 1246 struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> { 1247 using OpRewritePattern<ReshapeOpTy>::OpRewritePattern; 1248 LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, 1249 PatternRewriter &rewriter) const override { 1250 auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>(); 1251 if (!srcReshapeOp) 1252 return failure(); 1253 1254 ShapedType resultType = reshapeOp.getResultType(); 1255 Optional<SmallVector<ReassociationIndices>> reassociationIndices = 1256 collapseReassociationIndices(srcReshapeOp.getReassociationMaps(), 1257 reshapeOp.getReassociationMaps(), 1258 rewriter.getContext()); 1259 if (!reassociationIndices) 1260 return failure(); 1261 rewriter.replaceOpWithNewOp<ReshapeOpTy>( 1262 reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); 1263 return success(); 1264 } 1265 }; 1266 1267 /// Pattern to collapse producer/consumer reshape ops that are both collapsing 1268 /// dimensions or are both expanding dimensions. 1269 template <typename ReshapeOpTy, typename InverseReshapeOpTy> 1270 struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> { 1271 using OpRewritePattern<ReshapeOpTy>::OpRewritePattern; 1272 LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, 1273 PatternRewriter &rewriter) const override { 1274 auto srcReshapeOp = 1275 reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>(); 1276 if (!srcReshapeOp) 1277 return failure(); 1278 1279 ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); 1280 ShapedType intermediateType = reshapeOp.getSrcType(); 1281 ShapedType resultType = reshapeOp.getResultType(); 1282 1283 // If the source reshape can be collapsed/expanded into the target reshape 1284 // they can still be folded. This can only be reasoned about statically 1285 // for cases where 1286 // - either all shapes are static, or 1287 // - The number of dynamic dimensions matches in the source of source and 1288 // result with all other dimensions being 1. 1289 Optional<SmallVector<ReassociationIndices>> reassociationIndices = 1290 getReassociationIndicesForReshape(srcReshapeSrcType, resultType); 1291 if (!reassociationIndices) 1292 return failure(); 1293 bool originalOpExpands = 1294 intermediateType.getRank() > srcReshapeSrcType.getRank(); 1295 bool resultingOpExpands = 1296 resultType.getRank() > srcReshapeSrcType.getRank(); 1297 if (!(resultingOpExpands ^ originalOpExpands)) 1298 rewriter.replaceOpWithNewOp<InverseReshapeOpTy>( 1299 reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); 1300 else 1301 rewriter.replaceOpWithNewOp<ReshapeOpTy>( 1302 reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); 1303 return success(); 1304 } 1305 }; 1306 } // namespace 1307 1308 template <typename ReshapeOpTy, typename InverseReshapeOpTy> 1309 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, 1310 ArrayRef<Attribute> operands) { 1311 // Fold producer-consumer reshape ops that where the operand type of the 1312 // producer is same as the return type of the consumer. 1313 auto reshapeSrcOp = 1314 reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>(); 1315 if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) 1316 return reshapeSrcOp.src(); 1317 // Reshape of a constant can be replaced with a new constant. 1318 if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) { 1319 return elements.reshape( 1320 reshapeOp.getResult().getType().template cast<ShapedType>()); 1321 } 1322 return nullptr; 1323 } 1324 1325 /// Return true if the reassociation specification is valid, false otherwise. 1326 /// When false, the `invalidIndex` integer pointer is optionally filled with the 1327 /// index of the offending reassociation map. 1328 static bool isReassociationValid(ArrayRef<AffineMap> reassociation, 1329 int *invalidIndex = nullptr) { 1330 if (reassociation.empty()) 1331 return true; 1332 unsigned nDims = reassociation[0].getNumDims(); 1333 unsigned nextExpectedDim = 0; 1334 for (auto it : llvm::enumerate(reassociation)) { 1335 auto m = it.value(); 1336 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { 1337 if (invalidIndex) 1338 *invalidIndex = it.index(); 1339 return false; 1340 } 1341 for (auto e : m.getResults()) { 1342 auto d = e.dyn_cast<AffineDimExpr>(); 1343 if (!d || d.getPosition() != nextExpectedDim++) { 1344 if (invalidIndex) 1345 *invalidIndex = it.index(); 1346 return false; 1347 } 1348 } 1349 } 1350 if (nextExpectedDim != nDims) { 1351 if (invalidIndex) 1352 *invalidIndex = reassociation.size() - 1; 1353 return false; 1354 } 1355 return true; 1356 } 1357 1358 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1359 /// copies. 1360 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1361 ArrayRef<int64_t> sizes, 1362 ArrayRef<AffineExpr> strides) { 1363 assert(sizes.size() == strides.size() && "mismatched ranks"); 1364 // off by 1 indexing to avoid out of bounds 1365 // V 1366 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1367 // Only bands of static shapes are reshapable. This is due to the fact that 1368 // there is no relation between dynamic sizes and dynamic strides: we do not 1369 // have enough information to know whether a "-1" size corresponds to the 1370 // proper symbol in the AffineExpr of a stride. 1371 if (ShapedType::isDynamic(sizes[dim + 1])) 1372 return false; 1373 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1374 // simplify on the fly and catch more reshapable cases. 1375 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1376 return false; 1377 } 1378 return true; 1379 } 1380 1381 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1382 /// expected to be valid) to `type`. 1383 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1384 /// MemRefType. 1385 static MemRefType 1386 computeReshapeCollapsedType(MemRefType type, 1387 ArrayRef<AffineMap> reassociation) { 1388 auto sizes = type.getShape(); 1389 AffineExpr offset; 1390 SmallVector<AffineExpr, 4> strides; 1391 auto status = getStridesAndOffset(type, strides, offset); 1392 (void)status; 1393 assert(succeeded(status) && "expected strided memref"); 1394 1395 SmallVector<int64_t, 4> newSizes; 1396 newSizes.reserve(reassociation.size()); 1397 SmallVector<AffineExpr, 4> newStrides; 1398 newStrides.reserve(reassociation.size()); 1399 1400 // Use the fact that reassociation is valid to simplify the logic: only use 1401 // each map's rank. 1402 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1403 unsigned currentDim = 0; 1404 for (AffineMap m : reassociation) { 1405 unsigned dim = m.getNumResults(); 1406 int64_t size = 1; 1407 AffineExpr stride = strides[currentDim + dim - 1]; 1408 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { 1409 size = ShapedType::kDynamicSize; 1410 stride = AffineExpr(); 1411 } else { 1412 for (unsigned d = 0; d < dim; ++d) 1413 size *= sizes[currentDim + d]; 1414 } 1415 newSizes.push_back(size); 1416 newStrides.push_back(stride); 1417 currentDim += dim; 1418 } 1419 1420 // Early-exit: if `type` is contiguous, the result must be contiguous. 1421 if (canonicalizeStridedLayout(type).getAffineMaps().empty()) 1422 return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); 1423 1424 // Convert back to int64_t because we don't have enough information to create 1425 // new strided layouts from AffineExpr only. This corresponds to a case where 1426 // copies may be necessary. 1427 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1428 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1429 intOffset = o.getValue(); 1430 SmallVector<int64_t, 4> intStrides; 1431 intStrides.reserve(strides.size()); 1432 for (auto stride : newStrides) { 1433 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1434 intStrides.push_back(cst.getValue()); 1435 else 1436 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1437 } 1438 auto layout = 1439 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1440 return canonicalizeStridedLayout( 1441 MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); 1442 } 1443 1444 template <typename AffineExprTy> 1445 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { 1446 unsigned pos = 0; 1447 for (const auto &exprs : exprArrays) { 1448 for (auto expr : exprs) { 1449 expr.walk([&pos](AffineExpr e) { 1450 if (auto d = e.dyn_cast<AffineExprTy>()) 1451 pos = std::max(pos, d.getPosition()); 1452 }); 1453 } 1454 } 1455 return pos; 1456 } 1457 1458 static SmallVector<AffineMap, 4> 1459 getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) { 1460 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation); 1461 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 && 1462 "Expected symbol-less expressions"); 1463 SmallVector<AffineMap, 4> maps; 1464 maps.reserve(reassociation.size()); 1465 for (const auto &exprs : reassociation) { 1466 assert(!exprs.empty()); 1467 maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); 1468 } 1469 return maps; 1470 } 1471 1472 static SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices( 1473 OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) { 1474 SmallVector<ReassociationIndices, 2> reassociationIndices; 1475 for (const auto &exprs : reassociationExprs) { 1476 ReassociationIndices indices; 1477 indices.reserve(exprs.size()); 1478 for (const auto &expr : exprs) 1479 indices.push_back(expr.cast<AffineDimExpr>().getPosition()); 1480 reassociationIndices.push_back(indices); 1481 } 1482 return reassociationIndices; 1483 } 1484 1485 static SmallVector<SmallVector<AffineExpr, 2>, 2> 1486 convertReassociationIndicesToExprs( 1487 OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices) { 1488 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps; 1489 for (const auto &indices : reassociationIndices) { 1490 SmallVector<AffineExpr, 2> reassociationMap; 1491 reassociationMap.reserve(indices.size()); 1492 for (int64_t index : indices) 1493 reassociationMap.push_back(b.getAffineDimExpr(index)); 1494 reassociationMaps.push_back(std::move(reassociationMap)); 1495 } 1496 return reassociationMaps; 1497 } 1498 1499 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1500 return getSymbolLessAffineMaps(getReassociationExprs()); 1501 } 1502 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1503 OpBuilder b(this->getContext()); 1504 return convertReassociationIndicesToExprs(b, getReassociationIndices()); 1505 } 1506 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1507 return getSymbolLessAffineMaps(getReassociationExprs()); 1508 } 1509 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1510 OpBuilder b(this->getContext()); 1511 return convertReassociationIndicesToExprs(b, getReassociationIndices()); 1512 } 1513 1514 SmallVector<AffineMap, 4> TensorCollapseShapeOp::getReassociationMaps() { 1515 return getSymbolLessAffineMaps(getReassociationExprs()); 1516 } 1517 SmallVector<ReassociationExprs, 4> 1518 TensorCollapseShapeOp::getReassociationExprs() { 1519 OpBuilder b(this->getContext()); 1520 return convertReassociationIndicesToExprs(b, getReassociationIndices()); 1521 } 1522 SmallVector<AffineMap, 4> TensorExpandShapeOp::getReassociationMaps() { 1523 return getSymbolLessAffineMaps(getReassociationExprs()); 1524 } 1525 SmallVector<ReassociationExprs, 4> 1526 TensorExpandShapeOp::getReassociationExprs() { 1527 OpBuilder b(this->getContext()); 1528 return convertReassociationIndicesToExprs(b, getReassociationIndices()); 1529 } 1530 1531 /// For reshape op compute the shape at dimension `dimIndex` of the output in 1532 /// terms of shape of the `src`, when the reshape op is a collapsing 1533 /// operation. It is the product of the shape of the collapsed dimensions of the 1534 /// `src`. 1535 static OpFoldResult 1536 getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, 1537 int64_t dimIndex, Value src, 1538 ArrayRef<AffineMap> reassociationMap) { 1539 AffineMap map = reassociationMap[dimIndex]; 1540 unsigned startPos = 1541 map.getResults().front().cast<AffineDimExpr>().getPosition(); 1542 unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition(); 1543 AffineExpr expr; 1544 SmallVector<Value, 2> dynamicDims; 1545 for (auto dim : llvm::seq(startPos, endPos + 1)) { 1546 dynamicDims.push_back(builder.createOrFold<memref::DimOp>(loc, src, dim)); 1547 AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); 1548 expr = (expr ? expr * currExpr : currExpr); 1549 } 1550 return applyMapToValues(builder, loc, 1551 AffineMap::get(0, endPos - startPos + 1, expr), 1552 dynamicDims)[0]; 1553 } 1554 1555 /// Given the `src` of a collapsing reshape op and its reassociation maps, 1556 /// compute the shape of the result of the reshape. 1557 static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape( 1558 OpBuilder &builder, Location loc, Value src, 1559 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) { 1560 return llvm::to_vector<4>(llvm::map_range( 1561 llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) { 1562 return getCollapsedOutputDimFromInputShape(builder, loc, dim, src, 1563 reassociation); 1564 })); 1565 } 1566 1567 /// Compute a map that for a given dimension of the expanded type gives the 1568 /// dimension in the collapsed type it maps to. Essentially its the inverse of 1569 /// the `reassocation` maps. 1570 static llvm::DenseMap<int64_t, int64_t> 1571 getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) { 1572 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim; 1573 for (auto map : enumerate(reassociation)) { 1574 unsigned startPos = 1575 map.value().getResults().front().cast<AffineDimExpr>().getPosition(); 1576 unsigned endPos = 1577 map.value().getResults().back().cast<AffineDimExpr>().getPosition(); 1578 for (auto dim : llvm::seq(startPos, endPos + 1)) { 1579 expandedDimToCollapsedDim[dim] = map.index(); 1580 } 1581 } 1582 return expandedDimToCollapsedDim; 1583 } 1584 1585 /// For an expanding reshape op, compute the value for a dimension of the output 1586 /// from the shape of the input. 1587 static OpFoldResult getExpandedOutputDimFromInputShape( 1588 OpBuilder &builder, Location loc, int64_t dimIndex, Value src, 1589 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation, 1590 llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) { 1591 if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { 1592 return builder.getI64IntegerAttr(dstStaticShape[dimIndex]); 1593 } 1594 unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; 1595 unsigned startPos = reassociation[sourceDimPos] 1596 .getResults() 1597 .front() 1598 .cast<AffineDimExpr>() 1599 .getPosition(); 1600 unsigned endPos = reassociation[sourceDimPos] 1601 .getResults() 1602 .back() 1603 .cast<AffineDimExpr>() 1604 .getPosition(); 1605 int64_t linearizedStaticDim = 1; 1606 for (auto d : 1607 llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { 1608 if (d.index() + startPos == static_cast<unsigned>(dimIndex)) 1609 continue; 1610 assert(!ShapedType::isDynamic(d.value()) && 1611 "single dimension cannot be expanded into multiple dynamic " 1612 "dimensions"); 1613 linearizedStaticDim *= d.value(); 1614 } 1615 Value sourceDim = builder.create<memref::DimOp>(loc, src, sourceDimPos); 1616 return applyMapToValues( 1617 builder, loc, 1618 AffineMap::get( 1619 0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), 1620 sourceDim)[0]; 1621 } 1622 1623 /// Given the `src` of an expanding reshape op, the reassociation maps and the 1624 /// result type, compute the shape of the result of the reshape. 1625 static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape( 1626 OpBuilder &builder, Location loc, Value src, 1627 ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) { 1628 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim = 1629 getExpandedDimToCollapsedDimMap(reassociation); 1630 return llvm::to_vector<4>(llvm::map_range( 1631 llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) { 1632 return getExpandedOutputDimFromInputShape(builder, loc, dim, src, 1633 dstStaticShape, reassociation, 1634 expandedDimToCollapsedDim); 1635 })); 1636 } 1637 1638 static SmallVector<OpFoldResult, 4> 1639 getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, 1640 ArrayRef<int64_t> dstStaticShape, 1641 ArrayRef<AffineMap> reassocation) { 1642 return dstStaticShape.size() > 1643 static_cast<size_t>(src.getType().cast<ShapedType>().getRank()) 1644 ? getExpandedOutputShapeFromInputShape( 1645 builder, loc, src, dstStaticShape, reassocation) 1646 : getCollapsedOutputShapeFromInputShape( 1647 builder, loc, src, dstStaticShape, reassocation); 1648 } 1649 1650 static ArrayAttr 1651 getReassociationIndicesAttribute(OpBuilder &b, 1652 ArrayRef<ReassociationIndices> reassociation) { 1653 SmallVector<Attribute, 4> reassociationAttr = 1654 llvm::to_vector<4>(llvm::map_range( 1655 reassociation, [&](ReassociationIndices indices) -> Attribute { 1656 return b.getI64ArrayAttr(indices).cast<Attribute>(); 1657 })); 1658 return b.getArrayAttr(reassociationAttr); 1659 } 1660 1661 void mlir::linalg::ExpandShapeOp::build( 1662 OpBuilder &b, OperationState &result, Value src, 1663 ArrayRef<ReassociationIndices> reassociation, 1664 ArrayRef<NamedAttribute> attrs) { 1665 auto memRefType = src.getType().cast<MemRefType>(); 1666 auto resultType = computeReshapeCollapsedType( 1667 memRefType, getSymbolLessAffineMaps( 1668 convertReassociationIndicesToExprs(b, reassociation))); 1669 build(b, result, resultType, src, attrs); 1670 result.addAttribute(getReassociationAttrName(), 1671 getReassociationIndicesAttribute(b, reassociation)); 1672 } 1673 1674 Value mlir::linalg::ExpandShapeOp::getViewSource() { return src(); } 1675 1676 void mlir::linalg::CollapseShapeOp::build( 1677 OpBuilder &b, OperationState &result, Value src, 1678 ArrayRef<ReassociationIndices> reassociation, 1679 ArrayRef<NamedAttribute> attrs) { 1680 auto memRefType = src.getType().cast<MemRefType>(); 1681 auto resultType = computeReshapeCollapsedType( 1682 memRefType, getSymbolLessAffineMaps( 1683 convertReassociationIndicesToExprs(b, reassociation))); 1684 build(b, result, resultType, src, attrs); 1685 result.addAttribute(getReassociationAttrName(), 1686 getReassociationIndicesAttribute(b, reassociation)); 1687 } 1688 1689 Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); } 1690 1691 /// Verify that shapes of the reshaped types using following rules 1692 /// 1) if a dimension in the collapsed type is static, then the corresponding 1693 /// dimensions in the expanded shape should be 1694 /// a) static 1695 /// b) the product should be same as the collaped shape. 1696 /// 2) if a dimension in the collaped type is dynamic, one and only one of the 1697 /// corresponding dimensions in the expanded type should be dynamic. This 1698 /// rule is only needed with reshape operations that are expanding. 1699 template <typename OpTy> 1700 static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, 1701 ShapedType expandedType, 1702 bool isExpandingReshape) { 1703 ArrayRef<int64_t> collapsedShape = collapsedType.getShape(); 1704 ArrayRef<int64_t> expandedShape = expandedType.getShape(); 1705 unsigned expandedDimStart = 0; 1706 for (auto map : llvm::enumerate(op.getReassociationMaps())) { 1707 Optional<int64_t> dynamicShape; 1708 int64_t linearizedStaticShape = 1; 1709 for (auto dim : llvm::enumerate(expandedShape.slice( 1710 expandedDimStart, map.value().getNumResults()))) { 1711 if (ShapedType::isDynamic(dim.value())) { 1712 if (isExpandingReshape && dynamicShape) { 1713 return op->emitOpError("invalid to have a single dimension (") 1714 << map.index() << ") expanded into multiple dynamic dims (" 1715 << expandedDimStart + dynamicShape.getValue() << "," 1716 << expandedDimStart + dim.index() << ")"; 1717 } 1718 dynamicShape = dim.index(); 1719 } else { 1720 linearizedStaticShape *= dim.value(); 1721 } 1722 } 1723 if (dynamicShape) { 1724 if (!ShapedType::isDynamic(collapsedShape[map.index()])) { 1725 return op->emitOpError("expected dimension ") 1726 << map.index() 1727 << " of collapsed type to be dynamic since one or more of the " 1728 "corresponding dimensions in the expanded type is dynamic"; 1729 } 1730 } else { 1731 if (collapsedShape[map.index()] != linearizedStaticShape) { 1732 return op->emitOpError("expected dimension ") 1733 << map.index() << " of collapsed type to be static value of " 1734 << linearizedStaticShape << " "; 1735 } 1736 } 1737 expandedDimStart += map.value().getNumResults(); 1738 } 1739 return success(); 1740 } 1741 1742 // Common verifier for reshape-like types. Fills `expandedType` and 1743 // `collapsedType` with the proper `src` or `result` type. 1744 template <typename Op, typename T, 1745 bool isExpansion = std::is_same<Op, TensorExpandShapeOp>::value || 1746 std::is_same<Op, ExpandShapeOp>::value> 1747 static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, 1748 T collapsedType) { 1749 unsigned expandedRank = expandedType.getRank(); 1750 unsigned collapsedRank = collapsedType.getRank(); 1751 if (expandedRank < collapsedRank) 1752 return op.emitOpError("expected the type ") 1753 << expandedType 1754 << " to have higher rank than the type = " << collapsedType; 1755 if (expandedRank == 0) 1756 return op.emitOpError("expected non-zero memref ranks"); 1757 if (expandedRank == collapsedRank) 1758 return op.emitOpError("expected to collapse or expand dims"); 1759 1760 if (collapsedRank == 0) { 1761 // If collapsed rank is 0, then expanded type must be static shaped and of 1762 // sizes 1. 1763 if (llvm::any_of(expandedType.getShape(), 1764 [](int64_t dim) -> bool { return dim != 1; })) 1765 return op.emitOpError("invalid to reshape tensor/memref with non-unit " 1766 "extent dimensions to zero-rank tensor/memref"); 1767 return success(); 1768 } 1769 if (collapsedRank != op.reassociation().size()) 1770 return op.emitOpError("expected rank of the collapsed type(") 1771 << collapsedRank << ") to be the number of reassociation maps(" 1772 << op.reassociation().size() << ")"; 1773 auto maps = op.getReassociationMaps(); 1774 for (auto it : llvm::enumerate(maps)) 1775 if (it.value().getNumDims() != expandedRank) 1776 return op.emitOpError("expected reassociation map #") 1777 << it.index() << " of same rank as expanded memref(" 1778 << expandedRank << "), but got " << it.value().getNumDims(); 1779 int invalidIdx = 0; 1780 if (!isReassociationValid(maps, &invalidIdx)) 1781 return op.emitOpError("expected reassociation map #") 1782 << invalidIdx << " to be valid and contiguous"; 1783 return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion); 1784 } 1785 1786 template <typename TensorReshapeOp> 1787 static LogicalResult verifyReshapeOp(TensorReshapeOp op, 1788 MemRefType expandedType, 1789 MemRefType collapsedType) { 1790 if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) 1791 return failure(); 1792 auto maps = op.getReassociationMaps(); 1793 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1794 if (collapsedType != expectedType) 1795 return op.emitOpError("expected collapsed type to be ") 1796 << expectedType << ", but got " << collapsedType; 1797 return success(); 1798 } 1799 1800 static LogicalResult verify(ExpandShapeOp op) { 1801 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1802 } 1803 1804 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1805 MLIRContext *context) { 1806 results.add<CollapseReshapeOps<ExpandShapeOp>, 1807 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1808 } 1809 1810 static LogicalResult verify(CollapseShapeOp op) { 1811 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1812 } 1813 1814 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1815 MLIRContext *context) { 1816 results.add<CollapseReshapeOps<CollapseShapeOp>, 1817 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context); 1818 } 1819 1820 //===----------------------------------------------------------------------===// 1821 // TensorReshapeOp 1822 //===----------------------------------------------------------------------===// 1823 1824 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`. 1825 static RankedTensorType 1826 computeTensorReshapeCollapsedType(RankedTensorType type, 1827 ArrayRef<AffineMap> reassociation) { 1828 auto shape = type.getShape(); 1829 SmallVector<int64_t, 4> newShape; 1830 newShape.reserve(reassociation.size()); 1831 1832 // Use the fact that reassociation is valid to simplify the logic: only use 1833 // each map's rank. 1834 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1835 unsigned currentDim = 0; 1836 for (AffineMap m : reassociation) { 1837 unsigned dim = m.getNumResults(); 1838 auto band = shape.slice(currentDim, dim); 1839 int64_t size = 1; 1840 if (llvm::is_contained(band, ShapedType::kDynamicSize)) 1841 size = ShapedType::kDynamicSize; 1842 else 1843 for (unsigned d = 0; d < dim; ++d) 1844 size *= shape[currentDim + d]; 1845 newShape.push_back(size); 1846 currentDim += dim; 1847 } 1848 1849 return RankedTensorType::get(newShape, type.getElementType()); 1850 } 1851 1852 void mlir::linalg::TensorCollapseShapeOp::build( 1853 OpBuilder &b, OperationState &result, Value src, 1854 ArrayRef<ReassociationIndices> reassociation, 1855 ArrayRef<NamedAttribute> attrs) { 1856 auto resultType = computeTensorReshapeCollapsedType( 1857 src.getType().cast<RankedTensorType>(), 1858 getSymbolLessAffineMaps( 1859 convertReassociationIndicesToExprs(b, reassociation))); 1860 build(b, result, resultType, src, attrs); 1861 result.addAttribute(getReassociationAttrName(), 1862 getReassociationIndicesAttribute(b, reassociation)); 1863 } 1864 1865 void mlir::linalg::TensorExpandShapeOp::build( 1866 OpBuilder &b, OperationState &result, Value src, 1867 ArrayRef<ReassociationIndices> reassociation, 1868 ArrayRef<NamedAttribute> attrs) { 1869 auto resultType = computeTensorReshapeCollapsedType( 1870 src.getType().cast<RankedTensorType>(), 1871 getSymbolLessAffineMaps( 1872 convertReassociationIndicesToExprs(b, reassociation))); 1873 build(b, result, resultType, src, attrs); 1874 result.addAttribute(getReassociationAttrName(), 1875 getReassociationIndicesAttribute(b, reassociation)); 1876 } 1877 1878 template <typename TensorReshapeOp> 1879 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, 1880 RankedTensorType expandedType, 1881 RankedTensorType collapsedType) { 1882 if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) 1883 return failure(); 1884 1885 auto maps = op.getReassociationMaps(); 1886 RankedTensorType expectedType = 1887 computeTensorReshapeCollapsedType(expandedType, maps); 1888 if (collapsedType != expectedType) 1889 return op.emitOpError("expected collapsed type to be ") 1890 << expectedType << ", but got " << collapsedType; 1891 return success(); 1892 } 1893 1894 static LogicalResult verify(TensorExpandShapeOp op) { 1895 return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType()); 1896 } 1897 1898 static LogicalResult verify(TensorCollapseShapeOp op) { 1899 return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType()); 1900 } 1901 1902 namespace { 1903 /// Reshape of a splat constant can be replaced with a constant of the result 1904 /// type. 1905 template <typename TensorReshapeOp> 1906 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> { 1907 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1908 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1909 PatternRewriter &rewriter) const override { 1910 DenseElementsAttr attr; 1911 if (!matchPattern(reshapeOp.src(), m_Constant(&attr))) 1912 return failure(); 1913 if (!attr || !attr.isSplat()) 1914 return failure(); 1915 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( 1916 reshapeOp.getResultType(), attr.getRawData(), true); 1917 rewriter.replaceOpWithNewOp<ConstantOp>(reshapeOp, newAttr); 1918 return success(); 1919 } 1920 }; 1921 1922 /// Fold linalg.fill -> linalg.tensor_reshape chain. 1923 /// 1924 /// For such op chains, we can create new linalg.fill ops with the result 1925 /// type of the linalg.tensor_reshape op. 1926 template <typename TensorReshapeOp> 1927 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> { 1928 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1929 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1930 PatternRewriter &rewriter) const override { 1931 auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>(); 1932 if (!oldFill) 1933 return failure(); 1934 1935 Location loc = oldFill.getLoc(); 1936 auto newInit = rewriter.create<TensorReshapeOp>( 1937 loc, reshapeOp.getResultType(), oldFill.output(), 1938 reshapeOp.reassociation()); 1939 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, oldFill.value(), newInit); 1940 1941 return success(); 1942 } 1943 }; 1944 } // namespace 1945 1946 void TensorExpandShapeOp::getCanonicalizationPatterns( 1947 RewritePatternSet &results, MLIRContext *context) { 1948 results 1949 .add<CollapseReshapeOps<TensorExpandShapeOp>, 1950 CollapseMixedReshapeOps<TensorExpandShapeOp, TensorCollapseShapeOp>, 1951 FoldFillWithTensorReshape<TensorExpandShapeOp>, 1952 FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>, 1953 FoldReshapeWithConstant<TensorExpandShapeOp>>(context); 1954 } 1955 1956 void TensorCollapseShapeOp::getCanonicalizationPatterns( 1957 RewritePatternSet &results, MLIRContext *context) { 1958 results 1959 .add<CollapseReshapeOps<TensorCollapseShapeOp>, 1960 CollapseMixedReshapeOps<TensorCollapseShapeOp, TensorExpandShapeOp>, 1961 FoldFillWithTensorReshape<TensorCollapseShapeOp>, 1962 FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>, 1963 FoldReshapeWithConstant<TensorCollapseShapeOp>>(context); 1964 } 1965 1966 LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim( 1967 OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) { 1968 auto resultShape = 1969 getAsValues(b, getLoc(), 1970 getReshapeOutputShapeFromInputShape( 1971 b, getLoc(), src(), getResultType().getShape(), 1972 getReassociationMaps())); 1973 reifiedReturnShapes.emplace_back(std::move(resultShape)); 1974 return success(); 1975 } 1976 1977 LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim( 1978 OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) { 1979 auto resultShape = 1980 getAsValues(b, getLoc(), 1981 getReshapeOutputShapeFromInputShape( 1982 b, getLoc(), src(), getResultType().getShape(), 1983 getReassociationMaps())); 1984 reifiedReturnShapes.emplace_back(std::move(resultShape)); 1985 return success(); 1986 } 1987 1988 //===----------------------------------------------------------------------===// 1989 // YieldOp 1990 //===----------------------------------------------------------------------===// 1991 1992 static void print(OpAsmPrinter &p, linalg::YieldOp op) { 1993 p << op.getOperationName(); 1994 if (op.getNumOperands() > 0) 1995 p << ' ' << op.getOperands(); 1996 p.printOptionalAttrDict(op->getAttrs()); 1997 if (op.getNumOperands() > 0) 1998 p << " : " << op.getOperandTypes(); 1999 } 2000 2001 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { 2002 SmallVector<OpAsmParser::OperandType, 2> opInfo; 2003 SmallVector<Type, 2> types; 2004 llvm::SMLoc loc = parser.getCurrentLocation(); 2005 return failure(parser.parseOperandList(opInfo) || 2006 parser.parseOptionalAttrDict(result.attributes) || 2007 (!opInfo.empty() && parser.parseColonTypeList(types)) || 2008 parser.resolveOperands(opInfo, types, loc, result.operands)); 2009 } 2010 2011 // Check the operand number and types must match the element types of the 2012 // LinalgOp interface's shaped operands. 2013 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { 2014 if (op.getNumOperands() != linalgOp.getNumOutputs()) 2015 return op.emitOpError("expected number of yield values (") 2016 << linalgOp.getNumOutputs() 2017 << ") to match the number of operands of the enclosing " 2018 << "LinalgOp (" << op.getNumOperands() << ")"; 2019 2020 for (OpOperand &opOperand : op->getOpOperands()) { 2021 OpOperand *outputOperand = 2022 linalgOp.getOutputOperand(opOperand.getOperandNumber()); 2023 Type elementType = getElementTypeOrSelf(outputOperand->get().getType()); 2024 if (opOperand.get().getType() != elementType) 2025 return op.emitOpError("type of yield operand ") 2026 << (opOperand.getOperandNumber() + 1) << " (" 2027 << opOperand.get().getType() << ") doesn't match " 2028 << "the element type of the enclosing linalg.generic op (" 2029 << elementType << ")"; 2030 } 2031 return success(); 2032 } 2033 2034 static LogicalResult verify(linalg::YieldOp op) { 2035 auto *parentOp = op->getParentOp(); 2036 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) 2037 return op.emitOpError("expected single non-empty parent region"); 2038 2039 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp)) 2040 return verifyYield(op, cast<LinalgOp>(parentOp)); 2041 2042 if (auto padTensorOp = dyn_cast<linalg::PadTensorOp>(parentOp)) { 2043 if (op.getNumOperands() != 1) 2044 return op.emitOpError("expected single yield operand (got ") 2045 << op->getNumOperands() << ")"; 2046 if (op.getOperand(0).getType() != 2047 padTensorOp.getType().cast<ShapedType>().getElementType()) 2048 return op.emitOpError("expected yield type to match shape element type"); 2049 return success(); 2050 } 2051 2052 if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) { 2053 // Check if output args with tensor types match results types. 2054 SmallVector<Value, 2> tensorOuts; 2055 llvm::copy_if( 2056 tiledLoopOp.outputs(), std::back_inserter(tensorOuts), 2057 [&](Value out) { return out.getType().isa<RankedTensorType>(); }); 2058 if (tensorOuts.size() != op.values().size()) 2059 return op.emitOpError("expected number of tensor output args = ") 2060 << tensorOuts.size() << " to match the number of yield operands = " 2061 << op.values().size(); 2062 2063 TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts)); 2064 for (auto &item : 2065 llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) { 2066 Type outType, resultType; 2067 unsigned index = item.index(); 2068 std::tie(outType, resultType) = item.value(); 2069 if (outType != resultType) 2070 return op.emitOpError("expected yield operand ") 2071 << index << " with type = " << resultType 2072 << " to match output arg type = " << outType; 2073 } 2074 return success(); 2075 } 2076 return op.emitOpError("expected parent op with LinalgOp interface"); 2077 } 2078 2079 //===----------------------------------------------------------------------===// 2080 // TiledLoopOp 2081 //===----------------------------------------------------------------------===// 2082 2083 void TiledLoopOp::build(OpBuilder &builder, OperationState &result, 2084 ValueRange lowerBounds, ValueRange upperBounds, 2085 ValueRange steps, ValueRange inputs, ValueRange outputs, 2086 ArrayAttr iteratorTypes, 2087 function_ref<void(OpBuilder &, Location, ValueRange, 2088 ValueRange, ValueRange)> 2089 bodyBuilderFn) { 2090 build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs, 2091 iteratorTypes, llvm::None, bodyBuilderFn); 2092 } 2093 2094 void TiledLoopOp::build(OpBuilder &builder, OperationState &result, 2095 ValueRange lowerBounds, ValueRange upperBounds, 2096 ValueRange steps, ValueRange inputs, ValueRange outputs, 2097 ArrayAttr iteratorTypes, 2098 Optional<ArrayAttr> distributionTypes, 2099 function_ref<void(OpBuilder &, Location, ValueRange, 2100 ValueRange, ValueRange)> 2101 bodyBuilderFn) { 2102 result.addOperands(lowerBounds); 2103 result.addOperands(upperBounds); 2104 result.addOperands(steps); 2105 result.addOperands(inputs); 2106 result.addOperands(outputs); 2107 result.addAttribute( 2108 TiledLoopOp::getOperandSegmentSizeAttr(), 2109 builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()), 2110 static_cast<int32_t>(upperBounds.size()), 2111 static_cast<int32_t>(steps.size()), 2112 static_cast<int32_t>(inputs.size()), 2113 static_cast<int32_t>(outputs.size())})); 2114 result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); 2115 2116 if (distributionTypes.hasValue()) 2117 result.addAttribute(getDistributionTypesAttrName(), 2118 distributionTypes.getValue()); 2119 2120 // Add output types for `RankedTensorType` output arguments. 2121 for (Value output : outputs) { 2122 Type outputType = output.getType(); 2123 if (outputType.isa<RankedTensorType>()) 2124 result.addTypes(outputType); 2125 } 2126 2127 OpBuilder::InsertionGuard guard(builder); 2128 unsigned numIVs = steps.size(); 2129 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType()); 2130 for (Type type : TypeRange(inputs)) 2131 argTypes.push_back(type); 2132 for (Type type : TypeRange(outputs)) 2133 argTypes.push_back(type); 2134 Region *bodyRegion = result.addRegion(); 2135 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes); 2136 2137 if (bodyBuilderFn) { 2138 builder.setInsertionPointToStart(bodyBlock); 2139 bodyBuilderFn(builder, result.location, 2140 bodyBlock->getArguments().take_front(numIVs), 2141 bodyBlock->getArguments().slice(numIVs, inputs.size()), 2142 bodyBlock->getArguments().take_back(outputs.size())); 2143 TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); 2144 } 2145 } 2146 2147 static void print(OpAsmPrinter &p, TiledLoopOp op) { 2148 p << op.getOperationName() << " (" << op.getInductionVars() << ") = (" 2149 << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step() 2150 << ")"; 2151 2152 if (!op.inputs().empty()) { 2153 p << " ins ("; 2154 llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p, 2155 [&](auto it) { 2156 p << std::get<0>(it) << " = " << std::get<1>(it) 2157 << ": " << std::get<1>(it).getType(); 2158 }); 2159 p << ")"; 2160 } 2161 if (!op.outputs().empty()) { 2162 p << " outs ("; 2163 llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p, 2164 [&](auto it) { 2165 p << std::get<0>(it) << " = " << std::get<1>(it) 2166 << ": " << std::get<1>(it).getType(); 2167 }); 2168 p << ")"; 2169 } 2170 2171 if (llvm::any_of(op.iterator_types(), [](Attribute attr) { 2172 return attr.cast<StringAttr>().getValue() != 2173 getParallelIteratorTypeName(); 2174 })) 2175 p << " iterators" << op.iterator_types() << ""; 2176 2177 if (op.distribution_types().hasValue()) 2178 p << " distribution" << op.distribution_types().getValue() << ""; 2179 2180 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 2181 p.printOptionalAttrDict( 2182 op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(), 2183 getIteratorTypesAttrName(), 2184 getDistributionTypesAttrName()}); 2185 } 2186 2187 static ParseResult parseTiledLoopOp(OpAsmParser &parser, 2188 OperationState &result) { 2189 auto &builder = parser.getBuilder(); 2190 // Parse an opening `(` followed by induction variables followed by `)` 2191 SmallVector<OpAsmParser::OperandType, 4> ivs; 2192 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 2193 OpAsmParser::Delimiter::Paren)) 2194 return failure(); 2195 2196 // Parse loop bounds. 2197 SmallVector<OpAsmParser::OperandType, 4> lower; 2198 if (parser.parseEqual() || 2199 parser.parseOperandList(lower, ivs.size(), 2200 OpAsmParser::Delimiter::Paren) || 2201 parser.resolveOperands(lower, builder.getIndexType(), result.operands)) 2202 return failure(); 2203 2204 SmallVector<OpAsmParser::OperandType, 4> upper; 2205 if (parser.parseKeyword("to") || 2206 parser.parseOperandList(upper, ivs.size(), 2207 OpAsmParser::Delimiter::Paren) || 2208 parser.resolveOperands(upper, builder.getIndexType(), result.operands)) 2209 return failure(); 2210 2211 // Parse step values. 2212 SmallVector<OpAsmParser::OperandType, 4> steps; 2213 if (parser.parseKeyword("step") || 2214 parser.parseOperandList(steps, ivs.size(), 2215 OpAsmParser::Delimiter::Paren) || 2216 parser.resolveOperands(steps, builder.getIndexType(), result.operands)) 2217 return failure(); 2218 2219 // Parse input tensors. 2220 SmallVector<OpAsmParser::OperandType, 4> inputs, input_region_args; 2221 SmallVector<Type, 4> inputTypes; 2222 if (succeeded(parser.parseOptionalKeyword("ins"))) { 2223 llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation(); 2224 2225 if (parser.parseAssignmentListWithTypes(input_region_args, inputs, 2226 inputTypes)) 2227 return failure(); 2228 2229 if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc, 2230 result.operands)) 2231 return failure(); 2232 } 2233 2234 // Parse output tensors. 2235 SmallVector<OpAsmParser::OperandType, 4> outputs, output_region_args; 2236 SmallVector<Type, 4> outputTypes; 2237 if (succeeded(parser.parseOptionalKeyword("outs"))) { 2238 llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation(); 2239 2240 if (parser.parseAssignmentListWithTypes(output_region_args, outputs, 2241 outputTypes)) 2242 return failure(); 2243 2244 if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc, 2245 result.operands)) 2246 return failure(); 2247 for (Type outputType : outputTypes) 2248 if (outputType.isa<RankedTensorType>()) 2249 result.addTypes(outputType); 2250 } 2251 2252 // Parse attributes. 2253 SmallVector<Attribute, 4> iterTypes, distributionTypes; 2254 auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) { 2255 if (succeeded(parser.parseOptionalKeyword(keyword))) { 2256 StringAttr attr; 2257 2258 if (parser.parseLSquare() || parser.parseAttribute(attr)) 2259 return failure(); 2260 attrs->push_back(attr); 2261 for (int i = 1, e = ivs.size(); i < e; ++i) { 2262 if (parser.parseComma() || parser.parseAttribute(attr)) 2263 return failure(); 2264 attrs->push_back(attr); 2265 } 2266 if (parser.parseRSquare()) 2267 return failure(); 2268 } 2269 return success(); 2270 }; 2271 if (failed(parseAttr("iterators", &iterTypes)) || 2272 failed(parseAttr("distribution", &distributionTypes))) 2273 return failure(); 2274 2275 // Set all loop iterator types to "parallel" if they are not printed in IR. 2276 if (iterTypes.empty()) { 2277 auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName()); 2278 iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter); 2279 } 2280 result.addAttribute(getIteratorTypesAttrName(), 2281 builder.getArrayAttr(iterTypes)); 2282 if (!distributionTypes.empty()) 2283 result.addAttribute(getDistributionTypesAttrName(), 2284 builder.getArrayAttr(distributionTypes)); 2285 result.addAttribute( 2286 TiledLoopOp::getOperandSegmentSizeAttr(), 2287 builder.getI32VectorAttr({static_cast<int32_t>(lower.size()), 2288 static_cast<int32_t>(upper.size()), 2289 static_cast<int32_t>(steps.size()), 2290 static_cast<int32_t>(inputs.size()), 2291 static_cast<int32_t>(outputs.size())})); 2292 2293 // Parse the body. 2294 Region *body = result.addRegion(); 2295 2296 SmallVector<Type, 4> region_types(ivs.size(), builder.getIndexType()); 2297 region_types.append(inputTypes); 2298 region_types.append(outputTypes); 2299 2300 SmallVector<OpAsmParser::OperandType, 4> region_args(ivs); 2301 region_args.append(input_region_args); 2302 region_args.append(output_region_args); 2303 2304 if (parser.parseRegion(*body, region_args, region_types)) 2305 return failure(); 2306 2307 // Parse optional attributes. 2308 parser.parseOptionalAttrDict(result.attributes); 2309 2310 return success(); 2311 } 2312 2313 Region &TiledLoopOp::getLoopBody() { return region(); } 2314 2315 LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef<Operation *> ops) { 2316 for (auto *op : ops) 2317 op->moveBefore(*this); 2318 return success(); 2319 } 2320 2321 bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) { 2322 return !region().isAncestor(value.getParentRegion()); 2323 } 2324 2325 static LogicalResult verify(TiledLoopOp op) { 2326 // Check if iterator types are provided for every loop dimension. 2327 if (op.iterator_types().size() != op.getNumLoops()) 2328 return op.emitOpError("expected iterator types array attribute size = ") 2329 << op.iterator_types().size() 2330 << " to match the number of loops = " << op.getNumLoops(); 2331 2332 // Check if types of input arguments match region args types. 2333 for (auto &item : 2334 llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) { 2335 Value input, inputRegionArg; 2336 unsigned index = item.index(); 2337 std::tie(input, inputRegionArg) = item.value(); 2338 if (input.getType() != inputRegionArg.getType()) 2339 return op.emitOpError("expected input arg ") 2340 << index << " with type = " << input.getType() 2341 << " to match region arg " << index + op.getNumLoops() 2342 << " type = " << inputRegionArg.getType(); 2343 } 2344 2345 // Check if types of input arguments match region args types. 2346 for (auto &item : 2347 llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) { 2348 Value output, outputRegionArg; 2349 unsigned index = item.index(); 2350 std::tie(output, outputRegionArg) = item.value(); 2351 if (output.getType() != outputRegionArg.getType()) 2352 return op.emitOpError("expected output arg ") 2353 << index << " with type = " << output.getType() 2354 << " to match region arg " 2355 << index + op.getNumLoops() + op.inputs().size() 2356 << " type = " << outputRegionArg.getType(); 2357 } 2358 return success(); 2359 } 2360 2361 namespace { 2362 2363 static constexpr int64_t kNoMatch = -1; 2364 2365 // Folds away TiledLoopOp inputs if they have no uses within the body. 2366 // 2367 // Example: 2368 // 2369 // %0 = linalg.tiled_loop ... ins (%in_ = %in: tensor<...>, 2370 // %in_buf_ = %in_buf: memref<...>) {...} 2371 // Becomes 2372 // 2373 // linalg.tiled_loop ... ins (%in_buf_ = %in_buf: memref<...>) {...} 2374 struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> { 2375 using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern; 2376 2377 LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, 2378 PatternRewriter &rewriter) const final { 2379 SmallVector<Value, 2> newInputs, regionInputTensorArgs; 2380 // Store ids of the corresponding old and new input operands. 2381 SmallVector<int64_t, 2> oldInputIdToNew(tiledLoop.inputs().size(), 2382 kNoMatch); 2383 for (auto en : llvm::enumerate( 2384 llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) { 2385 Value in, bbArg; 2386 size_t index = en.index(); 2387 std::tie(in, bbArg) = en.value(); 2388 if (!bbArg.use_empty()) { 2389 oldInputIdToNew[index] = newInputs.size(); 2390 newInputs.push_back(in); 2391 } 2392 } 2393 if (newInputs.size() == tiledLoop.inputs().size()) 2394 return failure(); 2395 Location loc = tiledLoop.getLoc(); 2396 auto newTiledLoop = rewriter.create<TiledLoopOp>( 2397 loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), 2398 newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(), 2399 tiledLoop.distribution_types()); 2400 2401 // Clone the region. 2402 BlockAndValueMapping bvm; 2403 bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); 2404 bvm.map(tiledLoop.getRegionOutputArgs(), 2405 newTiledLoop.getRegionOutputArgs()); 2406 for (const auto &en : llvm::enumerate(oldInputIdToNew)) 2407 if (en.value() != kNoMatch) 2408 bvm.map(tiledLoop.getRegionInputArgs()[en.index()], 2409 newTiledLoop.getRegionInputArgs()[en.value()]); 2410 OpBuilder innerBuilder = 2411 OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); 2412 for (auto &op : *tiledLoop.getBody()) 2413 innerBuilder.clone(op, bvm); 2414 rewriter.replaceOp(tiledLoop, newTiledLoop.getResults()); 2415 2416 return success(); 2417 } 2418 }; 2419 2420 // Folds away TiledLoopOp output tensors when the following conditions are met: 2421 // * result of `linalg.tiled_loop` has no uses 2422 // * output tensor is the argument of `linalg.yield` 2423 // 2424 // Example: 2425 // 2426 // %0 = linalg.tiled_loop ... outs (%o_ = %out: tensor<...>, 2427 // %obuf_ = %out_buf: memref<...>) { 2428 // ... 2429 // linalg.yield %o_ : tensor ... 2430 // } 2431 // 2432 // Becomes 2433 // 2434 // linalg.tiled_loop ... outs (%obuf_ = %out_buf: memref<...>) { 2435 // ... 2436 // linalg.yield 2437 // } 2438 struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> { 2439 using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern; 2440 2441 LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop, 2442 PatternRewriter &rewriter) const final { 2443 if (tiledLoop.getNumResults() == 0) 2444 return failure(); 2445 2446 Block *block = tiledLoop.getBody(); 2447 auto yieldOp = cast<linalg::YieldOp>(block->getTerminator()); 2448 2449 // Match the pattern and collect output buffers that will replace the output 2450 // tensors and also the ops that will be ignored when cloning the body. 2451 SmallVector<Value, 2> newOutputOperands, newYieldArgs; 2452 int resultId = 0; 2453 // Store ids of the corresponding old and new output operands. 2454 SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(), 2455 kNoMatch); 2456 // Store ids of the corresponding old and new results. 2457 SmallVector<int64_t, 2> oldResultIdToNew(tiledLoop.getNumResults(), 2458 kNoMatch); 2459 SmallVector<Value, 2> resultReplacement(tiledLoop.getNumResults()); 2460 for (auto en : llvm::enumerate( 2461 llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) { 2462 size_t index = en.index(); 2463 Value out = std::get<0>(en.value()); 2464 Value outRegionArg = std::get<1>(en.value()); 2465 2466 if (!out.getType().isa<RankedTensorType>()) { 2467 oldOutputIdToNew[index] = newOutputOperands.size(); 2468 newOutputOperands.push_back(out); 2469 continue; 2470 } 2471 Value result = tiledLoop.getResult(resultId); 2472 Value yieldArg = yieldOp.getOperand(resultId); 2473 if (yieldArg != outRegionArg || !result.use_empty()) { 2474 oldOutputIdToNew[index] = newOutputOperands.size(); 2475 oldResultIdToNew[resultId] = newYieldArgs.size(); 2476 resultReplacement[resultId] = out; 2477 newOutputOperands.push_back(out); 2478 newYieldArgs.push_back(yieldArg); 2479 } 2480 ++resultId; 2481 } 2482 if (newOutputOperands.size() == tiledLoop.outputs().size()) 2483 return failure(); 2484 2485 Location loc = tiledLoop.getLoc(); 2486 auto newTiledLoop = rewriter.create<TiledLoopOp>( 2487 loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), 2488 tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(), 2489 tiledLoop.distribution_types()); 2490 2491 // Clone the region. 2492 BlockAndValueMapping bvm; 2493 bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars()); 2494 bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs()); 2495 for (const auto &en : llvm::enumerate(oldOutputIdToNew)) { 2496 if (en.value() != kNoMatch) 2497 bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], 2498 newTiledLoop.getRegionOutputArgs()[en.value()]); 2499 else 2500 bvm.map(tiledLoop.getRegionOutputArgs()[en.index()], 2501 tiledLoop.outputs()[en.index()]); 2502 } 2503 OpBuilder innerBuilder = 2504 OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener()); 2505 for (auto &op : tiledLoop.getBody()->without_terminator()) 2506 innerBuilder.clone(op, bvm); 2507 innerBuilder.create<linalg::YieldOp>( 2508 loc, llvm::to_vector<2>(llvm::map_range( 2509 newYieldArgs, [&](Value arg) { return bvm.lookup(arg); }))); 2510 2511 for (const auto &en : llvm::enumerate(oldResultIdToNew)) 2512 if (en.value() != kNoMatch) 2513 resultReplacement[en.index()] = newTiledLoop.getResult(en.value()); 2514 rewriter.replaceOp(tiledLoop, resultReplacement); 2515 2516 return success(); 2517 } 2518 }; 2519 } // namespace 2520 2521 void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 2522 MLIRContext *context) { 2523 results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder>(context); 2524 } 2525 2526 LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>, 2527 SmallVectorImpl<OpFoldResult> &) { 2528 return foldMemRefCastInTiledLoopOp(*this); 2529 } 2530 2531 //===----------------------------------------------------------------------===// 2532 // IndexOp 2533 //===----------------------------------------------------------------------===// 2534 2535 static LogicalResult verify(IndexOp op) { 2536 auto linalgOp = dyn_cast<LinalgOp>(op->getParentOp()); 2537 if (!linalgOp) 2538 return op.emitOpError("expected parent op with LinalgOp interface"); 2539 if (linalgOp.getNumLoops() <= op.dim()) 2540 return op.emitOpError("expected dim (") 2541 << op.dim() << ") to be lower than the number of loops (" 2542 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; 2543 return success(); 2544 } 2545 2546 /////// Operations corresponding to library calls defined with Tablegen //////// 2547 2548 template <typename LinalgPoolingOp> 2549 static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, 2550 ArrayRef<Attribute> attrs, 2551 bool isStride) { 2552 auto strideOrDilation = isStride ? "stride" : "dilation"; 2553 if (attrs.size() != op.getNumWindowLoops()) 2554 return op.emitOpError("expects num ") 2555 << strideOrDilation 2556 << "s equal to number of window dimensions: " << attrs.size() 2557 << " vs " << op.getNumWindowLoops(); 2558 return success(); 2559 } 2560 2561 void ConvOp::getEffects( 2562 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 2563 &effects) { 2564 effects.emplace_back(MemoryEffects::Read::get(), input(), 2565 SideEffects::DefaultResource::get()); 2566 effects.emplace_back(MemoryEffects::Read::get(), filter(), 2567 SideEffects::DefaultResource::get()); 2568 effects.emplace_back(MemoryEffects::Write::get(), output(), 2569 SideEffects::DefaultResource::get()); 2570 } 2571 2572 static LogicalResult verify(ConvOp op) { 2573 auto oType = op.output().getType().cast<MemRefType>(); 2574 auto fType = op.filter().getType().cast<MemRefType>(); 2575 auto iType = op.input().getType().cast<MemRefType>(); 2576 if (oType.getElementType() != iType.getElementType() || 2577 oType.getElementType() != fType.getElementType()) 2578 return op.emitOpError("expects memref elemental types to match"); 2579 if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank()) 2580 return op.emitOpError("expects memref ranks to match"); 2581 if (auto strides = op.strides()) { 2582 if (failed(verifyStrideOrDilation(op, strides->getValue(), 2583 /*isStride=*/true))) 2584 return failure(); 2585 } 2586 if (auto dilations = op.dilations()) { 2587 if (failed(verifyStrideOrDilation(op, dilations->getValue(), 2588 /*isStride=*/false))) 2589 return failure(); 2590 } 2591 return success(); 2592 } 2593 2594 template <typename PoolingOp> 2595 static LogicalResult verifySingleInputPoolingOp(PoolingOp op) { 2596 auto inputType = op.input().getType().template cast<MemRefType>(); 2597 auto outputType = op.output().getType().template cast<MemRefType>(); 2598 if (outputType.getElementType() != inputType.getElementType()) 2599 return op.emitOpError("expects memref elemental types to match"); 2600 2601 auto windowDimsType = op.windowDims().getType().template cast<MemRefType>(); 2602 if (outputType.getRank() != inputType.getRank() || 2603 outputType.getRank() != windowDimsType.getRank()) 2604 return op.emitOpError("expects memref ranks to match"); 2605 2606 if (auto strides = op.strides()) { 2607 if (failed(verifyStrideOrDilation(op, strides->getValue(), 2608 /*isStride=*/true))) 2609 return failure(); 2610 } 2611 if (auto dilations = op.dilations()) { 2612 if (failed(verifyStrideOrDilation(op, dilations->getValue(), 2613 /*isStride=*/false))) 2614 return failure(); 2615 } 2616 return success(); 2617 } 2618 2619 #define DEFINE_POOLING_OP_GET_EFFECTS(OP_NAME) \ 2620 void OP_NAME::getEffects( \ 2621 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \ 2622 &effects) { \ 2623 effects.emplace_back(MemoryEffects::Read::get(), input(), \ 2624 SideEffects::DefaultResource::get()); \ 2625 effects.emplace_back(MemoryEffects::Write::get(), output(), \ 2626 SideEffects::DefaultResource::get()); \ 2627 } 2628 2629 static LogicalResult verify(PoolingMaxOp op) { 2630 return verifySingleInputPoolingOp(op); 2631 } 2632 static LogicalResult verify(PoolingMinOp op) { 2633 return verifySingleInputPoolingOp(op); 2634 } 2635 static LogicalResult verify(PoolingSumOp op) { 2636 return verifySingleInputPoolingOp(op); 2637 } 2638 2639 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp) 2640 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp) 2641 DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp) 2642 2643 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc" 2644 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" 2645 2646 #define GET_OP_CLASSES 2647 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" 2648 2649 #define GET_OP_CLASSES 2650 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 2651 2652 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. 2653 /// Assumes `op` is a LinalgOp. 2654 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, 2655 SmallVectorImpl<AffineExpr> &res) { 2656 if (!cast<LinalgOp>(op).iterator_types()) 2657 return; 2658 2659 unsigned dim = 0; 2660 MLIRContext *ctx = op->getContext(); 2661 for (auto tn : 2662 cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) { 2663 if (tn == iteratorTypeName) 2664 res.push_back(getAffineDimExpr(dim, ctx)); 2665 ++dim; 2666 } 2667 } 2668 2669 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap, 2670 unsigned rank, 2671 MLIRContext *context) { 2672 if (maybeMap) 2673 return maybeMap.getValue(); 2674 if (rank == 0) 2675 return AffineMap::get(context); 2676 return AffineMap::getMultiDimIdentityMap(rank, context); 2677 } 2678 2679 SmallVector<AffineExpr, 4> 2680 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, 2681 MLIRContext *context) { 2682 SmallVector<AffineExpr, 4> res; 2683 res.reserve(num); 2684 for (unsigned i = 0; i < num; ++i) 2685 res.push_back(getAffineDimExpr(startIdx++, context)); 2686 return res; 2687 } 2688 2689 template <typename PoolingOp> 2690 SmallVector<AffineExpr, 4> 2691 mlir::linalg::weightedPoolingInputIndex(PoolingOp op, 2692 ArrayRef<AffineExpr> outputDims, 2693 ArrayRef<AffineExpr> windowDims) { 2694 assert(outputDims.size() == windowDims.size()); 2695 SmallVector<AffineExpr, 4> res; 2696 res.reserve(outputDims.size()); 2697 for (unsigned i = 0, e = outputDims.size(); i < e; ++i) { 2698 // TODO: add a level of indirection to linalg.generic. 2699 auto expr = op.getStride(i) * outputDims[i] + 2700 op.getDilation(i) * windowDims[i] - op.getLowPad(i); 2701 res.push_back(expr); 2702 } 2703 return res; 2704 } 2705 2706 #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE) \ 2707 template SmallVector<AffineExpr, 4> \ 2708 mlir::linalg::weightedPoolingInputIndex<OP_TYPE>( \ 2709 OP_TYPE op, ArrayRef<AffineExpr> outputDims, \ 2710 ArrayRef<AffineExpr> windowDims); 2711 2712 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp) 2713 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp) 2714 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMinOp) 2715 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingSumOp) 2716 2717 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a, 2718 ArrayRef<AffineExpr> b) { 2719 auto rangeA = llvm::make_range(a.begin(), a.end()); 2720 auto rangeB = llvm::make_range(b.begin(), b.end()); 2721 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB); 2722 return llvm::to_vector<4>(concatRanges); 2723 } 2724 2725 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { 2726 if (auto memref = t.dyn_cast<MemRefType>()) { 2727 ss << "view"; 2728 for (auto size : memref.getShape()) 2729 if (size < 0) 2730 ss << "sx"; 2731 else 2732 ss << size << "x"; 2733 appendMangledType(ss, memref.getElementType()); 2734 } else if (auto vec = t.dyn_cast<VectorType>()) { 2735 ss << "vector"; 2736 llvm::interleave( 2737 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); 2738 appendMangledType(ss, vec.getElementType()); 2739 } else if (t.isSignlessIntOrIndexOrFloat()) { 2740 ss << t; 2741 } else { 2742 llvm_unreachable("Invalid type for linalg library name mangling"); 2743 } 2744 } 2745 2746 std::string mlir::linalg::generateLibraryCallName(Operation *op) { 2747 assert(isa<LinalgOp>(op)); 2748 std::string name(op->getName().getStringRef().str()); 2749 name.reserve(128); 2750 std::replace(name.begin(), name.end(), '.', '_'); 2751 llvm::raw_string_ostream ss(name); 2752 ss << "_"; 2753 auto types = op->getOperandTypes(); 2754 llvm::interleave( 2755 types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, 2756 [&]() { ss << "_"; }); 2757 return ss.str(); 2758 } 2759 2760 // TODO: Consider making all this boilerplate easy to autogenerate 2761 // with Tablegen. This seems a desirable property in the context of 2762 // OpInterfaces where a Linalg "named" op **isa** LinalgOp. 2763 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 2764 if (succeeded(foldMemRefCast(*this))) 2765 return getResult(); 2766 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 2767 } 2768 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 2769 if (succeeded(foldMemRefCast(*this))) 2770 return getResult(); 2771 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 2772 } 2773 OpFoldResult TensorExpandShapeOp::fold(ArrayRef<Attribute> operands) { 2774 return foldReshapeOp<TensorExpandShapeOp, TensorCollapseShapeOp>(*this, 2775 operands); 2776 } 2777 OpFoldResult TensorCollapseShapeOp::fold(ArrayRef<Attribute> operands) { 2778 return foldReshapeOp<TensorCollapseShapeOp, TensorExpandShapeOp>(*this, 2779 operands); 2780 } 2781 2782 //===----------------------------------------------------------------------===// 2783 // Support for named Linalg ops defined in ods-gen. 2784 //===----------------------------------------------------------------------===// 2785 2786 /// Generic entry point to create the block for the region of a LinalgOp. 2787 /// This is used by both named structured ops created by ods-gen and by manually 2788 /// defined C++ ops. 2789 /// This is used by both builders and parsers. 2790 /// This function creates the block in the region with arguments corresponding 2791 /// to the elemental types of `inputTypes` and `outputTypes`, which are asserted 2792 /// to be ShapedType. 2793 template <typename NamedStructuredOpType> 2794 static void 2795 fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, 2796 TypeRange inputTypes, TypeRange outputTypes, 2797 std::function<void(unsigned, unsigned)> errorHandler) { 2798 assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); })); 2799 2800 // TODO: atm all operands go through getElementTypeOrSelf, 2801 // reconsider when we have evidence we need to. 2802 SmallVector<Type, 8> argTypes; 2803 for (auto containers : {inputTypes, outputTypes}) 2804 for (auto t : containers) 2805 argTypes.push_back(getElementTypeOrSelf(t)); 2806 2807 // RAII. 2808 OpBuilder::InsertionGuard guard(opBuilder); 2809 Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); 2810 unsigned actual = body->getNumArguments(); 2811 unsigned expected = NamedStructuredOpType::getNumRegionArgs(); 2812 if (expected != actual) { 2813 if (errorHandler) 2814 errorHandler(expected, actual); 2815 return; 2816 } 2817 2818 opBuilder.setInsertionPointToStart(body); 2819 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); 2820 NamedStructuredOpType::regionBuilder(b, *body); 2821 2822 // indexing_maps is an auto-generated method. 2823 2824 // iterator_types is an auto-generated method. 2825 } 2826 2827 /// Generic entry point to create both the region and the block of a LinalgOp. 2828 template <typename NamedStructuredOpType> 2829 void createAndFillStructuredOpRegion(OpBuilder &opBuilder, 2830 OperationState &result, 2831 TypeRange inputTypes, 2832 TypeRange outputTypes) { 2833 Region ®ion = *result.addRegion(); 2834 fillStructuredOpRegion<NamedStructuredOpType>( 2835 opBuilder, region, inputTypes, outputTypes, 2836 [&](unsigned expected, unsigned actual) { 2837 assert(expected != actual && "incorrect number of arguments"); 2838 }); 2839 } 2840 2841 /// Common parsing used for both named structured ops created by ods-gen and by 2842 /// manually defined C++ ops. Does not handle regions. 2843 static ParseResult 2844 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, 2845 SmallVectorImpl<Type> &inputTypes, 2846 SmallVectorImpl<Type> &outputTypes) { 2847 llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc; 2848 SmallVector<OpAsmParser::OperandType, 4> inputsOperands, outputsOperands; 2849 2850 parser.parseOptionalAttrDict(result.attributes); 2851 2852 if (succeeded(parser.parseOptionalKeyword("ins"))) { 2853 if (parser.parseLParen()) 2854 return failure(); 2855 2856 inputsOperandsLoc = parser.getCurrentLocation(); 2857 if (parser.parseOperandList(inputsOperands) || 2858 parser.parseColonTypeList(inputTypes) || parser.parseRParen()) 2859 return failure(); 2860 } 2861 2862 if (succeeded(parser.parseOptionalKeyword("outs"))) { 2863 outputsOperandsLoc = parser.getCurrentLocation(); 2864 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || 2865 parser.parseColonTypeList(outputTypes) || parser.parseRParen()) 2866 return failure(); 2867 } 2868 2869 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, 2870 result.operands) || 2871 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, 2872 result.operands)) 2873 return failure(); 2874 2875 result.addAttribute("operand_segment_sizes", 2876 parser.getBuilder().getI32VectorAttr( 2877 {static_cast<int32_t>(inputsOperands.size()), 2878 static_cast<int32_t>(outputsOperands.size())})); 2879 return success(); 2880 } 2881 2882 template <typename NamedStructuredOpType> 2883 static void printCommonStructuredOpParts(OpAsmPrinter &p, 2884 NamedStructuredOpType op) { 2885 if (!op.inputs().empty()) 2886 p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; 2887 if (!op.outputs().empty()) 2888 p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; 2889 } 2890 2891 //===----------------------------------------------------------------------===// 2892 // Specific parsing and printing for named structured ops created by ods-gen. 2893 //===----------------------------------------------------------------------===// 2894 2895 template <typename NamedStructuredOpType> 2896 static ParseResult 2897 parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, 2898 TypeRange inputTypes, TypeRange outputTypes) { 2899 ParseResult res = success(); 2900 OpBuilder opBuilder(parser.getBuilder().getContext()); 2901 // Resolve `captures` into `capturedValues` at parse time so we can build the 2902 // region with captures. 2903 SmallVector<Value> capturedValues; 2904 fillStructuredOpRegion<NamedStructuredOpType>( 2905 opBuilder, region, inputTypes, outputTypes, 2906 [&](unsigned expected, unsigned actual) { 2907 res = parser.emitError( 2908 parser.getCurrentLocation(), 2909 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " 2910 "region expects {0} args, got {1}", 2911 expected, actual)); 2912 region.front().dump(); 2913 }); 2914 return res; 2915 } 2916 2917 static ParseResult 2918 parseNamedStructuredOpResults(OpAsmParser &parser, 2919 SmallVectorImpl<Type> &resultTypes) { 2920 if (parser.parseOptionalArrowTypeList(resultTypes)) 2921 return failure(); 2922 return success(); 2923 } 2924 2925 template <typename NamedStructuredOpType> 2926 static ParseResult parseNamedStructuredOp(OpAsmParser &parser, 2927 OperationState &result) { 2928 // TODO: Enable when ods-gen supports captures. 2929 SmallVector<Type, 1> inputTypes, outputTypes; 2930 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) 2931 return failure(); 2932 2933 // TODO: consider merging results parsing into region parsing. 2934 // Need to wait for declarative assembly resolution to decide. 2935 SmallVector<Type, 1> outputTensorsTypes; 2936 if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) 2937 return failure(); 2938 result.addTypes(outputTensorsTypes); 2939 2940 std::unique_ptr<Region> region = std::make_unique<Region>(); 2941 if (parseNamedStructuredOpRegion<NamedStructuredOpType>( 2942 parser, *region, inputTypes, outputTypes)) 2943 return failure(); 2944 result.addRegion(std::move(region)); 2945 2946 return success(); 2947 } 2948 2949 static void printNamedStructuredOpResults(OpAsmPrinter &p, 2950 TypeRange resultTypes) { 2951 if (resultTypes.empty()) 2952 return; 2953 p.printOptionalArrowTypeList(resultTypes); 2954 } 2955 2956 template <typename NamedStructuredOpType> 2957 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { 2958 p << op.getOperationName(); 2959 p.printOptionalAttrDict( 2960 op->getAttrs(), 2961 /*elidedAttrs=*/{"operand_segment_sizes", 2962 // See generated code in mlir-linalg-yaml-gen.cpp 2963 "linalg.memoized_indexing_maps"}); 2964 2965 // Printing is shared with generic ops, except for the region and 2966 // attributes. 2967 printCommonStructuredOpParts(p, op); 2968 2969 // Results printing. 2970 printNamedStructuredOpResults(p, op.result_tensors().getTypes()); 2971 2972 // Region is elided. 2973 } 2974 2975 template <typename NamedStructuredOpType> 2976 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { 2977 return verifyGenericOp<NamedStructuredOpType>(op); 2978 } 2979 2980 //===----------------------------------------------------------------------===// 2981 // Canonicalizers and Folders. 2982 //===----------------------------------------------------------------------===// 2983 2984 namespace { 2985 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> { 2986 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 2987 2988 LogicalResult matchAndRewrite(LinalgOp op, 2989 PatternRewriter &rewriter) const override { 2990 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 2991 // Linalg "inputs" may be either tensor or memref type. 2992 // tensor<0xelt_type> is a convention that may not always mean 2993 // "0 iterations". Only erase in cases we see memref<...x0x...>. 2994 auto mt = opOperand->get().getType().dyn_cast<MemRefType>(); 2995 if (!mt) 2996 continue; 2997 if (llvm::is_contained(op.getShape(opOperand), 0)) { 2998 rewriter.eraseOp(op); 2999 return success(); 3000 } 3001 } 3002 return failure(); 3003 } 3004 }; 3005 3006 struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> { 3007 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 3008 3009 LogicalResult matchAndRewrite(LinalgOp op, 3010 PatternRewriter &rewriter) const override { 3011 // If no operand comes from a tensor::CastOp and can be folded then fail. 3012 bool hasTensorCastOperand = 3013 llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { 3014 if (opOperand->get().isa<BlockArgument>()) 3015 return false; 3016 auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>(); 3017 return castOp && canFoldIntoConsumerOp(castOp); 3018 }); 3019 if (!hasTensorCastOperand) 3020 return failure(); 3021 3022 SmallVector<Type, 4> newResultTypes; 3023 newResultTypes.reserve(op->getNumResults()); 3024 SmallVector<Value, 4> newOperands; 3025 newOperands.reserve(op->getNumOperands()); 3026 // Inputs may fold. 3027 for (OpOperand *opOperand : op.getInputOperands()) { 3028 auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>(); 3029 newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) 3030 ? tensorCastOp.source() 3031 : opOperand->get()); 3032 } 3033 // Init tensors may fold, in which case the resultType must also change. 3034 for (OpOperand *opOperand : op.getOutputOperands()) { 3035 auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>(); 3036 bool fold = canFoldIntoConsumerOp(tensorCastOp); 3037 newOperands.push_back(fold ? tensorCastOp.getOperand() 3038 : opOperand->get()); 3039 newResultTypes.push_back(newOperands.back().getType()); 3040 } 3041 // Clone op. 3042 Operation *newOp = 3043 op.clone(rewriter, op->getLoc(), newResultTypes, newOperands); 3044 SmallVector<Value, 4> replacements; 3045 replacements.reserve(newOp->getNumResults()); 3046 for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { 3047 Value oldResult = std::get<0>(result); 3048 Value newResult = std::get<1>(result); 3049 if (newResult.getType() != oldResult.getType()) { 3050 replacements.push_back(rewriter.create<tensor::CastOp>( 3051 op->getLoc(), oldResult.getType(), newResult)); 3052 } else { 3053 replacements.push_back(newResult); 3054 } 3055 } 3056 rewriter.replaceOp(op, replacements); 3057 3058 return success(); 3059 } 3060 }; 3061 } // namespace 3062 3063 namespace { 3064 // Deduplicate redundant args of a linalg op. 3065 // An arg is redundant if it has the same Value and indexing map as another. 3066 struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> { 3067 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 3068 3069 LogicalResult matchAndRewrite(LinalgOp op, 3070 PatternRewriter &rewriter) const override { 3071 // This pattern reduces the number of arguments of an op, which breaks 3072 // the invariants of semantically charged named ops. 3073 if (!isa<GenericOp>(op)) 3074 return failure(); 3075 3076 // Associate each input to an equivalent "canonical" input that has the same 3077 // Value and indexing map. 3078 // 3079 // In the non-duplicate case, input `i` will have canonical input `i`. But 3080 // in the case of duplicated inputs, the canonical input could be some other 3081 // input `< i`. That is, a later input will have some earlier input as its 3082 // canonical input. 3083 llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput; 3084 // For later remapping tasks like deduplicating payload block arguments, 3085 // having a simple "inputIndex -> canonicalInputIndex" integer mapping is 3086 // convenient. 3087 SmallVector<unsigned> canonicalInputIndices; 3088 for (OpOperand *opOperand : op.getInputOperands()) { 3089 AffineMap indexingMap = op.getTiedIndexingMap(opOperand); 3090 // STL-like maps have a convenient behavior for our use case here. In the 3091 // case of duplicate keys, the insertion is rejected, and the returned 3092 // iterator gives access to the value already in the map. 3093 auto pair = canonicalInput.insert( 3094 {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); 3095 canonicalInputIndices.push_back(pair.first->second); 3096 } 3097 3098 // If there are no duplicate args, then bail out. 3099 if (canonicalInput.size() == op.getNumInputs()) 3100 return failure(); 3101 3102 // The operands for the newly canonicalized op. 3103 SmallVector<Value> newOperands; 3104 for (OpOperand *opOperand : op.getInputOperands()) 3105 if (canonicalInputIndices[opOperand->getOperandNumber()] == 3106 opOperand->getOperandNumber()) 3107 newOperands.push_back(opOperand->get()); 3108 SmallVector<Value> outputOperands = op.getOutputOperands(); 3109 llvm::append_range(newOperands, outputOperands); 3110 3111 // Repair the indexing maps by filtering out the ones that have been 3112 // eliminated. 3113 SmallVector<AffineMap> newIndexingMaps; 3114 for (OpOperand *opOperand : op.getInputOperands()) 3115 if (canonicalInputIndices[opOperand->getOperandNumber()] == 3116 opOperand->getOperandNumber()) 3117 newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand)); 3118 for (OpOperand *opOperand : op.getOutputOperands()) 3119 newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand)); 3120 3121 // Clone the old op with new operands. 3122 Operation *newOp = 3123 op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands); 3124 auto newLinalgOp = cast<LinalgOp>(newOp); 3125 newOp->setAttr("indexing_maps", 3126 rewriter.getAffineMapArrayAttr(newIndexingMaps)); 3127 3128 // Set the number of inputs to the new value. The `clone` call above kept 3129 // the value from the original op. 3130 newLinalgOp.setNumInputs(canonicalInput.size()); 3131 3132 // Repair the payload entry block by RAUW'ing redundant arguments and 3133 // erasing them. 3134 Block &payload = newOp->getRegion(0).front(); 3135 SmallVector<OpOperand *> inputOperands = op.getInputOperands(); 3136 for (OpOperand *opOperand : llvm::reverse(inputOperands)) { 3137 // Iterate in reverse, so that we erase later args first, preventing the 3138 // argument list from shifting unexpectedly and invalidating all our 3139 // indices. 3140 unsigned operandNumber = opOperand->getOperandNumber(); 3141 if (canonicalInputIndices[operandNumber] == operandNumber) 3142 continue; 3143 payload.getArgument(operandNumber) 3144 .replaceAllUsesWith( 3145 payload.getArgument(canonicalInputIndices[operandNumber])); 3146 payload.eraseArgument(operandNumber); 3147 } 3148 3149 rewriter.replaceOp(op, newOp->getResults()); 3150 return success(); 3151 } 3152 }; 3153 3154 /// Remove generic operations (on tensors) that are just copying 3155 /// the values from inputs to the results. Requirements are 3156 /// 1) All iterator types are parallel 3157 /// 2) The body contains just a yield operation with the yielded values being 3158 /// the arguments corresponding to the operands. 3159 struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> { 3160 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 3161 3162 LogicalResult matchAndRewrite(LinalgOp op, 3163 PatternRewriter &rewriter) const override { 3164 if (auto copyOp = dyn_cast<CopyOp>(*op)) { 3165 assert(copyOp.hasBufferSemantics()); 3166 if (copyOp.input() == copyOp.output() && 3167 copyOp.inputPermutation() == copyOp.outputPermutation()) { 3168 rewriter.eraseOp(op); 3169 return success(); 3170 } 3171 } 3172 3173 if (!isa<GenericOp>(op)) 3174 return failure(); 3175 if (!op.hasTensorSemantics()) 3176 return failure(); 3177 // Check all indexing maps are identity. 3178 if (llvm::any_of(op.getIndexingMaps(), 3179 [](AffineMap map) { return !map.isIdentity(); })) 3180 return failure(); 3181 3182 // Check that the body of the linalg operation is just a linalg.yield 3183 // operation. 3184 Block &body = op->getRegion(0).front(); 3185 if (!llvm::hasSingleElement(body)) 3186 return failure(); 3187 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); 3188 if (!yieldOp) 3189 return failure(); 3190 3191 // Get the argument number of the returned values. That is the operand 3192 // number to use for replacing uses of this operation. 3193 SmallVector<Value, 4> returnedArgs; 3194 for (Value yieldVal : yieldOp.values()) { 3195 auto yieldArg = yieldVal.dyn_cast<BlockArgument>(); 3196 if (!yieldArg || yieldArg.getOwner() != &body) 3197 return failure(); 3198 unsigned argumentNumber = yieldArg.getArgNumber(); 3199 returnedArgs.push_back(op->getOperand(argumentNumber)); 3200 } 3201 if (returnedArgs.size() != op.getOperation()->getNumResults()) 3202 return failure(); 3203 rewriter.replaceOp(op, returnedArgs); 3204 return success(); 3205 } 3206 }; 3207 } // namespace 3208 3209 #define LINALGOP_FOLDERS(XXX) \ 3210 LogicalResult XXX::fold(ArrayRef<Attribute>, \ 3211 SmallVectorImpl<OpFoldResult> &) { \ 3212 return foldMemRefCast(*this); \ 3213 } 3214 3215 LINALGOP_FOLDERS(ConvOp) 3216 LINALGOP_FOLDERS(PoolingMaxOp) 3217 LINALGOP_FOLDERS(PoolingMinOp) 3218 LINALGOP_FOLDERS(PoolingSumOp) 3219 LINALGOP_FOLDERS(CopyOp) 3220 LINALGOP_FOLDERS(FillOp) 3221 LINALGOP_FOLDERS(GenericOp) 3222 3223 // All named ops canonicalizers and folders are auto-generated in the 3224 // .cpp.inc. 3225 3226 //===----------------------------------------------------------------------===// 3227 // LinalgDialect 3228 //===----------------------------------------------------------------------===// 3229 3230 void LinalgDialect::getCanonicalizationPatterns( 3231 RewritePatternSet &results) const { 3232 results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, 3233 RemoveIdentityLinalgOps>(getContext()); 3234 } 3235