1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/AffineExprVisitor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/TypeUtilities.h" 16 #include "llvm/ADT/SmallSet.h" 17 18 using namespace mlir; 19 using namespace mlir::linalg; 20 21 /// Include the definitions of the copy operation interface. 22 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" 23 24 //===----------------------------------------------------------------------===// 25 // ContractionOpInterface implementation 26 //===----------------------------------------------------------------------===// 27 28 /// Return true if the use-def chain from `v` to `from` consists of 0 or more 29 /// unary single-operand operations. 30 // TODO: relax to multi-operands with constants, which are technically unary ops 31 // as needed (e.g. add5). 32 static bool isChainOfUnaryOpsFrom(Value v, Value from) { 33 while (true) { 34 if (v == from) 35 return true; 36 Operation *op = v.getDefiningOp(); 37 if (!op || op->getNumOperands() != 1) 38 return false; 39 v = op->getOperand(0); 40 }; 41 } 42 43 /// Return the unique instance of OpType in `block` if it is indeed unique. 44 /// Return null if none or more than 1 instances exist. 45 template <typename OpType> 46 static OpType getSingleOpOfType(Block &block) { 47 OpType res = nullptr; 48 block.walk([&](OpType op) { 49 if (res) { 50 res = nullptr; 51 return WalkResult::interrupt(); 52 } 53 res = op; 54 return WalkResult::advance(); 55 }); 56 return res; 57 } 58 59 /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))` 60 /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent 61 /// unary operations that may change the type. 62 template <typename AddOpType, typename MulOpType> 63 static bool isAddMul(Block &block) { 64 if (block.getNumArguments() != 3) 65 return false; 66 Operation *yieldOp = block.getTerminator(); 67 if (yieldOp->getNumOperands() != 1) 68 return false; 69 70 AddOpType addOp = getSingleOpOfType<AddOpType>(block); 71 MulOpType mulOp = getSingleOpOfType<MulOpType>(block); 72 if (!addOp || !mulOp) 73 return false; 74 75 Value argA = block.getArgument(0), argB = block.getArgument(1); 76 Value a = mulOp->getOperand(0), b = mulOp->getOperand(1); 77 Value mul = mulOp->getResult(0); 78 Value argC = block.getArgument(2); 79 Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1); 80 Value add = addOp->getResult(0); 81 Value res = yieldOp->getOperand(0); 82 // Result traces back to add. 83 auto un = isChainOfUnaryOpsFrom; 84 bool success = un(res, add); 85 // One of the operands of add traces back to argC, the other to the mul. 86 success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC)); 87 // One of the operands of mul traces back to argA, the other to argB. 88 success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA)); 89 return success; 90 } 91 92 enum MatchContractionResult { 93 Success = 0, 94 NotLinalgOp, 95 WrongNumOperands, 96 NoReduction, 97 NotProjectedPermutations, 98 NotAddMul 99 }; 100 static MatchContractionResult isContractionInterfaceImpl(Operation *op) { 101 auto linalgOp = dyn_cast<linalg::LinalgOp>(op); 102 if (!linalgOp) 103 return MatchContractionResult::NotLinalgOp; 104 if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) 105 return MatchContractionResult::WrongNumOperands; 106 auto mapRange = linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>(); 107 if (linalgOp.getNumReductionLoops() == 0) 108 return MatchContractionResult::NoReduction; 109 if (llvm::any_of(mapRange, 110 [](AffineMap m) { return !m.isProjectedPermutation(); })) 111 return MatchContractionResult::NotProjectedPermutations; 112 // TODO: more fields than add/mul. 113 if (!isAddMul<AddFOp, MulFOp>(linalgOp->getRegion(0).front()) && 114 !isAddMul<AddIOp, MulIOp>(linalgOp->getRegion(0).front())) 115 return MatchContractionResult::NotAddMul; 116 return MatchContractionResult::Success; 117 } 118 119 bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { 120 if (!linalgOp) 121 return false; 122 Operation *op = linalgOp.getOperation(); 123 return isa<ContractionOpInterface>(op) || 124 (isContractionInterfaceImpl(op) == MatchContractionResult::Success); 125 } 126 127 /// Verify that a LinalgOp `op` is a contraction. 128 /// A Linalg contraction is defined in general terms: 129 /// 1. Has 2 input and 1 output shapes. 130 /// 2. Has at least one reduction dimension. 131 /// 3. Has only projected permutation indexing maps. 132 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field 133 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary 134 /// operations that may change the type (e.g. for mixed-precision). 135 /// As a consequence, when vectorization of such an op occurs, the only special 136 /// behavior is that the (unique) MulOpType is vectorized into a 137 /// `vector.contract`. All other ops are handled in a generic fashion. 138 /// In the future, we may wish to allow more input arguments and elementwise and 139 /// constant operations that do not involve the reduction dimension(s). 140 LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) { 141 auto res = isContractionInterfaceImpl(op); 142 if (res == MatchContractionResult::NotLinalgOp) 143 return op->emitError("expected a LinalgOp"); 144 if (res == MatchContractionResult::WrongNumOperands) 145 return op->emitError("expected op with 2 inputs and 1 outputs"); 146 if (res == MatchContractionResult::NoReduction) 147 return op->emitError("expected at least a reduction loop"); 148 if (res == MatchContractionResult::NotProjectedPermutations) 149 return op->emitError("expected all indexings to be projected permutations"); 150 if (res == MatchContractionResult::NotAddMul) 151 return op->emitError("(add, mul) operations not found"); 152 return success(); 153 } 154 155 //===----------------------------------------------------------------------===// 156 // StructuredOpInterface implementation 157 //===----------------------------------------------------------------------===// 158 159 OpOperandVector::operator SmallVector<Value>() { 160 SmallVector<Value> result; 161 result.reserve(this->size()); 162 llvm::transform(*this, std::back_inserter(result), 163 [](OpOperand *opOperand) { return opOperand->get(); }); 164 return result; 165 } 166 167 /// Fully compose map with operands and canonicalize the result. 168 /// Return the `createOrFold`'ed AffineApply op. 169 static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc, 170 AffineMap map, 171 ValueRange operandsRef) { 172 SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end()); 173 fullyComposeAffineMapAndOperands(&map, &operands); 174 canonicalizeMapAndOperands(&map, &operands); 175 return b.createOrFold<AffineApplyOp>(loc, map, operands); 176 } 177 178 SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, 179 AffineMap map, 180 ValueRange values) { 181 SmallVector<Value, 4> res; 182 res.reserve(map.getNumResults()); 183 unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols(); 184 // For each `expr` in `map`, applies the `expr` to the values extracted from 185 // ranges. If the resulting application can be folded into a Value, the 186 // folding occurs eagerly. 187 for (auto expr : map.getResults()) { 188 AffineMap map = AffineMap::get(numDims, numSym, expr); 189 res.push_back(createFoldedComposedAffineApply(b, loc, map, values)); 190 } 191 return res; 192 } 193 194 SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b, 195 Location loc) { 196 SmallVector<Value, 4> res; 197 for (OpOperand *opOperand : getInputAndOutputOperands()) { 198 for (int64_t i = 0, e = getRank(opOperand); i < e; ++i) 199 res.push_back(b.createOrFold<memref::DimOp>(loc, opOperand->get(), i)); 200 } 201 return res; 202 } 203 204 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() { 205 SmallVector<int64_t, 4> res; 206 assert(!hasDynamicShape() && "expected operands to have static shapes"); 207 for (OpOperand *opOperand : getInputAndOutputOperands()) 208 llvm::append_range(res, getShape(opOperand)); 209 return res; 210 } 211 212 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { 213 AffineMap map = getLoopsToShapesMap(); 214 unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); 215 auto viewSizes = createFlatListOfOperandDims(b, loc); 216 SmallVector<Range, 4> res(numDims); 217 Value zeroVal = b.create<ConstantIndexOp>(loc, 0); 218 Value oneVal = b.create<ConstantIndexOp>(loc, 1); 219 for (unsigned idx = 0; idx < numRes; ++idx) { 220 auto result = map.getResult(idx); 221 if (auto d = result.dyn_cast<AffineDimExpr>()) { 222 if (res[d.getPosition()].offset) 223 continue; 224 res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal}; 225 } 226 } 227 return res; 228 } 229 230 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() { 231 AffineMap map = getLoopsToShapesMap(); 232 unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); 233 SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims(); 234 SmallVector<int64_t, 4> res(numDims, 0); 235 for (unsigned idx = 0; idx < numRes; ++idx) { 236 auto result = map.getResult(idx); 237 if (auto d = result.dyn_cast<AffineDimExpr>()) 238 res[d.getPosition()] = allShapeSizes[idx]; 239 } 240 return res; 241 } 242 243 /// Visitor to check if any of the given set of positions from AffineDimExprs 244 /// are used within an AffineExpr. 245 struct HasAffineDimExprVisitor 246 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> { 247 HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions) 248 : positions(positions) {} 249 250 bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { 251 return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS()); 252 } 253 254 bool visitDimExpr(AffineDimExpr dimExpr) { 255 return positions.count(dimExpr.getPosition()); 256 } 257 258 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } 259 260 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } 261 262 private: 263 llvm::SmallSet<unsigned, 4> positions; 264 }; 265 266 LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim( 267 OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) { 268 // An example that helps understand the logic below. 269 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) 270 // We want to express the shape of dim 0 of O in terms of shape of the inputs. 271 // This is achieved as follows. 272 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) 273 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1) 274 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) 275 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) 276 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) 277 AffineMap loopsToShapesMap = getLoopsToShapesMap(); 278 279 // Find the position in the above map that represents the shape of the 280 // result:dim being inferred. 281 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(); 282 283 /// From loopsToShapesMap extract the submap that represents the shape of the 284 /// (resultIdx, dim) needed. 285 SmallVector<unsigned, 4> resultPosRange = 286 llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first, 287 resultShapesSubMapPos.second)); 288 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange); 289 AffineMap resultShapesFromInputShapesMap = 290 loopToResultsShapeMap.compose(getShapesToLoopsMap()); 291 292 // Check that the result dim map does not contain the positions corresponding 293 // to the outputs. 294 llvm::SmallSet<unsigned, 4> outputDims; 295 llvm::for_each(resultPosRange, 296 [&outputDims](unsigned dim) { outputDims.insert(dim); }); 297 HasAffineDimExprVisitor checkDimExpr(outputDims); 298 Location loc = getOperation()->getLoc(); 299 auto allResultDimValues = 300 applyMapToValues(b, loc, resultShapesFromInputShapesMap, 301 createFlatListOfOperandDims(b, loc)); 302 int64_t pos = 0; 303 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults(); 304 for (OpOperand *opOperand : getOutputOperands()) { 305 SmallVector<Value> shapes; 306 for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) { 307 if (checkDimExpr.visit(shapeExprs[pos])) 308 shapes.push_back( 309 b.createOrFold<memref::DimOp>(loc, opOperand->get(), dim)); 310 else 311 shapes.push_back(allResultDimValues[pos]); 312 pos++; 313 } 314 reifiedReturnShapes.emplace_back(std::move(shapes)); 315 } 316 return success(); 317 } 318 319 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { 320 LinalgOp linalgOp = cast<LinalgOp>(op); 321 // Expect at least one output operand. 322 // This means an op that constructs a tensor out of indices cannot be a 323 // LinalgOp at the moment. For now this will have to be a special op until we 324 // have output shape operands that are not tensors. 325 int64_t numInputs = linalgOp.getNumInputs(); 326 int64_t numOutputs = linalgOp.getNumOutputs(); 327 if (numOutputs == 0) 328 return op->emitOpError("expected at least one output operand"); 329 if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) 330 return failure(); 331 // Should have at least one output tensor per result tensor. 332 // Can also have outbut buffers that do not correspond to results. 333 if (op->getNumResults() > linalgOp.getOutputTensorOperands().size()) 334 return op->emitOpError("unexpected #results > #outputs"); 335 336 // Before checking indexing maps, we need to make sure the attributes 337 // referenced by it are valid. 338 if (linalgOp.hasDynamicIndexingMaps()) 339 if (failed(linalgOp.verifyIndexingMapRequiredAttributes())) 340 return failure(); 341 342 // All input/output operands must be indexed. 343 if (static_cast<int64_t>(linalgOp.indexing_maps().size()) != 344 linalgOp.getNumInputsAndOutputs()) 345 return op->emitOpError("expected the number of indexing_map (") 346 << linalgOp.indexing_maps().size() 347 << ") to be equal to the number of input/output operands (" 348 << linalgOp.getNumInputsAndOutputs() << ")"; 349 350 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 351 AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); 352 353 // Symbols disallowed. 354 if (indexingMap.getNumSymbols() != 0) 355 return op->emitOpError("unexpected symbols in indexing_map #") 356 << opOperand->getOperandNumber(); 357 358 // Domain must be consistent. 359 unsigned numLoops = linalgOp.getNumLoops(); 360 if (indexingMap.getNumDims() != numLoops) 361 return op->emitOpError("expected indexing_map #") 362 << opOperand->getOperandNumber() << " to have " << numLoops 363 << " dim(s) to match the number of loops"; 364 365 int64_t rank = linalgOp.getRank(opOperand); 366 if (indexingMap.getNumResults() != rank) 367 return op->emitOpError("expected operand rank (") 368 << rank << ") to match the result rank of indexing_map #" 369 << opOperand->getOperandNumber() << " (" 370 << indexingMap.getNumResults() << ")"; 371 } 372 373 SmallVector<AffineExpr> redDims; 374 linalgOp.getReductionDims(redDims); 375 376 // Simplifying assumption: either full tensor or full buffer mode. 377 // This allows simpler verification of output operands vs result types 378 // without premature tracking of which operand is what in mixed-mode. 379 // TODO: relax when mixed-mode needs to pass verification. 380 if (!linalgOp.getOutputBufferOperands().empty() && 381 !linalgOp.getOutputTensorOperands().empty()) 382 return op->emitOpError( 383 "expected output operands to all have tensor type or " 384 "all have buffer type"); 385 386 for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) { 387 // TODO: Enforce one output tensor per result? 388 if (opOperand->getOperandNumber() - linalgOp.getNumInputs() >= 389 linalgOp->getNumResults()) 390 continue; 391 OpResult result = linalgOp.getTiedOpResult(opOperand); 392 if (result.getType() != opOperand->get().getType()) 393 return op->emitOpError("expected type of operand #") 394 << opOperand->getOperandNumber() << " (" 395 << opOperand->get().getType() << ")" 396 << " to match type of corresponding result (" << result.getType() 397 << ")"; 398 } 399 400 // Output tensor indexing map may not depend on reduction indices. 401 for (OpOperand *opOperand : linalgOp.getOutputOperands()) { 402 AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); 403 for (auto expr : indexingMap.getResults()) { 404 for (auto dim : redDims) { 405 unsigned pos = dim.cast<AffineDimExpr>().getPosition(); 406 if (expr.isFunctionOfDim(pos)) { 407 std::string exprStr; 408 { 409 llvm::raw_string_ostream os(exprStr); 410 os << expr; 411 } 412 return op->emitOpError( 413 "unexpected output tensor expression in indexing map #") 414 << (opOperand->getOperandNumber() - linalgOp.getNumInputs()) 415 << " a.k.a '" << exprStr 416 << "' is function of reduction iterator 'd" << pos << "'"; 417 } 418 } 419 } 420 } 421 422 // Named ops that are defined manually have a region builder but no region at 423 // this time. Assume the region is well-formed by specification. 424 // TODO: use linalg-ods-gen for all ops when we have enough expressive power. 425 if (linalgOp->getNumRegions() == 0) { 426 assert(!linalgOp.getRegionBuilder() && "regionBuilder but no region"); 427 return success(); 428 } 429 430 auto ®ion = linalgOp->getRegion(0); 431 if (linalgOp->getNumRegions() > 1 || !llvm::hasSingleElement(region)) 432 return op->emitOpError("expected 1 region with 1 block"); 433 434 if (!linalgOp.getShapesToLoopsMap()) 435 return op->emitOpError("expected the shape-to-loops map to be non-null"); 436 437 // Simplifying assumption: bbargs match 1-1 with shape operands elemental 438 // types. 439 // TODO: once ranked shape types are plugged in, we may want to drop the 440 // corresponding bbargs, that can never be read from. This will be subject to 441 // consistency discussions (i.e. what to do with output tensors whose bbarg is 442 // not used). 443 Block &block = linalgOp->getRegion(0).front(); 444 445 if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments()) 446 return op->emitOpError("expected as many non-induction variable region " 447 "arguments as the number of input/output operands"); 448 449 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 450 Type elementType = getElementTypeOrSelf(opOperand->get()); 451 Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); 452 if (elementType != argType) 453 return op->emitOpError("expected type of bb argument #") 454 << opOperand->getOperandNumber() << " (" << argType << ")" 455 << " to match element or self type of the corresponding operand (" 456 << elementType << ")"; 457 } 458 459 // Check if given shapes match to inferred shapes. 460 Optional<SmallVector<int64_t, 4>> endLoopRangeValues = 461 linalgOp.getStaticLoopRanges(); 462 if (!endLoopRangeValues) 463 return op->emitOpError("unable to find loop range for operation"); 464 SmallVector<int64_t, 4> startLoopRangeValues((*endLoopRangeValues).size(), 0); 465 466 // Verify only static cases since we can't get exact dimension sizes and loop 467 // ranges for dynamic cases in this stage. 468 if (llvm::none_of(*endLoopRangeValues, ShapedType::isDynamic)) { 469 for (int64_t &range : *endLoopRangeValues) 470 range -= 1; 471 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 472 AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); 473 SmallVector<int64_t, 4> startIndices = 474 indexingMap.compose(startLoopRangeValues); 475 SmallVector<int64_t, 4> endIndices = 476 indexingMap.compose(*endLoopRangeValues); 477 ArrayRef<int64_t> shape = linalgOp.getShape(opOperand); 478 for (auto dim : llvm::seq<int64_t>(0, shape.size())) { 479 // Ignore dynamic dimension or the case that the dimension size is 0 480 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) 481 continue; 482 483 // The first index or last index should be the maximum or the minimum in 484 // the inferred index ranges since the range is increasing or 485 // decreasing. The size of dimensions of input/output operands and the 486 // maximum value + 1 in the inferred range should be the same. But, for 487 // now we check if the inferred ranges are in boundary of input/output 488 // operands' size or not in case that Affine Expressions are complicated 489 // such as d0 * 3 490 // + d1 since it is not easy to handle the issues. 491 // Found the case that this solution can't check, for example, (d0, d1) 492 // -> (d1 - d0) 493 int64_t inferredDimSize = 494 std::max(startIndices[dim], endIndices[dim]) + 1; 495 if (std::min(startIndices[dim], endIndices[dim]) < 0) { 496 std::string mapStr; 497 { 498 llvm::raw_string_ostream os(mapStr); 499 os << indexingMap; 500 } 501 return op->emitOpError( 502 "unexpected result less than 0 at expression #") 503 << dim << " in " << mapStr; 504 } 505 if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) { 506 if (inferredDimSize != shape[dim]) { 507 return op->emitOpError("inferred input/output operand #") 508 << opOperand->getOperandNumber() 509 << " has shape's dimension #" << dim << " to be " 510 << inferredDimSize << ", but found " << shape[dim]; 511 } 512 } else { 513 if (inferredDimSize > shape[dim]) { 514 return op->emitOpError("inferred input/output operand #") 515 << opOperand->getOperandNumber() 516 << " has shape's dimension #" << dim 517 << " to be greater than or equal to " << inferredDimSize 518 << ", but found " << shape[dim]; 519 } 520 } 521 } 522 } 523 } 524 525 return success(); 526 } 527