1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/PDL/IR/PDL.h"
14 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 #include "mlir/Parser/Parser.h"
18 #include "llvm/Support/FormatVariadic.h"
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 using namespace mlir::transform;
23 
24 /// Extracts a vector of int64_t from an array attribute. Asserts if the
25 /// attribute contains values other than integers.
26 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
27   SmallVector<int64_t> result;
28   result.reserve(attr.size());
29   for (APInt value : attr.getAsValueRange<IntegerAttr>())
30     result.push_back(value.getSExtValue());
31   return result;
32 }
33 
34 /// Extracts a vector of unsigned from an array attribute. Asserts if the
35 /// attribute contains values other than intergers. May truncate.
36 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
37   SmallVector<unsigned> result;
38   result.reserve(attr.size());
39   for (APInt value : attr.getAsValueRange<IntegerAttr>())
40     result.push_back(value.getZExtValue());
41   return result;
42 }
43 
44 namespace {
45 /// A simple pattern rewriter that implements no special logic.
46 class SimpleRewriter : public PatternRewriter {
47 public:
48   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
49 };
50 } // namespace
51 
52 //===----------------------------------------------------------------------===//
53 // TileOp
54 //===----------------------------------------------------------------------===//
55 
56 /// Apply a tiling transformation to all payload ops and store both the
57 /// tiled operation as well as the created tile loops.
58 static LogicalResult
59 applyTilingToAll(Operation *transformOp, Value target,
60                  ArrayRef<int64_t> tileSizes,
61                  transform::TransformResults &transformResults,
62                  transform::TransformState &state,
63                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
64   // Number of loops: Number of tiles sizes that are not zero.
65   size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
66   // All payload ops. These should all be LinalgOps for now.
67   ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
68 
69   SmallVector<Operation *> tiledLinalgOps;
70   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
71   for (unsigned int i = 0; i < numLoops; ++i)
72     loopOps[i].reserve(payloadOps.size());
73 
74   for (Operation *target : payloadOps) {
75     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
76     if (!linalgOp)
77       return transformOp->emitError("only LinalgOps are supported");
78 
79     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
80     if (failed(tiled))
81       return failure();
82 
83     tiledLinalgOps.push_back(tiled->op);
84     if (tiled->loops.size() != numLoops)
85       // Not enough loops were generated. This usually means that the input size
86       // was smaller than the tiling size.
87       // TODO: LinalgTilingPattern should return failure().
88       return failure();
89     for (unsigned int i = 0; i < numLoops; ++i)
90       loopOps[i].push_back(tiled->loops[i]);
91   }
92 
93   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
94   for (unsigned int i = 0; i < numLoops; ++i)
95     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
96   return success();
97 }
98 
99 LogicalResult transform::TileOp::apply(TransformResults &transformResults,
100                                        TransformState &state) {
101   LinalgTilingOptions tilingOptions;
102   SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
103 
104   if (!tileSizes.empty())
105     tilingOptions.setTileSizes(tileSizes);
106   tilingOptions.setInterchange(extractUIntArray(getInterchange()));
107   LinalgTilingPattern pattern(getContext(), tilingOptions);
108 
109   return applyTilingToAll(getOperation(), getTarget(), tileSizes,
110                           transformResults, state, [&](LinalgOp linalgOp) {
111                             SimpleRewriter rewriter(linalgOp.getContext());
112                             return pattern.returningMatchAndRewrite(linalgOp,
113                                                                     rewriter);
114                           });
115 }
116 
117 ParseResult transform::TileOp::parse(OpAsmParser &parser,
118                                      OperationState &result) {
119   StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
120   OpAsmParser::UnresolvedOperand targetOperand;
121   SMLoc opLoc;
122   parser.getCurrentLocation(&opLoc);
123   if (parser.parseOperand(targetOperand))
124     return parser.emitError(opLoc, "expected 'target' operand");
125   if (parser.parseOptionalAttrDict(result.attributes))
126     return failure();
127   Attribute sizesAttr = result.attributes.get(sizesAttrName);
128   if (!sizesAttr)
129     return parser.emitError(opLoc)
130            << "expected '" << sizesAttrName << "' attribute";
131   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
132   if (!sizesArrayAttr)
133     return parser.emitError(opLoc)
134            << "'" << sizesAttrName << "' attribute must be an array";
135   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
136   size_t numExpectedLoops =
137       sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
138   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
139   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
140     return failure();
141   return success();
142 }
143 
144 void TileOp::print(OpAsmPrinter &p) {
145   p << ' ';
146   p << getTarget();
147   p.printOptionalAttrDict((*this)->getAttrs());
148 }
149 
150 void TileOp::getEffects(
151     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
152         &effects) {
153   // `target` arg is consumed and can no longer be used.
154   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
155                        TransformMappingResource::get());
156   effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
157                        TransformMappingResource::get());
158 
159   for (Value r : getResults()) {
160     effects.emplace_back(MemoryEffects::Write::get(), r,
161                          TransformMappingResource::get());
162     effects.emplace_back(MemoryEffects::Allocate::get(), r,
163                          TransformMappingResource::get());
164   }
165 
166   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
167   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // Transform op registration
172 //===----------------------------------------------------------------------===//
173 
174 namespace {
175 /// Registers new ops and declares PDL as dependent dialect since the additional
176 /// ops are using PDL types for operands and results.
177 class LinalgTransformDialectExtension
178     : public transform::TransformDialectExtension<
179           LinalgTransformDialectExtension> {
180 public:
181   LinalgTransformDialectExtension() {
182     declareDependentDialect<pdl::PDLDialect>();
183     declareDependentDialect<scf::SCFDialect>();
184     registerTransformOps<
185 #define GET_OP_LIST
186 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
187         >();
188   }
189 };
190 } // namespace
191 
192 #define GET_OP_CLASSES
193 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
194 
195 void mlir::linalg::registerTransformDialectExtension(
196     DialectRegistry &registry) {
197   registry.addExtensions<LinalgTransformDialectExtension>();
198 }
199