1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/PDL/IR/PDL.h"
14 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Parser/Parser.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 using namespace mlir::transform;
22 
23 /// Extracts a vector of int64_t from an array attribute. Asserts if the
24 /// attribute contains values other than integers.
25 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
26   SmallVector<int64_t> result;
27   result.reserve(attr.size());
28   for (APInt value : attr.getAsValueRange<IntegerAttr>())
29     result.push_back(value.getSExtValue());
30   return result;
31 }
32 
33 /// Extracts a vector of unsigned from an array attribute. Asserts if the
34 /// attribute contains values other than intergers. May truncate.
35 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
36   SmallVector<unsigned> result;
37   result.reserve(attr.size());
38   for (APInt value : attr.getAsValueRange<IntegerAttr>())
39     result.push_back(value.getZExtValue());
40   return result;
41 }
42 
43 namespace {
44 /// A simple pattern rewriter that implements no special logic.
45 class SimpleRewriter : public PatternRewriter {
46 public:
47   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
48 };
49 } // namespace
50 
51 /// Attempts to apply the pattern specified as template argument to the given
52 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
53 /// function that returns the "main" result or failure. Returns failure if the
54 /// pattern failed to apply. Extra arguments are forwarded to the pattern
55 /// constructor.
56 template <typename PatternTy, typename... Args>
57 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
58   // Check if the given operation has the type expected by the pattern.
59   using OpTy = typename llvm::function_traits<
60       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
61   auto op = dyn_cast<OpTy>(operation);
62   if (!op)
63     return failure();
64 
65   // Apply the pattern directly to the op.
66   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
67   SimpleRewriter rewriter(operation->getContext());
68   rewriter.setInsertionPoint(operation);
69   auto result = pattern.returningMatchAndRewrite(op, rewriter);
70   if (failed(result))
71     return failure();
72   return cast<LinalgOp>(result->getOperation());
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // DecomposeOp
77 //===----------------------------------------------------------------------===//
78 
79 FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
80   FailureOr<LinalgOp> windowed =
81       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
82   if (succeeded(windowed))
83     return windowed;
84 
85   FailureOr<LinalgOp> depthwise =
86       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
87   if (succeeded(depthwise))
88     return depthwise;
89 
90   return reportUnknownTransformError(target);
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // FuseOp
95 //===----------------------------------------------------------------------===//
96 
97 /// Apply a tiling transformation to all payload ops and store both the
98 /// tiled operation as well as the created tile loops.
99 static LogicalResult
100 applyTilingToAll(Operation *transformOp, Value target,
101                  ArrayRef<int64_t> tileSizes,
102                  transform::TransformResults &transformResults,
103                  transform::TransformState &state,
104                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
105   // Number of loops: Number of tiles sizes that are not zero.
106   size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
107   // All payload ops. These should all be LinalgOps for now.
108   ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
109 
110   SmallVector<Operation *> tiledLinalgOps;
111   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
112   for (unsigned int i = 0; i < numLoops; ++i)
113     loopOps[i].reserve(payloadOps.size());
114 
115   for (Operation *target : payloadOps) {
116     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
117     if (!linalgOp)
118       return transformOp->emitError("only LinalgOps are supported");
119 
120     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
121     if (failed(tiled))
122       return failure();
123 
124     tiledLinalgOps.push_back(tiled->op);
125     if (tiled->loops.size() != numLoops)
126       // Not enough loops were generated. This usually means that the input size
127       // was smaller than the tiling size.
128       // TODO: LinalgTilingPattern should return failure().
129       return failure();
130     for (unsigned int i = 0; i < numLoops; ++i)
131       loopOps[i].push_back(tiled->loops[i]);
132   }
133 
134   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
135   for (unsigned int i = 0; i < numLoops; ++i)
136     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
137   return success();
138 }
139 
140 /// Parse a tiling-like operation that returns the tiled op as well as the
141 /// created tile loops. The function counts the non-zero tile sizes to compute
142 /// the number of results.
143 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
144                                    StringRef sizesAttrName) {
145   OpAsmParser::UnresolvedOperand targetOperand;
146   SMLoc opLoc = parser.getCurrentLocation();
147   if (parser.parseOperand(targetOperand) ||
148       parser.parseOptionalAttrDict(result.attributes))
149     return failure();
150   Attribute sizesAttr = result.attributes.get(sizesAttrName);
151   if (!sizesAttr)
152     return parser.emitError(opLoc)
153            << "expected '" << sizesAttrName << "' attribute";
154   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
155   if (!sizesArrayAttr)
156     return parser.emitError(opLoc)
157            << "'" << sizesAttrName << "' attribute must be an array";
158   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
159   size_t numExpectedLoops =
160       sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
161   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
162   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
163     return failure();
164   return success();
165 }
166 
167 DiagnosedSilencableFailure
168 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
169                          mlir::transform::TransformState &state) {
170   LinalgTilingAndFusionOptions fusionOptions;
171   fusionOptions.tileSizes = extractI64Array(getTileSizes());
172   fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
173 
174   LogicalResult result = applyTilingToAll(
175       getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
176       state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
177         LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
178         SimpleRewriter rewriter(getContext());
179         rewriter.setInsertionPoint(linalgOp);
180         FailureOr<TileLoopNest> tileLoopNest =
181             pattern.returningMatchAndRewrite(linalgOp, rewriter);
182         if (failed(tileLoopNest))
183           return failure();
184 
185         TiledLinalgOp tiledLinalgOp;
186         tiledLinalgOp.op = tileLoopNest->getRootOp();
187         tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
188                                tileLoopNest->getLoopOps().end()};
189         return tiledLinalgOp;
190       });
191   return failed(result) ? DiagnosedSilencableFailure::definiteFailure()
192                         : DiagnosedSilencableFailure::success();
193 }
194 
195 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
196                                      OperationState &result) {
197   return parseTileLikeOp(
198       parser, result,
199       transform::FuseOp::getTileSizesAttrName(result.name).getValue());
200 }
201 
202 void transform::FuseOp::print(OpAsmPrinter &p) {
203   p << ' ';
204   p << getTarget();
205   p.printOptionalAttrDict((*this)->getAttrs());
206 }
207 
208 LogicalResult transform::FuseOp::verify() {
209   SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
210   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
211   if (!std::is_permutation(sequence.begin(), sequence.end(),
212                            permutation.begin(), permutation.end())) {
213     return emitOpError() << "expects interchange to be a permutation, found "
214                          << getTileInterchange();
215   }
216   return success();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // GeneralizeOp
221 //===----------------------------------------------------------------------===//
222 
223 FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
224   // Exit early if no transformation is needed.
225   if (isa<GenericOp>(target))
226     return target;
227 
228   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
229   if (succeeded(generic))
230     return generic;
231 
232   return reportUnknownTransformError(target);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // InterchangeOp
237 //===----------------------------------------------------------------------===//
238 
239 FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
240   SmallVector<unsigned> interchangeVector =
241       extractUIntArray(getIteratorInterchange());
242   // Exit early if no transformation is needed.
243   if (interchangeVector.empty())
244     return target;
245 
246   auto genericTarget = dyn_cast<GenericOp>(target.getOperation());
247   if (!genericTarget) {
248     InFlightDiagnostic diag = emitOpError()
249                               << "applies to " << GenericOp::getOperationName()
250                               << " ops";
251     diag.attachNote(target.getLoc()) << "attempted to apply to this op";
252     return diag;
253   }
254 
255   return tryApply<GenericOpInterchangePattern>(target, interchangeVector);
256 }
257 
258 LogicalResult transform::InterchangeOp::verify() {
259   SmallVector<unsigned> permutation =
260       extractUIntArray(getIteratorInterchange());
261   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
262   if (!std::is_permutation(sequence.begin(), sequence.end(),
263                            permutation.begin(), permutation.end())) {
264     return emitOpError()
265            << "expects iterator_interchange to be a permutation, found "
266            << getIteratorInterchange();
267   }
268   return success();
269 }
270 
271 //===---------------------------------------------------------------------===//
272 // PadOp
273 //===---------------------------------------------------------------------===//
274 
275 FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
276   // Convert the integer packing flags to booleans.
277   SmallVector<bool> packPaddings;
278   for (int64_t packPadding : extractI64Array(getPackPaddings()))
279     packPaddings.push_back(static_cast<bool>(packPadding));
280 
281   // Convert the padding values to attributes.
282   SmallVector<Attribute> paddingValues;
283   for (auto const &it :
284        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
285     Attribute attr = std::get<0>(it);
286     Type elementType = getElementTypeOrSelf(std::get<1>(it));
287     // Try to parse string attributes to obtain an attribute of element type.
288     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
289       paddingValues.push_back(
290           parseAttribute(attr.cast<StringAttr>(), elementType));
291       if (!paddingValues.back()) {
292         InFlightDiagnostic diag = emitOpError()
293                                   << "expects a padding value that parses to "
294                                   << elementType << ", got " << std::get<0>(it);
295         diag.attachNote(target.getLoc()) << "when applied to this op";
296         return diag;
297       }
298       continue;
299     }
300     // Otherwise, add the attribute directly.
301     if (attr.getType() != elementType) {
302       InFlightDiagnostic diag = emitOpError()
303                                 << "expects a padding value of type "
304                                 << elementType << ", got " << attr;
305       diag.attachNote(target.getLoc()) << "when applied to this op";
306       return diag;
307     }
308     paddingValues.push_back(attr);
309   }
310 
311   // Extract the transpose vectors.
312   SmallVector<SmallVector<int64_t>> transposePaddings;
313   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
314     transposePaddings.push_back(
315         extractI64Array(transposeVector.cast<ArrayAttr>()));
316 
317   LinalgPaddingOptions paddingOptions;
318   paddingOptions.setPaddingValues(paddingValues);
319   paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
320   paddingOptions.setPackPaddings(packPaddings);
321   paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
322   paddingOptions.setTransposePaddings(transposePaddings);
323 
324   FailureOr<LinalgOp> result =
325       tryApply<LinalgPaddingPattern>(target, paddingOptions);
326   if (succeeded(result))
327     return result;
328 
329   InFlightDiagnostic diag = emitError()
330                             << "failed to apply pattern to target op";
331   diag.attachNote(target.getLoc()) << "target op";
332   return diag;
333 }
334 
335 LogicalResult transform::PadOp::verify() {
336   SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
337   if (any_of(packPaddings, [](int64_t packPadding) {
338         return packPadding != 0 && packPadding != 1;
339       })) {
340     return emitOpError()
341            << "expects pack_paddings to contain booleans (0/1), found "
342            << getPackPaddings();
343   }
344 
345   SmallVector<int64_t> paddingDimensions =
346       extractI64Array(getPaddingDimensions());
347   if (any_of(paddingDimensions,
348              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
349     return emitOpError()
350            << "expects padding_dimensions to contain positive integers, found "
351            << getPaddingDimensions();
352   }
353 
354   SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
355   if (any_of(hoistPaddings,
356              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
357     return emitOpError()
358            << "expects hoist_paddings to contain positive integers, found "
359            << getHoistPaddings();
360   }
361 
362   ArrayAttr transposes = getTransposePaddings();
363   for (Attribute attr : transposes) {
364     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
365     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
366     if (!std::is_permutation(sequence.begin(), sequence.end(),
367                              transpose.begin(), transpose.end())) {
368       return emitOpError()
369              << "expects transpose_paddings to be a permutation, found "
370              << attr;
371     }
372   }
373   return success();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // ScalarizeOp
378 //===----------------------------------------------------------------------===//
379 
380 FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
381   LinalgTilingOptions tilingOptions;
382   tilingOptions.scalarizeDynamicDims();
383   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
384   // sizes and asserts that it is not already set.
385   SmallVector<int64_t> emptyTileSizes;
386   LinalgTilingPattern pattern(getContext(), tilingOptions);
387   SimpleRewriter rewriter(getContext());
388   rewriter.setInsertionPoint(target);
389   FailureOr<TiledLinalgOp> result =
390       pattern.returningMatchAndRewrite(target, rewriter);
391   if (failed(result))
392     return failure();
393 
394   return result->op;
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // TileOp
399 //===----------------------------------------------------------------------===//
400 
401 DiagnosedSilencableFailure
402 transform::TileOp::apply(TransformResults &transformResults,
403                          TransformState &state) {
404   LinalgTilingOptions tilingOptions;
405   SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
406 
407   if (!tileSizes.empty())
408     tilingOptions.setTileSizes(tileSizes);
409   tilingOptions.setInterchange(extractUIntArray(getInterchange()));
410   LinalgTilingPattern pattern(getContext(), tilingOptions);
411 
412   LogicalResult result = applyTilingToAll(
413       getOperation(), getTarget(), tileSizes, transformResults, state,
414       [&](LinalgOp linalgOp) {
415         SimpleRewriter rewriter(linalgOp.getContext());
416         return pattern.returningMatchAndRewrite(linalgOp, rewriter);
417       });
418   return DiagnosedSilencableFailure(result);
419 }
420 
421 ParseResult transform::TileOp::parse(OpAsmParser &parser,
422                                      OperationState &result) {
423   return parseTileLikeOp(parser, result,
424                          TileOp::getSizesAttrName(result.name).getValue());
425 }
426 
427 void TileOp::print(OpAsmPrinter &p) {
428   p << ' ';
429   p << getTarget();
430   p.printOptionalAttrDict((*this)->getAttrs());
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // VectorizeOp
435 //===----------------------------------------------------------------------===//
436 
437 FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) {
438   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
439     InFlightDiagnostic diag = emitOpError()
440                               << "applies only to isolated-from-above targets";
441     diag.attachNote(target->getLoc()) << "non-isolated target";
442     return diag;
443   }
444 
445   MLIRContext *ctx = getContext();
446   RewritePatternSet patterns(ctx);
447   patterns.add<LinalgVectorizationPattern>(ctx);
448 
449   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
450   vector::populateVectorReductionToContractPatterns(patterns);
451   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
452                linalg::LinalgCopyVTWForwardingPattern>(ctx,
453                                                        /*benefit=*/2);
454   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
455   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
456   if (getVectorizePadding())
457     linalg::populatePadOpVectorizationPatterns(patterns);
458 
459   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
460     return reportUnknownTransformError(target);
461   return target;
462 }
463 
464 //===----------------------------------------------------------------------===//
465 // Transform op registration
466 //===----------------------------------------------------------------------===//
467 
468 namespace {
469 /// Registers new ops and declares PDL as dependent dialect since the additional
470 /// ops are using PDL types for operands and results.
471 class LinalgTransformDialectExtension
472     : public transform::TransformDialectExtension<
473           LinalgTransformDialectExtension> {
474 public:
475   LinalgTransformDialectExtension() {
476     declareDependentDialect<pdl::PDLDialect>();
477     declareDependentDialect<scf::SCFDialect>();
478     declareDependentDialect<vector::VectorDialect>();
479     registerTransformOps<
480 #define GET_OP_LIST
481 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
482         >();
483   }
484 };
485 } // namespace
486 
487 #define GET_OP_CLASSES
488 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
489 
490 void mlir::linalg::registerTransformDialectExtension(
491     DialectRegistry &registry) {
492   registry.addExtensions<LinalgTransformDialectExtension>();
493 }
494