1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
18 #include "mlir/Dialect/Math/IR/Math.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24 #include "mlir/Dialect/Utils/StaticValueUtils.h"
25 #include "mlir/IR/AffineExprVisitor.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/IR/OpImplementation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/Interfaces/InferTypeOpInterface.h"
30 #include "mlir/Parser/Parser.h"
31 
32 #include "llvm/ADT/DenseMap.h"
33 #include "llvm/ADT/SetVector.h"
34 #include "llvm/ADT/SmallSet.h"
35 #include "llvm/ADT/StringSet.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/FormatVariadic.h"
38 #include "llvm/Support/MathExtras.h"
39 #include "llvm/Support/raw_ostream.h"
40 
41 using namespace mlir;
42 using namespace mlir::linalg;
43 
44 //===----------------------------------------------------------------------===//
45 // Support for named Linalg ops defined in ods-gen.
46 //===----------------------------------------------------------------------===//
47 
48 using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
49                                                 ArrayRef<NamedAttribute>)>;
50 
51 /// Fills the region of a structured operation using the provided
52 /// `regionBuilder`. The method is used by both named structured ops created by
53 /// ods-gen and by manually defined C++ ops. It is called by both builders and
54 /// parsers and creates a block with arguments corresponding to the elemental
55 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
56 /// ShapedType.
57 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
58                                    TypeRange inputTypes, TypeRange outputTypes,
59                                    ArrayRef<NamedAttribute> attrs,
60                                    RegionBuilderFn regionBuilder) {
61   assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
62 
63   // TODO: atm all operands go through getElementTypeOrSelf,
64   // reconsider when we have evidence we need to.
65   SmallVector<Type, 8> argTypes;
66   SmallVector<Location, 8> argLocs;
67   for (auto containers : {inputTypes, outputTypes}) {
68     for (auto t : containers) {
69       argTypes.push_back(getElementTypeOrSelf(t));
70 
71       // TODO: Pass in a proper location here.
72       argLocs.push_back(opBuilder.getUnknownLoc());
73     }
74   }
75 
76   // RAII.
77   OpBuilder::InsertionGuard guard(opBuilder);
78   Block *body =
79       opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
80 
81   opBuilder.setInsertionPointToStart(body);
82   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
83   regionBuilder(b, *body, attrs);
84 
85   // indexing_maps is an auto-generated method.
86 
87   // iterator_types is an auto-generated method.
88 }
89 
90 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
91 /// The result types are derived automatically if `resultTensorTypes` is none.
92 /// The body of the operation is filled using `regionBuilder`. All ods-gen
93 /// created structured operations use the method to implement their builders.
94 static void buildStructuredOp(OpBuilder &b, OperationState &state,
95                               llvm::Optional<TypeRange> resultTensorTypes,
96                               ValueRange inputs, ValueRange outputs,
97                               ArrayRef<NamedAttribute> attributes,
98                               RegionBuilderFn regionBuilder) {
99   // Derive the result types if needed.
100   SmallVector<Type> derivedResultTypes =
101       resultTensorTypes.getValueOr(TypeRange());
102   if (!resultTensorTypes.hasValue())
103     copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
104             [](Type type) { return type.isa<RankedTensorType>(); });
105 
106   state.addOperands(inputs);
107   state.addOperands(outputs);
108   state.addTypes(derivedResultTypes);
109   state.addAttributes(attributes);
110   state.addAttribute(
111       "operand_segment_sizes",
112       b.getI32VectorAttr({static_cast<int32_t>(inputs.size()),
113                           static_cast<int32_t>(outputs.size())}));
114 
115   // Create and fill the region of the structured operation.
116   Region &region = *state.addRegion();
117   fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
118                          state.attributes.getAttrs(), regionBuilder);
119 }
120 
121 /// Common parsing used for both named structured ops created by ods-gen and by
122 /// manually defined C++ ops. Does not handle regions.
123 static ParseResult
124 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
125                              SmallVectorImpl<Type> &inputTypes,
126                              SmallVectorImpl<Type> &outputTypes) {
127   SMLoc inputsOperandsLoc, outputsOperandsLoc;
128   SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
129       outputsOperands;
130 
131   parser.parseOptionalAttrDict(result.attributes);
132 
133   if (succeeded(parser.parseOptionalKeyword("ins"))) {
134     if (parser.parseLParen())
135       return failure();
136 
137     inputsOperandsLoc = parser.getCurrentLocation();
138     if (parser.parseOperandList(inputsOperands) ||
139         parser.parseColonTypeList(inputTypes) || parser.parseRParen())
140       return failure();
141   }
142 
143   if (succeeded(parser.parseOptionalKeyword("outs"))) {
144     outputsOperandsLoc = parser.getCurrentLocation();
145     if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
146         parser.parseColonTypeList(outputTypes) || parser.parseRParen())
147       return failure();
148   }
149 
150   if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
151                              result.operands) ||
152       parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
153                              result.operands))
154     return failure();
155 
156   result.addAttribute("operand_segment_sizes",
157                       parser.getBuilder().getI32VectorAttr(
158                           {static_cast<int32_t>(inputsOperands.size()),
159                            static_cast<int32_t>(outputsOperands.size())}));
160   return success();
161 }
162 
163 static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
164                                          ValueRange outputs) {
165   if (!inputs.empty())
166     p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
167   if (!outputs.empty())
168     p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Specific parsing and printing for named structured ops created by ods-gen.
173 //===----------------------------------------------------------------------===//
174 
175 static ParseResult parseNamedStructuredOpRegion(
176     OpAsmParser &parser, Region &region, unsigned numRegionArgs,
177     TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
178     RegionBuilderFn regionBuilder) {
179   if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
180     return parser.emitError(
181         parser.getCurrentLocation(),
182         llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
183                       "region expects {0} args, got {1}",
184                       numRegionArgs, inputTypes.size() + outputTypes.size()));
185   }
186 
187   OpBuilder opBuilder(parser.getContext());
188   fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
189                          regionBuilder);
190   return success();
191 }
192 
193 static ParseResult
194 parseNamedStructuredOpResults(OpAsmParser &parser,
195                               SmallVectorImpl<Type> &resultTypes) {
196   if (parser.parseOptionalArrowTypeList(resultTypes))
197     return failure();
198   return success();
199 }
200 
201 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
202                                           OperationState &result,
203                                           unsigned numRegionArgs,
204                                           RegionBuilderFn regionBuilder) {
205   // TODO: Enable when ods-gen supports captures.
206   SmallVector<Type, 1> inputTypes, outputTypes;
207   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
208     return failure();
209 
210   // TODO: consider merging results parsing into region parsing.
211   // Need to wait for declarative assembly resolution to decide.
212   SmallVector<Type, 1> outputTensorsTypes;
213   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
214     return failure();
215   result.addTypes(outputTensorsTypes);
216 
217   std::unique_ptr<Region> region = std::make_unique<Region>();
218   if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
219                                    outputTypes, result.attributes.getAttrs(),
220                                    regionBuilder))
221     return failure();
222   result.addRegion(std::move(region));
223 
224   return success();
225 }
226 
227 static void printNamedStructuredOpResults(OpAsmPrinter &p,
228                                           TypeRange resultTypes) {
229   if (resultTypes.empty())
230     return;
231   p.printOptionalArrowTypeList(resultTypes);
232 }
233 
234 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
235                                    ValueRange inputs, ValueRange outputs) {
236   p.printOptionalAttrDict(
237       op->getAttrs(),
238       /*elidedAttrs=*/{"operand_segment_sizes",
239                        // See generated code in mlir-linalg-yaml-gen.cpp
240                        "linalg.memoized_indexing_maps"});
241 
242   // Printing is shared with generic ops, except for the region and
243   // attributes.
244   printCommonStructuredOpParts(p, inputs, outputs);
245 
246   // Results printing.
247   printNamedStructuredOpResults(p, op->getResultTypes());
248 
249   // Region is elided.
250 }
251 
252 /// This is a common class used for patterns of the form
253 /// ```
254 ///    someop(memrefcast(%src)) -> someop(%src)
255 /// ```
256 /// It folds the source of the memref.cast into the root operation directly.
257 static LogicalResult foldMemRefCast(Operation *op) {
258   bool folded = false;
259   for (OpOperand &operand : op->getOpOperands()) {
260     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
261     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
262       operand.set(castOp.getOperand());
263       folded = true;
264     }
265   }
266   return success(folded);
267 }
268 
269 /// Helper function to find if there is atleast one dimension in an AffineMap
270 /// testMap that is contained in `testMapLocation` of  `maps` but not in any
271 /// other locations
272 static bool hasaUniqueDim(ArrayRef<AffineMap> maps, unsigned testMapLocation) {
273   AffineMap testMap = maps[testMapLocation];
274   llvm::SmallDenseSet<unsigned> dimsToCheck;
275   for (auto result : testMap.getResults()) {
276     auto expr = result.dyn_cast<AffineDimExpr>();
277     if (expr != nullptr)
278       dimsToCheck.insert(expr.getPosition());
279   }
280   for (const auto &it : llvm::enumerate(maps)) {
281     if (it.index() == testMapLocation)
282       continue;
283     auto map = it.value();
284     for (auto result : map.getResults()) {
285       auto expr = result.dyn_cast<AffineDimExpr>();
286       if (expr != nullptr) {
287         dimsToCheck.erase(expr.getPosition());
288       }
289       if (dimsToCheck.empty())
290         return false;
291     }
292   }
293   return true;
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // Region builder helper.
298 // TODO: Move this to a utility library.
299 // The public methods on this class are referenced directly from generated code.
300 // Helper build the unary, binary, and type conversion functions defined by the
301 // DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
302 //
303 // Implementations of the math functions must be polymorphic over numeric types,
304 // internally performing necessary casts. If the function application makes no
305 // sense, then the only recourse is to assert and return nullptr. This can be
306 // extended later if it becomes possible to fail construction of the region. The
307 // invariant should be enforced at a higher level.
308 //
309 // TODO: These helpers are currently type polymorphic over the class of integer
310 // and floating point types, but they will not internally cast within bit
311 // widths of a class (mixed precision such as i8->i32) or across classes
312 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
313 // to be handled with care and work is being considered to extend the op
314 // language to make such cases explicit. In the mean-time, violating this will
315 // fail verification, which is deemed acceptable.
316 //===----------------------------------------------------------------------===//
317 
318 namespace {
319 
320 class RegionBuilderHelper {
321 public:
322   RegionBuilderHelper(MLIRContext *context, Block &block)
323       : context(context), block(block) {}
324 
325   // Build the unary functions defined by OpDSL.
326   Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
327     if (!isFloatingPoint(arg))
328       llvm_unreachable("unsupported non numeric type");
329     OpBuilder builder = getBuilder();
330     switch (unaryFn) {
331     case UnaryFn::exp:
332       return builder.create<math::ExpOp>(arg.getLoc(), arg);
333     case UnaryFn::log:
334       return builder.create<math::LogOp>(arg.getLoc(), arg);
335     case UnaryFn::abs:
336       return builder.create<math::AbsOp>(arg.getLoc(), arg);
337     case UnaryFn::ceil:
338       return builder.create<math::CeilOp>(arg.getLoc(), arg);
339     case UnaryFn::floor:
340       return builder.create<math::FloorOp>(arg.getLoc(), arg);
341     case UnaryFn::negf:
342       return builder.create<arith::NegFOp>(arg.getLoc(), arg);
343     }
344     llvm_unreachable("unsupported unary function");
345   }
346 
347   // Build the binary functions defined by OpDSL.
348   Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
349     bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
350     bool allInteger = isInteger(arg0) && isInteger(arg1);
351     if (!allFloatingPoint && !allInteger)
352       llvm_unreachable("unsupported non numeric type");
353     OpBuilder builder = getBuilder();
354     switch (binaryFn) {
355     case BinaryFn::add:
356       if (allFloatingPoint)
357         return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
358       return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
359     case BinaryFn::sub:
360       if (allFloatingPoint)
361         return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
362       return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
363     case BinaryFn::mul:
364       if (allFloatingPoint)
365         return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
366       return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
367     case BinaryFn::max_signed:
368       if (allFloatingPoint)
369         return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
370       return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
371     case BinaryFn::min_signed:
372       if (allFloatingPoint)
373         return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
374       return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
375     case BinaryFn::max_unsigned:
376       if (allFloatingPoint)
377         return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
378       return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
379     case BinaryFn::min_unsigned:
380       if (allFloatingPoint)
381         return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
382       return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
383     }
384     llvm_unreachable("unsupported binary function");
385   }
386 
387   // Build the type functions defined by OpDSL.
388   Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
389     switch (typeFn) {
390     case TypeFn::cast_signed:
391       return cast(toType, operand, false);
392     case TypeFn::cast_unsigned:
393       return cast(toType, operand, true);
394     }
395     llvm_unreachable("unsupported type conversion function");
396   }
397 
398   void yieldOutputs(ValueRange values) {
399     OpBuilder builder = getBuilder();
400     Location loc = builder.getUnknownLoc();
401     builder.create<YieldOp>(loc, values);
402   }
403 
404   Value constant(const std::string &value) {
405     OpBuilder builder = getBuilder();
406     Location loc = builder.getUnknownLoc();
407     Attribute valueAttr = parseAttribute(value, builder.getContext());
408     return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
409                                              valueAttr);
410   }
411 
412   Value index(int64_t dim) {
413     OpBuilder builder = getBuilder();
414     return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
415   }
416 
417   Type getIntegerType(unsigned width) {
418     return IntegerType::get(context, width);
419   }
420 
421   Type getFloat32Type() { return Float32Type::get(context); }
422   Type getFloat64Type() { return Float64Type::get(context); }
423 
424 private:
425   // Generates operations to cast the given operand to a specified type.
426   // If the cast cannot be performed, a warning will be issued and the
427   // operand returned as-is (which will presumably yield a verification
428   // issue downstream).
429   Value cast(Type toType, Value operand, bool isUnsignedCast) {
430     OpBuilder builder = getBuilder();
431     auto loc = operand.getLoc();
432 
433     if (operand.getType() == toType)
434       return operand;
435     if (auto toIntType = toType.dyn_cast<IntegerType>()) {
436       // If operand is floating point, cast directly to the int type.
437       if (operand.getType().isa<FloatType>()) {
438         if (isUnsignedCast)
439           return builder.create<arith::FPToUIOp>(loc, toType, operand);
440         return builder.create<arith::FPToSIOp>(loc, toType, operand);
441       }
442       // Cast index operands directly to the int type.
443       if (operand.getType().isIndex())
444         return builder.create<arith::IndexCastOp>(loc, toType, operand);
445       if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
446         // Either extend or truncate.
447         if (toIntType.getWidth() > fromIntType.getWidth()) {
448           if (isUnsignedCast)
449             return builder.create<arith::ExtUIOp>(loc, toType, operand);
450           return builder.create<arith::ExtSIOp>(loc, toType, operand);
451         }
452         if (toIntType.getWidth() < fromIntType.getWidth())
453           return builder.create<arith::TruncIOp>(loc, toType, operand);
454       }
455     } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
456       // If operand is integer, cast directly to the float type.
457       // Note that it is unclear how to cast from BF16<->FP16.
458       if (operand.getType().isa<IntegerType>()) {
459         if (isUnsignedCast)
460           return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
461         return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
462       }
463       if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
464         if (toFloatType.getWidth() > fromFloatType.getWidth())
465           return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
466         if (toFloatType.getWidth() < fromFloatType.getWidth())
467           return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
468       }
469     }
470 
471     emitWarning(operand.getLoc()) << "could not cast operand of type "
472                                   << operand.getType() << " to " << toType;
473     return operand;
474   }
475 
476   bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
477   bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
478 
479   OpBuilder getBuilder() {
480     OpBuilder builder(context);
481     builder.setInsertionPointToEnd(&block);
482     return builder;
483   }
484 
485   MLIRContext *context;
486   Block &block;
487 };
488 
489 } // namespace
490 
491 //===----------------------------------------------------------------------===//
492 // FillOp
493 //===----------------------------------------------------------------------===//
494 
495 namespace {
496 
497 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
498 ///
499 /// For such op chains, we can create new linalg.fill ops with the result
500 /// type of the tensor.expand/collapse_shape op.
501 template <typename TensorReshapeOp>
502 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
503   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
504   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
505                                 PatternRewriter &rewriter) const override {
506     auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>();
507     if (!oldFill)
508       return failure();
509 
510     Location loc = oldFill.getLoc();
511     auto newInit = rewriter.create<TensorReshapeOp>(
512         loc, reshapeOp.getResultType(), oldFill.output(),
513         reshapeOp.reassociation());
514     rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
515                                         ValueRange{newInit});
516 
517     return success();
518   }
519 };
520 
521 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
522 /// filling value are the same.
523 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
524   using OpRewritePattern::OpRewritePattern;
525 
526   LogicalResult matchAndRewrite(tensor::PadOp padOp,
527                                 PatternRewriter &rewriter) const override {
528     auto fillOp = padOp.source().getDefiningOp<linalg::FillOp>();
529     if (!fillOp)
530       return failure();
531 
532     // We can only fold if the padding value is the same as the original
533     // filling value.
534     Value padValue = padOp.getConstantPaddingValue();
535     if (!padValue || fillOp.value() != padValue)
536       return failure();
537 
538     ReifiedRankedShapedTypeDims reifiedShape;
539     ReifyRankedShapedTypeOpInterface interface =
540         cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
541     if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
542       return rewriter.notifyMatchFailure(
543           padOp, "failed to reify tensor.pad op result shape");
544 
545     auto oldResultType = padOp.getResultType();
546     SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
547                                         ShapedType::kDynamicSize);
548     auto newInitOp = rewriter.create<InitTensorOp>(
549         padOp.getLoc(), reifiedShape.front(), staticShape,
550         oldResultType.getElementType());
551     auto newFillOp = rewriter.create<FillOp>(
552         fillOp.getLoc(), ValueRange{padValue}, ValueRange{newInitOp});
553     rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
554                                                 newFillOp.result());
555 
556     return success();
557   }
558 };
559 
560 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
561 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
562 /// filling value are the same.
563 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
564   using OpRewritePattern::OpRewritePattern;
565 
566   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
567                                 PatternRewriter &rewriter) const override {
568     auto srcPadOp = insertOp.source().getDefiningOp<tensor::PadOp>();
569     if (!srcPadOp)
570       return failure();
571 
572     if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
573       return failure();
574 
575     // Walk back the tensor.insert_slice chain and find the first destination
576     // value at the start of the chain.
577     Value firstDest = insertOp.dest();
578     while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
579       if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
580         return failure();
581 
582       // Make sure the range of values accessed are disjoint. Without this, we
583       // cannot fold tensor.pad away.
584       bool disjoint = false;
585       for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
586         // If the dimension has dynamic offset/size, we cannot guarantee
587         // disjoint. So just skip it.
588         if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
589             insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
590             prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
591           continue;
592 
593         // Get the range start and end, inclusively for both.
594         int64_t prevStart = prevOp.getStaticOffset(i);
595         int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
596                                           prevOp.getStaticStride(i);
597         int64_t nextStart = insertOp.getStaticOffset(i);
598         int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
599                                           insertOp.getStaticStride(i);
600         if (prevEnd < nextStart || nextEnd < prevStart) {
601           disjoint = true;
602           break;
603         }
604       }
605 
606       if (!disjoint)
607         break;
608       firstDest = prevOp.dest();
609     }
610 
611     // Check whether the first destination is a fill op. For overlapped cases,
612     // this also cannot be true.
613     auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
614     if (!dstFillOp)
615       return failure();
616 
617     // We can only fold if the padding value is the same as the original
618     // filling value.
619     Value padValue = srcPadOp.getConstantPaddingValue();
620     if (!padValue || dstFillOp.value() != padValue)
621       return failure();
622 
623     SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
624     SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
625 
626     Location loc = insertOp.getLoc();
627     MLIRContext *context = getContext();
628 
629     AffineExpr sym0, sym1;
630     bindSymbols(context, sym0, sym1);
631     auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
632 
633     // Calculate the new offsets for the insert. It should be the old offsets
634     // plus low padding sizes.
635     SmallVector<OpFoldResult, 4> newOffsets;
636     for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
637       Value padValue = getValueOrCreateConstantIndexOp(
638           rewriter, srcPadOp.getLoc(), std::get<0>(p));
639       Value offsetValue = getValueOrCreateConstantIndexOp(
640           rewriter, insertOp.getLoc(), std::get<1>(p));
641       newOffsets.push_back(
642           applyMapToValues(rewriter, loc, addMap, {offsetValue, padValue})[0]);
643     }
644 
645     SmallVector<OpFoldResult, 4> newSizes;
646     for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
647       newSizes.push_back(
648           rewriter.create<tensor::DimOp>(loc, srcPadOp.source(), i).result());
649     }
650 
651     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
652         insertOp, srcPadOp.source(), insertOp.dest(), newOffsets, newSizes,
653         insertOp.getMixedStrides());
654     return success();
655   }
656 };
657 
658 } // namespace
659 
660 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
661                                          MLIRContext *context) {
662   results
663       .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
664            FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
665            FoldInsertPadIntoFill>(context);
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // GenericOps
670 //===----------------------------------------------------------------------===//
671 void GenericOp::build(
672     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
673     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
674     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
675     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
676     ArrayRef<NamedAttribute> attributes) {
677   build(builder, result, resultTensorTypes, inputs, outputs,
678         builder.getAffineMapArrayAttr(indexingMaps),
679         builder.getStrArrayAttr(iteratorTypes),
680         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
681         libraryCall.empty() ? StringAttr()
682                             : builder.getStringAttr(libraryCall));
683   result.addAttributes(attributes);
684   if (!bodyBuild)
685     return;
686 
687   SmallVector<Type, 4> blockArgTypes;
688   SmallVector<Location, 4> blockArgLocs;
689   for (ValueRange container : {inputs, outputs}) {
690     for (Value v : container) {
691       blockArgTypes.push_back(getElementTypeOrSelf(v));
692       blockArgLocs.push_back(v.getLoc());
693     }
694   }
695 
696   OpBuilder::InsertionGuard guard(builder);
697   auto &region = *result.regions.front();
698   Block *bodyBlock =
699       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
700   bodyBuild(builder, result.location, bodyBlock->getArguments());
701 }
702 
703 void GenericOp::build(
704     OpBuilder &builder, OperationState &result, ValueRange inputs,
705     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
706     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
707     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
708     ArrayRef<NamedAttribute> attributes) {
709   build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
710         iteratorTypes, doc, libraryCall, bodyBuild, attributes);
711 }
712 
713 void GenericOp::build(
714     OpBuilder &builder, OperationState &result, ValueRange inputs,
715     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
716     ArrayRef<StringRef> iteratorTypes,
717     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
718     ArrayRef<NamedAttribute> attributes) {
719   build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
720         /*doc=*/"",
721         /*libraryCall=*/"", bodyBuild, attributes);
722 }
723 
724 void GenericOp::build(
725     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
726     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
727     ArrayRef<StringRef> iteratorTypes,
728     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
729     ArrayRef<NamedAttribute> attributes) {
730   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
731         iteratorTypes,
732         /*doc=*/"",
733         /*libraryCall=*/"", bodyBuild, attributes);
734 }
735 
736 void GenericOp::print(OpAsmPrinter &p) {
737   p << " ";
738 
739   // Print extra attributes.
740   auto genericAttrNames = linalgTraitAttrNames();
741 
742   llvm::StringSet<> genericAttrNamesSet;
743   genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
744   SmallVector<NamedAttribute, 8> genericAttrs;
745   for (auto attr : (*this)->getAttrs())
746     if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
747       genericAttrs.push_back(attr);
748   if (!genericAttrs.empty()) {
749     auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
750     p << genericDictAttr;
751   }
752 
753   // Printing is shared with named ops, except for the region and attributes
754   printCommonStructuredOpParts(p, inputs(), outputs());
755 
756   genericAttrNames.push_back("operand_segment_sizes");
757   genericAttrNamesSet.insert(genericAttrNames.back());
758 
759   bool hasExtraAttrs = false;
760   for (NamedAttribute n : (*this)->getAttrs()) {
761     if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
762       break;
763   }
764   if (hasExtraAttrs) {
765     p << " attrs = ";
766     p.printOptionalAttrDict((*this)->getAttrs(),
767                             /*elidedAttrs=*/genericAttrNames);
768   }
769 
770   // Print region.
771   if (!region().empty()) {
772     p << ' ';
773     p.printRegion(region());
774   }
775 
776   // Print results.
777   printNamedStructuredOpResults(p, result_tensors().getTypes());
778 }
779 
780 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
781   DictionaryAttr dictAttr;
782   // Parse the core linalg traits that must check into a dictAttr.
783   // The name is unimportant as we will overwrite result.attributes.
784   // The core linalg traits must contain the information necessary to pass the
785   // verifier.
786   if (parser.parseAttribute(dictAttr, "_", result.attributes))
787     return failure();
788   result.attributes.assign(dictAttr.getValue().begin(),
789                            dictAttr.getValue().end());
790 
791   // Parsing is shared with named ops, except for the region.
792   SmallVector<Type, 1> inputTypes, outputTypes;
793   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
794     return failure();
795 
796   // Optional attributes may be added.
797   if (succeeded(parser.parseOptionalKeyword("attrs")))
798     if (failed(parser.parseEqual()) ||
799         failed(parser.parseOptionalAttrDict(result.attributes)))
800       return failure();
801 
802   SmallVector<OpAsmParser::UnresolvedOperand, 8> regionOperands;
803   std::unique_ptr<Region> region = std::make_unique<Region>();
804   SmallVector<Type, 8> operandTypes, regionTypes;
805   if (parser.parseRegion(*region, regionOperands, regionTypes))
806     return failure();
807   result.addRegion(std::move(region));
808 
809   // Generic ops may specify that a subset of its outputs are tensors. Such
810   // outputs are specified in the result type.
811   // TODO: may need to move output parsing before region parsing.
812   // Need to wait for declarative assembly resolution to decide.
813   SmallVector<Type, 1> outputTensorsTypes;
814   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
815     return failure();
816   result.addTypes(outputTensorsTypes);
817 
818   return success();
819 }
820 
821 static void getGenericEffectsImpl(
822     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
823         &effects,
824     ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
825   for (Value value : inputBuffers) {
826     effects.emplace_back(MemoryEffects::Read::get(), value,
827                          SideEffects::DefaultResource::get());
828   }
829   for (Value value : outputs) {
830     effects.emplace_back(MemoryEffects::Read::get(), value,
831                          SideEffects::DefaultResource::get());
832     effects.emplace_back(MemoryEffects::Write::get(), value,
833                          SideEffects::DefaultResource::get());
834   }
835 }
836 
837 void GenericOp::getEffects(
838     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
839         &effects) {
840   SmallVector<Value> inputBuffers = getInputBufferOperands();
841   SmallVector<Value> outputBuffers = getOutputBufferOperands();
842   getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
843                         outputBuffers);
844 }
845 
846 LogicalResult GenericOp::verify() { return success(); }
847 
848 namespace {
849 // Deduplicate redundant args of a linalg generic op.
850 // An arg is redundant if it has the same Value and indexing map as another.
851 struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
852   using OpRewritePattern<GenericOp>::OpRewritePattern;
853 
854   LogicalResult matchAndRewrite(GenericOp genericOp,
855                                 PatternRewriter &rewriter) const override {
856     // Associate each input to an equivalent "canonical" input that has the same
857     // Value and indexing map.
858     //
859     // In the non-duplicate case, input `i` will have canonical input `i`. But
860     // in the case of duplicated inputs, the canonical input could be some other
861     // input `< i`. That is, a later input will have some earlier input as its
862     // canonical input.
863     llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
864     // For later remapping tasks like deduplicating payload block arguments,
865     // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
866     // convenient.
867     SmallVector<unsigned> canonicalInputIndices;
868     for (OpOperand *opOperand : genericOp.getInputOperands()) {
869       AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
870       // STL-like maps have a convenient behavior for our use case here. In the
871       // case of duplicate keys, the insertion is rejected, and the returned
872       // iterator gives access to the value already in the map.
873       auto pair = canonicalInput.insert(
874           {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
875       canonicalInputIndices.push_back(pair.first->second);
876     }
877 
878     // If there are no duplicate args, then bail out.
879     if (canonicalInput.size() == genericOp.getNumInputs())
880       return failure();
881 
882     // The operands for the newly canonicalized op.
883     SmallVector<Value> newInputOperands;
884     for (OpOperand *opOperand : genericOp.getInputOperands())
885       if (canonicalInputIndices[opOperand->getOperandNumber()] ==
886           opOperand->getOperandNumber())
887         newInputOperands.push_back(opOperand->get());
888 
889     // Repair the indexing maps by filtering out the ones that have been
890     // eliminated.
891     SmallVector<AffineMap> newIndexingMaps;
892     for (OpOperand *opOperand : genericOp.getInputOperands())
893       if (canonicalInputIndices[opOperand->getOperandNumber()] ==
894           opOperand->getOperandNumber())
895         newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
896     for (OpOperand *opOperand : genericOp.getOutputOperands())
897       newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
898 
899     // Clone the old op with new operands.
900     SmallVector<Value> outputOperands = genericOp.getOutputOperands();
901     auto newOp = rewriter.create<GenericOp>(
902         genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
903         outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
904         genericOp.iterator_types(), genericOp.docAttr(),
905         genericOp.library_callAttr());
906 
907     // Copy over unknown attributes. They might be load bearing for some flow.
908     ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
909     for (NamedAttribute kv : genericOp->getAttrs()) {
910       if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
911         newOp->setAttr(kv.getName(), kv.getValue());
912       }
913     }
914 
915     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
916                                 newOp.region().begin());
917 
918     // Repair the payload entry block by RAUW'ing redundant arguments and
919     // erasing them.
920     Block &payload = newOp.region().front();
921     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
922     for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
923       // Iterate in reverse, so that we erase later args first, preventing the
924       // argument list from shifting unexpectedly and invalidating all our
925       // indices.
926       unsigned operandNumber = opOperand->getOperandNumber();
927       if (canonicalInputIndices[operandNumber] == operandNumber)
928         continue;
929       payload.getArgument(operandNumber)
930           .replaceAllUsesWith(
931               payload.getArgument(canonicalInputIndices[operandNumber]));
932       payload.eraseArgument(operandNumber);
933     }
934 
935     rewriter.replaceOp(genericOp, newOp->getResults());
936     return success();
937   }
938 };
939 
940 /// Remove generic operations (on tensors) that are just copying
941 /// the values from inputs to the results. Requirements are
942 /// 1) All iterator types are parallel
943 /// 2) The body contains just a yield operation with the yielded values being
944 ///    the arguments corresponding to the operands.
945 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
946   using OpRewritePattern<GenericOp>::OpRewritePattern;
947 
948   LogicalResult matchAndRewrite(GenericOp genericOp,
949                                 PatternRewriter &rewriter) const override {
950     // Check all indexing maps are identity.
951     if (llvm::any_of(genericOp.getIndexingMaps(),
952                      [](AffineMap map) { return !map.isIdentity(); }))
953       return failure();
954 
955     // Check that the body of the linalg operation is just a linalg.yield
956     // operation.
957     Block &body = genericOp.region().front();
958     if (!llvm::hasSingleElement(body))
959       return failure();
960     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
961     if (!yieldOp)
962       return failure();
963 
964     // In the buffer case, we need to check exact buffer equality.
965     if (genericOp.hasBufferSemantics()) {
966       if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
967           genericOp.getInputOperand(0)->get() ==
968               genericOp.getOutputOperand(0)->get()) {
969         rewriter.eraseOp(genericOp);
970         return success();
971       }
972       return failure();
973     }
974 
975     // Get the argument number of the returned values. That is the operand
976     // number to use for replacing uses of this operation.
977     SmallVector<Value> returnedArgs;
978     for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) {
979       auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
980       if (!yieldArg || yieldArg.getOwner() != &body)
981         return failure();
982       unsigned argumentNumber = yieldArg.getArgNumber();
983       Value returnedArg = genericOp->getOperand(argumentNumber);
984       Type resultType = genericOp->getResult(yieldVal.index()).getType();
985       // The input can have a different type than the result, e.g. a dynamic
986       // input dimension can be turned into a static output dimension.
987       Type returnType = returnedArg.getType();
988       if (returnType != resultType) {
989         // Distinguish between sparse conversion or dense tensor casting.
990         // TODO: unify the two ops?
991         if (sparse_tensor::getSparseTensorEncoding(returnType) ||
992             sparse_tensor::getSparseTensorEncoding(resultType))
993           returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
994               genericOp.getLoc(), resultType, returnedArg);
995         else {
996           if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
997                                                  resultType))
998             return failure();
999           returnedArg = rewriter.create<tensor::CastOp>(
1000               genericOp.getLoc(), resultType, returnedArg);
1001         }
1002       }
1003       returnedArgs.push_back(returnedArg);
1004     }
1005 
1006     if (returnedArgs.size() != genericOp->getNumResults())
1007       return failure();
1008     rewriter.replaceOp(genericOp, returnedArgs);
1009     return success();
1010   }
1011 };
1012 
1013 /// Drop dead args of a linalg generic op.
1014 /// An arg is dead if it has zero uses in the op region.
1015 struct DeadArgsGenericOpInputs : public OpRewritePattern<GenericOp> {
1016   using OpRewritePattern<GenericOp>::OpRewritePattern;
1017   LogicalResult matchAndRewrite(GenericOp genericOp,
1018                                 PatternRewriter &rewriter) const override {
1019     SmallVector<AffineMap> oldIndexingMaps = genericOp.getIndexingMaps();
1020     // Maps must be projected permutations.
1021     if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) {
1022           return !map.isProjectedPermutation();
1023         }))
1024       return failure();
1025     Block &payload = genericOp.region().front();
1026     SmallVector<Value> newInputOperands;
1027     SmallVector<AffineMap> newIndexingMaps;
1028     bool deadArgFound = false;
1029     int inputSize = genericOp.getInputOperands().size();
1030     for (int i = inputSize - 1; i >= 0; i--) {
1031       OpOperand *opOperand = genericOp.getInputOperand(i);
1032       // Iterate in reverse, so that we erase later args first, preventing the
1033       // argument list from shifting unexpectedly and invalidating all our
1034       // indices.
1035       if (payload.getArgument(i).use_empty() &&
1036           !hasaUniqueDim(oldIndexingMaps, i)) {
1037         payload.eraseArgument(i);
1038         deadArgFound = true;
1039         // remove this indexing map out of consideration for hasaUniqueDim check
1040         oldIndexingMaps.erase(oldIndexingMaps.begin() + i);
1041       } else {
1042         newInputOperands.insert(newInputOperands.begin(), opOperand->get());
1043         newIndexingMaps.insert(newIndexingMaps.begin(),
1044                                genericOp.getTiedIndexingMap(opOperand));
1045       }
1046     }
1047     // Bail out if there are no dead args.
1048     if (!deadArgFound)
1049       return failure();
1050     for (OpOperand *opOperand : genericOp.getOutputOperands())
1051       newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
1052     SmallVector<Value> outputOperands = genericOp.getOutputOperands();
1053 
1054     auto newOp = rewriter.create<GenericOp>(
1055         genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
1056         outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
1057         genericOp.iterator_types(), genericOp.docAttr(),
1058         genericOp.library_callAttr());
1059     // Copy over unknown attributes. They might be load bearing for some flow.
1060     ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
1061     for (NamedAttribute kv : genericOp->getAttrs()) {
1062       if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
1063         newOp->setAttr(kv.getName(), kv.getValue());
1064       }
1065     }
1066     rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
1067                                 newOp.region().begin());
1068     rewriter.replaceOp(genericOp, newOp->getResults());
1069     return success();
1070   }
1071 };
1072 } // namespace
1073 
1074 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1075                                             MLIRContext *context) {
1076   results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
1077               DeadArgsGenericOpInputs>(context);
1078 }
1079 
1080 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
1081                               SmallVectorImpl<OpFoldResult> &) {
1082   return foldMemRefCast(*this);
1083 }
1084 
1085 //===----------------------------------------------------------------------===//
1086 // InitTensorOp
1087 //===----------------------------------------------------------------------===//
1088 
1089 void InitTensorOp::build(OpBuilder &b, OperationState &result,
1090                          ArrayRef<OpFoldResult> sizes, Type elementType,
1091                          ArrayRef<NamedAttribute> attrs) {
1092   SmallVector<Value, 4> dynamicSizes;
1093   SmallVector<int64_t, 4> staticSizes;
1094   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1095                              ShapedType::kDynamicSize);
1096   auto resultType = RankedTensorType ::get(staticSizes, elementType);
1097   build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
1098   result.addAttributes(attrs);
1099 }
1100 
1101 LogicalResult InitTensorOp::verify() {
1102   RankedTensorType resultType = getType();
1103   SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
1104       static_sizes().cast<ArrayAttr>(),
1105       [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
1106 
1107   if (failed(verifyListOfOperandsOrIntegers(
1108           *this, "sizes", resultType.getRank(), static_sizes(), sizes(),
1109           ShapedType::isDynamic)))
1110     return failure();
1111 
1112   if (static_sizes().size() != static_cast<unsigned>(resultType.getRank()))
1113     return emitError("expected ") << resultType.getRank() << " sizes values";
1114 
1115   Type expectedType = InitTensorOp::inferResultType(
1116       staticSizes, resultType.getElementType(), resultType.getEncoding());
1117   if (resultType != expectedType) {
1118     return emitError("specified type ")
1119            << resultType << " does not match the inferred type "
1120            << expectedType;
1121   }
1122   return success();
1123 }
1124 
1125 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
1126                                    Type elementType, Attribute encoding) {
1127   return RankedTensorType::get(staticSizes, elementType, encoding);
1128 }
1129 
1130 SmallVector<OpFoldResult> InitTensorOp::getMixedSizes() {
1131   SmallVector<OpFoldResult> mixedSizes;
1132   mixedSizes.reserve(getType().getRank());
1133   unsigned dynamicValIndex = 0;
1134   for (Attribute attr : static_sizes()) {
1135     auto intAttr = attr.cast<IntegerAttr>();
1136     if (!ShapedType::isDynamic(intAttr.getInt())) {
1137       mixedSizes.push_back(intAttr);
1138       continue;
1139     }
1140     mixedSizes.push_back(sizes()[dynamicValIndex++]);
1141   }
1142   return mixedSizes;
1143 }
1144 
1145 namespace {
1146 /// Change the type of the result of a `linalg.init_tensor` by making the result
1147 /// type statically sized along dimension that in the original operation where
1148 /// defined as dynamic, but the size was defined using a `constant` op. For
1149 /// example
1150 ///
1151 ///  %c5 = arith.constant 5: index
1152 ///  %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
1153 ///
1154 ///  to
1155 ///
1156 ///  %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
1157 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
1158   using OpRewritePattern<InitTensorOp>::OpRewritePattern;
1159 
1160   LogicalResult matchAndRewrite(InitTensorOp op,
1161                                 PatternRewriter &rewriter) const override {
1162     SmallVector<Value, 4> dynamicSizes;
1163     SmallVector<int64_t, 4> staticSizes;
1164     for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
1165       // If the size is already static, nothing to do.
1166       if (!op.isDynamicSize(i)) {
1167         staticSizes.push_back(op.getStaticSize(i));
1168         continue;
1169       }
1170 
1171       // If the size is dynamic but defined using a `constant` op, get the
1172       // constant value to find the static size to use.
1173       unsigned operandNum = op.getIndexOfDynamicSize(i);
1174       Value sizeOperand = op.getOperand(operandNum);
1175       if (auto constantIndexOp =
1176               sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) {
1177         staticSizes.push_back(constantIndexOp.value());
1178         continue;
1179       }
1180 
1181       // Fallback case. Keep the size dynamic.
1182       dynamicSizes.push_back(sizeOperand);
1183       staticSizes.push_back(ShapedType::kDynamicSize);
1184     }
1185     RankedTensorType newType =
1186         RankedTensorType::get(staticSizes, op.getType().getElementType());
1187     if (newType == op.getType())
1188       return failure();
1189     auto newOp =
1190         rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
1191                                       rewriter.getI64ArrayAttr(staticSizes));
1192     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1193     return success();
1194   }
1195 };
1196 } // namespace
1197 
1198 namespace {
1199 /// Since `init_tensor` operation creates a tensor needed only for its shape, a
1200 /// slice of this is also needed only for its shape. The result can be
1201 /// replaced by a new init_tensor operation of the same size as the extract
1202 /// slice op.
1203 struct FoldInitTensorWithExtractSliceOp
1204     : public OpRewritePattern<tensor::ExtractSliceOp> {
1205   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
1206 
1207   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
1208                                 PatternRewriter &rewriter) const override {
1209     if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>())
1210       return failure();
1211     // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
1212     // as well as its result type.
1213     rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
1214         sliceOp, sliceOp.sizes(),
1215         sliceOp.result().getType().cast<RankedTensorType>().getShape(),
1216         sliceOp.getSourceType().getElementType());
1217     return success();
1218   }
1219 };
1220 
1221 template <typename TensorReshapeOp>
1222 struct FoldInitTensorWithTensorReshapeOp
1223     : public OpRewritePattern<TensorReshapeOp> {
1224   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1225 
1226   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1227                                 PatternRewriter &rewriter) const override {
1228     if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
1229       return failure();
1230     Location loc = reshapeOp.getLoc();
1231     ReifiedRankedShapedTypeDims resultShapes;
1232     ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
1233         cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
1234     if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
1235                                                           resultShapes)) ||
1236         !llvm::hasSingleElement(resultShapes))
1237       return failure();
1238     Value initTensor = rewriter.create<InitTensorOp>(
1239         loc, getAsOpFoldResult(resultShapes[0]),
1240         reshapeOp.getResultType().getElementType());
1241     if (initTensor.getType() != reshapeOp.getResultType()) {
1242       rewriter.replaceOpWithNewOp<tensor::CastOp>(
1243           reshapeOp, reshapeOp.getResultType(), initTensor);
1244     } else {
1245       rewriter.replaceOp(reshapeOp, initTensor);
1246     }
1247     return success();
1248   }
1249 };
1250 
1251 struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
1252   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1253 
1254   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1255                                 PatternRewriter &rewriter) const override {
1256     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1257     auto initTensorOp = dimOp.source().getDefiningOp<linalg::InitTensorOp>();
1258     if (!initTensorOp || !maybeConstantIndex)
1259       return failure();
1260     if (!initTensorOp.isDynamicSize(*maybeConstantIndex))
1261       return failure();
1262     rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex));
1263     return success();
1264   }
1265 };
1266 
1267 /// Canonicalize
1268 ///
1269 /// ```mlir
1270 ///   %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
1271 ///   %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1272 /// ```
1273 ///
1274 /// into
1275 ///
1276 /// ```mlir
1277 ///   %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32>
1278 /// ```
1279 ///
1280 /// This assumes the input program is correct in terms of its shape. So it
1281 /// is safe to assume that `%d0` is in fact 4. If that was not the case, the
1282 /// input program is wrong to begin with, so its undefined behavior anyway (i.e.
1283 /// this optimization can still triggering without violating program semantics).
1284 struct FoldInitTensorWithTensorCastOp
1285     : public OpRewritePattern<tensor::CastOp> {
1286   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1287 
1288   LogicalResult matchAndRewrite(tensor::CastOp castOp,
1289                                 PatternRewriter &rewriter) const override {
1290     if (!canFoldIntoProducerOp(castOp))
1291       return failure();
1292     auto producer = castOp.source().getDefiningOp<InitTensorOp>();
1293     if (!producer)
1294       return failure();
1295 
1296     auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
1297     ArrayRef<int64_t> resultShape = resultType.getShape();
1298     SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1299     SmallVector<OpFoldResult> newMixedSizes;
1300     newMixedSizes.reserve(currMixedSizes.size());
1301     assert(resultShape.size() == currMixedSizes.size() &&
1302            "mismatch in result shape and sizes of init_tensor op");
1303     for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1304       int64_t newDim = std::get<0>(it);
1305       OpFoldResult currDim = std::get<1>(it);
1306       // Case 1: The init tensor dim is static. Check that the tensor cast
1307       // result dim matches.
1308       if (auto attr = currDim.dyn_cast<Attribute>()) {
1309         if (ShapedType::isDynamic(newDim) ||
1310             newDim != attr.cast<IntegerAttr>().getInt()) {
1311           // Something is off, the cast result shape cannot be more dynamic than
1312           // the init tensor result shape (enforced by `canFoldIntoProducer`).
1313           // Abort for now.
1314           return rewriter.notifyMatchFailure(
1315               producer, "mismatch in static value of shape of init "
1316                         "tensor result and cast result");
1317         }
1318         newMixedSizes.push_back(attr);
1319         continue;
1320       }
1321 
1322       // Case 2 : The tensor cast shape is static, but init tensor result shape
1323       // is dynamic.
1324       if (!ShapedType::isDynamic(newDim)) {
1325         newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1326         continue;
1327       }
1328 
1329       // Case 3 : The tensor cast shape is dynamic and init tensor result shape
1330       // is dynamic. Use the dynamic value from the init tensor op.
1331       newMixedSizes.push_back(currDim);
1332     }
1333 
1334     rewriter.replaceOpWithNewOp<InitTensorOp>(castOp, newMixedSizes,
1335                                               resultType.getElementType());
1336     return success();
1337   }
1338 };
1339 
1340 } // namespace
1341 
1342 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
1343                                                MLIRContext *context) {
1344   results.add<FoldInitTensorWithTensorCastOp, FoldInitTensorWithDimOp,
1345               FoldInitTensorWithExtractSliceOp,
1346               FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>,
1347               FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>,
1348               ReplaceStaticShapeDims>(context);
1349 }
1350 
1351 LogicalResult InitTensorOp::reifyResultShapes(
1352     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1353   auto shapes = llvm::to_vector<4>(llvm::map_range(
1354       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
1355         if (isDynamicSize(dim))
1356           return getDynamicSize(dim);
1357         return builder.create<arith::ConstantIndexOp>(getLoc(),
1358                                                       getStaticSize(dim));
1359       }));
1360   reifiedReturnShapes.emplace_back(std::move(shapes));
1361   return success();
1362 }
1363 
1364 //===----------------------------------------------------------------------===//
1365 // YieldOp
1366 //===----------------------------------------------------------------------===//
1367 
1368 void linalg::YieldOp::print(OpAsmPrinter &p) {
1369   if (getNumOperands() > 0)
1370     p << ' ' << getOperands();
1371   p.printOptionalAttrDict((*this)->getAttrs());
1372   if (getNumOperands() > 0)
1373     p << " : " << getOperandTypes();
1374 }
1375 
1376 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
1377   SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
1378   SmallVector<Type, 2> types;
1379   SMLoc loc = parser.getCurrentLocation();
1380   return failure(parser.parseOperandList(opInfo) ||
1381                  parser.parseOptionalAttrDict(result.attributes) ||
1382                  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1383                  parser.resolveOperands(opInfo, types, loc, result.operands));
1384 }
1385 
1386 // Check the operand number and types must match the element types of the
1387 // LinalgOp interface's shaped operands.
1388 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1389   if (op.getNumOperands() != linalgOp.getNumOutputs())
1390     return op.emitOpError("expected number of yield values (")
1391            << linalgOp.getNumOutputs()
1392            << ") to match the number of operands of the enclosing "
1393            << "LinalgOp (" << op.getNumOperands() << ")";
1394 
1395   for (OpOperand &opOperand : op->getOpOperands()) {
1396     OpOperand *outputOperand =
1397         linalgOp.getOutputOperand(opOperand.getOperandNumber());
1398     Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
1399     if (opOperand.get().getType() != elementType)
1400       return op.emitOpError("type of yield operand ")
1401              << (opOperand.getOperandNumber() + 1) << " ("
1402              << opOperand.get().getType() << ") doesn't match "
1403              << "the element type of the enclosing linalg.generic op ("
1404              << elementType << ")";
1405   }
1406   return success();
1407 }
1408 
1409 LogicalResult linalg::YieldOp::verify() {
1410   auto *parentOp = (*this)->getParentOp();
1411   if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1412     return emitOpError("expected single non-empty parent region");
1413 
1414   if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1415     return verifyYield(*this, linalgOp);
1416 
1417   return emitOpError("expected parent op with LinalgOp interface");
1418 }
1419 
1420 //===----------------------------------------------------------------------===//
1421 // IndexOp
1422 //===----------------------------------------------------------------------===//
1423 
1424 LogicalResult IndexOp::verify() {
1425   auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
1426   if (!linalgOp)
1427     return emitOpError("expected parent op with LinalgOp interface");
1428   if (linalgOp.getNumLoops() <= dim())
1429     return emitOpError("expected dim (")
1430            << dim() << ") to be lower than the number of loops ("
1431            << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
1432   return success();
1433 }
1434 
1435 /////// Operations corresponding to library calls defined with Tablegen ////////
1436 
1437 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1438 
1439 #define GET_OP_CLASSES
1440 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1441 
1442 #define GET_OP_CLASSES
1443 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1444 
1445 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
1446 /// Assumes `op` is a LinalgOp.
1447 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
1448                                  SmallVectorImpl<unsigned> &res) {
1449   if (!cast<LinalgOp>(op).iterator_types())
1450     return;
1451 
1452   unsigned dim = 0;
1453   for (auto tn :
1454        cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
1455     if (tn == iteratorTypeName)
1456       res.push_back(dim);
1457     ++dim;
1458   }
1459 }
1460 
1461 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
1462                                              unsigned rank,
1463                                              MLIRContext *context) {
1464   if (maybeMap)
1465     return maybeMap.getValue();
1466   if (rank == 0)
1467     return AffineMap::get(context);
1468   return AffineMap::getMultiDimIdentityMap(rank, context);
1469 }
1470 
1471 SmallVector<AffineExpr, 4>
1472 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
1473                                  MLIRContext *context) {
1474   SmallVector<AffineExpr, 4> res;
1475   res.reserve(num);
1476   for (unsigned i = 0; i < num; ++i)
1477     res.push_back(getAffineDimExpr(startIdx++, context));
1478   return res;
1479 }
1480 
1481 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
1482                                                 ArrayRef<AffineExpr> b) {
1483   auto rangeA = llvm::make_range(a.begin(), a.end());
1484   auto rangeB = llvm::make_range(b.begin(), b.end());
1485   auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1486   return llvm::to_vector<4>(concatRanges);
1487 }
1488 
1489 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
1490   if (auto memref = t.dyn_cast<MemRefType>()) {
1491     ss << "view";
1492     for (auto size : memref.getShape())
1493       if (size < 0)
1494         ss << "sx";
1495       else
1496         ss << size << "x";
1497     appendMangledType(ss, memref.getElementType());
1498   } else if (auto vec = t.dyn_cast<VectorType>()) {
1499     ss << "vector";
1500     llvm::interleave(
1501         vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
1502     appendMangledType(ss, vec.getElementType());
1503   } else if (t.isSignlessIntOrIndexOrFloat()) {
1504     ss << t;
1505   } else {
1506     llvm_unreachable("Invalid type for linalg library name mangling");
1507   }
1508 }
1509 
1510 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
1511   assert(isa<LinalgOp>(op));
1512   std::string name(op->getName().getStringRef().str());
1513   name.reserve(128);
1514   std::replace(name.begin(), name.end(), '.', '_');
1515   llvm::raw_string_ostream ss(name);
1516   ss << "_";
1517   auto types = op->getOperandTypes();
1518   llvm::interleave(
1519       types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
1520       [&]() { ss << "_"; });
1521   return ss.str();
1522 }
1523 
1524 //===----------------------------------------------------------------------===//
1525 // Canonicalizers and Folders.
1526 //===----------------------------------------------------------------------===//
1527 
1528 namespace {
1529 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
1530   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1531 
1532   LogicalResult matchAndRewrite(LinalgOp op,
1533                                 PatternRewriter &rewriter) const override {
1534     for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
1535       // Linalg "inputs" may be either tensor or memref type.
1536       // tensor<0xelt_type> is a convention that may not always mean
1537       // "0 iterations". Only erase in cases we see memref<...x0x...>.
1538       auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
1539       if (!mt)
1540         continue;
1541       if (llvm::is_contained(op.getShape(opOperand), 0)) {
1542         rewriter.eraseOp(op);
1543         return success();
1544       }
1545     }
1546     return failure();
1547   }
1548 };
1549 
1550 struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
1551   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1552 
1553   LogicalResult matchAndRewrite(LinalgOp op,
1554                                 PatternRewriter &rewriter) const override {
1555     // If no operand comes from a tensor::CastOp and can be folded then fail.
1556     bool hasTensorCastOperand =
1557         llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
1558           if (opOperand->get().isa<BlockArgument>())
1559             return false;
1560           auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1561           return castOp && canFoldIntoConsumerOp(castOp);
1562         });
1563     if (!hasTensorCastOperand)
1564       return failure();
1565 
1566     SmallVector<Type, 4> newResultTypes;
1567     newResultTypes.reserve(op->getNumResults());
1568     SmallVector<Value, 4> newOperands;
1569     newOperands.reserve(op->getNumOperands());
1570     // Inputs may fold.
1571     for (OpOperand *opOperand : op.getInputOperands()) {
1572       auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1573       newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
1574                                 ? tensorCastOp.source()
1575                                 : opOperand->get());
1576     }
1577     // Init tensors may fold, in which case the resultType must also change.
1578     for (OpOperand *opOperand : op.getOutputOperands()) {
1579       auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1580       bool fold = canFoldIntoConsumerOp(tensorCastOp);
1581       newOperands.push_back(fold ? tensorCastOp.getOperand()
1582                                  : opOperand->get());
1583       newResultTypes.push_back(newOperands.back().getType());
1584     }
1585     // Clone op.
1586     Operation *newOp =
1587         op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
1588     SmallVector<Value, 4> replacements;
1589     replacements.reserve(newOp->getNumResults());
1590     for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
1591       Value oldResult = std::get<0>(result);
1592       Value newResult = std::get<1>(result);
1593       if (newResult.getType() != oldResult.getType()) {
1594         replacements.push_back(rewriter.create<tensor::CastOp>(
1595             op->getLoc(), oldResult.getType(), newResult));
1596       } else {
1597         replacements.push_back(newResult);
1598       }
1599     }
1600     rewriter.replaceOp(op, replacements);
1601 
1602     return success();
1603   }
1604 };
1605 
1606 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
1607 /// result that is more static than the linalg op.
1608 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
1609   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1610 
1611   LogicalResult matchAndRewrite(tensor::CastOp castOp,
1612                                 PatternRewriter &rewriter) const override {
1613     if (!tensor::canFoldIntoProducerOp(castOp))
1614       return failure();
1615     auto linalgOp = castOp.source().getDefiningOp<LinalgOp>();
1616     if (!linalgOp)
1617       return failure();
1618 
1619     OpBuilder::InsertionGuard guard(rewriter);
1620     rewriter.setInsertionPoint(linalgOp);
1621 
1622     Location loc = linalgOp.getLoc();
1623     OpResult resultValue = castOp.source().cast<OpResult>();
1624     unsigned resultNumber = resultValue.getResultNumber();
1625     auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
1626     // Replace the `outs` for the result with a `tensor.cast`. This cast is now
1627     // going from a more dynamic shape to a less dynamic shape. If the producer
1628     // for this cast, i.e. producer of the out operand, is also an operation
1629     // that folds with tensor.cast consumer (like this pattern), the cast will
1630     // continue to propagate as far up the stack as it can go.
1631     OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
1632     Value newOperand =
1633         rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
1634     SmallVector<Value> newOperands = linalgOp.getInputOperands();
1635     SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
1636     outputOperands[resultNumber] = newOperand;
1637     newOperands.append(outputOperands.begin(), outputOperands.end());
1638 
1639     SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
1640                                   linalgOp->result_type_end());
1641     resultTypes[resultNumber] = resultType;
1642     Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands);
1643 
1644     // Create a tensor.cast operation back to the original type.
1645     Value castBack = rewriter.create<tensor::CastOp>(
1646         loc, resultValue.getType(), newOp->getResult(resultNumber));
1647 
1648     SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
1649     results[resultNumber] = castBack;
1650     rewriter.replaceOp(linalgOp, results);
1651     rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
1652     return success();
1653   }
1654 };
1655 
1656 /// For each of the operand in `operands` this function maps the static sizes of
1657 /// dimensions to their affine dim expressions.
1658 static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
1659                         llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
1660   for (OpOperand *opOperand : operands) {
1661     if (linalgOp.isScalar(opOperand))
1662       continue;
1663     Value src = opOperand->get();
1664     auto sourceType = src.getType().cast<RankedTensorType>();
1665     auto sourceMap = linalgOp.getTiedIndexingMap(opOperand);
1666 
1667     // Get the `sourceShape` of the `sourceType`. If the operand is a result of
1668     // `tensor.cast` operation and source of the cast operation has a static
1669     // shape, then assign it to the `sourceShape`.
1670     auto *parentOp = src.getDefiningOp();
1671     ArrayRef<int64_t> sourceShape = sourceType.getShape();
1672     if (parentOp) {
1673       if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
1674         Value castSource = castOp.source();
1675         auto castSourceType = castSource.getType().cast<RankedTensorType>();
1676         if (castSourceType.hasStaticShape())
1677           sourceShape = castSourceType.getShape();
1678       }
1679     }
1680 
1681     // If the source shape's dimension has a static shape, map the affine dim
1682     // expression to the known static size.
1683     for (unsigned i = 0; i < sourceShape.size(); i++) {
1684       if (sourceType.isDynamicDim(i))
1685         continue;
1686       if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
1687         affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
1688     }
1689   }
1690 }
1691 
1692 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
1693 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
1694 /// their result types is stored in `resultTypes`. If `opOperand` requires no
1695 /// change then `changeNeeded` is false and same operand is added in the
1696 /// `newOperands` list.
1697 static void createNewOperandWithStaticSizes(
1698     Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
1699     llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
1700     SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
1701     bool &changeNeeded) {
1702   Value src = opOperand->get();
1703   newOperands.push_back(src);
1704   if (linalgOp.isScalar(opOperand))
1705     return;
1706   auto sourceType = src.getType().cast<RankedTensorType>();
1707   Type resultType = sourceType;
1708   if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) {
1709     resultTypes.push_back(resultType);
1710     return;
1711   }
1712   ArrayRef<int64_t> sourceShape = sourceType.getShape();
1713   AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand);
1714   SmallVector<int64_t> newShape;
1715   // If operand is updated with new shape, `newOperandNeeded` will be
1716   // true.
1717   bool newOperandNeeded = false;
1718   for (unsigned i = 0; i < sourceShape.size(); i++) {
1719     int64_t dimShape = sourceShape[i];
1720     AffineExpr dimExpr = sourceMap.getResult(i);
1721     if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
1722         !sourceType.isDynamicDim(i)) {
1723       newShape.push_back(dimShape);
1724       continue;
1725     }
1726     // Dimension has a dynamic shape and corresponding affine dim
1727     // expression is present in the map. So assign the size for the
1728     // given affine dim expression to the dimension.
1729     newShape.push_back(affineExprToSize[dimExpr]);
1730     newOperandNeeded = true;
1731   }
1732   resultType = RankedTensorType::get(newShape, sourceType.getElementType());
1733   if (newOperandNeeded) {
1734     changeNeeded = true;
1735     // Get the new operand value given its size and element type by
1736     // casting it.
1737     Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
1738     unsigned index = opOperand->getOperandNumber();
1739     newOperands[index] = newOperand;
1740   }
1741   if (linalgOp.isOutputTensor(opOperand))
1742     resultTypes.push_back(resultType);
1743 }
1744 
1745 /// Static shapes for the operands can be inferred if any one of the operands
1746 /// have a static shape. This can be done by referring to the affine dim
1747 /// expressions for the operand.
1748 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
1749   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1750 
1751   LogicalResult matchAndRewrite(LinalgOp linalgOp,
1752                                 PatternRewriter &rewriter) const override {
1753     if (!linalgOp.hasTensorSemantics())
1754       return failure();
1755 
1756     // Maps must be projected permutations.
1757     if (llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap map) {
1758           return !map.isProjectedPermutation();
1759         }))
1760       return failure();
1761 
1762     // Maps affine dim expressions to the static size of that dimension.
1763     llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
1764     Location loc = linalgOp.getLoc();
1765 
1766     // For each of the affine dim expression, check if the size is known. If
1767     // known add that in the map.
1768     populateMap(linalgOp, linalgOp.getInputAndOutputOperands(),
1769                 affineExprToSize);
1770 
1771     SmallVector<Value> newOperands;
1772     SmallVector<Type> resultTypes;
1773 
1774     // `changeNeeded` is `false` if the operands of `linalgOp` require no
1775     // change in their types.
1776     bool changeNeeded = false;
1777     newOperands.reserve(linalgOp.getNumInputsAndOutputs());
1778     resultTypes.reserve(linalgOp.getNumOutputs());
1779 
1780     // Iterate over all the operands and update the static sizes.
1781     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
1782       createNewOperandWithStaticSizes(loc, rewriter, opOperand,
1783                                       affineExprToSize, linalgOp, newOperands,
1784                                       resultTypes, changeNeeded);
1785     }
1786 
1787     // If the generic op has all the required static information, no
1788     // canonicalization needed.
1789     if (!changeNeeded)
1790       return failure();
1791 
1792     // Clone op.
1793     Operation *newOp =
1794         linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands);
1795     SmallVector<Value> replacements;
1796     replacements.reserve(newOp->getNumResults());
1797     for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
1798       Value newResult = std::get<1>(it);
1799       Value oldResult = std::get<0>(it);
1800       Type newType = newResult.getType();
1801       Type oldType = oldResult.getType();
1802       replacements.push_back(
1803           (newType != oldType)
1804               ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
1805               : newResult);
1806     }
1807     rewriter.replaceOp(linalgOp, replacements);
1808     return success();
1809   }
1810 };
1811 
1812 } // namespace
1813 
1814 // All named ops canonicalizers and folders are auto-generated in the
1815 // .cpp.inc.
1816 
1817 //===----------------------------------------------------------------------===//
1818 // LinalgDialect
1819 //===----------------------------------------------------------------------===//
1820 
1821 void LinalgDialect::getCanonicalizationPatterns(
1822     RewritePatternSet &results) const {
1823   results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
1824               FoldTensorCastProducerOp, InferStaticShapeOfOperands>(
1825       getContext());
1826 }
1827 
1828 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
1829                                               Attribute value, Type type,
1830                                               Location loc) {
1831   return builder.create<arith::ConstantOp>(loc, type, value);
1832 }
1833