1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/AffineExprVisitor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "llvm/ADT/SmallSet.h"
17 
18 using namespace mlir;
19 using namespace mlir::linalg;
20 
21 /// Include the definitions of the copy operation interface.
22 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
23 
24 //===----------------------------------------------------------------------===//
25 // ContractionOpInterface implementation
26 //===----------------------------------------------------------------------===//
27 
28 /// Return true if the use-def chain from `v` to `from` consists of 0 or more
29 /// unary single-operand operations.
30 // TODO: relax to multi-operands with constants, which are technically unary ops
31 // as needed (e.g. add5).
32 static bool isChainOfUnaryOpsFrom(Value v, Value from) {
33   while (true) {
34     if (v == from)
35       return true;
36     Operation *op = v.getDefiningOp();
37     if (!op || op->getNumOperands() != 1)
38       return false;
39     v = op->getOperand(0);
40   };
41 }
42 
43 /// Return the unique instance of OpType in `block` if it is indeed unique.
44 /// Return null if none or more than 1 instances exist.
45 template <typename OpType>
46 static OpType getSingleOpOfType(Block &block) {
47   OpType res = nullptr;
48   block.walk([&](OpType op) {
49     if (res) {
50       res = nullptr;
51       return WalkResult::interrupt();
52     }
53     res = op;
54     return WalkResult::advance();
55   });
56   return res;
57 }
58 
59 /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
60 /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
61 /// unary operations that may change the type.
62 template <typename AddOpType, typename MulOpType>
63 static bool isAddMul(Block &block) {
64   if (block.getNumArguments() != 3)
65     return false;
66   Operation *yieldOp = block.getTerminator();
67   if (yieldOp->getNumOperands() != 1)
68     return false;
69 
70   AddOpType addOp = getSingleOpOfType<AddOpType>(block);
71   MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
72   if (!addOp || !mulOp)
73     return false;
74 
75   Value argA = block.getArgument(0), argB = block.getArgument(1);
76   Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
77   Value mul = mulOp->getResult(0);
78   Value argC = block.getArgument(2);
79   Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
80   Value add = addOp->getResult(0);
81   Value res = yieldOp->getOperand(0);
82   // Result traces back to add.
83   auto un = isChainOfUnaryOpsFrom;
84   bool success = un(res, add);
85   // One of the operands of add traces back to argC, the other to the mul.
86   success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
87   // One of the operands of mul traces back to argA, the other to argB.
88   success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
89   return success;
90 }
91 
92 enum MatchContractionResult {
93   Success = 0,
94   NotLinalgOp,
95   WrongNumOperands,
96   NoReduction,
97   NotProjectedPermutations,
98   NotAddMul
99 };
100 static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
101   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
102   if (!linalgOp)
103     return MatchContractionResult::NotLinalgOp;
104   if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
105     return MatchContractionResult::WrongNumOperands;
106   auto mapRange = linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>();
107   if (linalgOp.getNumReductionLoops() == 0)
108     return MatchContractionResult::NoReduction;
109   if (llvm::any_of(mapRange,
110                    [](AffineMap m) { return !m.isProjectedPermutation(); }))
111     return MatchContractionResult::NotProjectedPermutations;
112   // TODO: more fields than add/mul.
113   if (!isAddMul<AddFOp, MulFOp>(linalgOp->getRegion(0).front()) &&
114       !isAddMul<AddIOp, MulIOp>(linalgOp->getRegion(0).front()))
115     return MatchContractionResult::NotAddMul;
116   return MatchContractionResult::Success;
117 }
118 
119 bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
120   if (!linalgOp)
121     return false;
122   Operation *op = linalgOp.getOperation();
123   return isa<ContractionOpInterface>(op) ||
124          (isContractionInterfaceImpl(op) == MatchContractionResult::Success);
125 }
126 
127 /// Verify that a LinalgOp `op` is a contraction.
128 /// A Linalg contraction is defined in general terms:
129 ///   1. Has 2 input and 1 output shapes.
130 ///   2. Has at least one reduction dimension.
131 ///   3. Has only projected permutation indexing maps.
132 ///   4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
133 ///   (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
134 ///   operations that may change the type (e.g. for mixed-precision).
135 /// As a consequence, when vectorization of such an op occurs, the only special
136 /// behavior is that the (unique) MulOpType is vectorized into a
137 /// `vector.contract`. All other ops are handled in a generic fashion.
138 /// In the future, we may wish to allow more input arguments and elementwise and
139 /// constant operations that do not involve the reduction dimension(s).
140 LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
141   auto res = isContractionInterfaceImpl(op);
142   if (res == MatchContractionResult::NotLinalgOp)
143     return op->emitError("expected a LinalgOp");
144   if (res == MatchContractionResult::WrongNumOperands)
145     return op->emitError("expected op with 2 inputs and 1 outputs");
146   if (res == MatchContractionResult::NoReduction)
147     return op->emitError("expected at least a reduction loop");
148   if (res == MatchContractionResult::NotProjectedPermutations)
149     return op->emitError("expected all indexings to be projected permutations");
150   if (res == MatchContractionResult::NotAddMul)
151     return op->emitError("(add, mul) operations not found");
152   return success();
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // StructuredOpInterface implementation
157 //===----------------------------------------------------------------------===//
158 
159 OpOperandVector::operator SmallVector<Value>() {
160   SmallVector<Value> result;
161   result.reserve(this->size());
162   llvm::transform(*this, std::back_inserter(result),
163                   [](OpOperand *opOperand) { return opOperand->get(); });
164   return result;
165 }
166 
167 /// Fully compose map with operands and canonicalize the result.
168 /// Return the `createOrFold`'ed AffineApply op.
169 static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
170                                              AffineMap map,
171                                              ValueRange operandsRef) {
172   SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
173   fullyComposeAffineMapAndOperands(&map, &operands);
174   canonicalizeMapAndOperands(&map, &operands);
175   return b.createOrFold<AffineApplyOp>(loc, map, operands);
176 }
177 
178 SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
179                                                      AffineMap map,
180                                                      ValueRange values) {
181   SmallVector<Value, 4> res;
182   res.reserve(map.getNumResults());
183   unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
184   // For each `expr` in `map`, applies the `expr` to the values extracted from
185   // ranges. If the resulting application can be folded into a Value, the
186   // folding occurs eagerly.
187   for (auto expr : map.getResults()) {
188     AffineMap map = AffineMap::get(numDims, numSym, expr);
189     res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
190   }
191   return res;
192 }
193 
194 SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
195                                                             Location loc) {
196   SmallVector<Value, 4> res;
197   for (OpOperand *opOperand : getInputAndOutputOperands()) {
198     for (int64_t i = 0, e = getRank(opOperand); i < e; ++i)
199       res.push_back(b.createOrFold<memref::DimOp>(loc, opOperand->get(), i));
200   }
201   return res;
202 }
203 
204 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
205   SmallVector<int64_t, 4> res;
206   assert(!hasDynamicShape() && "expected operands to have static shapes");
207   for (OpOperand *opOperand : getInputAndOutputOperands())
208     llvm::append_range(res, getShape(opOperand));
209   return res;
210 }
211 
212 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
213   AffineMap map = getLoopsToShapesMap();
214   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
215   auto viewSizes = createFlatListOfOperandDims(b, loc);
216   SmallVector<Range, 4> res(numDims);
217   Value zeroVal = b.create<ConstantIndexOp>(loc, 0);
218   Value oneVal = b.create<ConstantIndexOp>(loc, 1);
219   for (unsigned idx = 0; idx < numRes; ++idx) {
220     auto result = map.getResult(idx);
221     if (auto d = result.dyn_cast<AffineDimExpr>()) {
222       if (res[d.getPosition()].offset)
223         continue;
224       res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
225     }
226   }
227   return res;
228 }
229 
230 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
231   AffineMap map = getLoopsToShapesMap();
232   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
233   SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
234   SmallVector<int64_t, 4> res(numDims, 0);
235   for (unsigned idx = 0; idx < numRes; ++idx) {
236     auto result = map.getResult(idx);
237     if (auto d = result.dyn_cast<AffineDimExpr>())
238       res[d.getPosition()] = allShapeSizes[idx];
239   }
240   return res;
241 }
242 
243 /// Visitor to check if any of the given set of positions from AffineDimExprs
244 /// are used within an AffineExpr.
245 struct HasAffineDimExprVisitor
246     : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
247   HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
248       : positions(positions) {}
249 
250   bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
251     return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
252   }
253 
254   bool visitDimExpr(AffineDimExpr dimExpr) {
255     return positions.count(dimExpr.getPosition());
256   }
257 
258   bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
259 
260   bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
261 
262 private:
263   llvm::SmallSet<unsigned, 4> positions;
264 };
265 
266 LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
267     OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
268   // An example that helps understand the logic below.
269   // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
270   // We want to express the shape of dim 0 of O in terms of shape of the inputs.
271   // This is achieved as follows.
272   //   loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
273   //   subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
274   //   shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
275   //   resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
276   //     = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
277   AffineMap loopsToShapesMap = getLoopsToShapesMap();
278 
279   // Find the position in the above map that represents the shape of the
280   // result:dim being inferred.
281   auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
282 
283   /// From loopsToShapesMap extract the submap that represents the shape of the
284   /// (resultIdx, dim) needed.
285   SmallVector<unsigned, 4> resultPosRange =
286       llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first,
287                                              resultShapesSubMapPos.second));
288   AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange);
289   AffineMap resultShapesFromInputShapesMap =
290       loopToResultsShapeMap.compose(getShapesToLoopsMap());
291 
292   // Check that the result dim map does not contain the positions corresponding
293   // to the outputs.
294   llvm::SmallSet<unsigned, 4> outputDims;
295   llvm::for_each(resultPosRange,
296                  [&outputDims](unsigned dim) { outputDims.insert(dim); });
297   HasAffineDimExprVisitor checkDimExpr(outputDims);
298   Location loc = getOperation()->getLoc();
299   auto allResultDimValues =
300       applyMapToValues(b, loc, resultShapesFromInputShapesMap,
301                        createFlatListOfOperandDims(b, loc));
302   int64_t pos = 0;
303   ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
304   for (OpOperand *opOperand : getOutputOperands()) {
305     SmallVector<Value> shapes;
306     for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
307       if (checkDimExpr.visit(shapeExprs[pos]))
308         shapes.push_back(
309             b.createOrFold<memref::DimOp>(loc, opOperand->get(), dim));
310       else
311         shapes.push_back(allResultDimValues[pos]);
312       pos++;
313     }
314     reifiedReturnShapes.emplace_back(std::move(shapes));
315   }
316   return success();
317 }
318 
319 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
320   LinalgOp linalgOp = cast<LinalgOp>(op);
321   // Expect at least one output operand.
322   // This means an op that constructs a tensor out of indices cannot be a
323   // LinalgOp at the moment. For now this will have to be a special op until we
324   // have output shape operands that are not tensors.
325   int64_t numInputs = linalgOp.getNumInputs();
326   int64_t numOutputs = linalgOp.getNumOutputs();
327   if (numOutputs == 0)
328     return op->emitOpError("expected at least one output operand");
329   if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
330     return failure();
331   // Should have at least one output tensor per result tensor.
332   // Can also have outbut buffers that do not correspond to results.
333   if (op->getNumResults() > linalgOp.getOutputTensorOperands().size())
334     return op->emitOpError("unexpected #results > #outputs");
335 
336   // Before checking indexing maps, we need to make sure the attributes
337   // referenced by it are valid.
338   if (linalgOp.hasDynamicIndexingMaps())
339     if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
340       return failure();
341 
342   // All input/output operands must be indexed.
343   if (static_cast<int64_t>(linalgOp.indexing_maps().size()) !=
344       linalgOp.getNumInputsAndOutputs())
345     return op->emitOpError("expected the number of indexing_map (")
346            << linalgOp.indexing_maps().size()
347            << ") to be equal to the number of input/output operands ("
348            << linalgOp.getNumInputsAndOutputs() << ")";
349 
350   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
351     AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
352 
353     // Symbols disallowed.
354     if (indexingMap.getNumSymbols() != 0)
355       return op->emitOpError("unexpected symbols in indexing_map #")
356              << opOperand->getOperandNumber();
357 
358     // Domain must be consistent.
359     unsigned numLoops = linalgOp.getNumLoops();
360     if (indexingMap.getNumDims() != numLoops)
361       return op->emitOpError("expected indexing_map #")
362              << opOperand->getOperandNumber() << " to have " << numLoops
363              << " dim(s) to match the number of loops";
364 
365     int64_t rank = linalgOp.getRank(opOperand);
366     if (indexingMap.getNumResults() != rank)
367       return op->emitOpError("expected operand rank (")
368              << rank << ") to match the result rank of indexing_map #"
369              << opOperand->getOperandNumber() << " ("
370              << indexingMap.getNumResults() << ")";
371   }
372 
373   SmallVector<AffineExpr> redDims;
374   linalgOp.getReductionDims(redDims);
375 
376   // Simplifying assumption: either full tensor or full buffer mode.
377   // This allows simpler verification of output operands vs result types
378   // without premature tracking of which operand is what in mixed-mode.
379   // TODO: relax when mixed-mode needs to pass verification.
380   if (!linalgOp.getOutputBufferOperands().empty() &&
381       !linalgOp.getOutputTensorOperands().empty())
382     return op->emitOpError(
383         "expected output operands to all have tensor type or "
384         "all have buffer type");
385 
386   for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) {
387     // TODO: Enforce one output tensor per result?
388     if (opOperand->getOperandNumber() - linalgOp.getNumInputs() >=
389         linalgOp->getNumResults())
390       continue;
391     OpResult result = linalgOp.getTiedOpResult(opOperand);
392     if (result.getType() != opOperand->get().getType())
393       return op->emitOpError("expected type of operand #")
394              << opOperand->getOperandNumber() << " ("
395              << opOperand->get().getType() << ")"
396              << " to match type of corresponding result (" << result.getType()
397              << ")";
398   }
399 
400   // Output tensor indexing map may not depend on reduction indices.
401   for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
402     AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
403     for (auto expr : indexingMap.getResults()) {
404       for (auto dim : redDims) {
405         unsigned pos = dim.cast<AffineDimExpr>().getPosition();
406         if (expr.isFunctionOfDim(pos)) {
407           std::string exprStr;
408           {
409             llvm::raw_string_ostream os(exprStr);
410             os << expr;
411           }
412           return op->emitOpError(
413                      "unexpected output tensor expression in indexing map #")
414                  << (opOperand->getOperandNumber() - linalgOp.getNumInputs())
415                  << " a.k.a '" << exprStr
416                  << "' is function of reduction iterator 'd" << pos << "'";
417         }
418       }
419     }
420   }
421 
422   // Named ops that are defined manually have a region builder but no region at
423   // this time. Assume the region is well-formed by specification.
424   // TODO: use linalg-ods-gen for all ops when we have enough expressive power.
425   if (linalgOp->getNumRegions() == 0) {
426     assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region");
427     return success();
428   }
429 
430   auto &region = linalgOp->getRegion(0);
431   if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region))
432     return op->emitOpError("expected 1 region with 1 block");
433 
434   if (!linalgOp.getShapesToLoopsMap())
435     return op->emitOpError("expected the shape-to-loops map to be non-null");
436 
437   // Simplifying assumption: bbargs match 1-1 with shape operands elemental
438   // types.
439   // TODO: once ranked shape types are plugged in, we may want to drop the
440   // corresponding bbargs, that can never be read from. This will be subject to
441   // consistency discussions (i.e. what to do with output tensors whose bbarg is
442   // not used).
443   Block &block = linalgOp->getRegion(0).front();
444 
445   if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments())
446     return op->emitOpError("expected as many non-induction variable region "
447                            "arguments as the number of input/output operands");
448 
449   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
450     Type elementType = getElementTypeOrSelf(opOperand->get());
451     Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
452     if (elementType != argType)
453       return op->emitOpError("expected type of bb argument #")
454              << opOperand->getOperandNumber() << " (" << argType << ")"
455              << " to match element or self type of the corresponding operand ("
456              << elementType << ")";
457   }
458 
459   // Check if given shapes match to inferred shapes.
460   Optional<SmallVector<int64_t, 4>> endLoopRangeValues =
461       linalgOp.getStaticLoopRanges();
462   if (!endLoopRangeValues)
463     return op->emitOpError("unable to find loop range for operation");
464   SmallVector<int64_t, 4> startLoopRangeValues((*endLoopRangeValues).size(), 0);
465 
466   // Verify only static cases since we can't get exact dimension sizes and loop
467   // ranges for dynamic cases in this stage.
468   if (llvm::none_of(*endLoopRangeValues, ShapedType::isDynamic)) {
469     for (int64_t &range : *endLoopRangeValues)
470       range -= 1;
471     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
472       AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
473       SmallVector<int64_t, 4> startIndices =
474           indexingMap.compose(startLoopRangeValues);
475       SmallVector<int64_t, 4> endIndices =
476           indexingMap.compose(*endLoopRangeValues);
477       ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
478       for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
479         // Ignore dynamic dimension or the case that the dimension size is 0
480         if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
481           continue;
482 
483         // The first index or last index should be the maximum or the minimum in
484         // the inferred index ranges since the range is increasing or
485         // decreasing. The size of dimensions of input/output operands and the
486         // maximum value + 1 in the inferred range should be the same. But, for
487         // now we check if the inferred ranges are in boundary of input/output
488         // operands' size or not in case that Affine Expressions are complicated
489         // such as d0 * 3
490         // + d1 since it is not easy to handle the issues.
491         // Found the case that this solution can't check, for example, (d0, d1)
492         // -> (d1 - d0)
493         int64_t inferredDimSize =
494             std::max(startIndices[dim], endIndices[dim]) + 1;
495         if (std::min(startIndices[dim], endIndices[dim]) < 0) {
496           std::string mapStr;
497           {
498             llvm::raw_string_ostream os(mapStr);
499             os << indexingMap;
500           }
501           return op->emitOpError(
502                      "unexpected result less than 0 at expression #")
503                  << dim << " in " << mapStr;
504         }
505         if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
506           if (inferredDimSize != shape[dim]) {
507             return op->emitOpError("inferred input/output operand #")
508                    << opOperand->getOperandNumber()
509                    << " has shape's dimension #" << dim << " to be "
510                    << inferredDimSize << ", but found " << shape[dim];
511           }
512         } else {
513           if (inferredDimSize > shape[dim]) {
514             return op->emitOpError("inferred input/output operand #")
515                    << opOperand->getOperandNumber()
516                    << " has shape's dimension #" << dim
517                    << " to be greater than or equal to " << inferredDimSize
518                    << ", but found " << shape[dim];
519           }
520         }
521       }
522     }
523   }
524 
525   return success();
526 }
527