1 //===- Utils.h - Utilities to support the Linalg dialect --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_LINALG_UTILS_UTILS_H 10 #define MLIR_DIALECT_LINALG_UTILS_UTILS_H 11 12 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/SCF/IR/SCF.h" 15 #include "llvm/ADT/MapVector.h" 16 #include "llvm/ADT/SetVector.h" 17 18 namespace mlir { 19 class AffineExpr; 20 class AffineForOp; 21 class AffineMap; 22 class PatternRewriter; 23 24 namespace tensor { 25 class ExtractSliceOp; 26 } // namespace tensor 27 28 namespace linalg { 29 class LinalgDependenceGraph; 30 31 //===----------------------------------------------------------------------===// 32 // General utilities 33 //===----------------------------------------------------------------------===// 34 35 /// Check if all indexing maps are projected permutations. 36 bool allIndexingsAreProjectedPermutation(LinalgOp op); 37 38 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. 39 bool hasOnlyScalarElementwiseOp(Region &r); 40 41 /// Check if a LinalgOp is an element-wise operation. 42 bool isElementwise(LinalgOp op); 43 44 /// Check if `permutation` is a permutation of the range 45 /// `[0, permutation.size())`. 46 bool isPermutation(ArrayRef<int64_t> permutation); 47 48 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on 49 /// the type of `source`. 50 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); 51 52 /// Given an operation, retrieves the value of each dynamic dimension through 53 /// constructing the necessary DimOp operators. 54 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b); 55 56 /// Computes an upper bound for the result `value` of an index computation. 57 /// Translates AffineMinOps and AffineApplyOps along the use-def chains of the 58 /// index computation to affine constraints and projects out intermediate 59 /// values. The method sets `boundMap` to an affine map that given 60 /// `boundOperands` evaluates to an upper bound for the index computation. 61 /// 62 /// If constantRequired is true, only returns the constant bounds (potentially 63 /// over-approximating) and fails when not possible. 64 /// 65 /// Example: 66 /// ``` 67 /// %dim0 = dim %tensor, %c0 68 /// %dim1 = dim %tensor, %c1 69 /// %0 = affine.min affine.map<(d0) -> (40, d0)> (%dim0) 70 /// %1 = affine.apply affine.map<(d0, d1) -> (d0 + d1)> (%0, %dim1) 71 /// ``` 72 /// getUpperBoundForIndex(%1, boundMap, boundOperands) 73 /// set the output parameters to: 74 /// - boundMap = affine.map<(d0) -> (d0 + 40)> 75 /// - boundOperands = [%dim1] 76 void getUpperBoundForIndex(Value value, AffineMap &boundMap, 77 SmallVectorImpl<Value> &boundOperands, 78 bool constantRequired = false); 79 80 /// Returns a constant upper bound for the result `value` of an index 81 /// computation. Calls `getUpperBoundForIndex` and returns a constant upper 82 /// bound if the result of `boundMap` is a constant expression and failure 83 /// otherwise. 84 /// 85 /// Example: 86 /// ``` 87 /// %0 = affine.min affine.map<(d0) -> (40, d0)> (%d0) 88 /// %1 = affine.apply affine.map<(d0) -> (d0 + 2)> (%0) 89 /// ``` 90 /// getConstantUpperBoundForIndex(%1) returns 42 91 /// (boundsMap = affine.map<() -> (42)>) 92 FailureOr<int64_t> getConstantUpperBoundForIndex(Value value); 93 94 /// Create an ExtractSliceOp and, if `source` is defined by an ExtractSliceOp, 95 /// fold it by adding the offsets. 96 /// 97 /// Example: 98 /// ``` 99 /// %0 = tensor.extract_slice %arg0[3, 4][3, 32][1, 1] : tensor<64x64xf32> to 100 /// tensor<3x32xf32> 101 /// %1 = tensor.extract_slice %0[0, 5][3, 4][1, 1] : tensor<3x32xf32> to 102 /// tensor<3x4xf32> 103 /// ``` 104 /// folds into: 105 /// ``` 106 /// %1 = tensor.extract_slice %arg0[3, 9][3, 4][1, 1] : tensor<64x64xf32> to 107 /// tensor<3x4xf32> 108 /// ``` 109 tensor::ExtractSliceOp makeComposedExtractSliceOp( 110 OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets, 111 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides); 112 113 /// Create a tensor::PadOp that pads `source` to the size of the statically 114 /// sized `type` whose static sizes are assumed to be greater than the dynamic 115 /// `source` size. The padding introduces trailing `pad` values until the target 116 /// size is met. If `source` is defined by one or more LinalgOps that have been 117 /// padded with the same value and sizes, return their padded result instead of 118 /// creating a tensor::PadOp. 119 /// 120 /// Example: 121 /// ``` 122 /// %0 = tensor.extract_slice %arg0 [%iv0, %iv1] [%sz0, %sz1] 123 /// %1 = tensor.pad %0 low[0, 0] high[...] { tensor.yield %cst } 124 /// %2 = linalg.matmul ins(...) outs(%1) 125 /// %3 = tensor.extract_slice %2 [0, 0] [%sz0, %sz1] 126 /// ``` 127 /// makeComposedPadHighOp(source=%3, pad=%cst) returns %2 128 /// makeComposedPadHighOp(source=%3, pad=%other_cst) returns %4 129 /// ``` 130 /// %4 = tensor.pad %3 low[0, 0] high[...] { tensor.yield %other_cst } 131 /// ``` 132 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, 133 Value source, Value pad, bool nofold); 134 135 /// Returns a GenericOp that tansposes `inputTensor` into `outputTensor` using 136 /// `transposeVector` to permute the `inputTensor` dimensions. 137 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, 138 Value outputTensor, 139 ArrayRef<int64_t> transposeVector); 140 141 /// Returns GenericOp that copies an n-D memref. Unlike the current 142 /// implementation of memref::CopyOp, this op can further tile, lower to loops 143 /// or vectorize. 144 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to); 145 146 /// Get the reassociation maps to fold the result of a extract_slice (or source 147 /// of a insert_slice) operation with given offsets, and sizes to its 148 /// rank-reduced version. This is only done for the cases where the size is 1 149 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 150 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 151 /// potentially cannot be handled). 152 Optional<SmallVector<ReassociationIndices>> 153 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes); 154 155 //===----------------------------------------------------------------------===// 156 // Fusion / Tiling utilities 157 //===----------------------------------------------------------------------===// 158 159 /// The type of loops to be generated during tiling. 160 enum class LinalgTilingLoopType { 161 Loops = 0, 162 AffineLoops = 1, 163 ParallelLoops = 2, 164 TiledLoops = 3, 165 }; 166 167 /// Checks whether the specific `producer` is the last write to exactly the 168 /// whole `consumedView`. This checks structural dominance, that the dependence 169 /// is a RAW without any interleaved write to any piece of `consumedView`. 170 bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 171 LinalgOp consumer, Value consumedView, 172 LinalgOp producer); 173 174 /// Checks whether fusing the specific `producer` of the `consumedView` is 175 /// feasible. This checks `producer` is the last write of `consumedView` and 176 /// that no interleaved dependence would be violated (RAW, WAR or WAW). 177 bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, 178 Value consumedView, LinalgOp producer); 179 180 /// Creates either a memref.subview or a tensor.extract_slice with the given 181 /// offsets/sizes/strides based on the type of `value`. 182 Value createSlice(OpBuilder &builder, Location loc, Value value, 183 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 184 ArrayRef<OpFoldResult> strides); 185 186 /// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a 187 /// tile size is zero (i.e., no tiling), the corresponding offset is also zero. 188 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc, 189 ValueRange ivs, ValueRange tileSizes); 190 191 /// Computes tile sizes, given a list of `tileSizes` and dimension 192 /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the 193 /// corresponding result size is the corresponding value from `sizeBounds`. 194 /// Note: The returned tile sizes are closed intervals. 195 SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, 196 ValueRange tileSizes, 197 ArrayRef<Value> sizeBounds); 198 199 /// Returns the list of tensor output types produced when the given structured 200 /// operation `op` is applied to the given `operands`. Note that `operands` are 201 /// not necessarily the actual operands of `op`. 202 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands); 203 204 /// Creates `insert_slice` ops that insert `results` back into larger tensors 205 /// they were originally extracted from with `extract_slice` before being passed 206 /// as `operands` to the given structured operation `op` or its clone. Note that 207 /// `operands` are not necessarily the actual operands of `op`, the operation 208 /// serves only as metadata container for operand types and positions. 209 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, 210 LinalgOp op, ValueRange operands, 211 ValueRange results); 212 213 /// Turns an OpFoldResult into a value, creating an index-typed constant if 214 /// necessary. 215 Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, 216 OpFoldResult opFoldResult); 217 218 /// Creates an extract_slice/subview op for a single `valueToTile` with 219 /// `builder`. This new operation extracts a tile of `valueToTile`, starting 220 /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck` 221 /// controls whether to omit the partial/boundary tile condition check in cases 222 /// where we statically know that it is unnecessary. 223 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, 224 ValueRange tileSizes, AffineMap map, ValueRange lbs, 225 ValueRange ubs, ValueRange subShapeSizes, 226 bool omitPartialTileCheck); 227 228 /// Creates extract_slice/subview ops for all `valuesToTile` of the given 229 /// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop 230 /// nest for tiling with the given induction variables `ivs` and tile sizes 231 /// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the 232 /// implicit loops in `linalgOp`. `omitPartialTileCheck` controls whether to 233 /// omit the partial/boundary tile condition check in cases where we statically 234 /// know that it is unnecessary. 235 /// 236 /// Note that a constant zero in `tileSizes` means no tiling at that implicit 237 /// loop. The number of non-zero values in `tileSizes` should be equal to the 238 /// number of values in `ivs`. 239 SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc, 240 LinalgOp linalgOp, 241 ArrayRef<Value> valuesToTile, 242 ValueRange ivs, ValueRange tileSizes, 243 ArrayRef<Value> sizeBounds, 244 bool omitPartialTileCheck); 245 246 /// Add the specified offsets to any `linalg.index` ops contained in the given 247 /// `linalgOp`. The offsets are provided in the same order as iteration space 248 /// dimensions. Null offests are assumed to be zero. 249 void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef<Value> offests); 250 void offsetIndices(RewriterBase &b, LinalgOp linalgOp, ArrayRef<Value> offests); 251 252 using FusableOpDependencesTy = llvm::MapVector< 253 Operation *, 254 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>; 255 FusableOpDependencesTy 256 findAllFusableDependences(ArrayRef<LinalgOp> ops, 257 const LinalgDependenceGraph &dependenceGraph); 258 259 /// A struct containing the Linalg producer before and after fusion. 260 /// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op 261 /// before the consumer Linalg op, until enough canonicalizations have applied. 262 struct FusionInfo { 263 LinalgOp originalProducer; 264 LinalgOp fusedProducer; 265 }; 266 267 /// Fuses producer into consumer if the producer is structurally feasible and 268 /// the fusion would not violate dependencies. 269 /// Implements the fusion part of the "tileAndFuse on buffers" transformation 270 /// and thus requires the `consumerOpOperand` to be a `subview` op (generally 271 /// obtained by applying the tiling transformation). 272 FailureOr<FusionInfo> fuseProducerOfBuffer(OpBuilder &b, 273 OpOperand &consumerOpOperand, 274 const LinalgDependenceGraph &graph); 275 /// Tensor counterpart of `fuseProducerOfBuffer`. 276 /// This implements the fusion part of the "tileAndFuse on tensors" 277 /// transformation and thus requires the `consumerOpOperand` to be a 278 /// `extract_slice` op (generally obtained by applying the tiling 279 /// transformation). 280 FailureOr<FusionInfo> fuseProducerOfTensor(OpBuilder &b, 281 OpOperand &consumerOpOperand); 282 /// Tensor counterpart of `fuseProducerOfBuffer`. 283 /// This implements the fusion part of the "tileAndFuse on tensors" 284 /// transformation and thus requires the `consumerOpOperand` to be a 285 /// `extract_slice` op (generally obtained by applying the tiling 286 /// transformation). Assumes `producerOfTensor` is a Linalg op that produces 287 /// `consumerOpOperand`. 288 FailureOr<FusionInfo> fuseProducerOfTensor(OpBuilder &b, 289 OpResult producerOpResult, 290 OpOperand &consumerOpOperand); 291 292 //===----------------------------------------------------------------------===// 293 // Distribution utilities 294 //===----------------------------------------------------------------------===// 295 296 /// Scheme used to distribute loops to processors. 297 enum class DistributionMethod { 298 /// Cyclic distribution where no assumption is made about the dynamic 299 /// relationship between number of processors and number of iterations of the 300 /// distributed loop. Distributes the following loop 301 /// 302 /// scf.parallel (%iv) = (%lb) to (%ub) step (%step) 303 /// 304 /// to 305 /// 306 /// scf.parallel(%iv)= (%lb + %procId * %step) to (%ub) step (%step * %nprocs) 307 Cyclic = 0, 308 309 /// Cyclic distribution where the number of processors can be assumed to be 310 /// more than or equal to the number of iterations of the distributed loop. In 311 /// such cases, a simple in-bounds check is enough (instead of materializing a 312 /// loop). Distributes the following loop 313 /// 314 /// scf.parallel (%iv) = (%lb) to (%ub) step (%step) 315 /// 316 /// to 317 /// 318 /// %iv = %lb + %procId * %step 319 /// %cond = arith.cmpi "slt", %iv, %ub 320 /// scf.if %cond { 321 /// ... 322 /// } 323 CyclicNumProcsGeNumIters = 1, 324 325 /// Cyclic distribution where the number of processors can be assumed to be 326 /// equal to the number of iterations of the distributed loop. In such cases, 327 /// no bounds check is needed. Distributes the following loop 328 /// 329 /// scf.parallel (%iv) = (%lb) to (%ub) step (%step) 330 /// 331 /// to 332 /// 333 /// %iv = %lb + %procId * %step 334 CyclicNumProcsEqNumIters = 2 335 }; 336 337 /// Callback function type used to get processor ID, and number of processors 338 /// used for distribution for all parallel loops generated. 339 struct ProcInfo { 340 Value procId; 341 Value nprocs; 342 }; 343 using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo, 2>( 344 OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges)>; 345 using OneDimProcInfoCallBackFn = 346 std::function<ProcInfo(OpBuilder &b, Location loc)>; 347 348 /// Options that allow distribution of loops generated in Linalg transforms to 349 /// processors while generating the loops. 350 struct LinalgLoopDistributionOptions { 351 /// Callback function that returns the Values for processor ID (`procId`), and 352 /// number of processors (`nprocs`) used to execute the parallel loops. The 353 /// number of `{procId, nprocs}` pairs returned must be equal to the number of 354 /// `parallelLoopRanges` passed into the callback, which in-turn is same as 355 /// the number of parallel loops for which the `distributionMethod` is 356 /// specified below. 357 ProcInfoCallBackFn procInfo; 358 /// Specification of how to distribute the `scf.parallel` loops that are 359 /// generated. As the `scf.parallel` loop is generated, the elements of this 360 /// vector is used (from left to right) and the specified distribution is 361 /// applied. If the vector is less than the number of `scf.parallel` loops 362 /// generated, then no distribution is applied. 363 SmallVector<DistributionMethod, 0> distributionMethod = {}; 364 365 /// The map keyed by the distribution type that contains callback functions 366 /// that return the Values for processor ID (`procId`), and number of 367 /// processors (`nprocs`) used to execute the parallel loops. 368 DenseMap<StringRef, OneDimProcInfoCallBackFn> procInfoMap; 369 }; 370 371 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. 372 void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, 373 Value procId, Value nprocs, Value &lb, 374 Value &ub, Value &step); 375 376 //===----------------------------------------------------------------------===// 377 // Fusion on tensor utilities 378 //===----------------------------------------------------------------------===// 379 380 /// A struct to manage the tile loop nest specific information. 381 class TileLoopNest { 382 public: TileLoopNest(LinalgOp rootOp)383 TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {} 384 385 /// Tile the root operation using the given `tileSizes` and `tileInterchange`, 386 /// and `tileDistribution`. 387 LogicalResult 388 tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes, 389 ArrayRef<int64_t> tileInterchange, 390 Optional<LinalgLoopDistributionOptions> tileDistribution); 391 392 /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns 393 /// the fused producer or fails if fusion is not possible. 394 FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand); 395 396 /// Returns the replacement results for the original untiled root operation. 397 ValueRange getRootOpReplacementResults(); 398 399 /// Returns the tiled root operation. getRootOp()400 LinalgOp getRootOp() { return rootOp; } 401 402 /// Returns the tiled root operation and the fused producers. 403 SmallVector<LinalgOp> getAllTiledAndFusedOps(); 404 405 /// Returns the loop ops generated from tiling. getLoopOps()406 ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; } 407 408 /// Returns true if the tile loop nest has no tile loops. 409 bool isEmpty(); 410 411 private: 412 /// Returns true if the tile loop nest invariants are satisfied: 413 /// - The `rootOp` has been tiled at least once. 414 /// - The number of tile loop operations and dimensions match. 415 /// - The innermost tile loop is the parent of `tiledOp`. 416 /// - The tile loops are directly nested. 417 // TODO: relax to support additional control flow, e.g., IfOp. 418 bool isValid(); 419 420 /// Searches the block arguments tied to a block argument `bbArg` of the 421 /// innermost tile loop. Returns the block argument from outermost to 422 /// innermost or an empty vector if none are found. 423 SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg); 424 425 /// Returns the iteration argument of the outermost tile loop mapped to a 426 /// block argument `bbArg` of the innermost tile loop. 427 OpOperand *getTiedIterArg(BlockArgument bbArg); 428 429 /// Returns true if `bbArg` has other used than `sliceOp` and its 430 /// dependencies. Only if there are no other uses, the producer output 431 /// iteration argument may reused to pass the producer result after fusion. 432 bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp); 433 434 LinalgOp rootOp; 435 SmallVector<scf::ForOp> tileLoopOps; 436 DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops; 437 }; 438 439 /// Tiles `consumerOp` and fuses its dependencies if possible. Uses the 440 /// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control 441 /// the tiling. 442 FailureOr<TileLoopNest> tileConsumerAndFuseProducers( 443 OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes, 444 ArrayRef<int64_t> tileInterchange, 445 const Optional<LinalgLoopDistributionOptions> &tileDistribution); 446 447 //===----------------------------------------------------------------------===// 448 // Generic op region utilities 449 //===----------------------------------------------------------------------===// 450 451 /// A struct containing common matchers over linalg op's region. 452 struct RegionMatcher { 453 enum class BinaryOpKind { 454 IAdd, 455 }; 456 457 /// Matches the given linalg op if its body is performing binary operation on 458 /// int or float scalar values and returns the binary op kind. 459 /// 460 /// The linalg op's region is expected to be 461 /// ``` 462 /// { 463 /// ^bb(%a: <scalar-type>, %b: <scalar-type>): 464 /// %0 = <binary-op> %a, %b: <scalar-type> 465 /// linalg.yield %0: <scalar-type> 466 /// } 467 /// ``` 468 static Optional<BinaryOpKind> matchAsScalarBinaryOp(GenericOp op); 469 }; 470 471 //===----------------------------------------------------------------------===// 472 // Loop nest utilities 473 //===----------------------------------------------------------------------===// 474 475 /// Utility class used to generate nested loops with ranges described by 476 /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn` 477 /// is used to generate the body of the innermost loop. It is passed a range 478 /// of loop induction variables and a range of operand values to use. 479 template <typename LoopTy> 480 struct GenerateLoopNest { 481 static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, 482 LinalgOp linalgOp, ArrayRef<Attribute> iteratorTypes, 483 function_ref<scf::ValueVector(OpBuilder &, Location, 484 ValueRange, ValueRange)> 485 bodyBuilderFn, 486 Optional<LinalgLoopDistributionOptions> = None, 487 ArrayRef<StringRef> distributionTypes = {}); 488 }; 489 490 } // namespace linalg 491 } // namespace mlir 492 493 #endif // MLIR_DIALECT_LINALG_UTILS_UTILS_H 494