1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Utils/Utils.h" 14 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 25 #include "mlir/Dialect/Tensor/IR/Tensor.h" 26 #include "mlir/Dialect/Utils/StaticValueUtils.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/AffineExprVisitor.h" 29 #include "mlir/IR/AffineMap.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/IR/OpImplementation.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Transforms/LoopUtils.h" 34 #include "llvm/ADT/TypeSwitch.h" 35 #include "llvm/Support/Debug.h" 36 37 #define DEBUG_TYPE "linalg-utils" 38 39 using namespace mlir; 40 using namespace mlir::linalg; 41 using namespace mlir::scf; 42 43 static bool isZero(Value v) { 44 if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>()) 45 return cst.value() == 0; 46 return false; 47 } 48 49 namespace { 50 51 // Helper visitor to determine whether an AffineExpr is tiled. 52 // This is achieved by traversing every AffineDimExpr with position `pos` and 53 // checking whether the corresponding `tileSizes[pos]` is non-zero. 54 // This also enforces only positive coefficients occur in multiplications. 55 // 56 // Example: 57 // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] 58 // 59 struct TileCheck : public AffineExprVisitor<TileCheck> { 60 TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {} 61 62 void visitDimExpr(AffineDimExpr expr) { 63 isTiled |= !isZero(tileSizes[expr.getPosition()]); 64 } 65 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { 66 visit(expr.getLHS()); 67 visit(expr.getRHS()); 68 if (expr.getKind() == mlir::AffineExprKind::Mul) 69 assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 && 70 "nonpositive multiplying coefficient"); 71 } 72 bool isTiled; 73 ValueRange tileSizes; 74 }; 75 76 } // namespace 77 78 static bool isTiled(AffineExpr expr, ValueRange tileSizes) { 79 if (!expr) 80 return false; 81 TileCheck t(tileSizes); 82 t.visit(expr); 83 return t.isTiled; 84 } 85 86 // Checks whether the `map varies with respect to a non-zero `tileSize`. 87 static bool isTiled(AffineMap map, ValueRange tileSizes) { 88 if (!map) 89 return false; 90 for (unsigned r = 0; r < map.getNumResults(); ++r) 91 if (isTiled(map.getResult(r), tileSizes)) 92 return true; 93 return false; 94 } 95 96 Optional<RegionMatcher::BinaryOpKind> 97 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { 98 auto ®ion = op.region(); 99 if (!llvm::hasSingleElement(region)) 100 return llvm::None; 101 102 Block &block = region.front(); 103 if (block.getNumArguments() != 2 || 104 !block.getArgument(0).getType().isSignlessIntOrFloat() || 105 !block.getArgument(1).getType().isSignlessIntOrFloat()) 106 return llvm::None; 107 108 auto &ops = block.getOperations(); 109 if (!llvm::hasSingleElement(block.without_terminator())) 110 return llvm::None; 111 112 using mlir::matchers::m_Val; 113 auto a = m_Val(block.getArgument(0)); 114 auto b = m_Val(block.getArgument(1)); 115 116 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b)); 117 if (addPattern.match(&ops.back())) 118 return BinaryOpKind::IAdd; 119 120 return llvm::None; 121 } 122 123 /// Explicit instantiation of loop nest generator for different loop types. 124 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; 125 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; 126 template struct mlir::linalg::GenerateLoopNest<AffineForOp>; 127 template struct mlir::linalg::GenerateLoopNest<TiledLoopOp>; 128 129 /// Given a list of subview ranges, extract individual values for lower, upper 130 /// bounds and steps and put them into the corresponding vectors. 131 static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, 132 SmallVectorImpl<Value> &ubs, 133 SmallVectorImpl<Value> &steps) { 134 for (Range range : ranges) { 135 lbs.emplace_back(range.offset); 136 ubs.emplace_back(range.size); 137 steps.emplace_back(range.stride); 138 } 139 } 140 141 namespace mlir { 142 namespace linalg { 143 144 bool isPermutation(ArrayRef<int64_t> permutation) { 145 // Count the number of appearances for all indices. 146 SmallVector<int64_t> indexCounts(permutation.size(), 0); 147 for (auto index : permutation) { 148 // Exit if the index is out-of-range. 149 if (index < 0 || index >= static_cast<int64_t>(permutation.size())) 150 return false; 151 indexCounts[index]++; 152 } 153 // Return true if all indices appear once. 154 return count(indexCounts, 1) == static_cast<int64_t>(permutation.size()); 155 } 156 157 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on 158 /// the type of `source`. 159 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { 160 if (source.getType().isa<UnrankedMemRefType, MemRefType>()) 161 return b.createOrFold<memref::DimOp>(loc, source, dim); 162 if (source.getType().isa<UnrankedTensorType, RankedTensorType>()) 163 return b.createOrFold<tensor::DimOp>(loc, source, dim); 164 llvm_unreachable("Expected MemRefType or TensorType"); 165 } 166 167 /// Given an operation, retrieves the value of each dynamic dimension through 168 /// constructing the necessary DimOp operators. 169 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) { 170 SmallVector<Value, 4> dynOperands; 171 auto shapedType = val.getType().cast<ShapedType>(); 172 for (const auto &dim : llvm::enumerate(shapedType.getShape())) { 173 if (dim.value() == ShapedType::kDynamicSize) 174 dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index())); 175 } 176 return dynOperands; 177 } 178 179 void getUpperBoundForIndex(Value value, AffineMap &boundMap, 180 SmallVectorImpl<Value> &boundOperands) { 181 // Initialize `boundMap` and `boundOperands` to the identity returning 182 // `value`. This combination is the default result of the method if no 183 // simplification is possible. 184 assert(value.getType().isIndex() && "expect value to have index type"); 185 boundMap = AffineMap::getMultiDimIdentityMap(1, value.getContext()); 186 boundOperands.assign({value}); 187 canonicalizeMapAndOperands(&boundMap, &boundOperands); 188 189 // Continue only if there is an affine index computation to simplify. 190 Operation *definingOp = value.getDefiningOp(); 191 if (!definingOp || !isa<AffineApplyOp, AffineMinOp>(definingOp)) 192 return; 193 194 // Get the backward slice containing the affine index computation. 195 SetVector<Operation *> backwardSlice; 196 getBackwardSlice(definingOp, &backwardSlice, [](Operation *op) { 197 return isa<AffineApplyOp, AffineMinOp>(op); 198 }); 199 backwardSlice.insert(definingOp); 200 201 // Setup a system of affine constraints that describe the index computation. 202 FlatAffineValueConstraints constraints; 203 204 // Helper to find or create an identifier for the given value. 205 auto findOrCreateId = [&](Value value) { 206 if (!constraints.containsId(value)) { 207 constraints.appendDimId(value); 208 return true; 209 } 210 unsigned pos; 211 constraints.findId(value, &pos); 212 return pos < constraints.getNumDimIds(); 213 }; 214 // Helper to get the position for the given value. 215 auto getPosition = [&](Value value) { 216 unsigned pos; 217 bool exists = constraints.findId(value, &pos); 218 (void)exists; 219 assert(exists && "expect to find the identifier"); 220 return pos; 221 }; 222 223 // Add the affine operations in `backwardSlice` to the constraints. 224 for (Operation *op : llvm::reverse(backwardSlice)) { 225 // Add an identifier for all op results and operands. 226 if (!(llvm::all_of(op->getResults(), findOrCreateId) && 227 llvm::all_of(op->getOperands(), findOrCreateId))) 228 return; 229 // Add AffineApplyOps to the constraints. 230 if (auto applyOp = dyn_cast<AffineApplyOp>(op)) { 231 AffineValueMap valueMap(applyOp.getAffineMap(), applyOp.getOperands(), 232 applyOp.getResult()); 233 if (failed(constraints.composeMap(&valueMap))) 234 return; 235 continue; 236 } 237 // Add AffineMinOps to the constraints. 238 auto minOp = cast<AffineMinOp>(op); 239 AffineMap map = constraints.computeAlignedMap(minOp.getAffineMap(), 240 minOp.getOperands()); 241 if (failed(constraints.addBound(FlatAffineConstraints::UB, 242 getPosition(minOp.getResult()), map))) 243 return; 244 } 245 246 // Obtain an upper bound for the affine index computation by projecting out 247 // all temporary results and expressing the upper bound for `value` in terms 248 // of the terminals of the index computation. 249 SmallVector<AffineMap> lowerBounds(1), upperBounds(1); 250 constraints.getSliceBounds(getPosition(value), 1, value.getContext(), 251 &lowerBounds, &upperBounds); 252 253 // Verify `upperBounds[0]` is valid and has at least one result. 254 if (!upperBounds[0] || upperBounds[0].getNumResults() == 0) 255 return; 256 257 // Set `boundMap` and `boundOperands` to the computed upper bound. 258 boundMap = upperBounds[0]; 259 constraints.getAllValues(&boundOperands); 260 erase_value(boundOperands, value); 261 canonicalizeMapAndOperands(&boundMap, &boundOperands); 262 } 263 264 FailureOr<int64_t> getConstantUpperBoundForIndex(Value value) { 265 // Compute an upper bound for `value`. 266 AffineMap boundMap; 267 SmallVector<Value> boundOperands; 268 getUpperBoundForIndex(value, boundMap, boundOperands); 269 270 // Search the results of `boundMap` for constant upper bounds. 271 SmallVector<int64_t> constantBounds; 272 for (AffineExpr result : boundMap.getResults()) 273 if (auto constExpr = result.dyn_cast<AffineConstantExpr>()) 274 constantBounds.push_back(constExpr.getValue()); 275 276 // Return the minimal upper bound or failure if none is found. 277 if (constantBounds.empty()) 278 return failure(); 279 return *std::min_element(constantBounds.begin(), constantBounds.end()); 280 } 281 282 tensor::ExtractSliceOp makeComposedExtractSliceOp( 283 OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets, 284 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { 285 assert(source && "expect source to be nonzero"); 286 287 // Do not fold if the producer is not an ExtractSliceOp. 288 auto producerOp = source.getDefiningOp<tensor::ExtractSliceOp>(); 289 if (!producerOp) 290 return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes, 291 strides); 292 293 // Do not fold if the producer is rank reducing or if there are any non-unit 294 // strides. Supporting non-unit strides complicates the offset computation 295 // since the consumer offsets need to be multiplied by the producer strides. 296 // TODO: support non-unit strides once there are use cases. 297 SmallVector<OpFoldResult> allStrides = producerOp.getMixedStrides(); 298 allStrides.append(strides.begin(), strides.end()); 299 bool hasNonUnitStride = any_of(allStrides, [](OpFoldResult ofr) { 300 return getConstantIntValue(ofr) != static_cast<int64_t>(1); 301 }); 302 if (hasNonUnitStride || 303 producerOp.getSourceType().getRank() != 304 producerOp.getResult().getType().cast<ShapedType>().getRank()) 305 return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes, 306 strides); 307 308 // Fold the producer by adding the offests and extracting the slice directly 309 // from the producer source tensor. 310 SmallVector<OpFoldResult> foldedOffsets(offsets.begin(), offsets.end()); 311 AffineExpr dim1, dim2; 312 bindDims(b.getContext(), dim1, dim2); 313 for (const auto &en : enumerate(producerOp.getMixedOffsets())) { 314 SmallVector<Value> offsetValues = { 315 getValueOrCreateConstantIndexOp(b, loc, foldedOffsets[en.index()]), 316 getValueOrCreateConstantIndexOp(b, loc, en.value())}; 317 foldedOffsets[en.index()] = 318 makeComposedAffineApply(b, loc, dim1 + dim2, offsetValues).getResult(); 319 } 320 return b.create<tensor::ExtractSliceOp>(loc, producerOp.source(), 321 foldedOffsets, sizes, strides); 322 } 323 324 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, 325 Value source, Value pad, bool nofold) { 326 assert(type.hasStaticShape() && "expect tensor type to have static shape"); 327 328 // Exit if `source` is not defined by an ExtractSliceOp. 329 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); 330 if (!sliceOp) 331 return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); 332 333 // Search the `source` use-def chain for padded LinalgOps. 334 Value current = sliceOp.source(); 335 while (current) { 336 auto linalgOp = current.getDefiningOp<LinalgOp>(); 337 if (!linalgOp) 338 break; 339 OpResult opResult = current.cast<OpResult>(); 340 current = linalgOp.getOutputOperand(opResult.getResultNumber())->get(); 341 } 342 auto padTensorOp = current ? current.getDefiningOp<PadTensorOp>() : nullptr; 343 344 // Exit if the search fails to match a PadTensorOp at the end of the matched 345 // LinalgOp sequence. 346 if (!padTensorOp) 347 return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); 348 349 // Exit if the padded result type does not match. 350 if (sliceOp.source().getType() != type) 351 return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); 352 353 // Exit if the LinalgOps are not high padded. 354 if (llvm::any_of(padTensorOp.getMixedLowPad(), [](OpFoldResult ofr) { 355 return getConstantIntValue(ofr) != static_cast<int64_t>(0); 356 })) 357 return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); 358 359 // Exit if `padTensorOpSliceOp`, which defines the slice used by 360 // `padTensorOp`, is rank-reducing. 361 auto padTensorOpSliceOp = 362 padTensorOp.source().getDefiningOp<tensor::ExtractSliceOp>(); 363 if (!padTensorOpSliceOp || sliceOp.getMixedSizes().size() != 364 padTensorOpSliceOp.getMixedSizes().size()) 365 return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); 366 367 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size 368 // of the slice padded by `padTensorOp`. 369 if (llvm::any_of(llvm::zip(sliceOp.getMixedSizes(), 370 padTensorOpSliceOp.getMixedSizes()), 371 [](std::tuple<OpFoldResult, OpFoldResult> it) { 372 return !isEqualConstantIntOrValue(std::get<0>(it), 373 std::get<1>(it)); 374 })) 375 return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); 376 377 // Exit if the padding values do not match. 378 Attribute padTensorOpPadAttr, padAttr; 379 Value padTensorOpPad = padTensorOp.getConstantPaddingValue(); 380 if (!padTensorOpPad || 381 !matchPattern(padTensorOpPad, m_Constant(&padTensorOpPadAttr)) || 382 !matchPattern(pad, m_Constant(&padAttr)) || padTensorOpPadAttr != padAttr) 383 return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); 384 385 // Return the padded result if the padding values and sizes match. 386 return sliceOp.source(); 387 } 388 389 /// Specialization to build an scf "for" nest. 390 template <> 391 void GenerateLoopNest<scf::ForOp>::doit( 392 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 393 ArrayRef<Attribute> iteratorTypes, 394 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 395 ValueRange)> 396 bodyBuilderFn, 397 Optional<LinalgLoopDistributionOptions> distributionOptions, 398 ArrayRef<StringRef> distributionTypes) { 399 SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands(); 400 // Create procInfo so it dominates loops, if appropriate. 401 SmallVector<ProcInfo, 4> procInfo; 402 SmallVector<DistributionMethod, 0> distributionMethod; 403 if (distributionOptions.hasValue()) { 404 // Collect loop ranges for parallel dimensions. 405 SmallVector<Range, 2> parallelLoopRanges; 406 for (const auto &iteratorType : enumerate(iteratorTypes)) 407 if (isParallelIterator(iteratorType.value())) 408 parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); 409 410 // Get their distribution schemes. 411 distributionMethod = distributionOptions->distributionMethod; 412 if (distributionMethod.size() < parallelLoopRanges.size()) 413 parallelLoopRanges.resize(distributionMethod.size()); 414 procInfo = distributionOptions->procInfo(b, loc, parallelLoopRanges); 415 } 416 417 SmallVector<Value, 4> lbs, ubs, steps; 418 unpackRanges(loopRanges, lbs, ubs, steps); 419 LoopNest loopNest = mlir::scf::buildLoopNest( 420 b, loc, lbs, ubs, steps, iterArgInitValues, 421 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { 422 assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() && 423 "expect the number of output tensors and iter args to match"); 424 SmallVector<Value> operandValuesToUse = 425 linalgOp.getInputAndOutputOperands(); 426 if (!iterArgs.empty()) { 427 operandValuesToUse = linalgOp.getInputOperands(); 428 operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); 429 } 430 return bodyBuilderFn(b, loc, ivs, operandValuesToUse); 431 }); 432 433 if (!distributionOptions || loopNest.loops.empty()) 434 return; 435 436 // Filter out scf.for loops that were created out of parallel dimensions. 437 SmallVector<scf::ForOp, 4> loops; 438 for (const auto &iteratorType : enumerate(iteratorTypes)) 439 if (isParallelIterator(iteratorType.value())) 440 loops.push_back(loopNest.loops[iteratorType.index()]); 441 442 // Distribute - only supports cyclic distribution for now. 443 for (auto it : llvm::zip(loops, procInfo, distributionMethod)) 444 if (std::get<2>(it) == DistributionMethod::Cyclic) 445 mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId, 446 std::get<1>(it).nprocs); 447 } 448 449 /// Specialization to build affine "for" nest. 450 template <> 451 void GenerateLoopNest<AffineForOp>::doit( 452 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 453 ArrayRef<Attribute> iteratorTypes, 454 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 455 ValueRange)> 456 bodyBuilderFn, 457 Optional<LinalgLoopDistributionOptions>, ArrayRef<StringRef>) { 458 SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands(); 459 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); 460 SmallVector<Value, 4> lbs, ubs, steps; 461 unpackRanges(loopRanges, lbs, ubs, steps); 462 463 // Affine loops require constant steps. 464 SmallVector<int64_t, 4> constantSteps; 465 constantSteps.reserve(steps.size()); 466 for (Value v : steps) { 467 auto op = v.getDefiningOp<arith::ConstantIndexOp>(); 468 assert(op && "Affine loops require constant steps"); 469 constantSteps.push_back(op.value()); 470 } 471 472 mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, 473 [&](OpBuilder &b, Location loc, ValueRange ivs) { 474 SmallVector<Value> operandValuesToUse = 475 linalgOp.getInputAndOutputOperands(); 476 bodyBuilderFn(b, loc, ivs, operandValuesToUse); 477 }); 478 } 479 480 /// Specialization to build an linalg.tiled_loop 481 template <> 482 void GenerateLoopNest<TiledLoopOp>::doit( 483 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 484 ArrayRef<Attribute> iteratorTypes, 485 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 486 ValueRange)> 487 bodyBuilderFn, 488 Optional<LinalgLoopDistributionOptions> distributionOptions, 489 ArrayRef<StringRef> distributionTypes) { 490 SmallVector<ProcInfo, 2> procInfo; 491 SmallVector<Value, 4> lbs, ubs, steps; 492 unpackRanges(loopRanges, lbs, ubs, steps); 493 494 auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc, 495 ValueRange ivs, ValueRange inputs, 496 ValueRange outputs) { 497 SmallVector<Value> operandValuesToUse = inputs; 498 operandValuesToUse.append(outputs.begin(), outputs.end()); 499 scf::ValueVector results = 500 bodyBuilderFn(nestedBuilder, nestedLoc, ivs, operandValuesToUse); 501 nestedBuilder.create<linalg::YieldOp>(nestedLoc, results); 502 }; 503 504 SmallVector<Value> inputOperands = linalgOp.getInputOperands(); 505 SmallVector<Value> outputOperands = linalgOp.getOutputOperands(); 506 auto tiledLoop = 507 b.create<TiledLoopOp>(loc, lbs, ubs, steps, inputOperands, outputOperands, 508 b.getArrayAttr(iteratorTypes), wrappedBuilderFn); 509 if (!distributionTypes.empty()) 510 tiledLoop.setDistributionTypes(b, distributionTypes); 511 } 512 513 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. 514 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, 515 Value nprocs, Value &lb, Value &ub, 516 Value &step) { 517 AffineExpr d0, d1; 518 bindDims(b.getContext(), d0, d1); 519 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); 520 lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step}); 521 step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step}); 522 } 523 524 /// Generates a loop nest consisting of scf.parallel and scf.for, depending 525 /// on the `iteratorTypes.` Consecutive parallel loops create a single 526 /// scf.parallel operation; each sequential loop creates a new scf.for 527 /// operation. The body of the innermost loop is populated by 528 /// `bodyBuilderFn` that accepts a range of induction variables for all 529 /// loops. `ivStorage` is used to store the partial list of induction 530 /// variables. 531 // TODO: this function can be made iterative instead. However, it 532 // will have at most as many recursive calls as nested loops, which rarely 533 // exceeds 10. 534 static void generateParallelLoopNest( 535 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, 536 ValueRange steps, ArrayRef<Attribute> iteratorTypes, 537 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, 538 SmallVectorImpl<Value> &ivStorage, 539 ArrayRef<DistributionMethod> distributionMethod = {}) { 540 assert(lbs.size() == ubs.size()); 541 assert(lbs.size() == steps.size()); 542 assert(lbs.size() == iteratorTypes.size()); 543 544 // If there are no (more) loops to be generated, generate the body and be 545 // done with it. 546 if (iteratorTypes.empty()) { 547 bodyBuilderFn(b, loc, ivStorage); 548 return; 549 } 550 551 // Find the outermost parallel loops and drop their types from the list. 552 unsigned nLoops = iteratorTypes.size(); 553 unsigned nOuterPar = 554 nLoops - iteratorTypes.drop_while(isParallelIterator).size(); 555 556 // If there are no outer parallel loops, generate one sequential loop and 557 // recurse. Note that we wouldn't have dropped anything from `iteratorTypes` 558 // in this case. 559 if (nOuterPar == 0) { 560 LoopNest singleLoop = buildLoopNest( 561 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(), 562 [&](OpBuilder &b, Location loc, ValueRange ivs) { 563 ivStorage.append(ivs.begin(), ivs.end()); 564 generateParallelLoopNest(b, loc, lbs.drop_front(), ubs.drop_front(), 565 steps.drop_front(), 566 iteratorTypes.drop_front(), bodyBuilderFn, 567 ivStorage, distributionMethod); 568 }); 569 return; 570 } 571 if (distributionMethod.empty()) { 572 // Generate a single parallel loop-nest operation for all outermost 573 // parallel loops and recurse. 574 b.create<scf::ParallelOp>( 575 loc, lbs.take_front(nOuterPar), ubs.take_front(nOuterPar), 576 steps.take_front(nOuterPar), 577 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 578 ivStorage.append(localIvs.begin(), localIvs.end()); 579 generateParallelLoopNest( 580 nestedBuilder, nestedLoc, lbs.drop_front(nOuterPar), 581 ubs.drop_front(nOuterPar), steps.drop_front(nOuterPar), 582 iteratorTypes.drop_front(nOuterPar), bodyBuilderFn, ivStorage, 583 (distributionMethod.size() < nOuterPar) 584 ? ArrayRef<DistributionMethod>() 585 : distributionMethod.drop_front(nOuterPar)); 586 }); 587 return; 588 } 589 590 // Process all consecutive similarly distributed loops simultaneously. 591 DistributionMethod methodToUse = distributionMethod[0]; 592 unsigned numProcessed = 1; 593 for (unsigned i = 1; i < nOuterPar && i < distributionMethod.size(); ++i) { 594 if (distributionMethod[i] != methodToUse) 595 break; 596 numProcessed++; 597 } 598 599 switch (methodToUse) { 600 case DistributionMethod::Cyclic: { 601 // Generate a single parallel loop-nest operation for all outermost 602 // parallel loops and recurse. 603 b.create<scf::ParallelOp>( 604 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 605 steps.take_front(numProcessed), 606 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 607 ivStorage.append(localIvs.begin(), localIvs.end()); 608 generateParallelLoopNest( 609 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 610 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 611 iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage, 612 (distributionMethod.size() < numProcessed) 613 ? ArrayRef<DistributionMethod>() 614 : distributionMethod.drop_front(numProcessed)); 615 }); 616 return; 617 } 618 case DistributionMethod::CyclicNumProcsGeNumIters: { 619 // Check (for the processed loops) that the iteration is in-bounds. 620 ArithBuilder ab(b, loc); 621 Value cond = ab.slt(lbs[0], ubs[0]); 622 for (unsigned i = 1; i < numProcessed; ++i) 623 cond = ab._and(cond, ab.slt(lbs[i], ubs[i])); 624 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 625 b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { 626 generateParallelLoopNest( 627 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), 628 steps.drop_front(numProcessed), 629 iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage, 630 distributionMethod.drop_front(numProcessed)); 631 b.create<scf::YieldOp>(loc, ValueRange{}); 632 }); 633 return; 634 } 635 case DistributionMethod::CyclicNumProcsEqNumIters: 636 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed 637 // with inner loop generation. 638 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 639 generateParallelLoopNest( 640 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), 641 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), 642 bodyBuilderFn, ivStorage, distributionMethod.drop_front(numProcessed)); 643 return; 644 } 645 } 646 647 /// Specialization for generating a mix of parallel and sequential scf loops. 648 template <> 649 void GenerateLoopNest<scf::ParallelOp>::doit( 650 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 651 ArrayRef<Attribute> iteratorTypes, 652 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 653 ValueRange)> 654 bodyBuilderFn, 655 Optional<LinalgLoopDistributionOptions> distributionOptions, 656 ArrayRef<StringRef> distributionTypes) { 657 SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands(); 658 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); 659 // This function may be passed more iterator types than ranges. 660 assert(iteratorTypes.size() >= loopRanges.size() && 661 "expected iterator type for all ranges"); 662 iteratorTypes = iteratorTypes.take_front(loopRanges.size()); 663 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; 664 unsigned numLoops = iteratorTypes.size(); 665 ivs.reserve(numLoops); 666 lbsStorage.reserve(numLoops); 667 ubsStorage.reserve(numLoops); 668 stepsStorage.reserve(numLoops); 669 670 // Get the loop lb, ub, and step. 671 unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage); 672 673 // Modify the lb, ub, and step based on the distribution options. 674 SmallVector<DistributionMethod, 0> distributionMethod; 675 if (distributionOptions) { 676 auto &options = distributionOptions.getValue(); 677 distributionMethod.assign(distributionOptions->distributionMethod.begin(), 678 distributionOptions->distributionMethod.end()); 679 SmallVector<Range, 2> parallelLoopRanges; 680 for (const auto &iteratorType : enumerate(iteratorTypes)) { 681 if (isParallelIterator(iteratorType.value())) 682 parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); 683 } 684 if (distributionMethod.size() < parallelLoopRanges.size()) 685 parallelLoopRanges.resize(distributionMethod.size()); 686 SmallVector<ProcInfo, 2> procInfo = 687 options.procInfo(b, loc, parallelLoopRanges); 688 unsigned index = 0; 689 for (const auto &iteratorType : enumerate(iteratorTypes)) { 690 if (index >= procInfo.size()) 691 break; 692 if (isParallelIterator(iteratorType.value())) { 693 unsigned i = iteratorType.index(); 694 updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId, 695 procInfo[index].nprocs, lbsStorage[i], 696 ubsStorage[i], stepsStorage[i]); 697 index++; 698 } 699 } 700 } 701 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); 702 generateParallelLoopNest( 703 b, loc, lbs, ubs, steps, iteratorTypes, 704 [&](OpBuilder &b, Location loc, ValueRange ivs) { 705 SmallVector<Value> operandValuesToUse = 706 linalgOp.getInputAndOutputOperands(); 707 bodyBuilderFn(b, loc, ivs, operandValuesToUse); 708 }, 709 ivs, distributionMethod); 710 711 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); 712 } 713 714 static Value fullyComposeAndAffineApply(OpBuilder &b, Location loc, 715 AffineExpr expr, ValueRange operands) { 716 AffineMap map = AffineMap::inferFromExprList({expr}).front(); 717 SmallVector<Value> normalizedOperands(operands.begin(), operands.end()); 718 mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands); 719 canonicalizeMapAndOperands(&map, &normalizedOperands); 720 return b.createOrFold<AffineApplyOp>(loc, map, normalizedOperands); 721 } 722 723 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, 724 ValueRange tileSizes, AffineMap map, ValueRange lbs, 725 ValueRange ubs, ValueRange subShapeSizes) { 726 auto shapedType = valueToTile.getType().dyn_cast<ShapedType>(); 727 assert(shapedType && "only shaped types can be tiled"); 728 ArrayRef<int64_t> shape = shapedType.getShape(); 729 int64_t rank = shapedType.getRank(); 730 731 // Construct a new subview / extract_slice for the tile. 732 SmallVector<OpFoldResult, 4> offsets, sizes, strides; 733 offsets.reserve(rank); 734 sizes.reserve(rank); 735 strides.reserve(rank); 736 for (unsigned r = 0; r < rank; ++r) { 737 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: for dim#" << r); 738 if (!isTiled(map.getSubMap({r}), tileSizes)) { 739 offsets.push_back(builder.getIndexAttr(0)); 740 Value dim = createOrFoldDimOp(builder, loc, valueToTile, r); 741 sizes.push_back(getAsOpFoldResult(dim)); 742 strides.push_back(builder.getIndexAttr(1)); 743 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); 744 continue; 745 } 746 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n"); 747 748 // Tiling creates a new slice at the proper index, the slice step is 1 749 // (i.e. the op does not subsample, stepping occurs in the loop). 750 auto m = map.getSubMap({r}); 751 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n"); 752 auto offset = applyMapToValues(builder, loc, m, lbs).front(); 753 offsets.push_back(offset); 754 auto closedIntSize = 755 applyMapToValues(builder, loc, m, subShapeSizes).front(); 756 // Resulting size needs to be made half open interval again. 757 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext()); 758 Value size = 759 fullyComposeAndAffineApply(builder, loc, s0 + 1, closedIntSize); 760 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n"); 761 762 // The size of the subview / extract_slice should be trimmed to avoid 763 // out-of-bounds accesses, unless: 764 // a. We statically know the subshape size divides the shape size evenly. 765 // b. The subshape size is 1. According to the way the loops are set up, 766 // tensors with "0" dimensions would never be constructed. 767 int64_t shapeSize = shape[r]; 768 auto sizeCst = size.getDefiningOp<arith::ConstantIndexOp>(); 769 auto hasTileSizeOne = sizeCst && sizeCst.value() == 1; 770 auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && 771 ((shapeSize % sizeCst.value()) == 0); 772 if (!hasTileSizeOne && !dividesEvenly) { 773 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize 774 << ", size: " << size 775 << ": make sure in bound with affine.min\n"); 776 777 AffineExpr dim0, dim1, dim2; 778 bindDims(builder.getContext(), dim0, dim1, dim2); 779 780 // Get the dimension size for this dimension. We need to first calculate 781 // the max index and then plus one. This is important because for 782 // convolution ops, we have its input window dimension's affine map of the 783 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window 784 // dimension and `s0` is stride. Directly use the dimension size of 785 // output/filer window dimensions will cause incorrect calculation. 786 AffineMap minusOneMap = 787 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}}) 788 .front(); 789 AffineMap plusOneMap = 790 AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}}) 791 .front(); 792 auto maxIndices = llvm::to_vector<8>(llvm::map_range(ubs, [&](Value ub) { 793 return makeComposedAffineApply(builder, loc, minusOneMap, {ub}) 794 .getResult(); 795 })); 796 Value maxIndex = applyMapToValues(builder, loc, m, maxIndices).front(); 797 Value d = makeComposedAffineApply(builder, loc, plusOneMap, {maxIndex}); 798 799 // Compute min(size, dim - offset) to avoid out-of-bounds accesses. 800 AffineMap minMap = AffineMap::inferFromExprList( 801 {ArrayRef<AffineExpr>{dim0, dim1 - dim2}}) 802 .front(); 803 SmallVector<Value, 4> operands{size, d, offset}; 804 fullyComposeAffineMapAndOperands(&minMap, &operands); 805 canonicalizeMapAndOperands(&minMap, &operands); 806 size = builder.create<AffineMinOp>(loc, builder.getIndexType(), minMap, 807 operands); 808 } 809 810 sizes.push_back(size); 811 LLVM_DEBUG(llvm::dbgs() 812 << "makeTiledShape: new offset: " << offset << "\n"); 813 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 814 strides.push_back(builder.getIndexAttr(1)); 815 } 816 817 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) 818 .Case([&](MemRefType) { 819 return builder.create<memref::SubViewOp>( 820 loc, valueToTile, offsets, sizes, strides); 821 }) 822 .Case([&](RankedTensorType) { 823 return makeComposedExtractSliceOp( 824 builder, loc, valueToTile, offsets, sizes, strides); 825 }) 826 .Default([](ShapedType) -> Operation * { 827 llvm_unreachable("Unexpected shaped type"); 828 }); 829 return sliceOp->getResult(0); 830 } 831 832 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc, 833 ValueRange ivs, ValueRange tileSizes) { 834 SmallVector<Value> offsets; 835 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { 836 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); 837 bool isTiled = !isZero(tileSizes[idx]); 838 offsets.push_back( 839 isTiled ? ivs[idxIvs++] 840 : b.create<arith::ConstantIndexOp>(loc, 0).getResult()); 841 LLVM_DEBUG(llvm::dbgs() 842 << "computeTileOffsets: " << offsets.back() << "\n"); 843 } 844 return offsets; 845 } 846 847 SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs, 848 ValueRange tileSizes, 849 ArrayRef<Value> sizeBounds) { 850 SmallVector<Value> sizes; 851 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { 852 bool isTiled = !isZero(tileSizes[idx]); 853 // Before composing, we need to make range a closed interval. 854 Value size = isTiled ? tileSizes[idx] : sizeBounds[idx]; 855 AffineExpr d0 = getAffineDimExpr(0, b.getContext()); 856 sizes.push_back(fullyComposeAndAffineApply(b, loc, d0 - 1, size)); 857 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n"); 858 } 859 return sizes; 860 } 861 862 SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc, 863 LinalgOp linalgOp, 864 ArrayRef<Value> valuesToTile, 865 ValueRange ivs, ValueRange tileSizes, 866 ArrayRef<Value> sizeBounds) { 867 assert(ivs.size() == static_cast<size_t>(llvm::count_if( 868 llvm::make_range(tileSizes.begin(), tileSizes.end()), 869 [](Value v) { return !isZero(v); })) && 870 "expected as many ivs as non-zero sizes"); 871 872 // Construct (potentially temporary) mins and maxes on which to apply maps 873 // that define tile subshapes. 874 SmallVector<Value> lbs = computeTileOffsets(b, loc, ivs, tileSizes); 875 SmallVector<Value> subShapeSizes = 876 computeTileSizes(b, loc, ivs, tileSizes, sizeBounds); 877 878 assert(static_cast<int64_t>(valuesToTile.size()) == 879 linalgOp.getNumInputsAndOutputs() && 880 "expected one value to tile for every operand"); 881 SmallVector<Value, 4> tiledShapes; 882 tiledShapes.reserve(valuesToTile.size()); 883 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 884 Value shapedOp = valuesToTile[opOperand->getOperandNumber()]; 885 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); 886 AffineMap map = linalgOp.getTiedIndexingMap(opOperand); 887 // Use `opOperand` as is if it is not tiled and not an output tensor. Having 888 // an extract/insert slice pair for all output tensors simplifies follow up 889 // transformations such as padding and bufferization since the 890 // extract/insert slice pairs make the accessed iteration argument 891 // subdomains explicit. 892 if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) { 893 tiledShapes.push_back(shapedOp); 894 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: " 895 << opOperand->get().getType() << "\n"); 896 continue; 897 } 898 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); 899 900 tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs, 901 sizeBounds, subShapeSizes)); 902 } 903 904 return tiledShapes; 905 } 906 907 void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp, 908 ArrayRef<Value> ivs) { 909 if (tiledOp.hasIndexSemantics()) { 910 for (IndexOp indexOp : tiledOp.getBlock()->getOps<IndexOp>()) { 911 if (ivs[indexOp.dim()] == nullptr) 912 continue; 913 OpBuilder::InsertionGuard guard(b); 914 b.setInsertionPointAfter(indexOp); 915 AffineExpr index, offset; 916 bindDims(b.getContext(), index, offset); 917 AffineApplyOp applyOp = makeComposedAffineApply( 918 b, indexOp.getLoc(), index + offset, 919 ValueRange{indexOp.getResult(), ivs[indexOp.dim()]}); 920 indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp); 921 } 922 } 923 } 924 925 } // namespace linalg 926 } // namespace mlir 927