1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 
15 #include "mlir/Dialect/SCF/SCF.h"
16 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
17 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
18 #include "mlir/Dialect/Utils/StaticValueUtils.h"
19 #include "mlir/IR/AffineExprVisitor.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Interfaces/InferTypeOpInterface.h"
24 #include "mlir/Parser.h"
25 
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/StringSet.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/FormatVariadic.h"
32 #include "llvm/Support/MathExtras.h"
33 #include "llvm/Support/raw_ostream.h"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"
39 
40 /// Forward declarations.
41 
42 /// Generic entry point to create the block for the region of a LinalgOp.
43 /// This is used by both named structured ops created by ods-gen and by manually
44 /// defined C++ ops.
45 /// This is used by both builders and parsers.
46 /// This function creates the block in the region with arguments corresponding
47 /// to the elemental types of `inputTypes` and `outputTypes`. The latter are
48 /// asserted to be of ShapedType.
49 template <typename NamedStructuredOpType>
50 static void fillStructuredOpRegion(
51     OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
52     TypeRange outputTypes,
53     llvm::function_ref<void(unsigned, unsigned)> errorHandler = nullptr);
54 
55 /// Generic entry point to create both the region and the block of a LinalgOp.
56 template <typename NamedStructuredOpType>
57 static void
58 createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
59                                 TypeRange inputTypes, TypeRange outputTypes);
60 
61 /// Common parsing and printing used for both named structured ops created by
62 /// ods-gen and by manually defined C++ ops. Does not handle regions.
63 static ParseResult
64 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
65                              SmallVectorImpl<Type> &inputTypes,
66                              SmallVectorImpl<Type> &outputTypes);
67 template <typename NamedStructuredOpType>
68 static void printCommonStructuredOpParts(OpAsmPrinter &p,
69                                          NamedStructuredOpType op);
70 
71 /// Specific parsing and printing for named structured ops created by ods-gen.
72 template <typename NamedStructuredOpType>
73 static ParseResult
74 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
75                              TypeRange inputTypes, TypeRange outputTypes);
76 
77 static ParseResult
78 parseNamedStructuredOpResults(OpAsmParser &parser,
79                               SmallVectorImpl<Type> &resultTypes);
80 
81 template <typename NamedStructuredOpType>
82 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
83                                           OperationState &result);
84 
85 static void printNamedStructuredOpResults(OpAsmPrinter &p,
86                                           TypeRange resultTypes);
87 
88 template <typename NamedStructuredOpType>
89 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
90 
91 /// This is a common class used for patterns of the form
92 /// ```
93 ///    someop(memrefcast(%src)) -> someop(%src)
94 /// ```
95 /// It folds the source of the memref.cast into the root operation directly.
96 static LogicalResult foldMemRefCast(Operation *op) {
97   bool folded = false;
98   for (OpOperand &operand : op->getOpOperands()) {
99     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
100     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
101       operand.set(castOp.getOperand());
102       folded = true;
103     }
104   }
105   return success(folded);
106 }
107 
108 /// This is a specialization of `foldMemRefCast` used for patterns of the form
109 /// ```
110 ///    tiled_loop(memrefcast(%src)) -> tiled_loop(%src)
111 /// ```
112 /// It folds the source of the memref.cast into the root operation directly.
113 static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
114   bool folded = false;
115   Location loc = op->getLoc();
116 
117   Block *body = op.getBody();
118   OpBuilder b = OpBuilder::atBlockBegin(body);
119 
120   // Update `input` and `output` operands and block arguments if necessary.
121   // Operands list: [lbs, ubs, steps, inputs, outputs].
122   // Block args list: [ivs, inputs, outputs].
123   for (size_t operandIndex = op.getNumControlOperands(),
124               bbArgIndex = op.getNumLoops(), e = op.getNumOperands();
125        operandIndex < e; ++operandIndex, ++bbArgIndex) {
126     OpOperand &operand = op->getOpOperand(operandIndex);
127 
128     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
129     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
130       operand.set(castOp.getOperand());
131       BlockArgument newBbArg =
132           body->insertArgument(bbArgIndex, castOp.getOperand().getType());
133       BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1);
134 
135       // Insert memref.cast back to the original type.
136       oldBbArg.replaceAllUsesWith(
137           b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg));
138       body->eraseArgument(oldBbArg.getArgNumber());
139 
140       folded = true;
141     }
142   }
143   return success(folded);
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // Region builder helper.
148 // TODO: Move this to a utility library.
149 // The public methods on this class are referenced directly from generated code
150 // and bind by name to math and type conversion functions in the DSL as:
151 //   `arithfn__{fnName}`
152 //   `typefn__{fnName}`
153 // Examples:
154 //   `arithfn__add`
155 //   `arithfn__mul`
156 //   `typefn__cast`
157 // The naming convention is intentional in order to match snake-cased DSL names.
158 // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
159 //
160 // Implementations of the math functions must be polymorphic over numeric types,
161 // internally performing necessary casts. If the function application makes no
162 // sense, then the only recourse is to assert and return nullptr. This can be
163 // extended later if it becomes possible to fail construction of the region. The
164 // invariant should be enforced at a higher level.
165 //
166 // TODO: These helpers are currently type polymorphic over the class of integer
167 // and floating point types, but they will not internally cast within bit
168 // widths of a class (mixed precision such as i8->i32) or across classes
169 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
170 // to be handled with care and work is being considered to extend the op
171 // language to make such cases explicit. In the mean-time, violating this will
172 // fail verification, which is deemed acceptable.
173 //===----------------------------------------------------------------------===//
174 
175 namespace {
176 
177 class RegionBuilderHelper {
178 public:
179   RegionBuilderHelper(MLIRContext *context, Block &block)
180       : context(context), block(block) {}
181 
182   // Generates operations to cast the given operand to a specified type.
183   // If the cast cannot be performed, a warning will be issued and the
184   // operand returned as-is (which will presumably yield a verification
185   // issue downstream).
186   Value cast(Type toType, Value operand, bool isUnsignedCast) {
187     OpBuilder builder = getBuilder();
188     auto loc = operand.getLoc();
189 
190     if (operand.getType() == toType)
191       return operand;
192     if (auto toIntType = toType.dyn_cast<IntegerType>()) {
193       // If operand is floating point, cast directly to the int type.
194       if (operand.getType().isa<FloatType>()) {
195         if (isUnsignedCast)
196           return builder.create<arith::FPToUIOp>(loc, toType, operand);
197         return builder.create<arith::FPToSIOp>(loc, toType, operand);
198       }
199       // Cast index operands directly to the int type.
200       if (operand.getType().isIndex())
201         return builder.create<arith::IndexCastOp>(loc, toType, operand);
202       if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
203         // Either extend or truncate.
204         if (toIntType.getWidth() > fromIntType.getWidth()) {
205           if (isUnsignedCast)
206             return builder.create<arith::ExtUIOp>(loc, toType, operand);
207           return builder.create<arith::ExtSIOp>(loc, toType, operand);
208         }
209         if (toIntType.getWidth() < fromIntType.getWidth())
210           return builder.create<arith::TruncIOp>(loc, toType, operand);
211       }
212     } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
213       // If operand is integer, cast directly to the float type.
214       // Note that it is unclear how to cast from BF16<->FP16.
215       if (operand.getType().isa<IntegerType>()) {
216         if (isUnsignedCast)
217           return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
218         return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
219       }
220       if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
221         if (toFloatType.getWidth() > fromFloatType.getWidth())
222           return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
223         if (toFloatType.getWidth() < fromFloatType.getWidth())
224           return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
225       }
226     }
227 
228     emitWarning(operand.getLoc()) << "could not cast operand of type "
229                                   << operand.getType() << " to " << toType;
230     return operand;
231   }
232 
233   // NOLINTNEXTLINE(*-identifier-naming): externally called.
234   Value typefn__cast(Type toType, Value operand) {
235     return cast(toType, operand, false);
236   }
237 
238   // NOLINTNEXTLINE(*-identifier-naming): externally called.
239   Value typefn__cast_unsigned(Type toType, Value operand) {
240     return cast(toType, operand, true);
241   }
242 
243   // NOLINTNEXTLINE(*-identifier-naming): externally called.
244   Value arithfn__add(Value lhs, Value rhs) {
245     OpBuilder builder = getBuilder();
246     if (isFloatingPoint(lhs))
247       return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
248     if (isInteger(lhs))
249       return builder.create<arith::AddIOp>(lhs.getLoc(), lhs, rhs);
250     llvm_unreachable("unsupported non numeric type");
251   }
252 
253   // NOLINTNEXTLINE(*-identifier-naming): externally called.
254   Value arithfn__exp(Value x) {
255     OpBuilder builder = getBuilder();
256     if (isFloatingPoint(x))
257       return builder.create<math::ExpOp>(x.getLoc(), x);
258     llvm_unreachable("unsupported non numeric type");
259   }
260 
261   // NOLINTNEXTLINE(*-identifier-naming): externally called.
262   Value arithfn__log(Value x) {
263     OpBuilder builder = getBuilder();
264     if (isFloatingPoint(x))
265       return builder.create<math::LogOp>(x.getLoc(), x);
266     llvm_unreachable("unsupported non numeric type");
267   }
268 
269   // NOLINTNEXTLINE(*-identifier-naming): externally called.
270   Value arithfn__sub(Value lhs, Value rhs) {
271     OpBuilder builder = getBuilder();
272     if (isFloatingPoint(lhs))
273       return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
274     if (isInteger(lhs))
275       return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
276     llvm_unreachable("unsupported non numeric type");
277   }
278 
279   // NOLINTNEXTLINE(*-identifier-naming): externally called.
280   Value arithfn__mul(Value lhs, Value rhs) {
281     OpBuilder builder = getBuilder();
282     if (isFloatingPoint(lhs))
283       return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
284     if (isInteger(lhs))
285       return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
286     llvm_unreachable("unsupported non numeric type");
287   }
288 
289   // NOLINTNEXTLINE(*-identifier-naming): externally called.
290   Value arithfn__max(Value lhs, Value rhs) {
291     OpBuilder builder = getBuilder();
292     if (isFloatingPoint(lhs))
293       return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
294     if (isInteger(lhs))
295       return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
296     llvm_unreachable("unsupported non numeric type");
297   }
298 
299   // NOLINTNEXTLINE(*-identifier-naming): externally called.
300   Value arithfn__max_unsigned(Value lhs, Value rhs) {
301     OpBuilder builder = getBuilder();
302     if (isFloatingPoint(lhs))
303       return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
304     if (isInteger(lhs))
305       return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
306     llvm_unreachable("unsupported non numeric type");
307   }
308 
309   // NOLINTNEXTLINE(*-identifier-naming): externally called.
310   Value arithfn__min(Value lhs, Value rhs) {
311     OpBuilder builder = getBuilder();
312     if (isFloatingPoint(lhs))
313       return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
314     if (isInteger(lhs))
315       return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
316     llvm_unreachable("unsupported non numeric type");
317   }
318 
319   // NOLINTNEXTLINE(*-identifier-naming): externally called.
320   Value arithfn__min_unsigned(Value lhs, Value rhs) {
321     OpBuilder builder = getBuilder();
322     if (isFloatingPoint(lhs))
323       return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
324     if (isInteger(lhs))
325       return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
326     llvm_unreachable("unsupported non numeric type");
327   }
328 
329   void yieldOutputs(ValueRange values) {
330     assert(!values.empty() && "linalg ops must yield outputs");
331     if (values.empty())
332       return;
333     Value first = values.front();
334     OpBuilder builder = getBuilder();
335     builder.create<YieldOp>(first.getLoc(), values);
336   }
337 
338   Value constant(const std::string &value) {
339     OpBuilder builder = getBuilder();
340     Location loc = builder.getUnknownLoc();
341     Attribute valueAttr = parseAttribute(value, builder.getContext());
342     return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
343                                              valueAttr);
344   }
345 
346   Value index(int64_t dim) {
347     OpBuilder builder = getBuilder();
348     return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
349   }
350 
351   Type getIntegerType(unsigned width) {
352     return IntegerType::get(context, width);
353   }
354 
355   Type getFloat32Type() { return Float32Type::get(context); }
356 
357   Type getFloat64Type() { return Float64Type::get(context); }
358 
359 private:
360   MLIRContext *context;
361   Block &block;
362 
363   bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
364   bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
365 
366   OpBuilder getBuilder() {
367     OpBuilder builder(context);
368     builder.setInsertionPointToEnd(&block);
369     return builder;
370   }
371 };
372 
373 } // namespace
374 
375 //===----------------------------------------------------------------------===//
376 // CopyOp
377 //===----------------------------------------------------------------------===//
378 void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
379   assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
380   b.create<linalg::YieldOp>(block.getArgument(0));
381 }
382 
383 void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
384                    Value output, AffineMap inputPermutation,
385                    AffineMap outputPermutation,
386                    ArrayRef<NamedAttribute> namedAttrs) {
387   result.addOperands({input, output});
388   result.addAttributes(namedAttrs);
389   if (inputPermutation)
390     result.addAttribute("inputPermutation",
391                         AffineMapAttr::get(inputPermutation));
392   if (outputPermutation)
393     result.addAttribute("outputPermutation",
394                         AffineMapAttr::get(outputPermutation));
395   result.addRegion();
396   fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
397                                  TypeRange{input.getType()},
398                                  TypeRange{output.getType()});
399 }
400 
401 ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
402                               Type outputType) {
403   OpBuilder opBuilder(parser.getContext());
404   fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
405                                  TypeRange{outputType});
406   return success();
407 }
408 
409 /// CopyOp region is elided when printing.
410 void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
411 
412 static LogicalResult verify(CopyOp op) {
413   OpOperand *output = op.getOutputOperand(0);
414   OpOperand *input = op.getInputOperand(0);
415   if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get()))
416     return op.emitOpError("expects views of the same type");
417   if (op.getRank(input) != op.getRank(output))
418     return op.emitOpError("expects views of the same rank");
419   auto rank = op.getNumParallelLoops();
420   auto inputPermutationMap = op.inputPermutation();
421   if (inputPermutationMap) {
422     if (inputPermutationMap->getNumInputs() != rank)
423       return op.emitOpError("expects optional input_permutation map of rank ")
424              << rank;
425     if (!inputPermutationMap->isPermutation())
426       return op.emitOpError(
427           "expects optional input_permutation map to be a permutation");
428   }
429   auto outputPermutationMap = op.outputPermutation();
430   if (outputPermutationMap) {
431     if (outputPermutationMap->getNumInputs() != rank)
432       return op.emitOpError("expects optional output_permutation map of rank ")
433              << rank;
434     if (!outputPermutationMap->isPermutation())
435       return op.emitOpError(
436           "expects optional output_permutation map to be a permutation");
437   }
438   if (rank == 0 && inputPermutationMap)
439     return op.emitOpError("expected no input permutation when rank == 0");
440   if (rank == 0 && outputPermutationMap)
441     return op.emitOpError("expected no output permutation when rank == 0");
442   return success();
443 }
444 
445 void CopyOp::getEffects(
446     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
447         &effects) {
448   effects.emplace_back(MemoryEffects::Read::get(), input(),
449                        SideEffects::DefaultResource::get());
450   effects.emplace_back(MemoryEffects::Write::get(), output(),
451                        SideEffects::DefaultResource::get());
452 }
453 
454 namespace {
455 /// Remove copy operations that copy data inplace. Requirements are:
456 /// 1) The input and output values are identical.
457 /// 2) The input and output permutation maps are identical.
458 struct EraseIdentityCopyOp : public OpRewritePattern<CopyOp> {
459   using OpRewritePattern<CopyOp>::OpRewritePattern;
460 
461   LogicalResult matchAndRewrite(CopyOp copyOp,
462                                 PatternRewriter &rewriter) const override {
463     assert(copyOp.hasBufferSemantics());
464     if (copyOp.input() == copyOp.output() &&
465         copyOp.inputPermutation() == copyOp.outputPermutation()) {
466       rewriter.eraseOp(copyOp);
467       return success();
468     }
469     return failure();
470   }
471 };
472 } // namespace
473 
474 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
475                                          MLIRContext *context) {
476   results.add<EraseIdentityCopyOp>(context);
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // FillOp
481 //===----------------------------------------------------------------------===//
482 void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
483   assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
484   b.create<linalg::YieldOp>(block.getArgument(0));
485 }
486 
487 void FillOp::build(OpBuilder &builder, OperationState &result, Value value,
488                    Value output) {
489   build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
490         output);
491   fillStructuredOpRegion<FillOp>(builder, *result.regions.front(),
492                                  TypeRange{value.getType()},
493                                  TypeRange{output.getType()}, {});
494 }
495 
496 ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType,
497                               Type outputType) {
498   OpBuilder opBuilder(parser.getContext());
499   fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType},
500                                  TypeRange{outputType});
501   return success();
502 }
503 
504 /// FillOp region is elided when printing.
505 void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
506 
507 static LogicalResult verify(FillOp op) {
508   OpOperand *output = op.getOutputOperand(0);
509   Type fillType = op.value().getType();
510   if (getElementTypeOrSelf(output->get()) != fillType)
511     return op.emitOpError("expects fill type to match view elemental type");
512   return success();
513 }
514 
515 void FillOp::getEffects(
516     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
517         &effects) {
518   if (output().getType().isa<MemRefType>())
519     effects.emplace_back(MemoryEffects::Write::get(), output(),
520                          SideEffects::DefaultResource::get());
521 }
522 
523 namespace {
524 
525 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
526 ///
527 /// For such op chains, we can create new linalg.fill ops with the result
528 /// type of the tensor.expand/collapse_shape op.
529 template <typename TensorReshapeOp>
530 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
531   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
532   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
533                                 PatternRewriter &rewriter) const override {
534     auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>();
535     if (!oldFill)
536       return failure();
537 
538     Location loc = oldFill.getLoc();
539     auto newInit = rewriter.create<TensorReshapeOp>(
540         loc, reshapeOp.getResultType(), oldFill.output(),
541         reshapeOp.reassociation());
542     rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, oldFill.value(), newInit);
543 
544     return success();
545   }
546 };
547 
548 } // namespace
549 
550 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
551                                          MLIRContext *context) {
552   results.add<FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
553               FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // GenericOps
558 //===----------------------------------------------------------------------===//
559 void GenericOp::build(
560     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
561     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
562     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
563     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
564     ArrayRef<NamedAttribute> attributes) {
565   build(builder, result, resultTensorTypes, inputs, outputs,
566         builder.getAffineMapArrayAttr(indexingMaps),
567         builder.getStrArrayAttr(iteratorTypes),
568         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
569         libraryCall.empty() ? StringAttr()
570                             : builder.getStringAttr(libraryCall));
571   result.addAttributes(attributes);
572   if (!bodyBuild)
573     return;
574 
575   SmallVector<Type, 4> blockArgTypes;
576   for (ValueRange container : {inputs, outputs})
577     for (Value v : container)
578       blockArgTypes.push_back(getElementTypeOrSelf(v));
579 
580   OpBuilder::InsertionGuard guard(builder);
581   auto &region = *result.regions.front();
582   Block *bodyBlock = builder.createBlock(&region, region.end(), blockArgTypes);
583   bodyBuild(builder, result.location, bodyBlock->getArguments());
584 }
585 
586 void GenericOp::build(
587     OpBuilder &builder, OperationState &result, ValueRange inputs,
588     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
589     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
590     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
591     ArrayRef<NamedAttribute> attributes) {
592   build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
593         iteratorTypes, doc, libraryCall, bodyBuild, attributes);
594 }
595 
596 void GenericOp::build(
597     OpBuilder &builder, OperationState &result, ValueRange inputs,
598     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
599     ArrayRef<StringRef> iteratorTypes,
600     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
601     ArrayRef<NamedAttribute> attributes) {
602   build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
603         /*doc=*/"",
604         /*libraryCall=*/"", bodyBuild, attributes);
605 }
606 
607 void GenericOp::build(
608     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
609     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
610     ArrayRef<StringRef> iteratorTypes,
611     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
612     ArrayRef<NamedAttribute> attributes) {
613   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
614         iteratorTypes,
615         /*doc=*/"",
616         /*libraryCall=*/"", bodyBuild, attributes);
617 }
618 
619 static void print(OpAsmPrinter &p, GenericOp op) {
620   p << " ";
621 
622   // Print extra attributes.
623   auto genericAttrNames = op.linalgTraitAttrNames();
624 
625   llvm::StringSet<> genericAttrNamesSet;
626   genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
627   SmallVector<NamedAttribute, 8> genericAttrs;
628   for (auto attr : op->getAttrs())
629     if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
630       genericAttrs.push_back(attr);
631   if (!genericAttrs.empty()) {
632     auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
633     p << genericDictAttr;
634   }
635 
636   // Printing is shared with named ops, except for the region and attributes
637   printCommonStructuredOpParts(p, op);
638 
639   genericAttrNames.push_back("operand_segment_sizes");
640   genericAttrNamesSet.insert(genericAttrNames.back());
641 
642   bool hasExtraAttrs = false;
643   for (NamedAttribute n : op->getAttrs()) {
644     if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
645       break;
646   }
647   if (hasExtraAttrs) {
648     p << " attrs = ";
649     p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames);
650   }
651 
652   // Print region.
653   if (!op.region().empty()) {
654     p << ' ';
655     p.printRegion(op.region());
656   }
657 
658   // Print results.
659   printNamedStructuredOpResults(p, op.result_tensors().getTypes());
660 }
661 
662 static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
663   DictionaryAttr dictAttr;
664   // Parse the core linalg traits that must check into a dictAttr.
665   // The name is unimportant as we will overwrite result.attributes.
666   // The core linalg traits must contain the information necessary to pass the
667   // verifier.
668   if (parser.parseAttribute(dictAttr, "_", result.attributes))
669     return failure();
670   result.attributes.assign(dictAttr.getValue().begin(),
671                            dictAttr.getValue().end());
672 
673   // Parsing is shared with named ops, except for the region.
674   SmallVector<Type, 1> inputTypes, outputTypes;
675   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
676     return failure();
677 
678   // Optional attributes may be added.
679   if (succeeded(parser.parseOptionalKeyword("attrs")))
680     if (failed(parser.parseEqual()) ||
681         failed(parser.parseOptionalAttrDict(result.attributes)))
682       return failure();
683 
684   SmallVector<OpAsmParser::OperandType, 8> regionOperands;
685   std::unique_ptr<Region> region = std::make_unique<Region>();
686   SmallVector<Type, 8> operandTypes, regionTypes;
687   if (parser.parseRegion(*region, regionOperands, regionTypes))
688     return failure();
689   result.addRegion(std::move(region));
690 
691   // Generic ops may specify that a subset of its outputs are tensors. Such
692   // outputs are specified in the result type.
693   // TODO: may need to move output parsing before region parsing.
694   // Need to wait for declarative assembly resolution to decide.
695   SmallVector<Type, 1> outputTensorsTypes;
696   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
697     return failure();
698   result.addTypes(outputTensorsTypes);
699 
700   return success();
701 }
702 
703 static void getGenericEffectsImpl(
704     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
705         &effects,
706     ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
707   for (Value value : results) {
708     effects.emplace_back(MemoryEffects::Allocate::get(), value,
709                          SideEffects::DefaultResource::get());
710   }
711   for (Value value : inputBuffers) {
712     effects.emplace_back(MemoryEffects::Read::get(), value,
713                          SideEffects::DefaultResource::get());
714   }
715   for (Value value : outputs) {
716     effects.emplace_back(MemoryEffects::Read::get(), value,
717                          SideEffects::DefaultResource::get());
718     effects.emplace_back(MemoryEffects::Write::get(), value,
719                          SideEffects::DefaultResource::get());
720   }
721 }
722 
723 void GenericOp::getEffects(
724     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
725         &effects) {
726   SmallVector<Value> inputBuffers = getInputBufferOperands();
727   SmallVector<Value> outputBuffers = getOutputBufferOperands();
728   getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
729                         outputBuffers);
730 }
731 
732 template <typename GenericOpType>
733 static LogicalResult verifyGenericOp(GenericOpType op) {
734   return success();
735 }
736 
737 static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
738 
739 namespace {
740 // Deduplicate redundant args of a linalg generic op.
741 // An arg is redundant if it has the same Value and indexing map as another.
742 struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
743   using OpRewritePattern<GenericOp>::OpRewritePattern;
744 
745   LogicalResult matchAndRewrite(GenericOp genericOp,
746                                 PatternRewriter &rewriter) const override {
747     // Associate each input to an equivalent "canonical" input that has the same
748     // Value and indexing map.
749     //
750     // In the non-duplicate case, input `i` will have canonical input `i`. But
751     // in the case of duplicated inputs, the canonical input could be some other
752     // input `< i`. That is, a later input will have some earlier input as its
753     // canonical input.
754     llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
755     // For later remapping tasks like deduplicating payload block arguments,
756     // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
757     // convenient.
758     SmallVector<unsigned> canonicalInputIndices;
759     for (OpOperand *opOperand : genericOp.getInputOperands()) {
760       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
761       // STL-like maps have a convenient behavior for our use case here. In the
762       // case of duplicate keys, the insertion is rejected, and the returned
763       // iterator gives access to the value already in the map.
764       auto pair = canonicalInput.insert(
765           {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
766       canonicalInputIndices.push_back(pair.first->second);
767     }
768 
769     // If there are no duplicate args, then bail out.
770     if (canonicalInput.size() == genericOp.getNumInputs())
771       return failure();
772 
773     // The operands for the newly canonicalized op.
774     SmallVector<Value> newInputOperands;
775     for (OpOperand *opOperand : genericOp.getInputOperands())
776       if (canonicalInputIndices[opOperand->getOperandNumber()] ==
777           opOperand->getOperandNumber())
778         newInputOperands.push_back(opOperand->get());
779 
780     // Repair the indexing maps by filtering out the ones that have been
781     // eliminated.
782     SmallVector<AffineMap> newIndexingMaps;
783     for (OpOperand *opOperand : genericOp.getInputOperands())
784       if (canonicalInputIndices[opOperand->getOperandNumber()] ==
785           opOperand->getOperandNumber())
786         newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
787     for (OpOperand *opOperand : genericOp.getOutputOperands())
788       newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
789 
790     // Clone the old op with new operands.
791     SmallVector<Value> outputOperands = genericOp.getOutputOperands();
792     auto newOp = rewriter.create<GenericOp>(
793         genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
794         outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
795         genericOp.iterator_types(), genericOp.docAttr(),
796         genericOp.library_callAttr());
797 
798     // Copy over unknown attributes. They might be load bearing for some flow.
799     ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
800     for (NamedAttribute kv : genericOp->getAttrs()) {
801       if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
802         newOp->setAttr(kv.getName(), kv.getValue());
803       }
804     }
805 
806     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
807                                 newOp.region().begin());
808 
809     // Repair the payload entry block by RAUW'ing redundant arguments and
810     // erasing them.
811     Block &payload = newOp.region().front();
812     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
813     for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
814       // Iterate in reverse, so that we erase later args first, preventing the
815       // argument list from shifting unexpectedly and invalidating all our
816       // indices.
817       unsigned operandNumber = opOperand->getOperandNumber();
818       if (canonicalInputIndices[operandNumber] == operandNumber)
819         continue;
820       payload.getArgument(operandNumber)
821           .replaceAllUsesWith(
822               payload.getArgument(canonicalInputIndices[operandNumber]));
823       payload.eraseArgument(operandNumber);
824     }
825 
826     rewriter.replaceOp(genericOp, newOp->getResults());
827     return success();
828   }
829 };
830 
831 /// Remove generic operations (on tensors) that are just copying
832 /// the values from inputs to the results. Requirements are
833 /// 1) All iterator types are parallel
834 /// 2) The body contains just a yield operation with the yielded values being
835 ///    the arguments corresponding to the operands.
836 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
837   using OpRewritePattern<GenericOp>::OpRewritePattern;
838 
839   LogicalResult matchAndRewrite(GenericOp genericOp,
840                                 PatternRewriter &rewriter) const override {
841     if (!genericOp.hasTensorSemantics())
842       return failure();
843     // Check all indexing maps are identity.
844     if (llvm::any_of(genericOp.getIndexingMaps(),
845                      [](AffineMap map) { return !map.isIdentity(); }))
846       return failure();
847 
848     // Check that the body of the linalg operation is just a linalg.yield
849     // operation.
850     Block &body = genericOp.region().front();
851     if (!llvm::hasSingleElement(body))
852       return failure();
853     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
854     if (!yieldOp)
855       return failure();
856 
857     // Get the argument number of the returned values. That is the operand
858     // number to use for replacing uses of this operation.
859     SmallVector<Value> returnedArgs;
860     for (Value yieldVal : yieldOp.values()) {
861       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
862       if (!yieldArg || yieldArg.getOwner() != &body)
863         return failure();
864       unsigned argumentNumber = yieldArg.getArgNumber();
865       returnedArgs.push_back(genericOp->getOperand(argumentNumber));
866     }
867     if (returnedArgs.size() != genericOp->getNumResults())
868       return failure();
869     rewriter.replaceOp(genericOp, returnedArgs);
870     return success();
871   }
872 };
873 } // namespace
874 
875 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
876                                             MLIRContext *context) {
877   results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // InitTensorOp
882 //===----------------------------------------------------------------------===//
883 
884 void InitTensorOp::build(OpBuilder &b, OperationState &result,
885                          ArrayRef<OpFoldResult> sizes, Type elementType,
886                          ArrayRef<NamedAttribute> attrs) {
887   SmallVector<Value, 4> dynamicSizes;
888   SmallVector<int64_t, 4> staticSizes;
889   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
890                              ShapedType::kDynamicSize);
891   auto resultType = RankedTensorType ::get(staticSizes, elementType);
892   build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
893   result.addAttributes(attrs);
894 }
895 
896 static LogicalResult verify(InitTensorOp op) {
897   RankedTensorType resultType = op.getType();
898   SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
899       op.static_sizes().cast<ArrayAttr>(),
900       [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
901 
902   if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(),
903                                             op.static_sizes(), op.sizes(),
904                                             ShapedType::isDynamic)))
905     return failure();
906 
907   if (op.static_sizes().size() != static_cast<unsigned>(resultType.getRank()))
908     return op->emitError("expected ")
909            << resultType.getRank() << " sizes values";
910 
911   Type expectedType = InitTensorOp::inferResultType(
912       staticSizes, resultType.getElementType(), resultType.getEncoding());
913   if (resultType != expectedType) {
914     return op.emitError("specified type ")
915            << resultType << " does not match the inferred type "
916            << expectedType;
917   }
918   return success();
919 }
920 
921 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
922                                    Type elementType, Attribute encoding) {
923   return RankedTensorType::get(staticSizes, elementType, encoding);
924 }
925 
926 namespace {
927 /// Change the type of the result of a `linalg.init_tensor` by making the result
928 /// type statically sized along dimension that in the original operation where
929 /// defined as dynamic, but the size was defined using a `constant` op. For
930 /// example
931 ///
932 ///  %c5 = arith.constant 5: index
933 ///  %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
934 ///
935 ///  to
936 ///
937 ///  %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
938 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
939   using OpRewritePattern<InitTensorOp>::OpRewritePattern;
940 
941   LogicalResult matchAndRewrite(InitTensorOp op,
942                                 PatternRewriter &rewriter) const override {
943     SmallVector<Value, 4> dynamicSizes;
944     SmallVector<int64_t, 4> staticSizes;
945     for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
946       // If the size is already static, nothing to do.
947       if (!op.isDynamicSize(i)) {
948         staticSizes.push_back(op.getStaticSize(i));
949         continue;
950       }
951 
952       // If the size is dynamic but defined using a `constant` op, get the
953       // constant value to find the static size to use.
954       unsigned operandNum = op.getIndexOfDynamicSize(i);
955       Value sizeOperand = op.getOperand(operandNum);
956       if (auto constantIndexOp =
957               sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) {
958         staticSizes.push_back(constantIndexOp.value());
959         continue;
960       }
961 
962       // Fallback case. Keep the size dynamic.
963       dynamicSizes.push_back(sizeOperand);
964       staticSizes.push_back(ShapedType::kDynamicSize);
965     }
966     RankedTensorType newType =
967         RankedTensorType::get(staticSizes, op.getType().getElementType());
968     if (newType == op.getType())
969       return failure();
970     auto newOp =
971         rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
972                                       rewriter.getI64ArrayAttr(staticSizes));
973     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
974     return success();
975   }
976 };
977 } // namespace
978 
979 namespace {
980 /// Since `init_tensor` operation creates a tensor needed only for its shape, a
981 /// slice of this is also needed only for its shape. The result can be
982 /// replaced by a new init_tensor operation of the same size as the extract
983 /// slice op.
984 struct FoldInitTensorWithExtractSliceOp
985     : public OpRewritePattern<tensor::ExtractSliceOp> {
986   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
987 
988   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
989                                 PatternRewriter &rewriter) const override {
990     if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>())
991       return failure();
992     // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
993     // as well as its result type.
994     rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
995         sliceOp, sliceOp.sizes(),
996         sliceOp.result().getType().cast<RankedTensorType>().getShape(),
997         sliceOp.getSourceType().getElementType());
998     return success();
999   }
1000 };
1001 
1002 template <typename TensorReshapeOp>
1003 struct FoldInitTensorWithTensorReshapeOp
1004     : public OpRewritePattern<TensorReshapeOp> {
1005   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1006 
1007   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1008                                 PatternRewriter &rewriter) const override {
1009     if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
1010       return failure();
1011     Location loc = reshapeOp.getLoc();
1012     ReifiedRankedShapedTypeDims resultShapes;
1013     ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
1014         cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
1015     if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
1016                                                           resultShapes)) ||
1017         !llvm::hasSingleElement(resultShapes))
1018       return failure();
1019     Value initTensor = rewriter.create<InitTensorOp>(
1020         loc, getAsOpFoldResult(resultShapes[0]),
1021         reshapeOp.getResultType().getElementType());
1022     if (initTensor.getType() != reshapeOp.getResultType()) {
1023       rewriter.replaceOpWithNewOp<tensor::CastOp>(
1024           reshapeOp, reshapeOp.getResultType(), initTensor);
1025     } else {
1026       rewriter.replaceOp(reshapeOp, initTensor);
1027     }
1028     return success();
1029   }
1030 };
1031 
1032 struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
1033   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1034 
1035   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1036                                 PatternRewriter &rewriter) const override {
1037     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1038     auto initTensorOp = dimOp.source().getDefiningOp<linalg::InitTensorOp>();
1039     if (!initTensorOp || !maybeConstantIndex)
1040       return failure();
1041     if (!initTensorOp.isDynamicSize(*maybeConstantIndex))
1042       return failure();
1043     rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex));
1044     return success();
1045   }
1046 };
1047 } // namespace
1048 
1049 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
1050                                                MLIRContext *context) {
1051   results.add<FoldInitTensorWithDimOp, FoldInitTensorWithExtractSliceOp,
1052               FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>,
1053               FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>,
1054               ReplaceStaticShapeDims>(context);
1055 }
1056 
1057 LogicalResult InitTensorOp::reifyResultShapes(
1058     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1059   auto shapes = llvm::to_vector<4>(llvm::map_range(
1060       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
1061         if (isDynamicSize(dim))
1062           return getDynamicSize(dim);
1063         return builder.create<arith::ConstantIndexOp>(getLoc(),
1064                                                       getStaticSize(dim));
1065       }));
1066   reifiedReturnShapes.emplace_back(std::move(shapes));
1067   return success();
1068 }
1069 
1070 //===----------------------------------------------------------------------===//
1071 // PadTensorOp
1072 //===----------------------------------------------------------------------===//
1073 
1074 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
1075 // supports optional types.
1076 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
1077                     Type typeToInfer, Type typeToInferFrom) {}
1078 
1079 ParseResult parseInferType(OpAsmParser &parser,
1080                            Optional<OpAsmParser::OperandType> optOperand,
1081                            Type &typeToInfer, Type typeToInferFrom) {
1082   if (optOperand)
1083     typeToInfer = typeToInferFrom;
1084   return success();
1085 }
1086 
1087 static LogicalResult verify(PadTensorOp op) {
1088   auto sourceType = op.source().getType().cast<RankedTensorType>();
1089   auto resultType = op.result().getType().cast<RankedTensorType>();
1090   auto expectedType = PadTensorOp::inferResultType(
1091       sourceType, extractFromI64ArrayAttr(op.static_low()),
1092       extractFromI64ArrayAttr(op.static_high()));
1093   for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
1094     if (resultType.getDimSize(i) == expectedType.getDimSize(i))
1095       continue;
1096     if (expectedType.isDynamicDim(i))
1097       continue;
1098     return op.emitError("specified type ")
1099            << resultType << " does not match the inferred type "
1100            << expectedType;
1101   }
1102 
1103   auto &region = op.region();
1104   unsigned rank = resultType.getRank();
1105   Block &block = region.front();
1106   if (block.getNumArguments() != rank)
1107     return op.emitError("expected the block to have ") << rank << " arguments";
1108 
1109   // Note: the number and type of yield values are checked in the YieldOp.
1110   for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
1111     if (!en.value().isIndex())
1112       return op.emitOpError("expected block argument ")
1113              << (en.index() + 1) << " to be an index";
1114   }
1115 
1116   return success();
1117 }
1118 
1119 RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
1120                                               ArrayRef<int64_t> staticLow,
1121                                               ArrayRef<int64_t> staticHigh,
1122                                               ArrayRef<int64_t> resultShape) {
1123   unsigned rank = sourceType.getRank();
1124   assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
1125   assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
1126   assert((resultShape.empty() || resultShape.size() == rank) &&
1127          "unexpected resultShape size mismatch");
1128 
1129   SmallVector<int64_t, 4> inferredShape;
1130   for (auto i : llvm::seq<unsigned>(0, rank)) {
1131     if (sourceType.isDynamicDim(i) ||
1132         staticLow[i] == ShapedType::kDynamicSize ||
1133         staticHigh[i] == ShapedType::kDynamicSize) {
1134       inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
1135                                                   : resultShape[i]);
1136     } else {
1137       int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
1138       assert((resultShape.empty() || size == resultShape[i] ||
1139               resultShape[i] == ShapedType::kDynamicSize) &&
1140              "mismatch between inferred shape and result shape");
1141       inferredShape.push_back(size);
1142     }
1143   }
1144 
1145   return RankedTensorType::get(inferredShape, sourceType.getElementType());
1146 }
1147 
1148 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
1149                         ArrayRef<int64_t> staticLow,
1150                         ArrayRef<int64_t> staticHigh, ValueRange low,
1151                         ValueRange high, bool nofold,
1152                         ArrayRef<NamedAttribute> attrs) {
1153   auto sourceType = source.getType().cast<RankedTensorType>();
1154   auto resultType = inferResultType(sourceType, staticLow, staticHigh);
1155   build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
1156         b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
1157   result.addAttributes(attrs);
1158 }
1159 
1160 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
1161                         ValueRange low, ValueRange high, bool nofold,
1162                         ArrayRef<NamedAttribute> attrs) {
1163   auto sourceType = source.getType().cast<RankedTensorType>();
1164   unsigned rank = sourceType.getRank();
1165   SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
1166   build(b, result, source, staticVector, staticVector, low, high, nofold,
1167         attrs);
1168 }
1169 
1170 void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
1171                         Value source, ArrayRef<OpFoldResult> low,
1172                         ArrayRef<OpFoldResult> high, bool nofold,
1173                         ArrayRef<NamedAttribute> attrs) {
1174   assert(resultType.isa<RankedTensorType>());
1175   auto sourceType = source.getType().cast<RankedTensorType>();
1176   SmallVector<Value, 4> dynamicLow, dynamicHigh;
1177   SmallVector<int64_t, 4> staticLow, staticHigh;
1178   // staticLow and staticHigh have full information of the padding config.
1179   // This will grow staticLow and staticHigh with 1 value. If the config is
1180   // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
1181   // value as well.
1182   dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
1183                              ShapedType::kDynamicSize);
1184   dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
1185                              ShapedType::kDynamicSize);
1186   if (!resultType) {
1187     resultType =
1188         PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
1189   }
1190   build(b, result, resultType, source, dynamicLow, dynamicHigh,
1191         b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
1192         nofold ? b.getUnitAttr() : UnitAttr());
1193   result.addAttributes(attrs);
1194 }
1195 
1196 PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
1197                                            ArrayRef<OpFoldResult> low,
1198                                            ArrayRef<OpFoldResult> high,
1199                                            bool nofold, Location loc,
1200                                            OpBuilder &builder) {
1201   auto padTensorOp =
1202       builder.create<linalg::PadTensorOp>(loc, type, source, low, high, nofold);
1203   int rank = padTensorOp.getResultType().getRank();
1204   SmallVector<Type, 4> blockArgTypes;
1205   blockArgTypes.assign(rank, builder.getIndexType());
1206   auto &region = padTensorOp.region();
1207   // `builder.createBlock` changes the insertion point within the block. Create
1208   // a guard to reset the insertion point of the builder after it is destroyed.
1209   OpBuilder::InsertionGuard guard(builder);
1210   builder.createBlock(&region, region.end(), blockArgTypes);
1211   builder.create<linalg::YieldOp>(loc, pad);
1212   return padTensorOp;
1213 }
1214 
1215 PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
1216                                          bool nofold, Location loc,
1217                                          OpBuilder &b) {
1218   SmallVector<OpFoldResult, 4> low, high;
1219   auto rankedTensorType = type.cast<RankedTensorType>();
1220   assert(rankedTensorType.hasStaticShape());
1221   for (const auto &en : enumerate(rankedTensorType.getShape())) {
1222     AffineExpr d0;
1223     bindDims(b.getContext(), d0);
1224     auto dimOp = b.createOrFold<tensor::DimOp>(loc, source, en.index());
1225     Value paddingWidth =
1226         makeComposedAffineApply(b, loc, en.value() - d0, {dimOp});
1227     high.push_back(paddingWidth);
1228     low.push_back(b.createOrFold<arith::ConstantIndexOp>(loc, 0));
1229   }
1230   return PadTensorOp::createPadScalarOp(type, source, pad, low, high, nofold,
1231                                         loc, b);
1232 }
1233 
1234 LogicalResult PadTensorOp::reifyResultShapes(
1235     OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1236   Location loc = getLoc();
1237   auto lowPad = getMixedLowPad();
1238   auto highPad = getMixedHighPad();
1239   SmallVector<Value> shapes;
1240   for (auto dim : llvm::seq<int64_t>(0, getSourceType().getRank())) {
1241     // Shape along each dimension is source dim + low pad + high pad.
1242     SmallVector<Value> mapOperands;
1243     mapOperands.push_back(b.createOrFold<tensor::DimOp>(loc, source(), dim));
1244     AffineExpr expr = b.getAffineDimExpr(0);
1245     unsigned numSymbols = 0;
1246     auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
1247       if (Value v = valueOrAttr.dyn_cast<Value>()) {
1248         expr = expr + b.getAffineSymbolExpr(numSymbols++);
1249         mapOperands.push_back(v);
1250         return;
1251       }
1252       int64_t staticValue =
1253           valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt();
1254       expr = expr + staticValue;
1255     };
1256     addOpFoldResult(lowPad[dim]);
1257     addOpFoldResult(highPad[dim]);
1258     shapes.push_back(applyMapToValues(
1259         b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]);
1260   }
1261   reifiedReturnShapes.emplace_back(std::move(shapes));
1262   return success();
1263 }
1264 
1265 //===----------------------------------------------------------------------===//
1266 // Methods related to PadTensor tiling.
1267 //===----------------------------------------------------------------------===//
1268 
1269 SmallVector<Value> PadTensorOp::getDestinationOperands(OpBuilder &b) {
1270   ReifiedRankedShapedTypeDims reifiedShapes;
1271   (void)reifyResultShapes(b, reifiedShapes);
1272   SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]);
1273   Value initTensor = b.create<InitTensorOp>(getLoc(), mixedSizes,
1274                                             getResultType().getElementType());
1275   return {initTensor};
1276 }
1277 
1278 SmallVector<StringRef> PadTensorOp::getLoopIteratorTypes() {
1279   SmallVector<StringRef> iteratorTypes(getResultType().getRank(),
1280                                        getParallelIteratorTypeName());
1281   return iteratorTypes;
1282 }
1283 
1284 SmallVector<Range> PadTensorOp::getIterationDomain(OpBuilder &b) {
1285   ReifiedRankedShapedTypeDims reifiedShapes;
1286   (void)reifyResultShapes(b, reifiedShapes);
1287   Value zero = b.create<arith::ConstantIndexOp>(getLoc(), 0);
1288   Value one = b.create<arith::ConstantIndexOp>(getLoc(), 1);
1289   // Initialize all the ranges to {zero, one, one}. All the `ub`s are
1290   // overwritten.
1291   SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
1292   for (const auto &ub : enumerate(reifiedShapes[0]))
1293     loopRanges[ub.index()].size = ub.value();
1294   return loopRanges;
1295 }
1296 
1297 SmallVector<Operation *> PadTensorOp::getTiledImplementation(
1298     OpBuilder &b, ValueRange dest, ArrayRef<OpFoldResult> offsets,
1299     ArrayRef<OpFoldResult> sizes, bool /*tileDestOperands*/) {
1300   // Only constant padding value supported.
1301   Value padValue = getConstantPaddingValue();
1302   if (!padValue)
1303     return {};
1304 
1305   // Helper variables and functions for various arithmetic operations. These are
1306   // used extensively for computing new offset/length and padding values.
1307   Location loc = getLoc();
1308   AffineExpr dim0, dim1;
1309   bindDims(b.getContext(), dim0, dim1);
1310   // Add two integers.
1311   auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
1312   auto add = [&](Value v1, Value v2) {
1313     return b.createOrFold<AffineApplyOp>(loc, addMap, ValueRange{v1, v2});
1314   };
1315   // Subtract two integers.
1316   auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
1317   auto sub = [&](Value v1, Value v2) {
1318     return b.createOrFold<AffineApplyOp>(loc, subMap, ValueRange{v1, v2});
1319   };
1320   // Take the minimum of two integers.
1321   auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
1322   auto min = [&](Value v1, Value v2) {
1323     return b.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
1324   };
1325   // Take the maximum of two integers.
1326   auto max = [&](Value v1, Value v2) {
1327     return b.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
1328   };
1329   // Zero index-typed integer.
1330   auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
1331 
1332   // Helper function for filling static/dynamic low/high padding indices vectors
1333   // of PadTensorOp.
1334   auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
1335                          SmallVector<int64_t> &staticIndices) {
1336     if (auto constInt = getConstantIntValue(val)) {
1337       staticIndices.push_back(*constInt);
1338     } else {
1339       staticIndices.push_back(ShapedType::kDynamicSize);
1340       dynIndices.push_back(val);
1341     }
1342   };
1343 
1344   // Compute new offsets, lengths, low padding, high padding.
1345   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
1346   SmallVector<Value> newLows, newHighs;
1347   SmallVector<int64_t> staticNewLows, staticNewHighs;
1348   // Set to true if the original data source is not read at all.
1349   bool hasZeroLen = false;
1350   // Same as hasZeroLen, but for dynamic dimension sizes. This condition
1351   // is true if the original data source turns out to be unused at runtime.
1352   Value dynHasZeroLenCond;
1353 
1354   int64_t rank = getSourceType().getRank();
1355   for (unsigned dim = 0; dim < rank; ++dim) {
1356     auto low = getValueOrCreateConstantIndexOp(b, loc, getMixedLowPad()[dim]);
1357     bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
1358     auto high = getValueOrCreateConstantIndexOp(b, loc, getMixedHighPad()[dim]);
1359     bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
1360     auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]);
1361     auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]);
1362     auto srcSize = b.createOrFold<tensor::DimOp>(loc, source(), dim);
1363 
1364     // The new amount of low padding is `low - offset`. Except for the case
1365     // where none of the low padding is read. In that case, the new amount of
1366     // low padding is zero.
1367     //
1368     // Optimization: If low = 0, then newLow = 0.
1369     Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
1370     appendIndex(newLow, newLows, staticNewLows);
1371 
1372     // Start reading the data from position `offset - low`. Since the original
1373     // read may have started in the low padding zone, this value could be
1374     // negative. Therefore, start reading from:
1375     //
1376     // max(offset - low, 0)
1377     //
1378     // The original read could also have started in the high padding zone.
1379     // In that case, set the offset to the end of source tensor. The new
1380     // ExtractSliceOp length will be zero in that case. (Effectively reading no
1381     // data from the source.)
1382     //
1383     // Optimization: If low = 0, then the formula can be simplified.
1384     Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
1385                                 : min(offset, srcSize);
1386     newOffsets.push_back(getAsOpFoldResult(newOffset));
1387 
1388     // The original ExtractSliceOp was reading until position `offset + length`.
1389     // Therefore, the corresponding position within the source tensor is:
1390     //
1391     // offset + length - low
1392     //
1393     // In case the original ExtractSliceOp stopped reading within the low
1394     // padding zone, this value can be negative. In that case, the end position
1395     // of the read should be zero. (Similar to newOffset.)
1396     //
1397     // The original read could also have stopped in the high padding zone.
1398     // In that case, set the end positition of the read should be the end of the
1399     // source tensor. (Similar to newOffset.)
1400     //
1401     // endLoc = min(max(offset - low + length, 0), srcSize)
1402     //
1403     // The new ExtractSliceOp length is `endLoc - newOffset`.
1404     //
1405     // Optimization: If low = 0, then the formula can be simplified.
1406     Value endLoc = hasLowPad
1407                        ? min(max(add(sub(offset, low), length), zero), srcSize)
1408                        : min(add(offset, length), srcSize);
1409     Value newLength = sub(endLoc, newOffset);
1410     newLengths.push_back(getAsOpFoldResult(newLength));
1411 
1412     // Check if newLength is zero. In that case, no SubTensorOp should be
1413     // executed.
1414     if (auto newLengthInt = getConstantIntValue(newLength)) {
1415       hasZeroLen |= *newLengthInt == 0;
1416     } else {
1417       Value check = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
1418                                             newLength, zero);
1419       dynHasZeroLenCond =
1420           dynHasZeroLenCond
1421               ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
1422               : check;
1423     }
1424 
1425     // The amount of high padding is simply the number of elements remaining,
1426     // so that the result has the same length as the original ExtractSliceOp.
1427     // As an optimization, if the original high padding is zero, then the new
1428     // high padding must also be zero.
1429     Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
1430     appendIndex(newHigh, newHighs, staticNewHighs);
1431 
1432     // Only unit stride supported.
1433     newStrides.push_back(b.getIndexAttr(1));
1434   }
1435 
1436   // The shape of the result can be obtained from the sizes passed in.
1437   SmallVector<Value> dynDims;
1438   SmallVector<int64_t> shape;
1439   dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize);
1440   RankedTensorType resultType =
1441       RankedTensorType::get(shape, getResultType().getElementType());
1442 
1443   // Insert cast to ensure that types match. (May be folded away.)
1444   auto castResult = [&](Value val) -> Operation * {
1445     auto castOp = b.create<tensor::CastOp>(loc, resultType, val);
1446     return castOp;
1447   };
1448 
1449   // In cases where the original data source is unused: Emit a GenerateOp and
1450   // do not generate a SliceOp. (The result shape of the SliceOp would
1451   // have a dimension of size 0, the semantics of which is unclear.)
1452   auto createGenerateOp = [&]() {
1453     // Create GenerateOp.
1454     auto generateOp = b.create<tensor::GenerateOp>(
1455         loc, resultType, dynDims,
1456         [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
1457           builder.create<tensor::YieldOp>(gLoc, padValue);
1458         });
1459     return castResult(generateOp);
1460   };
1461 
1462   // Emit a SliceOp and a PadTensorOp. Should not be used in cases where
1463   // the result shape of the new SliceOp has a zero dimension.
1464   auto createPadTensorOfSubTensor = [&]() {
1465     // Create pad_tensor(subtensor(x)).
1466     auto newSliceOp = b.create<tensor::ExtractSliceOp>(
1467         loc, source(), newOffsets, newLengths, newStrides);
1468     auto newPadTensorOp = b.create<PadTensorOp>(
1469         loc, newSliceOp, staticNewLows, staticNewHighs, newLows, newHighs);
1470 
1471     // Copy region to new PadTensorOp.
1472     BlockAndValueMapping bvm;
1473     region().cloneInto(&newPadTensorOp.getRegion(), bvm);
1474 
1475     // Cast result and return.
1476     return castResult(newPadTensorOp);
1477   };
1478 
1479   // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known
1480   // that the original data source x is not used.
1481   if (hasZeroLen) {
1482     return {createGenerateOp()};
1483   }
1484 
1485   // If there are dynamic dimensions: Generate an scf.if check to avoid creating
1486   // SliceOps with result dimensions of size 0 at runtime.
1487   if (dynHasZeroLenCond) {
1488     auto result = b.create<scf::IfOp>(
1489         loc, resultType, dynHasZeroLenCond,
1490         /*thenBuilder=*/
1491         [&](OpBuilder &b, Location loc) {
1492           b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
1493         },
1494         /*elseBuilder=*/
1495         [&](OpBuilder &b, Location loc) {
1496           b.create<scf::YieldOp>(loc,
1497                                  createPadTensorOfSubTensor()->getResult(0));
1498         });
1499     return {result};
1500   }
1501   return {createPadTensorOfSubTensor()};
1502 }
1503 
1504 namespace {
1505 // Folds linalg.pad_tensor when padding is static zeros and the attribute
1506 // doesn't request otherwise.
1507 struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
1508   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
1509 
1510   LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
1511                                 PatternRewriter &rewriter) const override {
1512     if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
1513       return failure();
1514     if (padTensorOp.nofold())
1515       return failure();
1516     rewriter.replaceOpWithNewOp<tensor::CastOp>(
1517         padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
1518     return success();
1519   }
1520 };
1521 
1522 // Fold CastOp into PadTensorOp when adding static information.
1523 struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
1524   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
1525 
1526   LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
1527                                 PatternRewriter &rewriter) const override {
1528     auto castOp = padTensorOp.source().getDefiningOp<tensor::CastOp>();
1529     if (!tensor::canFoldIntoConsumerOp(castOp))
1530       return failure();
1531 
1532     auto newResultType = PadTensorOp::inferResultType(
1533         castOp.source().getType().cast<RankedTensorType>(),
1534         extractFromI64ArrayAttr(padTensorOp.static_low()),
1535         extractFromI64ArrayAttr(padTensorOp.static_high()),
1536         padTensorOp.getResultType().getShape());
1537 
1538     if (newResultType == padTensorOp.getResultType()) {
1539       rewriter.updateRootInPlace(padTensorOp, [&]() {
1540         padTensorOp.sourceMutable().assign(castOp.source());
1541       });
1542     } else {
1543       auto newOp = rewriter.create<PadTensorOp>(
1544           padTensorOp->getLoc(), newResultType, padTensorOp.source(),
1545           padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
1546           padTensorOp.static_high(), padTensorOp.nofold());
1547       BlockAndValueMapping mapper;
1548       padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
1549 
1550       rewriter.replaceOpWithNewOp<tensor::CastOp>(
1551           padTensorOp, padTensorOp.getResultType(), newOp);
1552     }
1553     return success();
1554   }
1555 };
1556 
1557 // Fold CastOp using the result of PadTensorOp back into the latter if it adds
1558 // static information.
1559 struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> {
1560   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
1561 
1562   LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
1563                                 PatternRewriter &rewriter) const override {
1564     if (!padTensorOp.result().hasOneUse())
1565       return failure();
1566     auto tensorCastOp =
1567         dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
1568     if (!tensorCastOp)
1569       return failure();
1570     if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
1571                                             tensorCastOp.dest().getType()))
1572       return failure();
1573 
1574     auto replacementOp = rewriter.create<PadTensorOp>(
1575         padTensorOp.getLoc(), tensorCastOp.dest().getType(),
1576         padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
1577         padTensorOp.static_low(), padTensorOp.static_high(),
1578         padTensorOp.nofold());
1579     replacementOp.region().takeBody(padTensorOp.region());
1580 
1581     rewriter.replaceOp(padTensorOp, replacementOp.result());
1582     rewriter.replaceOp(tensorCastOp, replacementOp.result());
1583     return success();
1584   }
1585 };
1586 } // namespace
1587 
1588 void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
1589                                               MLIRContext *context) {
1590   results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context);
1591   results.add<FoldTargetTensorCast>(context);
1592 }
1593 
1594 /// Return the padding value of the PadTensorOp if it constant. In this context,
1595 /// "constant" means an actual constant or "defined outside of the block".
1596 ///
1597 /// Values are considered constant in three cases:
1598 ///  - A ConstantLike value.
1599 ///  - A basic block argument from a different block.
1600 ///  - A value defined outside of the block.
1601 ///
1602 /// If the padding value is not constant, an empty Value is returned.
1603 Value PadTensorOp::getConstantPaddingValue() {
1604   auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
1605   if (!yieldOp || yieldOp.values().size() != 1)
1606     return {};
1607   Value padValue = yieldOp.values().front();
1608   // Check if yield value is a constant.
1609   if (matchPattern(padValue, m_Constant()))
1610     return padValue;
1611   // Check if yield value is defined inside the PadTensorOp block.
1612   if (padValue.getParentBlock() == &getRegion().front())
1613     return {};
1614   // Else: Yield value defined outside of the PadTensorOp block.
1615   return padValue;
1616 }
1617 
1618 OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
1619   if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
1620       !nofold())
1621     return source();
1622   return {};
1623 }
1624 
1625 //===----------------------------------------------------------------------===//
1626 // YieldOp
1627 //===----------------------------------------------------------------------===//
1628 
1629 static void print(OpAsmPrinter &p, linalg::YieldOp op) {
1630   if (op.getNumOperands() > 0)
1631     p << ' ' << op.getOperands();
1632   p.printOptionalAttrDict(op->getAttrs());
1633   if (op.getNumOperands() > 0)
1634     p << " : " << op.getOperandTypes();
1635 }
1636 
1637 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
1638   SmallVector<OpAsmParser::OperandType, 2> opInfo;
1639   SmallVector<Type, 2> types;
1640   llvm::SMLoc loc = parser.getCurrentLocation();
1641   return failure(parser.parseOperandList(opInfo) ||
1642                  parser.parseOptionalAttrDict(result.attributes) ||
1643                  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1644                  parser.resolveOperands(opInfo, types, loc, result.operands));
1645 }
1646 
1647 // Check the operand number and types must match the element types of the
1648 // LinalgOp interface's shaped operands.
1649 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1650   if (op.getNumOperands() != linalgOp.getNumOutputs())
1651     return op.emitOpError("expected number of yield values (")
1652            << linalgOp.getNumOutputs()
1653            << ") to match the number of operands of the enclosing "
1654            << "LinalgOp (" << op.getNumOperands() << ")";
1655 
1656   for (OpOperand &opOperand : op->getOpOperands()) {
1657     OpOperand *outputOperand =
1658         linalgOp.getOutputOperand(opOperand.getOperandNumber());
1659     Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
1660     if (opOperand.get().getType() != elementType)
1661       return op.emitOpError("type of yield operand ")
1662              << (opOperand.getOperandNumber() + 1) << " ("
1663              << opOperand.get().getType() << ") doesn't match "
1664              << "the element type of the enclosing linalg.generic op ("
1665              << elementType << ")";
1666   }
1667   return success();
1668 }
1669 
1670 static LogicalResult verify(linalg::YieldOp op) {
1671   auto *parentOp = op->getParentOp();
1672   if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1673     return op.emitOpError("expected single non-empty parent region");
1674 
1675   if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1676     return verifyYield(op, cast<LinalgOp>(parentOp));
1677 
1678   if (auto padTensorOp = dyn_cast<linalg::PadTensorOp>(parentOp)) {
1679     if (op.getNumOperands() != 1)
1680       return op.emitOpError("expected single yield operand (got ")
1681              << op->getNumOperands() << ")";
1682     if (op.getOperand(0).getType() !=
1683         padTensorOp.getType().cast<ShapedType>().getElementType())
1684       return op.emitOpError("expected yield type to match shape element type");
1685     return success();
1686   }
1687 
1688   if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) {
1689     // Check if output args with tensor types match results types.
1690     SmallVector<Value, 2> tensorOuts;
1691     llvm::copy_if(
1692         tiledLoopOp.outputs(), std::back_inserter(tensorOuts),
1693         [&](Value out) { return out.getType().isa<RankedTensorType>(); });
1694     if (tensorOuts.size() != op.values().size())
1695       return op.emitOpError("expected number of tensor output args = ")
1696              << tensorOuts.size() << " to match the number of yield operands = "
1697              << op.values().size();
1698 
1699     TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts));
1700     for (auto &item :
1701          llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) {
1702       Type outType, resultType;
1703       unsigned index = item.index();
1704       std::tie(outType, resultType) = item.value();
1705       if (outType != resultType)
1706         return op.emitOpError("expected yield operand ")
1707                << index << " with type = " << resultType
1708                << " to match output arg type = " << outType;
1709     }
1710     return success();
1711   }
1712   return op.emitOpError("expected parent op with LinalgOp interface");
1713 }
1714 
1715 //===----------------------------------------------------------------------===//
1716 // TiledLoopOp
1717 //===----------------------------------------------------------------------===//
1718 
1719 void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
1720                         ValueRange lowerBounds, ValueRange upperBounds,
1721                         ValueRange steps, ValueRange inputs, ValueRange outputs,
1722                         ArrayAttr iteratorTypes,
1723                         function_ref<void(OpBuilder &, Location, ValueRange,
1724                                           ValueRange, ValueRange)>
1725                             bodyBuilderFn) {
1726   build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs,
1727         iteratorTypes, llvm::None, bodyBuilderFn);
1728 }
1729 
1730 void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
1731                         ValueRange lowerBounds, ValueRange upperBounds,
1732                         ValueRange steps, ValueRange inputs, ValueRange outputs,
1733                         ArrayAttr iteratorTypes,
1734                         Optional<ArrayAttr> distributionTypes,
1735                         function_ref<void(OpBuilder &, Location, ValueRange,
1736                                           ValueRange, ValueRange)>
1737                             bodyBuilderFn) {
1738   result.addOperands(lowerBounds);
1739   result.addOperands(upperBounds);
1740   result.addOperands(steps);
1741   result.addOperands(inputs);
1742   result.addOperands(outputs);
1743   result.addAttribute(
1744       TiledLoopOp::getOperandSegmentSizeAttr(),
1745       builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
1746                                 static_cast<int32_t>(upperBounds.size()),
1747                                 static_cast<int32_t>(steps.size()),
1748                                 static_cast<int32_t>(inputs.size()),
1749                                 static_cast<int32_t>(outputs.size())}));
1750   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
1751 
1752   if (distributionTypes.hasValue())
1753     result.addAttribute(getDistributionTypesAttrName(),
1754                         distributionTypes.getValue());
1755 
1756   // Add output types for `RankedTensorType` output arguments.
1757   for (Value output : outputs) {
1758     Type outputType = output.getType();
1759     if (outputType.isa<RankedTensorType>())
1760       result.addTypes(outputType);
1761   }
1762 
1763   OpBuilder::InsertionGuard guard(builder);
1764   unsigned numIVs = steps.size();
1765   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
1766   for (Type type : TypeRange(inputs))
1767     argTypes.push_back(type);
1768   for (Type type : TypeRange(outputs))
1769     argTypes.push_back(type);
1770   Region *bodyRegion = result.addRegion();
1771   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
1772 
1773   if (bodyBuilderFn) {
1774     builder.setInsertionPointToStart(bodyBlock);
1775     bodyBuilderFn(builder, result.location,
1776                   bodyBlock->getArguments().take_front(numIVs),
1777                   bodyBlock->getArguments().slice(numIVs, inputs.size()),
1778                   bodyBlock->getArguments().take_back(outputs.size()));
1779     TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
1780   }
1781 }
1782 
1783 static void print(OpAsmPrinter &p, TiledLoopOp op) {
1784   p << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to ("
1785     << op.upperBound() << ") step (" << op.step() << ")";
1786 
1787   if (!op.inputs().empty()) {
1788     p << " ins (";
1789     llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p,
1790                           [&](auto it) {
1791                             p << std::get<0>(it) << " = " << std::get<1>(it)
1792                               << ": " << std::get<1>(it).getType();
1793                           });
1794     p << ")";
1795   }
1796   if (!op.outputs().empty()) {
1797     p << " outs (";
1798     llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p,
1799                           [&](auto it) {
1800                             p << std::get<0>(it) << " = " << std::get<1>(it)
1801                               << ": " << std::get<1>(it).getType();
1802                           });
1803     p << ")";
1804   }
1805 
1806   if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
1807         return attr.cast<StringAttr>().getValue() !=
1808                getParallelIteratorTypeName();
1809       }))
1810     p << " iterators" << op.iterator_types();
1811 
1812   if (op.distribution_types().hasValue())
1813     p << " distribution" << op.distribution_types().getValue();
1814 
1815   p << ' ';
1816   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1817   p.printOptionalAttrDict(
1818       op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(),
1819                                        getIteratorTypesAttrName(),
1820                                        getDistributionTypesAttrName()});
1821 }
1822 
1823 static ParseResult parseTiledLoopOp(OpAsmParser &parser,
1824                                     OperationState &result) {
1825   auto &builder = parser.getBuilder();
1826   // Parse an opening `(` followed by induction variables followed by `)`
1827   SmallVector<OpAsmParser::OperandType, 4> ivs;
1828   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1829                                      OpAsmParser::Delimiter::Paren))
1830     return failure();
1831 
1832   // Parse loop bounds.
1833   SmallVector<OpAsmParser::OperandType, 4> lower;
1834   if (parser.parseEqual() ||
1835       parser.parseOperandList(lower, ivs.size(),
1836                               OpAsmParser::Delimiter::Paren) ||
1837       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
1838     return failure();
1839 
1840   SmallVector<OpAsmParser::OperandType, 4> upper;
1841   if (parser.parseKeyword("to") ||
1842       parser.parseOperandList(upper, ivs.size(),
1843                               OpAsmParser::Delimiter::Paren) ||
1844       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
1845     return failure();
1846 
1847   // Parse step values.
1848   SmallVector<OpAsmParser::OperandType, 4> steps;
1849   if (parser.parseKeyword("step") ||
1850       parser.parseOperandList(steps, ivs.size(),
1851                               OpAsmParser::Delimiter::Paren) ||
1852       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
1853     return failure();
1854 
1855   // Parse input tensors.
1856   SmallVector<OpAsmParser::OperandType, 4> inputs, inputRegionArgs;
1857   SmallVector<Type, 4> inputTypes;
1858   if (succeeded(parser.parseOptionalKeyword("ins"))) {
1859     llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
1860 
1861     if (parser.parseAssignmentListWithTypes(inputRegionArgs, inputs,
1862                                             inputTypes))
1863       return failure();
1864 
1865     if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
1866                                result.operands))
1867       return failure();
1868   }
1869 
1870   // Parse output tensors.
1871   SmallVector<OpAsmParser::OperandType, 4> outputs, outputRegionArgs;
1872   SmallVector<Type, 4> outputTypes;
1873   if (succeeded(parser.parseOptionalKeyword("outs"))) {
1874     llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
1875 
1876     if (parser.parseAssignmentListWithTypes(outputRegionArgs, outputs,
1877                                             outputTypes))
1878       return failure();
1879 
1880     if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
1881                                result.operands))
1882       return failure();
1883     for (Type outputType : outputTypes)
1884       if (outputType.isa<RankedTensorType>())
1885         result.addTypes(outputType);
1886   }
1887 
1888   // Parse attributes.
1889   SmallVector<Attribute, 4> iterTypes, distributionTypes;
1890   auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) {
1891     if (succeeded(parser.parseOptionalKeyword(keyword))) {
1892       StringAttr attr;
1893 
1894       if (parser.parseLSquare() || parser.parseAttribute(attr))
1895         return failure();
1896       attrs->push_back(attr);
1897       for (int i = 1, e = ivs.size(); i < e; ++i) {
1898         if (parser.parseComma() || parser.parseAttribute(attr))
1899           return failure();
1900         attrs->push_back(attr);
1901       }
1902       if (parser.parseRSquare())
1903         return failure();
1904     }
1905     return success();
1906   };
1907   if (failed(parseAttr("iterators", &iterTypes)) ||
1908       failed(parseAttr("distribution", &distributionTypes)))
1909     return failure();
1910 
1911   // Set all loop iterator types to "parallel" if they are not printed in IR.
1912   if (iterTypes.empty()) {
1913     auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName());
1914     iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
1915   }
1916   result.addAttribute(getIteratorTypesAttrName(),
1917                       builder.getArrayAttr(iterTypes));
1918   if (!distributionTypes.empty())
1919     result.addAttribute(getDistributionTypesAttrName(),
1920                         builder.getArrayAttr(distributionTypes));
1921   result.addAttribute(
1922       TiledLoopOp::getOperandSegmentSizeAttr(),
1923       builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
1924                                 static_cast<int32_t>(upper.size()),
1925                                 static_cast<int32_t>(steps.size()),
1926                                 static_cast<int32_t>(inputs.size()),
1927                                 static_cast<int32_t>(outputs.size())}));
1928 
1929   // Parse the body.
1930   Region *body = result.addRegion();
1931 
1932   SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType());
1933   regionTypes.append(inputTypes);
1934   regionTypes.append(outputTypes);
1935 
1936   SmallVector<OpAsmParser::OperandType, 4> regionArgs(ivs);
1937   regionArgs.append(inputRegionArgs);
1938   regionArgs.append(outputRegionArgs);
1939 
1940   if (parser.parseRegion(*body, regionArgs, regionTypes))
1941     return failure();
1942 
1943   // Parse optional attributes.
1944   parser.parseOptionalAttrDict(result.attributes);
1945 
1946   return success();
1947 }
1948 
1949 Region &TiledLoopOp::getLoopBody() { return region(); }
1950 
1951 LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1952   for (auto *op : ops)
1953     op->moveBefore(*this);
1954   return success();
1955 }
1956 
1957 bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) {
1958   return !region().isAncestor(value.getParentRegion());
1959 }
1960 
1961 static LogicalResult verify(TiledLoopOp op) {
1962   // Check if iterator types are provided for every loop dimension.
1963   if (op.iterator_types().size() != op.getNumLoops())
1964     return op.emitOpError("expected iterator types array attribute size = ")
1965            << op.iterator_types().size()
1966            << " to match the number of loops = " << op.getNumLoops();
1967 
1968   // Check if types of input arguments match region args types.
1969   for (auto &item :
1970        llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) {
1971     Value input, inputRegionArg;
1972     unsigned index = item.index();
1973     std::tie(input, inputRegionArg) = item.value();
1974     if (input.getType() != inputRegionArg.getType())
1975       return op.emitOpError("expected input arg ")
1976              << index << " with type = " << input.getType()
1977              << " to match region arg " << index + op.getNumLoops()
1978              << " type = " << inputRegionArg.getType();
1979   }
1980 
1981   // Check if types of input arguments match region args types.
1982   for (auto &item :
1983        llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) {
1984     Value output, outputRegionArg;
1985     unsigned index = item.index();
1986     std::tie(output, outputRegionArg) = item.value();
1987     if (output.getType() != outputRegionArg.getType())
1988       return op.emitOpError("expected output arg ")
1989              << index << " with type = " << output.getType()
1990              << " to match region arg "
1991              << index + op.getNumLoops() + op.inputs().size()
1992              << " type = " << outputRegionArg.getType();
1993   }
1994   return success();
1995 }
1996 
1997 namespace {
1998 
1999 static constexpr int64_t kNoMatch = -1;
2000 
2001 // Folds away TiledLoopOp inputs if they have no uses within the body.
2002 //
2003 // Example:
2004 //
2005 // %0 = linalg.tiled_loop ...  ins (%in_ = %in: tensor<...>,
2006 //                                  %in_buf_ = %in_buf: memref<...>) {...}
2007 // Becomes
2008 //
2009 // linalg.tiled_loop ...  ins (%in_buf_ = %in_buf: memref<...>) {...}
2010 struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
2011   using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
2012 
2013   LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
2014                                 PatternRewriter &rewriter) const final {
2015     SmallVector<Value, 2> newInputs, regionInputTensorArgs;
2016     // Store ids of the corresponding old and new input operands.
2017     SmallVector<int64_t, 2> oldInputIdToNew(tiledLoop.inputs().size(),
2018                                             kNoMatch);
2019     for (const auto &en : llvm::enumerate(
2020              llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) {
2021       Value in, bbArg;
2022       size_t index = en.index();
2023       std::tie(in, bbArg) = en.value();
2024       if (!bbArg.use_empty()) {
2025         oldInputIdToNew[index] = newInputs.size();
2026         newInputs.push_back(in);
2027       }
2028     }
2029     if (newInputs.size() == tiledLoop.inputs().size())
2030       return failure();
2031     Location loc = tiledLoop.getLoc();
2032     auto newTiledLoop = rewriter.create<TiledLoopOp>(
2033         loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
2034         newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(),
2035         tiledLoop.distribution_types());
2036 
2037     // Clone the region.
2038     BlockAndValueMapping bvm;
2039     bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
2040     bvm.map(tiledLoop.getRegionOutputArgs(),
2041             newTiledLoop.getRegionOutputArgs());
2042     for (const auto &en : llvm::enumerate(oldInputIdToNew))
2043       if (en.value() != kNoMatch)
2044         bvm.map(tiledLoop.getRegionInputArgs()[en.index()],
2045                 newTiledLoop.getRegionInputArgs()[en.value()]);
2046     OpBuilder innerBuilder =
2047         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
2048     for (auto &op : *tiledLoop.getBody())
2049       innerBuilder.clone(op, bvm);
2050     rewriter.replaceOp(tiledLoop, newTiledLoop.getResults());
2051 
2052     return success();
2053   }
2054 };
2055 
2056 } // namespace
2057 
2058 /// A simple, conservative analysis to determine if the loop is shape
2059 /// conserving. I.e., the type of the arg-th yielded value is the same as the
2060 /// type of the corresponding basic block argument of the loop.
2061 /// Note: This function handles only simple cases. Expand as needed.
2062 static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
2063   auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
2064   if (yieldOp.values().empty())
2065     // Tiled loop either has no outputs or is a "memref-based version". In
2066     // either case, the loop is shape conserving.
2067     return true;
2068   assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
2069          "arg is out of bounds");
2070   Value value = yieldOp.values()[arg];
2071   while (value) {
2072     if (value == loopOp.getRegionOutputArgs()[arg])
2073       return true;
2074     OpResult opResult = value.dyn_cast<OpResult>();
2075     if (!opResult)
2076       return false;
2077 
2078     using tensor::InsertSliceOp;
2079     value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
2080                 .template Case<InsertSliceOp>(
2081                     [&](InsertSliceOp op) { return op.dest(); })
2082                 .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
2083                   return isShapePreserving(loopOp, opResult.getResultNumber())
2084                              ? loopOp.outputs()[opResult.getResultNumber()]
2085                              : Value();
2086                 })
2087                 .Default([&](auto op) { return Value(); });
2088   }
2089   return false;
2090 }
2091 
2092 namespace {
2093 
2094 /// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
2095 /// to dim(y) where `y` is the initial input/output value of the argument.
2096 ///
2097 /// E.g.:
2098 /// %y = ... : tensor<...>
2099 /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
2100 ///   tensor.dim %x, %c0 : tensor<...>
2101 /// }
2102 ///
2103 /// is folded to:
2104 /// %y = ... : tensor<...>
2105 /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
2106 ///   tensor.dim %y, %c0 : tensor<...>
2107 /// }
2108 ///
2109 /// Note: Dim ops are folded only if it can be proven that the runtime type of
2110 /// the yielded value (in case of outputs) does not change with loop iterations.
2111 template <typename OpTy>
2112 struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
2113   using OpRewritePattern<OpTy>::OpRewritePattern;
2114 
2115   LogicalResult matchAndRewrite(OpTy dimOp,
2116                                 PatternRewriter &rewriter) const final {
2117     auto src = dimOp.source().template dyn_cast<BlockArgument>();
2118     if (!src)
2119       return failure();
2120     auto loopOp =
2121         dyn_cast<TiledLoopOp>(src.getOwner()->getParent()->getParentOp());
2122     if (!loopOp)
2123       return failure();
2124     unsigned numLoops = loopOp.getNumLoops();
2125     unsigned numInputArgs = loopOp.getRegionInputArgs().size();
2126     if (src.getArgNumber() >= numInputArgs + numLoops &&
2127         !isShapePreserving(loopOp,
2128                            src.getArgNumber() - numInputArgs - numLoops))
2129       return failure();
2130 
2131     auto inputArgs = loopOp.getRegionInputArgs();
2132     auto it1 = llvm::find(inputArgs, src);
2133     if (it1 != inputArgs.end()) {
2134       rewriter.updateRootInPlace(dimOp, [&] {
2135         dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]);
2136       });
2137       return success();
2138     }
2139 
2140     auto outputArgs = loopOp.getRegionOutputArgs();
2141     auto it2 = llvm::find(outputArgs, src);
2142     if (it2 != outputArgs.end()) {
2143       rewriter.updateRootInPlace(dimOp, [&] {
2144         dimOp.sourceMutable().assign(
2145             loopOp.outputs()[it2 - outputArgs.begin()]);
2146       });
2147       return success();
2148     }
2149 
2150     return failure();
2151   }
2152 };
2153 
2154 /// Fold dim(r) where `r` is the result of a TiledLoopOp to dim(y) where `y`
2155 /// is the initial output value of the loop.
2156 ///
2157 /// E.g.:
2158 /// %y = ... : tensor<...>
2159 /// %r = linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
2160 ///   ...
2161 /// }
2162 /// %0 = tensor.dim %r, %c0 : tensor<...>
2163 ///
2164 /// is folded to:
2165 /// %y = ... : tensor<...>
2166 /// linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
2167 ///   ...
2168 /// }
2169 /// %0 = tensor.dim %y, %c0 : tensor<...>
2170 ///
2171 /// Note: Dim ops are folded only if it can be proven that the runtime type of
2172 /// the yielded value (in case of outputs) does not change with loop iterations.
2173 template <typename OpTy>
2174 struct DimOfTiledLoopResultFolder : public OpRewritePattern<OpTy> {
2175   using OpRewritePattern<OpTy>::OpRewritePattern;
2176 
2177   LogicalResult matchAndRewrite(OpTy dimOp,
2178                                 PatternRewriter &rewriter) const final {
2179     auto loopOp = dimOp.source().template getDefiningOp<TiledLoopOp>();
2180     if (!loopOp)
2181       return failure();
2182     auto opResult = dimOp.source().template cast<OpResult>();
2183     unsigned resultNumber = opResult.getResultNumber();
2184     if (!isShapePreserving(loopOp, resultNumber))
2185       return failure();
2186     rewriter.updateRootInPlace(dimOp, [&]() {
2187       dimOp.sourceMutable().assign(loopOp.outputs()[resultNumber]);
2188     });
2189     return success();
2190   }
2191 };
2192 
2193 // Folds away TiledLoopOp output tensors when the following conditions are met:
2194 // * result of `linalg.tiled_loop` has no uses
2195 // * output tensor is the argument of `linalg.yield`
2196 //
2197 // Example:
2198 //
2199 // %0 = linalg.tiled_loop ...  outs (%o_ = %out: tensor<...>,
2200 //                                   %obuf_ = %out_buf: memref<...>) {
2201 //   ...
2202 //   linalg.yield %o_ : tensor ...
2203 // }
2204 //
2205 // Becomes
2206 //
2207 // linalg.tiled_loop ...  outs (%obuf_ = %out_buf: memref<...>) {
2208 //   ...
2209 //   linalg.yield
2210 // }
2211 struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
2212   using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
2213 
2214   LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
2215                                 PatternRewriter &rewriter) const final {
2216     if (tiledLoop.getNumResults() == 0)
2217       return failure();
2218 
2219     Block *block = tiledLoop.getBody();
2220     auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
2221 
2222     // Match the pattern and collect output buffers that will replace the output
2223     // tensors and also the ops that will be ignored when cloning the body.
2224     SmallVector<Value, 2> newOutputOperands, newYieldArgs;
2225     int resultId = 0;
2226     // Store ids of the corresponding old and new output operands.
2227     SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
2228                                              kNoMatch);
2229     // Store ids of the corresponding old and new results.
2230     SmallVector<int64_t, 2> oldResultIdToNew(tiledLoop.getNumResults(),
2231                                              kNoMatch);
2232     SmallVector<Value, 2> resultReplacement(tiledLoop.getNumResults());
2233     for (const auto &en : llvm::enumerate(
2234              llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
2235       size_t index = en.index();
2236       Value out = std::get<0>(en.value());
2237       Value outRegionArg = std::get<1>(en.value());
2238 
2239       if (!out.getType().isa<RankedTensorType>()) {
2240         oldOutputIdToNew[index] = newOutputOperands.size();
2241         newOutputOperands.push_back(out);
2242         continue;
2243       }
2244       Value result = tiledLoop.getResult(resultId);
2245       Value yieldArg = yieldOp.getOperand(resultId);
2246       if (yieldArg != outRegionArg || !result.use_empty()) {
2247         oldOutputIdToNew[index] = newOutputOperands.size();
2248         oldResultIdToNew[resultId] = newYieldArgs.size();
2249         resultReplacement[resultId] = out;
2250         newOutputOperands.push_back(out);
2251         newYieldArgs.push_back(yieldArg);
2252       }
2253       ++resultId;
2254     }
2255     if (newOutputOperands.size() == tiledLoop.outputs().size())
2256       return failure();
2257 
2258     Location loc = tiledLoop.getLoc();
2259     auto newTiledLoop = rewriter.create<TiledLoopOp>(
2260         loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
2261         tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(),
2262         tiledLoop.distribution_types());
2263 
2264     // Clone the region.
2265     BlockAndValueMapping bvm;
2266     bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
2267     bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs());
2268     for (const auto &en : llvm::enumerate(oldOutputIdToNew)) {
2269       if (en.value() != kNoMatch)
2270         bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
2271                 newTiledLoop.getRegionOutputArgs()[en.value()]);
2272       else
2273         bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
2274                 tiledLoop.outputs()[en.index()]);
2275     }
2276     OpBuilder innerBuilder =
2277         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
2278     for (auto &op : tiledLoop.getBody()->without_terminator())
2279       innerBuilder.clone(op, bvm);
2280     innerBuilder.create<linalg::YieldOp>(
2281         loc, llvm::to_vector<2>(llvm::map_range(
2282                  newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
2283 
2284     for (const auto &en : llvm::enumerate(oldResultIdToNew))
2285       if (en.value() != kNoMatch)
2286         resultReplacement[en.index()] = newTiledLoop.getResult(en.value());
2287     rewriter.replaceOp(tiledLoop, resultReplacement);
2288 
2289     return success();
2290   }
2291 };
2292 } // namespace
2293 
2294 void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2295                                               MLIRContext *context) {
2296   results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder,
2297                  DimOfTiledLoopInsOutsFolder<tensor::DimOp>,
2298                  DimOfTiledLoopInsOutsFolder<memref::DimOp>,
2299                  DimOfTiledLoopResultFolder<tensor::DimOp>,
2300                  DimOfTiledLoopResultFolder<memref::DimOp>>(context);
2301 }
2302 
2303 LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
2304                                 SmallVectorImpl<OpFoldResult> &) {
2305   return foldMemRefCastInTiledLoopOp(*this);
2306 }
2307 
2308 //===----------------------------------------------------------------------===//
2309 // IndexOp
2310 //===----------------------------------------------------------------------===//
2311 
2312 static LogicalResult verify(IndexOp op) {
2313   auto linalgOp = dyn_cast<LinalgOp>(op->getParentOp());
2314   if (!linalgOp)
2315     return op.emitOpError("expected parent op with LinalgOp interface");
2316   if (linalgOp.getNumLoops() <= op.dim())
2317     return op.emitOpError("expected dim (")
2318            << op.dim() << ") to be lower than the number of loops ("
2319            << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2320   return success();
2321 }
2322 
2323 /////// Operations corresponding to library calls defined with Tablegen ////////
2324 
2325 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2326 
2327 #define GET_OP_CLASSES
2328 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2329 
2330 #define GET_OP_CLASSES
2331 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2332 
2333 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
2334 /// Assumes `op` is a LinalgOp.
2335 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
2336                                  SmallVectorImpl<unsigned> &res) {
2337   if (!cast<LinalgOp>(op).iterator_types())
2338     return;
2339 
2340   unsigned dim = 0;
2341   for (auto tn :
2342        cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
2343     if (tn == iteratorTypeName)
2344       res.push_back(dim);
2345     ++dim;
2346   }
2347 }
2348 
2349 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
2350                                              unsigned rank,
2351                                              MLIRContext *context) {
2352   if (maybeMap)
2353     return maybeMap.getValue();
2354   if (rank == 0)
2355     return AffineMap::get(context);
2356   return AffineMap::getMultiDimIdentityMap(rank, context);
2357 }
2358 
2359 SmallVector<AffineExpr, 4>
2360 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2361                                  MLIRContext *context) {
2362   SmallVector<AffineExpr, 4> res;
2363   res.reserve(num);
2364   for (unsigned i = 0; i < num; ++i)
2365     res.push_back(getAffineDimExpr(startIdx++, context));
2366   return res;
2367 }
2368 
2369 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
2370                                                 ArrayRef<AffineExpr> b) {
2371   auto rangeA = llvm::make_range(a.begin(), a.end());
2372   auto rangeB = llvm::make_range(b.begin(), b.end());
2373   auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2374   return llvm::to_vector<4>(concatRanges);
2375 }
2376 
2377 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2378   if (auto memref = t.dyn_cast<MemRefType>()) {
2379     ss << "view";
2380     for (auto size : memref.getShape())
2381       if (size < 0)
2382         ss << "sx";
2383       else
2384         ss << size << "x";
2385     appendMangledType(ss, memref.getElementType());
2386   } else if (auto vec = t.dyn_cast<VectorType>()) {
2387     ss << "vector";
2388     llvm::interleave(
2389         vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2390     appendMangledType(ss, vec.getElementType());
2391   } else if (t.isSignlessIntOrIndexOrFloat()) {
2392     ss << t;
2393   } else {
2394     llvm_unreachable("Invalid type for linalg library name mangling");
2395   }
2396 }
2397 
2398 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
2399   assert(isa<LinalgOp>(op));
2400   std::string name(op->getName().getStringRef().str());
2401   name.reserve(128);
2402   std::replace(name.begin(), name.end(), '.', '_');
2403   llvm::raw_string_ostream ss(name);
2404   ss << "_";
2405   auto types = op->getOperandTypes();
2406   llvm::interleave(
2407       types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
2408       [&]() { ss << "_"; });
2409   return ss.str();
2410 }
2411 
2412 //===----------------------------------------------------------------------===//
2413 // Support for named Linalg ops defined in ods-gen.
2414 //===----------------------------------------------------------------------===//
2415 
2416 /// Generic entry point to create the block for the region of a LinalgOp.
2417 /// This is used by both named structured ops created by ods-gen and by manually
2418 /// defined C++ ops.
2419 /// This is used by both builders and parsers.
2420 /// This function creates the block in the region with arguments corresponding
2421 /// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
2422 /// to be ShapedType.
2423 template <typename NamedStructuredOpType>
2424 static void fillStructuredOpRegion(
2425     OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
2426     TypeRange outputTypes,
2427     llvm::function_ref<void(unsigned, unsigned)> errorHandler) {
2428   assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
2429 
2430   // TODO: atm all operands go through getElementTypeOrSelf,
2431   // reconsider when we have evidence we need to.
2432   SmallVector<Type, 8> argTypes;
2433   for (auto containers : {inputTypes, outputTypes})
2434     for (auto t : containers)
2435       argTypes.push_back(getElementTypeOrSelf(t));
2436 
2437   // RAII.
2438   OpBuilder::InsertionGuard guard(opBuilder);
2439   Block *body = opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes);
2440   unsigned actual = body->getNumArguments();
2441   unsigned expected = NamedStructuredOpType::getNumRegionArgs();
2442   if (expected != actual) {
2443     if (errorHandler)
2444       errorHandler(expected, actual);
2445     return;
2446   }
2447 
2448   opBuilder.setInsertionPointToStart(body);
2449   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
2450   NamedStructuredOpType::regionBuilder(b, *body);
2451 
2452   // indexing_maps is an auto-generated method.
2453 
2454   // iterator_types is an auto-generated method.
2455 }
2456 
2457 /// Generic entry point to create both the region and the block of a LinalgOp.
2458 template <typename NamedStructuredOpType>
2459 void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
2460                                      OperationState &result,
2461                                      TypeRange inputTypes,
2462                                      TypeRange outputTypes) {
2463   Region &region = *result.addRegion();
2464   fillStructuredOpRegion<NamedStructuredOpType>(
2465       opBuilder, region, inputTypes, outputTypes,
2466       [&](unsigned expected, unsigned actual) {
2467         assert(expected != actual && "incorrect number of arguments");
2468       });
2469 }
2470 
2471 /// Common parsing used for both named structured ops created by ods-gen and by
2472 /// manually defined C++ ops. Does not handle regions.
2473 static ParseResult
2474 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
2475                              SmallVectorImpl<Type> &inputTypes,
2476                              SmallVectorImpl<Type> &outputTypes) {
2477   llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc;
2478   SmallVector<OpAsmParser::OperandType, 4> inputsOperands, outputsOperands;
2479 
2480   parser.parseOptionalAttrDict(result.attributes);
2481 
2482   if (succeeded(parser.parseOptionalKeyword("ins"))) {
2483     if (parser.parseLParen())
2484       return failure();
2485 
2486     inputsOperandsLoc = parser.getCurrentLocation();
2487     if (parser.parseOperandList(inputsOperands) ||
2488         parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2489       return failure();
2490   }
2491 
2492   if (succeeded(parser.parseOptionalKeyword("outs"))) {
2493     outputsOperandsLoc = parser.getCurrentLocation();
2494     if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
2495         parser.parseColonTypeList(outputTypes) || parser.parseRParen())
2496       return failure();
2497   }
2498 
2499   if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2500                              result.operands) ||
2501       parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
2502                              result.operands))
2503     return failure();
2504 
2505   result.addAttribute("operand_segment_sizes",
2506                       parser.getBuilder().getI32VectorAttr(
2507                           {static_cast<int32_t>(inputsOperands.size()),
2508                            static_cast<int32_t>(outputsOperands.size())}));
2509   return success();
2510 }
2511 
2512 template <typename NamedStructuredOpType>
2513 static void printCommonStructuredOpParts(OpAsmPrinter &p,
2514                                          NamedStructuredOpType op) {
2515   if (!op.inputs().empty())
2516     p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
2517   if (!op.outputs().empty())
2518     p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
2519 }
2520 
2521 //===----------------------------------------------------------------------===//
2522 // Specific parsing and printing for named structured ops created by ods-gen.
2523 //===----------------------------------------------------------------------===//
2524 
2525 template <typename NamedStructuredOpType>
2526 static ParseResult
2527 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
2528                              TypeRange inputTypes, TypeRange outputTypes) {
2529   ParseResult res = success();
2530   OpBuilder opBuilder(parser.getContext());
2531   // Resolve `captures` into `capturedValues` at parse time so we can build the
2532   // region with captures.
2533   SmallVector<Value> capturedValues;
2534   fillStructuredOpRegion<NamedStructuredOpType>(
2535       opBuilder, region, inputTypes, outputTypes,
2536       [&](unsigned expected, unsigned actual) {
2537         res = parser.emitError(
2538             parser.getCurrentLocation(),
2539             llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
2540                           "region expects {0} args, got {1}",
2541                           expected, actual));
2542         region.front().dump();
2543       });
2544   return res;
2545 }
2546 
2547 static ParseResult
2548 parseNamedStructuredOpResults(OpAsmParser &parser,
2549                               SmallVectorImpl<Type> &resultTypes) {
2550   if (parser.parseOptionalArrowTypeList(resultTypes))
2551     return failure();
2552   return success();
2553 }
2554 
2555 template <typename NamedStructuredOpType>
2556 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
2557                                           OperationState &result) {
2558   // TODO: Enable when ods-gen supports captures.
2559   SmallVector<Type, 1> inputTypes, outputTypes;
2560   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
2561     return failure();
2562 
2563   // TODO: consider merging results parsing into region parsing.
2564   // Need to wait for declarative assembly resolution to decide.
2565   SmallVector<Type, 1> outputTensorsTypes;
2566   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
2567     return failure();
2568   result.addTypes(outputTensorsTypes);
2569 
2570   std::unique_ptr<Region> region = std::make_unique<Region>();
2571   if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
2572           parser, *region, inputTypes, outputTypes))
2573     return failure();
2574   result.addRegion(std::move(region));
2575 
2576   return success();
2577 }
2578 
2579 static void printNamedStructuredOpResults(OpAsmPrinter &p,
2580                                           TypeRange resultTypes) {
2581   if (resultTypes.empty())
2582     return;
2583   p.printOptionalArrowTypeList(resultTypes);
2584 }
2585 
2586 template <typename NamedStructuredOpType>
2587 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
2588   p.printOptionalAttrDict(
2589       op->getAttrs(),
2590       /*elidedAttrs=*/{"operand_segment_sizes",
2591                        // See generated code in mlir-linalg-yaml-gen.cpp
2592                        "linalg.memoized_indexing_maps"});
2593 
2594   // Printing is shared with generic ops, except for the region and
2595   // attributes.
2596   printCommonStructuredOpParts(p, op);
2597 
2598   // Results printing.
2599   printNamedStructuredOpResults(p, op.result_tensors().getTypes());
2600 
2601   // Region is elided.
2602 }
2603 
2604 template <typename NamedStructuredOpType>
2605 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
2606   return verifyGenericOp<NamedStructuredOpType>(op);
2607 }
2608 
2609 //===----------------------------------------------------------------------===//
2610 // Canonicalizers and Folders.
2611 //===----------------------------------------------------------------------===//
2612 
2613 namespace {
2614 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2615   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2616 
2617   LogicalResult matchAndRewrite(LinalgOp op,
2618                                 PatternRewriter &rewriter) const override {
2619     for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
2620       // Linalg "inputs" may be either tensor or memref type.
2621       // tensor<0xelt_type> is a convention that may not always mean
2622       // "0 iterations". Only erase in cases we see memref<...x0x...>.
2623       auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
2624       if (!mt)
2625         continue;
2626       if (llvm::is_contained(op.getShape(opOperand), 0)) {
2627         rewriter.eraseOp(op);
2628         return success();
2629       }
2630     }
2631     return failure();
2632   }
2633 };
2634 
2635 struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
2636   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2637 
2638   LogicalResult matchAndRewrite(LinalgOp op,
2639                                 PatternRewriter &rewriter) const override {
2640     // If no operand comes from a tensor::CastOp and can be folded then fail.
2641     bool hasTensorCastOperand =
2642         llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
2643           if (opOperand->get().isa<BlockArgument>())
2644             return false;
2645           auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
2646           return castOp && canFoldIntoConsumerOp(castOp);
2647         });
2648     if (!hasTensorCastOperand)
2649       return failure();
2650 
2651     SmallVector<Type, 4> newResultTypes;
2652     newResultTypes.reserve(op->getNumResults());
2653     SmallVector<Value, 4> newOperands;
2654     newOperands.reserve(op->getNumOperands());
2655     // Inputs may fold.
2656     for (OpOperand *opOperand : op.getInputOperands()) {
2657       auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
2658       newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
2659                                 ? tensorCastOp.source()
2660                                 : opOperand->get());
2661     }
2662     // Init tensors may fold, in which case the resultType must also change.
2663     for (OpOperand *opOperand : op.getOutputOperands()) {
2664       auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
2665       bool fold = canFoldIntoConsumerOp(tensorCastOp);
2666       newOperands.push_back(fold ? tensorCastOp.getOperand()
2667                                  : opOperand->get());
2668       newResultTypes.push_back(newOperands.back().getType());
2669     }
2670     // Clone op.
2671     Operation *newOp =
2672         op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
2673     SmallVector<Value, 4> replacements;
2674     replacements.reserve(newOp->getNumResults());
2675     for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
2676       Value oldResult = std::get<0>(result);
2677       Value newResult = std::get<1>(result);
2678       if (newResult.getType() != oldResult.getType()) {
2679         replacements.push_back(rewriter.create<tensor::CastOp>(
2680             op->getLoc(), oldResult.getType(), newResult));
2681       } else {
2682         replacements.push_back(newResult);
2683       }
2684     }
2685     rewriter.replaceOp(op, replacements);
2686 
2687     return success();
2688   }
2689 };
2690 
2691 } // namespace
2692 
2693 #define LINALGOP_FOLDERS(XXX)                                                  \
2694   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
2695                           SmallVectorImpl<OpFoldResult> &) {                   \
2696     return foldMemRefCast(*this);                                              \
2697   }
2698 
2699 LINALGOP_FOLDERS(CopyOp)
2700 LINALGOP_FOLDERS(FillOp)
2701 LINALGOP_FOLDERS(GenericOp)
2702 
2703 // All named ops canonicalizers and folders are auto-generated in the
2704 // .cpp.inc.
2705 
2706 //===----------------------------------------------------------------------===//
2707 // LinalgDialect
2708 //===----------------------------------------------------------------------===//
2709 
2710 void LinalgDialect::getCanonicalizationPatterns(
2711     RewritePatternSet &results) const {
2712   results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
2713 }
2714 
2715 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
2716                                               Attribute value, Type type,
2717                                               Location loc) {
2718   return builder.create<arith::ConstantOp>(loc, type, value);
2719 }
2720