1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/AffineExprVisitor.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallSet.h"
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 
23 /// Include the definitions of the copy operation interface.
24 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
25 
26 //===----------------------------------------------------------------------===//
27 // ContractionOpInterface implementation
28 //===----------------------------------------------------------------------===//
29 
30 /// Return true if the use-def chain from `v` to `from` consists of 0 or more
31 /// unary single-operand operations.
32 // TODO: relax to multi-operands with constants, which are technically unary ops
33 // as needed (e.g. add5).
34 static bool isChainOfUnaryOpsFrom(Value v, Value from) {
35   while (true) {
36     if (v == from)
37       return true;
38     Operation *op = v.getDefiningOp();
39     if (!op || op->getNumOperands() != 1)
40       return false;
41     v = op->getOperand(0);
42   };
43 }
44 
45 /// Return the unique instance of OpType in `block` if it is indeed unique.
46 /// Return null if none or more than 1 instances exist.
47 template <typename OpType>
48 static OpType getSingleOpOfType(Block &block) {
49   OpType res = nullptr;
50   block.walk([&](OpType op) {
51     if (res) {
52       res = nullptr;
53       return WalkResult::interrupt();
54     }
55     res = op;
56     return WalkResult::advance();
57   });
58   return res;
59 }
60 
61 /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
62 /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
63 /// unary operations that may change the type.
64 template <typename AddOpType, typename MulOpType>
65 static bool isAddMul(Block &block) {
66   if (block.getNumArguments() != 3)
67     return false;
68   Operation *yieldOp = block.getTerminator();
69   if (yieldOp->getNumOperands() != 1)
70     return false;
71 
72   AddOpType addOp = getSingleOpOfType<AddOpType>(block);
73   MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
74   if (!addOp || !mulOp)
75     return false;
76 
77   Value argA = block.getArgument(0), argB = block.getArgument(1);
78   Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
79   Value mul = mulOp->getResult(0);
80   Value argC = block.getArgument(2);
81   Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
82   Value add = addOp->getResult(0);
83   Value res = yieldOp->getOperand(0);
84   // Result traces back to add.
85   auto un = isChainOfUnaryOpsFrom;
86   bool success = un(res, add);
87   // One of the operands of add traces back to argC, the other to the mul.
88   success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
89   // One of the operands of mul traces back to argA, the other to argB.
90   success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
91   return success;
92 }
93 
94 enum class MatchContractionResult {
95   Success = 0,
96   NotLinalgOp,
97   WrongNumOperands,
98   NoReduction,
99   NotProjectedPermutations,
100   NotAddMul
101 };
102 static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
103   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
104   if (!linalgOp)
105     return MatchContractionResult::NotLinalgOp;
106   if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
107     return MatchContractionResult::WrongNumOperands;
108   auto mapRange = linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>();
109   if (linalgOp.getNumReductionLoops() == 0)
110     return MatchContractionResult::NoReduction;
111   if (llvm::any_of(mapRange,
112                    [](AffineMap m) { return !m.isProjectedPermutation(); }))
113     return MatchContractionResult::NotProjectedPermutations;
114   // TODO: more fields than add/mul.
115   if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
116       !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()))
117     return MatchContractionResult::NotAddMul;
118   return MatchContractionResult::Success;
119 }
120 
121 bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
122   if (!linalgOp)
123     return false;
124   Operation *op = linalgOp.getOperation();
125   return isa<ContractionOpInterface>(op) ||
126          (isContractionInterfaceImpl(op) == MatchContractionResult::Success);
127 }
128 
129 /// Verify that a LinalgOp `op` is a contraction.
130 /// A Linalg contraction is defined in general terms:
131 ///   1. Has 2 input and 1 output shapes.
132 ///   2. Has at least one reduction dimension.
133 ///   3. Has only projected permutation indexing maps.
134 ///   4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
135 ///   (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
136 ///   operations that may change the type (e.g. for mixed-precision).
137 /// As a consequence, when vectorization of such an op occurs, the only special
138 /// behavior is that the (unique) MulOpType is vectorized into a
139 /// `vector.contract`. All other ops are handled in a generic fashion.
140 /// In the future, we may wish to allow more input arguments and elementwise and
141 /// constant operations that do not involve the reduction dimension(s).
142 LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
143   auto res = isContractionInterfaceImpl(op);
144   if (res == MatchContractionResult::NotLinalgOp)
145     return op->emitError("expected a LinalgOp");
146   if (res == MatchContractionResult::WrongNumOperands)
147     return op->emitError("expected op with 2 inputs and 1 outputs");
148   if (res == MatchContractionResult::NoReduction)
149     return op->emitError("expected at least a reduction loop");
150   if (res == MatchContractionResult::NotProjectedPermutations)
151     return op->emitError("expected all indexings to be projected permutations");
152   if (res == MatchContractionResult::NotAddMul)
153     return op->emitError("(add, mul) operations not found");
154   return success();
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // ConvolutionOpInterface implementation
159 //===----------------------------------------------------------------------===//
160 
161 /// Of the given two expressions returns one that is of type T (`lhs` gets
162 /// preference over `rhs`)
163 template <typename T>
164 static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
165   return lhs.isa<T>() ? lhs.cast<T>()
166                       : (rhs.isa<T>() ? rhs.cast<T>() : nullptr);
167 }
168 
169 namespace {
170 /// Walk the indexing expressions for input of a convolution operation to verify
171 /// its of the right form, either
172 /// - AffineDimExpr
173 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
174 ///      (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
175 ///
176 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
177 /// dimensions and verifies each dimension occurs only once.
178 struct ConvAccessExprWalker
179     : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
180   llvm::SmallDenseSet<unsigned> convolvedDims;
181   llvm::SmallDenseSet<unsigned> unConvolvedDims;
182 
183   LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
184     unsigned position = dimExpr.getPosition();
185     if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
186       return failure();
187     }
188     unConvolvedDims.insert(position);
189     return success();
190   }
191 
192   LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
193 
194   LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
195 
196   LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
197     // In pre-order visit, top level op has to be an add op.
198     if (binaryExpr.getKind() != AffineExprKind::Add)
199       return failure();
200     return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) &&
201                    succeeded(isDimExprOrMulExpr(binaryExpr.getRHS())));
202   }
203 
204   LogicalResult isDimExprOrMulExpr(AffineExpr expr) {
205     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
206       unsigned dim = dimExpr.getPosition();
207       if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
208         return failure();
209       convolvedDims.insert(dim);
210       return success();
211     }
212     if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
213       if (symbolMulExpr.getKind() != AffineExprKind::Mul)
214         return failure();
215       auto lhsExpr = symbolMulExpr.getLHS();
216       auto rhsExpr = symbolMulExpr.getRHS();
217       // Check for symbol expression.
218       AffineExpr mulExpr =
219           getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
220       // If there was no symbol expr, check for constant expression.
221       if (!mulExpr) {
222         mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
223       }
224       auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
225       if (!mulExpr || !dimExpr)
226         return failure();
227       unsigned dim = dimExpr.getPosition();
228       if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
229         return failure();
230       convolvedDims.insert(dim);
231       return success();
232     }
233     return failure();
234   }
235 };
236 } // namespace
237 
238 static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
239   assert(map.isProjectedPermutation() &&
240          "expected map to have projected permutations");
241   llvm::SmallDenseSet<unsigned> preservedDims;
242   for (auto expr : map.getResults())
243     preservedDims.insert(expr.cast<AffineDimExpr>().getPosition());
244   return preservedDims;
245 }
246 
247 enum class MatchConvolutionResult {
248   Success = 0,
249   NotLinalgOp,
250   WrongNumOperands,
251   WrongInputIndexingMap,
252   NotProjectedPermutations,
253   NonConvolutionLoop,
254   OutputDimsNotParallel,
255   NonOutputDimNotReduction
256 };
257 
258 static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
259   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
260   if (!linalgOp)
261     return MatchConvolutionResult::NotLinalgOp;
262   if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1)
263     return MatchConvolutionResult::WrongNumOperands;
264 
265   auto indexingMaps = linalgOp.getIndexingMaps();
266 
267   // Check the input indexing map has the right form.
268   ConvAccessExprWalker inputExprWalker;
269   if (llvm::any_of(indexingMaps[0].getResults(),
270                    [&inputExprWalker](AffineExpr expr) {
271                      return failed(inputExprWalker.visit(expr));
272                    })) {
273     return MatchConvolutionResult::WrongInputIndexingMap;
274   }
275 
276   // Filter and output maps must be projected permutation.
277   if (!indexingMaps[1].isProjectedPermutation() ||
278       !indexingMaps.back().isProjectedPermutation())
279     return MatchConvolutionResult::NotProjectedPermutations;
280 
281   auto iteratorTypesRange =
282       linalgOp.iterator_types().getAsValueRange<StringAttr>();
283 
284   llvm::SmallDenseSet<unsigned> outputDims =
285       getPreservedDims(indexingMaps.back());
286   llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
287   // Make sure all loops are charecterized as one of:
288   // - Batch loop : present in output, as non-convolved in input, not present in
289   //   filter.
290   // - Output image dimension : present in output, convolved dims in input, not
291   //   present in filter.
292   // - Output channel dimension : present in output, not present in input,
293   //   present in filter.
294   // - Filter loop dimension : present in filter, convolved in input, not
295   //   present in output.
296   // - Input channel dimension : unconvolved in input, not present in output,
297   //   present in filter.
298   // - Depth multiplier : unconvolved in input, present in output, present in
299   //   filter.
300   llvm::SmallDenseSet<unsigned> allLoopDims;
301   for (auto outputExpr : indexingMaps.back().getResults()) {
302     unsigned outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
303     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
304         !filterDims.count(outputDim)) {
305       // Batch dimension.
306       if (*std::next(iteratorTypesRange.begin(), outputDim) !=
307           getParallelIteratorTypeName())
308         return MatchConvolutionResult::OutputDimsNotParallel;
309       allLoopDims.insert(outputDim);
310       continue;
311     }
312     if (inputExprWalker.convolvedDims.count(outputDim) &&
313         !filterDims.count(outputDim)) {
314       // Output image Loop dimension.
315       if (*std::next(iteratorTypesRange.begin(), outputDim) !=
316           getParallelIteratorTypeName())
317         return MatchConvolutionResult::OutputDimsNotParallel;
318       allLoopDims.insert(outputDim);
319       continue;
320     }
321     if (!inputExprWalker.convolvedDims.count(outputDim) &&
322         !inputExprWalker.unConvolvedDims.count(outputDim) &&
323         filterDims.count(outputDim)) {
324       // Output channel dimension.
325       if (*std::next(iteratorTypesRange.begin(), outputDim) !=
326           getParallelIteratorTypeName())
327         return MatchConvolutionResult::OutputDimsNotParallel;
328       allLoopDims.insert(outputDim);
329       continue;
330     }
331     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
332         filterDims.count(outputDim)) {
333       // Depth multiplier.
334       if (*std::next(iteratorTypesRange.begin(), outputDim) !=
335           getParallelIteratorTypeName())
336         return MatchConvolutionResult::OutputDimsNotParallel;
337       allLoopDims.insert(outputDim);
338       continue;
339     }
340     return MatchConvolutionResult::NonConvolutionLoop;
341   }
342   for (auto filterExpr : indexingMaps[1].getResults()) {
343     unsigned filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
344     if (outputDims.count(filterDim) &&
345         !inputExprWalker.unConvolvedDims.count(filterDim) &&
346         !inputExprWalker.convolvedDims.count(filterDim)) {
347       // Output channel dimension. THis is already seen, continue;
348       continue;
349     }
350     if (inputExprWalker.convolvedDims.count(filterDim) &&
351         !outputDims.count(filterDim)) {
352       // Filter loop dimension.
353       if (*std::next(iteratorTypesRange.begin(), filterDim) !=
354           getReductionIteratorTypeName())
355         return MatchConvolutionResult::NonOutputDimNotReduction;
356       if (allLoopDims.count(filterDim))
357         return MatchConvolutionResult::NonConvolutionLoop;
358       allLoopDims.insert(filterDim);
359       continue;
360     }
361     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
362         !outputDims.count(filterDim)) {
363       // Input channel dimension.
364       if (*std::next(iteratorTypesRange.begin(), filterDim) !=
365           getReductionIteratorTypeName())
366         return MatchConvolutionResult::NonOutputDimNotReduction;
367       if (allLoopDims.count(filterDim))
368         return MatchConvolutionResult::NonConvolutionLoop;
369       allLoopDims.insert(filterDim);
370       continue;
371     }
372     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
373         outputDims.count(filterDim)) {
374       // Depthwise loop. Already seen.
375       continue;
376     }
377     return MatchConvolutionResult::NonConvolutionLoop;
378   }
379   // All loops must be covered now.
380   if (allLoopDims.size() != linalgOp.getNumLoops())
381     return MatchConvolutionResult::NonConvolutionLoop;
382 
383   return MatchConvolutionResult::Success;
384 }
385 
386 LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
387   auto res = isConvolutionInterfaceImpl(op);
388   if (res == MatchConvolutionResult::NotLinalgOp)
389     return op->emitError("expected a LinalgOp");
390   if (res == MatchConvolutionResult::WrongNumOperands)
391     return op->emitError("expected op with 2 inputs and 1 output");
392   if (res == MatchConvolutionResult::WrongInputIndexingMap)
393     return op->emitError("unexpected input index map for convolutions");
394   if (res == MatchConvolutionResult::NotProjectedPermutations) {
395     return op->emitError(
396         "expected output/filter indexing maps to be projected permutations");
397   }
398   if (res == MatchConvolutionResult::NonConvolutionLoop) {
399     return op->emitError("unexpected loop dimension for convolution op");
400   }
401   if (res == MatchConvolutionResult::OutputDimsNotParallel) {
402     return op->emitError(
403         "expected all iterators used to access outputs to be parallel");
404   }
405   if (res == MatchConvolutionResult::NonOutputDimNotReduction) {
406     return op->emitError(
407         "expected all iterators not used to access outputs to be reduction");
408   }
409   return success();
410 }
411 //===----------------------------------------------------------------------===//
412 // StructuredOpInterface implementation
413 //===----------------------------------------------------------------------===//
414 
415 OpOperandVector::operator SmallVector<Value>() {
416   SmallVector<Value> result;
417   result.reserve(this->size());
418   llvm::transform(*this, std::back_inserter(result),
419                   [](OpOperand *opOperand) { return opOperand->get(); });
420   return result;
421 }
422 
423 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
424 /// the type of `source`.
425 static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
426                                int64_t dim) {
427   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
428     return b.createOrFold<memref::DimOp>(loc, source, dim);
429   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
430     return b.createOrFold<tensor::DimOp>(loc, source, dim);
431   llvm_unreachable("Expected MemRefType or TensorType");
432 }
433 
434 SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
435                                                             Location loc) {
436   SmallVector<Value, 4> res;
437   for (OpOperand *opOperand : getInputAndOutputOperands()) {
438     for (int64_t i = 0, e = getRank(opOperand); i < e; ++i)
439       res.push_back(createOrFoldDimOp(b, loc, opOperand->get(), i));
440   }
441   return res;
442 }
443 
444 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
445   SmallVector<int64_t, 4> res;
446   assert(!hasDynamicShape() && "expected operands to have static shapes");
447   for (OpOperand *opOperand : getInputAndOutputOperands())
448     llvm::append_range(res, getShape(opOperand));
449   return res;
450 }
451 
452 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
453   AffineMap map = getLoopsToShapesMap();
454   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
455   auto viewSizes = createFlatListOfOperandDims(b, loc);
456   SmallVector<Range, 4> res(numDims);
457   Value zeroVal = b.create<arith::ConstantIndexOp>(loc, 0);
458   Value oneVal = b.create<arith::ConstantIndexOp>(loc, 1);
459   for (unsigned idx = 0; idx < numRes; ++idx) {
460     auto result = map.getResult(idx);
461     if (auto d = result.dyn_cast<AffineDimExpr>()) {
462       if (res[d.getPosition()].offset)
463         continue;
464       res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
465     }
466   }
467   return res;
468 }
469 
470 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
471   AffineMap map = getLoopsToShapesMap();
472   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
473   SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
474   SmallVector<int64_t, 4> res(numDims, 0);
475   for (unsigned idx = 0; idx < numRes; ++idx) {
476     auto result = map.getResult(idx);
477     if (auto d = result.dyn_cast<AffineDimExpr>())
478       res[d.getPosition()] = allShapeSizes[idx];
479   }
480   return res;
481 }
482 
483 /// Visitor to check if any of the given set of positions from AffineDimExprs
484 /// are used within an AffineExpr.
485 struct HasAffineDimExprVisitor
486     : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
487   HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
488       : positions(positions) {}
489 
490   bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
491     return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
492   }
493 
494   bool visitDimExpr(AffineDimExpr dimExpr) {
495     return positions.count(dimExpr.getPosition());
496   }
497 
498   bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
499 
500   bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
501 
502 private:
503   llvm::SmallSet<unsigned, 4> positions;
504 };
505 
506 LogicalResult
507 LinalgOp::reifyResultShapes(OpBuilder &b,
508                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
509   // An example that helps understand the logic below.
510   // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
511   // We want to express the shape of dim 0 of O in terms of shape of the inputs.
512   // This is achieved as follows.
513   //   loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
514   //   subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
515   //   shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
516   //   resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
517   //     = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
518   AffineMap loopsToShapesMap = getLoopsToShapesMap();
519 
520   // Find the position in the above map that represents the shape of the
521   // result:dim being inferred.
522   auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
523 
524   /// From loopsToShapesMap extract the submap that represents the shape of the
525   /// (resultIdx, dim) needed.
526   SmallVector<unsigned, 4> resultPosRange =
527       llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first,
528                                              resultShapesSubMapPos.second));
529   AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange);
530   AffineMap resultShapesFromInputShapesMap =
531       loopToResultsShapeMap.compose(getShapesToLoopsMap());
532 
533   // Check that the result dim map does not contain the positions corresponding
534   // to the outputs.
535   llvm::SmallSet<unsigned, 4> outputDims;
536   llvm::for_each(resultPosRange,
537                  [&outputDims](unsigned dim) { outputDims.insert(dim); });
538   HasAffineDimExprVisitor checkDimExpr(outputDims);
539   Location loc = getOperation()->getLoc();
540   auto allResultDimValues =
541       applyMapToValues(b, loc, resultShapesFromInputShapesMap,
542                        createFlatListOfOperandDims(b, loc));
543   int64_t pos = 0;
544   ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
545   for (OpOperand *opOperand : getOutputOperands()) {
546     SmallVector<Value> shapes;
547     for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
548       if (checkDimExpr.visit(shapeExprs[pos]))
549         shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim));
550       else
551         shapes.push_back(allResultDimValues[pos]);
552       pos++;
553     }
554     reifiedReturnShapes.emplace_back(std::move(shapes));
555   }
556   return success();
557 }
558 
559 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
560   LinalgOp linalgOp = cast<LinalgOp>(op);
561   // Expect at least one output operand.
562   // This means an op that constructs a tensor out of indices cannot be a
563   // LinalgOp at the moment. For now this will have to be a special op until we
564   // have output shape operands that are not tensors.
565   int64_t numInputs = linalgOp.getNumInputs();
566   int64_t numOutputs = linalgOp.getNumOutputs();
567   if (numOutputs == 0)
568     return op->emitOpError("expected at least one output operand");
569   if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
570     return failure();
571   // Verify the number of results matches the number of output tensors.
572   if (op->getNumResults() != linalgOp.getOutputTensorOperands().size())
573     return op->emitOpError("expected the number of results (")
574            << op->getNumResults()
575            << ") to be equal to the number of output tensors ("
576            << linalgOp.getOutputTensorOperands().size() << ")";
577 
578   // Before checking indexing maps, we need to make sure the attributes
579   // referenced by it are valid.
580   if (linalgOp.hasDynamicIndexingMaps())
581     if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
582       return failure();
583 
584   // All input/output operands must be indexed.
585   if (static_cast<int64_t>(linalgOp.indexing_maps().size()) !=
586       linalgOp.getNumInputsAndOutputs())
587     return op->emitOpError("expected the number of indexing_map (")
588            << linalgOp.indexing_maps().size()
589            << ") to be equal to the number of input/output operands ("
590            << linalgOp.getNumInputsAndOutputs() << ")";
591 
592   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
593     AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
594 
595     // Symbols disallowed.
596     if (indexingMap.getNumSymbols() != 0)
597       return op->emitOpError("unexpected symbols in indexing_map #")
598              << opOperand->getOperandNumber();
599 
600     // Domain must be consistent.
601     unsigned numLoops = linalgOp.getNumLoops();
602     if (indexingMap.getNumDims() != numLoops)
603       return op->emitOpError("expected indexing_map #")
604              << opOperand->getOperandNumber() << " to have " << numLoops
605              << " dim(s) to match the number of loops";
606 
607     int64_t rank = linalgOp.getRank(opOperand);
608     if (indexingMap.getNumResults() != rank)
609       return op->emitOpError("expected operand rank (")
610              << rank << ") to match the result rank of indexing_map #"
611              << opOperand->getOperandNumber() << " ("
612              << indexingMap.getNumResults() << ")";
613   }
614 
615   SmallVector<unsigned> redDims;
616   linalgOp.getReductionDims(redDims);
617 
618   // Simplifying assumption: either full tensor or full buffer mode.
619   // This allows simpler verification of output operands vs result types
620   // without premature tracking of which operand is what in mixed-mode.
621   // TODO: relax when mixed-mode needs to pass verification.
622   if (!linalgOp.getOutputBufferOperands().empty() &&
623       !linalgOp.getOutputTensorOperands().empty())
624     return op->emitOpError(
625         "expected output operands to all have tensor type or "
626         "all have buffer type");
627 
628   for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) {
629     OpResult result = linalgOp.getTiedOpResult(opOperand);
630     if (result.getType() != opOperand->get().getType())
631       return op->emitOpError("expected type of operand #")
632              << opOperand->getOperandNumber() << " ("
633              << opOperand->get().getType() << ")"
634              << " to match type of corresponding result (" << result.getType()
635              << ")";
636   }
637 
638   // Output tensor indexing map may not depend on reduction indices.
639   for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
640     AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
641     for (AffineExpr expr : indexingMap.getResults()) {
642       for (unsigned pos : redDims) {
643         if (expr.isFunctionOfDim(pos)) {
644           std::string exprStr;
645           {
646             llvm::raw_string_ostream os(exprStr);
647             os << expr;
648           }
649           return op->emitOpError(
650                      "unexpected output tensor expression in indexing map #")
651                  << (opOperand->getOperandNumber() - linalgOp.getNumInputs())
652                  << " a.k.a '" << exprStr
653                  << "' is function of reduction iterator 'd" << pos << "'";
654         }
655       }
656     }
657   }
658 
659   // Check the region has exactly one block.
660   if (linalgOp->getNumRegions() != 1 ||
661       !llvm::hasSingleElement(linalgOp->getRegion(0)))
662     return op->emitOpError("expects to have 1 region with 1 block");
663 
664   if (!linalgOp.getShapesToLoopsMap())
665     return op->emitOpError("expected the shape-to-loops map to be non-null");
666 
667   // Simplifying assumption: bbargs match 1-1 with shape operands elemental
668   // types.
669   // TODO: once ranked shape types are plugged in, we may want to drop the
670   // corresponding bbargs, that can never be read from. This will be subject to
671   // consistency discussions (i.e. what to do with output tensors whose bbarg is
672   // not used).
673   Block &block = linalgOp->getRegion(0).front();
674 
675   if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments())
676     return op->emitOpError("expected as many non-induction variable region "
677                            "arguments as the number of input/output operands");
678 
679   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
680     Type elementType = getElementTypeOrSelf(opOperand->get());
681     Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
682     if (elementType != argType)
683       return op->emitOpError("expected type of bb argument #")
684              << opOperand->getOperandNumber() << " (" << argType << ")"
685              << " to match element or self type of the corresponding operand ("
686              << elementType << ")";
687   }
688 
689   // Check if given shapes match to inferred shapes.
690   Optional<SmallVector<int64_t, 4>> endLoopRangeValues =
691       linalgOp.getStaticLoopRanges();
692   if (!endLoopRangeValues)
693     return op->emitOpError("unable to find loop range for operation");
694   SmallVector<int64_t, 4> startLoopRangeValues((*endLoopRangeValues).size(), 0);
695 
696   // Verify only static cases since we can't get exact dimension sizes and loop
697   // ranges for dynamic cases in this stage.
698   if (llvm::none_of(*endLoopRangeValues, ShapedType::isDynamic)) {
699     for (int64_t &range : *endLoopRangeValues)
700       range -= 1;
701     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
702       AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
703       SmallVector<int64_t, 4> startIndices =
704           indexingMap.compose(startLoopRangeValues);
705       SmallVector<int64_t, 4> endIndices =
706           indexingMap.compose(*endLoopRangeValues);
707       ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
708       for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
709         // Ignore dynamic dimension or the case that the dimension size is 0
710         if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
711           continue;
712 
713         // The first index or last index should be the maximum or the minimum in
714         // the inferred index ranges since the range is increasing or
715         // decreasing. The size of dimensions of input/output operands and the
716         // maximum value + 1 in the inferred range should be the same. But, for
717         // now we check if the inferred ranges are in boundary of input/output
718         // operands' size or not in case that Affine Expressions are complicated
719         // such as d0 * 3
720         // + d1 since it is not easy to handle the issues.
721         // Found the case that this solution can't check, for example, (d0, d1)
722         // -> (d1 - d0)
723         int64_t inferredDimSize =
724             std::max(startIndices[dim], endIndices[dim]) + 1;
725         if (std::min(startIndices[dim], endIndices[dim]) < 0) {
726           std::string mapStr;
727           {
728             llvm::raw_string_ostream os(mapStr);
729             os << indexingMap;
730           }
731           return op->emitOpError(
732                      "unexpected result less than 0 at expression #")
733                  << dim << " in " << mapStr;
734         }
735         if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
736           if (inferredDimSize != shape[dim]) {
737             return op->emitOpError("inferred input/output operand #")
738                    << opOperand->getOperandNumber()
739                    << " has shape's dimension #" << dim << " to be "
740                    << inferredDimSize << ", but found " << shape[dim];
741           }
742         } else {
743           if (inferredDimSize > shape[dim]) {
744             return op->emitOpError("inferred input/output operand #")
745                    << opOperand->getOperandNumber()
746                    << " has shape's dimension #" << dim
747                    << " to be greater than or equal to " << inferredDimSize
748                    << ", but found " << shape[dim];
749           }
750         }
751       }
752     }
753   }
754 
755   return success();
756 }
757