1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14
15 #include "mlir/AsmParser/AsmParser.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
19 #include "mlir/Dialect/Complex/IR/Complex.h"
20 #include "mlir/Dialect/Math/IR/Math.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
26 #include "mlir/Dialect/Utils/StaticValueUtils.h"
27 #include "mlir/IR/AffineExprVisitor.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/IR/OpImplementation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/Interfaces/InferTypeOpInterface.h"
33
34 #include "llvm/ADT/DenseMap.h"
35 #include "llvm/ADT/SetVector.h"
36 #include "llvm/ADT/SmallSet.h"
37 #include "llvm/ADT/StringSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/MathExtras.h"
41 #include "llvm/Support/raw_ostream.h"
42
43 using namespace mlir;
44 using namespace mlir::linalg;
45
46 //===----------------------------------------------------------------------===//
47 // Support for named Linalg ops defined in ods-gen.
48 //===----------------------------------------------------------------------===//
49
50 using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
51 ArrayRef<NamedAttribute>)>;
52
53 /// Fills the region of a structured operation using the provided
54 /// `regionBuilder`. The method is used by both named structured ops created by
55 /// ods-gen and by manually defined C++ ops. It is called by both builders and
56 /// parsers and creates a block with arguments corresponding to the elemental
57 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
58 /// ShapedType.
fillStructuredOpRegion(OpBuilder & opBuilder,Region & region,TypeRange inputTypes,TypeRange outputTypes,ArrayRef<NamedAttribute> attrs,RegionBuilderFn regionBuilder)59 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
60 TypeRange inputTypes, TypeRange outputTypes,
61 ArrayRef<NamedAttribute> attrs,
62 RegionBuilderFn regionBuilder) {
63 assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
64
65 // TODO: atm all operands go through getElementTypeOrSelf,
66 // reconsider when we have evidence we need to.
67 SmallVector<Type, 8> argTypes;
68 SmallVector<Location, 8> argLocs;
69 for (auto containers : {inputTypes, outputTypes}) {
70 for (auto t : containers) {
71 argTypes.push_back(getElementTypeOrSelf(t));
72
73 // TODO: Pass in a proper location here.
74 argLocs.push_back(opBuilder.getUnknownLoc());
75 }
76 }
77
78 // RAII.
79 OpBuilder::InsertionGuard guard(opBuilder);
80 Block *body =
81 opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
82
83 opBuilder.setInsertionPointToStart(body);
84 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
85 regionBuilder(b, *body, attrs);
86
87 // indexing_maps is an auto-generated method.
88
89 // iterator_types is an auto-generated method.
90 }
91
92 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
93 /// The result types are derived automatically if `resultTensorTypes` is none.
94 /// The body of the operation is filled using `regionBuilder`. All ods-gen
95 /// created structured operations use the method to implement their builders.
buildStructuredOp(OpBuilder & b,OperationState & state,llvm::Optional<TypeRange> resultTensorTypes,ValueRange inputs,ValueRange outputs,ArrayRef<NamedAttribute> attributes,RegionBuilderFn regionBuilder)96 static void buildStructuredOp(OpBuilder &b, OperationState &state,
97 llvm::Optional<TypeRange> resultTensorTypes,
98 ValueRange inputs, ValueRange outputs,
99 ArrayRef<NamedAttribute> attributes,
100 RegionBuilderFn regionBuilder) {
101 // Derive the result types if needed.
102 SmallVector<Type> derivedResultTypes =
103 resultTensorTypes.value_or(TypeRange());
104 if (!resultTensorTypes)
105 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
106 [](Type type) { return type.isa<RankedTensorType>(); });
107
108 state.addOperands(inputs);
109 state.addOperands(outputs);
110 state.addTypes(derivedResultTypes);
111 state.addAttributes(attributes);
112 state.addAttribute(
113 "operand_segment_sizes",
114 b.getI32VectorAttr({static_cast<int32_t>(inputs.size()),
115 static_cast<int32_t>(outputs.size())}));
116
117 // Create and fill the region of the structured operation.
118 Region ®ion = *state.addRegion();
119 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
120 state.attributes.getAttrs(), regionBuilder);
121 }
122
123 /// Common parsing used for both named structured ops created by ods-gen and by
124 /// manually defined C++ ops. Does not handle regions.
125 static ParseResult
parseCommonStructuredOpParts(OpAsmParser & parser,OperationState & result,SmallVectorImpl<Type> & inputTypes,SmallVectorImpl<Type> & outputTypes)126 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
127 SmallVectorImpl<Type> &inputTypes,
128 SmallVectorImpl<Type> &outputTypes) {
129 SMLoc inputsOperandsLoc, outputsOperandsLoc;
130 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
131 outputsOperands;
132
133 if (parser.parseOptionalAttrDict(result.attributes))
134 return failure();
135
136 if (succeeded(parser.parseOptionalKeyword("ins"))) {
137 if (parser.parseLParen())
138 return failure();
139
140 inputsOperandsLoc = parser.getCurrentLocation();
141 if (parser.parseOperandList(inputsOperands) ||
142 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
143 return failure();
144 }
145
146 if (succeeded(parser.parseOptionalKeyword("outs"))) {
147 outputsOperandsLoc = parser.getCurrentLocation();
148 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
149 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
150 return failure();
151 }
152
153 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
154 result.operands) ||
155 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
156 result.operands))
157 return failure();
158
159 result.addAttribute("operand_segment_sizes",
160 parser.getBuilder().getI32VectorAttr(
161 {static_cast<int32_t>(inputsOperands.size()),
162 static_cast<int32_t>(outputsOperands.size())}));
163 return success();
164 }
165
printCommonStructuredOpParts(OpAsmPrinter & p,ValueRange inputs,ValueRange outputs)166 static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
167 ValueRange outputs) {
168 if (!inputs.empty())
169 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
170 if (!outputs.empty())
171 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
172 }
173
174 //===----------------------------------------------------------------------===//
175 // Specific parsing and printing for named structured ops created by ods-gen.
176 //===----------------------------------------------------------------------===//
177
parseNamedStructuredOpRegion(OpAsmParser & parser,Region & region,unsigned numRegionArgs,TypeRange inputTypes,TypeRange outputTypes,ArrayRef<NamedAttribute> attrs,RegionBuilderFn regionBuilder)178 static ParseResult parseNamedStructuredOpRegion(
179 OpAsmParser &parser, Region ®ion, unsigned numRegionArgs,
180 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
181 RegionBuilderFn regionBuilder) {
182 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
183 return parser.emitError(
184 parser.getCurrentLocation(),
185 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
186 "region expects {0} args, got {1}",
187 numRegionArgs, inputTypes.size() + outputTypes.size()));
188 }
189
190 OpBuilder opBuilder(parser.getContext());
191 fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
192 regionBuilder);
193 return success();
194 }
195
196 static ParseResult
parseNamedStructuredOpResults(OpAsmParser & parser,SmallVectorImpl<Type> & resultTypes)197 parseNamedStructuredOpResults(OpAsmParser &parser,
198 SmallVectorImpl<Type> &resultTypes) {
199 if (parser.parseOptionalArrowTypeList(resultTypes))
200 return failure();
201 return success();
202 }
203
parseNamedStructuredOp(OpAsmParser & parser,OperationState & result,unsigned numRegionArgs,RegionBuilderFn regionBuilder)204 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
205 OperationState &result,
206 unsigned numRegionArgs,
207 RegionBuilderFn regionBuilder) {
208 // TODO: Enable when ods-gen supports captures.
209 SmallVector<Type, 1> inputTypes, outputTypes;
210 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
211 return failure();
212
213 // TODO: consider merging results parsing into region parsing.
214 // Need to wait for declarative assembly resolution to decide.
215 SmallVector<Type, 1> outputTensorsTypes;
216 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
217 return failure();
218 result.addTypes(outputTensorsTypes);
219
220 std::unique_ptr<Region> region = std::make_unique<Region>();
221 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
222 outputTypes, result.attributes.getAttrs(),
223 regionBuilder))
224 return failure();
225 result.addRegion(std::move(region));
226
227 return success();
228 }
229
printNamedStructuredOpResults(OpAsmPrinter & p,TypeRange resultTypes)230 static void printNamedStructuredOpResults(OpAsmPrinter &p,
231 TypeRange resultTypes) {
232 if (resultTypes.empty())
233 return;
234 p.printOptionalArrowTypeList(resultTypes);
235 }
236
printNamedStructuredOp(OpAsmPrinter & p,Operation * op,ValueRange inputs,ValueRange outputs)237 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
238 ValueRange inputs, ValueRange outputs) {
239 p.printOptionalAttrDict(
240 op->getAttrs(),
241 /*elidedAttrs=*/{"operand_segment_sizes",
242 // See generated code in mlir-linalg-yaml-gen.cpp
243 "linalg.memoized_indexing_maps"});
244
245 // Printing is shared with generic ops, except for the region and
246 // attributes.
247 printCommonStructuredOpParts(p, inputs, outputs);
248
249 // Results printing.
250 printNamedStructuredOpResults(p, op->getResultTypes());
251
252 // Region is elided.
253 }
254
255 /// This is a common class used for patterns of the form
256 /// ```
257 /// someop(memrefcast(%src)) -> someop(%src)
258 /// ```
259 /// It folds the source of the memref.cast into the root operation directly.
foldMemRefCast(Operation * op)260 static LogicalResult foldMemRefCast(Operation *op) {
261 bool folded = false;
262 for (OpOperand &operand : op->getOpOperands()) {
263 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
264 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
265 operand.set(castOp.getOperand());
266 folded = true;
267 }
268 }
269 return success(folded);
270 }
271
272 //===----------------------------------------------------------------------===//
273 // Region builder helper.
274 // TODO: Move this to a utility library.
275 // The public methods on this class are referenced directly from generated code.
276 // Helper build the unary, binary, and type conversion functions defined by the
277 // DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
278 //
279 // Implementations of the math functions must be polymorphic over numeric types,
280 // internally performing necessary casts. If the function application makes no
281 // sense, then the only recourse is to assert and return nullptr. This can be
282 // extended later if it becomes possible to fail construction of the region. The
283 // invariant should be enforced at a higher level.
284 //
285 // TODO: These helpers are currently type polymorphic over the class of integer
286 // and floating point types, but they will not internally cast within bit
287 // widths of a class (mixed precision such as i8->i32) or across classes
288 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
289 // to be handled with care and work is being considered to extend the op
290 // language to make such cases explicit. In the mean-time, violating this will
291 // fail verification, which is deemed acceptable.
292 //===----------------------------------------------------------------------===//
293
294 namespace {
295
296 class RegionBuilderHelper {
297 public:
RegionBuilderHelper(MLIRContext * context,Block & block)298 RegionBuilderHelper(MLIRContext *context, Block &block)
299 : context(context), block(block) {}
300
301 // Build the unary functions defined by OpDSL.
buildUnaryFn(UnaryFn unaryFn,Value arg)302 Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
303 if (!isFloatingPoint(arg))
304 llvm_unreachable("unsupported non numeric type");
305 OpBuilder builder = getBuilder();
306 switch (unaryFn) {
307 case UnaryFn::exp:
308 return builder.create<math::ExpOp>(arg.getLoc(), arg);
309 case UnaryFn::log:
310 return builder.create<math::LogOp>(arg.getLoc(), arg);
311 case UnaryFn::abs:
312 return builder.create<math::AbsOp>(arg.getLoc(), arg);
313 case UnaryFn::ceil:
314 return builder.create<math::CeilOp>(arg.getLoc(), arg);
315 case UnaryFn::floor:
316 return builder.create<math::FloorOp>(arg.getLoc(), arg);
317 case UnaryFn::negf:
318 return builder.create<arith::NegFOp>(arg.getLoc(), arg);
319 }
320 llvm_unreachable("unsupported unary function");
321 }
322
323 // Build the binary functions defined by OpDSL.
buildBinaryFn(BinaryFn binaryFn,Value arg0,Value arg1)324 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
325 bool allComplex = isComplex(arg0) && isComplex(arg1);
326 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
327 bool allInteger = isInteger(arg0) && isInteger(arg1);
328 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
329 arg1.getType().getIntOrFloatBitWidth() == 1;
330 if (!allComplex && !allFloatingPoint && !allInteger)
331 llvm_unreachable("unsupported non numeric type");
332 OpBuilder builder = getBuilder();
333 switch (binaryFn) {
334 case BinaryFn::add:
335 if (allComplex)
336 return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
337 if (allFloatingPoint)
338 return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
339 if (allBool)
340 return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
341 return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
342 case BinaryFn::sub:
343 if (allComplex)
344 return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
345 if (allFloatingPoint)
346 return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
347 if (allBool)
348 llvm_unreachable("unsupported operation: sub with bools");
349 return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
350 case BinaryFn::mul:
351 if (allComplex)
352 return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
353 if (allFloatingPoint)
354 return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
355 if (allBool)
356 return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
357 return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
358 case BinaryFn::max_signed:
359 assert(!allComplex);
360 if (allFloatingPoint)
361 return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
362 return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
363 case BinaryFn::min_signed:
364 assert(!allComplex);
365 if (allFloatingPoint)
366 return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
367 return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
368 case BinaryFn::max_unsigned:
369 assert(!allComplex);
370 if (allFloatingPoint)
371 return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
372 return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
373 case BinaryFn::min_unsigned:
374 assert(!allComplex);
375 if (allFloatingPoint)
376 return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
377 return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
378 }
379 llvm_unreachable("unsupported binary function");
380 }
381
382 // Build the type functions defined by OpDSL.
buildTypeFn(TypeFn typeFn,Type toType,Value operand)383 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
384 switch (typeFn) {
385 case TypeFn::cast_signed:
386 return cast(toType, operand, false);
387 case TypeFn::cast_unsigned:
388 return cast(toType, operand, true);
389 }
390 llvm_unreachable("unsupported type conversion function");
391 }
392
yieldOutputs(ValueRange values)393 void yieldOutputs(ValueRange values) {
394 OpBuilder builder = getBuilder();
395 Location loc = builder.getUnknownLoc();
396 builder.create<YieldOp>(loc, values);
397 }
398
constant(const std::string & value)399 Value constant(const std::string &value) {
400 OpBuilder builder = getBuilder();
401 Location loc = builder.getUnknownLoc();
402 Attribute valueAttr = parseAttribute(value, builder.getContext());
403 return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
404 valueAttr);
405 }
406
index(int64_t dim)407 Value index(int64_t dim) {
408 OpBuilder builder = getBuilder();
409 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
410 }
411
getIntegerType(unsigned width)412 Type getIntegerType(unsigned width) {
413 return IntegerType::get(context, width);
414 }
415
getFloat32Type()416 Type getFloat32Type() { return Float32Type::get(context); }
getFloat64Type()417 Type getFloat64Type() { return Float64Type::get(context); }
418
419 private:
420 // Generates operations to cast the given operand to a specified type.
421 // If the cast cannot be performed, a warning will be issued and the
422 // operand returned as-is (which will presumably yield a verification
423 // issue downstream).
cast(Type toType,Value operand,bool isUnsignedCast)424 Value cast(Type toType, Value operand, bool isUnsignedCast) {
425 OpBuilder builder = getBuilder();
426 auto loc = operand.getLoc();
427
428 if (operand.getType() == toType)
429 return operand;
430 if (auto toIntType = toType.dyn_cast<IntegerType>()) {
431 // If operand is floating point, cast directly to the int type.
432 if (operand.getType().isa<FloatType>()) {
433 if (isUnsignedCast)
434 return builder.create<arith::FPToUIOp>(loc, toType, operand);
435 return builder.create<arith::FPToSIOp>(loc, toType, operand);
436 }
437 // Cast index operands directly to the int type.
438 if (operand.getType().isIndex())
439 return builder.create<arith::IndexCastOp>(loc, toType, operand);
440 if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
441 // Either extend or truncate.
442 if (toIntType.getWidth() > fromIntType.getWidth()) {
443 if (isUnsignedCast)
444 return builder.create<arith::ExtUIOp>(loc, toType, operand);
445 return builder.create<arith::ExtSIOp>(loc, toType, operand);
446 }
447 if (toIntType.getWidth() < fromIntType.getWidth())
448 return builder.create<arith::TruncIOp>(loc, toType, operand);
449 }
450 } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
451 // If operand is integer, cast directly to the float type.
452 // Note that it is unclear how to cast from BF16<->FP16.
453 if (operand.getType().isa<IntegerType>()) {
454 if (isUnsignedCast)
455 return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
456 return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
457 }
458 if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
459 if (toFloatType.getWidth() > fromFloatType.getWidth())
460 return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
461 if (toFloatType.getWidth() < fromFloatType.getWidth())
462 return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
463 }
464 }
465
466 emitWarning(operand.getLoc()) << "could not cast operand of type "
467 << operand.getType() << " to " << toType;
468 return operand;
469 }
470
isComplex(Value value)471 bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
isFloatingPoint(Value value)472 bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
isInteger(Value value)473 bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
474
getBuilder()475 OpBuilder getBuilder() {
476 OpBuilder builder(context);
477 builder.setInsertionPointToEnd(&block);
478 return builder;
479 }
480
481 MLIRContext *context;
482 Block █
483 };
484
485 } // namespace
486
487 //===----------------------------------------------------------------------===//
488 // FillOp
489 //===----------------------------------------------------------------------===//
490
491 namespace {
492
493 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
494 ///
495 /// For such op chains, we can create new linalg.fill ops with the result
496 /// type of the tensor.expand/collapse_shape op.
497 template <typename TensorReshapeOp>
498 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
499 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon9c47bb970411::FoldFillWithTensorReshape500 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
501 PatternRewriter &rewriter) const override {
502 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
503 if (!oldFill)
504 return failure();
505
506 Location loc = oldFill.getLoc();
507 auto newInit = rewriter.create<TensorReshapeOp>(
508 loc, reshapeOp.getResultType(), oldFill.output(),
509 reshapeOp.getReassociation());
510 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
511 ValueRange{newInit});
512
513 return success();
514 }
515 };
516
517 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
518 /// filling value are the same.
519 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
520 using OpRewritePattern::OpRewritePattern;
521
matchAndRewrite__anon9c47bb970411::FoldFillWithPad522 LogicalResult matchAndRewrite(tensor::PadOp padOp,
523 PatternRewriter &rewriter) const override {
524 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
525 if (!fillOp)
526 return failure();
527
528 // We can only fold if the padding value is the same as the original
529 // filling value.
530 Value padValue = padOp.getConstantPaddingValue();
531 if (!padValue || fillOp.value() != padValue)
532 return failure();
533
534 ReifiedRankedShapedTypeDims reifiedShape;
535 ReifyRankedShapedTypeOpInterface interface =
536 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
537 if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
538 return rewriter.notifyMatchFailure(
539 padOp, "failed to reify tensor.pad op result shape");
540
541 auto oldResultType = padOp.getResultType();
542 SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
543 ShapedType::kDynamicSize);
544 auto newInitOp = rewriter.create<InitTensorOp>(
545 padOp.getLoc(), reifiedShape.front(), staticShape,
546 oldResultType.getElementType());
547 auto newFillOp = rewriter.create<FillOp>(
548 fillOp.getLoc(), ValueRange{padValue}, ValueRange{newInitOp});
549 rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
550 newFillOp.result());
551
552 return success();
553 }
554 };
555
556 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
557 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
558 /// filling value are the same.
559 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
560 using OpRewritePattern::OpRewritePattern;
561
matchAndRewrite__anon9c47bb970411::FoldInsertPadIntoFill562 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
563 PatternRewriter &rewriter) const override {
564 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
565 if (!srcPadOp)
566 return failure();
567
568 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
569 return failure();
570
571 // Walk back the tensor.insert_slice chain and find the first destination
572 // value at the start of the chain.
573 Value firstDest = insertOp.getDest();
574 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
575 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
576 return failure();
577
578 // Make sure the range of values accessed are disjoint. Without this, we
579 // cannot fold tensor.pad away.
580 bool disjoint = false;
581 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
582 // If the dimension has dynamic offset/size, we cannot guarantee
583 // disjoint. So just skip it.
584 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
585 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
586 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
587 continue;
588
589 // Get the range start and end, inclusively for both.
590 int64_t prevStart = prevOp.getStaticOffset(i);
591 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
592 prevOp.getStaticStride(i);
593 int64_t nextStart = insertOp.getStaticOffset(i);
594 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
595 insertOp.getStaticStride(i);
596 if (prevEnd < nextStart || nextEnd < prevStart) {
597 disjoint = true;
598 break;
599 }
600 }
601
602 if (!disjoint)
603 break;
604 firstDest = prevOp.getDest();
605 }
606
607 // Check whether the first destination is a fill op. For overlapped cases,
608 // this also cannot be true.
609 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
610 if (!dstFillOp)
611 return failure();
612
613 // We can only fold if the padding value is the same as the original
614 // filling value.
615 Value padValue = srcPadOp.getConstantPaddingValue();
616 if (!padValue || dstFillOp.value() != padValue)
617 return failure();
618
619 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
620 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
621
622 Location loc = insertOp.getLoc();
623 MLIRContext *context = getContext();
624
625 AffineExpr sym0, sym1;
626 bindSymbols(context, sym0, sym1);
627 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
628
629 // Calculate the new offsets for the insert. It should be the old offsets
630 // plus low padding sizes.
631 SmallVector<OpFoldResult, 4> newOffsets;
632 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
633 Value padValue = getValueOrCreateConstantIndexOp(
634 rewriter, srcPadOp.getLoc(), std::get<0>(p));
635 Value offsetValue = getValueOrCreateConstantIndexOp(
636 rewriter, insertOp.getLoc(), std::get<1>(p));
637 newOffsets.push_back(
638 applyMapToValues(rewriter, loc, addMap, {offsetValue, padValue})[0]);
639 }
640
641 SmallVector<OpFoldResult, 4> newSizes;
642 for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
643 newSizes.push_back(
644 rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
645 .getResult());
646 }
647
648 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
649 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
650 newSizes, insertOp.getMixedStrides());
651 return success();
652 }
653 };
654
655 } // namespace
656
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)657 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
658 MLIRContext *context) {
659 results
660 .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
661 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
662 FoldInsertPadIntoFill>(context);
663 }
664
665 //===----------------------------------------------------------------------===//
666 // GenericOps
667 //===----------------------------------------------------------------------===//
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputs,ArrayAttr indexingMaps,ArrayAttr iteratorTypes,StringAttr doc,StringAttr libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)668 void GenericOp::build(
669 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
670 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
671 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
672 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
673 ArrayRef<NamedAttribute> attributes) {
674 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
675 iteratorTypes, doc, libraryCall);
676 result.addAttributes(attributes);
677 if (!bodyBuild)
678 return;
679
680 SmallVector<Type, 4> blockArgTypes;
681 SmallVector<Location, 4> blockArgLocs;
682 for (ValueRange container : {inputs, outputs}) {
683 for (Value v : container) {
684 blockArgTypes.push_back(getElementTypeOrSelf(v));
685 blockArgLocs.push_back(v.getLoc());
686 }
687 }
688
689 OpBuilder::InsertionGuard guard(builder);
690 auto ®ion = *result.regions.front();
691 Block *bodyBlock =
692 builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
693 bodyBuild(builder, result.location, bodyBlock->getArguments());
694 }
695
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputs,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)696 void GenericOp::build(
697 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
698 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
699 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
700 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
701 ArrayRef<NamedAttribute> attributes) {
702 build(builder, result, resultTensorTypes, inputs, outputs,
703 builder.getAffineMapArrayAttr(indexingMaps),
704 builder.getStrArrayAttr(iteratorTypes),
705 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
706 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
707 bodyBuild, attributes);
708 }
709
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputs,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)710 void GenericOp::build(
711 OpBuilder &builder, OperationState &result, ValueRange inputs,
712 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
713 ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
714 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
715 ArrayRef<NamedAttribute> attributes) {
716 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
717 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
718 }
719
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputs,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)720 void GenericOp::build(
721 OpBuilder &builder, OperationState &result, ValueRange inputs,
722 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
723 ArrayRef<StringRef> iteratorTypes,
724 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
725 ArrayRef<NamedAttribute> attributes) {
726 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
727 /*doc=*/"",
728 /*libraryCall=*/"", bodyBuild, attributes);
729 }
730
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputs,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild,ArrayRef<NamedAttribute> attributes)731 void GenericOp::build(
732 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
733 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
734 ArrayRef<StringRef> iteratorTypes,
735 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
736 ArrayRef<NamedAttribute> attributes) {
737 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
738 iteratorTypes,
739 /*doc=*/"",
740 /*libraryCall=*/"", bodyBuild, attributes);
741 }
742
print(OpAsmPrinter & p)743 void GenericOp::print(OpAsmPrinter &p) {
744 p << " ";
745
746 // Print extra attributes.
747 auto genericAttrNames = linalgTraitAttrNames();
748
749 llvm::StringSet<> genericAttrNamesSet;
750 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
751 SmallVector<NamedAttribute, 8> genericAttrs;
752 for (auto attr : (*this)->getAttrs())
753 if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
754 genericAttrs.push_back(attr);
755 if (!genericAttrs.empty()) {
756 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
757 p << genericDictAttr;
758 }
759
760 // Printing is shared with named ops, except for the region and attributes
761 printCommonStructuredOpParts(p, inputs(), outputs());
762
763 genericAttrNames.push_back("operand_segment_sizes");
764 genericAttrNamesSet.insert(genericAttrNames.back());
765
766 bool hasExtraAttrs = false;
767 for (NamedAttribute n : (*this)->getAttrs()) {
768 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
769 break;
770 }
771 if (hasExtraAttrs) {
772 p << " attrs = ";
773 p.printOptionalAttrDict((*this)->getAttrs(),
774 /*elidedAttrs=*/genericAttrNames);
775 }
776
777 // Print region.
778 if (!region().empty()) {
779 p << ' ';
780 p.printRegion(region());
781 }
782
783 // Print results.
784 printNamedStructuredOpResults(p, result_tensors().getTypes());
785 }
786
parse(OpAsmParser & parser,OperationState & result)787 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
788 DictionaryAttr dictAttr;
789 // Parse the core linalg traits that must check into a dictAttr.
790 // The name is unimportant as we will overwrite result.attributes.
791 // The core linalg traits must contain the information necessary to pass the
792 // verifier.
793 if (parser.parseAttribute(dictAttr, "_", result.attributes))
794 return failure();
795 result.attributes.assign(dictAttr.getValue().begin(),
796 dictAttr.getValue().end());
797
798 // Parsing is shared with named ops, except for the region.
799 SmallVector<Type, 1> inputTypes, outputTypes;
800 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
801 return failure();
802
803 // Optional attributes may be added.
804 if (succeeded(parser.parseOptionalKeyword("attrs")))
805 if (failed(parser.parseEqual()) ||
806 failed(parser.parseOptionalAttrDict(result.attributes)))
807 return failure();
808
809 std::unique_ptr<Region> region = std::make_unique<Region>();
810 if (parser.parseRegion(*region, {}))
811 return failure();
812 result.addRegion(std::move(region));
813
814 // Generic ops may specify that a subset of its outputs are tensors. Such
815 // outputs are specified in the result type.
816 // TODO: may need to move output parsing before region parsing.
817 // Need to wait for declarative assembly resolution to decide.
818 SmallVector<Type, 1> outputTensorsTypes;
819 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
820 return failure();
821 result.addTypes(outputTensorsTypes);
822
823 return success();
824 }
825
getGenericEffectsImpl(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects,ValueRange results,ValueRange inputBuffers,ValueRange outputs)826 static void getGenericEffectsImpl(
827 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
828 &effects,
829 ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
830 for (Value value : inputBuffers) {
831 effects.emplace_back(MemoryEffects::Read::get(), value,
832 SideEffects::DefaultResource::get());
833 }
834 for (Value value : outputs) {
835 effects.emplace_back(MemoryEffects::Read::get(), value,
836 SideEffects::DefaultResource::get());
837 effects.emplace_back(MemoryEffects::Write::get(), value,
838 SideEffects::DefaultResource::get());
839 }
840 }
841
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)842 void GenericOp::getEffects(
843 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
844 &effects) {
845 SmallVector<Value> inputBuffers = getInputBufferOperands();
846 SmallVector<Value> outputBuffers = getOutputBufferOperands();
847 getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
848 outputBuffers);
849 }
850
verify()851 LogicalResult GenericOp::verify() { return success(); }
852
853 namespace {
854
855 struct DeduplicateAndRemoveDeadOperandsAndResults
856 : public OpRewritePattern<GenericOp> {
857 using OpRewritePattern<GenericOp>::OpRewritePattern;
858
matchAndRewrite__anon9c47bb970511::DeduplicateAndRemoveDeadOperandsAndResults859 LogicalResult matchAndRewrite(GenericOp genericOp,
860 PatternRewriter &rewriter) const override {
861 // Create a map from argument position in the original op to the argument
862 // position in the new op. If the argument is dropped it wont have an entry.
863 SmallVector<OpOperand *> droppedOpOperands;
864
865 // Information needed to build the new op.
866 SmallVector<Value> newInputOperands, newOutputOperands;
867 SmallVector<AffineMap> newIndexingMaps;
868
869 // Gather information about duplicate input operands.
870 llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
871 deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
872 newIndexingMaps);
873
874 // Gather information about the dropped outputs.
875 llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
876 deduplicateOutputOperands(genericOp, droppedOpOperands,
877 newOutputOperands, newIndexingMaps);
878
879 // Check if there is any change to operands.
880 if (newInputOperands.size() + newOutputOperands.size() ==
881 static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
882 return failure();
883
884 // Create the new op with the body being empty.
885 Location loc = genericOp.getLoc();
886 SmallVector<Type> newResultTypes;
887 if (genericOp.hasTensorSemantics()) {
888 newResultTypes = llvm::to_vector(llvm::map_range(
889 newOutputOperands, [](Value v) { return v.getType(); }));
890 }
891 auto newOp = rewriter.create<GenericOp>(
892 loc, newResultTypes, newInputOperands, newOutputOperands,
893 rewriter.getAffineMapArrayAttr(newIndexingMaps),
894 genericOp.iterator_types(), genericOp.docAttr(),
895 genericOp.library_callAttr(),
896 [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
897 return;
898 });
899 // Copy over unknown attributes. They might be load bearing for some flow.
900 ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
901 for (NamedAttribute kv : genericOp->getAttrs())
902 if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
903 newOp->setAttr(kv.getName(), kv.getValue());
904
905 // Fix up the payload of the canonicalized operation.
906 populateOpPayload(genericOp, newOp, origInsToNewInsPos,
907 origOutsToNewOutsPos, rewriter);
908
909 // Replace all live uses of the op.
910 SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
911 for (auto result : llvm::enumerate(genericOp.getResults())) {
912 auto it = origOutsToNewOutsPos.find(result.index());
913 if (it == origOutsToNewOutsPos.end())
914 continue;
915 replacementsVals[result.index()] = newOp.getResult(it->second);
916 }
917 rewriter.replaceOp(genericOp, replacementsVals);
918 return success();
919 }
920
921 private:
922 // Deduplicate input operands, and return the
923 // - Mapping from operand position in the original op, to operand position in
924 // the canonicalized op.
925 // - The preserved input operands list (by reference).
926 llvm::SmallDenseMap<unsigned, unsigned>
deduplicateInputOperands__anon9c47bb970511::DeduplicateAndRemoveDeadOperandsAndResults927 deduplicateInputOperands(GenericOp genericOp,
928 SmallVector<OpOperand *> &droppedOpOperands,
929 SmallVector<Value> &newInputOperands,
930 SmallVector<AffineMap> &newIndexingMaps) const {
931 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
932 llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
933 for (auto inputOpOperand : llvm::enumerate(genericOp.getInputOperands())) {
934 // Check if operand is dead and if dropping the indexing map makes the
935 // loops to shape computation invalid.
936 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) {
937 // Add the current operands to the list of potentially droppable
938 // operands. If it cannot be dropped, this needs to be popped back.
939 droppedOpOperands.push_back(inputOpOperand.value());
940 if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
941 continue;
942 droppedOpOperands.pop_back();
943 }
944
945 // Check if this operand is a duplicate.
946 AffineMap indexingMap =
947 genericOp.getTiedIndexingMap(inputOpOperand.value());
948 auto it = dedupedInputs.find(
949 std::make_pair(inputOpOperand.value()->get(), indexingMap));
950 if (it != dedupedInputs.end()) {
951 origToNewPos[inputOpOperand.index()] = it->second;
952 droppedOpOperands.push_back(inputOpOperand.value());
953 continue;
954 }
955
956 // This is a preserved argument.
957 origToNewPos[inputOpOperand.index()] = newInputOperands.size();
958 dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] =
959 newInputOperands.size();
960 newInputOperands.push_back(inputOpOperand.value()->get());
961 newIndexingMaps.push_back(indexingMap);
962 }
963 return origToNewPos;
964 }
965
966 // Deduplicate output operands, and return the
967 // - Mapping from operand position in the original op, to operand position in
968 // the canonicalized op.
969 // - The preserved output operands list (by reference).
970 llvm::SmallDenseMap<unsigned, unsigned>
deduplicateOutputOperands__anon9c47bb970511::DeduplicateAndRemoveDeadOperandsAndResults971 deduplicateOutputOperands(GenericOp genericOp,
972 SmallVector<OpOperand *> &droppedOpOperands,
973 SmallVector<Value> &newOutputOperands,
974 SmallVector<AffineMap> &newIndexingMaps) const {
975 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
976 llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
977 dedupedOutpts;
978 // If the op doesnt have tensor semantics, keep all the outputs as
979 // preserved.
980 if (!genericOp.hasTensorSemantics()) {
981 for (auto outputOpOperand :
982 llvm::enumerate(genericOp.getOutputOperands())) {
983 origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
984 newOutputOperands.push_back(outputOpOperand.value()->get());
985 newIndexingMaps.push_back(
986 genericOp.getTiedIndexingMap(outputOpOperand.value()));
987 }
988 } else {
989 // Output argument can be dropped if the result has
990 // - no users, and
991 // - it is not used in the payload, and
992 // - the corresponding indexing maps are not needed for loop bound
993 // computation.
994 auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
995 for (auto outputOpOperand :
996 llvm::enumerate(genericOp.getOutputOperands())) {
997 Value result = genericOp.getResult(outputOpOperand.index());
998 AffineMap indexingMap =
999 genericOp.getTiedIndexingMap(outputOpOperand.value());
1000 auto key =
1001 std::make_tuple(outputOpOperand.value()->get(), indexingMap,
1002 yieldOp->getOperand(outputOpOperand.index()));
1003
1004 // Do not drop an out if its value is used in the payload.
1005 if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
1006 if (result.use_empty()) {
1007 // Check if the opoperand can be dropped without affecting loop
1008 // bound computation. Add the operand to the list of dropped op
1009 // operand for checking. If it cannot be dropped, need to pop the
1010 // value back.
1011 droppedOpOperands.push_back(outputOpOperand.value());
1012 if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
1013 continue;
1014 }
1015 droppedOpOperands.pop_back();
1016 }
1017
1018 // The out operand can also be dropped if it is computed redundantly
1019 // by another result, the conditions for that are
1020 // - The same operand is used as the out operand
1021 // - The same indexing map is used
1022 // - The same yield value is used.
1023 auto it = dedupedOutpts.find(key);
1024 if (it != dedupedOutpts.end()) {
1025 origToNewPos[outputOpOperand.index()] = it->second;
1026 droppedOpOperands.push_back(outputOpOperand.value());
1027 continue;
1028 }
1029 }
1030
1031 origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
1032 dedupedOutpts[key] = newOutputOperands.size();
1033 newOutputOperands.push_back(outputOpOperand.value()->get());
1034 newIndexingMaps.push_back(
1035 genericOp.getTiedIndexingMap(outputOpOperand.value()));
1036 }
1037 }
1038
1039 return origToNewPos;
1040 }
1041
1042 // Populate the body of the canonicalized operation.
populateOpPayload__anon9c47bb970511::DeduplicateAndRemoveDeadOperandsAndResults1043 void populateOpPayload(
1044 GenericOp genericOp, GenericOp newOp,
1045 const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
1046 const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
1047 PatternRewriter &rewriter) const {
1048 // Merge the body of the original op with the new op.
1049 Block *newOpBlock = &newOp.region().front();
1050 assert(newOpBlock->empty() && "expected new op to have an empty payload");
1051 Block *origOpBlock = &genericOp.region().front();
1052 SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
1053
1054 // Replace all arguments in the original op, with arguments from the
1055 // canonicalized op.
1056 auto updateReplacements =
1057 [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
1058 const llvm::SmallDenseMap<unsigned, unsigned> &map) {
1059 for (auto origOperand : llvm::enumerate(origOperands)) {
1060 auto it = map.find(origOperand.index());
1061 if (it == map.end())
1062 continue;
1063 OpOperand *newOperand = newOperands[it->second];
1064 replacements[origOperand.value()->getOperandNumber()] =
1065 newOpBlock->getArgument(newOperand->getOperandNumber());
1066 }
1067 };
1068
1069 OpOperandVector origInputOperands = genericOp.getInputOperands();
1070 OpOperandVector newInputOperands = newOp.getInputOperands();
1071 updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
1072
1073 OpOperandVector origOutputOperands = genericOp.getOutputOperands();
1074 OpOperandVector newOutputOperands = newOp.getOutputOperands();
1075 updateReplacements(origOutputOperands, newOutputOperands,
1076 origOutsToNewOutsPos);
1077
1078 rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
1079
1080 // Drop the unused yield args.
1081 if (newOp.getNumOutputs() != genericOp.getNumOutputs()) {
1082 OpBuilder::InsertionGuard g(rewriter);
1083 YieldOp origYieldOp = cast<YieldOp>(newOpBlock->getTerminator());
1084 rewriter.setInsertionPoint(origYieldOp);
1085
1086 SmallVector<Value> newYieldVals(newOp.getNumOutputs(), nullptr);
1087 for (const auto &yieldOpOperands :
1088 llvm::enumerate(origYieldOp.values())) {
1089 auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
1090 if (it == origOutsToNewOutsPos.end())
1091 continue;
1092 newYieldVals[it->second] = yieldOpOperands.value();
1093 }
1094 rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
1095 }
1096 }
1097 };
1098
1099 /// Remove generic operations (on tensors) that are just copying
1100 /// the values from inputs to the results. Requirements are
1101 /// 1) All iterator types are parallel
1102 /// 2) The body contains just a yield operation with the yielded values being
1103 /// the arguments corresponding to the operands.
1104 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1105 using OpRewritePattern<GenericOp>::OpRewritePattern;
1106
matchAndRewrite__anon9c47bb970511::EraseIdentityGenericOp1107 LogicalResult matchAndRewrite(GenericOp genericOp,
1108 PatternRewriter &rewriter) const override {
1109 // Check all indexing maps are identity.
1110 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1111 [](AffineMap map) { return !map.isIdentity(); }))
1112 return failure();
1113
1114 // Check that the body of the linalg operation is just a linalg.yield
1115 // operation.
1116 Block &body = genericOp.region().front();
1117 if (!llvm::hasSingleElement(body))
1118 return failure();
1119 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1120 if (!yieldOp)
1121 return failure();
1122
1123 // In the buffer case, we need to check exact buffer equality.
1124 if (genericOp.hasBufferSemantics()) {
1125 if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
1126 genericOp.getInputOperand(0)->get() ==
1127 genericOp.getOutputOperand(0)->get()) {
1128 rewriter.eraseOp(genericOp);
1129 return success();
1130 }
1131 return failure();
1132 }
1133
1134 // Get the argument number of the returned values. That is the operand
1135 // number to use for replacing uses of this operation.
1136 SmallVector<Value> returnedArgs;
1137 for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) {
1138 auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
1139 if (!yieldArg || yieldArg.getOwner() != &body)
1140 return failure();
1141 unsigned argumentNumber = yieldArg.getArgNumber();
1142 Value returnedArg = genericOp->getOperand(argumentNumber);
1143 Type resultType = genericOp->getResult(yieldVal.index()).getType();
1144 // The input can have a different type than the result, e.g. a dynamic
1145 // input dimension can be turned into a static output dimension.
1146 Type returnType = returnedArg.getType();
1147 if (returnType != resultType) {
1148 // Distinguish between sparse conversion or dense tensor casting.
1149 // TODO: unify the two ops?
1150 if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1151 sparse_tensor::getSparseTensorEncoding(resultType))
1152 returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1153 genericOp.getLoc(), resultType, returnedArg);
1154 else {
1155 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1156 resultType))
1157 return failure();
1158 returnedArg = rewriter.create<tensor::CastOp>(
1159 genericOp.getLoc(), resultType, returnedArg);
1160 }
1161 }
1162 returnedArgs.push_back(returnedArg);
1163 }
1164
1165 if (returnedArgs.size() != genericOp->getNumResults())
1166 return failure();
1167 rewriter.replaceOp(genericOp, returnedArgs);
1168 return success();
1169 }
1170 };
1171 } // namespace
1172
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1173 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1174 MLIRContext *context) {
1175 results
1176 .add<DeduplicateAndRemoveDeadOperandsAndResults, EraseIdentityGenericOp>(
1177 context);
1178 }
1179
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)1180 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
1181 SmallVectorImpl<OpFoldResult> &) {
1182 return foldMemRefCast(*this);
1183 }
1184
1185 //===----------------------------------------------------------------------===//
1186 // InitTensorOp
1187 //===----------------------------------------------------------------------===//
1188
build(OpBuilder & b,OperationState & result,ArrayRef<OpFoldResult> sizes,Type elementType,ArrayRef<NamedAttribute> attrs)1189 void InitTensorOp::build(OpBuilder &b, OperationState &result,
1190 ArrayRef<OpFoldResult> sizes, Type elementType,
1191 ArrayRef<NamedAttribute> attrs) {
1192 SmallVector<Value, 4> dynamicSizes;
1193 SmallVector<int64_t, 4> staticSizes;
1194 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1195 ShapedType::kDynamicSize);
1196 auto resultType = RankedTensorType ::get(staticSizes, elementType);
1197 build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
1198 result.addAttributes(attrs);
1199 }
1200
verify()1201 LogicalResult InitTensorOp::verify() {
1202 RankedTensorType resultType = getType();
1203 SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
1204 static_sizes().cast<ArrayAttr>(),
1205 [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
1206
1207 if (failed(verifyListOfOperandsOrIntegers(
1208 *this, "sizes", resultType.getRank(), static_sizes(), sizes(),
1209 ShapedType::isDynamic)))
1210 return failure();
1211
1212 if (static_sizes().size() != static_cast<unsigned>(resultType.getRank()))
1213 return emitError("expected ") << resultType.getRank() << " sizes values";
1214
1215 Type expectedType = InitTensorOp::inferResultType(
1216 staticSizes, resultType.getElementType(), resultType.getEncoding());
1217 if (resultType != expectedType) {
1218 return emitError("specified type ")
1219 << resultType << " does not match the inferred type "
1220 << expectedType;
1221 }
1222 return success();
1223 }
1224
inferResultType(ArrayRef<int64_t> staticSizes,Type elementType,Attribute encoding)1225 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
1226 Type elementType, Attribute encoding) {
1227 return RankedTensorType::get(staticSizes, elementType, encoding);
1228 }
1229
getMixedSizes()1230 SmallVector<OpFoldResult> InitTensorOp::getMixedSizes() {
1231 SmallVector<OpFoldResult> mixedSizes;
1232 mixedSizes.reserve(getType().getRank());
1233 unsigned dynamicValIndex = 0;
1234 for (Attribute attr : static_sizes()) {
1235 auto intAttr = attr.cast<IntegerAttr>();
1236 if (!ShapedType::isDynamic(intAttr.getInt())) {
1237 mixedSizes.push_back(intAttr);
1238 continue;
1239 }
1240 mixedSizes.push_back(sizes()[dynamicValIndex++]);
1241 }
1242 return mixedSizes;
1243 }
1244
1245 namespace {
1246 /// Change the type of the result of a `linalg.init_tensor` by making the result
1247 /// type statically sized along dimension that in the original operation where
1248 /// defined as dynamic, but the size was defined using a `constant` op. For
1249 /// example
1250 ///
1251 /// %c5 = arith.constant 5: index
1252 /// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
1253 ///
1254 /// to
1255 ///
1256 /// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
1257 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
1258 using OpRewritePattern<InitTensorOp>::OpRewritePattern;
1259
matchAndRewrite__anon9c47bb970b11::ReplaceStaticShapeDims1260 LogicalResult matchAndRewrite(InitTensorOp op,
1261 PatternRewriter &rewriter) const override {
1262 SmallVector<Value, 4> dynamicSizes;
1263 SmallVector<int64_t, 4> staticSizes;
1264 for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
1265 // If the size is already static, nothing to do.
1266 if (!op.isDynamicSize(i)) {
1267 staticSizes.push_back(op.getStaticSize(i));
1268 continue;
1269 }
1270
1271 // If the size is dynamic but defined using a `constant` op, get the
1272 // constant value to find the static size to use.
1273 unsigned operandNum = op.getIndexOfDynamicSize(i);
1274 Value sizeOperand = op.getOperand(operandNum);
1275 if (auto constantIndexOp =
1276 sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) {
1277 staticSizes.push_back(constantIndexOp.value());
1278 continue;
1279 }
1280
1281 // Fallback case. Keep the size dynamic.
1282 dynamicSizes.push_back(sizeOperand);
1283 staticSizes.push_back(ShapedType::kDynamicSize);
1284 }
1285 RankedTensorType newType =
1286 RankedTensorType::get(staticSizes, op.getType().getElementType());
1287 if (newType == op.getType())
1288 return failure();
1289 auto newOp =
1290 rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
1291 rewriter.getI64ArrayAttr(staticSizes));
1292 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1293 return success();
1294 }
1295 };
1296 } // namespace
1297
1298 namespace {
1299 /// Since `init_tensor` operation creates a tensor needed only for its shape, a
1300 /// slice of this is also needed only for its shape. The result can be
1301 /// replaced by a new init_tensor operation of the same size as the extract
1302 /// slice op.
1303 struct FoldInitTensorWithExtractSliceOp
1304 : public OpRewritePattern<tensor::ExtractSliceOp> {
1305 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
1306
matchAndRewrite__anon9c47bb970c11::FoldInitTensorWithExtractSliceOp1307 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
1308 PatternRewriter &rewriter) const override {
1309 if (!sliceOp.getSource().getDefiningOp<linalg::InitTensorOp>())
1310 return failure();
1311 // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
1312 // as well as its result type.
1313 rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
1314 sliceOp, sliceOp.getSizes(),
1315 sliceOp.getResult().getType().cast<RankedTensorType>().getShape(),
1316 sliceOp.getSourceType().getElementType());
1317 return success();
1318 }
1319 };
1320
1321 template <typename TensorReshapeOp>
1322 struct FoldInitTensorWithTensorReshapeOp
1323 : public OpRewritePattern<TensorReshapeOp> {
1324 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
1325
matchAndRewrite__anon9c47bb970c11::FoldInitTensorWithTensorReshapeOp1326 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1327 PatternRewriter &rewriter) const override {
1328 if (!reshapeOp.getSrc().template getDefiningOp<InitTensorOp>())
1329 return failure();
1330 Location loc = reshapeOp.getLoc();
1331 ReifiedRankedShapedTypeDims resultShapes;
1332 ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
1333 cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
1334 if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
1335 resultShapes)) ||
1336 !llvm::hasSingleElement(resultShapes))
1337 return failure();
1338 Value initTensor = rewriter.create<InitTensorOp>(
1339 loc, getAsOpFoldResult(resultShapes[0]),
1340 reshapeOp.getResultType().getElementType());
1341 if (initTensor.getType() != reshapeOp.getResultType()) {
1342 rewriter.replaceOpWithNewOp<tensor::CastOp>(
1343 reshapeOp, reshapeOp.getResultType(), initTensor);
1344 } else {
1345 rewriter.replaceOp(reshapeOp, initTensor);
1346 }
1347 return success();
1348 }
1349 };
1350
1351 struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
1352 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
1353
matchAndRewrite__anon9c47bb970c11::FoldInitTensorWithDimOp1354 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1355 PatternRewriter &rewriter) const override {
1356 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1357 auto initTensorOp = dimOp.getSource().getDefiningOp<linalg::InitTensorOp>();
1358 if (!initTensorOp || !maybeConstantIndex)
1359 return failure();
1360 if (!initTensorOp.isDynamicSize(*maybeConstantIndex))
1361 return failure();
1362 rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex));
1363 return success();
1364 }
1365 };
1366
1367 /// Canonicalize
1368 ///
1369 /// ```mlir
1370 /// %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
1371 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1372 /// ```
1373 ///
1374 /// into
1375 ///
1376 /// ```mlir
1377 /// %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32>
1378 /// ```
1379 ///
1380 /// This assumes the input program is correct in terms of its shape. So it
1381 /// is safe to assume that `%d0` is in fact 4. If that was not the case, the
1382 /// input program is wrong to begin with, so its undefined behavior anyway (i.e.
1383 /// this optimization can still triggering without violating program semantics).
1384 struct FoldInitTensorWithTensorCastOp
1385 : public OpRewritePattern<tensor::CastOp> {
1386 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1387
matchAndRewrite__anon9c47bb970c11::FoldInitTensorWithTensorCastOp1388 LogicalResult matchAndRewrite(tensor::CastOp castOp,
1389 PatternRewriter &rewriter) const override {
1390 if (!canFoldIntoProducerOp(castOp))
1391 return failure();
1392 auto producer = castOp.getSource().getDefiningOp<InitTensorOp>();
1393 if (!producer)
1394 return failure();
1395
1396 auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
1397 ArrayRef<int64_t> resultShape = resultType.getShape();
1398 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1399 SmallVector<OpFoldResult> newMixedSizes;
1400 newMixedSizes.reserve(currMixedSizes.size());
1401 assert(resultShape.size() == currMixedSizes.size() &&
1402 "mismatch in result shape and sizes of init_tensor op");
1403 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1404 int64_t newDim = std::get<0>(it);
1405 OpFoldResult currDim = std::get<1>(it);
1406 // Case 1: The init tensor dim is static. Check that the tensor cast
1407 // result dim matches.
1408 if (auto attr = currDim.dyn_cast<Attribute>()) {
1409 if (ShapedType::isDynamic(newDim) ||
1410 newDim != attr.cast<IntegerAttr>().getInt()) {
1411 // Something is off, the cast result shape cannot be more dynamic than
1412 // the init tensor result shape (enforced by `canFoldIntoProducer`).
1413 // Abort for now.
1414 return rewriter.notifyMatchFailure(
1415 producer, "mismatch in static value of shape of init "
1416 "tensor result and cast result");
1417 }
1418 newMixedSizes.push_back(attr);
1419 continue;
1420 }
1421
1422 // Case 2 : The tensor cast shape is static, but init tensor result shape
1423 // is dynamic.
1424 if (!ShapedType::isDynamic(newDim)) {
1425 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1426 continue;
1427 }
1428
1429 // Case 3 : The tensor cast shape is dynamic and init tensor result shape
1430 // is dynamic. Use the dynamic value from the init tensor op.
1431 newMixedSizes.push_back(currDim);
1432 }
1433
1434 rewriter.replaceOpWithNewOp<InitTensorOp>(castOp, newMixedSizes,
1435 resultType.getElementType());
1436 return success();
1437 }
1438 };
1439
1440 } // namespace
1441
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1442 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443 MLIRContext *context) {
1444 results.add<FoldInitTensorWithTensorCastOp, FoldInitTensorWithDimOp,
1445 FoldInitTensorWithExtractSliceOp,
1446 FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>,
1447 FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>,
1448 ReplaceStaticShapeDims>(context);
1449 }
1450
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & reifiedReturnShapes)1451 LogicalResult InitTensorOp::reifyResultShapes(
1452 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1453 auto shapes = llvm::to_vector<4>(llvm::map_range(
1454 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
1455 if (isDynamicSize(dim))
1456 return getDynamicSize(dim);
1457 return builder.create<arith::ConstantIndexOp>(getLoc(),
1458 getStaticSize(dim));
1459 }));
1460 reifiedReturnShapes.emplace_back(std::move(shapes));
1461 return success();
1462 }
1463
1464 //===----------------------------------------------------------------------===//
1465 // YieldOp
1466 //===----------------------------------------------------------------------===//
1467
print(OpAsmPrinter & p)1468 void linalg::YieldOp::print(OpAsmPrinter &p) {
1469 if (getNumOperands() > 0)
1470 p << ' ' << getOperands();
1471 p.printOptionalAttrDict((*this)->getAttrs());
1472 if (getNumOperands() > 0)
1473 p << " : " << getOperandTypes();
1474 }
1475
parse(OpAsmParser & parser,OperationState & result)1476 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
1477 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
1478 SmallVector<Type, 2> types;
1479 SMLoc loc = parser.getCurrentLocation();
1480 return failure(parser.parseOperandList(opInfo) ||
1481 parser.parseOptionalAttrDict(result.attributes) ||
1482 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1483 parser.resolveOperands(opInfo, types, loc, result.operands));
1484 }
1485
1486 // Check the operand number and types must match the element types of the
1487 // LinalgOp interface's shaped operands.
verifyYield(linalg::YieldOp op,LinalgOp linalgOp)1488 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1489 if (op.getNumOperands() != linalgOp.getNumOutputs())
1490 return op.emitOpError("expected number of yield values (")
1491 << linalgOp.getNumOutputs()
1492 << ") to match the number of operands of the enclosing "
1493 << "LinalgOp (" << op.getNumOperands() << ")";
1494
1495 for (OpOperand &opOperand : op->getOpOperands()) {
1496 OpOperand *outputOperand =
1497 linalgOp.getOutputOperand(opOperand.getOperandNumber());
1498 Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
1499 if (opOperand.get().getType() != elementType)
1500 return op.emitOpError("type of yield operand ")
1501 << (opOperand.getOperandNumber() + 1) << " ("
1502 << opOperand.get().getType() << ") doesn't match "
1503 << "the element type of the enclosing linalg.generic op ("
1504 << elementType << ")";
1505 }
1506 return success();
1507 }
1508
verify()1509 LogicalResult linalg::YieldOp::verify() {
1510 auto *parentOp = (*this)->getParentOp();
1511 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1512 return emitOpError("expected single non-empty parent region");
1513
1514 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1515 return verifyYield(*this, linalgOp);
1516
1517 return emitOpError("expected parent op with LinalgOp interface");
1518 }
1519
1520 //===----------------------------------------------------------------------===//
1521 // IndexOp
1522 //===----------------------------------------------------------------------===//
1523
verify()1524 LogicalResult IndexOp::verify() {
1525 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
1526 if (!linalgOp)
1527 return emitOpError("expected parent op with LinalgOp interface");
1528 if (linalgOp.getNumLoops() <= dim())
1529 return emitOpError("expected dim (")
1530 << dim() << ") to be lower than the number of loops ("
1531 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
1532 return success();
1533 }
1534
1535 /////// Operations corresponding to library calls defined with Tablegen ////////
1536
1537 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1538
1539 #define GET_OP_CLASSES
1540 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1541
1542 #define GET_OP_CLASSES
1543 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1544
1545 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
1546 /// Assumes `op` is a LinalgOp.
getDimsOfType(Operation * op,StringRef iteratorTypeName,SmallVectorImpl<unsigned> & res)1547 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
1548 SmallVectorImpl<unsigned> &res) {
1549 if (!cast<LinalgOp>(op).iterator_types())
1550 return;
1551
1552 unsigned dim = 0;
1553 for (auto tn :
1554 cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
1555 if (tn == iteratorTypeName)
1556 res.push_back(dim);
1557 ++dim;
1558 }
1559 }
1560
extractOrIdentityMap(Optional<AffineMap> maybeMap,unsigned rank,MLIRContext * context)1561 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
1562 unsigned rank,
1563 MLIRContext *context) {
1564 if (maybeMap)
1565 return *maybeMap;
1566 if (rank == 0)
1567 return AffineMap::get(context);
1568 return AffineMap::getMultiDimIdentityMap(rank, context);
1569 }
1570
1571 SmallVector<AffineExpr, 4>
makeAffineDimExprs(unsigned num,unsigned & startIdx,MLIRContext * context)1572 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
1573 MLIRContext *context) {
1574 SmallVector<AffineExpr, 4> res;
1575 res.reserve(num);
1576 for (unsigned i = 0; i < num; ++i)
1577 res.push_back(getAffineDimExpr(startIdx++, context));
1578 return res;
1579 }
1580
concat(ArrayRef<AffineExpr> a,ArrayRef<AffineExpr> b)1581 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
1582 ArrayRef<AffineExpr> b) {
1583 auto rangeA = llvm::make_range(a.begin(), a.end());
1584 auto rangeB = llvm::make_range(b.begin(), b.end());
1585 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1586 return llvm::to_vector<4>(concatRanges);
1587 }
1588
appendMangledType(llvm::raw_string_ostream & ss,Type t)1589 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
1590 if (auto memref = t.dyn_cast<MemRefType>()) {
1591 ss << "view";
1592 for (auto size : memref.getShape())
1593 if (size < 0)
1594 ss << "sx";
1595 else
1596 ss << size << "x";
1597 appendMangledType(ss, memref.getElementType());
1598 } else if (auto vec = t.dyn_cast<VectorType>()) {
1599 ss << "vector";
1600 llvm::interleave(
1601 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
1602 appendMangledType(ss, vec.getElementType());
1603 } else if (t.isSignlessIntOrIndexOrFloat()) {
1604 ss << t;
1605 } else {
1606 llvm_unreachable("Invalid type for linalg library name mangling");
1607 }
1608 }
1609
generateLibraryCallName(Operation * op)1610 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
1611 assert(isa<LinalgOp>(op));
1612 std::string name(op->getName().getStringRef().str());
1613 name.reserve(128);
1614 std::replace(name.begin(), name.end(), '.', '_');
1615 llvm::raw_string_ostream ss(name);
1616 ss << "_";
1617 auto types = op->getOperandTypes();
1618 llvm::interleave(
1619 types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
1620 [&]() { ss << "_"; });
1621 return ss.str();
1622 }
1623
1624 //===----------------------------------------------------------------------===//
1625 // Canonicalizers and Folders.
1626 //===----------------------------------------------------------------------===//
1627
1628 namespace {
1629 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
1630 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1631
matchAndRewrite__anon9c47bb971211::EraseDeadLinalgOp1632 LogicalResult matchAndRewrite(LinalgOp op,
1633 PatternRewriter &rewriter) const override {
1634 for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
1635 // Linalg "inputs" may be either tensor or memref type.
1636 // tensor<0xelt_type> is a convention that may not always mean
1637 // "0 iterations". Only erase in cases we see memref<...x0x...>.
1638 auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
1639 if (!mt)
1640 continue;
1641 if (llvm::is_contained(op.getShape(opOperand), 0)) {
1642 rewriter.eraseOp(op);
1643 return success();
1644 }
1645 }
1646 return failure();
1647 }
1648 };
1649
1650 struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
1651 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1652
matchAndRewrite__anon9c47bb971211::FoldTensorCastProducerOp1653 LogicalResult matchAndRewrite(LinalgOp op,
1654 PatternRewriter &rewriter) const override {
1655 // If no operand comes from a tensor::CastOp and can be folded then fail.
1656 bool hasTensorCastOperand =
1657 llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
1658 if (opOperand->get().isa<BlockArgument>())
1659 return false;
1660 auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1661 return castOp && canFoldIntoConsumerOp(castOp);
1662 });
1663 if (!hasTensorCastOperand)
1664 return failure();
1665
1666 SmallVector<Type, 4> newResultTypes;
1667 newResultTypes.reserve(op->getNumResults());
1668 SmallVector<Value, 4> newOperands;
1669 newOperands.reserve(op->getNumOperands());
1670 // Inputs may fold.
1671 for (OpOperand *opOperand : op.getInputOperands()) {
1672 auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1673 newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
1674 ? tensorCastOp.getSource()
1675 : opOperand->get());
1676 }
1677 // Init tensors may fold, in which case the resultType must also change.
1678 for (OpOperand *opOperand : op.getOutputOperands()) {
1679 auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1680 bool fold = canFoldIntoConsumerOp(tensorCastOp);
1681 newOperands.push_back(fold ? tensorCastOp.getOperand()
1682 : opOperand->get());
1683 newResultTypes.push_back(newOperands.back().getType());
1684 }
1685 // Clone op.
1686 Operation *newOp =
1687 op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
1688 SmallVector<Value, 4> replacements;
1689 replacements.reserve(newOp->getNumResults());
1690 for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
1691 Value oldResult = std::get<0>(result);
1692 Value newResult = std::get<1>(result);
1693 if (newResult.getType() != oldResult.getType()) {
1694 replacements.push_back(rewriter.create<tensor::CastOp>(
1695 op->getLoc(), oldResult.getType(), newResult));
1696 } else {
1697 replacements.push_back(newResult);
1698 }
1699 }
1700 rewriter.replaceOp(op, replacements);
1701
1702 return success();
1703 }
1704 };
1705
1706 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
1707 /// result that is more static than the linalg op.
1708 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
1709 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1710
matchAndRewrite__anon9c47bb971211::FoldTensorCastConsumerOp1711 LogicalResult matchAndRewrite(tensor::CastOp castOp,
1712 PatternRewriter &rewriter) const override {
1713 if (!tensor::canFoldIntoProducerOp(castOp))
1714 return failure();
1715
1716 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
1717 if (!linalgOp)
1718 return failure();
1719
1720 // Cast can be in conditionally reachable region, if which case folding will
1721 // generate invalid code. Only conservatively fold ops in same block for
1722 // now.
1723 if (castOp->getBlock() != linalgOp->getBlock())
1724 return failure();
1725
1726 OpBuilder::InsertionGuard guard(rewriter);
1727 rewriter.setInsertionPoint(linalgOp);
1728
1729 Location loc = linalgOp.getLoc();
1730 OpResult resultValue = castOp.getSource().cast<OpResult>();
1731 unsigned resultNumber = resultValue.getResultNumber();
1732 auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
1733 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
1734 // going from a more dynamic shape to a less dynamic shape. If the producer
1735 // for this cast, i.e. producer of the out operand, is also an operation
1736 // that folds with tensor.cast consumer (like this pattern), the cast will
1737 // continue to propagate as far up the stack as it can go.
1738 OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
1739 Value newOperand =
1740 rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
1741 SmallVector<Value> newOperands = linalgOp.getInputOperands();
1742 SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
1743 outputOperands[resultNumber] = newOperand;
1744 newOperands.append(outputOperands.begin(), outputOperands.end());
1745
1746 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
1747 linalgOp->result_type_end());
1748 resultTypes[resultNumber] = resultType;
1749 Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands);
1750
1751 // Create a tensor.cast operation back to the original type.
1752 Value castBack = rewriter.create<tensor::CastOp>(
1753 loc, resultValue.getType(), newOp->getResult(resultNumber));
1754
1755 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
1756 results[resultNumber] = castBack;
1757 rewriter.replaceOp(linalgOp, results);
1758 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
1759 return success();
1760 }
1761 };
1762
1763 /// For each of the operand in `operands` this function maps the static sizes of
1764 /// dimensions to their affine dim expressions.
populateMap(LinalgOp linalgOp,ArrayRef<OpOperand * > operands,llvm::DenseMap<AffineExpr,int64_t> & affineExprToSize)1765 static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
1766 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
1767 for (OpOperand *opOperand : operands) {
1768 if (linalgOp.isScalar(opOperand))
1769 continue;
1770 Value src = opOperand->get();
1771 auto sourceType = src.getType().cast<RankedTensorType>();
1772 auto sourceMap = linalgOp.getTiedIndexingMap(opOperand);
1773
1774 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
1775 // `tensor.cast` operation and source of the cast operation has a static
1776 // shape, then assign it to the `sourceShape`.
1777 auto *parentOp = src.getDefiningOp();
1778 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1779 if (parentOp) {
1780 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
1781 Value castSource = castOp.getSource();
1782 auto castSourceType = castSource.getType().cast<RankedTensorType>();
1783 if (castSourceType.hasStaticShape())
1784 sourceShape = castSourceType.getShape();
1785 }
1786 }
1787
1788 // If the source shape's dimension has a static shape, map the affine dim
1789 // expression to the known static size.
1790 for (unsigned i = 0; i < sourceShape.size(); i++) {
1791 if (sourceType.isDynamicDim(i))
1792 continue;
1793 if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
1794 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
1795 }
1796 }
1797 }
1798
1799 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
1800 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
1801 /// their result types is stored in `resultTypes`. If `opOperand` requires no
1802 /// change then `changeNeeded` is false and same operand is added in the
1803 /// `newOperands` list.
createNewOperandWithStaticSizes(Location loc,PatternRewriter & rewriter,OpOperand * opOperand,llvm::DenseMap<AffineExpr,int64_t> & affineExprToSize,LinalgOp linalgOp,SmallVector<Value> & newOperands,SmallVector<Type> & resultTypes,bool & changeNeeded)1804 static void createNewOperandWithStaticSizes(
1805 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
1806 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
1807 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
1808 bool &changeNeeded) {
1809 Value src = opOperand->get();
1810 newOperands.push_back(src);
1811 if (linalgOp.isScalar(opOperand))
1812 return;
1813 auto sourceType = src.getType().cast<RankedTensorType>();
1814 Type resultType = sourceType;
1815 if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) {
1816 resultTypes.push_back(resultType);
1817 return;
1818 }
1819 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1820 AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand);
1821 SmallVector<int64_t> newShape;
1822 // If operand is updated with new shape, `newOperandNeeded` will be
1823 // true.
1824 bool newOperandNeeded = false;
1825 for (unsigned i = 0; i < sourceShape.size(); i++) {
1826 int64_t dimShape = sourceShape[i];
1827 AffineExpr dimExpr = sourceMap.getResult(i);
1828 if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
1829 !sourceType.isDynamicDim(i)) {
1830 newShape.push_back(dimShape);
1831 continue;
1832 }
1833 // Dimension has a dynamic shape and corresponding affine dim
1834 // expression is present in the map. So assign the size for the
1835 // given affine dim expression to the dimension.
1836 newShape.push_back(affineExprToSize[dimExpr]);
1837 newOperandNeeded = true;
1838 }
1839 resultType = RankedTensorType::get(newShape, sourceType.getElementType());
1840 if (newOperandNeeded) {
1841 changeNeeded = true;
1842 // Get the new operand value given its size and element type by
1843 // casting it.
1844 Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
1845 unsigned index = opOperand->getOperandNumber();
1846 newOperands[index] = newOperand;
1847 }
1848 if (linalgOp.isOutputTensor(opOperand))
1849 resultTypes.push_back(resultType);
1850 }
1851
1852 /// Static shapes for the operands can be inferred if any one of the operands
1853 /// have a static shape. This can be done by referring to the affine dim
1854 /// expressions for the operand.
1855 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
1856 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
1857
matchAndRewrite__anon9c47bb971211::InferStaticShapeOfOperands1858 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1859 PatternRewriter &rewriter) const override {
1860 if (!linalgOp.hasTensorSemantics())
1861 return failure();
1862
1863 // Maps must be projected permutations.
1864 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
1865 return !map.isProjectedPermutation();
1866 }))
1867 return failure();
1868
1869 // Maps affine dim expressions to the static size of that dimension.
1870 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
1871 Location loc = linalgOp.getLoc();
1872
1873 // For each of the affine dim expression, check if the size is known. If
1874 // known add that in the map.
1875 populateMap(linalgOp, linalgOp.getInputAndOutputOperands(),
1876 affineExprToSize);
1877
1878 SmallVector<Value> newOperands;
1879 SmallVector<Type> resultTypes;
1880
1881 // `changeNeeded` is `false` if the operands of `linalgOp` require no
1882 // change in their types.
1883 bool changeNeeded = false;
1884 newOperands.reserve(linalgOp.getNumInputsAndOutputs());
1885 resultTypes.reserve(linalgOp.getNumOutputs());
1886
1887 // Iterate over all the operands and update the static sizes.
1888 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
1889 createNewOperandWithStaticSizes(loc, rewriter, opOperand,
1890 affineExprToSize, linalgOp, newOperands,
1891 resultTypes, changeNeeded);
1892 }
1893
1894 // If the generic op has all the required static information, no
1895 // canonicalization needed.
1896 if (!changeNeeded)
1897 return failure();
1898
1899 // Clone op.
1900 Operation *newOp =
1901 linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands);
1902 SmallVector<Value> replacements;
1903 replacements.reserve(newOp->getNumResults());
1904 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
1905 Value newResult = std::get<1>(it);
1906 Value oldResult = std::get<0>(it);
1907 Type newType = newResult.getType();
1908 Type oldType = oldResult.getType();
1909 replacements.push_back(
1910 (newType != oldType)
1911 ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
1912 : newResult);
1913 }
1914 rewriter.replaceOp(linalgOp, replacements);
1915 return success();
1916 }
1917 };
1918
1919 } // namespace
1920
1921 // All named ops canonicalizers and folders are auto-generated in the
1922 // .cpp.inc.
1923
1924 //===----------------------------------------------------------------------===//
1925 // LinalgDialect
1926 //===----------------------------------------------------------------------===//
1927
getCanonicalizationPatterns(RewritePatternSet & results) const1928 void LinalgDialect::getCanonicalizationPatterns(
1929 RewritePatternSet &results) const {
1930 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
1931 FoldTensorCastProducerOp, InferStaticShapeOfOperands>(
1932 getContext());
1933 }
1934
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)1935 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
1936 Attribute value, Type type,
1937 Location loc) {
1938 return builder.create<arith::ConstantOp>(loc, type, value);
1939 }
1940