1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
10
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Complex/IR/Complex.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/AffineExprVisitor.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/SmallBitVector.h"
20
21 using namespace mlir;
22 using namespace mlir::linalg;
23
24 /// Include the definitions of the copy operation interface.
25 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
26
27 //===----------------------------------------------------------------------===//
28 // Interface utility functions
29 //===----------------------------------------------------------------------===//
canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp,ArrayRef<OpOperand * > droppedOperands)30 bool linalg::detail::canOpOperandsBeDroppedImpl(
31 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
32 SmallVector<AffineMap> indexingMaps;
33 for (auto *opOperand : linalgOp.getInputAndOutputOperands()) {
34 if (llvm::is_contained(droppedOperands, opOperand))
35 continue;
36 indexingMaps.push_back(linalgOp.getTiedIndexingMap(opOperand));
37 }
38 return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
39 }
40
41 //===----------------------------------------------------------------------===//
42 // ContractionOpInterface implementation
43 //===----------------------------------------------------------------------===//
44
45 /// Return true if the use-def chain from `v` to `from` consists of 0 or more
46 /// unary single-operand operations.
47 // TODO: relax to multi-operands with constants, which are technically unary ops
48 // as needed (e.g. add5).
isChainOfUnaryOpsFrom(Value v,Value from)49 static bool isChainOfUnaryOpsFrom(Value v, Value from) {
50 while (true) {
51 if (v == from)
52 return true;
53 Operation *op = v.getDefiningOp();
54 if (!op || op->getNumOperands() != 1)
55 return false;
56 v = op->getOperand(0);
57 };
58 }
59
60 /// Return the unique instance of OpType in `block` if it is indeed unique.
61 /// Return null if none or more than 1 instances exist.
62 template <typename OpType>
getSingleOpOfType(Block & block)63 static OpType getSingleOpOfType(Block &block) {
64 OpType res = nullptr;
65 block.walk([&](OpType op) {
66 if (res) {
67 res = nullptr;
68 return WalkResult::interrupt();
69 }
70 res = op;
71 return WalkResult::advance();
72 });
73 return res;
74 }
75
76 /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
77 /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
78 /// unary operations that may change the type.
79 template <typename AddOpType, typename MulOpType>
isAddMul(Block & block)80 static bool isAddMul(Block &block) {
81 if (block.getNumArguments() != 3)
82 return false;
83 Operation *yieldOp = block.getTerminator();
84 if (yieldOp->getNumOperands() != 1)
85 return false;
86
87 AddOpType addOp = getSingleOpOfType<AddOpType>(block);
88 MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
89 if (!addOp || !mulOp)
90 return false;
91
92 Value argA = block.getArgument(0), argB = block.getArgument(1);
93 Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
94 Value mul = mulOp->getResult(0);
95 Value argC = block.getArgument(2);
96 Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
97 Value add = addOp->getResult(0);
98 Value res = yieldOp->getOperand(0);
99 // Result traces back to add.
100 auto un = isChainOfUnaryOpsFrom;
101 bool success = un(res, add);
102 // One of the operands of add traces back to argC, the other to the mul.
103 success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
104 // One of the operands of mul traces back to argA, the other to argB.
105 success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
106 return success;
107 }
108
109 enum class MatchContractionResult {
110 Success = 0,
111 NotLinalgOp,
112 WrongNumOperands,
113 NoReduction,
114 NotProjectedPermutations,
115 NotAddMul
116 };
isContractionInterfaceImpl(Operation * op)117 static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
118 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
119 if (!linalgOp)
120 return MatchContractionResult::NotLinalgOp;
121 if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
122 return MatchContractionResult::WrongNumOperands;
123 auto mapRange = linalgOp.getIndexingMapsArray();
124 if (linalgOp.getNumReductionLoops() == 0)
125 return MatchContractionResult::NoReduction;
126 if (llvm::any_of(mapRange,
127 [](AffineMap m) { return !m.isProjectedPermutation(); }))
128 return MatchContractionResult::NotProjectedPermutations;
129 // TODO: more fields than add/mul.
130 if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
131 !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
132 !isAddMul<complex::AddOp, complex::MulOp>(
133 linalgOp->getRegion(0).front()) &&
134 !isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
135 return MatchContractionResult::NotAddMul;
136 return MatchContractionResult::Success;
137 }
138
isaContractionOpInterface(LinalgOp linalgOp)139 bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
140 if (!linalgOp)
141 return false;
142 Operation *op = linalgOp.getOperation();
143 return isa<ContractionOpInterface>(op) ||
144 (isContractionInterfaceImpl(op) == MatchContractionResult::Success);
145 }
146
147 /// Verify that a LinalgOp `op` is a contraction.
148 /// A Linalg contraction is defined in general terms:
149 /// 1. Has 2 input and 1 output shapes.
150 /// 2. Has at least one reduction dimension.
151 /// 3. Has only projected permutation indexing maps.
152 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
153 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
154 /// operations that may change the type (e.g. for mixed-precision).
155 /// As a consequence, when vectorization of such an op occurs, the only special
156 /// behavior is that the (unique) MulOpType is vectorized into a
157 /// `vector.contract`. All other ops are handled in a generic fashion.
158 /// In the future, we may wish to allow more input arguments and elementwise and
159 /// constant operations that do not involve the reduction dimension(s).
verifyContractionInterface(Operation * op)160 LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
161 auto res = isContractionInterfaceImpl(op);
162 if (res == MatchContractionResult::NotLinalgOp)
163 return op->emitError("expected a LinalgOp");
164 if (res == MatchContractionResult::WrongNumOperands)
165 return op->emitError("expected op with 2 inputs and 1 outputs");
166 if (res == MatchContractionResult::NoReduction)
167 return op->emitError("expected at least a reduction loop");
168 if (res == MatchContractionResult::NotProjectedPermutations)
169 return op->emitError("expected all indexings to be projected permutations");
170 if (res == MatchContractionResult::NotAddMul)
171 return op->emitError("(add, mul) operations not found");
172 return success();
173 }
174
175 //===----------------------------------------------------------------------===//
176 // ConvolutionOpInterface implementation
177 //===----------------------------------------------------------------------===//
178
179 /// Of the given two expressions returns one that is of type T (`lhs` gets
180 /// preference over `rhs`)
181 template <typename T>
getAffineExprOfType(AffineExpr lhs,AffineExpr rhs)182 static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
183 return lhs.isa<T>() ? lhs.cast<T>()
184 : (rhs.isa<T>() ? rhs.cast<T>() : nullptr);
185 }
186
187 namespace {
188 /// Walk the indexing expressions for input of a convolution operation to verify
189 /// its of the right form, either
190 /// - AffineDimExpr
191 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
192 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
193 ///
194 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
195 /// dimensions and verifies each dimension occurs only once.
196 struct ConvAccessExprWalker
197 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
198 llvm::SmallDenseSet<unsigned> convolvedDims;
199 llvm::SmallDenseSet<unsigned> unConvolvedDims;
200
visitDimExpr__anonfddbe0890311::ConvAccessExprWalker201 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
202 unsigned position = dimExpr.getPosition();
203 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
204 return failure();
205 }
206 unConvolvedDims.insert(position);
207 return success();
208 }
209
visitSymbolExpr__anonfddbe0890311::ConvAccessExprWalker210 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
211
visitConstantExpr__anonfddbe0890311::ConvAccessExprWalker212 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
213
visitAffineBinaryOpExpr__anonfddbe0890311::ConvAccessExprWalker214 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
215 // In pre-order visit, top level op has to be an add op.
216 if (binaryExpr.getKind() != AffineExprKind::Add)
217 return failure();
218 return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) &&
219 succeeded(isDimExprOrMulExpr(binaryExpr.getRHS())));
220 }
221
isDimExprOrMulExpr__anonfddbe0890311::ConvAccessExprWalker222 LogicalResult isDimExprOrMulExpr(AffineExpr expr) {
223 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
224 unsigned dim = dimExpr.getPosition();
225 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
226 return failure();
227 convolvedDims.insert(dim);
228 return success();
229 }
230 if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
231 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
232 return failure();
233 auto lhsExpr = symbolMulExpr.getLHS();
234 auto rhsExpr = symbolMulExpr.getRHS();
235 // Check for symbol expression.
236 AffineExpr mulExpr =
237 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
238 // If there was no symbol expr, check for constant expression.
239 if (!mulExpr) {
240 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
241 }
242 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
243 if (!mulExpr || !dimExpr)
244 return failure();
245 unsigned dim = dimExpr.getPosition();
246 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
247 return failure();
248 convolvedDims.insert(dim);
249 return success();
250 }
251 return failure();
252 }
253 };
254 } // namespace
255
getPreservedDims(AffineMap map)256 static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
257 assert(map.isProjectedPermutation() &&
258 "expected map to have projected permutations");
259 llvm::SmallDenseSet<unsigned> preservedDims;
260 for (auto expr : map.getResults())
261 preservedDims.insert(expr.cast<AffineDimExpr>().getPosition());
262 return preservedDims;
263 }
264
265 enum class MatchConvolutionResult {
266 Success = 0,
267 NotLinalgOp,
268 WrongNumOperands,
269 WrongInputIndexingMap,
270 NotProjectedPermutations,
271 NonConvolutionLoop,
272 OutputDimsNotParallel,
273 NonOutputDimNotReduction
274 };
275
isConvolutionInterfaceImpl(Operation * op)276 static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
277 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
278 if (!linalgOp)
279 return MatchConvolutionResult::NotLinalgOp;
280 if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1)
281 return MatchConvolutionResult::WrongNumOperands;
282
283 auto indexingMaps = linalgOp.getIndexingMapsArray();
284
285 // Check the input indexing map has the right form.
286 ConvAccessExprWalker inputExprWalker;
287 if (llvm::any_of(indexingMaps[0].getResults(),
288 [&inputExprWalker](AffineExpr expr) {
289 return failed(inputExprWalker.visit(expr));
290 })) {
291 return MatchConvolutionResult::WrongInputIndexingMap;
292 }
293
294 // Filter and output maps must be projected permutation.
295 if (!indexingMaps[1].isProjectedPermutation() ||
296 !indexingMaps.back().isProjectedPermutation())
297 return MatchConvolutionResult::NotProjectedPermutations;
298
299 auto iteratorTypesRange =
300 linalgOp.iterator_types().getAsValueRange<StringAttr>();
301
302 llvm::SmallDenseSet<unsigned> outputDims =
303 getPreservedDims(indexingMaps.back());
304 llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
305 // Make sure all loops are charecterized as one of:
306 // - Batch loop : present in output, as non-convolved in input, not present in
307 // filter.
308 // - Output image dimension : present in output, convolved dims in input, not
309 // present in filter.
310 // - Output channel dimension : present in output, not present in input,
311 // present in filter.
312 // - Filter loop dimension : present in filter, convolved in input, not
313 // present in output.
314 // - Input channel dimension : unconvolved in input, not present in output,
315 // present in filter.
316 // - Depth multiplier : unconvolved in input, present in output, present in
317 // filter.
318 llvm::SmallDenseSet<unsigned> allLoopDims;
319 for (auto outputExpr : indexingMaps.back().getResults()) {
320 unsigned outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
321 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
322 !filterDims.count(outputDim)) {
323 // Batch dimension.
324 if (*std::next(iteratorTypesRange.begin(), outputDim) !=
325 getParallelIteratorTypeName())
326 return MatchConvolutionResult::OutputDimsNotParallel;
327 allLoopDims.insert(outputDim);
328 continue;
329 }
330 if (inputExprWalker.convolvedDims.count(outputDim) &&
331 !filterDims.count(outputDim)) {
332 // Output image Loop dimension.
333 if (*std::next(iteratorTypesRange.begin(), outputDim) !=
334 getParallelIteratorTypeName())
335 return MatchConvolutionResult::OutputDimsNotParallel;
336 allLoopDims.insert(outputDim);
337 continue;
338 }
339 if (!inputExprWalker.convolvedDims.count(outputDim) &&
340 !inputExprWalker.unConvolvedDims.count(outputDim) &&
341 filterDims.count(outputDim)) {
342 // Output channel dimension.
343 if (*std::next(iteratorTypesRange.begin(), outputDim) !=
344 getParallelIteratorTypeName())
345 return MatchConvolutionResult::OutputDimsNotParallel;
346 allLoopDims.insert(outputDim);
347 continue;
348 }
349 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
350 filterDims.count(outputDim)) {
351 // Depth multiplier.
352 if (*std::next(iteratorTypesRange.begin(), outputDim) !=
353 getParallelIteratorTypeName())
354 return MatchConvolutionResult::OutputDimsNotParallel;
355 allLoopDims.insert(outputDim);
356 continue;
357 }
358 return MatchConvolutionResult::NonConvolutionLoop;
359 }
360 for (auto filterExpr : indexingMaps[1].getResults()) {
361 unsigned filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
362 if (outputDims.count(filterDim) &&
363 !inputExprWalker.unConvolvedDims.count(filterDim) &&
364 !inputExprWalker.convolvedDims.count(filterDim)) {
365 // Output channel dimension. THis is already seen, continue;
366 continue;
367 }
368 if (inputExprWalker.convolvedDims.count(filterDim) &&
369 !outputDims.count(filterDim)) {
370 // Filter loop dimension.
371 if (*std::next(iteratorTypesRange.begin(), filterDim) !=
372 getReductionIteratorTypeName())
373 return MatchConvolutionResult::NonOutputDimNotReduction;
374 if (allLoopDims.count(filterDim))
375 return MatchConvolutionResult::NonConvolutionLoop;
376 allLoopDims.insert(filterDim);
377 continue;
378 }
379 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
380 !outputDims.count(filterDim)) {
381 // Input channel dimension.
382 if (*std::next(iteratorTypesRange.begin(), filterDim) !=
383 getReductionIteratorTypeName())
384 return MatchConvolutionResult::NonOutputDimNotReduction;
385 if (allLoopDims.count(filterDim))
386 return MatchConvolutionResult::NonConvolutionLoop;
387 allLoopDims.insert(filterDim);
388 continue;
389 }
390 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
391 outputDims.count(filterDim)) {
392 // Depthwise loop. Already seen.
393 continue;
394 }
395 return MatchConvolutionResult::NonConvolutionLoop;
396 }
397 // All loops must be covered now.
398 if (allLoopDims.size() != linalgOp.getNumLoops())
399 return MatchConvolutionResult::NonConvolutionLoop;
400
401 return MatchConvolutionResult::Success;
402 }
403
verifyConvolutionInterface(Operation * op)404 LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
405 auto res = isConvolutionInterfaceImpl(op);
406 if (res == MatchConvolutionResult::NotLinalgOp)
407 return op->emitError("expected a LinalgOp");
408 if (res == MatchConvolutionResult::WrongNumOperands)
409 return op->emitError("expected op with 2 inputs and 1 output");
410 if (res == MatchConvolutionResult::WrongInputIndexingMap)
411 return op->emitError("unexpected input index map for convolutions");
412 if (res == MatchConvolutionResult::NotProjectedPermutations) {
413 return op->emitError(
414 "expected output/filter indexing maps to be projected permutations");
415 }
416 if (res == MatchConvolutionResult::NonConvolutionLoop) {
417 return op->emitError("unexpected loop dimension for convolution op");
418 }
419 if (res == MatchConvolutionResult::OutputDimsNotParallel) {
420 return op->emitError(
421 "expected all iterators used to access outputs to be parallel");
422 }
423 if (res == MatchConvolutionResult::NonOutputDimNotReduction) {
424 return op->emitError(
425 "expected all iterators not used to access outputs to be reduction");
426 }
427 return success();
428 }
429
430 //===----------------------------------------------------------------------===//
431 // FillOpInterface implementation
432 //===----------------------------------------------------------------------===//
433
434 enum class MatchFillResult {
435 Success = 0,
436 NotLinalgOp,
437 WrongNumOperands,
438 NotScalarInput
439 };
440
isFillInterfaceImpl(Operation * op)441 static MatchFillResult isFillInterfaceImpl(Operation *op) {
442 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
443 if (!linalgOp)
444 return MatchFillResult::NotLinalgOp;
445 if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1)
446 return MatchFillResult::WrongNumOperands;
447
448 OpOperand *value = linalgOp.getInputOperand(0);
449 if (!linalgOp.isScalar(value))
450 return MatchFillResult::NotScalarInput;
451
452 return MatchFillResult::Success;
453 }
454
verifyFillInterface(Operation * op)455 LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
456 auto res = isFillInterfaceImpl(op);
457 if (res == MatchFillResult::NotLinalgOp)
458 return op->emitError("expected a LinalgOp");
459 if (res == MatchFillResult::WrongNumOperands)
460 return op->emitError("expected op with 1 input and 1 output");
461 if (res == MatchFillResult::NotScalarInput)
462 return op->emitError("expected op with scalar input");
463
464 return success();
465 }
466
467 //===----------------------------------------------------------------------===//
468 // StructuredOpInterface implementation
469 //===----------------------------------------------------------------------===//
470
operator SmallVector<Value>()471 OpOperandVector::operator SmallVector<Value>() {
472 SmallVector<Value> result;
473 result.reserve(this->size());
474 llvm::transform(*this, std::back_inserter(result),
475 [](OpOperand *opOperand) { return opOperand->get(); });
476 return result;
477 }
478
479 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
480 /// the type of `source`.
createOrFoldDimOp(OpBuilder & b,Location loc,Value source,int64_t dim)481 static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
482 int64_t dim) {
483 if (source.getType().isa<UnrankedMemRefType, MemRefType>())
484 return b.createOrFold<memref::DimOp>(loc, source, dim);
485 if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
486 return b.createOrFold<tensor::DimOp>(loc, source, dim);
487 llvm_unreachable("Expected MemRefType or TensorType");
488 }
489
createFlatListOfOperandDims(OpBuilder & b,Location loc)490 SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
491 Location loc) {
492 SmallVector<Value, 4> res;
493 for (OpOperand *opOperand : getInputAndOutputOperands()) {
494 for (int64_t i = 0, e = getRank(opOperand); i < e; ++i)
495 res.push_back(createOrFoldDimOp(b, loc, opOperand->get(), i));
496 }
497 return res;
498 }
499
createFlatListOfOperandStaticDims()500 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
501 SmallVector<int64_t, 4> res;
502 assert(!hasDynamicShape() && "expected operands to have static shapes");
503 for (OpOperand *opOperand : getInputAndOutputOperands())
504 llvm::append_range(res, getShape(opOperand));
505 return res;
506 }
507
createLoopRanges(OpBuilder & b,Location loc)508 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
509 AffineMap map = getLoopsToShapesMap();
510 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
511 auto viewSizes = createFlatListOfOperandDims(b, loc);
512 SmallVector<Range, 4> res(numDims);
513 Value zeroVal = b.create<arith::ConstantIndexOp>(loc, 0);
514 Value oneVal = b.create<arith::ConstantIndexOp>(loc, 1);
515 for (unsigned idx = 0; idx < numRes; ++idx) {
516 auto result = map.getResult(idx);
517 if (auto d = result.dyn_cast<AffineDimExpr>()) {
518 if (res[d.getPosition()].offset)
519 continue;
520 res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
521 }
522 }
523 return res;
524 }
525
computeStaticLoopSizes()526 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
527 AffineMap map = getLoopsToShapesMap();
528 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
529 SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
530 SmallVector<int64_t, 4> res(numDims, 0);
531 for (unsigned idx = 0; idx < numRes; ++idx) {
532 auto result = map.getResult(idx);
533 if (auto d = result.dyn_cast<AffineDimExpr>())
534 res[d.getPosition()] = allShapeSizes[idx];
535 }
536 return res;
537 }
538
539 /// Visitor to check if any of the given set of positions from AffineDimExprs
540 /// are used within an AffineExpr.
541 struct HasAffineDimExprVisitor
542 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
HasAffineDimExprVisitorHasAffineDimExprVisitor543 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
544 : positions(std::move(positions)) {}
545
visitAffineBinaryOpExprHasAffineDimExprVisitor546 bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
547 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
548 }
549
visitDimExprHasAffineDimExprVisitor550 bool visitDimExpr(AffineDimExpr dimExpr) {
551 return positions.test(dimExpr.getPosition());
552 }
553
visitConstantExprHasAffineDimExprVisitor554 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
555
visitSymbolExprHasAffineDimExprVisitor556 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
557
558 private:
559 llvm::SmallBitVector positions;
560 };
561
562 LogicalResult
reifyResultShapes(OpBuilder & b,ReifiedRankedShapedTypeDims & reifiedReturnShapes)563 LinalgOp::reifyResultShapes(OpBuilder &b,
564 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
565 // An example that helps understand the logic below.
566 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
567 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
568 // This is achieved as follows.
569 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
570 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
571 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
572 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
573 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
574 AffineMap loopsToShapesMap = getLoopsToShapesMap();
575
576 // Find the position in the above map that represents the shape of the
577 // result:dim being inferred.
578 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
579
580 /// From loopsToShapesMap extract the submap that represents the shape of the
581 /// (resultIdx, dim) needed.
582 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
583 resultShapesSubMapPos.first,
584 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
585 AffineMap resultShapesFromInputShapesMap =
586 loopToResultsShapeMap.compose(getShapesToLoopsMap());
587
588 // Check that the result dim map does not contain the positions corresponding
589 // to the outputs.
590 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
591 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
592 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
593 Location loc = getOperation()->getLoc();
594 auto allResultDimValues =
595 applyMapToValues(b, loc, resultShapesFromInputShapesMap,
596 createFlatListOfOperandDims(b, loc));
597 int64_t pos = 0;
598 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
599 for (OpOperand *opOperand : getOutputOperands()) {
600 SmallVector<Value> shapes;
601 for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
602 if (checkDimExpr.visit(shapeExprs[pos]))
603 shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim));
604 else
605 shapes.push_back(allResultDimValues[pos]);
606 pos++;
607 }
608 reifiedReturnShapes.emplace_back(std::move(shapes));
609 }
610 return success();
611 }
612
verifyStructuredOpInterface(Operation * op)613 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
614 LinalgOp linalgOp = cast<LinalgOp>(op);
615 // Expect at least one output operand.
616 // This means an op that constructs a tensor out of indices cannot be a
617 // LinalgOp at the moment. For now this will have to be a special op until we
618 // have output shape operands that are not tensors.
619 int64_t numInputs = linalgOp.getNumInputs();
620 int64_t numOutputs = linalgOp.getNumOutputs();
621 if (numOutputs == 0)
622 return op->emitOpError("expected at least one output operand");
623 if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
624 return failure();
625 // Verify the number of results matches the number of output tensors.
626 if (op->getNumResults() != linalgOp.getOutputTensorOperands().size())
627 return op->emitOpError("expected the number of results (")
628 << op->getNumResults()
629 << ") to be equal to the number of output tensors ("
630 << linalgOp.getOutputTensorOperands().size() << ")";
631
632 // Check all iterator types are known.
633 auto iteratorTypesRange =
634 linalgOp.iterator_types().getAsValueRange<StringAttr>();
635 for (StringRef iteratorType : iteratorTypesRange) {
636 if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType))
637 return op->emitOpError("unexpected iterator_type (")
638 << iteratorType << ")";
639 }
640
641 // Before checking indexing maps, we need to make sure the attributes
642 // referenced by it are valid.
643 if (linalgOp.hasDynamicIndexingMaps())
644 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
645 return failure();
646
647 // All input/output operands must be indexed.
648 if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
649 linalgOp.getNumInputsAndOutputs())
650 return op->emitOpError("expected the number of indexing_map (")
651 << linalgOp.getIndexingMapsArray().size()
652 << ") to be equal to the number of input/output operands ("
653 << linalgOp.getNumInputsAndOutputs() << ")";
654
655 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
656 AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
657
658 // Symbols disallowed.
659 if (indexingMap.getNumSymbols() != 0)
660 return op->emitOpError("unexpected symbols in indexing_map #")
661 << opOperand->getOperandNumber();
662
663 // Domain must be consistent.
664 unsigned numLoops = linalgOp.getNumLoops();
665 if (indexingMap.getNumDims() != numLoops)
666 return op->emitOpError("expected indexing_map #")
667 << opOperand->getOperandNumber() << " to have " << numLoops
668 << " dim(s) to match the number of loops";
669
670 int64_t rank = linalgOp.getRank(opOperand);
671 if (indexingMap.getNumResults() != rank)
672 return op->emitOpError("expected operand rank (")
673 << rank << ") to match the result rank of indexing_map #"
674 << opOperand->getOperandNumber() << " ("
675 << indexingMap.getNumResults() << ")";
676 }
677
678 SmallVector<unsigned> redDims;
679 linalgOp.getReductionDims(redDims);
680
681 // Simplifying assumption: either full tensor or full buffer mode.
682 // This allows simpler verification of output operands vs result types
683 // without premature tracking of which operand is what in mixed-mode.
684 // TODO: relax when mixed-mode needs to pass verification.
685 if (!linalgOp.getOutputBufferOperands().empty() &&
686 !linalgOp.getOutputTensorOperands().empty())
687 return op->emitOpError(
688 "expected output operands to all have tensor type or "
689 "all have buffer type");
690
691 for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) {
692 OpResult result = linalgOp.getTiedOpResult(opOperand);
693 if (result.getType() != opOperand->get().getType())
694 return op->emitOpError("expected type of operand #")
695 << opOperand->getOperandNumber() << " ("
696 << opOperand->get().getType() << ")"
697 << " to match type of corresponding result (" << result.getType()
698 << ")";
699 }
700
701 // Output tensor indexing map may not depend on reduction indices.
702 for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
703 AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
704 for (AffineExpr expr : indexingMap.getResults()) {
705 for (unsigned pos : redDims) {
706 if (expr.isFunctionOfDim(pos)) {
707 std::string exprStr;
708 {
709 llvm::raw_string_ostream os(exprStr);
710 os << expr;
711 }
712 return op->emitOpError(
713 "unexpected output tensor expression in indexing map #")
714 << (opOperand->getOperandNumber() - linalgOp.getNumInputs())
715 << " a.k.a '" << exprStr
716 << "' is function of reduction iterator 'd" << pos << "'";
717 }
718 }
719 }
720 }
721
722 // Check the region has exactly one block.
723 if (linalgOp->getNumRegions() != 1 ||
724 !llvm::hasSingleElement(linalgOp->getRegion(0)))
725 return op->emitOpError("expects to have 1 region with 1 block");
726
727 if (!linalgOp.getShapesToLoopsMap())
728 return op->emitOpError("expected the shape-to-loops map to be non-null");
729
730 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
731 // types.
732 // TODO: once ranked shape types are plugged in, we may want to drop the
733 // corresponding bbargs, that can never be read from. This will be subject to
734 // consistency discussions (i.e. what to do with output tensors whose bbarg is
735 // not used).
736 Block &block = linalgOp->getRegion(0).front();
737
738 if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments())
739 return op->emitOpError("expected as many non-induction variable region "
740 "arguments as the number of input/output operands");
741
742 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
743 Type elementType = getElementTypeOrSelf(opOperand->get());
744 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
745 if (elementType != argType)
746 return op->emitOpError("expected type of bb argument #")
747 << opOperand->getOperandNumber() << " (" << argType << ")"
748 << " to match element or self type of the corresponding operand ("
749 << elementType << ")";
750 }
751
752 // Check if given shapes match to inferred shapes.
753 SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
754 SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
755
756 // Verify only static cases since we can't get exact dimension sizes and loop
757 // ranges for dynamic cases in this stage.
758 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
759 for (int64_t &range : endLoopRangeValues)
760 range -= 1;
761 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
762 AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
763 SmallVector<int64_t, 4> startIndices =
764 indexingMap.compose(startLoopRangeValues);
765 SmallVector<int64_t, 4> endIndices =
766 indexingMap.compose(endLoopRangeValues);
767 ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
768 for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
769 // Ignore dynamic dimension or the case that the dimension size is 0
770 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
771 continue;
772
773 // The first index or last index should be the maximum or the minimum in
774 // the inferred index ranges since the range is increasing or
775 // decreasing. The size of dimensions of input/output operands and the
776 // maximum value + 1 in the inferred range should be the same. But, for
777 // now we check if the inferred ranges are in boundary of input/output
778 // operands' size or not in case that Affine Expressions are complicated
779 // such as d0 * 3
780 // + d1 since it is not easy to handle the issues.
781 // Found the case that this solution can't check, for example, (d0, d1)
782 // -> (d1 - d0)
783 int64_t inferredDimSize =
784 std::max(startIndices[dim], endIndices[dim]) + 1;
785 if (std::min(startIndices[dim], endIndices[dim]) < 0) {
786 std::string mapStr;
787 {
788 llvm::raw_string_ostream os(mapStr);
789 os << indexingMap;
790 }
791 return op->emitOpError(
792 "unexpected result less than 0 at expression #")
793 << dim << " in " << mapStr;
794 }
795 if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
796 if (inferredDimSize != shape[dim]) {
797 return op->emitOpError("inferred input/output operand #")
798 << opOperand->getOperandNumber()
799 << " has shape's dimension #" << dim << " to be "
800 << inferredDimSize << ", but found " << shape[dim];
801 }
802 } else {
803 if (inferredDimSize > shape[dim]) {
804 return op->emitOpError("inferred input/output operand #")
805 << opOperand->getOperandNumber()
806 << " has shape's dimension #" << dim
807 << " to be greater than or equal to " << inferredDimSize
808 << ", but found " << shape[dim];
809 }
810 }
811 }
812 }
813 }
814
815 return success();
816 }
817