1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/PDL/IR/PDL.h"
14 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 #include "mlir/Parser/Parser.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "llvm/Support/FormatVariadic.h"
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 using namespace mlir::transform;
24 
25 /// Extracts a vector of int64_t from an array attribute. Asserts if the
26 /// attribute contains values other than integers.
27 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
28   SmallVector<int64_t> result;
29   result.reserve(attr.size());
30   for (APInt value : attr.getAsValueRange<IntegerAttr>())
31     result.push_back(value.getSExtValue());
32   return result;
33 }
34 
35 /// Extracts a vector of unsigned from an array attribute. Asserts if the
36 /// attribute contains values other than intergers. May truncate.
37 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
38   SmallVector<unsigned> result;
39   result.reserve(attr.size());
40   for (APInt value : attr.getAsValueRange<IntegerAttr>())
41     result.push_back(value.getZExtValue());
42   return result;
43 }
44 
45 namespace {
46 /// A simple pattern rewriter that implements no special logic.
47 class SimpleRewriter : public PatternRewriter {
48 public:
49   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
50 };
51 } // namespace
52 
53 /// Attempts to apply the pattern specified as template argument to the given
54 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
55 /// function that returns the "main" result or failure. Returns failure if the
56 /// pattern failed to apply. Extra arguments are forwarded to the pattern
57 /// constructor.
58 template <typename PatternTy, typename... Args>
59 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
60   // Check if the given operation has the type expected by the pattern.
61   using OpTy = typename llvm::function_traits<
62       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
63   auto op = dyn_cast<OpTy>(operation);
64   if (!op)
65     return failure();
66 
67   // Apply the pattern directly to the op.
68   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
69   SimpleRewriter rewriter(operation->getContext());
70   rewriter.setInsertionPoint(operation);
71   auto result = pattern.returningMatchAndRewrite(op, rewriter);
72   if (failed(result))
73     return failure();
74   return cast<LinalgOp>(result->getOperation());
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // DecomposeOp
79 //===----------------------------------------------------------------------===//
80 
81 FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
82   FailureOr<LinalgOp> windowed =
83       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
84   if (succeeded(windowed))
85     return windowed;
86 
87   FailureOr<LinalgOp> depthwise =
88       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
89   if (succeeded(depthwise))
90     return depthwise;
91 
92   InFlightDiagnostic diag = emitError() << "failed to apply";
93   diag.attachNote(target.getLoc()) << "attempted to apply to this op";
94   return diag;
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // GeneralizeOp
99 //===----------------------------------------------------------------------===//
100 
101 FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
102   // Exit early if no transformation is needed.
103   if (isa<GenericOp>(target))
104     return target;
105 
106   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
107   if (succeeded(generic))
108     return generic;
109 
110   InFlightDiagnostic diag = emitError() << "failed to apply";
111   diag.attachNote(target.getLoc()) << "attempted to apply to this op";
112   return diag;
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // InterchangeOp
117 //===----------------------------------------------------------------------===//
118 
119 FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
120   SmallVector<unsigned> interchangeVector =
121       extractUIntArray(getIteratorInterchange());
122   // Exit early if no transformation is needed.
123   if (interchangeVector.empty())
124     return target;
125 
126   auto genericTarget = dyn_cast<GenericOp>(target.getOperation());
127   if (!genericTarget) {
128     InFlightDiagnostic diag = emitOpError()
129                               << "applies to " << GenericOp::getOperationName()
130                               << " ops";
131     diag.attachNote(target.getLoc()) << "attempted to apply to this op";
132     return diag;
133   }
134 
135   return tryApply<GenericOpInterchangePattern>(target, interchangeVector);
136 }
137 
138 LogicalResult transform::InterchangeOp::verify() {
139   SmallVector<unsigned> permutation =
140       extractUIntArray(getIteratorInterchange());
141   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
142   if (!std::is_permutation(sequence.begin(), sequence.end(),
143                            permutation.begin(), permutation.end())) {
144     return emitOpError()
145            << "expects iterator_interchange to be a permutation, found "
146            << getIteratorInterchange();
147   }
148   return success();
149 }
150 
151 //===---------------------------------------------------------------------===//
152 // PadOp
153 //===---------------------------------------------------------------------===//
154 
155 FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
156   // Convert the integer packing flags to booleans.
157   SmallVector<bool> packPaddings;
158   for (int64_t packPadding : extractI64Array(getPackPaddings()))
159     packPaddings.push_back(static_cast<bool>(packPadding));
160 
161   // Convert the padding values to attributes.
162   SmallVector<Attribute> paddingValues;
163   for (auto const &it :
164        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
165     Attribute attr = std::get<0>(it);
166     Type elementType = getElementTypeOrSelf(std::get<1>(it));
167     // Try to parse string attributes to obtain an attribute of element type.
168     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
169       paddingValues.push_back(
170           parseAttribute(attr.cast<StringAttr>(), elementType));
171       if (!paddingValues.back()) {
172         InFlightDiagnostic diag = emitOpError()
173                                   << "expects a padding value that parses to "
174                                   << elementType << ", got " << std::get<0>(it);
175         diag.attachNote(target.getLoc()) << "when applied to this op";
176         return diag;
177       }
178       continue;
179     }
180     // Otherwise, add the attribute directly.
181     if (attr.getType() != elementType) {
182       InFlightDiagnostic diag = emitOpError()
183                                 << "expects a padding value of type "
184                                 << elementType << ", got " << attr;
185       diag.attachNote(target.getLoc()) << "when applied to this op";
186       return diag;
187     }
188     paddingValues.push_back(attr);
189   }
190 
191   // Extract the transpose vectors.
192   SmallVector<SmallVector<int64_t>> transposePaddings;
193   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
194     transposePaddings.push_back(
195         extractI64Array(transposeVector.cast<ArrayAttr>()));
196 
197   LinalgPaddingOptions paddingOptions;
198   paddingOptions.setPaddingValues(paddingValues);
199   paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
200   paddingOptions.setPackPaddings(packPaddings);
201   paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
202   paddingOptions.setTransposePaddings(transposePaddings);
203 
204   FailureOr<LinalgOp> result =
205       tryApply<LinalgPaddingPattern>(target, paddingOptions);
206   if (succeeded(result))
207     return result;
208 
209   InFlightDiagnostic diag = emitError()
210                             << "failed to apply pattern to target op";
211   diag.attachNote(target.getLoc()) << "target op";
212   return diag;
213 }
214 
215 LogicalResult transform::PadOp::verify() {
216   SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
217   if (any_of(packPaddings, [](int64_t packPadding) {
218         return packPadding != 0 && packPadding != 1;
219       })) {
220     return emitOpError()
221            << "expects pack_paddings to contain booleans (0/1), found "
222            << getPackPaddings();
223   }
224 
225   SmallVector<int64_t> paddingDimensions =
226       extractI64Array(getPaddingDimensions());
227   if (any_of(paddingDimensions,
228              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
229     return emitOpError()
230            << "expects padding_dimensions to contain positive integers, found "
231            << getPaddingDimensions();
232   }
233 
234   SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
235   if (any_of(hoistPaddings,
236              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
237     return emitOpError()
238            << "expects hoist_paddings to contain positive integers, found "
239            << getHoistPaddings();
240   }
241 
242   ArrayAttr transposes = getTransposePaddings();
243   for (Attribute attr : transposes) {
244     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
245     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
246     if (!std::is_permutation(sequence.begin(), sequence.end(),
247                              transpose.begin(), transpose.end())) {
248       return emitOpError()
249              << "expects transpose_paddings to be a permutation, found "
250              << attr;
251     }
252   }
253   return success();
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // ScalarizeOp
258 //===----------------------------------------------------------------------===//
259 
260 FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
261   LinalgTilingOptions tilingOptions;
262   tilingOptions.scalarizeDynamicDims();
263   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
264   // sizes and asserts that it is not already set.
265   SmallVector<int64_t> emptyTileSizes;
266   LinalgTilingPattern pattern(getContext(), tilingOptions);
267   SimpleRewriter rewriter(getContext());
268   rewriter.setInsertionPoint(target);
269   FailureOr<TiledLinalgOp> result =
270       pattern.returningMatchAndRewrite(target, rewriter);
271   if (failed(result))
272     return failure();
273 
274   return result->op;
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // TileOp
279 //===----------------------------------------------------------------------===//
280 
281 /// Apply a tiling transformation to all payload ops and store both the
282 /// tiled operation as well as the created tile loops.
283 static LogicalResult
284 applyTilingToAll(Operation *transformOp, Value target,
285                  ArrayRef<int64_t> tileSizes,
286                  transform::TransformResults &transformResults,
287                  transform::TransformState &state,
288                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
289   // Number of loops: Number of tiles sizes that are not zero.
290   size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
291   // All payload ops. These should all be LinalgOps for now.
292   ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
293 
294   SmallVector<Operation *> tiledLinalgOps;
295   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
296   for (unsigned int i = 0; i < numLoops; ++i)
297     loopOps[i].reserve(payloadOps.size());
298 
299   for (Operation *target : payloadOps) {
300     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
301     if (!linalgOp)
302       return transformOp->emitError("only LinalgOps are supported");
303 
304     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
305     if (failed(tiled))
306       return failure();
307 
308     tiledLinalgOps.push_back(tiled->op);
309     if (tiled->loops.size() != numLoops)
310       // Not enough loops were generated. This usually means that the input size
311       // was smaller than the tiling size.
312       // TODO: LinalgTilingPattern should return failure().
313       return failure();
314     for (unsigned int i = 0; i < numLoops; ++i)
315       loopOps[i].push_back(tiled->loops[i]);
316   }
317 
318   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
319   for (unsigned int i = 0; i < numLoops; ++i)
320     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
321   return success();
322 }
323 
324 LogicalResult transform::TileOp::apply(TransformResults &transformResults,
325                                        TransformState &state) {
326   LinalgTilingOptions tilingOptions;
327   SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
328 
329   if (!tileSizes.empty())
330     tilingOptions.setTileSizes(tileSizes);
331   tilingOptions.setInterchange(extractUIntArray(getInterchange()));
332   LinalgTilingPattern pattern(getContext(), tilingOptions);
333 
334   return applyTilingToAll(getOperation(), getTarget(), tileSizes,
335                           transformResults, state, [&](LinalgOp linalgOp) {
336                             SimpleRewriter rewriter(linalgOp.getContext());
337                             return pattern.returningMatchAndRewrite(linalgOp,
338                                                                     rewriter);
339                           });
340 }
341 
342 ParseResult transform::TileOp::parse(OpAsmParser &parser,
343                                      OperationState &result) {
344   StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
345   OpAsmParser::UnresolvedOperand targetOperand;
346   SMLoc opLoc = parser.getCurrentLocation();
347   if (parser.parseOperand(targetOperand) ||
348       parser.parseOptionalAttrDict(result.attributes))
349     return failure();
350   Attribute sizesAttr = result.attributes.get(sizesAttrName);
351   if (!sizesAttr)
352     return parser.emitError(opLoc)
353            << "expected '" << sizesAttrName << "' attribute";
354   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
355   if (!sizesArrayAttr)
356     return parser.emitError(opLoc)
357            << "'" << sizesAttrName << "' attribute must be an array";
358   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
359   size_t numExpectedLoops =
360       sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
361   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
362   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
363     return failure();
364   return success();
365 }
366 
367 void TileOp::print(OpAsmPrinter &p) {
368   p << ' ';
369   p << getTarget();
370   p.printOptionalAttrDict((*this)->getAttrs());
371 }
372 
373 void TileOp::getEffects(
374     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
375         &effects) {
376   // `target` arg is consumed and can no longer be used.
377   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
378                        TransformMappingResource::get());
379   effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
380                        TransformMappingResource::get());
381 
382   for (Value r : getResults()) {
383     effects.emplace_back(MemoryEffects::Write::get(), r,
384                          TransformMappingResource::get());
385     effects.emplace_back(MemoryEffects::Allocate::get(), r,
386                          TransformMappingResource::get());
387   }
388 
389   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
390   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // VectorizeOp
395 //===----------------------------------------------------------------------===//
396 
397 FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) {
398   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
399     InFlightDiagnostic diag = emitOpError()
400                               << "applies only to isolated-from-above targets";
401     diag.attachNote(target->getLoc()) << "non-isolated target";
402     return diag;
403   }
404 
405   MLIRContext *ctx = getContext();
406   RewritePatternSet patterns(ctx);
407   patterns.add<LinalgVectorizationPattern>(ctx);
408 
409   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
410   vector::populateVectorReductionToContractPatterns(patterns);
411   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
412                linalg::LinalgCopyVTWForwardingPattern>(ctx,
413                                                        /*benefit=*/2);
414   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
415   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
416   if (getVectorizePadding())
417     linalg::populatePadOpVectorizationPatterns(patterns);
418 
419   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) {
420     InFlightDiagnostic diag = emitError() << "failed to apply";
421     diag.attachNote(target->getLoc()) << "target op";
422     return diag;
423   }
424   return target;
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // Transform op registration
429 //===----------------------------------------------------------------------===//
430 
431 namespace {
432 /// Registers new ops and declares PDL as dependent dialect since the additional
433 /// ops are using PDL types for operands and results.
434 class LinalgTransformDialectExtension
435     : public transform::TransformDialectExtension<
436           LinalgTransformDialectExtension> {
437 public:
438   LinalgTransformDialectExtension() {
439     declareDependentDialect<pdl::PDLDialect>();
440     declareDependentDialect<scf::SCFDialect>();
441     declareDependentDialect<vector::VectorDialect>();
442     registerTransformOps<
443 #define GET_OP_LIST
444 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
445         >();
446   }
447 };
448 } // namespace
449 
450 #define GET_OP_CLASSES
451 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
452 
453 void mlir::linalg::registerTransformDialectExtension(
454     DialectRegistry &registry) {
455   registry.addExtensions<LinalgTransformDialectExtension>();
456 }
457