1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Complex/IR/Complex.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/AffineExprVisitor.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/SmallBitVector.h"
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 /// Include the definitions of the copy operation interface.
25 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
26 
27 //===----------------------------------------------------------------------===//
28 // Interface utility functions
29 //===----------------------------------------------------------------------===//
canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp,ArrayRef<OpOperand * > droppedOperands)30 bool linalg::detail::canOpOperandsBeDroppedImpl(
31     linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
32   SmallVector<AffineMap> indexingMaps;
33   for (auto *opOperand : linalgOp.getInputAndOutputOperands()) {
34     if (llvm::is_contained(droppedOperands, opOperand))
35       continue;
36     indexingMaps.push_back(linalgOp.getTiedIndexingMap(opOperand));
37   }
38   return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // ContractionOpInterface implementation
43 //===----------------------------------------------------------------------===//
44 
45 /// Return true if the use-def chain from `v` to `from` consists of 0 or more
46 /// unary single-operand operations.
47 // TODO: relax to multi-operands with constants, which are technically unary ops
48 // as needed (e.g. add5).
isChainOfUnaryOpsFrom(Value v,Value from)49 static bool isChainOfUnaryOpsFrom(Value v, Value from) {
50   while (true) {
51     if (v == from)
52       return true;
53     Operation *op = v.getDefiningOp();
54     if (!op || op->getNumOperands() != 1)
55       return false;
56     v = op->getOperand(0);
57   };
58 }
59 
60 /// Return the unique instance of OpType in `block` if it is indeed unique.
61 /// Return null if none or more than 1 instances exist.
62 template <typename OpType>
getSingleOpOfType(Block & block)63 static OpType getSingleOpOfType(Block &block) {
64   OpType res = nullptr;
65   block.walk([&](OpType op) {
66     if (res) {
67       res = nullptr;
68       return WalkResult::interrupt();
69     }
70     res = op;
71     return WalkResult::advance();
72   });
73   return res;
74 }
75 
76 /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
77 /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
78 /// unary operations that may change the type.
79 template <typename AddOpType, typename MulOpType>
isAddMul(Block & block)80 static bool isAddMul(Block &block) {
81   if (block.getNumArguments() != 3)
82     return false;
83   Operation *yieldOp = block.getTerminator();
84   if (yieldOp->getNumOperands() != 1)
85     return false;
86 
87   AddOpType addOp = getSingleOpOfType<AddOpType>(block);
88   MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
89   if (!addOp || !mulOp)
90     return false;
91 
92   Value argA = block.getArgument(0), argB = block.getArgument(1);
93   Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
94   Value mul = mulOp->getResult(0);
95   Value argC = block.getArgument(2);
96   Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
97   Value add = addOp->getResult(0);
98   Value res = yieldOp->getOperand(0);
99   // Result traces back to add.
100   auto un = isChainOfUnaryOpsFrom;
101   bool success = un(res, add);
102   // One of the operands of add traces back to argC, the other to the mul.
103   success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
104   // One of the operands of mul traces back to argA, the other to argB.
105   success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
106   return success;
107 }
108 
109 enum class MatchContractionResult {
110   Success = 0,
111   NotLinalgOp,
112   WrongNumOperands,
113   NoReduction,
114   NotProjectedPermutations,
115   NotAddMul
116 };
isContractionInterfaceImpl(Operation * op)117 static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
118   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
119   if (!linalgOp)
120     return MatchContractionResult::NotLinalgOp;
121   if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
122     return MatchContractionResult::WrongNumOperands;
123   auto mapRange = linalgOp.getIndexingMapsArray();
124   if (linalgOp.getNumReductionLoops() == 0)
125     return MatchContractionResult::NoReduction;
126   if (llvm::any_of(mapRange,
127                    [](AffineMap m) { return !m.isProjectedPermutation(); }))
128     return MatchContractionResult::NotProjectedPermutations;
129   // TODO: more fields than add/mul.
130   if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
131       !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
132       !isAddMul<complex::AddOp, complex::MulOp>(
133           linalgOp->getRegion(0).front()) &&
134       !isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
135     return MatchContractionResult::NotAddMul;
136   return MatchContractionResult::Success;
137 }
138 
isaContractionOpInterface(LinalgOp linalgOp)139 bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
140   if (!linalgOp)
141     return false;
142   Operation *op = linalgOp.getOperation();
143   return isa<ContractionOpInterface>(op) ||
144          (isContractionInterfaceImpl(op) == MatchContractionResult::Success);
145 }
146 
147 /// Verify that a LinalgOp `op` is a contraction.
148 /// A Linalg contraction is defined in general terms:
149 ///   1. Has 2 input and 1 output shapes.
150 ///   2. Has at least one reduction dimension.
151 ///   3. Has only projected permutation indexing maps.
152 ///   4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
153 ///   (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
154 ///   operations that may change the type (e.g. for mixed-precision).
155 /// As a consequence, when vectorization of such an op occurs, the only special
156 /// behavior is that the (unique) MulOpType is vectorized into a
157 /// `vector.contract`. All other ops are handled in a generic fashion.
158 /// In the future, we may wish to allow more input arguments and elementwise and
159 /// constant operations that do not involve the reduction dimension(s).
verifyContractionInterface(Operation * op)160 LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
161   auto res = isContractionInterfaceImpl(op);
162   if (res == MatchContractionResult::NotLinalgOp)
163     return op->emitError("expected a LinalgOp");
164   if (res == MatchContractionResult::WrongNumOperands)
165     return op->emitError("expected op with 2 inputs and 1 outputs");
166   if (res == MatchContractionResult::NoReduction)
167     return op->emitError("expected at least a reduction loop");
168   if (res == MatchContractionResult::NotProjectedPermutations)
169     return op->emitError("expected all indexings to be projected permutations");
170   if (res == MatchContractionResult::NotAddMul)
171     return op->emitError("(add, mul) operations not found");
172   return success();
173 }
174 
175 //===----------------------------------------------------------------------===//
176 // ConvolutionOpInterface implementation
177 //===----------------------------------------------------------------------===//
178 
179 /// Of the given two expressions returns one that is of type T (`lhs` gets
180 /// preference over `rhs`)
181 template <typename T>
getAffineExprOfType(AffineExpr lhs,AffineExpr rhs)182 static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
183   return lhs.isa<T>() ? lhs.cast<T>()
184                       : (rhs.isa<T>() ? rhs.cast<T>() : nullptr);
185 }
186 
187 namespace {
188 /// Walk the indexing expressions for input of a convolution operation to verify
189 /// its of the right form, either
190 /// - AffineDimExpr
191 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
192 ///      (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
193 ///
194 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
195 /// dimensions and verifies each dimension occurs only once.
196 struct ConvAccessExprWalker
197     : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
198   llvm::SmallDenseSet<unsigned> convolvedDims;
199   llvm::SmallDenseSet<unsigned> unConvolvedDims;
200 
visitDimExpr__anonfddbe0890311::ConvAccessExprWalker201   LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
202     unsigned position = dimExpr.getPosition();
203     if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
204       return failure();
205     }
206     unConvolvedDims.insert(position);
207     return success();
208   }
209 
visitSymbolExpr__anonfddbe0890311::ConvAccessExprWalker210   LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
211 
visitConstantExpr__anonfddbe0890311::ConvAccessExprWalker212   LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
213 
visitAffineBinaryOpExpr__anonfddbe0890311::ConvAccessExprWalker214   LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
215     // In pre-order visit, top level op has to be an add op.
216     if (binaryExpr.getKind() != AffineExprKind::Add)
217       return failure();
218     return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) &&
219                    succeeded(isDimExprOrMulExpr(binaryExpr.getRHS())));
220   }
221 
isDimExprOrMulExpr__anonfddbe0890311::ConvAccessExprWalker222   LogicalResult isDimExprOrMulExpr(AffineExpr expr) {
223     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
224       unsigned dim = dimExpr.getPosition();
225       if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
226         return failure();
227       convolvedDims.insert(dim);
228       return success();
229     }
230     if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
231       if (symbolMulExpr.getKind() != AffineExprKind::Mul)
232         return failure();
233       auto lhsExpr = symbolMulExpr.getLHS();
234       auto rhsExpr = symbolMulExpr.getRHS();
235       // Check for symbol expression.
236       AffineExpr mulExpr =
237           getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
238       // If there was no symbol expr, check for constant expression.
239       if (!mulExpr) {
240         mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
241       }
242       auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
243       if (!mulExpr || !dimExpr)
244         return failure();
245       unsigned dim = dimExpr.getPosition();
246       if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
247         return failure();
248       convolvedDims.insert(dim);
249       return success();
250     }
251     return failure();
252   }
253 };
254 } // namespace
255 
getPreservedDims(AffineMap map)256 static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
257   assert(map.isProjectedPermutation() &&
258          "expected map to have projected permutations");
259   llvm::SmallDenseSet<unsigned> preservedDims;
260   for (auto expr : map.getResults())
261     preservedDims.insert(expr.cast<AffineDimExpr>().getPosition());
262   return preservedDims;
263 }
264 
265 enum class MatchConvolutionResult {
266   Success = 0,
267   NotLinalgOp,
268   WrongNumOperands,
269   WrongInputIndexingMap,
270   NotProjectedPermutations,
271   NonConvolutionLoop,
272   OutputDimsNotParallel,
273   NonOutputDimNotReduction
274 };
275 
isConvolutionInterfaceImpl(Operation * op)276 static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
277   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
278   if (!linalgOp)
279     return MatchConvolutionResult::NotLinalgOp;
280   if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1)
281     return MatchConvolutionResult::WrongNumOperands;
282 
283   auto indexingMaps = linalgOp.getIndexingMapsArray();
284 
285   // Check the input indexing map has the right form.
286   ConvAccessExprWalker inputExprWalker;
287   if (llvm::any_of(indexingMaps[0].getResults(),
288                    [&inputExprWalker](AffineExpr expr) {
289                      return failed(inputExprWalker.visit(expr));
290                    })) {
291     return MatchConvolutionResult::WrongInputIndexingMap;
292   }
293 
294   // Filter and output maps must be projected permutation.
295   if (!indexingMaps[1].isProjectedPermutation() ||
296       !indexingMaps.back().isProjectedPermutation())
297     return MatchConvolutionResult::NotProjectedPermutations;
298 
299   auto iteratorTypesRange =
300       linalgOp.iterator_types().getAsValueRange<StringAttr>();
301 
302   llvm::SmallDenseSet<unsigned> outputDims =
303       getPreservedDims(indexingMaps.back());
304   llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
305   // Make sure all loops are charecterized as one of:
306   // - Batch loop : present in output, as non-convolved in input, not present in
307   //   filter.
308   // - Output image dimension : present in output, convolved dims in input, not
309   //   present in filter.
310   // - Output channel dimension : present in output, not present in input,
311   //   present in filter.
312   // - Filter loop dimension : present in filter, convolved in input, not
313   //   present in output.
314   // - Input channel dimension : unconvolved in input, not present in output,
315   //   present in filter.
316   // - Depth multiplier : unconvolved in input, present in output, present in
317   //   filter.
318   llvm::SmallDenseSet<unsigned> allLoopDims;
319   for (auto outputExpr : indexingMaps.back().getResults()) {
320     unsigned outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
321     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
322         !filterDims.count(outputDim)) {
323       // Batch dimension.
324       if (*std::next(iteratorTypesRange.begin(), outputDim) !=
325           getParallelIteratorTypeName())
326         return MatchConvolutionResult::OutputDimsNotParallel;
327       allLoopDims.insert(outputDim);
328       continue;
329     }
330     if (inputExprWalker.convolvedDims.count(outputDim) &&
331         !filterDims.count(outputDim)) {
332       // Output image Loop dimension.
333       if (*std::next(iteratorTypesRange.begin(), outputDim) !=
334           getParallelIteratorTypeName())
335         return MatchConvolutionResult::OutputDimsNotParallel;
336       allLoopDims.insert(outputDim);
337       continue;
338     }
339     if (!inputExprWalker.convolvedDims.count(outputDim) &&
340         !inputExprWalker.unConvolvedDims.count(outputDim) &&
341         filterDims.count(outputDim)) {
342       // Output channel dimension.
343       if (*std::next(iteratorTypesRange.begin(), outputDim) !=
344           getParallelIteratorTypeName())
345         return MatchConvolutionResult::OutputDimsNotParallel;
346       allLoopDims.insert(outputDim);
347       continue;
348     }
349     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
350         filterDims.count(outputDim)) {
351       // Depth multiplier.
352       if (*std::next(iteratorTypesRange.begin(), outputDim) !=
353           getParallelIteratorTypeName())
354         return MatchConvolutionResult::OutputDimsNotParallel;
355       allLoopDims.insert(outputDim);
356       continue;
357     }
358     return MatchConvolutionResult::NonConvolutionLoop;
359   }
360   for (auto filterExpr : indexingMaps[1].getResults()) {
361     unsigned filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
362     if (outputDims.count(filterDim) &&
363         !inputExprWalker.unConvolvedDims.count(filterDim) &&
364         !inputExprWalker.convolvedDims.count(filterDim)) {
365       // Output channel dimension. THis is already seen, continue;
366       continue;
367     }
368     if (inputExprWalker.convolvedDims.count(filterDim) &&
369         !outputDims.count(filterDim)) {
370       // Filter loop dimension.
371       if (*std::next(iteratorTypesRange.begin(), filterDim) !=
372           getReductionIteratorTypeName())
373         return MatchConvolutionResult::NonOutputDimNotReduction;
374       if (allLoopDims.count(filterDim))
375         return MatchConvolutionResult::NonConvolutionLoop;
376       allLoopDims.insert(filterDim);
377       continue;
378     }
379     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
380         !outputDims.count(filterDim)) {
381       // Input channel dimension.
382       if (*std::next(iteratorTypesRange.begin(), filterDim) !=
383           getReductionIteratorTypeName())
384         return MatchConvolutionResult::NonOutputDimNotReduction;
385       if (allLoopDims.count(filterDim))
386         return MatchConvolutionResult::NonConvolutionLoop;
387       allLoopDims.insert(filterDim);
388       continue;
389     }
390     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
391         outputDims.count(filterDim)) {
392       // Depthwise loop. Already seen.
393       continue;
394     }
395     return MatchConvolutionResult::NonConvolutionLoop;
396   }
397   // All loops must be covered now.
398   if (allLoopDims.size() != linalgOp.getNumLoops())
399     return MatchConvolutionResult::NonConvolutionLoop;
400 
401   return MatchConvolutionResult::Success;
402 }
403 
verifyConvolutionInterface(Operation * op)404 LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
405   auto res = isConvolutionInterfaceImpl(op);
406   if (res == MatchConvolutionResult::NotLinalgOp)
407     return op->emitError("expected a LinalgOp");
408   if (res == MatchConvolutionResult::WrongNumOperands)
409     return op->emitError("expected op with 2 inputs and 1 output");
410   if (res == MatchConvolutionResult::WrongInputIndexingMap)
411     return op->emitError("unexpected input index map for convolutions");
412   if (res == MatchConvolutionResult::NotProjectedPermutations) {
413     return op->emitError(
414         "expected output/filter indexing maps to be projected permutations");
415   }
416   if (res == MatchConvolutionResult::NonConvolutionLoop) {
417     return op->emitError("unexpected loop dimension for convolution op");
418   }
419   if (res == MatchConvolutionResult::OutputDimsNotParallel) {
420     return op->emitError(
421         "expected all iterators used to access outputs to be parallel");
422   }
423   if (res == MatchConvolutionResult::NonOutputDimNotReduction) {
424     return op->emitError(
425         "expected all iterators not used to access outputs to be reduction");
426   }
427   return success();
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // FillOpInterface implementation
432 //===----------------------------------------------------------------------===//
433 
434 enum class MatchFillResult {
435   Success = 0,
436   NotLinalgOp,
437   WrongNumOperands,
438   NotScalarInput
439 };
440 
isFillInterfaceImpl(Operation * op)441 static MatchFillResult isFillInterfaceImpl(Operation *op) {
442   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
443   if (!linalgOp)
444     return MatchFillResult::NotLinalgOp;
445   if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1)
446     return MatchFillResult::WrongNumOperands;
447 
448   OpOperand *value = linalgOp.getInputOperand(0);
449   if (!linalgOp.isScalar(value))
450     return MatchFillResult::NotScalarInput;
451 
452   return MatchFillResult::Success;
453 }
454 
verifyFillInterface(Operation * op)455 LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
456   auto res = isFillInterfaceImpl(op);
457   if (res == MatchFillResult::NotLinalgOp)
458     return op->emitError("expected a LinalgOp");
459   if (res == MatchFillResult::WrongNumOperands)
460     return op->emitError("expected op with 1 input and 1 output");
461   if (res == MatchFillResult::NotScalarInput)
462     return op->emitError("expected op with scalar input");
463 
464   return success();
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // StructuredOpInterface implementation
469 //===----------------------------------------------------------------------===//
470 
operator SmallVector<Value>()471 OpOperandVector::operator SmallVector<Value>() {
472   SmallVector<Value> result;
473   result.reserve(this->size());
474   llvm::transform(*this, std::back_inserter(result),
475                   [](OpOperand *opOperand) { return opOperand->get(); });
476   return result;
477 }
478 
479 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
480 /// the type of `source`.
createOrFoldDimOp(OpBuilder & b,Location loc,Value source,int64_t dim)481 static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
482                                int64_t dim) {
483   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
484     return b.createOrFold<memref::DimOp>(loc, source, dim);
485   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
486     return b.createOrFold<tensor::DimOp>(loc, source, dim);
487   llvm_unreachable("Expected MemRefType or TensorType");
488 }
489 
createFlatListOfOperandDims(OpBuilder & b,Location loc)490 SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
491                                                             Location loc) {
492   SmallVector<Value, 4> res;
493   for (OpOperand *opOperand : getInputAndOutputOperands()) {
494     for (int64_t i = 0, e = getRank(opOperand); i < e; ++i)
495       res.push_back(createOrFoldDimOp(b, loc, opOperand->get(), i));
496   }
497   return res;
498 }
499 
createFlatListOfOperandStaticDims()500 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
501   SmallVector<int64_t, 4> res;
502   assert(!hasDynamicShape() && "expected operands to have static shapes");
503   for (OpOperand *opOperand : getInputAndOutputOperands())
504     llvm::append_range(res, getShape(opOperand));
505   return res;
506 }
507 
createLoopRanges(OpBuilder & b,Location loc)508 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
509   AffineMap map = getLoopsToShapesMap();
510   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
511   auto viewSizes = createFlatListOfOperandDims(b, loc);
512   SmallVector<Range, 4> res(numDims);
513   Value zeroVal = b.create<arith::ConstantIndexOp>(loc, 0);
514   Value oneVal = b.create<arith::ConstantIndexOp>(loc, 1);
515   for (unsigned idx = 0; idx < numRes; ++idx) {
516     auto result = map.getResult(idx);
517     if (auto d = result.dyn_cast<AffineDimExpr>()) {
518       if (res[d.getPosition()].offset)
519         continue;
520       res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
521     }
522   }
523   return res;
524 }
525 
computeStaticLoopSizes()526 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
527   AffineMap map = getLoopsToShapesMap();
528   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
529   SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
530   SmallVector<int64_t, 4> res(numDims, 0);
531   for (unsigned idx = 0; idx < numRes; ++idx) {
532     auto result = map.getResult(idx);
533     if (auto d = result.dyn_cast<AffineDimExpr>())
534       res[d.getPosition()] = allShapeSizes[idx];
535   }
536   return res;
537 }
538 
539 /// Visitor to check if any of the given set of positions from AffineDimExprs
540 /// are used within an AffineExpr.
541 struct HasAffineDimExprVisitor
542     : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
HasAffineDimExprVisitorHasAffineDimExprVisitor543   HasAffineDimExprVisitor(llvm::SmallBitVector positions)
544       : positions(std::move(positions)) {}
545 
visitAffineBinaryOpExprHasAffineDimExprVisitor546   bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
547     return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
548   }
549 
visitDimExprHasAffineDimExprVisitor550   bool visitDimExpr(AffineDimExpr dimExpr) {
551     return positions.test(dimExpr.getPosition());
552   }
553 
visitConstantExprHasAffineDimExprVisitor554   bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
555 
visitSymbolExprHasAffineDimExprVisitor556   bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
557 
558 private:
559   llvm::SmallBitVector positions;
560 };
561 
562 LogicalResult
reifyResultShapes(OpBuilder & b,ReifiedRankedShapedTypeDims & reifiedReturnShapes)563 LinalgOp::reifyResultShapes(OpBuilder &b,
564                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
565   // An example that helps understand the logic below.
566   // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
567   // We want to express the shape of dim 0 of O in terms of shape of the inputs.
568   // This is achieved as follows.
569   //   loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
570   //   subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
571   //   shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
572   //   resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
573   //     = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
574   AffineMap loopsToShapesMap = getLoopsToShapesMap();
575 
576   // Find the position in the above map that represents the shape of the
577   // result:dim being inferred.
578   auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
579 
580   /// From loopsToShapesMap extract the submap that represents the shape of the
581   /// (resultIdx, dim) needed.
582   AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
583       resultShapesSubMapPos.first,
584       resultShapesSubMapPos.second - resultShapesSubMapPos.first);
585   AffineMap resultShapesFromInputShapesMap =
586       loopToResultsShapeMap.compose(getShapesToLoopsMap());
587 
588   // Check that the result dim map does not contain the positions corresponding
589   // to the outputs.
590   llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
591   outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
592   HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
593   Location loc = getOperation()->getLoc();
594   auto allResultDimValues =
595       applyMapToValues(b, loc, resultShapesFromInputShapesMap,
596                        createFlatListOfOperandDims(b, loc));
597   int64_t pos = 0;
598   ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
599   for (OpOperand *opOperand : getOutputOperands()) {
600     SmallVector<Value> shapes;
601     for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
602       if (checkDimExpr.visit(shapeExprs[pos]))
603         shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim));
604       else
605         shapes.push_back(allResultDimValues[pos]);
606       pos++;
607     }
608     reifiedReturnShapes.emplace_back(std::move(shapes));
609   }
610   return success();
611 }
612 
verifyStructuredOpInterface(Operation * op)613 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
614   LinalgOp linalgOp = cast<LinalgOp>(op);
615   // Expect at least one output operand.
616   // This means an op that constructs a tensor out of indices cannot be a
617   // LinalgOp at the moment. For now this will have to be a special op until we
618   // have output shape operands that are not tensors.
619   int64_t numInputs = linalgOp.getNumInputs();
620   int64_t numOutputs = linalgOp.getNumOutputs();
621   if (numOutputs == 0)
622     return op->emitOpError("expected at least one output operand");
623   if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
624     return failure();
625   // Verify the number of results matches the number of output tensors.
626   if (op->getNumResults() != linalgOp.getOutputTensorOperands().size())
627     return op->emitOpError("expected the number of results (")
628            << op->getNumResults()
629            << ") to be equal to the number of output tensors ("
630            << linalgOp.getOutputTensorOperands().size() << ")";
631 
632   // Check all iterator types are known.
633   auto iteratorTypesRange =
634       linalgOp.iterator_types().getAsValueRange<StringAttr>();
635   for (StringRef iteratorType : iteratorTypesRange) {
636     if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType))
637       return op->emitOpError("unexpected iterator_type (")
638              << iteratorType << ")";
639   }
640 
641   // Before checking indexing maps, we need to make sure the attributes
642   // referenced by it are valid.
643   if (linalgOp.hasDynamicIndexingMaps())
644     if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
645       return failure();
646 
647   // All input/output operands must be indexed.
648   if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
649       linalgOp.getNumInputsAndOutputs())
650     return op->emitOpError("expected the number of indexing_map (")
651            << linalgOp.getIndexingMapsArray().size()
652            << ") to be equal to the number of input/output operands ("
653            << linalgOp.getNumInputsAndOutputs() << ")";
654 
655   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
656     AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
657 
658     // Symbols disallowed.
659     if (indexingMap.getNumSymbols() != 0)
660       return op->emitOpError("unexpected symbols in indexing_map #")
661              << opOperand->getOperandNumber();
662 
663     // Domain must be consistent.
664     unsigned numLoops = linalgOp.getNumLoops();
665     if (indexingMap.getNumDims() != numLoops)
666       return op->emitOpError("expected indexing_map #")
667              << opOperand->getOperandNumber() << " to have " << numLoops
668              << " dim(s) to match the number of loops";
669 
670     int64_t rank = linalgOp.getRank(opOperand);
671     if (indexingMap.getNumResults() != rank)
672       return op->emitOpError("expected operand rank (")
673              << rank << ") to match the result rank of indexing_map #"
674              << opOperand->getOperandNumber() << " ("
675              << indexingMap.getNumResults() << ")";
676   }
677 
678   SmallVector<unsigned> redDims;
679   linalgOp.getReductionDims(redDims);
680 
681   // Simplifying assumption: either full tensor or full buffer mode.
682   // This allows simpler verification of output operands vs result types
683   // without premature tracking of which operand is what in mixed-mode.
684   // TODO: relax when mixed-mode needs to pass verification.
685   if (!linalgOp.getOutputBufferOperands().empty() &&
686       !linalgOp.getOutputTensorOperands().empty())
687     return op->emitOpError(
688         "expected output operands to all have tensor type or "
689         "all have buffer type");
690 
691   for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) {
692     OpResult result = linalgOp.getTiedOpResult(opOperand);
693     if (result.getType() != opOperand->get().getType())
694       return op->emitOpError("expected type of operand #")
695              << opOperand->getOperandNumber() << " ("
696              << opOperand->get().getType() << ")"
697              << " to match type of corresponding result (" << result.getType()
698              << ")";
699   }
700 
701   // Output tensor indexing map may not depend on reduction indices.
702   for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
703     AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
704     for (AffineExpr expr : indexingMap.getResults()) {
705       for (unsigned pos : redDims) {
706         if (expr.isFunctionOfDim(pos)) {
707           std::string exprStr;
708           {
709             llvm::raw_string_ostream os(exprStr);
710             os << expr;
711           }
712           return op->emitOpError(
713                      "unexpected output tensor expression in indexing map #")
714                  << (opOperand->getOperandNumber() - linalgOp.getNumInputs())
715                  << " a.k.a '" << exprStr
716                  << "' is function of reduction iterator 'd" << pos << "'";
717         }
718       }
719     }
720   }
721 
722   // Check the region has exactly one block.
723   if (linalgOp->getNumRegions() != 1 ||
724       !llvm::hasSingleElement(linalgOp->getRegion(0)))
725     return op->emitOpError("expects to have 1 region with 1 block");
726 
727   if (!linalgOp.getShapesToLoopsMap())
728     return op->emitOpError("expected the shape-to-loops map to be non-null");
729 
730   // Simplifying assumption: bbargs match 1-1 with shape operands elemental
731   // types.
732   // TODO: once ranked shape types are plugged in, we may want to drop the
733   // corresponding bbargs, that can never be read from. This will be subject to
734   // consistency discussions (i.e. what to do with output tensors whose bbarg is
735   // not used).
736   Block &block = linalgOp->getRegion(0).front();
737 
738   if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments())
739     return op->emitOpError("expected as many non-induction variable region "
740                            "arguments as the number of input/output operands");
741 
742   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
743     Type elementType = getElementTypeOrSelf(opOperand->get());
744     Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
745     if (elementType != argType)
746       return op->emitOpError("expected type of bb argument #")
747              << opOperand->getOperandNumber() << " (" << argType << ")"
748              << " to match element or self type of the corresponding operand ("
749              << elementType << ")";
750   }
751 
752   // Check if given shapes match to inferred shapes.
753   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
754   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
755 
756   // Verify only static cases since we can't get exact dimension sizes and loop
757   // ranges for dynamic cases in this stage.
758   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
759     for (int64_t &range : endLoopRangeValues)
760       range -= 1;
761     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
762       AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
763       SmallVector<int64_t, 4> startIndices =
764           indexingMap.compose(startLoopRangeValues);
765       SmallVector<int64_t, 4> endIndices =
766           indexingMap.compose(endLoopRangeValues);
767       ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
768       for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
769         // Ignore dynamic dimension or the case that the dimension size is 0
770         if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
771           continue;
772 
773         // The first index or last index should be the maximum or the minimum in
774         // the inferred index ranges since the range is increasing or
775         // decreasing. The size of dimensions of input/output operands and the
776         // maximum value + 1 in the inferred range should be the same. But, for
777         // now we check if the inferred ranges are in boundary of input/output
778         // operands' size or not in case that Affine Expressions are complicated
779         // such as d0 * 3
780         // + d1 since it is not easy to handle the issues.
781         // Found the case that this solution can't check, for example, (d0, d1)
782         // -> (d1 - d0)
783         int64_t inferredDimSize =
784             std::max(startIndices[dim], endIndices[dim]) + 1;
785         if (std::min(startIndices[dim], endIndices[dim]) < 0) {
786           std::string mapStr;
787           {
788             llvm::raw_string_ostream os(mapStr);
789             os << indexingMap;
790           }
791           return op->emitOpError(
792                      "unexpected result less than 0 at expression #")
793                  << dim << " in " << mapStr;
794         }
795         if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
796           if (inferredDimSize != shape[dim]) {
797             return op->emitOpError("inferred input/output operand #")
798                    << opOperand->getOperandNumber()
799                    << " has shape's dimension #" << dim << " to be "
800                    << inferredDimSize << ", but found " << shape[dim];
801           }
802         } else {
803           if (inferredDimSize > shape[dim]) {
804             return op->emitOpError("inferred input/output operand #")
805                    << opOperand->getOperandNumber()
806                    << " has shape's dimension #" << dim
807                    << " to be greater than or equal to " << inferredDimSize
808                    << ", but found " << shape[dim];
809           }
810         }
811       }
812     }
813   }
814 
815   return success();
816 }
817