1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/AffineExprVisitor.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Interfaces/InferTypeOpInterface.h"
26 #include "mlir/Parser.h"
27 
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/Support/FormatVariadic.h"
33 #include "llvm/Support/MathExtras.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 using namespace mlir;
37 using namespace mlir::linalg;
38 
39 /// Forward declarations.
40 
41 /// Generic entry point to create the block for the region of a LinalgOp.
42 /// This is used by both named structured ops created by ods-gen and by manually
43 /// defined C++ ops.
44 /// This is used by both builders and parsers.
45 /// This function creates the block in the region with arguments corresponding
46 /// to the elemental types of `inputTypes` and `outputTypes`. The latter are
47 /// asserted to be of ShapedType.
48 template <typename NamedStructuredOpType>
49 static void fillStructuredOpRegion(
50     OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
51     TypeRange outputTypes,
52     std::function<void(unsigned, unsigned)> errorHandler = nullptr);
53 
54 /// Generic entry point to create both the region and the block of a LinalgOp.
55 template <typename NamedStructuredOpType>
56 static void
57 createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
58                                 TypeRange inputTypes, TypeRange outputTypes);
59 
60 /// Common parsing and printing used for both named structured ops created by
61 /// ods-gen and by manually defined C++ ops. Does not handle regions.
62 static ParseResult
63 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
64                              SmallVectorImpl<Type> &inputTypes,
65                              SmallVectorImpl<Type> &outputTypes);
66 template <typename NamedStructuredOpType>
67 static void printCommonStructuredOpParts(OpAsmPrinter &p,
68                                          NamedStructuredOpType op);
69 
70 /// Specific parsing and printing for named structured ops created by ods-gen.
71 template <typename NamedStructuredOpType>
72 static ParseResult
73 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
74                              TypeRange inputTypes, TypeRange outputTypes);
75 
76 static ParseResult
77 parseNamedStructuredOpResults(OpAsmParser &parser,
78                               SmallVectorImpl<Type> &resultTypes);
79 
80 template <typename NamedStructuredOpType>
81 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
82                                           OperationState &result);
83 
84 static void printNamedStructuredOpResults(OpAsmPrinter &p,
85                                           TypeRange resultTypes);
86 
87 template <typename NamedStructuredOpType>
88 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
89 
90 /// Helper function to convert a Value into an OpFoldResult, if the Value is
91 /// known to be a constant index value.
92 static SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
93   return llvm::to_vector<4>(
94       llvm::map_range(values, [](Value v) -> OpFoldResult {
95         APInt intValue;
96         if (v.getType().isa<IndexType>() &&
97             matchPattern(v, m_ConstantInt(&intValue))) {
98           return IntegerAttr::get(v.getType(), intValue.getSExtValue());
99         }
100         return v;
101       }));
102 }
103 
104 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
105 /// `Value`s.
106 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
107                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
108   return llvm::to_vector<4>(
109       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
110         if (auto attr = value.dyn_cast<Attribute>())
111           return b.create<ConstantIndexOp>(loc,
112                                            attr.cast<IntegerAttr>().getInt());
113         return value.get<Value>();
114       }));
115 }
116 
117 /// This is a common class used for patterns of the form
118 /// ```
119 ///    someop(memrefcast(%src)) -> someop(%src)
120 /// ```
121 /// It folds the source of the memref.cast into the root operation directly.
122 static LogicalResult foldMemRefCast(Operation *op) {
123   bool folded = false;
124   for (OpOperand &operand : op->getOpOperands()) {
125     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
126     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
127       operand.set(castOp.getOperand());
128       folded = true;
129     }
130   }
131   return success(folded);
132 }
133 
134 /// This is a specialization of `foldMemRefCast` used for patterns of the form
135 /// ```
136 ///    tiled_loop(memrefcast(%src)) -> tiled_loop(%src)
137 /// ```
138 /// It folds the source of the memref.cast into the root operation directly.
139 static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
140   bool folded = false;
141   Location loc = op->getLoc();
142 
143   Block *body = op.getBody();
144   OpBuilder b = OpBuilder::atBlockBegin(body);
145 
146   // Update `input` and `output` operands and block arguments if necessary.
147   // Operands list: [lbs, ubs, steps, inputs, outputs].
148   // Block args list: [ivs, inputs, outputs].
149   for (size_t operandIndex = op.getNumControlOperands(),
150               bbArgIndex = op.getNumLoops(), e = op.getNumOperands();
151        operandIndex < e; ++operandIndex, ++bbArgIndex) {
152     OpOperand &operand = op->getOpOperand(operandIndex);
153 
154     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
155     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
156       operand.set(castOp.getOperand());
157       BlockArgument newBbArg =
158           body->insertArgument(bbArgIndex, castOp.getOperand().getType());
159       BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1);
160 
161       // Insert memref.cast back to the original type.
162       oldBbArg.replaceAllUsesWith(
163           b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg));
164       body->eraseArgument(oldBbArg.getArgNumber());
165 
166       folded = true;
167     }
168   }
169   return success(folded);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Region builder helper.
174 // TODO: Move this to a utility library.
175 // The public methods on this class are referenced directly from generated code
176 // and bind by name to math functions in the DSL as:
177 //   `applyfn__{fnName}`
178 // Examples:
179 //   `applyfn__add`
180 //   `applyfn__mul`
181 // The naming convention is intentional in order to match snake-cased DSL names.
182 // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
183 //
184 // Implementations of the math functions must be polymorphic over numeric types,
185 // internally performing necessary casts. If the function application makes no
186 // sense, then the only recourse is to assert and return nullptr. This can be
187 // extended later if it becomes possible to fail construction of the region. The
188 // invariant should be enforced at a higher level.
189 //
190 // TODO: These helpers are currently type polymorphic over the class of integer
191 // and floating point types, but they will not internally cast within bit
192 // widths of a class (mixed precision such as i8->i32) or across classes
193 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
194 // to be handled with care and work is being considered to extend the op
195 // language to make such cases explicit. In the mean-time, violating this will
196 // fail verification, which is deemed acceptable.
197 //===----------------------------------------------------------------------===//
198 
199 namespace {
200 
201 class RegionBuilderHelper {
202 public:
203   RegionBuilderHelper(MLIRContext *context, Block &block)
204       : context(context), block(block) {}
205 
206   // Generates operations to cast the given operand to a specified type.
207   // If the cast cannot be performed, a warning will be issued and the
208   // operand returned as-is (which will presumably yield a verification
209   // issue downstream).
210   Value cast(Type toType, Value operand) {
211     OpBuilder builder = getBuilder();
212     auto loc = operand.getLoc();
213 
214     if (operand.getType() == toType)
215       return operand;
216     if (auto toIntType = toType.dyn_cast<IntegerType>()) {
217       // If operand is floating point, cast directly to the int type.
218       if (operand.getType().isa<FloatType>())
219         return builder.create<FPToSIOp>(loc, toType, operand);
220       // Cast index operands directly to the int type.
221       if (operand.getType().isIndex())
222         return builder.create<IndexCastOp>(loc, toType, operand);
223       if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
224         // Either sign extend or truncate.
225         if (toIntType.getWidth() > fromIntType.getWidth())
226           return builder.create<SignExtendIOp>(loc, toType, operand);
227         if (toIntType.getWidth() < fromIntType.getWidth())
228           return builder.create<TruncateIOp>(loc, toType, operand);
229       }
230     } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
231       // If operand is integer, cast directly to the float type.
232       // Note that it is unclear how to cast from BF16<->FP16.
233       if (operand.getType().isa<IntegerType>())
234         return builder.create<SIToFPOp>(loc, toFloatType, operand);
235       if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
236         if (toFloatType.getWidth() > fromFloatType.getWidth())
237           return builder.create<FPExtOp>(loc, toFloatType, operand);
238         if (toFloatType.getWidth() < fromFloatType.getWidth())
239           return builder.create<FPTruncOp>(loc, toFloatType, operand);
240       }
241     }
242 
243     emitWarning(operand.getLoc()) << "could not cast operand of type "
244                                   << operand.getType() << " to " << toType;
245     return operand;
246   }
247 
248   Value applyfn__add(Value lhs, Value rhs) {
249     OpBuilder builder = getBuilder();
250     if (isFloatingPoint(lhs))
251       return builder.create<AddFOp>(lhs.getLoc(), lhs, rhs);
252     if (isInteger(lhs))
253       return builder.create<AddIOp>(lhs.getLoc(), lhs, rhs);
254     llvm_unreachable("unsupported non numeric type");
255   }
256 
257   Value applyfn__sub(Value lhs, Value rhs) {
258     OpBuilder builder = getBuilder();
259     if (isFloatingPoint(lhs))
260       return builder.create<SubFOp>(lhs.getLoc(), lhs, rhs);
261     if (isInteger(lhs))
262       return builder.create<SubIOp>(lhs.getLoc(), lhs, rhs);
263     llvm_unreachable("unsupported non numeric type");
264   }
265 
266   Value applyfn__mul(Value lhs, Value rhs) {
267     OpBuilder builder = getBuilder();
268     if (isFloatingPoint(lhs))
269       return builder.create<MulFOp>(lhs.getLoc(), lhs, rhs);
270     if (isInteger(lhs))
271       return builder.create<MulIOp>(lhs.getLoc(), lhs, rhs);
272     llvm_unreachable("unsupported non numeric type");
273   }
274 
275   void yieldOutputs(ValueRange values) {
276     assert(!values.empty() && "linalg ops must yield outputs");
277     if (values.empty())
278       return;
279     Value first = values.front();
280     OpBuilder builder = getBuilder();
281     builder.create<YieldOp>(first.getLoc(), values);
282   }
283 
284   Value constant(std::string value) {
285     OpBuilder builder = getBuilder();
286     Location loc = builder.getUnknownLoc();
287     Attribute valueAttr = parseAttribute(value, builder.getContext());
288     return builder.create<ConstantOp>(loc, valueAttr.getType(), valueAttr);
289   }
290 
291   Value index(int64_t dim) {
292     OpBuilder builder = getBuilder();
293     return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
294   }
295 
296   Type getIntegerType(unsigned width) {
297     return IntegerType::get(context, width);
298   }
299 
300   Type getFloat32Type() { return Float32Type::get(context); }
301 
302   Type getFloat64Type() { return Float64Type::get(context); }
303 
304 private:
305   MLIRContext *context;
306   Block &block;
307 
308   bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
309   bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
310 
311   OpBuilder getBuilder() {
312     OpBuilder builder(context);
313     builder.setInsertionPointToEnd(&block);
314     return builder;
315   }
316 };
317 
318 } // namespace
319 
320 //===----------------------------------------------------------------------===//
321 // CopyOp
322 //===----------------------------------------------------------------------===//
323 void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
324   assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
325   b.create<linalg::YieldOp>(block.getArgument(0));
326 }
327 
328 void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
329                    Value output, AffineMap inputPermutation,
330                    AffineMap outputPermutation,
331                    ArrayRef<NamedAttribute> namedAttrs) {
332   result.addOperands({input, output});
333   result.addAttributes(namedAttrs);
334   if (inputPermutation)
335     result.addAttribute("inputPermutation",
336                         AffineMapAttr::get(inputPermutation));
337   if (outputPermutation)
338     result.addAttribute("outputPermutation",
339                         AffineMapAttr::get(outputPermutation));
340   result.addRegion();
341   fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
342                                  TypeRange{input.getType()},
343                                  TypeRange{output.getType()});
344 }
345 
346 ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
347                               Type outputType) {
348   OpBuilder opBuilder(parser.getBuilder().getContext());
349   fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
350                                  TypeRange{outputType});
351   return success();
352 }
353 
354 /// CopyOp region is elided when printing.
355 void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
356 
357 static LogicalResult verify(CopyOp op) {
358   OpOperand *output = op.getOutputOperand(0);
359   OpOperand *input = op.getInputOperand(0);
360   if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get()))
361     return op.emitOpError("expects views of the same type");
362   if (op.getRank(input) != op.getRank(output))
363     return op.emitOpError("expects views of the same rank");
364   auto rank = op.getNumParallelLoops();
365   auto inputPermutationMap = op.inputPermutation();
366   if (inputPermutationMap) {
367     if (inputPermutationMap->getNumInputs() != rank)
368       return op.emitOpError("expects optional input_permutation map of rank ")
369              << rank;
370     if (!inputPermutationMap->isPermutation())
371       return op.emitOpError(
372           "expects optional input_permutation map to be a permutation");
373   }
374   auto outputPermutationMap = op.outputPermutation();
375   if (outputPermutationMap) {
376     if (outputPermutationMap->getNumInputs() != rank)
377       return op.emitOpError("expects optional output_permutation map of rank ")
378              << rank;
379     if (!outputPermutationMap->isPermutation())
380       return op.emitOpError(
381           "expects optional output_permutation map to be a permutation");
382   }
383   if (rank == 0 && inputPermutationMap)
384     return op.emitOpError("expected no input permutation when rank == 0");
385   if (rank == 0 && outputPermutationMap)
386     return op.emitOpError("expected no output permutation when rank == 0");
387   return success();
388 }
389 
390 void CopyOp::getEffects(
391     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
392         &effects) {
393   effects.emplace_back(MemoryEffects::Read::get(), input(),
394                        SideEffects::DefaultResource::get());
395   effects.emplace_back(MemoryEffects::Write::get(), output(),
396                        SideEffects::DefaultResource::get());
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // FillOp
401 //===----------------------------------------------------------------------===//
402 void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
403   assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
404   b.create<linalg::YieldOp>(block.getArgument(0));
405 }
406 
407 void FillOp::build(OpBuilder &builder, OperationState &result, Value value,
408                    Value output) {
409   build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
410         output);
411   fillStructuredOpRegion<FillOp>(builder, *result.regions.front(),
412                                  TypeRange{value.getType()},
413                                  TypeRange{output.getType()}, {});
414 }
415 
416 ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType,
417                               Type outputType) {
418   OpBuilder opBuilder(parser.getBuilder().getContext());
419   fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType},
420                                  TypeRange{outputType});
421   return success();
422 }
423 
424 /// FillOp region is elided when printing.
425 void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
426 
427 static LogicalResult verify(FillOp op) {
428   OpOperand *output = op.getOutputOperand(0);
429   Type fillType = op.value().getType();
430   if (getElementTypeOrSelf(output->get()) != fillType)
431     return op.emitOpError("expects fill type to match view elemental type");
432   if (!op.getNumResults() && !output->get().getType().isa<MemRefType>()) {
433     return op.emitOpError(
434         "expected fill op with no result value to use memref type");
435   }
436   return success();
437 }
438 
439 void FillOp::getEffects(
440     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
441         &effects) {
442   if (output().getType().isa<MemRefType>())
443     effects.emplace_back(MemoryEffects::Write::get(), output(),
444                          SideEffects::DefaultResource::get());
445 }
446 
447 //===----------------------------------------------------------------------===//
448 // GenericOps
449 //===----------------------------------------------------------------------===//
450 void GenericOp::build(
451     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
452     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
453     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
454     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
455   build(builder, result, resultTensorTypes, inputs, outputs,
456         builder.getAffineMapArrayAttr(indexingMaps),
457         builder.getStrArrayAttr(iteratorTypes),
458         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
459         libraryCall.empty() ? StringAttr()
460                             : builder.getStringAttr(libraryCall));
461   if (!bodyBuild)
462     return;
463 
464   SmallVector<Type, 4> blockArgTypes;
465   for (ValueRange container : {inputs, outputs})
466     for (Value v : container)
467       blockArgTypes.push_back(getElementTypeOrSelf(v));
468 
469   OpBuilder::InsertionGuard guard(builder);
470   auto &region = *result.regions.front();
471   Block *bodyBlock = builder.createBlock(&region, region.end(), blockArgTypes);
472   bodyBuild(builder, result.location, bodyBlock->getArguments());
473 }
474 
475 void GenericOp::build(
476     OpBuilder &builder, OperationState &result, ValueRange inputs,
477     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
478     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
479     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
480   build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
481         iteratorTypes, doc, libraryCall, bodyBuild);
482 }
483 
484 void GenericOp::build(
485     OpBuilder &builder, OperationState &result, ValueRange inputs,
486     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
487     ArrayRef<StringRef> iteratorTypes,
488     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
489   build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
490         /*doc=*/"",
491         /*libraryCall=*/"", bodyBuild);
492 }
493 
494 void GenericOp::build(
495     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
496     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
497     ArrayRef<StringRef> iteratorTypes,
498     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
499   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
500         iteratorTypes,
501         /*doc=*/"",
502         /*libraryCall=*/"", bodyBuild);
503 }
504 
505 static void print(OpAsmPrinter &p, GenericOp op) {
506   p << op.getOperationName() << " ";
507 
508   // Print extra attributes.
509   auto genericAttrNames = op.linalgTraitAttrNames();
510 
511   llvm::StringSet<> genericAttrNamesSet;
512   genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
513   SmallVector<NamedAttribute, 8> genericAttrs;
514   for (auto attr : op->getAttrs())
515     if (genericAttrNamesSet.count(attr.first.strref()) > 0)
516       genericAttrs.push_back(attr);
517   if (!genericAttrs.empty()) {
518     auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
519     p << genericDictAttr;
520   }
521 
522   // Printing is shared with named ops, except for the region and attributes
523   printCommonStructuredOpParts(p, op);
524 
525   genericAttrNames.push_back("operand_segment_sizes");
526   genericAttrNamesSet.insert(genericAttrNames.back());
527 
528   bool hasExtraAttrs = false;
529   for (NamedAttribute n : op->getAttrs()) {
530     if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref())))
531       break;
532   }
533   if (hasExtraAttrs) {
534     p << " attrs = ";
535     p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames);
536   }
537 
538   // Print region.
539   if (!op.region().empty())
540     p.printRegion(op.region());
541 
542   // Print results.
543   printNamedStructuredOpResults(p, op.result_tensors().getTypes());
544 }
545 
546 static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
547   DictionaryAttr dictAttr;
548   // Parse the core linalg traits that must check into a dictAttr.
549   // The name is unimportant as we will overwrite result.attributes.
550   // The core linalg traits must contain the information necessary to pass the
551   // verifier.
552   if (parser.parseAttribute(dictAttr, "_", result.attributes))
553     return failure();
554   result.attributes.assign(dictAttr.getValue().begin(),
555                            dictAttr.getValue().end());
556 
557   // Parsing is shared with named ops, except for the region.
558   SmallVector<Type, 1> inputTypes, outputTypes;
559   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
560     return failure();
561 
562   // Optional attributes may be added.
563   if (succeeded(parser.parseOptionalKeyword("attrs")))
564     if (failed(parser.parseEqual()) ||
565         failed(parser.parseOptionalAttrDict(result.attributes)))
566       return failure();
567 
568   SmallVector<OpAsmParser::OperandType, 8> regionOperands;
569   std::unique_ptr<Region> region = std::make_unique<Region>();
570   SmallVector<Type, 8> operandTypes, regionTypes;
571   if (parser.parseRegion(*region, regionOperands, regionTypes))
572     return failure();
573   result.addRegion(std::move(region));
574 
575   // Generic ops may specify that a subset of its outputs are tensors. Such
576   // outputs are specified in the result type.
577   // TODO: may need to move output parsing before region parsing.
578   // Need to wait for declarative assembly resolution to decide.
579   SmallVector<Type, 1> outputTensorsTypes;
580   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
581     return failure();
582   result.addTypes(outputTensorsTypes);
583 
584   return success();
585 }
586 
587 static void getGenericEffectsImpl(
588     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
589         &effects,
590     ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
591   for (Value value : results) {
592     effects.emplace_back(MemoryEffects::Allocate::get(), value,
593                          SideEffects::DefaultResource::get());
594   }
595   for (Value value : inputBuffers) {
596     effects.emplace_back(MemoryEffects::Read::get(), value,
597                          SideEffects::DefaultResource::get());
598   }
599   for (Value value : outputs) {
600     effects.emplace_back(MemoryEffects::Read::get(), value,
601                          SideEffects::DefaultResource::get());
602     effects.emplace_back(MemoryEffects::Write::get(), value,
603                          SideEffects::DefaultResource::get());
604   }
605 }
606 
607 void GenericOp::getEffects(
608     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
609         &effects) {
610   SmallVector<Value> inputBuffers = getInputBufferOperands();
611   SmallVector<Value> outputBuffers = getOutputBufferOperands();
612   getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
613                         outputBuffers);
614 }
615 
616 template <typename GenericOpType>
617 static LogicalResult verifyGenericOp(GenericOpType op) {
618   return success();
619 }
620 
621 static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
622 
623 //===----------------------------------------------------------------------===//
624 // InitTensorOp
625 //===----------------------------------------------------------------------===//
626 void InitTensorOp::build(OpBuilder &b, OperationState &result,
627                          ArrayRef<OpFoldResult> sizes, Type elementType,
628                          ArrayRef<NamedAttribute> attrs) {
629   unsigned rank = sizes.size();
630   SmallVector<Value, 4> dynamicSizes;
631   SmallVector<int64_t, 4> staticSizes;
632   for (unsigned i = 0; i < rank; ++i) {
633     dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes,
634                               ShapedType::kDynamicSize);
635   }
636   auto resultType = RankedTensorType ::get(staticSizes, elementType);
637   build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
638   result.addAttributes(attrs);
639 }
640 
641 static LogicalResult verify(InitTensorOp op) {
642   RankedTensorType resultType = op.getType();
643   SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
644       op.static_sizes().cast<ArrayAttr>(),
645       [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
646 
647   if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(),
648                                             op.static_sizes(), op.sizes(),
649                                             ShapedType::isDynamic)))
650     return failure();
651 
652   if (op.static_sizes().size() != static_cast<unsigned>(resultType.getRank()))
653     return op->emitError("expected ")
654            << resultType.getRank() << " sizes values";
655 
656   Type expectedType =
657       InitTensorOp::inferResultType(staticSizes, resultType.getElementType());
658   if (resultType != expectedType) {
659     return op.emitError("specified type ")
660            << resultType << " does not match the inferred type "
661            << expectedType;
662   }
663   return success();
664 }
665 
666 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
667                                    Type elementType) {
668   return RankedTensorType::get(staticSizes, elementType);
669 }
670 
671 namespace {
672 /// Change the type of the result of a `linalg.init_tensor` by making the result
673 /// type statically sized along dimension that in the original operation where
674 /// defined as dynamic, but the size was defined using a `constant` op. For
675 /// example
676 ///
677 ///  %c5 = constant 5: index
678 ///  %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
679 ///
680 ///  to
681 ///
682 ///  %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
683 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
684   using OpRewritePattern<InitTensorOp>::OpRewritePattern;
685 
686   LogicalResult matchAndRewrite(InitTensorOp op,
687                                 PatternRewriter &rewriter) const override {
688     SmallVector<Value, 4> dynamicSizes;
689     SmallVector<int64_t, 4> staticSizes;
690     for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
691       // If the size is already static, nothing to do.
692       if (!op.isDynamicSize(i)) {
693         staticSizes.push_back(op.getStaticSize(i));
694         continue;
695       }
696 
697       // If the size is dynamic but defined using a `constant` op, get the
698       // constant value to find the static size to use.
699       unsigned operandNum = op.getIndexOfDynamicSize(i);
700       Value sizeOperand = op.getOperand(operandNum);
701       if (auto constantIndexOp = sizeOperand.getDefiningOp<ConstantIndexOp>()) {
702         staticSizes.push_back(constantIndexOp.getValue());
703         continue;
704       }
705 
706       // Fallback case. Keep the size dynamic.
707       dynamicSizes.push_back(sizeOperand);
708       staticSizes.push_back(ShapedType::kDynamicSize);
709     }
710     RankedTensorType newType =
711         RankedTensorType::get(staticSizes, op.getType().getElementType());
712     if (newType == op.getType())
713       return failure();
714     auto newOp =
715         rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
716                                       rewriter.getI64ArrayAttr(staticSizes));
717     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
718     return success();
719   }
720 };
721 } // namespace
722 
723 namespace {
724 /// Since `init_tensor` operation creates a tensor needed only for its shape, a
725 /// slice of this is also needed only for its shape. The result can be
726 /// replaced by a new init_tensor operation of the same size as the extract
727 /// slice op.
728 struct FoldInitTensorWithExtractSliceOp
729     : public OpRewritePattern<tensor::ExtractSliceOp> {
730   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
731 
732   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
733                                 PatternRewriter &rewriter) const override {
734     if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>())
735       return failure();
736     rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
737         sliceOp, sliceOp.sizes(),
738         llvm::to_vector<4>(llvm::map_range(
739             sliceOp.static_sizes(),
740             [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
741         sliceOp.getSourceType().getElementType());
742     return success();
743   }
744 };
745 
746 template <typename TensorReshapeOp>
747 struct FoldInitTensorWithTensorReshapeOp
748     : public OpRewritePattern<TensorReshapeOp> {
749   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
750 
751   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
752                                 PatternRewriter &rewriter) const override {
753     if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
754       return failure();
755     Location loc = reshapeOp.getLoc();
756     SmallVector<SmallVector<Value>, 4> resultShapes;
757     if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter,
758                                                            resultShapes)) ||
759         !llvm::hasSingleElement(resultShapes))
760       return failure();
761     Value initTensor = rewriter.create<InitTensorOp>(
762         loc, getAsOpFoldResult(resultShapes[0]),
763         reshapeOp.getResultType().getElementType());
764     if (initTensor.getType() != reshapeOp.getResultType()) {
765       rewriter.replaceOpWithNewOp<tensor::CastOp>(
766           reshapeOp, reshapeOp.getResultType(), initTensor);
767     } else {
768       rewriter.replaceOp(reshapeOp, initTensor);
769     }
770     return success();
771   }
772 };
773 } // namespace
774 
775 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
776                                                MLIRContext *context) {
777   results.add<FoldInitTensorWithExtractSliceOp,
778               FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
779               FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
780               ReplaceStaticShapeDims>(context);
781 }
782 
783 LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
784     OpBuilder &builder,
785     SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
786   auto shapes = llvm::to_vector<4>(llvm::map_range(
787       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
788         if (isDynamicSize(dim))
789           return getDynamicSize(dim);
790         return builder.create<ConstantIndexOp>(getLoc(), getStaticSize(dim));
791       }));
792   reifiedReturnShapes.emplace_back(std::move(shapes));
793   return success();
794 }
795 
796 //===----------------------------------------------------------------------===//
797 // PadTensorOp
798 //===----------------------------------------------------------------------===//
799 
800 static LogicalResult verify(PadTensorOp op) {
801   auto sourceType = op.source().getType().cast<RankedTensorType>();
802   auto resultType = op.result().getType().cast<RankedTensorType>();
803   auto expectedType = PadTensorOp::inferResultType(
804       sourceType, extractFromI64ArrayAttr(op.static_low()),
805       extractFromI64ArrayAttr(op.static_high()));
806   for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
807     if (resultType.getDimSize(i) == expectedType.getDimSize(i))
808       continue;
809     if (expectedType.isDynamicDim(i))
810       continue;
811     return op.emitError("specified type ")
812            << resultType << " does not match the inferred type "
813            << expectedType;
814   }
815 
816   auto &region = op.region();
817   unsigned rank = resultType.getRank();
818   Block &block = region.front();
819   if (block.getNumArguments() != rank)
820     return op.emitError("expected the block to have ") << rank << " arguments";
821 
822   // Note: the number and type of yield values are checked in the YieldOp.
823   for (auto en : llvm::enumerate(block.getArgumentTypes())) {
824     if (!en.value().isIndex())
825       return op.emitOpError("expected block argument ")
826              << (en.index() + 1) << " to be an index";
827   }
828 
829   return success();
830 }
831 
832 RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
833                                               ArrayRef<int64_t> staticLow,
834                                               ArrayRef<int64_t> staticHigh) {
835   unsigned rank = sourceType.getRank();
836   assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
837   assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
838 
839   SmallVector<int64_t, 4> resultShape;
840   for (auto i : llvm::seq<unsigned>(0, rank)) {
841     if (sourceType.isDynamicDim(i) ||
842         staticLow[i] == ShapedType::kDynamicSize ||
843         staticHigh[i] == ShapedType::kDynamicSize) {
844       resultShape.push_back(ShapedType::kDynamicSize);
845     } else {
846       int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
847       resultShape.push_back(size);
848     }
849   }
850 
851   return RankedTensorType::get(resultShape, sourceType.getElementType());
852 }
853 
854 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
855                         ArrayRef<int64_t> staticLow,
856                         ArrayRef<int64_t> staticHigh, ValueRange low,
857                         ValueRange high, ArrayRef<NamedAttribute> attrs) {
858   auto sourceType = source.getType().cast<RankedTensorType>();
859   auto resultType = inferResultType(sourceType, staticLow, staticHigh);
860   build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
861         b.getI64ArrayAttr(staticHigh));
862   result.addAttributes(attrs);
863 }
864 
865 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
866                         ValueRange low, ValueRange high,
867                         ArrayRef<NamedAttribute> attrs) {
868   auto sourceType = source.getType().cast<RankedTensorType>();
869   unsigned rank = sourceType.getRank();
870   SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
871   build(b, result, source, staticVector, staticVector, low, high, attrs);
872 }
873 
874 void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
875                         Value source, ArrayRef<OpFoldResult> low,
876                         ArrayRef<OpFoldResult> high,
877                         ArrayRef<NamedAttribute> attrs) {
878   assert(resultType.isa<RankedTensorType>());
879   auto sourceType = source.getType().cast<RankedTensorType>();
880   unsigned rank = sourceType.getRank();
881   SmallVector<Value, 4> dynamicLow, dynamicHigh;
882   SmallVector<int64_t, 4> staticLow, staticHigh;
883   for (unsigned i = 0; i < rank; ++i) {
884     // staticLow and staticHigh have full information of the padding config.
885     // This will grow staticLow and staticHigh with 1 value. If the config is
886     // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
887     // value as well.
888     dispatchIndexOpFoldResult(low[i], dynamicLow, staticLow,
889                               ShapedType::kDynamicSize);
890     dispatchIndexOpFoldResult(high[i], dynamicHigh, staticHigh,
891                               ShapedType::kDynamicSize);
892   }
893   if (!resultType) {
894     resultType =
895         PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
896   }
897   build(b, result, resultType, source, dynamicLow, dynamicHigh,
898         b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
899 }
900 
901 PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
902                                            ArrayRef<OpFoldResult> low,
903                                            ArrayRef<OpFoldResult> high,
904                                            Location loc, OpBuilder &builder) {
905   auto padTensorOp =
906       builder.create<linalg::PadTensorOp>(loc, type, source, low, high);
907   int rank = padTensorOp.getResultType().getRank();
908   SmallVector<Type, 4> blockArgTypes;
909   blockArgTypes.assign(rank, builder.getIndexType());
910   auto &region = padTensorOp.region();
911   // `builder.createBlock` changes the insertion point within the block. Create
912   // a guard to reset the insertion point of the builder after it is destroyed.
913   OpBuilder::InsertionGuard guard(builder);
914   builder.createBlock(&region, region.end(), blockArgTypes);
915   builder.create<linalg::YieldOp>(loc, pad);
916   return padTensorOp;
917 }
918 
919 PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
920                                          Location loc, OpBuilder &builder) {
921   SmallVector<OpFoldResult, 4> low, high;
922   auto rankedTensorType = type.cast<RankedTensorType>();
923   assert(rankedTensorType.hasStaticShape());
924   int rank = rankedTensorType.getRank();
925   for (int i = 0; i < rank; ++i) {
926     auto dimOp = builder.createOrFold<memref::DimOp>(loc, source, i);
927     auto resultDimSize = builder.createOrFold<ConstantIndexOp>(
928         loc, rankedTensorType.getDimSize(i));
929     auto highValue = builder.createOrFold<SubIOp>(loc, resultDimSize, dimOp);
930     high.push_back(highValue);
931     low.push_back(builder.createOrFold<ConstantIndexOp>(loc, 0));
932   }
933   return PadTensorOp::createPadScalarOp(type, source, pad, low, high, loc,
934                                         builder);
935 }
936 
937 LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim(
938     OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
939   Location loc = getLoc();
940   auto lowPad = getMixedLowPad();
941   auto highPad = getMixedHighPad();
942   SmallVector<Value> shapes;
943   for (auto dim : llvm::seq<int64_t>(0, getSourceType().getRank())) {
944     // Shape along each dimension is source dim + low pad + high pad.
945     SmallVector<Value> mapOperands;
946     mapOperands.push_back(b.createOrFold<memref::DimOp>(loc, source(), dim));
947     AffineExpr expr = b.getAffineDimExpr(0);
948     unsigned numSymbols = 0;
949     auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
950       if (Value v = valueOrAttr.dyn_cast<Value>()) {
951         expr = expr + b.getAffineSymbolExpr(numSymbols++);
952         mapOperands.push_back(v);
953         return;
954       }
955       int64_t staticValue =
956           valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt();
957       expr = expr + staticValue;
958     };
959     addOpFoldResult(lowPad[dim]);
960     addOpFoldResult(highPad[dim]);
961     shapes.push_back(applyMapToValues(
962         b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]);
963   }
964   reifiedReturnShapes.emplace_back(std::move(shapes));
965   return success();
966 }
967 
968 namespace {
969 // Folds linalg.pad_tensor when padding is static zeros.
970 struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
971   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
972 
973   LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
974                                 PatternRewriter &rewriter) const override {
975     if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
976       return failure();
977     rewriter.replaceOpWithNewOp<tensor::CastOp>(
978         padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
979     return success();
980   }
981 };
982 
983 } // namespace
984 
985 void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
986                                               MLIRContext *context) {
987   results.add<FoldStaticZeroPadding>(context);
988 }
989 
990 /// Return the padding value of the PadTensorOp if it constant. In this context,
991 /// "constant" means an actual constant or "defined outside of the block".
992 ///
993 /// Values are considered constant in three cases:
994 ///  - A ConstantLike value.
995 ///  - A basic block argument from a different block.
996 ///  - A value defined outside of the block.
997 ///
998 /// If the padding value is not constant, an empty Value is returned.
999 Value PadTensorOp::getConstantPaddingValue() {
1000   auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
1001   if (!yieldOp || yieldOp.values().size() != 1)
1002     return {};
1003   Value padValue = yieldOp.values().front();
1004   // Check if yield value is a constant.
1005   if (matchPattern(padValue, m_Constant()))
1006     return padValue;
1007   // Check if yield value is defined inside the PadTensorOp block.
1008   if (padValue.getParentBlock() == &getRegion().front())
1009     return {};
1010   // Else: Yield value defined outside of the PadTensorOp block.
1011   return padValue;
1012 }
1013 
1014 OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
1015   if (getResultType().hasStaticShape() && getResultType() == getSourceType())
1016     return source();
1017   return {};
1018 }
1019 
1020 //===----------------------------------------------------------------------===//
1021 // ReshapeOp
1022 //===----------------------------------------------------------------------===//
1023 
1024 Optional<SmallVector<ReassociationIndices>>
1025 mlir::linalg::getReassociationIndicesForReshape(ShapedType sourceType,
1026                                                 ShapedType targetType) {
1027   // Make the sourceType greater rank than the targetType. If they are same
1028   // rank, then its an unsupported reshape op.
1029   if (sourceType.getRank() == targetType.getRank())
1030     return llvm::None;
1031   if (sourceType.getRank() < targetType.getRank())
1032     std::swap(sourceType, targetType);
1033 
1034   ArrayRef<int64_t> sourceShape = sourceType.getShape();
1035   ArrayRef<int64_t> targetShape = targetType.getShape();
1036   unsigned sourceDim = 0;
1037   SmallVector<ReassociationIndices> reassociationMap;
1038   reassociationMap.reserve(targetType.getRank());
1039 
1040   ReassociationIndices currIndices;
1041   int64_t prodOfCollapsedDims = 1;
1042   while (sourceDim < sourceShape.size()) {
1043     unsigned targetDim = reassociationMap.size();
1044 
1045     // If all the dimensions of the targetShape are exhausted, then the
1046     // remaining dims in the source shape must be all 1s. So for such cases, set
1047     // 1 as the target shape. The actual reassociation indices will be handled
1048     // later.
1049     int64_t currTargetShape =
1050         (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
1051     while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
1052            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
1053            sourceDim < sourceShape.size()) {
1054       prodOfCollapsedDims *= sourceShape[sourceDim];
1055       currIndices.push_back(sourceDim++);
1056     }
1057 
1058     // If the current expanded dimension is dynamic, then the collapsed
1059     // dimensions should also be dynamic and product of all previous unprocessed
1060     // dimensions of the expanded shape should be 1.
1061     if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
1062         (currTargetShape != ShapedType::kDynamicSize ||
1063          prodOfCollapsedDims != 1))
1064       return llvm::None;
1065 
1066     // If the collapsed dim is dynamic, the current expanded dim should also
1067     // be dynamic.
1068     if (currTargetShape == ShapedType::kDynamicSize &&
1069         sourceShape[sourceDim] != ShapedType::kDynamicSize)
1070       return llvm::None;
1071 
1072     // For static shapes, if the product of dimensions of the expanded shape
1073     // should match the collapsed dimension shape.
1074     if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
1075       return llvm::None;
1076 
1077     currIndices.push_back(sourceDim++);
1078     // If the reassociation is empty but the currIndices is not, this by
1079     // definition is folding unit-dimensions with the result being scalar type.
1080     // So only append the `currIndices` if reassociation map is not empty.
1081     if (targetDim == targetShape.size()) {
1082       if (!reassociationMap.empty() && !currIndices.empty())
1083         reassociationMap.back().append(currIndices.begin(), currIndices.end());
1084       // Break out of the loops. We should be done here.
1085       break;
1086     }
1087     reassociationMap.emplace_back(ReassociationIndices{});
1088     std::swap(reassociationMap.back(), currIndices);
1089     prodOfCollapsedDims = 1;
1090   }
1091   // All the dimensions in the two shapes must have been processed.
1092   if (reassociationMap.size() != targetShape.size() ||
1093       sourceDim != sourceShape.size())
1094     return llvm::None;
1095   return reassociationMap;
1096 }
1097 
1098 template <typename ReshapeLikeOp>
1099 static void print(OpAsmPrinter &p, ReshapeLikeOp op) {
1100   p << op.getOperationName() << ' ' << op.src() << " [";
1101 
1102   llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
1103     p << '[';
1104     auto arrayAttr = attr.template cast<ArrayAttr>();
1105     llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
1106       p << attr.cast<IntegerAttr>().getInt();
1107     });
1108     p << ']';
1109   });
1110 
1111   p << "] ";
1112   p.printOptionalAttrDict(op->getAttrs(),
1113                           /*elidedAttrs=*/{op.getReassociationAttrName()});
1114   p << ": " << op.src().getType() << " into " << op.getType();
1115 }
1116 
1117 static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) {
1118   print<linalg::ExpandShapeOp>(p, op);
1119 }
1120 
1121 static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) {
1122   print<linalg::CollapseShapeOp>(p, op);
1123 }
1124 
1125 static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
1126   print<linalg::TensorExpandShapeOp>(p, op);
1127 }
1128 
1129 static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) {
1130   print<linalg::TensorCollapseShapeOp>(p, op);
1131 }
1132 
1133 static constexpr StringRef getReassociationAttrName() {
1134   return "reassociation";
1135 }
1136 
1137 static ParseResult parseReshapeLikeOp(OpAsmParser &parser,
1138                                       OperationState &result) {
1139   // Parse the operand.
1140   OpAsmParser::OperandType src;
1141   if (parser.parseOperand(src))
1142     return failure();
1143 
1144   // Parse reassociation indices.
1145   Builder &b = parser.getBuilder();
1146   SmallVector<Attribute, 4> reassociation;
1147   if (parser.parseLSquare())
1148     return failure();
1149 
1150   while (true) {
1151     if (succeeded(parser.parseOptionalRSquare()))
1152       break;
1153     if (parser.parseLSquare())
1154       return failure();
1155     SmallVector<int64_t> indices;
1156     while (true) {
1157       int64_t index;
1158       if (parser.parseInteger(index))
1159         return failure();
1160       indices.push_back(index);
1161 
1162       if (succeeded(parser.parseOptionalComma()))
1163         continue;
1164       if (failed(parser.parseRSquare()))
1165         return failure();
1166       break;
1167     }
1168     reassociation.push_back(b.getI64ArrayAttr(indices));
1169     if (succeeded(parser.parseOptionalComma()))
1170       continue;
1171     if (failed(parser.parseRSquare()))
1172       return failure();
1173     break;
1174   }
1175 
1176   result.addAttribute(getReassociationAttrName(),
1177                       b.getArrayAttr(reassociation));
1178 
1179   // Parse optional attributes.
1180   parser.parseOptionalAttrDict(result.attributes);
1181 
1182   // Parse types.
1183   Type srcType;
1184   Type resultType;
1185   if (parser.parseColon() || parser.parseType(srcType) ||
1186       parser.resolveOperand(src, srcType, result.operands) ||
1187       parser.parseKeyword("into") || parser.parseType(resultType))
1188     return failure();
1189   result.addTypes(resultType);
1190   return success();
1191 }
1192 
1193 /// Collapse reassociation maps that are used in pair of reshape ops where one
1194 /// is a producer and other is the consumer. Only valid to use this method when
1195 /// both the producer and consumer are collapsing dimensions or both are
1196 /// expanding dimensions.
1197 ///
1198 /// For example,
1199 ///   mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
1200 ///                   affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
1201 ///                   affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
1202 ///   mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
1203 ///                   affine_map<(d0, d1, d2) -> (d2)>]
1204 ///
1205 /// is folded into
1206 ///
1207 ///   result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
1208 ///             affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
1209 static Optional<SmallVector<ReassociationIndices>>
1210 collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
1211                              ArrayRef<AffineMap> mapsConsumer,
1212                              MLIRContext *context) {
1213   // Make the producer the larger sized vector. If they are of same size, the
1214   // resulting reshape is not a supported reshape op.
1215   if (mapsProducer.size() == mapsConsumer.size())
1216     return llvm::None;
1217   if (mapsProducer.size() < mapsConsumer.size())
1218     std::swap(mapsProducer, mapsConsumer);
1219 
1220   // Handle the corner case of the result being a rank 0 shaped type. Return an
1221   // empty reassociation.
1222   if (mapsConsumer.empty())
1223     return SmallVector<ReassociationIndices>{};
1224   if (mapsProducer.size() != mapsConsumer[0].getNumDims())
1225     return llvm::None;
1226 
1227   unsigned currDim = 0;
1228   SmallVector<ReassociationIndices> reassociationMaps;
1229   for (AffineMap rhs : mapsConsumer) {
1230     ReassociationIndices reassociations;
1231     for (AffineExpr rhsExpr : rhs.getResults()) {
1232       AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
1233       for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
1234            i < e; ++i)
1235         reassociations.push_back(currDim++);
1236     }
1237     reassociationMaps.push_back(std::move(reassociations));
1238   }
1239   return reassociationMaps;
1240 }
1241 
1242 namespace {
1243 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
1244 /// dimensions or are both expanding dimensions.
1245 template <typename ReshapeOpTy>
1246 struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
1247   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
1248   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
1249                                 PatternRewriter &rewriter) const override {
1250     auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
1251     if (!srcReshapeOp)
1252       return failure();
1253 
1254     ShapedType resultType = reshapeOp.getResultType();
1255     Optional<SmallVector<ReassociationIndices>> reassociationIndices =
1256         collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
1257                                      reshapeOp.getReassociationMaps(),
1258                                      rewriter.getContext());
1259     if (!reassociationIndices)
1260       return failure();
1261     rewriter.replaceOpWithNewOp<ReshapeOpTy>(
1262         reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
1263     return success();
1264   }
1265 };
1266 
1267 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
1268 /// dimensions or are both expanding dimensions.
1269 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
1270 struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
1271   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
1272   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
1273                                 PatternRewriter &rewriter) const override {
1274     auto srcReshapeOp =
1275         reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
1276     if (!srcReshapeOp)
1277       return failure();
1278 
1279     ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
1280     ShapedType intermediateType = reshapeOp.getSrcType();
1281     ShapedType resultType = reshapeOp.getResultType();
1282 
1283     // If the source reshape can be collapsed/expanded into the target reshape
1284     // they can still be folded. This can only be reasoned about statically
1285     // for cases where
1286     // - either all shapes are static, or
1287     // - The number of dynamic dimensions matches in the source of source and
1288     //   result with all other dimensions being 1.
1289     Optional<SmallVector<ReassociationIndices>> reassociationIndices =
1290         getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
1291     if (!reassociationIndices)
1292       return failure();
1293     bool originalOpExpands =
1294         intermediateType.getRank() > srcReshapeSrcType.getRank();
1295     bool resultingOpExpands =
1296         resultType.getRank() > srcReshapeSrcType.getRank();
1297     if (!(resultingOpExpands ^ originalOpExpands))
1298       rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
1299           reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
1300     else
1301       rewriter.replaceOpWithNewOp<ReshapeOpTy>(
1302           reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
1303     return success();
1304   }
1305 };
1306 } // namespace
1307 
1308 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
1309 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
1310                                   ArrayRef<Attribute> operands) {
1311   // Fold producer-consumer reshape ops that where the operand type of the
1312   // producer is same as the return type of the consumer.
1313   auto reshapeSrcOp =
1314       reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
1315   if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
1316     return reshapeSrcOp.src();
1317   // Reshape of a constant can be replaced with a new constant.
1318   if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
1319     return elements.reshape(
1320         reshapeOp.getResult().getType().template cast<ShapedType>());
1321   }
1322   return nullptr;
1323 }
1324 
1325 /// Return true if the reassociation specification is valid, false otherwise.
1326 /// When false, the `invalidIndex` integer pointer is optionally filled with the
1327 /// index of the offending reassociation map.
1328 static bool isReassociationValid(ArrayRef<AffineMap> reassociation,
1329                                  int *invalidIndex = nullptr) {
1330   if (reassociation.empty())
1331     return true;
1332   unsigned nDims = reassociation[0].getNumDims();
1333   unsigned nextExpectedDim = 0;
1334   for (auto it : llvm::enumerate(reassociation)) {
1335     auto m = it.value();
1336     if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
1337       if (invalidIndex)
1338         *invalidIndex = it.index();
1339       return false;
1340     }
1341     for (auto e : m.getResults()) {
1342       auto d = e.dyn_cast<AffineDimExpr>();
1343       if (!d || d.getPosition() != nextExpectedDim++) {
1344         if (invalidIndex)
1345           *invalidIndex = it.index();
1346         return false;
1347       }
1348     }
1349   }
1350   if (nextExpectedDim != nDims) {
1351     if (invalidIndex)
1352       *invalidIndex = reassociation.size() - 1;
1353     return false;
1354   }
1355   return true;
1356 }
1357 
1358 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
1359 /// copies.
1360 static bool isReshapableDimBand(unsigned dim, unsigned extent,
1361                                 ArrayRef<int64_t> sizes,
1362                                 ArrayRef<AffineExpr> strides) {
1363   assert(sizes.size() == strides.size() && "mismatched ranks");
1364   // off by 1 indexing to avoid out of bounds
1365   //                       V
1366   for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
1367     // Only bands of static shapes are reshapable. This is due to the fact that
1368     // there is no relation between dynamic sizes and dynamic strides: we do not
1369     // have enough information to know whether a "-1" size corresponds to the
1370     // proper symbol in the AffineExpr of a stride.
1371     if (ShapedType::isDynamic(sizes[dim + 1]))
1372       return false;
1373     // TODO: Refine this by passing the proper nDims and nSymbols so we can
1374     // simplify on the fly and catch more reshapable cases.
1375     if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
1376       return false;
1377   }
1378   return true;
1379 }
1380 
1381 /// Compute the MemRefType obtained by applying the `reassociation` (which is
1382 /// expected to be valid) to `type`.
1383 /// If `type` is Contiguous MemRefType, this always produce a contiguous
1384 /// MemRefType.
1385 static MemRefType
1386 computeReshapeCollapsedType(MemRefType type,
1387                             ArrayRef<AffineMap> reassociation) {
1388   auto sizes = type.getShape();
1389   AffineExpr offset;
1390   SmallVector<AffineExpr, 4> strides;
1391   auto status = getStridesAndOffset(type, strides, offset);
1392   (void)status;
1393   assert(succeeded(status) && "expected strided memref");
1394 
1395   SmallVector<int64_t, 4> newSizes;
1396   newSizes.reserve(reassociation.size());
1397   SmallVector<AffineExpr, 4> newStrides;
1398   newStrides.reserve(reassociation.size());
1399 
1400   // Use the fact that reassociation is valid to simplify the logic: only use
1401   // each map's rank.
1402   assert(isReassociationValid(reassociation) && "invalid reassociation");
1403   unsigned currentDim = 0;
1404   for (AffineMap m : reassociation) {
1405     unsigned dim = m.getNumResults();
1406     int64_t size = 1;
1407     AffineExpr stride = strides[currentDim + dim - 1];
1408     if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
1409       size = ShapedType::kDynamicSize;
1410       stride = AffineExpr();
1411     } else {
1412       for (unsigned d = 0; d < dim; ++d)
1413         size *= sizes[currentDim + d];
1414     }
1415     newSizes.push_back(size);
1416     newStrides.push_back(stride);
1417     currentDim += dim;
1418   }
1419 
1420   // Early-exit: if `type` is contiguous, the result must be contiguous.
1421   if (canonicalizeStridedLayout(type).getAffineMaps().empty())
1422     return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
1423 
1424   // Convert back to int64_t because we don't have enough information to create
1425   // new strided layouts from AffineExpr only. This corresponds to a case where
1426   // copies may be necessary.
1427   int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
1428   if (auto o = offset.dyn_cast<AffineConstantExpr>())
1429     intOffset = o.getValue();
1430   SmallVector<int64_t, 4> intStrides;
1431   intStrides.reserve(strides.size());
1432   for (auto stride : newStrides) {
1433     if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
1434       intStrides.push_back(cst.getValue());
1435     else
1436       intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1437   }
1438   auto layout =
1439       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
1440   return canonicalizeStridedLayout(
1441       MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
1442 }
1443 
1444 template <typename AffineExprTy>
1445 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
1446   unsigned pos = 0;
1447   for (const auto &exprs : exprArrays) {
1448     for (auto expr : exprs) {
1449       expr.walk([&pos](AffineExpr e) {
1450         if (auto d = e.dyn_cast<AffineExprTy>())
1451           pos = std::max(pos, d.getPosition());
1452       });
1453     }
1454   }
1455   return pos;
1456 }
1457 
1458 static SmallVector<AffineMap, 4>
1459 getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
1460   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
1461   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
1462          "Expected symbol-less expressions");
1463   SmallVector<AffineMap, 4> maps;
1464   maps.reserve(reassociation.size());
1465   for (const auto &exprs : reassociation) {
1466     assert(!exprs.empty());
1467     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
1468   }
1469   return maps;
1470 }
1471 
1472 static SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
1473     OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
1474   SmallVector<ReassociationIndices, 2> reassociationIndices;
1475   for (const auto &exprs : reassociationExprs) {
1476     ReassociationIndices indices;
1477     indices.reserve(exprs.size());
1478     for (const auto &expr : exprs)
1479       indices.push_back(expr.cast<AffineDimExpr>().getPosition());
1480     reassociationIndices.push_back(indices);
1481   }
1482   return reassociationIndices;
1483 }
1484 
1485 static SmallVector<SmallVector<AffineExpr, 2>, 2>
1486 convertReassociationIndicesToExprs(
1487     OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices) {
1488   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
1489   for (const auto &indices : reassociationIndices) {
1490     SmallVector<AffineExpr, 2> reassociationMap;
1491     reassociationMap.reserve(indices.size());
1492     for (int64_t index : indices)
1493       reassociationMap.push_back(b.getAffineDimExpr(index));
1494     reassociationMaps.push_back(std::move(reassociationMap));
1495   }
1496   return reassociationMaps;
1497 }
1498 
1499 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1500   return getSymbolLessAffineMaps(getReassociationExprs());
1501 }
1502 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1503   OpBuilder b(this->getContext());
1504   return convertReassociationIndicesToExprs(b, getReassociationIndices());
1505 }
1506 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1507   return getSymbolLessAffineMaps(getReassociationExprs());
1508 }
1509 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1510   OpBuilder b(this->getContext());
1511   return convertReassociationIndicesToExprs(b, getReassociationIndices());
1512 }
1513 
1514 SmallVector<AffineMap, 4> TensorCollapseShapeOp::getReassociationMaps() {
1515   return getSymbolLessAffineMaps(getReassociationExprs());
1516 }
1517 SmallVector<ReassociationExprs, 4>
1518 TensorCollapseShapeOp::getReassociationExprs() {
1519   OpBuilder b(this->getContext());
1520   return convertReassociationIndicesToExprs(b, getReassociationIndices());
1521 }
1522 SmallVector<AffineMap, 4> TensorExpandShapeOp::getReassociationMaps() {
1523   return getSymbolLessAffineMaps(getReassociationExprs());
1524 }
1525 SmallVector<ReassociationExprs, 4>
1526 TensorExpandShapeOp::getReassociationExprs() {
1527   OpBuilder b(this->getContext());
1528   return convertReassociationIndicesToExprs(b, getReassociationIndices());
1529 }
1530 
1531 /// For reshape op compute the shape at dimension `dimIndex` of the output in
1532 /// terms of shape of the `src`, when the reshape op is a collapsing
1533 /// operation. It is the product of the shape of the collapsed dimensions of the
1534 /// `src`.
1535 static OpFoldResult
1536 getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
1537                                     int64_t dimIndex, Value src,
1538                                     ArrayRef<AffineMap> reassociationMap) {
1539   AffineMap map = reassociationMap[dimIndex];
1540   unsigned startPos =
1541       map.getResults().front().cast<AffineDimExpr>().getPosition();
1542   unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition();
1543   AffineExpr expr;
1544   SmallVector<Value, 2> dynamicDims;
1545   for (auto dim : llvm::seq(startPos, endPos + 1)) {
1546     dynamicDims.push_back(builder.createOrFold<memref::DimOp>(loc, src, dim));
1547     AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
1548     expr = (expr ? expr * currExpr : currExpr);
1549   }
1550   return applyMapToValues(builder, loc,
1551                           AffineMap::get(0, endPos - startPos + 1, expr),
1552                           dynamicDims)[0];
1553 }
1554 
1555 /// Given the `src` of a collapsing reshape op and its reassociation maps,
1556 /// compute the shape of the result of the reshape.
1557 static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
1558     OpBuilder &builder, Location loc, Value src,
1559     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
1560   return llvm::to_vector<4>(llvm::map_range(
1561       llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
1562         return getCollapsedOutputDimFromInputShape(builder, loc, dim, src,
1563                                                    reassociation);
1564       }));
1565 }
1566 
1567 /// Compute a map that for a given dimension of the expanded type gives the
1568 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1569 /// the `reassocation` maps.
1570 static llvm::DenseMap<int64_t, int64_t>
1571 getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
1572   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1573   for (auto map : enumerate(reassociation)) {
1574     unsigned startPos =
1575         map.value().getResults().front().cast<AffineDimExpr>().getPosition();
1576     unsigned endPos =
1577         map.value().getResults().back().cast<AffineDimExpr>().getPosition();
1578     for (auto dim : llvm::seq(startPos, endPos + 1)) {
1579       expandedDimToCollapsedDim[dim] = map.index();
1580     }
1581   }
1582   return expandedDimToCollapsedDim;
1583 }
1584 
1585 /// For an expanding reshape op, compute the value for a dimension of the output
1586 /// from the shape of the input.
1587 static OpFoldResult getExpandedOutputDimFromInputShape(
1588     OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
1589     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
1590     llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
1591   if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
1592     return builder.getI64IntegerAttr(dstStaticShape[dimIndex]);
1593   }
1594   unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
1595   unsigned startPos = reassociation[sourceDimPos]
1596                           .getResults()
1597                           .front()
1598                           .cast<AffineDimExpr>()
1599                           .getPosition();
1600   unsigned endPos = reassociation[sourceDimPos]
1601                         .getResults()
1602                         .back()
1603                         .cast<AffineDimExpr>()
1604                         .getPosition();
1605   int64_t linearizedStaticDim = 1;
1606   for (auto d :
1607        llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
1608     if (d.index() + startPos == static_cast<unsigned>(dimIndex))
1609       continue;
1610     assert(!ShapedType::isDynamic(d.value()) &&
1611            "single dimension cannot be expanded into multiple dynamic "
1612            "dimensions");
1613     linearizedStaticDim *= d.value();
1614   }
1615   Value sourceDim = builder.create<memref::DimOp>(loc, src, sourceDimPos);
1616   return applyMapToValues(
1617       builder, loc,
1618       AffineMap::get(
1619           0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
1620       sourceDim)[0];
1621 }
1622 
1623 /// Given the `src` of an expanding reshape op, the reassociation maps and the
1624 /// result type, compute the shape of the result of the reshape.
1625 static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
1626     OpBuilder &builder, Location loc, Value src,
1627     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
1628   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
1629       getExpandedDimToCollapsedDimMap(reassociation);
1630   return llvm::to_vector<4>(llvm::map_range(
1631       llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
1632         return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
1633                                                   dstStaticShape, reassociation,
1634                                                   expandedDimToCollapsedDim);
1635       }));
1636 }
1637 
1638 static SmallVector<OpFoldResult, 4>
1639 getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
1640                                     ArrayRef<int64_t> dstStaticShape,
1641                                     ArrayRef<AffineMap> reassocation) {
1642   return dstStaticShape.size() >
1643                  static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
1644              ? getExpandedOutputShapeFromInputShape(
1645                    builder, loc, src, dstStaticShape, reassocation)
1646              : getCollapsedOutputShapeFromInputShape(
1647                    builder, loc, src, dstStaticShape, reassocation);
1648 }
1649 
1650 static ArrayAttr
1651 getReassociationIndicesAttribute(OpBuilder &b,
1652                                  ArrayRef<ReassociationIndices> reassociation) {
1653   SmallVector<Attribute, 4> reassociationAttr =
1654       llvm::to_vector<4>(llvm::map_range(
1655           reassociation, [&](ReassociationIndices indices) -> Attribute {
1656             return b.getI64ArrayAttr(indices).cast<Attribute>();
1657           }));
1658   return b.getArrayAttr(reassociationAttr);
1659 }
1660 
1661 void mlir::linalg::ExpandShapeOp::build(
1662     OpBuilder &b, OperationState &result, Value src,
1663     ArrayRef<ReassociationIndices> reassociation,
1664     ArrayRef<NamedAttribute> attrs) {
1665   auto memRefType = src.getType().cast<MemRefType>();
1666   auto resultType = computeReshapeCollapsedType(
1667       memRefType, getSymbolLessAffineMaps(
1668                       convertReassociationIndicesToExprs(b, reassociation)));
1669   build(b, result, resultType, src, attrs);
1670   result.addAttribute(getReassociationAttrName(),
1671                       getReassociationIndicesAttribute(b, reassociation));
1672 }
1673 
1674 Value mlir::linalg::ExpandShapeOp::getViewSource() { return src(); }
1675 
1676 void mlir::linalg::CollapseShapeOp::build(
1677     OpBuilder &b, OperationState &result, Value src,
1678     ArrayRef<ReassociationIndices> reassociation,
1679     ArrayRef<NamedAttribute> attrs) {
1680   auto memRefType = src.getType().cast<MemRefType>();
1681   auto resultType = computeReshapeCollapsedType(
1682       memRefType, getSymbolLessAffineMaps(
1683                       convertReassociationIndicesToExprs(b, reassociation)));
1684   build(b, result, resultType, src, attrs);
1685   result.addAttribute(getReassociationAttrName(),
1686                       getReassociationIndicesAttribute(b, reassociation));
1687 }
1688 
1689 Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); }
1690 
1691 /// Verify that shapes of the reshaped types using following rules
1692 /// 1) if a dimension in the collapsed type is static, then the corresponding
1693 ///    dimensions in the expanded shape should be
1694 ///    a) static
1695 ///    b) the product should be same as the collaped shape.
1696 /// 2) if a dimension in the collaped type is dynamic, one and only one of the
1697 ///    corresponding dimensions in the expanded type should be dynamic. This
1698 ///    rule is only needed with reshape operations that are expanding.
1699 template <typename OpTy>
1700 static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
1701                                              ShapedType expandedType,
1702                                              bool isExpandingReshape) {
1703   ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
1704   ArrayRef<int64_t> expandedShape = expandedType.getShape();
1705   unsigned expandedDimStart = 0;
1706   for (auto map : llvm::enumerate(op.getReassociationMaps())) {
1707     Optional<int64_t> dynamicShape;
1708     int64_t linearizedStaticShape = 1;
1709     for (auto dim : llvm::enumerate(expandedShape.slice(
1710              expandedDimStart, map.value().getNumResults()))) {
1711       if (ShapedType::isDynamic(dim.value())) {
1712         if (isExpandingReshape && dynamicShape) {
1713           return op->emitOpError("invalid to have a single dimension (")
1714                  << map.index() << ") expanded into multiple dynamic dims ("
1715                  << expandedDimStart + dynamicShape.getValue() << ","
1716                  << expandedDimStart + dim.index() << ")";
1717         }
1718         dynamicShape = dim.index();
1719       } else {
1720         linearizedStaticShape *= dim.value();
1721       }
1722     }
1723     if (dynamicShape) {
1724       if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
1725         return op->emitOpError("expected dimension ")
1726                << map.index()
1727                << " of collapsed type to be dynamic since one or more of the "
1728                   "corresponding dimensions in the expanded type is dynamic";
1729       }
1730     } else {
1731       if (collapsedShape[map.index()] != linearizedStaticShape) {
1732         return op->emitOpError("expected dimension ")
1733                << map.index() << " of collapsed type to be static value of "
1734                << linearizedStaticShape << " ";
1735       }
1736     }
1737     expandedDimStart += map.value().getNumResults();
1738   }
1739   return success();
1740 }
1741 
1742 // Common verifier for reshape-like types. Fills `expandedType` and
1743 // `collapsedType` with the proper `src` or `result` type.
1744 template <typename Op, typename T,
1745           bool isExpansion = std::is_same<Op, TensorExpandShapeOp>::value ||
1746                              std::is_same<Op, ExpandShapeOp>::value>
1747 static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
1748                                             T collapsedType) {
1749   unsigned expandedRank = expandedType.getRank();
1750   unsigned collapsedRank = collapsedType.getRank();
1751   if (expandedRank < collapsedRank)
1752     return op.emitOpError("expected the type ")
1753            << expandedType
1754            << " to have higher rank than the type = " << collapsedType;
1755   if (expandedRank == 0)
1756     return op.emitOpError("expected non-zero memref ranks");
1757   if (expandedRank == collapsedRank)
1758     return op.emitOpError("expected to collapse or expand dims");
1759 
1760   if (collapsedRank == 0) {
1761     // If collapsed rank is 0, then expanded type must be static shaped and of
1762     // sizes 1.
1763     if (llvm::any_of(expandedType.getShape(),
1764                      [](int64_t dim) -> bool { return dim != 1; }))
1765       return op.emitOpError("invalid to reshape tensor/memref with non-unit "
1766                             "extent dimensions to zero-rank tensor/memref");
1767     return success();
1768   }
1769   if (collapsedRank != op.reassociation().size())
1770     return op.emitOpError("expected rank of the collapsed type(")
1771            << collapsedRank << ") to be the number of reassociation maps("
1772            << op.reassociation().size() << ")";
1773   auto maps = op.getReassociationMaps();
1774   for (auto it : llvm::enumerate(maps))
1775     if (it.value().getNumDims() != expandedRank)
1776       return op.emitOpError("expected reassociation map #")
1777              << it.index() << " of same rank as expanded memref("
1778              << expandedRank << "), but got " << it.value().getNumDims();
1779   int invalidIdx = 0;
1780   if (!isReassociationValid(maps, &invalidIdx))
1781     return op.emitOpError("expected reassociation map #")
1782            << invalidIdx << " to be valid and contiguous";
1783   return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
1784 }
1785 
1786 template <typename TensorReshapeOp>
1787 static LogicalResult verifyReshapeOp(TensorReshapeOp op,
1788                                      MemRefType expandedType,
1789                                      MemRefType collapsedType) {
1790   if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
1791     return failure();
1792   auto maps = op.getReassociationMaps();
1793   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
1794   if (collapsedType != expectedType)
1795     return op.emitOpError("expected collapsed type to be ")
1796            << expectedType << ", but got " << collapsedType;
1797   return success();
1798 }
1799 
1800 static LogicalResult verify(ExpandShapeOp op) {
1801   return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
1802 }
1803 
1804 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1805                                                 MLIRContext *context) {
1806   results.add<CollapseReshapeOps<ExpandShapeOp>,
1807               CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
1808 }
1809 
1810 static LogicalResult verify(CollapseShapeOp op) {
1811   return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
1812 }
1813 
1814 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1815                                                   MLIRContext *context) {
1816   results.add<CollapseReshapeOps<CollapseShapeOp>,
1817               CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context);
1818 }
1819 
1820 //===----------------------------------------------------------------------===//
1821 // TensorReshapeOp
1822 //===----------------------------------------------------------------------===//
1823 
1824 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
1825 static RankedTensorType
1826 computeTensorReshapeCollapsedType(RankedTensorType type,
1827                                   ArrayRef<AffineMap> reassociation) {
1828   auto shape = type.getShape();
1829   SmallVector<int64_t, 4> newShape;
1830   newShape.reserve(reassociation.size());
1831 
1832   // Use the fact that reassociation is valid to simplify the logic: only use
1833   // each map's rank.
1834   assert(isReassociationValid(reassociation) && "invalid reassociation");
1835   unsigned currentDim = 0;
1836   for (AffineMap m : reassociation) {
1837     unsigned dim = m.getNumResults();
1838     auto band = shape.slice(currentDim, dim);
1839     int64_t size = 1;
1840     if (llvm::is_contained(band, ShapedType::kDynamicSize))
1841       size = ShapedType::kDynamicSize;
1842     else
1843       for (unsigned d = 0; d < dim; ++d)
1844         size *= shape[currentDim + d];
1845     newShape.push_back(size);
1846     currentDim += dim;
1847   }
1848 
1849   return RankedTensorType::get(newShape, type.getElementType());
1850 }
1851 
1852 void mlir::linalg::TensorCollapseShapeOp::build(
1853     OpBuilder &b, OperationState &result, Value src,
1854     ArrayRef<ReassociationIndices> reassociation,
1855     ArrayRef<NamedAttribute> attrs) {
1856   auto resultType = computeTensorReshapeCollapsedType(
1857       src.getType().cast<RankedTensorType>(),
1858       getSymbolLessAffineMaps(
1859           convertReassociationIndicesToExprs(b, reassociation)));
1860   build(b, result, resultType, src, attrs);
1861   result.addAttribute(getReassociationAttrName(),
1862                       getReassociationIndicesAttribute(b, reassociation));
1863 }
1864 
1865 void mlir::linalg::TensorExpandShapeOp::build(
1866     OpBuilder &b, OperationState &result, Value src,
1867     ArrayRef<ReassociationIndices> reassociation,
1868     ArrayRef<NamedAttribute> attrs) {
1869   auto resultType = computeTensorReshapeCollapsedType(
1870       src.getType().cast<RankedTensorType>(),
1871       getSymbolLessAffineMaps(
1872           convertReassociationIndicesToExprs(b, reassociation)));
1873   build(b, result, resultType, src, attrs);
1874   result.addAttribute(getReassociationAttrName(),
1875                       getReassociationIndicesAttribute(b, reassociation));
1876 }
1877 
1878 template <typename TensorReshapeOp>
1879 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1880                                            RankedTensorType expandedType,
1881                                            RankedTensorType collapsedType) {
1882   if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
1883     return failure();
1884 
1885   auto maps = op.getReassociationMaps();
1886   RankedTensorType expectedType =
1887       computeTensorReshapeCollapsedType(expandedType, maps);
1888   if (collapsedType != expectedType)
1889     return op.emitOpError("expected collapsed type to be ")
1890            << expectedType << ", but got " << collapsedType;
1891   return success();
1892 }
1893 
1894 static LogicalResult verify(TensorExpandShapeOp op) {
1895   return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType());
1896 }
1897 
1898 static LogicalResult verify(TensorCollapseShapeOp op) {
1899   return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType());
1900 }
1901 
1902 namespace {
1903 /// Reshape of a splat constant can be replaced with a constant of the result
1904 /// type.
1905 template <typename TensorReshapeOp>
1906 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1907   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1908   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1909                                 PatternRewriter &rewriter) const override {
1910     DenseElementsAttr attr;
1911     if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
1912       return failure();
1913     if (!attr || !attr.isSplat())
1914       return failure();
1915     DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
1916         reshapeOp.getResultType(), attr.getRawData(), true);
1917     rewriter.replaceOpWithNewOp<ConstantOp>(reshapeOp, newAttr);
1918     return success();
1919   }
1920 };
1921 
1922 /// Fold linalg.fill -> linalg.tensor_reshape chain.
1923 ///
1924 /// For such op chains, we can create new linalg.fill ops with the result
1925 /// type of the linalg.tensor_reshape op.
1926 template <typename TensorReshapeOp>
1927 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
1928   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1929   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1930                                 PatternRewriter &rewriter) const override {
1931     auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>();
1932     if (!oldFill)
1933       return failure();
1934 
1935     Location loc = oldFill.getLoc();
1936     auto newInit = rewriter.create<TensorReshapeOp>(
1937         loc, reshapeOp.getResultType(), oldFill.output(),
1938         reshapeOp.reassociation());
1939     rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, oldFill.value(), newInit);
1940 
1941     return success();
1942   }
1943 };
1944 } // namespace
1945 
1946 void TensorExpandShapeOp::getCanonicalizationPatterns(
1947     RewritePatternSet &results, MLIRContext *context) {
1948   results
1949       .add<CollapseReshapeOps<TensorExpandShapeOp>,
1950            CollapseMixedReshapeOps<TensorExpandShapeOp, TensorCollapseShapeOp>,
1951            FoldFillWithTensorReshape<TensorExpandShapeOp>,
1952            FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
1953            FoldReshapeWithConstant<TensorExpandShapeOp>>(context);
1954 }
1955 
1956 void TensorCollapseShapeOp::getCanonicalizationPatterns(
1957     RewritePatternSet &results, MLIRContext *context) {
1958   results
1959       .add<CollapseReshapeOps<TensorCollapseShapeOp>,
1960            CollapseMixedReshapeOps<TensorCollapseShapeOp, TensorExpandShapeOp>,
1961            FoldFillWithTensorReshape<TensorCollapseShapeOp>,
1962            FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
1963            FoldReshapeWithConstant<TensorCollapseShapeOp>>(context);
1964 }
1965 
1966 LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim(
1967     OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
1968   auto resultShape =
1969       getAsValues(b, getLoc(),
1970                   getReshapeOutputShapeFromInputShape(
1971                       b, getLoc(), src(), getResultType().getShape(),
1972                       getReassociationMaps()));
1973   reifiedReturnShapes.emplace_back(std::move(resultShape));
1974   return success();
1975 }
1976 
1977 LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim(
1978     OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
1979   auto resultShape =
1980       getAsValues(b, getLoc(),
1981                   getReshapeOutputShapeFromInputShape(
1982                       b, getLoc(), src(), getResultType().getShape(),
1983                       getReassociationMaps()));
1984   reifiedReturnShapes.emplace_back(std::move(resultShape));
1985   return success();
1986 }
1987 
1988 //===----------------------------------------------------------------------===//
1989 // YieldOp
1990 //===----------------------------------------------------------------------===//
1991 
1992 static void print(OpAsmPrinter &p, linalg::YieldOp op) {
1993   p << op.getOperationName();
1994   if (op.getNumOperands() > 0)
1995     p << ' ' << op.getOperands();
1996   p.printOptionalAttrDict(op->getAttrs());
1997   if (op.getNumOperands() > 0)
1998     p << " : " << op.getOperandTypes();
1999 }
2000 
2001 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
2002   SmallVector<OpAsmParser::OperandType, 2> opInfo;
2003   SmallVector<Type, 2> types;
2004   llvm::SMLoc loc = parser.getCurrentLocation();
2005   return failure(parser.parseOperandList(opInfo) ||
2006                  parser.parseOptionalAttrDict(result.attributes) ||
2007                  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2008                  parser.resolveOperands(opInfo, types, loc, result.operands));
2009 }
2010 
2011 // Check the operand number and types must match the element types of the
2012 // LinalgOp interface's shaped operands.
2013 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2014   if (op.getNumOperands() != linalgOp.getNumOutputs())
2015     return op.emitOpError("expected number of yield values (")
2016            << linalgOp.getNumOutputs()
2017            << ") to match the number of operands of the enclosing "
2018            << "LinalgOp (" << op.getNumOperands() << ")";
2019 
2020   for (OpOperand &opOperand : op->getOpOperands()) {
2021     OpOperand *outputOperand =
2022         linalgOp.getOutputOperand(opOperand.getOperandNumber());
2023     Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
2024     if (opOperand.get().getType() != elementType)
2025       return op.emitOpError("type of yield operand ")
2026              << (opOperand.getOperandNumber() + 1) << " ("
2027              << opOperand.get().getType() << ") doesn't match "
2028              << "the element type of the enclosing linalg.generic op ("
2029              << elementType << ")";
2030   }
2031   return success();
2032 }
2033 
2034 static LogicalResult verify(linalg::YieldOp op) {
2035   auto *parentOp = op->getParentOp();
2036   if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2037     return op.emitOpError("expected single non-empty parent region");
2038 
2039   if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2040     return verifyYield(op, cast<LinalgOp>(parentOp));
2041 
2042   if (auto padTensorOp = dyn_cast<linalg::PadTensorOp>(parentOp)) {
2043     if (op.getNumOperands() != 1)
2044       return op.emitOpError("expected single yield operand (got ")
2045              << op->getNumOperands() << ")";
2046     if (op.getOperand(0).getType() !=
2047         padTensorOp.getType().cast<ShapedType>().getElementType())
2048       return op.emitOpError("expected yield type to match shape element type");
2049     return success();
2050   }
2051 
2052   if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) {
2053     // Check if output args with tensor types match results types.
2054     SmallVector<Value, 2> tensorOuts;
2055     llvm::copy_if(
2056         tiledLoopOp.outputs(), std::back_inserter(tensorOuts),
2057         [&](Value out) { return out.getType().isa<RankedTensorType>(); });
2058     if (tensorOuts.size() != op.values().size())
2059       return op.emitOpError("expected number of tensor output args = ")
2060              << tensorOuts.size() << " to match the number of yield operands = "
2061              << op.values().size();
2062 
2063     TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts));
2064     for (auto &item :
2065          llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) {
2066       Type outType, resultType;
2067       unsigned index = item.index();
2068       std::tie(outType, resultType) = item.value();
2069       if (outType != resultType)
2070         return op.emitOpError("expected yield operand ")
2071                << index << " with type = " << resultType
2072                << " to match output arg type = " << outType;
2073     }
2074     return success();
2075   }
2076   return op.emitOpError("expected parent op with LinalgOp interface");
2077 }
2078 
2079 //===----------------------------------------------------------------------===//
2080 // TiledLoopOp
2081 //===----------------------------------------------------------------------===//
2082 
2083 void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
2084                         ValueRange lowerBounds, ValueRange upperBounds,
2085                         ValueRange steps, ValueRange inputs, ValueRange outputs,
2086                         ArrayAttr iteratorTypes,
2087                         function_ref<void(OpBuilder &, Location, ValueRange,
2088                                           ValueRange, ValueRange)>
2089                             bodyBuilderFn) {
2090   build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs,
2091         iteratorTypes, llvm::None, bodyBuilderFn);
2092 }
2093 
2094 void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
2095                         ValueRange lowerBounds, ValueRange upperBounds,
2096                         ValueRange steps, ValueRange inputs, ValueRange outputs,
2097                         ArrayAttr iteratorTypes,
2098                         Optional<ArrayAttr> distributionTypes,
2099                         function_ref<void(OpBuilder &, Location, ValueRange,
2100                                           ValueRange, ValueRange)>
2101                             bodyBuilderFn) {
2102   result.addOperands(lowerBounds);
2103   result.addOperands(upperBounds);
2104   result.addOperands(steps);
2105   result.addOperands(inputs);
2106   result.addOperands(outputs);
2107   result.addAttribute(
2108       TiledLoopOp::getOperandSegmentSizeAttr(),
2109       builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
2110                                 static_cast<int32_t>(upperBounds.size()),
2111                                 static_cast<int32_t>(steps.size()),
2112                                 static_cast<int32_t>(inputs.size()),
2113                                 static_cast<int32_t>(outputs.size())}));
2114   result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
2115 
2116   if (distributionTypes.hasValue())
2117     result.addAttribute(getDistributionTypesAttrName(),
2118                         distributionTypes.getValue());
2119 
2120   // Add output types for `RankedTensorType` output arguments.
2121   for (Value output : outputs) {
2122     Type outputType = output.getType();
2123     if (outputType.isa<RankedTensorType>())
2124       result.addTypes(outputType);
2125   }
2126 
2127   OpBuilder::InsertionGuard guard(builder);
2128   unsigned numIVs = steps.size();
2129   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2130   for (Type type : TypeRange(inputs))
2131     argTypes.push_back(type);
2132   for (Type type : TypeRange(outputs))
2133     argTypes.push_back(type);
2134   Region *bodyRegion = result.addRegion();
2135   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
2136 
2137   if (bodyBuilderFn) {
2138     builder.setInsertionPointToStart(bodyBlock);
2139     bodyBuilderFn(builder, result.location,
2140                   bodyBlock->getArguments().take_front(numIVs),
2141                   bodyBlock->getArguments().slice(numIVs, inputs.size()),
2142                   bodyBlock->getArguments().take_back(outputs.size()));
2143     TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
2144   }
2145 }
2146 
2147 static void print(OpAsmPrinter &p, TiledLoopOp op) {
2148   p << op.getOperationName() << " (" << op.getInductionVars() << ") = ("
2149     << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
2150     << ")";
2151 
2152   if (!op.inputs().empty()) {
2153     p << " ins (";
2154     llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p,
2155                           [&](auto it) {
2156                             p << std::get<0>(it) << " = " << std::get<1>(it)
2157                               << ": " << std::get<1>(it).getType();
2158                           });
2159     p << ")";
2160   }
2161   if (!op.outputs().empty()) {
2162     p << " outs (";
2163     llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p,
2164                           [&](auto it) {
2165                             p << std::get<0>(it) << " = " << std::get<1>(it)
2166                               << ": " << std::get<1>(it).getType();
2167                           });
2168     p << ")";
2169   }
2170 
2171   if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
2172         return attr.cast<StringAttr>().getValue() !=
2173                getParallelIteratorTypeName();
2174       }))
2175     p << " iterators" << op.iterator_types() << "";
2176 
2177   if (op.distribution_types().hasValue())
2178     p << " distribution" << op.distribution_types().getValue() << "";
2179 
2180   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
2181   p.printOptionalAttrDict(
2182       op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(),
2183                                        getIteratorTypesAttrName(),
2184                                        getDistributionTypesAttrName()});
2185 }
2186 
2187 static ParseResult parseTiledLoopOp(OpAsmParser &parser,
2188                                     OperationState &result) {
2189   auto &builder = parser.getBuilder();
2190   // Parse an opening `(` followed by induction variables followed by `)`
2191   SmallVector<OpAsmParser::OperandType, 4> ivs;
2192   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
2193                                      OpAsmParser::Delimiter::Paren))
2194     return failure();
2195 
2196   // Parse loop bounds.
2197   SmallVector<OpAsmParser::OperandType, 4> lower;
2198   if (parser.parseEqual() ||
2199       parser.parseOperandList(lower, ivs.size(),
2200                               OpAsmParser::Delimiter::Paren) ||
2201       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2202     return failure();
2203 
2204   SmallVector<OpAsmParser::OperandType, 4> upper;
2205   if (parser.parseKeyword("to") ||
2206       parser.parseOperandList(upper, ivs.size(),
2207                               OpAsmParser::Delimiter::Paren) ||
2208       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2209     return failure();
2210 
2211   // Parse step values.
2212   SmallVector<OpAsmParser::OperandType, 4> steps;
2213   if (parser.parseKeyword("step") ||
2214       parser.parseOperandList(steps, ivs.size(),
2215                               OpAsmParser::Delimiter::Paren) ||
2216       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2217     return failure();
2218 
2219   // Parse input tensors.
2220   SmallVector<OpAsmParser::OperandType, 4> inputs, input_region_args;
2221   SmallVector<Type, 4> inputTypes;
2222   if (succeeded(parser.parseOptionalKeyword("ins"))) {
2223     llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
2224 
2225     if (parser.parseAssignmentListWithTypes(input_region_args, inputs,
2226                                             inputTypes))
2227       return failure();
2228 
2229     if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
2230                                result.operands))
2231       return failure();
2232   }
2233 
2234   // Parse output tensors.
2235   SmallVector<OpAsmParser::OperandType, 4> outputs, output_region_args;
2236   SmallVector<Type, 4> outputTypes;
2237   if (succeeded(parser.parseOptionalKeyword("outs"))) {
2238     llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
2239 
2240     if (parser.parseAssignmentListWithTypes(output_region_args, outputs,
2241                                             outputTypes))
2242       return failure();
2243 
2244     if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
2245                                result.operands))
2246       return failure();
2247     for (Type outputType : outputTypes)
2248       if (outputType.isa<RankedTensorType>())
2249         result.addTypes(outputType);
2250   }
2251 
2252   // Parse attributes.
2253   SmallVector<Attribute, 4> iterTypes, distributionTypes;
2254   auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) {
2255     if (succeeded(parser.parseOptionalKeyword(keyword))) {
2256       StringAttr attr;
2257 
2258       if (parser.parseLSquare() || parser.parseAttribute(attr))
2259         return failure();
2260       attrs->push_back(attr);
2261       for (int i = 1, e = ivs.size(); i < e; ++i) {
2262         if (parser.parseComma() || parser.parseAttribute(attr))
2263           return failure();
2264         attrs->push_back(attr);
2265       }
2266       if (parser.parseRSquare())
2267         return failure();
2268     }
2269     return success();
2270   };
2271   if (failed(parseAttr("iterators", &iterTypes)) ||
2272       failed(parseAttr("distribution", &distributionTypes)))
2273     return failure();
2274 
2275   // Set all loop iterator types to "parallel" if they are not printed in IR.
2276   if (iterTypes.empty()) {
2277     auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName());
2278     iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
2279   }
2280   result.addAttribute(getIteratorTypesAttrName(),
2281                       builder.getArrayAttr(iterTypes));
2282   if (!distributionTypes.empty())
2283     result.addAttribute(getDistributionTypesAttrName(),
2284                         builder.getArrayAttr(distributionTypes));
2285   result.addAttribute(
2286       TiledLoopOp::getOperandSegmentSizeAttr(),
2287       builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
2288                                 static_cast<int32_t>(upper.size()),
2289                                 static_cast<int32_t>(steps.size()),
2290                                 static_cast<int32_t>(inputs.size()),
2291                                 static_cast<int32_t>(outputs.size())}));
2292 
2293   // Parse the body.
2294   Region *body = result.addRegion();
2295 
2296   SmallVector<Type, 4> region_types(ivs.size(), builder.getIndexType());
2297   region_types.append(inputTypes);
2298   region_types.append(outputTypes);
2299 
2300   SmallVector<OpAsmParser::OperandType, 4> region_args(ivs);
2301   region_args.append(input_region_args);
2302   region_args.append(output_region_args);
2303 
2304   if (parser.parseRegion(*body, region_args, region_types))
2305     return failure();
2306 
2307   // Parse optional attributes.
2308   parser.parseOptionalAttrDict(result.attributes);
2309 
2310   return success();
2311 }
2312 
2313 Region &TiledLoopOp::getLoopBody() { return region(); }
2314 
2315 LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
2316   for (auto *op : ops)
2317     op->moveBefore(*this);
2318   return success();
2319 }
2320 
2321 bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) {
2322   return !region().isAncestor(value.getParentRegion());
2323 }
2324 
2325 static LogicalResult verify(TiledLoopOp op) {
2326   // Check if iterator types are provided for every loop dimension.
2327   if (op.iterator_types().size() != op.getNumLoops())
2328     return op.emitOpError("expected iterator types array attribute size = ")
2329            << op.iterator_types().size()
2330            << " to match the number of loops = " << op.getNumLoops();
2331 
2332   // Check if types of input arguments match region args types.
2333   for (auto &item :
2334        llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) {
2335     Value input, inputRegionArg;
2336     unsigned index = item.index();
2337     std::tie(input, inputRegionArg) = item.value();
2338     if (input.getType() != inputRegionArg.getType())
2339       return op.emitOpError("expected input arg ")
2340              << index << " with type = " << input.getType()
2341              << " to match region arg " << index + op.getNumLoops()
2342              << " type = " << inputRegionArg.getType();
2343   }
2344 
2345   // Check if types of input arguments match region args types.
2346   for (auto &item :
2347        llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) {
2348     Value output, outputRegionArg;
2349     unsigned index = item.index();
2350     std::tie(output, outputRegionArg) = item.value();
2351     if (output.getType() != outputRegionArg.getType())
2352       return op.emitOpError("expected output arg ")
2353              << index << " with type = " << output.getType()
2354              << " to match region arg "
2355              << index + op.getNumLoops() + op.inputs().size()
2356              << " type = " << outputRegionArg.getType();
2357   }
2358   return success();
2359 }
2360 
2361 namespace {
2362 
2363 static constexpr int64_t kNoMatch = -1;
2364 
2365 // Folds away TiledLoopOp inputs if they have no uses within the body.
2366 //
2367 // Example:
2368 //
2369 // %0 = linalg.tiled_loop ...  ins (%in_ = %in: tensor<...>,
2370 //                                  %in_buf_ = %in_buf: memref<...>) {...}
2371 // Becomes
2372 //
2373 // linalg.tiled_loop ...  ins (%in_buf_ = %in_buf: memref<...>) {...}
2374 struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
2375   using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
2376 
2377   LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
2378                                 PatternRewriter &rewriter) const final {
2379     SmallVector<Value, 2> newInputs, regionInputTensorArgs;
2380     // Store ids of the corresponding old and new input operands.
2381     SmallVector<int64_t, 2> oldInputIdToNew(tiledLoop.inputs().size(),
2382                                             kNoMatch);
2383     for (auto en : llvm::enumerate(
2384              llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) {
2385       Value in, bbArg;
2386       size_t index = en.index();
2387       std::tie(in, bbArg) = en.value();
2388       if (!bbArg.use_empty()) {
2389         oldInputIdToNew[index] = newInputs.size();
2390         newInputs.push_back(in);
2391       }
2392     }
2393     if (newInputs.size() == tiledLoop.inputs().size())
2394       return failure();
2395     Location loc = tiledLoop.getLoc();
2396     auto newTiledLoop = rewriter.create<TiledLoopOp>(
2397         loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
2398         newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(),
2399         tiledLoop.distribution_types());
2400 
2401     // Clone the region.
2402     BlockAndValueMapping bvm;
2403     bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
2404     bvm.map(tiledLoop.getRegionOutputArgs(),
2405             newTiledLoop.getRegionOutputArgs());
2406     for (const auto &en : llvm::enumerate(oldInputIdToNew))
2407       if (en.value() != kNoMatch)
2408         bvm.map(tiledLoop.getRegionInputArgs()[en.index()],
2409                 newTiledLoop.getRegionInputArgs()[en.value()]);
2410     OpBuilder innerBuilder =
2411         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
2412     for (auto &op : *tiledLoop.getBody())
2413       innerBuilder.clone(op, bvm);
2414     rewriter.replaceOp(tiledLoop, newTiledLoop.getResults());
2415 
2416     return success();
2417   }
2418 };
2419 
2420 // Folds away TiledLoopOp output tensors when the following conditions are met:
2421 // * result of `linalg.tiled_loop` has no uses
2422 // * output tensor is the argument of `linalg.yield`
2423 //
2424 // Example:
2425 //
2426 // %0 = linalg.tiled_loop ...  outs (%o_ = %out: tensor<...>,
2427 //                                   %obuf_ = %out_buf: memref<...>) {
2428 //   ...
2429 //   linalg.yield %o_ : tensor ...
2430 // }
2431 //
2432 // Becomes
2433 //
2434 // linalg.tiled_loop ...  outs (%obuf_ = %out_buf: memref<...>) {
2435 //   ...
2436 //   linalg.yield
2437 // }
2438 struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
2439   using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
2440 
2441   LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
2442                                 PatternRewriter &rewriter) const final {
2443     if (tiledLoop.getNumResults() == 0)
2444       return failure();
2445 
2446     Block *block = tiledLoop.getBody();
2447     auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
2448 
2449     // Match the pattern and collect output buffers that will replace the output
2450     // tensors and also the ops that will be ignored when cloning the body.
2451     SmallVector<Value, 2> newOutputOperands, newYieldArgs;
2452     int resultId = 0;
2453     // Store ids of the corresponding old and new output operands.
2454     SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
2455                                              kNoMatch);
2456     // Store ids of the corresponding old and new results.
2457     SmallVector<int64_t, 2> oldResultIdToNew(tiledLoop.getNumResults(),
2458                                              kNoMatch);
2459     SmallVector<Value, 2> resultReplacement(tiledLoop.getNumResults());
2460     for (auto en : llvm::enumerate(
2461              llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
2462       size_t index = en.index();
2463       Value out = std::get<0>(en.value());
2464       Value outRegionArg = std::get<1>(en.value());
2465 
2466       if (!out.getType().isa<RankedTensorType>()) {
2467         oldOutputIdToNew[index] = newOutputOperands.size();
2468         newOutputOperands.push_back(out);
2469         continue;
2470       }
2471       Value result = tiledLoop.getResult(resultId);
2472       Value yieldArg = yieldOp.getOperand(resultId);
2473       if (yieldArg != outRegionArg || !result.use_empty()) {
2474         oldOutputIdToNew[index] = newOutputOperands.size();
2475         oldResultIdToNew[resultId] = newYieldArgs.size();
2476         resultReplacement[resultId] = out;
2477         newOutputOperands.push_back(out);
2478         newYieldArgs.push_back(yieldArg);
2479       }
2480       ++resultId;
2481     }
2482     if (newOutputOperands.size() == tiledLoop.outputs().size())
2483       return failure();
2484 
2485     Location loc = tiledLoop.getLoc();
2486     auto newTiledLoop = rewriter.create<TiledLoopOp>(
2487         loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
2488         tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(),
2489         tiledLoop.distribution_types());
2490 
2491     // Clone the region.
2492     BlockAndValueMapping bvm;
2493     bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
2494     bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs());
2495     for (const auto &en : llvm::enumerate(oldOutputIdToNew)) {
2496       if (en.value() != kNoMatch)
2497         bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
2498                 newTiledLoop.getRegionOutputArgs()[en.value()]);
2499       else
2500         bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
2501                 tiledLoop.outputs()[en.index()]);
2502     }
2503     OpBuilder innerBuilder =
2504         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
2505     for (auto &op : tiledLoop.getBody()->without_terminator())
2506       innerBuilder.clone(op, bvm);
2507     innerBuilder.create<linalg::YieldOp>(
2508         loc, llvm::to_vector<2>(llvm::map_range(
2509                  newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
2510 
2511     for (const auto &en : llvm::enumerate(oldResultIdToNew))
2512       if (en.value() != kNoMatch)
2513         resultReplacement[en.index()] = newTiledLoop.getResult(en.value());
2514     rewriter.replaceOp(tiledLoop, resultReplacement);
2515 
2516     return success();
2517   }
2518 };
2519 } // namespace
2520 
2521 void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2522                                               MLIRContext *context) {
2523   results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder>(context);
2524 }
2525 
2526 LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
2527                                 SmallVectorImpl<OpFoldResult> &) {
2528   return foldMemRefCastInTiledLoopOp(*this);
2529 }
2530 
2531 //===----------------------------------------------------------------------===//
2532 // IndexOp
2533 //===----------------------------------------------------------------------===//
2534 
2535 static LogicalResult verify(IndexOp op) {
2536   auto linalgOp = dyn_cast<LinalgOp>(op->getParentOp());
2537   if (!linalgOp)
2538     return op.emitOpError("expected parent op with LinalgOp interface");
2539   if (linalgOp.getNumLoops() <= op.dim())
2540     return op.emitOpError("expected dim (")
2541            << op.dim() << ") to be lower than the number of loops ("
2542            << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2543   return success();
2544 }
2545 
2546 /////// Operations corresponding to library calls defined with Tablegen ////////
2547 
2548 template <typename LinalgPoolingOp>
2549 static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
2550                                             ArrayRef<Attribute> attrs,
2551                                             bool isStride) {
2552   auto strideOrDilation = isStride ? "stride" : "dilation";
2553   if (attrs.size() != op.getNumWindowLoops())
2554     return op.emitOpError("expects num ")
2555            << strideOrDilation
2556            << "s equal to number of window dimensions: " << attrs.size()
2557            << " vs " << op.getNumWindowLoops();
2558   return success();
2559 }
2560 
2561 void ConvOp::getEffects(
2562     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2563         &effects) {
2564   effects.emplace_back(MemoryEffects::Read::get(), input(),
2565                        SideEffects::DefaultResource::get());
2566   effects.emplace_back(MemoryEffects::Read::get(), filter(),
2567                        SideEffects::DefaultResource::get());
2568   effects.emplace_back(MemoryEffects::Write::get(), output(),
2569                        SideEffects::DefaultResource::get());
2570 }
2571 
2572 static LogicalResult verify(ConvOp op) {
2573   auto oType = op.output().getType().cast<MemRefType>();
2574   auto fType = op.filter().getType().cast<MemRefType>();
2575   auto iType = op.input().getType().cast<MemRefType>();
2576   if (oType.getElementType() != iType.getElementType() ||
2577       oType.getElementType() != fType.getElementType())
2578     return op.emitOpError("expects memref elemental types to match");
2579   if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
2580     return op.emitOpError("expects memref ranks to match");
2581   if (auto strides = op.strides()) {
2582     if (failed(verifyStrideOrDilation(op, strides->getValue(),
2583                                       /*isStride=*/true)))
2584       return failure();
2585   }
2586   if (auto dilations = op.dilations()) {
2587     if (failed(verifyStrideOrDilation(op, dilations->getValue(),
2588                                       /*isStride=*/false)))
2589       return failure();
2590   }
2591   return success();
2592 }
2593 
2594 template <typename PoolingOp>
2595 static LogicalResult verifySingleInputPoolingOp(PoolingOp op) {
2596   auto inputType = op.input().getType().template cast<MemRefType>();
2597   auto outputType = op.output().getType().template cast<MemRefType>();
2598   if (outputType.getElementType() != inputType.getElementType())
2599     return op.emitOpError("expects memref elemental types to match");
2600 
2601   auto windowDimsType = op.windowDims().getType().template cast<MemRefType>();
2602   if (outputType.getRank() != inputType.getRank() ||
2603       outputType.getRank() != windowDimsType.getRank())
2604     return op.emitOpError("expects memref ranks to match");
2605 
2606   if (auto strides = op.strides()) {
2607     if (failed(verifyStrideOrDilation(op, strides->getValue(),
2608                                       /*isStride=*/true)))
2609       return failure();
2610   }
2611   if (auto dilations = op.dilations()) {
2612     if (failed(verifyStrideOrDilation(op, dilations->getValue(),
2613                                       /*isStride=*/false)))
2614       return failure();
2615   }
2616   return success();
2617 }
2618 
2619 #define DEFINE_POOLING_OP_GET_EFFECTS(OP_NAME)                                 \
2620   void OP_NAME::getEffects(                                                    \
2621       SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>      \
2622           &effects) {                                                          \
2623     effects.emplace_back(MemoryEffects::Read::get(), input(),                  \
2624                          SideEffects::DefaultResource::get());                 \
2625     effects.emplace_back(MemoryEffects::Write::get(), output(),                \
2626                          SideEffects::DefaultResource::get());                 \
2627   }
2628 
2629 static LogicalResult verify(PoolingMaxOp op) {
2630   return verifySingleInputPoolingOp(op);
2631 }
2632 static LogicalResult verify(PoolingMinOp op) {
2633   return verifySingleInputPoolingOp(op);
2634 }
2635 static LogicalResult verify(PoolingSumOp op) {
2636   return verifySingleInputPoolingOp(op);
2637 }
2638 
2639 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp)
2640 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp)
2641 DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp)
2642 
2643 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc"
2644 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2645 
2646 #define GET_OP_CLASSES
2647 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2648 
2649 #define GET_OP_CLASSES
2650 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2651 
2652 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
2653 /// Assumes `op` is a LinalgOp.
2654 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
2655                                  SmallVectorImpl<AffineExpr> &res) {
2656   if (!cast<LinalgOp>(op).iterator_types())
2657     return;
2658 
2659   unsigned dim = 0;
2660   MLIRContext *ctx = op->getContext();
2661   for (auto tn :
2662        cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
2663     if (tn == iteratorTypeName)
2664       res.push_back(getAffineDimExpr(dim, ctx));
2665     ++dim;
2666   }
2667 }
2668 
2669 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
2670                                              unsigned rank,
2671                                              MLIRContext *context) {
2672   if (maybeMap)
2673     return maybeMap.getValue();
2674   if (rank == 0)
2675     return AffineMap::get(context);
2676   return AffineMap::getMultiDimIdentityMap(rank, context);
2677 }
2678 
2679 SmallVector<AffineExpr, 4>
2680 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2681                                  MLIRContext *context) {
2682   SmallVector<AffineExpr, 4> res;
2683   res.reserve(num);
2684   for (unsigned i = 0; i < num; ++i)
2685     res.push_back(getAffineDimExpr(startIdx++, context));
2686   return res;
2687 }
2688 
2689 template <typename PoolingOp>
2690 SmallVector<AffineExpr, 4>
2691 mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
2692                                         ArrayRef<AffineExpr> outputDims,
2693                                         ArrayRef<AffineExpr> windowDims) {
2694   assert(outputDims.size() == windowDims.size());
2695   SmallVector<AffineExpr, 4> res;
2696   res.reserve(outputDims.size());
2697   for (unsigned i = 0, e = outputDims.size(); i < e; ++i) {
2698     // TODO: add a level of indirection to linalg.generic.
2699     auto expr = op.getStride(i) * outputDims[i] +
2700                 op.getDilation(i) * windowDims[i] - op.getLowPad(i);
2701     res.push_back(expr);
2702   }
2703   return res;
2704 }
2705 
2706 #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE)                      \
2707   template SmallVector<AffineExpr, 4>                                          \
2708   mlir::linalg::weightedPoolingInputIndex<OP_TYPE>(                            \
2709       OP_TYPE op, ArrayRef<AffineExpr> outputDims,                             \
2710       ArrayRef<AffineExpr> windowDims);
2711 
2712 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp)
2713 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp)
2714 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMinOp)
2715 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingSumOp)
2716 
2717 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
2718                                                 ArrayRef<AffineExpr> b) {
2719   auto rangeA = llvm::make_range(a.begin(), a.end());
2720   auto rangeB = llvm::make_range(b.begin(), b.end());
2721   auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2722   return llvm::to_vector<4>(concatRanges);
2723 }
2724 
2725 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2726   if (auto memref = t.dyn_cast<MemRefType>()) {
2727     ss << "view";
2728     for (auto size : memref.getShape())
2729       if (size < 0)
2730         ss << "sx";
2731       else
2732         ss << size << "x";
2733     appendMangledType(ss, memref.getElementType());
2734   } else if (auto vec = t.dyn_cast<VectorType>()) {
2735     ss << "vector";
2736     llvm::interleave(
2737         vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2738     appendMangledType(ss, vec.getElementType());
2739   } else if (t.isSignlessIntOrIndexOrFloat()) {
2740     ss << t;
2741   } else {
2742     llvm_unreachable("Invalid type for linalg library name mangling");
2743   }
2744 }
2745 
2746 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
2747   assert(isa<LinalgOp>(op));
2748   std::string name(op->getName().getStringRef().str());
2749   name.reserve(128);
2750   std::replace(name.begin(), name.end(), '.', '_');
2751   llvm::raw_string_ostream ss(name);
2752   ss << "_";
2753   auto types = op->getOperandTypes();
2754   llvm::interleave(
2755       types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
2756       [&]() { ss << "_"; });
2757   return ss.str();
2758 }
2759 
2760 // TODO: Consider making all this boilerplate easy to autogenerate
2761 // with Tablegen. This seems a desirable property in the context of
2762 // OpInterfaces where a Linalg "named" op **isa** LinalgOp.
2763 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
2764   if (succeeded(foldMemRefCast(*this)))
2765     return getResult();
2766   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
2767 }
2768 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
2769   if (succeeded(foldMemRefCast(*this)))
2770     return getResult();
2771   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
2772 }
2773 OpFoldResult TensorExpandShapeOp::fold(ArrayRef<Attribute> operands) {
2774   return foldReshapeOp<TensorExpandShapeOp, TensorCollapseShapeOp>(*this,
2775                                                                    operands);
2776 }
2777 OpFoldResult TensorCollapseShapeOp::fold(ArrayRef<Attribute> operands) {
2778   return foldReshapeOp<TensorCollapseShapeOp, TensorExpandShapeOp>(*this,
2779                                                                    operands);
2780 }
2781 
2782 //===----------------------------------------------------------------------===//
2783 // Support for named Linalg ops defined in ods-gen.
2784 //===----------------------------------------------------------------------===//
2785 
2786 /// Generic entry point to create the block for the region of a LinalgOp.
2787 /// This is used by both named structured ops created by ods-gen and by manually
2788 /// defined C++ ops.
2789 /// This is used by both builders and parsers.
2790 /// This function creates the block in the region with arguments corresponding
2791 /// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
2792 /// to be ShapedType.
2793 template <typename NamedStructuredOpType>
2794 static void
2795 fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
2796                        TypeRange inputTypes, TypeRange outputTypes,
2797                        std::function<void(unsigned, unsigned)> errorHandler) {
2798   assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
2799 
2800   // TODO: atm all operands go through getElementTypeOrSelf,
2801   // reconsider when we have evidence we need to.
2802   SmallVector<Type, 8> argTypes;
2803   for (auto containers : {inputTypes, outputTypes})
2804     for (auto t : containers)
2805       argTypes.push_back(getElementTypeOrSelf(t));
2806 
2807   // RAII.
2808   OpBuilder::InsertionGuard guard(opBuilder);
2809   Block *body = opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes);
2810   unsigned actual = body->getNumArguments();
2811   unsigned expected = NamedStructuredOpType::getNumRegionArgs();
2812   if (expected != actual) {
2813     if (errorHandler)
2814       errorHandler(expected, actual);
2815     return;
2816   }
2817 
2818   opBuilder.setInsertionPointToStart(body);
2819   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
2820   NamedStructuredOpType::regionBuilder(b, *body);
2821 
2822   // indexing_maps is an auto-generated method.
2823 
2824   // iterator_types is an auto-generated method.
2825 }
2826 
2827 /// Generic entry point to create both the region and the block of a LinalgOp.
2828 template <typename NamedStructuredOpType>
2829 void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
2830                                      OperationState &result,
2831                                      TypeRange inputTypes,
2832                                      TypeRange outputTypes) {
2833   Region &region = *result.addRegion();
2834   fillStructuredOpRegion<NamedStructuredOpType>(
2835       opBuilder, region, inputTypes, outputTypes,
2836       [&](unsigned expected, unsigned actual) {
2837         assert(expected != actual && "incorrect number of arguments");
2838       });
2839 }
2840 
2841 /// Common parsing used for both named structured ops created by ods-gen and by
2842 /// manually defined C++ ops. Does not handle regions.
2843 static ParseResult
2844 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
2845                              SmallVectorImpl<Type> &inputTypes,
2846                              SmallVectorImpl<Type> &outputTypes) {
2847   llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc;
2848   SmallVector<OpAsmParser::OperandType, 4> inputsOperands, outputsOperands;
2849 
2850   parser.parseOptionalAttrDict(result.attributes);
2851 
2852   if (succeeded(parser.parseOptionalKeyword("ins"))) {
2853     if (parser.parseLParen())
2854       return failure();
2855 
2856     inputsOperandsLoc = parser.getCurrentLocation();
2857     if (parser.parseOperandList(inputsOperands) ||
2858         parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2859       return failure();
2860   }
2861 
2862   if (succeeded(parser.parseOptionalKeyword("outs"))) {
2863     outputsOperandsLoc = parser.getCurrentLocation();
2864     if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
2865         parser.parseColonTypeList(outputTypes) || parser.parseRParen())
2866       return failure();
2867   }
2868 
2869   if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2870                              result.operands) ||
2871       parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
2872                              result.operands))
2873     return failure();
2874 
2875   result.addAttribute("operand_segment_sizes",
2876                       parser.getBuilder().getI32VectorAttr(
2877                           {static_cast<int32_t>(inputsOperands.size()),
2878                            static_cast<int32_t>(outputsOperands.size())}));
2879   return success();
2880 }
2881 
2882 template <typename NamedStructuredOpType>
2883 static void printCommonStructuredOpParts(OpAsmPrinter &p,
2884                                          NamedStructuredOpType op) {
2885   if (!op.inputs().empty())
2886     p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
2887   if (!op.outputs().empty())
2888     p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
2889 }
2890 
2891 //===----------------------------------------------------------------------===//
2892 // Specific parsing and printing for named structured ops created by ods-gen.
2893 //===----------------------------------------------------------------------===//
2894 
2895 template <typename NamedStructuredOpType>
2896 static ParseResult
2897 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
2898                              TypeRange inputTypes, TypeRange outputTypes) {
2899   ParseResult res = success();
2900   OpBuilder opBuilder(parser.getBuilder().getContext());
2901   // Resolve `captures` into `capturedValues` at parse time so we can build the
2902   // region with captures.
2903   SmallVector<Value> capturedValues;
2904   fillStructuredOpRegion<NamedStructuredOpType>(
2905       opBuilder, region, inputTypes, outputTypes,
2906       [&](unsigned expected, unsigned actual) {
2907         res = parser.emitError(
2908             parser.getCurrentLocation(),
2909             llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
2910                           "region expects {0} args, got {1}",
2911                           expected, actual));
2912         region.front().dump();
2913       });
2914   return res;
2915 }
2916 
2917 static ParseResult
2918 parseNamedStructuredOpResults(OpAsmParser &parser,
2919                               SmallVectorImpl<Type> &resultTypes) {
2920   if (parser.parseOptionalArrowTypeList(resultTypes))
2921     return failure();
2922   return success();
2923 }
2924 
2925 template <typename NamedStructuredOpType>
2926 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
2927                                           OperationState &result) {
2928   // TODO: Enable when ods-gen supports captures.
2929   SmallVector<Type, 1> inputTypes, outputTypes;
2930   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
2931     return failure();
2932 
2933   // TODO: consider merging results parsing into region parsing.
2934   // Need to wait for declarative assembly resolution to decide.
2935   SmallVector<Type, 1> outputTensorsTypes;
2936   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
2937     return failure();
2938   result.addTypes(outputTensorsTypes);
2939 
2940   std::unique_ptr<Region> region = std::make_unique<Region>();
2941   if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
2942           parser, *region, inputTypes, outputTypes))
2943     return failure();
2944   result.addRegion(std::move(region));
2945 
2946   return success();
2947 }
2948 
2949 static void printNamedStructuredOpResults(OpAsmPrinter &p,
2950                                           TypeRange resultTypes) {
2951   if (resultTypes.empty())
2952     return;
2953   p.printOptionalArrowTypeList(resultTypes);
2954 }
2955 
2956 template <typename NamedStructuredOpType>
2957 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
2958   p << op.getOperationName();
2959   p.printOptionalAttrDict(
2960       op->getAttrs(),
2961       /*elidedAttrs=*/{"operand_segment_sizes",
2962                        // See generated code in mlir-linalg-yaml-gen.cpp
2963                        "linalg.memoized_indexing_maps"});
2964 
2965   // Printing is shared with generic ops, except for the region and
2966   // attributes.
2967   printCommonStructuredOpParts(p, op);
2968 
2969   // Results printing.
2970   printNamedStructuredOpResults(p, op.result_tensors().getTypes());
2971 
2972   // Region is elided.
2973 }
2974 
2975 template <typename NamedStructuredOpType>
2976 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
2977   return verifyGenericOp<NamedStructuredOpType>(op);
2978 }
2979 
2980 //===----------------------------------------------------------------------===//
2981 // Canonicalizers and Folders.
2982 //===----------------------------------------------------------------------===//
2983 
2984 namespace {
2985 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2986   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2987 
2988   LogicalResult matchAndRewrite(LinalgOp op,
2989                                 PatternRewriter &rewriter) const override {
2990     for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
2991       // Linalg "inputs" may be either tensor or memref type.
2992       // tensor<0xelt_type> is a convention that may not always mean
2993       // "0 iterations". Only erase in cases we see memref<...x0x...>.
2994       auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
2995       if (!mt)
2996         continue;
2997       if (llvm::is_contained(op.getShape(opOperand), 0)) {
2998         rewriter.eraseOp(op);
2999         return success();
3000       }
3001     }
3002     return failure();
3003   }
3004 };
3005 
3006 struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
3007   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
3008 
3009   LogicalResult matchAndRewrite(LinalgOp op,
3010                                 PatternRewriter &rewriter) const override {
3011     // If no operand comes from a tensor::CastOp and can be folded then fail.
3012     bool hasTensorCastOperand =
3013         llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
3014           if (opOperand->get().isa<BlockArgument>())
3015             return false;
3016           auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
3017           return castOp && canFoldIntoConsumerOp(castOp);
3018         });
3019     if (!hasTensorCastOperand)
3020       return failure();
3021 
3022     SmallVector<Type, 4> newResultTypes;
3023     newResultTypes.reserve(op->getNumResults());
3024     SmallVector<Value, 4> newOperands;
3025     newOperands.reserve(op->getNumOperands());
3026     // Inputs may fold.
3027     for (OpOperand *opOperand : op.getInputOperands()) {
3028       auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
3029       newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
3030                                 ? tensorCastOp.source()
3031                                 : opOperand->get());
3032     }
3033     // Init tensors may fold, in which case the resultType must also change.
3034     for (OpOperand *opOperand : op.getOutputOperands()) {
3035       auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
3036       bool fold = canFoldIntoConsumerOp(tensorCastOp);
3037       newOperands.push_back(fold ? tensorCastOp.getOperand()
3038                                  : opOperand->get());
3039       newResultTypes.push_back(newOperands.back().getType());
3040     }
3041     // Clone op.
3042     Operation *newOp =
3043         op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
3044     SmallVector<Value, 4> replacements;
3045     replacements.reserve(newOp->getNumResults());
3046     for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
3047       Value oldResult = std::get<0>(result);
3048       Value newResult = std::get<1>(result);
3049       if (newResult.getType() != oldResult.getType()) {
3050         replacements.push_back(rewriter.create<tensor::CastOp>(
3051             op->getLoc(), oldResult.getType(), newResult));
3052       } else {
3053         replacements.push_back(newResult);
3054       }
3055     }
3056     rewriter.replaceOp(op, replacements);
3057 
3058     return success();
3059   }
3060 };
3061 } // namespace
3062 
3063 namespace {
3064 // Deduplicate redundant args of a linalg op.
3065 // An arg is redundant if it has the same Value and indexing map as another.
3066 struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
3067   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
3068 
3069   LogicalResult matchAndRewrite(LinalgOp op,
3070                                 PatternRewriter &rewriter) const override {
3071     // This pattern reduces the number of arguments of an op, which breaks
3072     // the invariants of semantically charged named ops.
3073     if (!isa<GenericOp>(op))
3074       return failure();
3075 
3076     // Associate each input to an equivalent "canonical" input that has the same
3077     // Value and indexing map.
3078     //
3079     // In the non-duplicate case, input `i` will have canonical input `i`. But
3080     // in the case of duplicated inputs, the canonical input could be some other
3081     // input `< i`. That is, a later input will have some earlier input as its
3082     // canonical input.
3083     llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
3084     // For later remapping tasks like deduplicating payload block arguments,
3085     // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
3086     // convenient.
3087     SmallVector<unsigned> canonicalInputIndices;
3088     for (OpOperand *opOperand : op.getInputOperands()) {
3089       AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
3090       // STL-like maps have a convenient behavior for our use case here. In the
3091       // case of duplicate keys, the insertion is rejected, and the returned
3092       // iterator gives access to the value already in the map.
3093       auto pair = canonicalInput.insert(
3094           {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
3095       canonicalInputIndices.push_back(pair.first->second);
3096     }
3097 
3098     // If there are no duplicate args, then bail out.
3099     if (canonicalInput.size() == op.getNumInputs())
3100       return failure();
3101 
3102     // The operands for the newly canonicalized op.
3103     SmallVector<Value> newOperands;
3104     for (OpOperand *opOperand : op.getInputOperands())
3105       if (canonicalInputIndices[opOperand->getOperandNumber()] ==
3106           opOperand->getOperandNumber())
3107         newOperands.push_back(opOperand->get());
3108     SmallVector<Value> outputOperands = op.getOutputOperands();
3109     llvm::append_range(newOperands, outputOperands);
3110 
3111     // Repair the indexing maps by filtering out the ones that have been
3112     // eliminated.
3113     SmallVector<AffineMap> newIndexingMaps;
3114     for (OpOperand *opOperand : op.getInputOperands())
3115       if (canonicalInputIndices[opOperand->getOperandNumber()] ==
3116           opOperand->getOperandNumber())
3117         newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
3118     for (OpOperand *opOperand : op.getOutputOperands())
3119       newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
3120 
3121     // Clone the old op with new operands.
3122     Operation *newOp =
3123         op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
3124     auto newLinalgOp = cast<LinalgOp>(newOp);
3125     newOp->setAttr("indexing_maps",
3126                    rewriter.getAffineMapArrayAttr(newIndexingMaps));
3127 
3128     // Set the number of inputs to the new value. The `clone` call above kept
3129     // the value from the original op.
3130     newLinalgOp.setNumInputs(canonicalInput.size());
3131 
3132     // Repair the payload entry block by RAUW'ing redundant arguments and
3133     // erasing them.
3134     Block &payload = newOp->getRegion(0).front();
3135     SmallVector<OpOperand *> inputOperands = op.getInputOperands();
3136     for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
3137       // Iterate in reverse, so that we erase later args first, preventing the
3138       // argument list from shifting unexpectedly and invalidating all our
3139       // indices.
3140       unsigned operandNumber = opOperand->getOperandNumber();
3141       if (canonicalInputIndices[operandNumber] == operandNumber)
3142         continue;
3143       payload.getArgument(operandNumber)
3144           .replaceAllUsesWith(
3145               payload.getArgument(canonicalInputIndices[operandNumber]));
3146       payload.eraseArgument(operandNumber);
3147     }
3148 
3149     rewriter.replaceOp(op, newOp->getResults());
3150     return success();
3151   }
3152 };
3153 
3154 /// Remove generic operations (on tensors) that are just copying
3155 /// the values from inputs to the results. Requirements are
3156 /// 1) All iterator types are parallel
3157 /// 2) The body contains just a yield operation with the yielded values being
3158 ///    the arguments corresponding to the operands.
3159 struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
3160   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
3161 
3162   LogicalResult matchAndRewrite(LinalgOp op,
3163                                 PatternRewriter &rewriter) const override {
3164     if (auto copyOp = dyn_cast<CopyOp>(*op)) {
3165       assert(copyOp.hasBufferSemantics());
3166       if (copyOp.input() == copyOp.output() &&
3167           copyOp.inputPermutation() == copyOp.outputPermutation()) {
3168         rewriter.eraseOp(op);
3169         return success();
3170       }
3171     }
3172 
3173     if (!isa<GenericOp>(op))
3174       return failure();
3175     if (!op.hasTensorSemantics())
3176       return failure();
3177     // Check all indexing maps are identity.
3178     if (llvm::any_of(op.getIndexingMaps(),
3179                      [](AffineMap map) { return !map.isIdentity(); }))
3180       return failure();
3181 
3182     // Check that the body of the linalg operation is just a linalg.yield
3183     // operation.
3184     Block &body = op->getRegion(0).front();
3185     if (!llvm::hasSingleElement(body))
3186       return failure();
3187     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
3188     if (!yieldOp)
3189       return failure();
3190 
3191     // Get the argument number of the returned values. That is the operand
3192     // number to use for replacing uses of this operation.
3193     SmallVector<Value, 4> returnedArgs;
3194     for (Value yieldVal : yieldOp.values()) {
3195       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
3196       if (!yieldArg || yieldArg.getOwner() != &body)
3197         return failure();
3198       unsigned argumentNumber = yieldArg.getArgNumber();
3199       returnedArgs.push_back(op->getOperand(argumentNumber));
3200     }
3201     if (returnedArgs.size() != op.getOperation()->getNumResults())
3202       return failure();
3203     rewriter.replaceOp(op, returnedArgs);
3204     return success();
3205   }
3206 };
3207 } // namespace
3208 
3209 #define LINALGOP_FOLDERS(XXX)                                                  \
3210   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
3211                           SmallVectorImpl<OpFoldResult> &) {                   \
3212     return foldMemRefCast(*this);                                              \
3213   }
3214 
3215 LINALGOP_FOLDERS(ConvOp)
3216 LINALGOP_FOLDERS(PoolingMaxOp)
3217 LINALGOP_FOLDERS(PoolingMinOp)
3218 LINALGOP_FOLDERS(PoolingSumOp)
3219 LINALGOP_FOLDERS(CopyOp)
3220 LINALGOP_FOLDERS(FillOp)
3221 LINALGOP_FOLDERS(GenericOp)
3222 
3223 // All named ops canonicalizers and folders are auto-generated in the
3224 // .cpp.inc.
3225 
3226 //===----------------------------------------------------------------------===//
3227 // LinalgDialect
3228 //===----------------------------------------------------------------------===//
3229 
3230 void LinalgDialect::getCanonicalizationPatterns(
3231     RewritePatternSet &results) const {
3232   results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
3233               RemoveIdentityLinalgOps>(getContext());
3234 }
3235