1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/PDL/IR/PDL.h"
14 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 #include "mlir/Parser/Parser.h"
18 #include "llvm/Support/FormatVariadic.h"
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 using namespace mlir::transform;
23 
24 /// Extracts a vector of int64_t from an array attribute. Asserts if the
25 /// attribute contains values other than integers.
26 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
27   SmallVector<int64_t> result;
28   result.reserve(attr.size());
29   for (APInt value : attr.getAsValueRange<IntegerAttr>())
30     result.push_back(value.getSExtValue());
31   return result;
32 }
33 
34 /// Extracts a vector of unsigned from an array attribute. Asserts if the
35 /// attribute contains values other than intergers. May truncate.
36 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
37   SmallVector<unsigned> result;
38   result.reserve(attr.size());
39   for (APInt value : attr.getAsValueRange<IntegerAttr>())
40     result.push_back(value.getZExtValue());
41   return result;
42 }
43 
44 namespace {
45 /// A simple pattern rewriter that implements no special logic.
46 class SimpleRewriter : public PatternRewriter {
47 public:
48   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
49 };
50 } // namespace
51 
52 //===----------------------------------------------------------------------===//
53 // TileOp
54 //===----------------------------------------------------------------------===//
55 
56 /// Apply a tiling transformation to all payload ops and store both the
57 /// tiled operation as well as the created tile loops.
58 static LogicalResult
59 applyTilingToAll(Operation *transformOp, Value target,
60                  ArrayRef<int64_t> tileSizes,
61                  transform::TransformResults &transformResults,
62                  transform::TransformState &state,
63                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
64   // Number of loops: Number of tiles sizes that are not zero.
65   size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
66   // All payload ops. These should all be LinalgOps for now.
67   ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
68 
69   SmallVector<Operation *> tiledLinalgOps;
70   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
71   for (unsigned int i = 0; i < numLoops; ++i)
72     loopOps[i].reserve(payloadOps.size());
73 
74   for (Operation *target : payloadOps) {
75     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
76     if (!linalgOp)
77       return transformOp->emitError("only LinalgOps are supported");
78 
79     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
80     if (failed(tiled))
81       return failure();
82 
83     tiledLinalgOps.push_back(tiled->op);
84     if (tiled->loops.size() != numLoops)
85       // Not enough loops were generated. This usually means that the input size
86       // was smaller than the tiling size.
87       // TODO: LinalgTilingPattern should return failure().
88       return failure();
89     for (unsigned int i = 0; i < numLoops; ++i)
90       loopOps[i].push_back(tiled->loops[i]);
91   }
92 
93   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
94   for (unsigned int i = 0; i < numLoops; ++i)
95     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
96   return success();
97 }
98 
99 LogicalResult transform::TileOp::apply(TransformResults &transformResults,
100                                        TransformState &state) {
101   LinalgTilingOptions tilingOptions;
102   SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
103 
104   if (!tileSizes.empty())
105     tilingOptions.setTileSizes(tileSizes);
106   tilingOptions.setInterchange(extractUIntArray(getInterchange()));
107   LinalgTilingPattern pattern(getContext(), tilingOptions);
108 
109   return applyTilingToAll(getOperation(), getTarget(), tileSizes,
110                           transformResults, state, [&](LinalgOp linalgOp) {
111                             SimpleRewriter rewriter(linalgOp.getContext());
112                             return pattern.returningMatchAndRewrite(linalgOp,
113                                                                     rewriter);
114                           });
115 }
116 
117 ParseResult transform::TileOp::parse(OpAsmParser &parser,
118                                      OperationState &result) {
119   StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
120   OpAsmParser::UnresolvedOperand targetOperand;
121   SMLoc opLoc = parser.getCurrentLocation();
122   if (parser.parseOperand(targetOperand) ||
123       parser.parseOptionalAttrDict(result.attributes))
124     return failure();
125   Attribute sizesAttr = result.attributes.get(sizesAttrName);
126   if (!sizesAttr)
127     return parser.emitError(opLoc)
128            << "expected '" << sizesAttrName << "' attribute";
129   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
130   if (!sizesArrayAttr)
131     return parser.emitError(opLoc)
132            << "'" << sizesAttrName << "' attribute must be an array";
133   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
134   size_t numExpectedLoops =
135       sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
136   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
137   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
138     return failure();
139   return success();
140 }
141 
142 void TileOp::print(OpAsmPrinter &p) {
143   p << ' ';
144   p << getTarget();
145   p.printOptionalAttrDict((*this)->getAttrs());
146 }
147 
148 void TileOp::getEffects(
149     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
150         &effects) {
151   // `target` arg is consumed and can no longer be used.
152   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
153                        TransformMappingResource::get());
154   effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
155                        TransformMappingResource::get());
156 
157   for (Value r : getResults()) {
158     effects.emplace_back(MemoryEffects::Write::get(), r,
159                          TransformMappingResource::get());
160     effects.emplace_back(MemoryEffects::Allocate::get(), r,
161                          TransformMappingResource::get());
162   }
163 
164   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
165   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // Transform op registration
170 //===----------------------------------------------------------------------===//
171 
172 namespace {
173 /// Registers new ops and declares PDL as dependent dialect since the additional
174 /// ops are using PDL types for operands and results.
175 class LinalgTransformDialectExtension
176     : public transform::TransformDialectExtension<
177           LinalgTransformDialectExtension> {
178 public:
179   LinalgTransformDialectExtension() {
180     declareDependentDialect<pdl::PDLDialect>();
181     declareDependentDialect<scf::SCFDialect>();
182     registerTransformOps<
183 #define GET_OP_LIST
184 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
185         >();
186   }
187 };
188 } // namespace
189 
190 #define GET_OP_CLASSES
191 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
192 
193 void mlir::linalg::registerTransformDialectExtension(
194     DialectRegistry &registry) {
195   registry.addExtensions<LinalgTransformDialectExtension>();
196 }
197