1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/AffineExprVisitor.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "llvm/ADT/SmallSet.h" 19 20 using namespace mlir; 21 using namespace mlir::linalg; 22 23 /// Include the definitions of the copy operation interface. 24 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" 25 26 //===----------------------------------------------------------------------===// 27 // ContractionOpInterface implementation 28 //===----------------------------------------------------------------------===// 29 30 /// Return true if the use-def chain from `v` to `from` consists of 0 or more 31 /// unary single-operand operations. 32 // TODO: relax to multi-operands with constants, which are technically unary ops 33 // as needed (e.g. add5). 34 static bool isChainOfUnaryOpsFrom(Value v, Value from) { 35 while (true) { 36 if (v == from) 37 return true; 38 Operation *op = v.getDefiningOp(); 39 if (!op || op->getNumOperands() != 1) 40 return false; 41 v = op->getOperand(0); 42 }; 43 } 44 45 /// Return the unique instance of OpType in `block` if it is indeed unique. 46 /// Return null if none or more than 1 instances exist. 47 template <typename OpType> 48 static OpType getSingleOpOfType(Block &block) { 49 OpType res = nullptr; 50 block.walk([&](OpType op) { 51 if (res) { 52 res = nullptr; 53 return WalkResult::interrupt(); 54 } 55 res = op; 56 return WalkResult::advance(); 57 }); 58 return res; 59 } 60 61 /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))` 62 /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent 63 /// unary operations that may change the type. 64 template <typename AddOpType, typename MulOpType> 65 static bool isAddMul(Block &block) { 66 if (block.getNumArguments() != 3) 67 return false; 68 Operation *yieldOp = block.getTerminator(); 69 if (yieldOp->getNumOperands() != 1) 70 return false; 71 72 AddOpType addOp = getSingleOpOfType<AddOpType>(block); 73 MulOpType mulOp = getSingleOpOfType<MulOpType>(block); 74 if (!addOp || !mulOp) 75 return false; 76 77 Value argA = block.getArgument(0), argB = block.getArgument(1); 78 Value a = mulOp->getOperand(0), b = mulOp->getOperand(1); 79 Value mul = mulOp->getResult(0); 80 Value argC = block.getArgument(2); 81 Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1); 82 Value add = addOp->getResult(0); 83 Value res = yieldOp->getOperand(0); 84 // Result traces back to add. 85 auto un = isChainOfUnaryOpsFrom; 86 bool success = un(res, add); 87 // One of the operands of add traces back to argC, the other to the mul. 88 success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC)); 89 // One of the operands of mul traces back to argA, the other to argB. 90 success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA)); 91 return success; 92 } 93 94 enum class MatchContractionResult { 95 Success = 0, 96 NotLinalgOp, 97 WrongNumOperands, 98 NoReduction, 99 NotProjectedPermutations, 100 NotAddMul 101 }; 102 static MatchContractionResult isContractionInterfaceImpl(Operation *op) { 103 auto linalgOp = dyn_cast<linalg::LinalgOp>(op); 104 if (!linalgOp) 105 return MatchContractionResult::NotLinalgOp; 106 if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) 107 return MatchContractionResult::WrongNumOperands; 108 auto mapRange = linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>(); 109 if (linalgOp.getNumReductionLoops() == 0) 110 return MatchContractionResult::NoReduction; 111 if (llvm::any_of(mapRange, 112 [](AffineMap m) { return !m.isProjectedPermutation(); })) 113 return MatchContractionResult::NotProjectedPermutations; 114 // TODO: more fields than add/mul. 115 if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) && 116 !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front())) 117 return MatchContractionResult::NotAddMul; 118 return MatchContractionResult::Success; 119 } 120 121 bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { 122 if (!linalgOp) 123 return false; 124 Operation *op = linalgOp.getOperation(); 125 return isa<ContractionOpInterface>(op) || 126 (isContractionInterfaceImpl(op) == MatchContractionResult::Success); 127 } 128 129 /// Verify that a LinalgOp `op` is a contraction. 130 /// A Linalg contraction is defined in general terms: 131 /// 1. Has 2 input and 1 output shapes. 132 /// 2. Has at least one reduction dimension. 133 /// 3. Has only projected permutation indexing maps. 134 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field 135 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary 136 /// operations that may change the type (e.g. for mixed-precision). 137 /// As a consequence, when vectorization of such an op occurs, the only special 138 /// behavior is that the (unique) MulOpType is vectorized into a 139 /// `vector.contract`. All other ops are handled in a generic fashion. 140 /// In the future, we may wish to allow more input arguments and elementwise and 141 /// constant operations that do not involve the reduction dimension(s). 142 LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) { 143 auto res = isContractionInterfaceImpl(op); 144 if (res == MatchContractionResult::NotLinalgOp) 145 return op->emitError("expected a LinalgOp"); 146 if (res == MatchContractionResult::WrongNumOperands) 147 return op->emitError("expected op with 2 inputs and 1 outputs"); 148 if (res == MatchContractionResult::NoReduction) 149 return op->emitError("expected at least a reduction loop"); 150 if (res == MatchContractionResult::NotProjectedPermutations) 151 return op->emitError("expected all indexings to be projected permutations"); 152 if (res == MatchContractionResult::NotAddMul) 153 return op->emitError("(add, mul) operations not found"); 154 return success(); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // ConvolutionOpInterface implementation 159 //===----------------------------------------------------------------------===// 160 161 /// Of the given two expressions returns one that is of type T (`lhs` gets 162 /// preference over `rhs`) 163 template <typename T> 164 static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) { 165 return lhs.isa<T>() ? lhs.cast<T>() 166 : (rhs.isa<T>() ? rhs.cast<T>() : nullptr); 167 } 168 169 namespace { 170 /// Walk the indexing expressions for input of a convolution operation to verify 171 /// its of the right form, either 172 /// - AffineDimExpr 173 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))? 174 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)* 175 /// 176 /// classifies the AffineDimExpr as convolved dimensions or unconvolved 177 /// dimensions and verifies each dimension occurs only once. 178 struct ConvAccessExprWalker 179 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> { 180 llvm::SmallDenseSet<unsigned> convolvedDims; 181 llvm::SmallDenseSet<unsigned> unConvolvedDims; 182 183 LogicalResult visitDimExpr(AffineDimExpr dimExpr) { 184 unsigned position = dimExpr.getPosition(); 185 if (unConvolvedDims.count(position) || convolvedDims.count(position)) { 186 return failure(); 187 } 188 unConvolvedDims.insert(position); 189 return success(); 190 } 191 192 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); } 193 194 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); } 195 196 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) { 197 // In pre-order visit, top level op has to be an add op. 198 if (binaryExpr.getKind() != AffineExprKind::Add) 199 return failure(); 200 return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) && 201 succeeded(isDimExprOrMulExpr(binaryExpr.getRHS()))); 202 } 203 204 LogicalResult isDimExprOrMulExpr(AffineExpr expr) { 205 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) { 206 unsigned dim = dimExpr.getPosition(); 207 if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) 208 return failure(); 209 convolvedDims.insert(dim); 210 return success(); 211 } 212 if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) { 213 if (symbolMulExpr.getKind() != AffineExprKind::Mul) 214 return failure(); 215 auto lhsExpr = symbolMulExpr.getLHS(); 216 auto rhsExpr = symbolMulExpr.getRHS(); 217 // Check for symbol expression. 218 AffineExpr mulExpr = 219 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr); 220 // If there was no symbol expr, check for constant expression. 221 if (!mulExpr) { 222 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr); 223 } 224 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr); 225 if (!mulExpr || !dimExpr) 226 return failure(); 227 unsigned dim = dimExpr.getPosition(); 228 if (convolvedDims.count(dim) || unConvolvedDims.count(dim)) 229 return failure(); 230 convolvedDims.insert(dim); 231 return success(); 232 } 233 return failure(); 234 } 235 }; 236 } // namespace 237 238 static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) { 239 assert(map.isProjectedPermutation() && 240 "expected map to have projected permutations"); 241 llvm::SmallDenseSet<unsigned> preservedDims; 242 for (auto expr : map.getResults()) 243 preservedDims.insert(expr.cast<AffineDimExpr>().getPosition()); 244 return preservedDims; 245 } 246 247 enum class MatchConvolutionResult { 248 Success = 0, 249 NotLinalgOp, 250 WrongNumOperands, 251 WrongInputIndexingMap, 252 NotProjectedPermutations, 253 NonConvolutionLoop, 254 OutputDimsNotParallel, 255 NonOutputDimNotReduction 256 }; 257 258 static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { 259 auto linalgOp = dyn_cast<linalg::LinalgOp>(op); 260 if (!linalgOp) 261 return MatchConvolutionResult::NotLinalgOp; 262 if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1) 263 return MatchConvolutionResult::WrongNumOperands; 264 265 auto indexingMaps = linalgOp.getIndexingMaps(); 266 267 // Check the input indexing map has the right form. 268 ConvAccessExprWalker inputExprWalker; 269 if (llvm::any_of(indexingMaps[0].getResults(), 270 [&inputExprWalker](AffineExpr expr) { 271 return failed(inputExprWalker.visit(expr)); 272 })) { 273 return MatchConvolutionResult::WrongInputIndexingMap; 274 } 275 276 // Filter and output maps must be projected permutation. 277 if (!indexingMaps[1].isProjectedPermutation() || 278 !indexingMaps.back().isProjectedPermutation()) 279 return MatchConvolutionResult::NotProjectedPermutations; 280 281 auto iteratorTypesRange = 282 linalgOp.iterator_types().getAsValueRange<StringAttr>(); 283 284 llvm::SmallDenseSet<unsigned> outputDims = 285 getPreservedDims(indexingMaps.back()); 286 llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]); 287 // Make sure all loops are charecterized as one of: 288 // - Batch loop : present in output, as non-convolved in input, not present in 289 // filter. 290 // - Output image dimension : present in output, convolved dims in input, not 291 // present in filter. 292 // - Output channel dimension : present in output, not present in input, 293 // present in filter. 294 // - Filter loop dimension : present in filter, convolved in input, not 295 // present in output. 296 // - Input channel dimension : unconvolved in input, not present in output, 297 // present in filter. 298 // - Depth multiplier : unconvolved in input, present in output, present in 299 // filter. 300 llvm::SmallDenseSet<unsigned> allLoopDims; 301 for (auto outputExpr : indexingMaps.back().getResults()) { 302 unsigned outputDim = outputExpr.cast<AffineDimExpr>().getPosition(); 303 if (inputExprWalker.unConvolvedDims.count(outputDim) && 304 !filterDims.count(outputDim)) { 305 // Batch dimension. 306 if (*std::next(iteratorTypesRange.begin(), outputDim) != 307 getParallelIteratorTypeName()) 308 return MatchConvolutionResult::OutputDimsNotParallel; 309 allLoopDims.insert(outputDim); 310 continue; 311 } 312 if (inputExprWalker.convolvedDims.count(outputDim) && 313 !filterDims.count(outputDim)) { 314 // Output image Loop dimension. 315 if (*std::next(iteratorTypesRange.begin(), outputDim) != 316 getParallelIteratorTypeName()) 317 return MatchConvolutionResult::OutputDimsNotParallel; 318 allLoopDims.insert(outputDim); 319 continue; 320 } 321 if (!inputExprWalker.convolvedDims.count(outputDim) && 322 !inputExprWalker.unConvolvedDims.count(outputDim) && 323 filterDims.count(outputDim)) { 324 // Output channel dimension. 325 if (*std::next(iteratorTypesRange.begin(), outputDim) != 326 getParallelIteratorTypeName()) 327 return MatchConvolutionResult::OutputDimsNotParallel; 328 allLoopDims.insert(outputDim); 329 continue; 330 } 331 if (inputExprWalker.unConvolvedDims.count(outputDim) && 332 filterDims.count(outputDim)) { 333 // Depth multiplier. 334 if (*std::next(iteratorTypesRange.begin(), outputDim) != 335 getParallelIteratorTypeName()) 336 return MatchConvolutionResult::OutputDimsNotParallel; 337 allLoopDims.insert(outputDim); 338 continue; 339 } 340 return MatchConvolutionResult::NonConvolutionLoop; 341 } 342 for (auto filterExpr : indexingMaps[1].getResults()) { 343 unsigned filterDim = filterExpr.cast<AffineDimExpr>().getPosition(); 344 if (outputDims.count(filterDim) && 345 !inputExprWalker.unConvolvedDims.count(filterDim) && 346 !inputExprWalker.convolvedDims.count(filterDim)) { 347 // Output channel dimension. THis is already seen, continue; 348 continue; 349 } 350 if (inputExprWalker.convolvedDims.count(filterDim) && 351 !outputDims.count(filterDim)) { 352 // Filter loop dimension. 353 if (*std::next(iteratorTypesRange.begin(), filterDim) != 354 getReductionIteratorTypeName()) 355 return MatchConvolutionResult::NonOutputDimNotReduction; 356 if (allLoopDims.count(filterDim)) 357 return MatchConvolutionResult::NonConvolutionLoop; 358 allLoopDims.insert(filterDim); 359 continue; 360 } 361 if (inputExprWalker.unConvolvedDims.count(filterDim) && 362 !outputDims.count(filterDim)) { 363 // Input channel dimension. 364 if (*std::next(iteratorTypesRange.begin(), filterDim) != 365 getReductionIteratorTypeName()) 366 return MatchConvolutionResult::NonOutputDimNotReduction; 367 if (allLoopDims.count(filterDim)) 368 return MatchConvolutionResult::NonConvolutionLoop; 369 allLoopDims.insert(filterDim); 370 continue; 371 } 372 if (inputExprWalker.unConvolvedDims.count(filterDim) && 373 outputDims.count(filterDim)) { 374 // Depthwise loop. Already seen. 375 continue; 376 } 377 return MatchConvolutionResult::NonConvolutionLoop; 378 } 379 // All loops must be covered now. 380 if (allLoopDims.size() != linalgOp.getNumLoops()) 381 return MatchConvolutionResult::NonConvolutionLoop; 382 383 return MatchConvolutionResult::Success; 384 } 385 386 LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { 387 auto res = isConvolutionInterfaceImpl(op); 388 if (res == MatchConvolutionResult::NotLinalgOp) 389 return op->emitError("expected a LinalgOp"); 390 if (res == MatchConvolutionResult::WrongNumOperands) 391 return op->emitError("expected op with 2 inputs and 1 output"); 392 if (res == MatchConvolutionResult::WrongInputIndexingMap) 393 return op->emitError("unexpected input index map for convolutions"); 394 if (res == MatchConvolutionResult::NotProjectedPermutations) { 395 return op->emitError( 396 "expected output/filter indexing maps to be projected permutations"); 397 } 398 if (res == MatchConvolutionResult::NonConvolutionLoop) { 399 return op->emitError("unexpected loop dimension for convolution op"); 400 } 401 if (res == MatchConvolutionResult::OutputDimsNotParallel) { 402 return op->emitError( 403 "expected all iterators used to access outputs to be parallel"); 404 } 405 if (res == MatchConvolutionResult::NonOutputDimNotReduction) { 406 return op->emitError( 407 "expected all iterators not used to access outputs to be reduction"); 408 } 409 return success(); 410 } 411 //===----------------------------------------------------------------------===// 412 // StructuredOpInterface implementation 413 //===----------------------------------------------------------------------===// 414 415 OpOperandVector::operator SmallVector<Value>() { 416 SmallVector<Value> result; 417 result.reserve(this->size()); 418 llvm::transform(*this, std::back_inserter(result), 419 [](OpOperand *opOperand) { return opOperand->get(); }); 420 return result; 421 } 422 423 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on 424 /// the type of `source`. 425 static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, 426 int64_t dim) { 427 if (source.getType().isa<UnrankedMemRefType, MemRefType>()) 428 return b.createOrFold<memref::DimOp>(loc, source, dim); 429 if (source.getType().isa<UnrankedTensorType, RankedTensorType>()) 430 return b.createOrFold<tensor::DimOp>(loc, source, dim); 431 llvm_unreachable("Expected MemRefType or TensorType"); 432 } 433 434 SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b, 435 Location loc) { 436 SmallVector<Value, 4> res; 437 for (OpOperand *opOperand : getInputAndOutputOperands()) { 438 for (int64_t i = 0, e = getRank(opOperand); i < e; ++i) 439 res.push_back(createOrFoldDimOp(b, loc, opOperand->get(), i)); 440 } 441 return res; 442 } 443 444 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() { 445 SmallVector<int64_t, 4> res; 446 assert(!hasDynamicShape() && "expected operands to have static shapes"); 447 for (OpOperand *opOperand : getInputAndOutputOperands()) 448 llvm::append_range(res, getShape(opOperand)); 449 return res; 450 } 451 452 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { 453 AffineMap map = getLoopsToShapesMap(); 454 unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); 455 auto viewSizes = createFlatListOfOperandDims(b, loc); 456 SmallVector<Range, 4> res(numDims); 457 Value zeroVal = b.create<arith::ConstantIndexOp>(loc, 0); 458 Value oneVal = b.create<arith::ConstantIndexOp>(loc, 1); 459 for (unsigned idx = 0; idx < numRes; ++idx) { 460 auto result = map.getResult(idx); 461 if (auto d = result.dyn_cast<AffineDimExpr>()) { 462 if (res[d.getPosition()].offset) 463 continue; 464 res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal}; 465 } 466 } 467 return res; 468 } 469 470 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() { 471 AffineMap map = getLoopsToShapesMap(); 472 unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); 473 SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims(); 474 SmallVector<int64_t, 4> res(numDims, 0); 475 for (unsigned idx = 0; idx < numRes; ++idx) { 476 auto result = map.getResult(idx); 477 if (auto d = result.dyn_cast<AffineDimExpr>()) 478 res[d.getPosition()] = allShapeSizes[idx]; 479 } 480 return res; 481 } 482 483 /// Visitor to check if any of the given set of positions from AffineDimExprs 484 /// are used within an AffineExpr. 485 struct HasAffineDimExprVisitor 486 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> { 487 HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions) 488 : positions(positions) {} 489 490 bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { 491 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS()); 492 } 493 494 bool visitDimExpr(AffineDimExpr dimExpr) { 495 return positions.count(dimExpr.getPosition()); 496 } 497 498 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } 499 500 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } 501 502 private: 503 llvm::SmallSet<unsigned, 4> positions; 504 }; 505 506 LogicalResult 507 LinalgOp::reifyResultShapes(OpBuilder &b, 508 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 509 // An example that helps understand the logic below. 510 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) 511 // We want to express the shape of dim 0 of O in terms of shape of the inputs. 512 // This is achieved as follows. 513 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) 514 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1) 515 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) 516 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) 517 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) 518 AffineMap loopsToShapesMap = getLoopsToShapesMap(); 519 520 // Find the position in the above map that represents the shape of the 521 // result:dim being inferred. 522 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(); 523 524 /// From loopsToShapesMap extract the submap that represents the shape of the 525 /// (resultIdx, dim) needed. 526 SmallVector<unsigned, 4> resultPosRange = 527 llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first, 528 resultShapesSubMapPos.second)); 529 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange); 530 AffineMap resultShapesFromInputShapesMap = 531 loopToResultsShapeMap.compose(getShapesToLoopsMap()); 532 533 // Check that the result dim map does not contain the positions corresponding 534 // to the outputs. 535 llvm::SmallSet<unsigned, 4> outputDims; 536 llvm::for_each(resultPosRange, 537 [&outputDims](unsigned dim) { outputDims.insert(dim); }); 538 HasAffineDimExprVisitor checkDimExpr(outputDims); 539 Location loc = getOperation()->getLoc(); 540 auto allResultDimValues = 541 applyMapToValues(b, loc, resultShapesFromInputShapesMap, 542 createFlatListOfOperandDims(b, loc)); 543 int64_t pos = 0; 544 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults(); 545 for (OpOperand *opOperand : getOutputOperands()) { 546 SmallVector<Value> shapes; 547 for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) { 548 if (checkDimExpr.visit(shapeExprs[pos])) 549 shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim)); 550 else 551 shapes.push_back(allResultDimValues[pos]); 552 pos++; 553 } 554 reifiedReturnShapes.emplace_back(std::move(shapes)); 555 } 556 return success(); 557 } 558 559 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { 560 LinalgOp linalgOp = cast<LinalgOp>(op); 561 // Expect at least one output operand. 562 // This means an op that constructs a tensor out of indices cannot be a 563 // LinalgOp at the moment. For now this will have to be a special op until we 564 // have output shape operands that are not tensors. 565 int64_t numInputs = linalgOp.getNumInputs(); 566 int64_t numOutputs = linalgOp.getNumOutputs(); 567 if (numOutputs == 0) 568 return op->emitOpError("expected at least one output operand"); 569 if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) 570 return failure(); 571 // Verify the number of results matches the number of output tensors. 572 if (op->getNumResults() != linalgOp.getOutputTensorOperands().size()) 573 return op->emitOpError("expected the number of results (") 574 << op->getNumResults() 575 << ") to be equal to the number of output tensors (" 576 << linalgOp.getOutputTensorOperands().size() << ")"; 577 578 // Before checking indexing maps, we need to make sure the attributes 579 // referenced by it are valid. 580 if (linalgOp.hasDynamicIndexingMaps()) 581 if (failed(linalgOp.verifyIndexingMapRequiredAttributes())) 582 return failure(); 583 584 // All input/output operands must be indexed. 585 if (static_cast<int64_t>(linalgOp.indexing_maps().size()) != 586 linalgOp.getNumInputsAndOutputs()) 587 return op->emitOpError("expected the number of indexing_map (") 588 << linalgOp.indexing_maps().size() 589 << ") to be equal to the number of input/output operands (" 590 << linalgOp.getNumInputsAndOutputs() << ")"; 591 592 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 593 AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); 594 595 // Symbols disallowed. 596 if (indexingMap.getNumSymbols() != 0) 597 return op->emitOpError("unexpected symbols in indexing_map #") 598 << opOperand->getOperandNumber(); 599 600 // Domain must be consistent. 601 unsigned numLoops = linalgOp.getNumLoops(); 602 if (indexingMap.getNumDims() != numLoops) 603 return op->emitOpError("expected indexing_map #") 604 << opOperand->getOperandNumber() << " to have " << numLoops 605 << " dim(s) to match the number of loops"; 606 607 int64_t rank = linalgOp.getRank(opOperand); 608 if (indexingMap.getNumResults() != rank) 609 return op->emitOpError("expected operand rank (") 610 << rank << ") to match the result rank of indexing_map #" 611 << opOperand->getOperandNumber() << " (" 612 << indexingMap.getNumResults() << ")"; 613 } 614 615 SmallVector<unsigned> redDims; 616 linalgOp.getReductionDims(redDims); 617 618 // Simplifying assumption: either full tensor or full buffer mode. 619 // This allows simpler verification of output operands vs result types 620 // without premature tracking of which operand is what in mixed-mode. 621 // TODO: relax when mixed-mode needs to pass verification. 622 if (!linalgOp.getOutputBufferOperands().empty() && 623 !linalgOp.getOutputTensorOperands().empty()) 624 return op->emitOpError( 625 "expected output operands to all have tensor type or " 626 "all have buffer type"); 627 628 for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) { 629 OpResult result = linalgOp.getTiedOpResult(opOperand); 630 if (result.getType() != opOperand->get().getType()) 631 return op->emitOpError("expected type of operand #") 632 << opOperand->getOperandNumber() << " (" 633 << opOperand->get().getType() << ")" 634 << " to match type of corresponding result (" << result.getType() 635 << ")"; 636 } 637 638 // Output tensor indexing map may not depend on reduction indices. 639 for (OpOperand *opOperand : linalgOp.getOutputOperands()) { 640 AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); 641 for (AffineExpr expr : indexingMap.getResults()) { 642 for (unsigned pos : redDims) { 643 if (expr.isFunctionOfDim(pos)) { 644 std::string exprStr; 645 { 646 llvm::raw_string_ostream os(exprStr); 647 os << expr; 648 } 649 return op->emitOpError( 650 "unexpected output tensor expression in indexing map #") 651 << (opOperand->getOperandNumber() - linalgOp.getNumInputs()) 652 << " a.k.a '" << exprStr 653 << "' is function of reduction iterator 'd" << pos << "'"; 654 } 655 } 656 } 657 } 658 659 // Check the region has exactly one block. 660 if (linalgOp->getNumRegions() != 1 || 661 !llvm::hasSingleElement(linalgOp->getRegion(0))) 662 return op->emitOpError("expects to have 1 region with 1 block"); 663 664 if (!linalgOp.getShapesToLoopsMap()) 665 return op->emitOpError("expected the shape-to-loops map to be non-null"); 666 667 // Simplifying assumption: bbargs match 1-1 with shape operands elemental 668 // types. 669 // TODO: once ranked shape types are plugged in, we may want to drop the 670 // corresponding bbargs, that can never be read from. This will be subject to 671 // consistency discussions (i.e. what to do with output tensors whose bbarg is 672 // not used). 673 Block &block = linalgOp->getRegion(0).front(); 674 675 if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments()) 676 return op->emitOpError("expected as many non-induction variable region " 677 "arguments as the number of input/output operands"); 678 679 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 680 Type elementType = getElementTypeOrSelf(opOperand->get()); 681 Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); 682 if (elementType != argType) 683 return op->emitOpError("expected type of bb argument #") 684 << opOperand->getOperandNumber() << " (" << argType << ")" 685 << " to match element or self type of the corresponding operand (" 686 << elementType << ")"; 687 } 688 689 // Check if given shapes match to inferred shapes. 690 Optional<SmallVector<int64_t, 4>> endLoopRangeValues = 691 linalgOp.getStaticLoopRanges(); 692 if (!endLoopRangeValues) 693 return op->emitOpError("unable to find loop range for operation"); 694 SmallVector<int64_t, 4> startLoopRangeValues((*endLoopRangeValues).size(), 0); 695 696 // Verify only static cases since we can't get exact dimension sizes and loop 697 // ranges for dynamic cases in this stage. 698 if (llvm::none_of(*endLoopRangeValues, ShapedType::isDynamic)) { 699 for (int64_t &range : *endLoopRangeValues) 700 range -= 1; 701 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 702 AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); 703 SmallVector<int64_t, 4> startIndices = 704 indexingMap.compose(startLoopRangeValues); 705 SmallVector<int64_t, 4> endIndices = 706 indexingMap.compose(*endLoopRangeValues); 707 ArrayRef<int64_t> shape = linalgOp.getShape(opOperand); 708 for (auto dim : llvm::seq<int64_t>(0, shape.size())) { 709 // Ignore dynamic dimension or the case that the dimension size is 0 710 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) 711 continue; 712 713 // The first index or last index should be the maximum or the minimum in 714 // the inferred index ranges since the range is increasing or 715 // decreasing. The size of dimensions of input/output operands and the 716 // maximum value + 1 in the inferred range should be the same. But, for 717 // now we check if the inferred ranges are in boundary of input/output 718 // operands' size or not in case that Affine Expressions are complicated 719 // such as d0 * 3 720 // + d1 since it is not easy to handle the issues. 721 // Found the case that this solution can't check, for example, (d0, d1) 722 // -> (d1 - d0) 723 int64_t inferredDimSize = 724 std::max(startIndices[dim], endIndices[dim]) + 1; 725 if (std::min(startIndices[dim], endIndices[dim]) < 0) { 726 std::string mapStr; 727 { 728 llvm::raw_string_ostream os(mapStr); 729 os << indexingMap; 730 } 731 return op->emitOpError( 732 "unexpected result less than 0 at expression #") 733 << dim << " in " << mapStr; 734 } 735 if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) { 736 if (inferredDimSize != shape[dim]) { 737 return op->emitOpError("inferred input/output operand #") 738 << opOperand->getOperandNumber() 739 << " has shape's dimension #" << dim << " to be " 740 << inferredDimSize << ", but found " << shape[dim]; 741 } 742 } else { 743 if (inferredDimSize > shape[dim]) { 744 return op->emitOpError("inferred input/output operand #") 745 << opOperand->getOperandNumber() 746 << " has shape's dimension #" << dim 747 << " to be greater than or equal to " << inferredDimSize 748 << ", but found " << shape[dim]; 749 } 750 } 751 } 752 } 753 } 754 755 return success(); 756 } 757