1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/PDL/IR/PDL.h"
14 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 #include "mlir/Parser/Parser.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "llvm/Support/FormatVariadic.h"
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 using namespace mlir::transform;
24 
25 /// Extracts a vector of int64_t from an array attribute. Asserts if the
26 /// attribute contains values other than integers.
27 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
28   SmallVector<int64_t> result;
29   result.reserve(attr.size());
30   for (APInt value : attr.getAsValueRange<IntegerAttr>())
31     result.push_back(value.getSExtValue());
32   return result;
33 }
34 
35 /// Extracts a vector of unsigned from an array attribute. Asserts if the
36 /// attribute contains values other than intergers. May truncate.
37 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
38   SmallVector<unsigned> result;
39   result.reserve(attr.size());
40   for (APInt value : attr.getAsValueRange<IntegerAttr>())
41     result.push_back(value.getZExtValue());
42   return result;
43 }
44 
45 namespace {
46 /// A simple pattern rewriter that implements no special logic.
47 class SimpleRewriter : public PatternRewriter {
48 public:
49   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
50 };
51 } // namespace
52 
53 //===----------------------------------------------------------------------===//
54 // InterchangeOp
55 //===----------------------------------------------------------------------===//
56 
57 FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
58   SmallVector<unsigned> interchangeVector =
59       extractUIntArray(getIteratorInterchange());
60   // Exit early if no transformation is needed.
61   if (interchangeVector.empty())
62     return target;
63 
64   auto genericTarget = dyn_cast<GenericOp>(target.getOperation());
65   if (!genericTarget) {
66     InFlightDiagnostic diag = emitOpError()
67                               << "applies to " << GenericOp::getOperationName()
68                               << " ops";
69     diag.attachNote(target.getLoc()) << "attempted to apply to this op";
70     return diag;
71   }
72 
73   GenericOpInterchangePattern pattern(getContext(), interchangeVector);
74   SimpleRewriter rewriter(getContext());
75   rewriter.setInsertionPoint(target);
76   FailureOr<GenericOp> result =
77       pattern.returningMatchAndRewrite(genericTarget, rewriter);
78   if (failed(result))
79     return failure();
80 
81   return cast<LinalgOp>(result->getOperation());
82 }
83 
84 LogicalResult transform::InterchangeOp::verify() {
85   SmallVector<unsigned> permutation =
86       extractUIntArray(getIteratorInterchange());
87   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
88   if (!std::is_permutation(sequence.begin(), sequence.end(),
89                            permutation.begin(), permutation.end())) {
90     return emitOpError()
91            << "expects iterator_interchange to be a permutation, found "
92            << getIteratorInterchange();
93   }
94   return success();
95 }
96 
97 //===---------------------------------------------------------------------===//
98 // PadOp
99 //===---------------------------------------------------------------------===//
100 
101 FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
102   // Convert the integer packing flags to booleans.
103   SmallVector<bool> packPaddings;
104   for (int64_t packPadding : extractI64Array(getPackPaddings()))
105     packPaddings.push_back(static_cast<bool>(packPadding));
106 
107   // Convert the padding values to attributes.
108   SmallVector<Attribute> paddingValues;
109   for (auto const &it :
110        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
111     Attribute attr = std::get<0>(it);
112     Type elementType = getElementTypeOrSelf(std::get<1>(it));
113     // Try to parse string attributes to obtain an attribute of element type.
114     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
115       paddingValues.push_back(
116           parseAttribute(attr.cast<StringAttr>(), elementType));
117       if (!paddingValues.back()) {
118         InFlightDiagnostic diag = emitOpError()
119                                   << "expects a padding value that parses to "
120                                   << elementType << ", got " << std::get<0>(it);
121         diag.attachNote(target.getLoc()) << "when applied to this op";
122         return diag;
123       }
124       continue;
125     }
126     // Otherwise, add the attribute directly.
127     if (attr.getType() != elementType) {
128       InFlightDiagnostic diag = emitOpError()
129                                 << "expects a padding value of type "
130                                 << elementType << ", got " << attr;
131       diag.attachNote(target.getLoc()) << "when applied to this op";
132       return diag;
133     }
134     paddingValues.push_back(attr);
135   }
136 
137   // Extract the transpose vectors.
138   SmallVector<SmallVector<int64_t>> transposePaddings;
139   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
140     transposePaddings.push_back(
141         extractI64Array(transposeVector.cast<ArrayAttr>()));
142 
143   LinalgPaddingOptions paddingOptions;
144   paddingOptions.setPaddingValues(paddingValues);
145   paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
146   paddingOptions.setPackPaddings(packPaddings);
147   paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
148   paddingOptions.setTransposePaddings(transposePaddings);
149 
150   LinalgPaddingPattern pattern(getContext(), paddingOptions);
151   SimpleRewriter rewriter(getContext());
152   rewriter.setInsertionPoint(target);
153   FailureOr<LinalgOp> patternResult =
154       pattern.returningMatchAndRewrite(target, rewriter);
155   if (failed(patternResult)) {
156     InFlightDiagnostic diag = emitError()
157                               << "failed to apply pattern to target op";
158     diag.attachNote(target.getLoc()) << "target op";
159     return diag;
160   }
161   return patternResult;
162 }
163 
164 LogicalResult transform::PadOp::verify() {
165   SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
166   if (any_of(packPaddings, [](int64_t packPadding) {
167         return packPadding != 0 && packPadding != 1;
168       })) {
169     return emitOpError()
170            << "expects pack_paddings to contain booleans (0/1), found "
171            << getPackPaddings();
172   }
173 
174   SmallVector<int64_t> paddingDimensions =
175       extractI64Array(getPaddingDimensions());
176   if (any_of(paddingDimensions,
177              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
178     return emitOpError()
179            << "expects padding_dimensions to contain positive integers, found "
180            << getPaddingDimensions();
181   }
182 
183   SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
184   if (any_of(hoistPaddings,
185              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
186     return emitOpError()
187            << "expects hoist_paddings to contain positive integers, found "
188            << getHoistPaddings();
189   }
190 
191   ArrayAttr transposes = getTransposePaddings();
192   for (Attribute attr : transposes) {
193     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
194     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
195     if (!std::is_permutation(sequence.begin(), sequence.end(),
196                              transpose.begin(), transpose.end())) {
197       return emitOpError()
198              << "expects transpose_paddings to be a permutation, found "
199              << attr;
200     }
201   }
202   return success();
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // ScalarizeOp
207 //===----------------------------------------------------------------------===//
208 
209 FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
210   LinalgTilingOptions tilingOptions;
211   tilingOptions.scalarizeDynamicDims();
212   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
213   // sizes and asserts that it is not already set.
214   SmallVector<int64_t> emptyTileSizes;
215   LinalgTilingPattern pattern(getContext(), tilingOptions);
216   SimpleRewriter rewriter(getContext());
217   rewriter.setInsertionPoint(target);
218   FailureOr<TiledLinalgOp> result =
219       pattern.returningMatchAndRewrite(target, rewriter);
220   if (failed(result))
221     return failure();
222 
223   return result->op;
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // TileOp
228 //===----------------------------------------------------------------------===//
229 
230 /// Apply a tiling transformation to all payload ops and store both the
231 /// tiled operation as well as the created tile loops.
232 static LogicalResult
233 applyTilingToAll(Operation *transformOp, Value target,
234                  ArrayRef<int64_t> tileSizes,
235                  transform::TransformResults &transformResults,
236                  transform::TransformState &state,
237                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
238   // Number of loops: Number of tiles sizes that are not zero.
239   size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
240   // All payload ops. These should all be LinalgOps for now.
241   ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
242 
243   SmallVector<Operation *> tiledLinalgOps;
244   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
245   for (unsigned int i = 0; i < numLoops; ++i)
246     loopOps[i].reserve(payloadOps.size());
247 
248   for (Operation *target : payloadOps) {
249     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
250     if (!linalgOp)
251       return transformOp->emitError("only LinalgOps are supported");
252 
253     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
254     if (failed(tiled))
255       return failure();
256 
257     tiledLinalgOps.push_back(tiled->op);
258     if (tiled->loops.size() != numLoops)
259       // Not enough loops were generated. This usually means that the input size
260       // was smaller than the tiling size.
261       // TODO: LinalgTilingPattern should return failure().
262       return failure();
263     for (unsigned int i = 0; i < numLoops; ++i)
264       loopOps[i].push_back(tiled->loops[i]);
265   }
266 
267   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
268   for (unsigned int i = 0; i < numLoops; ++i)
269     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
270   return success();
271 }
272 
273 LogicalResult transform::TileOp::apply(TransformResults &transformResults,
274                                        TransformState &state) {
275   LinalgTilingOptions tilingOptions;
276   SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
277 
278   if (!tileSizes.empty())
279     tilingOptions.setTileSizes(tileSizes);
280   tilingOptions.setInterchange(extractUIntArray(getInterchange()));
281   LinalgTilingPattern pattern(getContext(), tilingOptions);
282 
283   return applyTilingToAll(getOperation(), getTarget(), tileSizes,
284                           transformResults, state, [&](LinalgOp linalgOp) {
285                             SimpleRewriter rewriter(linalgOp.getContext());
286                             return pattern.returningMatchAndRewrite(linalgOp,
287                                                                     rewriter);
288                           });
289 }
290 
291 ParseResult transform::TileOp::parse(OpAsmParser &parser,
292                                      OperationState &result) {
293   StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
294   OpAsmParser::UnresolvedOperand targetOperand;
295   SMLoc opLoc = parser.getCurrentLocation();
296   if (parser.parseOperand(targetOperand) ||
297       parser.parseOptionalAttrDict(result.attributes))
298     return failure();
299   Attribute sizesAttr = result.attributes.get(sizesAttrName);
300   if (!sizesAttr)
301     return parser.emitError(opLoc)
302            << "expected '" << sizesAttrName << "' attribute";
303   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
304   if (!sizesArrayAttr)
305     return parser.emitError(opLoc)
306            << "'" << sizesAttrName << "' attribute must be an array";
307   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
308   size_t numExpectedLoops =
309       sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
310   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
311   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
312     return failure();
313   return success();
314 }
315 
316 void TileOp::print(OpAsmPrinter &p) {
317   p << ' ';
318   p << getTarget();
319   p.printOptionalAttrDict((*this)->getAttrs());
320 }
321 
322 void TileOp::getEffects(
323     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
324         &effects) {
325   // `target` arg is consumed and can no longer be used.
326   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
327                        TransformMappingResource::get());
328   effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
329                        TransformMappingResource::get());
330 
331   for (Value r : getResults()) {
332     effects.emplace_back(MemoryEffects::Write::get(), r,
333                          TransformMappingResource::get());
334     effects.emplace_back(MemoryEffects::Allocate::get(), r,
335                          TransformMappingResource::get());
336   }
337 
338   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
339   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // VectorizeOp
344 //===----------------------------------------------------------------------===//
345 
346 FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) {
347   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
348     InFlightDiagnostic diag = emitOpError()
349                               << "applies only to isolated-from-above targets";
350     diag.attachNote(target->getLoc()) << "non-isolated target";
351     return diag;
352   }
353 
354   MLIRContext *ctx = getContext();
355   RewritePatternSet patterns(ctx);
356   patterns.add<LinalgVectorizationPattern>(ctx);
357 
358   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
359   vector::populateVectorReductionToContractPatterns(patterns);
360   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
361                linalg::LinalgCopyVTWForwardingPattern>(ctx,
362                                                        /*benefit=*/2);
363   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
364   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
365   if (getVectorizePadding())
366     linalg::populatePadOpVectorizationPatterns(patterns);
367 
368   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) {
369     InFlightDiagnostic diag = emitError() << "failed to apply";
370     diag.attachNote(target->getLoc()) << "target op";
371     return diag;
372   }
373   return target;
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // Transform op registration
378 //===----------------------------------------------------------------------===//
379 
380 namespace {
381 /// Registers new ops and declares PDL as dependent dialect since the additional
382 /// ops are using PDL types for operands and results.
383 class LinalgTransformDialectExtension
384     : public transform::TransformDialectExtension<
385           LinalgTransformDialectExtension> {
386 public:
387   LinalgTransformDialectExtension() {
388     declareDependentDialect<pdl::PDLDialect>();
389     declareDependentDialect<scf::SCFDialect>();
390     declareDependentDialect<vector::VectorDialect>();
391     registerTransformOps<
392 #define GET_OP_LIST
393 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
394         >();
395   }
396 };
397 } // namespace
398 
399 #define GET_OP_CLASSES
400 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
401 
402 void mlir::linalg::registerTransformDialectExtension(
403     DialectRegistry &registry) {
404   registry.addExtensions<LinalgTransformDialectExtension>();
405 }
406