1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 
15 #include "mlir/AsmParser/AsmParser.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
19 #include "mlir/Dialect/Complex/IR/Complex.h"
20 #include "mlir/Dialect/Math/IR/Math.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
26 #include "mlir/Dialect/Utils/StaticValueUtils.h"
27 #include "mlir/IR/AffineExprVisitor.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/IR/OpImplementation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/Interfaces/InferTypeOpInterface.h"
33 
34 #include "llvm/ADT/DenseMap.h"
35 #include "llvm/ADT/SetVector.h"
36 #include "llvm/ADT/SmallSet.h"
37 #include "llvm/ADT/StringSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/MathExtras.h"
41 #include "llvm/Support/raw_ostream.h"
42 
43 using namespace mlir;
44 using namespace mlir::linalg;
45 
46 //===----------------------------------------------------------------------===//
47 // Support for named Linalg ops defined in ods-gen.
48 //===----------------------------------------------------------------------===//
49 
50 using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
51                                                 ArrayRef<NamedAttribute>)>;
52 
53 /// Fills the region of a structured operation using the provided
54 /// `regionBuilder`. The method is used by both named structured ops created by
55 /// ods-gen and by manually defined C++ ops. It is called by both builders and
56 /// parsers and creates a block with arguments corresponding to the elemental
57 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
58 /// ShapedType.
fillStructuredOpRegion(OpBuilder & opBuilder,Region & region,TypeRange inputTypes,TypeRange outputTypes,ArrayRef<NamedAttribute> attrs,RegionBuilderFn regionBuilder)59 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
60                                    TypeRange inputTypes, TypeRange outputTypes,
61                                    ArrayRef<NamedAttribute> attrs,
62                                    RegionBuilderFn regionBuilder) {
63   assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
64 
65   // TODO: atm all operands go through getElementTypeOrSelf,
66   // reconsider when we have evidence we need to.
67   SmallVector<Type, 8> argTypes;
68   SmallVector<Location, 8> argLocs;
69   for (auto containers : {inputTypes, outputTypes}) {
70     for (auto t : containers) {
71       argTypes.push_back(getElementTypeOrSelf(t));
72 
73       // TODO: Pass in a proper location here.
74       argLocs.push_back(opBuilder.getUnknownLoc());
75     }
76   }
77 
78   // RAII.
79   OpBuilder::InsertionGuard guard(opBuilder);
80   Block *body =
81       opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
82 
83   opBuilder.setInsertionPointToStart(body);
84   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
85   regionBuilder(b, *body, attrs);
86 
87   // indexing_maps is an auto-generated method.
88 
89   // iterator_types is an auto-generated method.
90 }
91 
92 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
93 /// The result types are derived automatically if `resultTensorTypes` is none.
94 /// The body of the operation is filled using `regionBuilder`. All ods-gen
95 /// created structured operations use the method to implement their builders.
buildStructuredOp(OpBuilder & b,OperationState & state,llvm::Optional<TypeRange> resultTensorTypes,ValueRange inputs,ValueRange outputs,ArrayRef<NamedAttribute> attributes,RegionBuilderFn regionBuilder)96 static void buildStructuredOp(OpBuilder &b, OperationState &state,
97                               llvm::Optional<TypeRange> resultTensorTypes,
98                               ValueRange inputs, ValueRange outputs,
99                               ArrayRef<NamedAttribute> attributes,
100                               RegionBuilderFn regionBuilder) {
101   // Derive the result types if needed.
102   SmallVector<Type> derivedResultTypes =
103       resultTensorTypes.value_or(TypeRange());
104   if (!resultTensorTypes)
105     copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
106             [](Type type) { return type.isa<RankedTensorType>(); });
107 
108   state.addOperands(inputs);
109   state.addOperands(outputs);
110   state.addTypes(derivedResultTypes);
111   state.addAttributes(attributes);
112   state.addAttribute(
113       "operand_segment_sizes",
114       b.getI32VectorAttr({static_cast<int32_t>(inputs.size()),
115                           static_cast<int32_t>(outputs.size())}));
116 
117   // Create and fill the region of the structured operation.
118   Region &region = *state.addRegion();
119   fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
120                          state.attributes.getAttrs(), regionBuilder);
121 }
122 
123 /// Common parsing used for both named structured ops created by ods-gen and by
124 /// manually defined C++ ops. Does not handle regions.
125 static ParseResult
parseCommonStructuredOpParts(OpAsmParser & parser,OperationState & result,SmallVectorImpl<Type> & inputTypes,SmallVectorImpl<Type> & outputTypes)126 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
127                              SmallVectorImpl<Type> &inputTypes,
128                              SmallVectorImpl<Type> &outputTypes) {
129   SMLoc inputsOperandsLoc, outputsOperandsLoc;
130   SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
131       outputsOperands;
132 
133   if (parser.parseOptionalAttrDict(result.attributes))
134     return failure();
135 
136   if (succeeded(parser.parseOptionalKeyword("ins"))) {
137     if (parser.parseLParen())
138       return failure();
139 
140     inputsOperandsLoc = parser.getCurrentLocation();
141     if (parser.parseOperandList(inputsOperands) ||
142         parser.parseColonTypeList(inputTypes) || parser.parseRParen())
143       return failure();
144   }
145 
146   if (succeeded(parser.parseOptionalKeyword("outs"))) {
147     outputsOperandsLoc = parser.getCurrentLocation();
148     if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
149         parser.parseColonTypeList(outputTypes) || parser.parseRParen())
150       return failure();
151   }
152 
153   if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
154                              result.operands) ||
155       parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
156                              result.operands))
157     return failure();
158 
159   result.addAttribute("operand_segment_sizes",
160                       parser.getBuilder().getI32VectorAttr(
161                           {static_cast<int32_t>(inputsOperands.size()),
162                            static_cast<int32_t>(outputsOperands.size())}));
163   return success();
164 }
165 
printCommonStructuredOpParts(OpAsmPrinter & p,ValueRange inputs,ValueRange outputs)166 static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
167                                          ValueRange outputs) {
168   if (!inputs.empty())
169     p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
170   if (!outputs.empty())
171     p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Specific parsing and printing for named structured ops created by ods-gen.
176 //===----------------------------------------------------------------------===//
177 
parseNamedStructuredOpRegion(OpAsmParser & parser,Region & region,unsigned numRegionArgs,TypeRange inputTypes,TypeRange outputTypes,ArrayRef<NamedAttribute> attrs,RegionBuilderFn regionBuilder)178 static ParseResult parseNamedStructuredOpRegion(
179     OpAsmParser &parser, Region &region, unsigned numRegionArgs,
180     TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
181     RegionBuilderFn regionBuilder) {
182   if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
183     return parser.emitError(
184         parser.getCurrentLocation(),
185         llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
186                       "region expects {0} args, got {1}",
187                       numRegionArgs, inputTypes.size() + outputTypes.size()));
188   }
189 
190   OpBuilder opBuilder(parser.getContext());
191   fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
192                          regionBuilder);
193   return success();
194 }
195 
196 static ParseResult
parseNamedStructuredOpResults(OpAsmParser & parser,SmallVectorImpl<Type> & resultTypes)197 parseNamedStructuredOpResults(OpAsmParser &parser,
198                               SmallVectorImpl<Type> &resultTypes) {
199   if (parser.parseOptionalArrowTypeList(resultTypes))
200     return failure();
201   return success();
202 }
203 
parseNamedStructuredOp(OpAsmParser & parser,OperationState & result,unsigned numRegionArgs,RegionBuilderFn regionBuilder)204 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
205                                           OperationState &result,
206                                           unsigned numRegionArgs,
207                                           RegionBuilderFn regionBuilder) {
208   // TODO: Enable when ods-gen supports captures.
209   SmallVector<Type, 1> inputTypes, outputTypes;
210   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
211     return failure();
212 
213   // TODO: consider merging results parsing into region parsing.
214   // Need to wait for declarative assembly resolution to decide.
215   SmallVector<Type, 1> outputTensorsTypes;
216   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
217     return failure();
218   result.addTypes(outputTensorsTypes);
219 
220   std::unique_ptr<Region> region = std::make_unique<Region>();
221   if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
222                                    outputTypes, result.attributes.getAttrs(),
223                                    regionBuilder))
224     return failure();
225   result.addRegion(std::move(region));
226 
227   return success();
228 }
229 
printNamedStructuredOpResults(OpAsmPrinter & p,TypeRange resultTypes)230 static void printNamedStructuredOpResults(OpAsmPrinter &p,
231                                           TypeRange resultTypes) {
232   if (resultTypes.empty())
233     return;
234   p.printOptionalArrowTypeList(resultTypes);
235 }
236 
printNamedStructuredOp(OpAsmPrinter & p,Operation * op,ValueRange inputs,ValueRange outputs)237 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
238                                    ValueRange inputs, ValueRange outputs) {
239   p.printOptionalAttrDict(
240       op->getAttrs(),
241       /*elidedAttrs=*/{"operand_segment_sizes",
242                        // See generated code in mlir-linalg-yaml-gen.cpp
243                        "linalg.memoized_indexing_maps"});
244 
245   // Printing is shared with generic ops, except for the region and
246   // attributes.
247   printCommonStructuredOpParts(p, inputs, outputs);
248 
249   // Results printing.
250   printNamedStructuredOpResults(p, op->getResultTypes());
251 
252   // Region is elided.
253 }
254 
255 /// This is a common class used for patterns of the form
256 /// ```
257 ///    someop(memrefcast(%src)) -> someop(%src)
258 /// ```
259 /// It folds the source of the memref.cast into the root operation directly.
foldMemRefCast(Operation * op)260 static LogicalResult foldMemRefCast(Operation *op) {
261   bool folded = false;
262   for (OpOperand &operand : op->getOpOperands()) {
263     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
264     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
265       operand.set(castOp.getOperand());
266       folded = true;
267     }
268   }
269   return success(folded);
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // Region builder helper.
274 // TODO: Move this to a utility library.
275 // The public methods on this class are referenced directly from generated code.
276 // Helper build the unary, binary, and type conversion functions defined by the
277 // DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
278 //
279 // Implementations of the math functions must be polymorphic over numeric types,
280 // internally performing necessary casts. If the function application makes no
281 // sense, then the only recourse is to assert and return nullptr. This can be
282 // extended later if it becomes possible to fail construction of the region. The
283 // invariant should be enforced at a higher level.
284 //
285 // TODO: These helpers are currently type polymorphic over the class of integer
286 // and floating point types, but they will not internally cast within bit
287 // widths of a class (mixed precision such as i8->i32) or across classes
288 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
289 // to be handled with care and work is being considered to extend the op
290 // language to make such cases explicit. In the mean-time, violating this will
291 // fail verification, which is deemed acceptable.
292 //===----------------------------------------------------------------------===//
293 
294 namespace {
295 
296 class RegionBuilderHelper {
297 public:
RegionBuilderHelper(MLIRContext * context,Block & block)298   RegionBuilderHelper(MLIRContext *context, Block &block)
299       : context(context), block(block) {}
300 
301   // Build the unary functions defined by OpDSL.
buildUnaryFn(UnaryFn unaryFn,Value arg)302   Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
303     if (!isFloatingPoint(arg))
304       llvm_unreachable("unsupported non numeric type");
305     OpBuilder builder = getBuilder();
306     switch (unaryFn) {
307     case UnaryFn::exp:
308       return builder.create<math::ExpOp>(arg.getLoc(), arg);
309     case UnaryFn::log:
310       return builder.create<math::LogOp>(arg.getLoc(), arg);
311     case UnaryFn::abs:
312       return builder.create<math::AbsOp>(arg.getLoc(), arg);
313     case UnaryFn::ceil:
314       return builder.create<math::CeilOp>(arg.getLoc(), arg);
315     case UnaryFn::floor:
316       return builder.create<math::FloorOp>(arg.getLoc(), arg);
317     case UnaryFn::negf:
318       return builder.create<arith::NegFOp>(arg.getLoc(), arg);
319     }
320     llvm_unreachable("unsupported unary function");
321   }
322 
323   // Build the binary functions defined by OpDSL.
buildBinaryFn(BinaryFn binaryFn,Value arg0,Value arg1)324   Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
325     bool allComplex = isComplex(arg0) && isComplex(arg1);
326     bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
327     bool allInteger = isInteger(arg0) && isInteger(arg1);
328     bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
329                    arg1.getType().getIntOrFloatBitWidth() == 1;
330     if (!allComplex && !allFloatingPoint && !allInteger)
331       llvm_unreachable("unsupported non numeric type");
332     OpBuilder builder = getBuilder();
333     switch (binaryFn) {
334     case BinaryFn::add:
335       if (allComplex)
336         return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
337       if (allFloatingPoint)
338         return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
339       if (allBool)
340         return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
341       return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
342     case BinaryFn::sub:
343       if (allComplex)
344         return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
345       if (allFloatingPoint)
346         return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
347       if (allBool)
348         llvm_unreachable("unsupported operation: sub with bools");
349       return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
350     case BinaryFn::mul:
351       if (allComplex)
352         return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
353       if (allFloatingPoint)
354         return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
355       if (allBool)
356         return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
357       return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
358     case BinaryFn::max_signed:
359       assert(!allComplex);
360       if (allFloatingPoint)
361         return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
362       return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
363     case BinaryFn::min_signed:
364       assert(!allComplex);
365       if (allFloatingPoint)
366         return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
367       return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
368     case BinaryFn::max_unsigned:
369       assert(!allComplex);
370       if (allFloatingPoint)
371         return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
372       return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
373     case BinaryFn::min_unsigned:
374       assert(!allComplex);
375       if (allFloatingPoint)
376         return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
377       return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
378     }
379     llvm_unreachable("unsupported binary function");
380   }
381 
382   // Build the type functions defined by OpDSL.
buildTypeFn(TypeFn typeFn,Type toType,Value operand)383   Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
384     switch (typeFn) {
385     case TypeFn::cast_signed:
386       return cast(toType, operand, false);
387     case TypeFn::cast_unsigned:
388       return cast(toType, operand, true);
389     }
390     llvm_unreachable("unsupported type conversion function");
391   }
392 
yieldOutputs(ValueRange values)393   void yieldOutputs(ValueRange values) {
394     OpBuilder builder = getBuilder();
395     Location loc = builder.getUnknownLoc();
396     builder.create<YieldOp>(loc, values);
397   }
398 
constant(const std::string & value)399   Value constant(const std::string &value) {
400     OpBuilder builder = getBuilder();
401     Location loc = builder.getUnknownLoc();
402     Attribute valueAttr = parseAttribute(value, builder.getContext());
403     return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
404                                              valueAttr);
405   }
406 
index(int64_t dim)407   Value index(int64_t dim) {
408     OpBuilder builder = getBuilder();
409     return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
410   }
411 
getIntegerType(unsigned width)412   Type getIntegerType(unsigned width) {
413     return IntegerType::get(context, width);
414   }
415 
getFloat32Type()416   Type getFloat32Type() { return Float32Type::get(context); }
getFloat64Type()417   Type getFloat64Type() { return Float64Type::get(context); }
418 
419 private:
420   // Generates operations to cast the given operand to a specified type.
421   // If the cast cannot be performed, a warning will be issued and the
422   // operand returned as-is (which will presumably yield a verification
423   // issue downstream).
cast(Type toType,Value operand,bool isUnsignedCast)424   Value cast(Type toType, Value operand, bool isUnsignedCast) {
425     OpBuilder builder = getBuilder();
426     auto loc = operand.getLoc();
427 
428     if (operand.getType() == toType)
429       return operand;
430     if (auto toIntType = toType.dyn_cast<IntegerType>()) {
431       // If operand is floating point, cast directly to the int type.
432       if (operand.getType().isa<FloatType>()) {
433         if (isUnsignedCast)
434           return builder.create<arith::FPToUIOp>(loc, toType, operand);
435         return builder.create<arith::FPToSIOp>(loc, toType, operand);
436       }
437       // Cast index operands directly to the int type.
438       if (operand.getType().isIndex())
439         return builder.create<arith::IndexCastOp>(loc, toType, operand);
440       if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
441         // Either extend or truncate.
442         if (toIntType.getWidth() > fromIntType.getWidth()) {
443           if (isUnsignedCast)
444             return builder.create<arith::ExtUIOp>(loc, toType, operand);
445           return builder.create<arith::ExtSIOp>(loc, toType, operand);
446         }
447         if (toIntType.getWidth() < fromIntType.getWidth())
448           return builder.create<arith::TruncIOp>(loc, toType, operand);
449       }
450     } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
451       // If operand is integer, cast directly to the float type.
452       // Note that it is unclear how to cast from BF16<->FP16.
453       if (operand.getType().isa<IntegerType>()) {
454         if (isUnsignedCast)
455           return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
456         return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
457       }
458       if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
459         if (toFloatType.getWidth() > fromFloatType.getWidth())
460           return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
461         if (toFloatType.getWidth() < fromFloatType.getWidth())
462           return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
463       }
464     }
465 
466     emitWarning(operand.getLoc()) << "could not cast operand of type "
467                                   << operand.getType() << " to " << toType;
468     return operand;
469   }
470 
isComplex(Value value)471   bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
isFloatingPoint(Value value)472   bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
isInteger(Value value)473   bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
474 
getBuilder()475   OpBuilder getBuilder() {
476     OpBuilder builder(context);
477     builder.setInsertionPointToEnd(&block);
478     return builder;
479   }
480 
481   MLIRContext *context;
482   Block &block;
483 };
484 
485 } // namespace
486 
487 //===----------------------------------------------------------------------===//
488 // FillOp
489 //===----------------------------------------------------------------------===//
490 
491 namespace {
492 
493 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
494 ///
495 /// For such op chains, we can create new linalg.fill ops with the result
496 /// type of the tensor.expand/collapse_shape op.
497 template <typename TensorReshapeOp>
498 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
499   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon9c47bb970411::FoldFillWithTensorReshape500   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
501                                 PatternRewriter &rewriter) const override {
502     auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
503     if (!oldFill)
504       return failure();
505 
506     Location loc = oldFill.getLoc();
507     auto newInit = rewriter.create<TensorReshapeOp>(
508         loc, reshapeOp.getResultType(), oldFill.output(),
509         reshapeOp.getReassociation());
510     rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
511                                         ValueRange{newInit});
512 
513     return success();
514   }
515 };
516 
517 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
518 /// filling value are the same.
519 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
520   using OpRewritePattern::OpRewritePattern;
521 
matchAndRewrite__anon9c47bb970411::FoldFillWithPad522   LogicalResult matchAndRewrite(tensor::PadOp padOp,
523                                 PatternRewriter &rewriter) const override {
524     auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
525     if (!fillOp)
526       return failure();
527 
528     // We can only fold if the padding value is the same as the original
529     // filling value.
530     Value padValue = padOp.getConstantPaddingValue();
531     if (!padValue || fillOp.value() != padValue)
532       return failure();
533 
534     ReifiedRankedShapedTypeDims reifiedShape;
535     ReifyRankedShapedTypeOpInterface interface =
536         cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
537     if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
538       return rewriter.notifyMatchFailure(
539           padOp, "failed to reify tensor.pad op result shape");
540 
541     auto oldResultType = padOp.getResultType();
542     SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
543                                         ShapedType::kDynamicSize);
544     auto newInitOp = rewriter.create<InitTensorOp>(
545         padOp.getLoc(), reifiedShape.front(), staticShape,
546         oldResultType.getElementType());
547     auto newFillOp = rewriter.create<FillOp>(
548         fillOp.getLoc(), ValueRange{padValue}, ValueRange{newInitOp});
549     rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
550                                                 newFillOp.result());
551 
552     return success();
553   }
554 };
555 
556 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
557 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
558 /// filling value are the same.
559 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
560   using OpRewritePattern::OpRewritePattern;
561 
matchAndRewrite__anon9c47bb970411::FoldInsertPadIntoFill562   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
563                                 PatternRewriter &rewriter) const override {
564     auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
565     if (!srcPadOp)
566       return failure();
567 
568     if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
569       return failure();
570 
571     // Walk back the tensor.insert_slice chain and find the first destination
572     // value at the start of the chain.
573     Value firstDest = insertOp.getDest();
574     while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
575       if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
576         return failure();
577 
578       // Make sure the range of values accessed are disjoint. Without this, we
579       // cannot fold tensor.pad away.
580       bool disjoint = false;
581       for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
582         // If the dimension has dynamic offset/size, we cannot guarantee
583         // disjoint. So just skip it.
584         if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
585             insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
586             prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
587           continue;
588 
589         // Get the range start and end, inclusively for both.
590         int64_t prevStart = prevOp.getStaticOffset(i);
591         int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
592                                           prevOp.getStaticStride(i);
593         int64_t nextStart = insertOp.getStaticOffset(i);
594         int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
595                                           insertOp.getStaticStride(i);
596         if (prevEnd < nextStart || nextEnd < prevStart) {
597           disjoint = true;
598           break;
599         }
600       }
601 
602       if (!disjoint)
603         break;
604       firstDest = prevOp.getDest();
605     }
606 
607     // Check whether the first destination is a fill op. For overlapped cases,
608     // this also cannot be true.
609     auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
610     if (!dstFillOp)
611       return failure();
612 
613     // We can only fold if the padding value is the same as the original
614     // filling value.
615     Value padValue = srcPadOp.getConstantPaddingValue();
616     if (!padValue || dstFillOp.value() != padValue)
617       return failure();
618 
619     SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
620     SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
621 
622     Location loc = insertOp.getLoc();
623     MLIRContext *context = getContext();
624 
625     AffineExpr sym0, sym1;
626     bindSymbols(context, sym0, sym1);
627     auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
628 
629     // Calculate the new offsets for the insert. It should be the old offsets
630     // plus low padding sizes.
631     SmallVector<OpFoldResult, 4> newOffsets;
632     for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
633       Value padValue = getValueOrCreateConstantIndexOp(
634           rewriter, srcPadOp.getLoc(), std::get<0>(p));
635       Value offsetValue = getValueOrCreateConstantIndexOp(
636           rewriter, insertOp.getLoc(), std::get<1>(p));
637       newOffsets.push_back(
638           applyMapToValues(rewriter, loc, addMap, {offsetValue, padValue})[0]);
639     }
640 
641     SmallVector<OpFoldResult, 4> newSizes;
642     for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
643       newSizes.push_back(
644           rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
645               .getResult());
646     }
647 
648     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
649         insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
650         newSizes, insertOp.getMixedStrides());
651     return success();
652   }
653 };
654 
655 } // namespace
656 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)657 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
658                                          MLIRContext *context) {
659   results
660       .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
661            FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
662            FoldInsertPadIntoFill>(context);
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // GenericOps
667 //===----------------------------------------------------------------------===//
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputs,ArrayAttr indexingMaps,ArrayAttr iteratorTypes,StringAttr doc,StringAttr libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)668 void GenericOp::build(
669     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
670     ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
671     ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
672     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
673     ArrayRef<NamedAttribute> attributes) {
674   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
675         iteratorTypes, doc, libraryCall);
676   result.addAttributes(attributes);
677   if (!bodyBuild)
678     return;
679 
680   SmallVector<Type, 4> blockArgTypes;
681   SmallVector<Location, 4> blockArgLocs;
682   for (ValueRange container : {inputs, outputs}) {
683     for (Value v : container) {
684       blockArgTypes.push_back(getElementTypeOrSelf(v));
685       blockArgLocs.push_back(v.getLoc());
686     }
687   }
688 
689   OpBuilder::InsertionGuard guard(builder);
690   auto &region = *result.regions.front();
691   Block *bodyBlock =
692       builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
693   bodyBuild(builder, result.location, bodyBlock->getArguments());
694 }
695 
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputs,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)696 void GenericOp::build(
697     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
698     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
699     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
700     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
701     ArrayRef<NamedAttribute> attributes) {
702   build(builder, result, resultTensorTypes, inputs, outputs,
703         builder.getAffineMapArrayAttr(indexingMaps),
704         builder.getStrArrayAttr(iteratorTypes),
705         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
706         libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
707         bodyBuild, attributes);
708 }
709 
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputs,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)710 void GenericOp::build(
711     OpBuilder &builder, OperationState &result, ValueRange inputs,
712     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
713     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
714     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
715     ArrayRef<NamedAttribute> attributes) {
716   build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
717         iteratorTypes, doc, libraryCall, bodyBuild, attributes);
718 }
719 
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputs,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)720 void GenericOp::build(
721     OpBuilder &builder, OperationState &result, ValueRange inputs,
722     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
723     ArrayRef<StringRef> iteratorTypes,
724     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
725     ArrayRef<NamedAttribute> attributes) {
726   build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
727         /*doc=*/"",
728         /*libraryCall=*/"", bodyBuild, attributes);
729 }
730 
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputs,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)731 void GenericOp::build(
732     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
733     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
734     ArrayRef<StringRef> iteratorTypes,
735     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
736     ArrayRef<NamedAttribute> attributes) {
737   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
738         iteratorTypes,
739         /*doc=*/"",
740         /*libraryCall=*/"", bodyBuild, attributes);
741 }
742 
print(OpAsmPrinter & p)743 void GenericOp::print(OpAsmPrinter &p) {
744   p << " ";
745 
746   // Print extra attributes.
747   auto genericAttrNames = linalgTraitAttrNames();
748 
749   llvm::StringSet<> genericAttrNamesSet;
750   genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
751   SmallVector<NamedAttribute, 8> genericAttrs;
752   for (auto attr : (*this)->getAttrs())
753     if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
754       genericAttrs.push_back(attr);
755   if (!genericAttrs.empty()) {
756     auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
757     p << genericDictAttr;
758   }
759 
760   // Printing is shared with named ops, except for the region and attributes
761   printCommonStructuredOpParts(p, inputs(), outputs());
762 
763   genericAttrNames.push_back("operand_segment_sizes");
764   genericAttrNamesSet.insert(genericAttrNames.back());
765 
766   bool hasExtraAttrs = false;
767   for (NamedAttribute n : (*this)->getAttrs()) {
768     if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
769       break;
770   }
771   if (hasExtraAttrs) {
772     p << " attrs = ";
773     p.printOptionalAttrDict((*this)->getAttrs(),
774                             /*elidedAttrs=*/genericAttrNames);
775   }
776 
777   // Print region.
778   if (!region().empty()) {
779     p << ' ';
780     p.printRegion(region());
781   }
782 
783   // Print results.
784   printNamedStructuredOpResults(p, result_tensors().getTypes());
785 }
786 
parse(OpAsmParser & parser,OperationState & result)787 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
788   DictionaryAttr dictAttr;
789   // Parse the core linalg traits that must check into a dictAttr.
790   // The name is unimportant as we will overwrite result.attributes.
791   // The core linalg traits must contain the information necessary to pass the
792   // verifier.
793   if (parser.parseAttribute(dictAttr, "_", result.attributes))
794     return failure();
795   result.attributes.assign(dictAttr.getValue().begin(),
796                            dictAttr.getValue().end());
797 
798   // Parsing is shared with named ops, except for the region.
799   SmallVector<Type, 1> inputTypes, outputTypes;
800   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
801     return failure();
802 
803   // Optional attributes may be added.
804   if (succeeded(parser.parseOptionalKeyword("attrs")))
805     if (failed(parser.parseEqual()) ||
806         failed(parser.parseOptionalAttrDict(result.attributes)))
807       return failure();
808 
809   std::unique_ptr<Region> region = std::make_unique<Region>();
810   if (parser.parseRegion(*region, {}))
811     return failure();
812   result.addRegion(std::move(region));
813 
814   // Generic ops may specify that a subset of its outputs are tensors. Such
815   // outputs are specified in the result type.
816   // TODO: may need to move output parsing before region parsing.
817   // Need to wait for declarative assembly resolution to decide.
818   SmallVector<Type, 1> outputTensorsTypes;
819   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
820     return failure();
821   result.addTypes(outputTensorsTypes);
822 
823   return success();
824 }
825 
getGenericEffectsImpl(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects,ValueRange results,ValueRange inputBuffers,ValueRange outputs)826 static void getGenericEffectsImpl(
827     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
828         &effects,
829     ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
830   for (Value value : inputBuffers) {
831     effects.emplace_back(MemoryEffects::Read::get(), value,
832                          SideEffects::DefaultResource::get());
833   }
834   for (Value value : outputs) {
835     effects.emplace_back(MemoryEffects::Read::get(), value,
836                          SideEffects::DefaultResource::get());
837     effects.emplace_back(MemoryEffects::Write::get(), value,
838                          SideEffects::DefaultResource::get());
839   }
840 }
841 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)842 void GenericOp::getEffects(
843     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
844         &effects) {
845   SmallVector<Value> inputBuffers = getInputBufferOperands();
846   SmallVector<Value> outputBuffers = getOutputBufferOperands();
847   getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
848                         outputBuffers);
849 }
850 
verify()851 LogicalResult GenericOp::verify() { return success(); }
852 
853 namespace {
854 
855 struct DeduplicateAndRemoveDeadOperandsAndResults
856     : public OpRewritePattern<GenericOp> {
857   using OpRewritePattern<GenericOp>::OpRewritePattern;
858 
matchAndRewrite__anon9c47bb970511::DeduplicateAndRemoveDeadOperandsAndResults859   LogicalResult matchAndRewrite(GenericOp genericOp,
860                                 PatternRewriter &rewriter) const override {
861     // Create a map from argument position in the original op to the argument
862     // position in the new op. If the argument is dropped it wont have an entry.
863     SmallVector<OpOperand *> droppedOpOperands;
864 
865     // Information needed to build the new op.
866     SmallVector<Value> newInputOperands, newOutputOperands;
867     SmallVector<AffineMap> newIndexingMaps;
868 
869     // Gather information about duplicate input operands.
870     llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
871         deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
872                                  newIndexingMaps);
873 
874     // Gather information about the dropped outputs.
875     llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
876         deduplicateOutputOperands(genericOp, droppedOpOperands,
877                                   newOutputOperands, newIndexingMaps);
878 
879     // Check if there is any change to operands.
880     if (newInputOperands.size() + newOutputOperands.size() ==
881         static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
882       return failure();
883 
884     // Create the new op with the body being empty.
885     Location loc = genericOp.getLoc();
886     SmallVector<Type> newResultTypes;
887     if (genericOp.hasTensorSemantics()) {
888       newResultTypes = llvm::to_vector(llvm::map_range(
889           newOutputOperands, [](Value v) { return v.getType(); }));
890     }
891     auto newOp = rewriter.create<GenericOp>(
892         loc, newResultTypes, newInputOperands, newOutputOperands,
893         rewriter.getAffineMapArrayAttr(newIndexingMaps),
894         genericOp.iterator_types(), genericOp.docAttr(),
895         genericOp.library_callAttr(),
896         [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
897           return;
898         });
899     // Copy over unknown attributes. They might be load bearing for some flow.
900     ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
901     for (NamedAttribute kv : genericOp->getAttrs())
902       if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
903         newOp->setAttr(kv.getName(), kv.getValue());
904 
905     // Fix up the payload of the canonicalized operation.
906     populateOpPayload(genericOp, newOp, origInsToNewInsPos,
907                       origOutsToNewOutsPos, rewriter);
908 
909     // Replace all live uses of the op.
910     SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
911     for (auto result : llvm::enumerate(genericOp.getResults())) {
912       auto it = origOutsToNewOutsPos.find(result.index());
913       if (it == origOutsToNewOutsPos.end())
914         continue;
915       replacementsVals[result.index()] = newOp.getResult(it->second);
916     }
917     rewriter.replaceOp(genericOp, replacementsVals);
918     return success();
919   }
920 
921 private:
922   // Deduplicate input operands, and return the
923   // - Mapping from operand position in the original op, to operand position in
924   // the canonicalized op.
925   // - The preserved input operands list (by reference).
926   llvm::SmallDenseMap<unsigned, unsigned>
deduplicateInputOperands__anon9c47bb970511::DeduplicateAndRemoveDeadOperandsAndResults927   deduplicateInputOperands(GenericOp genericOp,
928                            SmallVector<OpOperand *> &droppedOpOperands,
929                            SmallVector<Value> &newInputOperands,
930                            SmallVector<AffineMap> &newIndexingMaps) const {
931     llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
932     llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
933     for (auto inputOpOperand : llvm::enumerate(genericOp.getInputOperands())) {
934       // Check if operand is dead and if dropping the indexing map makes the
935       // loops to shape computation invalid.
936       if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) {
937         // Add the current operands to the list of potentially droppable
938         // operands. If it cannot be dropped, this needs to be popped back.
939         droppedOpOperands.push_back(inputOpOperand.value());
940         if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
941           continue;
942         droppedOpOperands.pop_back();
943       }
944 
945       // Check if this operand is a duplicate.
946       AffineMap indexingMap =
947           genericOp.getTiedIndexingMap(inputOpOperand.value());
948       auto it = dedupedInputs.find(
949           std::make_pair(inputOpOperand.value()->get(), indexingMap));
950       if (it != dedupedInputs.end()) {
951         origToNewPos[inputOpOperand.index()] = it->second;
952         droppedOpOperands.push_back(inputOpOperand.value());
953         continue;
954       }
955 
956       // This is a preserved argument.
957       origToNewPos[inputOpOperand.index()] = newInputOperands.size();
958       dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] =
959           newInputOperands.size();
960       newInputOperands.push_back(inputOpOperand.value()->get());
961       newIndexingMaps.push_back(indexingMap);
962     }
963     return origToNewPos;
964   }
965 
966   // Deduplicate output operands, and return the
967   // - Mapping from operand position in the original op, to operand position in
968   // the canonicalized op.
969   // - The preserved output operands list (by reference).
970   llvm::SmallDenseMap<unsigned, unsigned>
deduplicateOutputOperands__anon9c47bb970511::DeduplicateAndRemoveDeadOperandsAndResults971   deduplicateOutputOperands(GenericOp genericOp,
972                             SmallVector<OpOperand *> &droppedOpOperands,
973                             SmallVector<Value> &newOutputOperands,
974                             SmallVector<AffineMap> &newIndexingMaps) const {
975     llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
976     llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
977         dedupedOutpts;
978     // If the op doesnt have tensor semantics, keep all the outputs as
979     // preserved.
980     if (!genericOp.hasTensorSemantics()) {
981       for (auto outputOpOperand :
982            llvm::enumerate(genericOp.getOutputOperands())) {
983         origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
984         newOutputOperands.push_back(outputOpOperand.value()->get());
985         newIndexingMaps.push_back(
986             genericOp.getTiedIndexingMap(outputOpOperand.value()));
987       }
988     } else {
989       // Output argument can be dropped if the result has
990       // - no users, and
991       // - it is not used in the payload, and
992       // - the corresponding indexing maps are not needed for loop bound
993       //   computation.
994       auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
995       for (auto outputOpOperand :
996            llvm::enumerate(genericOp.getOutputOperands())) {
997         Value result = genericOp.getResult(outputOpOperand.index());
998         AffineMap indexingMap =
999             genericOp.getTiedIndexingMap(outputOpOperand.value());
1000         auto key =
1001             std::make_tuple(outputOpOperand.value()->get(), indexingMap,
1002                             yieldOp->getOperand(outputOpOperand.index()));
1003 
1004         // Do not drop an out if its value is used in the payload.
1005         if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
1006           if (result.use_empty()) {
1007             // Check if the opoperand can be dropped without affecting loop
1008             // bound computation. Add the operand to the list of dropped op
1009             // operand for checking. If it cannot be dropped, need to pop the
1010             // value back.
1011             droppedOpOperands.push_back(outputOpOperand.value());
1012             if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
1013               continue;
1014             }
1015             droppedOpOperands.pop_back();
1016           }
1017 
1018           // The out operand can also be dropped if it is computed redundantly
1019           // by another result, the conditions for that are
1020           // - The same operand is used as the out operand
1021           // - The same indexing map is used
1022           // - The same yield value is used.
1023           auto it = dedupedOutpts.find(key);
1024           if (it != dedupedOutpts.end()) {
1025             origToNewPos[outputOpOperand.index()] = it->second;
1026             droppedOpOperands.push_back(outputOpOperand.value());
1027             continue;
1028           }
1029         }
1030 
1031         origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
1032         dedupedOutpts[key] = newOutputOperands.size();
1033         newOutputOperands.push_back(outputOpOperand.value()->get());
1034         newIndexingMaps.push_back(
1035             genericOp.getTiedIndexingMap(outputOpOperand.value()));
1036       }
1037     }
1038 
1039     return origToNewPos;
1040   }
1041 
1042   // Populate the body of the canonicalized operation.
populateOpPayload__anon9c47bb970511::DeduplicateAndRemoveDeadOperandsAndResults1043   void populateOpPayload(
1044       GenericOp genericOp, GenericOp newOp,
1045       const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
1046       const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
1047       PatternRewriter &rewriter) const {
1048     // Merge the body of the original op with the new op.
1049     Block *newOpBlock = &newOp.region().front();
1050     assert(newOpBlock->empty() && "expected new op to have an empty payload");
1051     Block *origOpBlock = &genericOp.region().front();
1052     SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
1053 
1054     // Replace all arguments in the original op, with arguments from the
1055     // canonicalized op.
1056     auto updateReplacements =
1057         [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
1058             const llvm::SmallDenseMap<unsigned, unsigned> &map) {
1059           for (auto origOperand : llvm::enumerate(origOperands)) {
1060             auto it = map.find(origOperand.index());
1061             if (it == map.end())
1062               continue;
1063             OpOperand *newOperand = newOperands[it->second];
1064             replacements[origOperand.value()->getOperandNumber()] =
1065                 newOpBlock->getArgument(newOperand->getOperandNumber());
1066           }
1067         };
1068 
1069     OpOperandVector origInputOperands = genericOp.getInputOperands();
1070     OpOperandVector newInputOperands = newOp.getInputOperands();
1071     updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
1072 
1073     OpOperandVector origOutputOperands = genericOp.getOutputOperands();
1074     OpOperandVector newOutputOperands = newOp.getOutputOperands();
1075     updateReplacements(origOutputOperands, newOutputOperands,
1076                        origOutsToNewOutsPos);
1077 
1078     rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
1079 
1080     // Drop the unused yield args.
1081     if (newOp.getNumOutputs() != genericOp.getNumOutputs()) {
1082       OpBuilder::InsertionGuard g(rewriter);
1083       YieldOp origYieldOp = cast<YieldOp>(newOpBlock->getTerminator());
1084       rewriter.setInsertionPoint(origYieldOp);
1085 
1086       SmallVector<Value> newYieldVals(newOp.getNumOutputs(), nullptr);
1087       for (const auto &yieldOpOperands :
1088            llvm::enumerate(origYieldOp.values())) {
1089         auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
1090         if (it == origOutsToNewOutsPos.end())
1091           continue;
1092         newYieldVals[it->second] = yieldOpOperands.value();
1093       }
1094       rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
1095     }
1096   }
1097 };
1098 
1099 /// Remove generic operations (on tensors) that are just copying
1100 /// the values from inputs to the results. Requirements are
1101 /// 1) All iterator types are parallel
1102 /// 2) The body contains just a yield operation with the yielded values being
1103 ///    the arguments corresponding to the operands.
1104 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1105   using OpRewritePattern<GenericOp>::OpRewritePattern;
1106 
matchAndRewrite__anon9c47bb970511::EraseIdentityGenericOp1107   LogicalResult matchAndRewrite(GenericOp genericOp,
1108                                 PatternRewriter &rewriter) const override {
1109     // Check all indexing maps are identity.
1110     if (llvm::any_of(genericOp.getIndexingMapsArray(),
1111                      [](AffineMap map) { return !map.isIdentity(); }))
1112       return failure();
1113 
1114     // Check that the body of the linalg operation is just a linalg.yield
1115     // operation.
1116     Block &body = genericOp.region().front();
1117     if (!llvm::hasSingleElement(body))
1118       return failure();
1119     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1120     if (!yieldOp)
1121       return failure();
1122 
1123     // In the buffer case, we need to check exact buffer equality.
1124     if (genericOp.hasBufferSemantics()) {
1125       if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
1126           genericOp.getInputOperand(0)->get() ==
1127               genericOp.getOutputOperand(0)->get()) {
1128         rewriter.eraseOp(genericOp);
1129         return success();
1130       }
1131       return failure();
1132     }
1133 
1134     // Get the argument number of the returned values. That is the operand
1135     // number to use for replacing uses of this operation.
1136     SmallVector<Value> returnedArgs;
1137     for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) {
1138       auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
1139       if (!yieldArg || yieldArg.getOwner() != &body)
1140         return failure();
1141       unsigned argumentNumber = yieldArg.getArgNumber();
1142       Value returnedArg = genericOp->getOperand(argumentNumber);
1143       Type resultType = genericOp->getResult(yieldVal.index()).getType();
1144       // The input can have a different type than the result, e.g. a dynamic
1145       // input dimension can be turned into a static output dimension.
1146       Type returnType = returnedArg.getType();
1147       if (returnType != resultType) {
1148         // Distinguish between sparse conversion or dense tensor casting.
1149         // TODO: unify the two ops?
1150         if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1151             sparse_tensor::getSparseTensorEncoding(resultType))
1152           returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1153               genericOp.getLoc(), resultType, returnedArg);
1154         else {
1155           if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1156                                                  resultType))
1157             return failure();
1158           returnedArg = rewriter.create<tensor::CastOp>(
1159               genericOp.getLoc(), resultType, returnedArg);
1160         }
1161       }
1162       returnedArgs.push_back(returnedArg);
1163     }
1164 
1165     if (returnedArgs.size() != genericOp->getNumResults())
1166       return failure();
1167     rewriter.replaceOp(genericOp, returnedArgs);
1168     return success();
1169   }
1170 };
1171 } // namespace
1172 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1173 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1174                                             MLIRContext *context) {
1175   results
1176       .add<DeduplicateAndRemoveDeadOperandsAndResults, EraseIdentityGenericOp>(
1177           context);
1178 }
1179 
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)1180 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
1181                               SmallVectorImpl<OpFoldResult> &) {
1182   return foldMemRefCast(*this);
1183 }
1184 
1185 //===----------------------------------------------------------------------===//
1186 // InitTensorOp
1187 //===----------------------------------------------------------------------===//
1188 
build(OpBuilder & b,OperationState & result,ArrayRef<OpFoldResult> sizes,Type elementType,ArrayRef<NamedAttribute> attrs)1189 void InitTensorOp::build(OpBuilder &b, OperationState &result,
1190                          ArrayRef<OpFoldResult> sizes, Type elementType,
1191                          ArrayRef<NamedAttribute> attrs) {
1192   SmallVector<Value, 4> dynamicSizes;
1193   SmallVector<int64_t, 4> staticSizes;
1194   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1195                              ShapedType::kDynamicSize);
1196   auto resultType = RankedTensorType ::get(staticSizes, elementType);
1197   build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
1198   result.addAttributes(attrs);
1199 }
1200 
verify()1201 LogicalResult InitTensorOp::verify() {
1202   RankedTensorType resultType = getType();
1203   SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
1204       static_sizes().cast<ArrayAttr>(),
1205       [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
1206 
1207   if (failed(verifyListOfOperandsOrIntegers(
1208           *this, "sizes", resultType.getRank(), static_sizes(), sizes(),
1209           ShapedType::isDynamic)))
1210     return failure();
1211 
1212   if (static_sizes().size() != static_cast<unsigned>(resultType.getRank()))
1213     return emitError("expected ") << resultType.getRank() << " sizes values";
1214 
1215   Type expectedType = InitTensorOp::inferResultType(
1216       staticSizes, resultType.getElementType(), resultType.getEncoding());
1217   if (resultType != expectedType) {
1218     return emitError("specified type ")
1219            << resultType << " does not match the inferred type "
1220            << expectedType;
1221   }
1222   return success();
1223 }
1224 
inferResultType(ArrayRef<int64_t> staticSizes,Type elementType,Attribute encoding)1225 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
1226                                    Type elementType, Attribute encoding) {
1227   return RankedTensorType::get(staticSizes, elementType, encoding);
1228 }
1229 
getMixedSizes()1230 SmallVector<OpFoldResult> InitTensorOp::getMixedSizes() {
1231   SmallVector<OpFoldResult> mixedSizes;
1232   mixedSizes.reserve(getType().getRank());
1233   unsigned dynamicValIndex = 0;
1234   for (Attribute attr : static_sizes()) {
1235     auto intAttr = attr.cast<IntegerAttr>();
1236     if (!ShapedType::isDynamic(intAttr.getInt())) {
1237       mixedSizes.push_back(intAttr);
1238       continue;
1239     }
1240     mixedSizes.push_back(sizes()[dynamicValIndex++]);
1241   }
1242   return mixedSizes;
1243 }
1244 
1245 namespace {
1246 /// Change the type of the result of a `linalg.init_tensor` by making the result
1247 /// type statically sized along dimension that in the original operation where
1248 /// defined as dynamic, but the size was defined using a `constant` op. For
1249 /// example
1250 ///
1251 ///  %c5 = arith.constant 5: index
1252 ///  %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
1253 ///
1254 ///  to
1255 ///
1256 ///  %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
1257 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
1258   using OpRewritePattern<InitTensorOp>::OpRewritePattern;
1259 
matchAndRewrite__anon9c47bb970b11::ReplaceStaticShapeDims1260   LogicalResult matchAndRewrite(InitTensorOp op,
1261                                 PatternRewriter &rewriter) const override {
1262     SmallVector<Value, 4> dynamicSizes;
1263     SmallVector<int64_t, 4> staticSizes;
1264     for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
1265       // If the size is already static, nothing to do.
1266       if (!op.isDynamicSize(i)) {
1267         staticSizes.push_back(op.getStaticSize(i));
1268         continue;
1269       }
1270 
1271       // If the size is dynamic but defined using a `constant` op, get the
1272       // constant value to find the static size to use.
1273       unsigned operandNum = op.getIndexOfDynamicSize(i);
1274       Value sizeOperand = op.getOperand(operandNum);
1275       if (auto constantIndexOp =
1276               sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) {
1277         staticSizes.push_back(constantIndexOp.value());
1278         continue;
1279       }
1280 
1281       // Fallback case. Keep the size dynamic.
1282       dynamicSizes.push_back(sizeOperand);
1283       staticSizes.push_back(ShapedType::kDynamicSize);
1284     }
1285     RankedTensorType newType =
1286         RankedTensorType::get(staticSizes, op.getType().getElementType());
1287     if (newType == op.getType())
1288       return failure();
1289     auto newOp =
1290         rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
1291                                       rewriter.getI64ArrayAttr(staticSizes));
1292     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1293     return success();
1294   }
1295 };
1296 } // namespace
1297 
1298 namespace {
1299 /// Since `init_tensor` operation creates a tensor needed only for its shape, a
1300 /// slice of this is also needed only for its shape. The result can be
1301 /// replaced by a new init_tensor operation of the same size as the extract
1302 /// slice op.
1303 struct FoldInitTensorWithExtractSliceOp
1304     : public OpRewritePattern<tensor::ExtractSliceOp> {
1305   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
1306 
matchAndRewrite__anon9c47bb970c11::FoldInitTensorWithExtractSliceOp1307   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
1308                                 PatternRewriter &rewriter) const override {
1309     if (!sliceOp.getSource().getDefiningOp<linalg::InitTensorOp>())
1310       return failure();
1311     // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
1312     // as well as its result type.
1313     rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
1314         sliceOp, sliceOp.getSizes(),
1315         sliceOp.getResult().getType().cast<RankedTensorType>().getShape(),
1316         sliceOp.getSourceType().getElementType());
1317     return success();
1318   }
1319 };
1320 
1321 template <typename TensorReshapeOp>
1322 struct FoldInitTensorWithTensorReshapeOp
1323     : public OpRewritePattern<TensorReshapeOp> {
1324   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1325 
matchAndRewrite__anon9c47bb970c11::FoldInitTensorWithTensorReshapeOp1326   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1327                                 PatternRewriter &rewriter) const override {
1328     if (!reshapeOp.getSrc().template getDefiningOp<InitTensorOp>())
1329       return failure();
1330     Location loc = reshapeOp.getLoc();
1331     ReifiedRankedShapedTypeDims resultShapes;
1332     ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
1333         cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
1334     if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
1335                                                           resultShapes)) ||
1336         !llvm::hasSingleElement(resultShapes))
1337       return failure();
1338     Value initTensor = rewriter.create<InitTensorOp>(
1339         loc, getAsOpFoldResult(resultShapes[0]),
1340         reshapeOp.getResultType().getElementType());
1341     if (initTensor.getType() != reshapeOp.getResultType()) {
1342       rewriter.replaceOpWithNewOp<tensor::CastOp>(
1343           reshapeOp, reshapeOp.getResultType(), initTensor);
1344     } else {
1345       rewriter.replaceOp(reshapeOp, initTensor);
1346     }
1347     return success();
1348   }
1349 };
1350 
1351 struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
1352   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1353 
matchAndRewrite__anon9c47bb970c11::FoldInitTensorWithDimOp1354   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1355                                 PatternRewriter &rewriter) const override {
1356     Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1357     auto initTensorOp = dimOp.getSource().getDefiningOp<linalg::InitTensorOp>();
1358     if (!initTensorOp || !maybeConstantIndex)
1359       return failure();
1360     if (!initTensorOp.isDynamicSize(*maybeConstantIndex))
1361       return failure();
1362     rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex));
1363     return success();
1364   }
1365 };
1366 
1367 /// Canonicalize
1368 ///
1369 /// ```mlir
1370 ///   %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
1371 ///   %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1372 /// ```
1373 ///
1374 /// into
1375 ///
1376 /// ```mlir
1377 ///   %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32>
1378 /// ```
1379 ///
1380 /// This assumes the input program is correct in terms of its shape. So it
1381 /// is safe to assume that `%d0` is in fact 4. If that was not the case, the
1382 /// input program is wrong to begin with, so its undefined behavior anyway (i.e.
1383 /// this optimization can still triggering without violating program semantics).
1384 struct FoldInitTensorWithTensorCastOp
1385     : public OpRewritePattern<tensor::CastOp> {
1386   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1387 
matchAndRewrite__anon9c47bb970c11::FoldInitTensorWithTensorCastOp1388   LogicalResult matchAndRewrite(tensor::CastOp castOp,
1389                                 PatternRewriter &rewriter) const override {
1390     if (!canFoldIntoProducerOp(castOp))
1391       return failure();
1392     auto producer = castOp.getSource().getDefiningOp<InitTensorOp>();
1393     if (!producer)
1394       return failure();
1395 
1396     auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
1397     ArrayRef<int64_t> resultShape = resultType.getShape();
1398     SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1399     SmallVector<OpFoldResult> newMixedSizes;
1400     newMixedSizes.reserve(currMixedSizes.size());
1401     assert(resultShape.size() == currMixedSizes.size() &&
1402            "mismatch in result shape and sizes of init_tensor op");
1403     for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1404       int64_t newDim = std::get<0>(it);
1405       OpFoldResult currDim = std::get<1>(it);
1406       // Case 1: The init tensor dim is static. Check that the tensor cast
1407       // result dim matches.
1408       if (auto attr = currDim.dyn_cast<Attribute>()) {
1409         if (ShapedType::isDynamic(newDim) ||
1410             newDim != attr.cast<IntegerAttr>().getInt()) {
1411           // Something is off, the cast result shape cannot be more dynamic than
1412           // the init tensor result shape (enforced by `canFoldIntoProducer`).
1413           // Abort for now.
1414           return rewriter.notifyMatchFailure(
1415               producer, "mismatch in static value of shape of init "
1416                         "tensor result and cast result");
1417         }
1418         newMixedSizes.push_back(attr);
1419         continue;
1420       }
1421 
1422       // Case 2 : The tensor cast shape is static, but init tensor result shape
1423       // is dynamic.
1424       if (!ShapedType::isDynamic(newDim)) {
1425         newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1426         continue;
1427       }
1428 
1429       // Case 3 : The tensor cast shape is dynamic and init tensor result shape
1430       // is dynamic. Use the dynamic value from the init tensor op.
1431       newMixedSizes.push_back(currDim);
1432     }
1433 
1434     rewriter.replaceOpWithNewOp<InitTensorOp>(castOp, newMixedSizes,
1435                                               resultType.getElementType());
1436     return success();
1437   }
1438 };
1439 
1440 } // namespace
1441 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1442 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443                                                MLIRContext *context) {
1444   results.add<FoldInitTensorWithTensorCastOp, FoldInitTensorWithDimOp,
1445               FoldInitTensorWithExtractSliceOp,
1446               FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>,
1447               FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>,
1448               ReplaceStaticShapeDims>(context);
1449 }
1450 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)1451 LogicalResult InitTensorOp::reifyResultShapes(
1452     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1453   auto shapes = llvm::to_vector<4>(llvm::map_range(
1454       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
1455         if (isDynamicSize(dim))
1456           return getDynamicSize(dim);
1457         return builder.create<arith::ConstantIndexOp>(getLoc(),
1458                                                       getStaticSize(dim));
1459       }));
1460   reifiedReturnShapes.emplace_back(std::move(shapes));
1461   return success();
1462 }
1463 
1464 //===----------------------------------------------------------------------===//
1465 // YieldOp
1466 //===----------------------------------------------------------------------===//
1467 
print(OpAsmPrinter & p)1468 void linalg::YieldOp::print(OpAsmPrinter &p) {
1469   if (getNumOperands() > 0)
1470     p << ' ' << getOperands();
1471   p.printOptionalAttrDict((*this)->getAttrs());
1472   if (getNumOperands() > 0)
1473     p << " : " << getOperandTypes();
1474 }
1475 
parse(OpAsmParser & parser,OperationState & result)1476 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
1477   SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
1478   SmallVector<Type, 2> types;
1479   SMLoc loc = parser.getCurrentLocation();
1480   return failure(parser.parseOperandList(opInfo) ||
1481                  parser.parseOptionalAttrDict(result.attributes) ||
1482                  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1483                  parser.resolveOperands(opInfo, types, loc, result.operands));
1484 }
1485 
1486 // Check the operand number and types must match the element types of the
1487 // LinalgOp interface's shaped operands.
verifyYield(linalg::YieldOp op,LinalgOp linalgOp)1488 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1489   if (op.getNumOperands() != linalgOp.getNumOutputs())
1490     return op.emitOpError("expected number of yield values (")
1491            << linalgOp.getNumOutputs()
1492            << ") to match the number of operands of the enclosing "
1493            << "LinalgOp (" << op.getNumOperands() << ")";
1494 
1495   for (OpOperand &opOperand : op->getOpOperands()) {
1496     OpOperand *outputOperand =
1497         linalgOp.getOutputOperand(opOperand.getOperandNumber());
1498     Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
1499     if (opOperand.get().getType() != elementType)
1500       return op.emitOpError("type of yield operand ")
1501              << (opOperand.getOperandNumber() + 1) << " ("
1502              << opOperand.get().getType() << ") doesn't match "
1503              << "the element type of the enclosing linalg.generic op ("
1504              << elementType << ")";
1505   }
1506   return success();
1507 }
1508 
verify()1509 LogicalResult linalg::YieldOp::verify() {
1510   auto *parentOp = (*this)->getParentOp();
1511   if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1512     return emitOpError("expected single non-empty parent region");
1513 
1514   if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1515     return verifyYield(*this, linalgOp);
1516 
1517   return emitOpError("expected parent op with LinalgOp interface");
1518 }
1519 
1520 //===----------------------------------------------------------------------===//
1521 // IndexOp
1522 //===----------------------------------------------------------------------===//
1523 
verify()1524 LogicalResult IndexOp::verify() {
1525   auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
1526   if (!linalgOp)
1527     return emitOpError("expected parent op with LinalgOp interface");
1528   if (linalgOp.getNumLoops() <= dim())
1529     return emitOpError("expected dim (")
1530            << dim() << ") to be lower than the number of loops ("
1531            << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
1532   return success();
1533 }
1534 
1535 /////// Operations corresponding to library calls defined with Tablegen ////////
1536 
1537 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1538 
1539 #define GET_OP_CLASSES
1540 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1541 
1542 #define GET_OP_CLASSES
1543 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1544 
1545 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
1546 /// Assumes `op` is a LinalgOp.
getDimsOfType(Operation * op,StringRef iteratorTypeName,SmallVectorImpl<unsigned> & res)1547 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
1548                                  SmallVectorImpl<unsigned> &res) {
1549   if (!cast<LinalgOp>(op).iterator_types())
1550     return;
1551 
1552   unsigned dim = 0;
1553   for (auto tn :
1554        cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
1555     if (tn == iteratorTypeName)
1556       res.push_back(dim);
1557     ++dim;
1558   }
1559 }
1560 
extractOrIdentityMap(Optional<AffineMap> maybeMap,unsigned rank,MLIRContext * context)1561 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
1562                                              unsigned rank,
1563                                              MLIRContext *context) {
1564   if (maybeMap)
1565     return *maybeMap;
1566   if (rank == 0)
1567     return AffineMap::get(context);
1568   return AffineMap::getMultiDimIdentityMap(rank, context);
1569 }
1570 
1571 SmallVector<AffineExpr, 4>
makeAffineDimExprs(unsigned num,unsigned & startIdx,MLIRContext * context)1572 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
1573                                  MLIRContext *context) {
1574   SmallVector<AffineExpr, 4> res;
1575   res.reserve(num);
1576   for (unsigned i = 0; i < num; ++i)
1577     res.push_back(getAffineDimExpr(startIdx++, context));
1578   return res;
1579 }
1580 
concat(ArrayRef<AffineExpr> a,ArrayRef<AffineExpr> b)1581 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
1582                                                 ArrayRef<AffineExpr> b) {
1583   auto rangeA = llvm::make_range(a.begin(), a.end());
1584   auto rangeB = llvm::make_range(b.begin(), b.end());
1585   auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1586   return llvm::to_vector<4>(concatRanges);
1587 }
1588 
appendMangledType(llvm::raw_string_ostream & ss,Type t)1589 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
1590   if (auto memref = t.dyn_cast<MemRefType>()) {
1591     ss << "view";
1592     for (auto size : memref.getShape())
1593       if (size < 0)
1594         ss << "sx";
1595       else
1596         ss << size << "x";
1597     appendMangledType(ss, memref.getElementType());
1598   } else if (auto vec = t.dyn_cast<VectorType>()) {
1599     ss << "vector";
1600     llvm::interleave(
1601         vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
1602     appendMangledType(ss, vec.getElementType());
1603   } else if (t.isSignlessIntOrIndexOrFloat()) {
1604     ss << t;
1605   } else {
1606     llvm_unreachable("Invalid type for linalg library name mangling");
1607   }
1608 }
1609 
generateLibraryCallName(Operation * op)1610 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
1611   assert(isa<LinalgOp>(op));
1612   std::string name(op->getName().getStringRef().str());
1613   name.reserve(128);
1614   std::replace(name.begin(), name.end(), '.', '_');
1615   llvm::raw_string_ostream ss(name);
1616   ss << "_";
1617   auto types = op->getOperandTypes();
1618   llvm::interleave(
1619       types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
1620       [&]() { ss << "_"; });
1621   return ss.str();
1622 }
1623 
1624 //===----------------------------------------------------------------------===//
1625 // Canonicalizers and Folders.
1626 //===----------------------------------------------------------------------===//
1627 
1628 namespace {
1629 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
1630   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1631 
matchAndRewrite__anon9c47bb971211::EraseDeadLinalgOp1632   LogicalResult matchAndRewrite(LinalgOp op,
1633                                 PatternRewriter &rewriter) const override {
1634     for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
1635       // Linalg "inputs" may be either tensor or memref type.
1636       // tensor<0xelt_type> is a convention that may not always mean
1637       // "0 iterations". Only erase in cases we see memref<...x0x...>.
1638       auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
1639       if (!mt)
1640         continue;
1641       if (llvm::is_contained(op.getShape(opOperand), 0)) {
1642         rewriter.eraseOp(op);
1643         return success();
1644       }
1645     }
1646     return failure();
1647   }
1648 };
1649 
1650 struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
1651   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1652 
matchAndRewrite__anon9c47bb971211::FoldTensorCastProducerOp1653   LogicalResult matchAndRewrite(LinalgOp op,
1654                                 PatternRewriter &rewriter) const override {
1655     // If no operand comes from a tensor::CastOp and can be folded then fail.
1656     bool hasTensorCastOperand =
1657         llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
1658           if (opOperand->get().isa<BlockArgument>())
1659             return false;
1660           auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1661           return castOp && canFoldIntoConsumerOp(castOp);
1662         });
1663     if (!hasTensorCastOperand)
1664       return failure();
1665 
1666     SmallVector<Type, 4> newResultTypes;
1667     newResultTypes.reserve(op->getNumResults());
1668     SmallVector<Value, 4> newOperands;
1669     newOperands.reserve(op->getNumOperands());
1670     // Inputs may fold.
1671     for (OpOperand *opOperand : op.getInputOperands()) {
1672       auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1673       newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
1674                                 ? tensorCastOp.getSource()
1675                                 : opOperand->get());
1676     }
1677     // Init tensors may fold, in which case the resultType must also change.
1678     for (OpOperand *opOperand : op.getOutputOperands()) {
1679       auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1680       bool fold = canFoldIntoConsumerOp(tensorCastOp);
1681       newOperands.push_back(fold ? tensorCastOp.getOperand()
1682                                  : opOperand->get());
1683       newResultTypes.push_back(newOperands.back().getType());
1684     }
1685     // Clone op.
1686     Operation *newOp =
1687         op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
1688     SmallVector<Value, 4> replacements;
1689     replacements.reserve(newOp->getNumResults());
1690     for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
1691       Value oldResult = std::get<0>(result);
1692       Value newResult = std::get<1>(result);
1693       if (newResult.getType() != oldResult.getType()) {
1694         replacements.push_back(rewriter.create<tensor::CastOp>(
1695             op->getLoc(), oldResult.getType(), newResult));
1696       } else {
1697         replacements.push_back(newResult);
1698       }
1699     }
1700     rewriter.replaceOp(op, replacements);
1701 
1702     return success();
1703   }
1704 };
1705 
1706 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
1707 /// result that is more static than the linalg op.
1708 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
1709   using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1710 
matchAndRewrite__anon9c47bb971211::FoldTensorCastConsumerOp1711   LogicalResult matchAndRewrite(tensor::CastOp castOp,
1712                                 PatternRewriter &rewriter) const override {
1713     if (!tensor::canFoldIntoProducerOp(castOp))
1714       return failure();
1715 
1716     auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
1717     if (!linalgOp)
1718       return failure();
1719 
1720     // Cast can be in conditionally reachable region, if which case folding will
1721     // generate invalid code. Only conservatively fold ops in same block for
1722     // now.
1723     if (castOp->getBlock() != linalgOp->getBlock())
1724       return failure();
1725 
1726     OpBuilder::InsertionGuard guard(rewriter);
1727     rewriter.setInsertionPoint(linalgOp);
1728 
1729     Location loc = linalgOp.getLoc();
1730     OpResult resultValue = castOp.getSource().cast<OpResult>();
1731     unsigned resultNumber = resultValue.getResultNumber();
1732     auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
1733     // Replace the `outs` for the result with a `tensor.cast`. This cast is now
1734     // going from a more dynamic shape to a less dynamic shape. If the producer
1735     // for this cast, i.e. producer of the out operand, is also an operation
1736     // that folds with tensor.cast consumer (like this pattern), the cast will
1737     // continue to propagate as far up the stack as it can go.
1738     OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
1739     Value newOperand =
1740         rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
1741     SmallVector<Value> newOperands = linalgOp.getInputOperands();
1742     SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
1743     outputOperands[resultNumber] = newOperand;
1744     newOperands.append(outputOperands.begin(), outputOperands.end());
1745 
1746     SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
1747                                   linalgOp->result_type_end());
1748     resultTypes[resultNumber] = resultType;
1749     Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands);
1750 
1751     // Create a tensor.cast operation back to the original type.
1752     Value castBack = rewriter.create<tensor::CastOp>(
1753         loc, resultValue.getType(), newOp->getResult(resultNumber));
1754 
1755     SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
1756     results[resultNumber] = castBack;
1757     rewriter.replaceOp(linalgOp, results);
1758     rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
1759     return success();
1760   }
1761 };
1762 
1763 /// For each of the operand in `operands` this function maps the static sizes of
1764 /// dimensions to their affine dim expressions.
populateMap(LinalgOp linalgOp,ArrayRef<OpOperand * > operands,llvm::DenseMap<AffineExpr,int64_t> & affineExprToSize)1765 static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
1766                         llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
1767   for (OpOperand *opOperand : operands) {
1768     if (linalgOp.isScalar(opOperand))
1769       continue;
1770     Value src = opOperand->get();
1771     auto sourceType = src.getType().cast<RankedTensorType>();
1772     auto sourceMap = linalgOp.getTiedIndexingMap(opOperand);
1773 
1774     // Get the `sourceShape` of the `sourceType`. If the operand is a result of
1775     // `tensor.cast` operation and source of the cast operation has a static
1776     // shape, then assign it to the `sourceShape`.
1777     auto *parentOp = src.getDefiningOp();
1778     ArrayRef<int64_t> sourceShape = sourceType.getShape();
1779     if (parentOp) {
1780       if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
1781         Value castSource = castOp.getSource();
1782         auto castSourceType = castSource.getType().cast<RankedTensorType>();
1783         if (castSourceType.hasStaticShape())
1784           sourceShape = castSourceType.getShape();
1785       }
1786     }
1787 
1788     // If the source shape's dimension has a static shape, map the affine dim
1789     // expression to the known static size.
1790     for (unsigned i = 0; i < sourceShape.size(); i++) {
1791       if (sourceType.isDynamicDim(i))
1792         continue;
1793       if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
1794         affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
1795     }
1796   }
1797 }
1798 
1799 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
1800 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
1801 /// their result types is stored in `resultTypes`. If `opOperand` requires no
1802 /// change then `changeNeeded` is false and same operand is added in the
1803 /// `newOperands` list.
createNewOperandWithStaticSizes(Location loc,PatternRewriter & rewriter,OpOperand * opOperand,llvm::DenseMap<AffineExpr,int64_t> & affineExprToSize,LinalgOp linalgOp,SmallVector<Value> & newOperands,SmallVector<Type> & resultTypes,bool & changeNeeded)1804 static void createNewOperandWithStaticSizes(
1805     Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
1806     llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
1807     SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
1808     bool &changeNeeded) {
1809   Value src = opOperand->get();
1810   newOperands.push_back(src);
1811   if (linalgOp.isScalar(opOperand))
1812     return;
1813   auto sourceType = src.getType().cast<RankedTensorType>();
1814   Type resultType = sourceType;
1815   if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) {
1816     resultTypes.push_back(resultType);
1817     return;
1818   }
1819   ArrayRef<int64_t> sourceShape = sourceType.getShape();
1820   AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand);
1821   SmallVector<int64_t> newShape;
1822   // If operand is updated with new shape, `newOperandNeeded` will be
1823   // true.
1824   bool newOperandNeeded = false;
1825   for (unsigned i = 0; i < sourceShape.size(); i++) {
1826     int64_t dimShape = sourceShape[i];
1827     AffineExpr dimExpr = sourceMap.getResult(i);
1828     if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
1829         !sourceType.isDynamicDim(i)) {
1830       newShape.push_back(dimShape);
1831       continue;
1832     }
1833     // Dimension has a dynamic shape and corresponding affine dim
1834     // expression is present in the map. So assign the size for the
1835     // given affine dim expression to the dimension.
1836     newShape.push_back(affineExprToSize[dimExpr]);
1837     newOperandNeeded = true;
1838   }
1839   resultType = RankedTensorType::get(newShape, sourceType.getElementType());
1840   if (newOperandNeeded) {
1841     changeNeeded = true;
1842     // Get the new operand value given its size and element type by
1843     // casting it.
1844     Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
1845     unsigned index = opOperand->getOperandNumber();
1846     newOperands[index] = newOperand;
1847   }
1848   if (linalgOp.isOutputTensor(opOperand))
1849     resultTypes.push_back(resultType);
1850 }
1851 
1852 /// Static shapes for the operands can be inferred if any one of the operands
1853 /// have a static shape. This can be done by referring to the affine dim
1854 /// expressions for the operand.
1855 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
1856   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1857 
matchAndRewrite__anon9c47bb971211::InferStaticShapeOfOperands1858   LogicalResult matchAndRewrite(LinalgOp linalgOp,
1859                                 PatternRewriter &rewriter) const override {
1860     if (!linalgOp.hasTensorSemantics())
1861       return failure();
1862 
1863     // Maps must be projected permutations.
1864     if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
1865           return !map.isProjectedPermutation();
1866         }))
1867       return failure();
1868 
1869     // Maps affine dim expressions to the static size of that dimension.
1870     llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
1871     Location loc = linalgOp.getLoc();
1872 
1873     // For each of the affine dim expression, check if the size is known. If
1874     // known add that in the map.
1875     populateMap(linalgOp, linalgOp.getInputAndOutputOperands(),
1876                 affineExprToSize);
1877 
1878     SmallVector<Value> newOperands;
1879     SmallVector<Type> resultTypes;
1880 
1881     // `changeNeeded` is `false` if the operands of `linalgOp` require no
1882     // change in their types.
1883     bool changeNeeded = false;
1884     newOperands.reserve(linalgOp.getNumInputsAndOutputs());
1885     resultTypes.reserve(linalgOp.getNumOutputs());
1886 
1887     // Iterate over all the operands and update the static sizes.
1888     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
1889       createNewOperandWithStaticSizes(loc, rewriter, opOperand,
1890                                       affineExprToSize, linalgOp, newOperands,
1891                                       resultTypes, changeNeeded);
1892     }
1893 
1894     // If the generic op has all the required static information, no
1895     // canonicalization needed.
1896     if (!changeNeeded)
1897       return failure();
1898 
1899     // Clone op.
1900     Operation *newOp =
1901         linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands);
1902     SmallVector<Value> replacements;
1903     replacements.reserve(newOp->getNumResults());
1904     for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
1905       Value newResult = std::get<1>(it);
1906       Value oldResult = std::get<0>(it);
1907       Type newType = newResult.getType();
1908       Type oldType = oldResult.getType();
1909       replacements.push_back(
1910           (newType != oldType)
1911               ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
1912               : newResult);
1913     }
1914     rewriter.replaceOp(linalgOp, replacements);
1915     return success();
1916   }
1917 };
1918 
1919 } // namespace
1920 
1921 // All named ops canonicalizers and folders are auto-generated in the
1922 // .cpp.inc.
1923 
1924 //===----------------------------------------------------------------------===//
1925 // LinalgDialect
1926 //===----------------------------------------------------------------------===//
1927 
getCanonicalizationPatterns(RewritePatternSet & results) const1928 void LinalgDialect::getCanonicalizationPatterns(
1929     RewritePatternSet &results) const {
1930   results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
1931               FoldTensorCastProducerOp, InferStaticShapeOfOperands>(
1932       getContext());
1933 }
1934 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)1935 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
1936                                               Attribute value, Type type,
1937                                               Location loc) {
1938   return builder.create<arith::ConstantOp>(loc, type, value);
1939 }
1940