1 //===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H 10 #define MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H 11 12 #include <utility> 13 14 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/Utils/Utils.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Utils/StaticValueUtils.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Dialect/X86Vector/Transforms.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "llvm/ADT/SmallBitVector.h" 25 #include "llvm/ADT/SmallSet.h" 26 27 namespace mlir { 28 namespace bufferization { 29 class BufferizeTypeConverter; 30 } // namespace bufferization 31 32 class FrozenRewritePatternSet; 33 34 namespace linalg { 35 36 struct LinalgElementwiseFusionOptions; 37 struct LinalgFusionOptions; 38 struct LinalgTilingOptions; 39 40 //===----------------------------------------------------------------------===// 41 // Transformations exposed as function calls. 42 //===----------------------------------------------------------------------===// 43 using LinalgLoops = SmallVector<Operation *, 4>; 44 45 void populatePadTensorTilingPatterns(RewritePatternSet &patterns, 46 const LinalgTilingOptions &options); 47 48 /// Populate patterns for splitting a `LinalgOp` with multiple statements within 49 /// its payload into multiple `GenericOp` that have a single statement. 50 void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns); 51 52 /// Populate patterns for vectorizing low-D convolution ops. This is a step in 53 /// progressive lowering for convolution ops, it assume high-D convolution ops 54 /// were decomposed previously. 55 void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, 56 PatternBenefit benefit = 1); 57 58 /// Populate patterns that convert `ElementwiseMappable` ops to linalg 59 /// parallel loops. 60 void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); 61 62 /// Populate patterns that are only useful in the context of sparse tensors. 63 void populateSparseTensorRewriting(RewritePatternSet &patterns); 64 65 /// Function type which is used to control when to stop fusion. It is expected 66 /// that OpOperand is not modified in the callback. The OpOperand is not marked 67 /// as const to allow callers to use non-const methods. 68 using ControlFusionFn = 69 std::function<bool(const OpResult &producer, OpOperand &consumer)>; 70 71 /// Patterns for fusing linalg operation on tensors. 72 73 /// Pattern to fuse `linalg.generic` -> `linalg.generic` operations 74 /// when both operations are fusable elementwise operations. 75 void populateElementwiseOpsFusionPatterns( 76 RewritePatternSet &patterns, 77 const ControlFusionFn &controlElementwiseOpFusion); 78 79 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its 80 /// producer (consumer) generic operation by expanding the dimensionality of the 81 /// loop in the generic op. 82 void populateFoldReshapeOpsByExpansionPatterns( 83 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); 84 85 /// Patterns to fold an expanding tensor.expand_shape operation with its 86 /// producer generic operation by collapsing the dimensions of the generic op. 87 void populateFoldReshapeOpsByCollapsingPatterns( 88 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); 89 90 /// Patterns to constant fold Linalg operations. 91 void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, 92 const ControlFusionFn &controlFn); 93 94 /// Pattern to fuse a `tensor.pad` operation with the producer of its source, 95 /// if the producer is a `linalg` operation with all parallel iterator types. 96 void populateFuseTensorPadWithProducerLinalgOpPatterns( 97 RewritePatternSet &patterns); 98 99 /// Patterns to convert from one named op to another. These can be seen as 100 /// canonicalizations of named ops into another named op. 101 void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); 102 103 /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on 104 /// tensors. 105 void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); 106 107 /// Patterns that are used to inline constant operands into linalg generic ops. 108 void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); 109 110 /// Patterns that are used to bubble up extract slice op above linalg op. 111 void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); 112 113 /// Split the given `op` into two parts along the given iteration space 114 /// `dimension` at the specified `splitPoint`, and return the two parts. 115 /// 116 /// For example, the following op: 117 /// 118 /// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>) 119 /// outs(%2 : tensor<128x64xf32>) 120 /// 121 /// split along the first dimension at position 42 will result in: 122 /// 123 /// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1] 124 /// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1] 125 /// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>) 126 /// outs(%5 : tensor<42x64xf32>) 127 /// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1] 128 /// 129 /// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1] 130 /// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1] 131 /// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>) 132 /// outs(%8 : tensor<86x64xf32>) 133 /// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1] 134 /// 135 /// Note that there is no simplification other than constant propagation applied 136 /// to slice extraction and insertion. 137 std::pair<LinalgOp, LinalgOp> splitOp(RewriterBase &rewriter, LinalgOp op, 138 unsigned dimension, 139 OpFoldResult splitPoint); 140 141 /// Perform standalone tiling of a single LinalgOp by `tileSizes`. 142 /// and permute the loop nest according to `interchangeVector` 143 /// The permutation is expressed as a list of integers that specify 144 /// the new ordering of the loop nest. The length of `interchangeVector` 145 /// must be equal to the length of `tileSizes`. 146 /// An empty vector is interpreted as the identity permutation and the 147 /// transformation returns early. 148 /// 149 /// Return a struct containing the tiled loops in the specified order 150 /// and the cloned op if successful, llvm::None otherwise. 151 /// 152 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by 153 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be 154 /// integers, in the range 0..`tileSizes.size()` without duplications 155 /// (i.e. `[1,1,2]` is an invalid permutation). 156 struct TiledLinalgOp { 157 LinalgOp op; 158 SmallVector<Operation *, 8> loops; 159 SmallVector<Value, 4> tensorResults; 160 }; 161 FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op, 162 const LinalgTilingOptions &options); 163 164 /// Peel and canonicalize 'loops'. 165 void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops); 166 167 /// Peel the loops of a TiledLinalgOp. 168 void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, 169 ArrayRef<int64_t> peeledLoops, 170 LinalgTilingLoopType loopType); 171 172 /// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts 173 /// the index accesses of `op`. This is an in-place transformation controlled by 174 /// `interchangeVector`. An empty vector is interpreted as the identity 175 /// permutation and the transformation returns early. 176 /// 177 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with 178 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be 179 /// integers, in the range 0..`op.rank` without duplications 180 /// (i.e. `[1,1,2]` is an invalid permutation). 181 /// 182 /// Return failure if the permutation is not valid. 183 FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter, 184 GenericOp genericOp, 185 ArrayRef<unsigned> interchangeVector); 186 187 /// Create a GenericOp from the given named operation `namedOp` and replace 188 /// namedOp. 189 /// Return failure if `namedOp` is a GenericOp or misses a region builder. 190 FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter, 191 LinalgOp namedOp); 192 193 /// Callback function type used to perform the allocation for the promoted 194 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the 195 /// smallest constant value for the size of the buffer needed for each 196 /// dimension. If that is not possible, contains the dynamic size of the 197 /// subview. The call back should return the buffer to use. 198 using AllocBufferCallbackFn = std::function<Optional<Value>( 199 OpBuilder &b, memref::SubViewOp subView, 200 ArrayRef<Value> boundingSubViewSize, DataLayout &layout)>; 201 202 /// Callback function type used to deallocate the buffers used to hold the 203 /// promoted subview. 204 using DeallocBufferCallbackFn = 205 std::function<LogicalResult(OpBuilder &b, Value buffer)>; 206 207 /// Callback function type used to insert copy from original subview to subview 208 /// of the promoted region for the read operands/subview of promoted region to 209 /// original subview for the results. The copy has to happen from `src` to 210 /// `dst`. 211 using CopyCallbackFn = 212 std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>; 213 214 struct LinalgPromotionOptions { 215 /// Indices of subViews to promote. If `None`, try to promote all operands. 216 Optional<DenseSet<unsigned>> operandsToPromote = None; setOperandsToPromoteLinalgPromotionOptions217 LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) { 218 operandsToPromote = DenseSet<unsigned>(); 219 operandsToPromote->insert(operands.begin(), operands.end()); 220 return *this; 221 } 222 /// If ith element of `useFullTiles` is true the full view should be used for 223 /// the promoted buffer of the ith operand in `operandsToPromote`. Otherwise 224 /// the partial view will be used. 225 /// The decision is defaulted to `useFullTileBuffersDefault` when 226 /// `useFullTileBuffers` is None and for operands missing from 227 /// `useFullTileBuffers`. 228 Optional<llvm::SmallBitVector> useFullTileBuffers = None; setUseFullTileBuffersLinalgPromotionOptions229 LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) { 230 unsigned size = useFullTiles.size(); 231 llvm::SmallBitVector tmp(size, false); 232 for (unsigned i = 0; i < size; ++i) 233 tmp[i] = useFullTiles[i]; 234 useFullTileBuffers = tmp; 235 return *this; 236 } 237 /// If true all operands unspecified by `useFullTileBuffers` will use the full 238 /// view, otherwise the partial view. 239 bool useFullTileBuffersDefault = false; setUseFullTileBuffersByDefaultLinalgPromotionOptions240 LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) { 241 useFullTileBuffersDefault = use; 242 return *this; 243 } 244 /// Alignment of promoted buffer. If `None` do not specify alignment. 245 Optional<unsigned> alignment = None; setAlignmentLinalgPromotionOptions246 LinalgPromotionOptions &setAlignment(unsigned align) { 247 alignment = align; 248 return *this; 249 } 250 /// Use alloca with the default allocation scheme. 251 bool useAlloca = false; setUseAllocaLinalgPromotionOptions252 LinalgPromotionOptions &setUseAlloca(bool use) { 253 useAlloca = use; 254 return *this; 255 } 256 /// Callback function to do the allocation of the promoted buffer. If None, 257 /// then the default allocation scheme of allocating a memref<?xi8> buffer 258 /// followed by a view operation is used. 259 Optional<AllocBufferCallbackFn> allocationFn = None; 260 Optional<DeallocBufferCallbackFn> deallocationFn = None; 261 LinalgPromotionOptions & setAllocationDeallocationFnsLinalgPromotionOptions262 setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, 263 DeallocBufferCallbackFn const &deallocFn) { 264 allocationFn = allocFn; 265 deallocationFn = deallocFn; 266 return *this; 267 } 268 /// Callback function to do the copy of data to and from the promoted 269 /// subview. If None then a memref.copy is used. 270 Optional<CopyCallbackFn> copyInFn = None; 271 Optional<CopyCallbackFn> copyOutFn = None; setCopyInOutFnsLinalgPromotionOptions272 LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const ©In, 273 CopyCallbackFn const ©Out) { 274 copyInFn = copyIn; 275 copyOutFn = copyOut; 276 return *this; 277 } 278 }; 279 280 /// Create a new buffer using the `allocationFn` provided. The size of this 281 /// buffer is the smallest constant bounding size along each dimension that can 282 /// be computed for the size of the result of `subView`. Returns the allocated 283 /// buffer as `fullLocalView` and the view that matches the size of the result 284 /// of subview operation as `partialLocalView`. 285 struct PromotionInfo { 286 Value fullLocalView; 287 Value partialLocalView; 288 }; 289 FailureOr<PromotionInfo> 290 promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, 291 const AllocBufferCallbackFn &allocationFn, 292 DataLayout &layout); 293 294 /// Promote the `subViews` into a new buffer allocated at the insertion point 295 /// `b`. Promotion occurs in 3 steps: 296 /// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). 297 /// 2. Take a full view on the buffer. 298 /// 3. Take a partial slice of the full view in step 2. and copy into it. 299 /// 300 /// Return the modified linalg op (the modification happens in place) as well 301 /// as all the copy ops created. 302 FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op, 303 const LinalgPromotionOptions &options); 304 305 /// Emit a suitable vector form for a Linalg op with fully static shape. 306 LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp); 307 308 /// Emit a suitable vector form for a Copy op with fully static shape. 309 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); 310 311 /// Emit a loop nest of `scf.for` with the proper body for `linalgOp`. 312 FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter, 313 LinalgOp linalgOp); 314 315 /// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`. 316 FailureOr<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter, 317 LinalgOp linalgOp); 318 319 /// Emit a loop nest of `affine.for` with the proper body for `linalgOp`. 320 FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter, 321 LinalgOp linalgOp); 322 323 //===----------------------------------------------------------------------===// 324 // Preconditions that ensure the corresponding transformation succeeds and can 325 // be applied as a rewrite pattern. 326 //===----------------------------------------------------------------------===// 327 /// Promote memref.subviews feeding linalg-on-buffers operations. 328 LogicalResult promoteSubviewsPrecondition(Operation *op, 329 LinalgPromotionOptions options); 330 331 //===----------------------------------------------------------------------===// 332 // Transformations exposed as rewrite patterns. 333 //===----------------------------------------------------------------------===// 334 // Marker used as attribute name in generated Linalg rewriting transformations. 335 struct LinalgTransforms { 336 static const StringLiteral kLinalgTransformMarker; 337 }; 338 339 /// Helper class to control application of linalg transformation patterns. 340 /// Control comes in 2 forms: 341 /// 1. attribute matching and setting behavior using the attribute named 342 /// `kLinalgTransformMarker`. This can be used to build a state machine 343 /// using attributes and incrementally applying patterns to advance states. 344 /// 2. filter function, which is a simple lambda on the Operation* that 345 /// returns a LogicalResult. 346 struct LinalgTransformationFilter { 347 using FilterFunction = std::function<LogicalResult(Operation *)>; 348 349 explicit LinalgTransformationFilter( 350 ArrayRef<StringAttr> matchDisjunction = {}, 351 Optional<StringAttr> replacement = None); 352 353 explicit LinalgTransformationFilter( 354 const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {}, 355 Optional<StringAttr> replacement = None); 356 357 LinalgTransformationFilter(LinalgTransformationFilter &&) = default; 358 LinalgTransformationFilter(const LinalgTransformationFilter &) = default; 359 LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; 360 void replaceLinalgTransformationFilter(PatternRewriter &rewriter, 361 Operation *op) const; 362 bool hasReplacementFilter(Operation *op) const; 363 addFilterLinalgTransformationFilter364 LinalgTransformationFilter &addFilter(const FilterFunction &f) { 365 if (f) 366 filters.push_back(f); 367 return *this; 368 } 369 370 template <typename... OpTypes> addOpFilterLinalgTransformationFilter371 LinalgTransformationFilter &addOpFilter() { 372 return addFilter( 373 [](Operation *op) { return success(isa<OpTypes...>(op)); }); 374 } 375 addOpNameFilterLinalgTransformationFilter376 LinalgTransformationFilter &addOpNameFilter(StringRef opName) { 377 return addFilter([opName](Operation *op) { 378 return success(op->getName().getStringRef() == opName); 379 }); 380 } 381 setMatchByDefaultLinalgTransformationFilter382 LinalgTransformationFilter &setMatchByDefault() { 383 matchByDefault = true; 384 return *this; 385 } 386 387 private: 388 SmallVector<FilterFunction> filters; 389 SmallVector<StringAttr> matchDisjunction; 390 Optional<StringAttr> replacement; 391 /// When set to true, if the attribute is not set, it will be treated as 392 /// a match. Default is false. 393 bool matchByDefault; 394 }; 395 396 using TileSizeComputationFunction = 397 std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>; 398 399 /// Creates a number of ranges equal to the number of non-zero in `tileSizes`. 400 /// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument 401 /// has one entry per surrounding loop. It uses zero as the convention that a 402 /// particular loop is not tiled. This convention simplifies implementations by 403 /// avoiding affine map manipulations. 404 /// The returned ranges correspond to the loop ranges, in the proper order, that 405 /// are tiled and for which new loops will be created. Also the function returns 406 /// a map from loop indices of the LinalgOp to the corresponding non-empty range 407 /// indices of newly created loops. 408 using LoopIndexToRangeIndexMap = DenseMap<int, int>; 409 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 410 makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, 411 ValueRange allShapeSizes, ValueRange allTileSizes); 412 413 /// A description of a multi-size tiling comprising tile sizes and numbers of 414 /// tiles, expressed as Values which may or may not be constant. Multi-size 415 /// currently means two-size. 416 struct MultiSizeSpecification { 417 /// Tile sizes. 418 Value lowTileSize, highTileSize; 419 /// Number of tiles associated with each size. 420 Value lowTripCount, highTripCount; 421 }; 422 423 /// Emits the IR computing the multi-sized tiling specification with two tile 424 /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such that 425 /// there exist numbers of tiles with these sizes that fully cover the given 426 /// iteration space `dimension` of the structured `op`. 427 /// 428 /// The computation is as follows: 429 /// 430 /// b = originalTripCount floordiv sizeDivisor 431 /// t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor 432 /// d = (b + t - 1) floordiv t 433 /// s = (b floordiv d) * sizeDivisor 434 /// v = b % d 435 /// u = d - v 436 /// 437 /// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of 438 /// the corresponding tiles are `u` and `v`, respectively. Alternatively, 439 /// 440 /// s * u + (s + sizeDivisor) * v == original size, 441 /// where s mod sizeDivisor = 0. 442 /// 443 /// Expects all values to be positive. In some cases with the target tile size 444 /// sufficiently close to the dimension shape and non-unit divisor, it is 445 /// impossible to compute such sizes. If `emitAssertion` is set, also emit the 446 /// assertion that size computation succeeded. 447 /// 448 /// Returns the specification consisting of both tile values and the number of 449 /// tiles of each size. 450 FailureOr<MultiSizeSpecification> 451 computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, 452 OpFoldResult targetSize, OpFoldResult divisor, 453 bool emitAssertions = true); 454 455 /// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying 456 /// tiling by `numThreads`. 457 /// If non-empty, the `threadDimMapping` is added as an attribute to the 458 /// resulting `scf.foreach_thread`. 459 /// Zero tile sizes indicate that the dimension is not tiled, and can be thought 460 /// of as tiling by the full size of data. 461 /// It is the user's responsibility to ensure that `numThreads` is a 462 /// valid tiling specification (i.e. that only tiles parallel 463 /// dimensions, e.g. in the Linalg case). 464 struct ForeachThreadTilingResult { 465 Operation *tileOp; 466 Operation *tiledOp; 467 }; 468 FailureOr<ForeachThreadTilingResult> 469 tileToForeachThreadOp(RewriterBase &builder, TilingInterface op, 470 ArrayRef<OpFoldResult> numThreads, 471 ArrayRef<int64_t> threadDimMapping = {}); 472 473 /// Same as `tileToForeachThreadOp`, but calculate the number of threads 474 /// required using the given tileSizes. 475 FailureOr<ForeachThreadTilingResult> 476 tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op, 477 ArrayRef<OpFoldResult> tileSizes, 478 ArrayRef<int64_t> threadDimMapping = {}); 479 480 /// All indices returned by IndexOp should be invariant with respect to tiling. 481 /// Therefore, if an operation is tiled, we have to transform the indices 482 /// accordingly, i.e. offset them by the values of the corresponding induction 483 /// variables that are captured implicitly in the body of the op. 484 /// 485 /// Example. `linalg.generic` before tiling: 486 /// 487 /// #id_2d = (i, j) -> (i, j) 488 /// #pointwise_2d_trait = { 489 /// indexing_maps = [#id_2d, #id_2d], 490 /// iterator_types = ["parallel", "parallel"] 491 /// } 492 /// linalg.generic #pointwise_2d_trait %operand, %result { 493 /// ^bb0(%operand_in: f32, %result_in: f32): 494 /// %i = linalg.index 0 : index 495 /// %j = linalg.index 1 : index 496 /// <some operations that use %i, %j> 497 /// }: memref<50x100xf32>, memref<50x100xf32> 498 /// 499 /// After tiling pass with tiles sizes 10 and 25: 500 /// 501 /// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) 502 /// 503 /// %c1 = arith.constant 1 : index 504 /// %c0 = arith.constant 0 : index 505 /// %c25 = arith.constant 25 : index 506 /// %c10 = arith.constant 10 : index 507 /// operand_dim_0 = dim %operand, 0 : memref<50x100xf32> 508 /// operand_dim_1 = dim %operand, 1 : memref<50x100xf32> 509 /// scf.for %k = %c0 to operand_dim_0 step %c10 { 510 /// scf.for %l = %c0 to operand_dim_1 step %c25 { 511 /// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1] 512 /// : memref<50x100xf32> to memref<?x?xf32, #strided> 513 /// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1] 514 /// : memref<50x100xf32> to memref<?x?xf32, #strided> 515 /// linalg.generic pointwise_2d_trait %4, %5 { 516 /// ^bb0(%operand_in: f32, %result_in: f32): 517 /// %i = linalg.index 0 : index 518 /// %j = linalg.index 1 : index 519 /// // Indices `k` and `l` are implicitly captured in the body. 520 /// %transformed_i = arith.addi %i, %k : index // index `i` is offset by 521 /// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset 522 /// by %l 523 /// // Every use of %i, %j is replaced with %transformed_i, %transformed_j 524 /// <some operations that use %transformed_i, %transformed_j> 525 /// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> 526 /// } 527 /// } 528 /// 529 /// TODO: Investigate whether mixing implicit and explicit indices 530 /// does not lead to losing information. 531 void transformIndexOps(RewriterBase &b, LinalgOp op, 532 SmallVectorImpl<Value> &ivs, 533 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex); 534 535 struct LinalgPaddingOptions { 536 /// A padding value for every operand. 537 SmallVector<Attribute> paddingValues; setPaddingValuesLinalgPaddingOptions538 LinalgPaddingOptions &setPaddingValues(ArrayRef<Attribute> pv) { 539 paddingValues.assign(pv.begin(), pv.end()); 540 return *this; 541 } 542 /// A list of iterator dimensions to pad. 543 SmallVector<int64_t> paddingDimensions; setPaddingDimensionsLinalgPaddingOptions544 LinalgPaddingOptions &setPaddingDimensions(ArrayRef<int64_t> pd) { 545 paddingDimensions.assign(pd.begin(), pd.end()); 546 return *this; 547 } 548 /// A flag for every operand to mark the PadOp as nofold which enables packing 549 /// for statically shaped operands. 550 SmallVector<bool> packPaddings; setPackPaddingsLinalgPaddingOptions551 LinalgPaddingOptions &setPackPaddings(ArrayRef<bool> pp) { 552 packPaddings.assign(pp.begin(), pp.end()); 553 return *this; 554 } 555 /// A number of loops to hoist the PadOp out for every operand. 556 SmallVector<int64_t> hoistPaddings; setHoistPaddingsLinalgPaddingOptions557 LinalgPaddingOptions &setHoistPaddings(ArrayRef<int64_t> hp) { 558 hoistPaddings.assign(hp.begin(), hp.end()); 559 return *this; 560 } 561 /// A permutation vector for every operand used to transpose the packed PadOp 562 /// results. 563 SmallVector<SmallVector<int64_t>> transposePaddings; 564 LinalgPaddingOptions & setTransposePaddingsLinalgPaddingOptions565 setTransposePaddings(ArrayRef<SmallVector<int64_t>> tp) { 566 transposePaddings.assign(tp.begin(), tp.end()); 567 return *this; 568 } 569 }; 570 571 struct LinalgTilingAndFusionOptions { 572 /// Tile sizes used to tile the root operation. 573 SmallVector<int64_t> tileSizes; setTileSizesLinalgTilingAndFusionOptions574 LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) { 575 tileSizes.assign(ts.begin(), ts.end()); 576 return *this; 577 } 578 /// Tile interchange used to permute the tile loops. 579 SmallVector<int64_t> tileInterchange; 580 /// When specified, specifies distribution of generated tile loops to 581 /// processors. 582 Optional<LinalgLoopDistributionOptions> tileDistribution = None; 583 LinalgTilingAndFusionOptions & setDistributionOptionsLinalgTilingAndFusionOptions584 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { 585 tileDistribution = std::move(distributionOptions); 586 return *this; 587 } 588 }; 589 590 struct LinalgTilingOptions { 591 /// Computation function that returns the tile sizes for each operation. 592 /// Delayed construction of constant tile sizes should occur to interoperate 593 /// with folding. 594 TileSizeComputationFunction tileSizeComputationFunction = nullptr; 595 596 LinalgTilingOptions & setTileSizeComputationFunctionLinalgTilingOptions597 setTileSizeComputationFunction(TileSizeComputationFunction fun) { 598 tileSizeComputationFunction = std::move(fun); 599 return *this; 600 } 601 /// Set the `tileSizeComputationFunction` to return the values `ts`. The 602 /// values must not fold away when tiling. Otherwise, use a more robust 603 /// `tileSizeComputationFunction`. setTileSizesLinalgTilingOptions604 LinalgTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) { 605 tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; 606 return *this; 607 } 608 /// Convenience function to set the `tileSizeComputationFunction` to a 609 /// function that computes tile sizes at the point they are needed. Allows 610 /// proper interaction with folding. 611 LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts); 612 613 /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions. 614 /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together. 615 LinalgTilingOptions &scalarizeDynamicDims(); 616 617 /// The interchange vector to reorder the tiled loops. 618 SmallVector<unsigned, 4> interchangeVector = {}; 619 setInterchangeLinalgTilingOptions620 LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) { 621 interchangeVector.assign(interchange.begin(), interchange.end()); 622 return *this; 623 } 624 625 /// The type of tile loops to generate. 626 LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops; 627 setLoopTypeLinalgTilingOptions628 LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) { 629 loopType = lt; 630 return *this; 631 } 632 633 /// When specified, specifies distribution of generated tile loops to 634 /// processors. 635 Optional<LinalgLoopDistributionOptions> distribution = None; 636 637 LinalgTilingOptions & setDistributionOptionsLinalgTilingOptions638 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { 639 distribution = std::move(distributionOptions); 640 return *this; 641 } 642 643 /// Specification markers of how to distribute the `linalg.tiled_loop`. 644 SmallVector<StringRef, 2> distributionTypes = {}; 645 setDistributionTypesLinalgTilingOptions646 LinalgTilingOptions &setDistributionTypes(ArrayRef<StringRef> types) { 647 distributionTypes.assign(types.begin(), types.end()); 648 return *this; 649 } 650 651 /// Peel the specified loops. 652 SmallVector<int64_t> peeledLoops; 653 setPeeledLoopsLinalgTilingOptions654 LinalgTilingOptions &setPeeledLoops(ArrayRef<int64_t> loops) { 655 peeledLoops.clear(); 656 peeledLoops.append(loops.begin(), loops.end()); 657 return *this; 658 } 659 }; 660 661 /// Canonicalization patterns relevant to apply after tiling patterns. These are 662 /// applied automatically by the tiling pass but need to be applied manually 663 /// when tiling is called programmatically. 664 RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); 665 void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns); 666 667 /// 668 /// Linalg tiling pattern. 669 /// 670 /// Apply the `tiling` transformation as a pattern. 671 /// `filter` controls LinalgTransformMarker matching and update when specified. 672 /// See `tiling` for more details. 673 // TODO: TiledOpInterface 674 struct LinalgTilingPattern : public OpInterfaceRewritePattern<LinalgOp> { 675 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 676 LinalgTilingPattern( 677 MLIRContext *context, LinalgTilingOptions options, 678 LinalgTransformationFilter f = LinalgTransformationFilter(), 679 PatternBenefit benefit = 1); 680 681 /// Construct a pattern specifically applied to `opName`. 682 LinalgTilingPattern( 683 StringRef opName, MLIRContext *context, LinalgTilingOptions options, 684 LinalgTransformationFilter f = LinalgTransformationFilter(), 685 PatternBenefit benefit = 1); 686 687 /// `matchAndRewrite` implementation that returns the significant transformed 688 /// pieces of IR. 689 FailureOr<TiledLinalgOp> 690 returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const; 691 matchAndRewriteLinalgTilingPattern692 LogicalResult matchAndRewrite(LinalgOp op, 693 PatternRewriter &rewriter) const override { 694 return returningMatchAndRewrite(op, rewriter); 695 } 696 697 private: 698 /// LinalgTransformMarker handles special attribute manipulations. 699 LinalgTransformationFilter filter; 700 /// Options to control tiling; 701 LinalgTilingOptions options; 702 }; 703 704 /// 705 /// Linalg padding pattern. 706 /// 707 /// Apply the `padding` transformation as a pattern. 708 /// `filter` controls LinalgTransformMarker matching and update when specified. 709 /// See `padding` for more details. 710 struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> { 711 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 712 LinalgPaddingPattern( 713 MLIRContext *context, 714 LinalgPaddingOptions options = LinalgPaddingOptions(), 715 LinalgTransformationFilter f = LinalgTransformationFilter(), 716 PatternBenefit benefit = 1); 717 718 /// Construct a pattern specifically applied to `opName`. 719 LinalgPaddingPattern( 720 StringRef opName, MLIRContext *context, 721 LinalgPaddingOptions options = LinalgPaddingOptions(), 722 LinalgTransformationFilter f = LinalgTransformationFilter(), 723 PatternBenefit benefit = 1); 724 725 /// `matchAndRewrite` implementation that returns the significant transformed 726 /// pieces of IR. 727 FailureOr<LinalgOp> returningMatchAndRewrite(LinalgOp op, 728 PatternRewriter &rewriter) const; 729 matchAndRewriteLinalgPaddingPattern730 LogicalResult matchAndRewrite(LinalgOp op, 731 PatternRewriter &rewriter) const override { 732 return returningMatchAndRewrite(op, rewriter); 733 } 734 735 private: 736 /// LinalgTransformMarker handles special attribute manipulations. 737 LinalgTransformationFilter filter; 738 /// Options to control padding and hoisting. 739 LinalgPaddingOptions options; 740 }; 741 742 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 743 /// convolution ops. 744 struct DownscaleSizeOneWindowed2DConvolution final 745 : public OpRewritePattern<Conv2DNhwcHwcfOp> { 746 DownscaleSizeOneWindowed2DConvolution( 747 MLIRContext *context, 748 LinalgTransformationFilter f = LinalgTransformationFilter(), 749 PatternBenefit benefit = 1) 750 : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), 751 filter(std::move(f)) {} 752 753 FailureOr<Conv1DNwcWcfOp> 754 returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 755 PatternRewriter &rewriter) const; 756 matchAndRewritefinal757 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, 758 PatternRewriter &rewriter) const override { 759 return returningMatchAndRewrite(convOp, rewriter); 760 } 761 762 private: 763 /// LinalgTransformMarker handles special attribute manipulations. 764 LinalgTransformationFilter filter; 765 }; 766 767 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 768 /// dimensions into 1-D depthwise convolution ops. 769 struct DownscaleDepthwiseConv2DNhwcHwcOp final 770 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 771 DownscaleDepthwiseConv2DNhwcHwcOp( 772 MLIRContext *context, 773 LinalgTransformationFilter f = LinalgTransformationFilter(), 774 PatternBenefit benefit = 1) 775 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit), 776 filter(std::move(f)) {} 777 778 FailureOr<DepthwiseConv1DNwcWcOp> 779 returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 780 PatternRewriter &rewriter) const; 781 matchAndRewritefinal782 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 783 PatternRewriter &rewriter) const override { 784 return returningMatchAndRewrite(convOp, rewriter); 785 } 786 787 private: 788 /// LinalgTransformMarker handles special attribute manipulations. 789 LinalgTransformationFilter filter; 790 }; 791 792 /// 793 /// Linalg tile and fuse tensor ops pattern. 794 /// 795 /// Apply tiling and fusion as a pattern. 796 /// `filter` controls LinalgTransformMarker matching and update when specified. 797 /// See `tileConsumerAndFuseProducers` for more details. 798 struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern { 799 // Entry point to match any LinalgOp. 800 LinalgTileAndFuseTensorOpsPattern( 801 MLIRContext *context, LinalgTilingAndFusionOptions options, 802 LinalgTransformationFilter f = LinalgTransformationFilter(), 803 PatternBenefit benefit = 1); 804 // Entry point to match a specific LinalgOp. 805 LinalgTileAndFuseTensorOpsPattern( 806 StringRef opName, MLIRContext *context, 807 LinalgTilingAndFusionOptions options, 808 LinalgTransformationFilter f = LinalgTransformationFilter(), 809 PatternBenefit benefit = 1); 810 811 /// `matchAndRewrite` implementation that returns the significant transformed 812 /// pieces of IR. 813 FailureOr<TileLoopNest> 814 returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const; 815 matchAndRewriteLinalgTileAndFuseTensorOpsPattern816 LogicalResult matchAndRewrite(Operation *op, 817 PatternRewriter &rewriter) const override { 818 return returningMatchAndRewrite(op, rewriter); 819 } 820 821 private: 822 /// LinalgTransformMarker handles special attribute manipulations. 823 LinalgTransformationFilter filter; 824 /// Tile sizes and interchange used to tile the root operation. 825 LinalgTilingAndFusionOptions options; 826 }; 827 828 /// 829 /// Linalg generic interchange pattern. 830 /// 831 /// Apply the `interchange` transformation on a RewriterBase. 832 /// `filter` controls LinalgTransformMarker matching and update when specified. 833 /// See `interchange` for more details. 834 struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> { 835 using OpRewritePattern<GenericOp>::OpRewritePattern; 836 837 /// GenericOp-specific constructor with an optional `filter`. 838 GenericOpInterchangePattern( 839 MLIRContext *context, ArrayRef<unsigned> interchangeVector, 840 LinalgTransformationFilter f = LinalgTransformationFilter(), 841 PatternBenefit benefit = 1); 842 843 /// `matchAndRewrite` implementation that returns the significant transformed 844 /// pieces of IR. 845 FailureOr<GenericOp> 846 returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const; 847 matchAndRewriteGenericOpInterchangePattern848 LogicalResult matchAndRewrite(GenericOp op, 849 PatternRewriter &rewriter) const override { 850 return returningMatchAndRewrite(op, rewriter); 851 } 852 853 private: 854 /// LinalgTransformMarker handles special attribute manipulations. 855 LinalgTransformationFilter filter; 856 /// The interchange vector to reorder the iterators and indexing_maps dims. 857 SmallVector<unsigned, 8> interchangeVector; 858 }; 859 860 /// 861 /// Linalg generalization pattern. 862 /// 863 /// Apply the `generalization` transformation as a pattern. 864 /// `filter` controls LinalgTransformMarker matching and update when specified. 865 /// See `generalization` for more details. 866 struct LinalgGeneralizationPattern 867 : public OpInterfaceRewritePattern<LinalgOp> { 868 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 869 LinalgGeneralizationPattern( 870 MLIRContext *context, 871 LinalgTransformationFilter f = LinalgTransformationFilter(), 872 PatternBenefit benefit = 1); 873 874 /// Construct a pattern specifically applied to `opName`. 875 LinalgGeneralizationPattern( 876 StringRef opName, MLIRContext *context, 877 LinalgTransformationFilter f = LinalgTransformationFilter(), 878 PatternBenefit benefit = 1); 879 880 /// `matchAndRewrite` implementation that returns the significant transformed 881 /// pieces of IR. 882 FailureOr<GenericOp> 883 returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const; 884 matchAndRewriteLinalgGeneralizationPattern885 LogicalResult matchAndRewrite(LinalgOp op, 886 PatternRewriter &rewriter) const override { 887 return returningMatchAndRewrite(op, rewriter); 888 } 889 890 private: 891 /// LinalgTransformMarker handles special attribute manipulations. 892 LinalgTransformationFilter filter; 893 }; 894 895 /// 896 /// Linalg peeling patterns. 897 /// 898 899 /// Compute the loops to peel and return them in a SmallVector. Loops will be 900 /// peeled in order of appearance in the SmallVector. This order will impact the 901 /// output IR. If an inner-to-outer order is provided, the peeled iterations of 902 /// the outer loops will also contain the peeled inner loops. If an 903 /// outer-to-inner order is provided, the peeled iterations of the outer loops 904 /// will not contain any peeled inner loops. 905 using LoopsToPeelComputationFunction = std::function<void( 906 OpBuilder &, Operation *, SmallVectorImpl<scf::ForOp> &)>; 907 908 struct LinalgPeelOptions { 909 LoopsToPeelComputationFunction loopsToPeelComputationFunction = nullptr; 910 }; 911 912 /// `filter` controls LinalgTransformMarker matching and update when specified. 913 struct LinalgPeelingPattern : public OpInterfaceRewritePattern<LinalgOp> { 914 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 915 LinalgPeelingPattern( 916 MLIRContext *context, 917 LinalgTransformationFilter f = LinalgTransformationFilter(), 918 LinalgPeelOptions options = LinalgPeelOptions(), 919 PatternBenefit benefit = 1); 920 921 /// Construct a pattern specifically applied to `opName`. 922 LinalgPeelingPattern( 923 StringRef opName, MLIRContext *context, 924 LinalgPeelOptions options = LinalgPeelOptions(), 925 LinalgTransformationFilter f = LinalgTransformationFilter(), 926 PatternBenefit benefit = 1); 927 928 LogicalResult matchAndRewrite(LinalgOp linalgOp, 929 PatternRewriter &rewriter) const override; 930 931 private: 932 /// LinalgTransformMarker handles special attribute manipulations. 933 const LinalgTransformationFilter filter; 934 /// Peeling options. 935 const LinalgPeelOptions options; 936 }; 937 938 /// 939 /// Linalg vectorization patterns. 940 /// 941 /// Empty for now, used for SFINAE purposes only. 942 struct LinalgVectorizationOptions {}; 943 944 /// `filter` controls LinalgTransformMarker matching and update when specified. 945 /// See `vectorizeLinalgOp` for more details. 946 struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> { 947 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 948 LinalgVectorizationPattern( 949 MLIRContext *context, 950 LinalgTransformationFilter f = LinalgTransformationFilter(), 951 LinalgVectorizationOptions options = LinalgVectorizationOptions(), 952 PatternBenefit benefit = 1); 953 954 /// Construct a pattern specifically applied to `opName`. 955 LinalgVectorizationPattern( 956 StringRef opName, MLIRContext *context, 957 LinalgVectorizationOptions options = LinalgVectorizationOptions(), 958 LinalgTransformationFilter f = LinalgTransformationFilter(), 959 PatternBenefit benefit = 1); 960 961 LogicalResult matchAndRewrite(LinalgOp linalgOp, 962 PatternRewriter &rewriter) const override; 963 964 private: 965 /// LinalgTransformMarker handles special attribute manipulations. 966 LinalgTransformationFilter filter; 967 }; 968 969 /// `filter` controls LinalgTransformMarker matching and update when specified. 970 /// See `vectorizeLinalgOp` for more details. 971 struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> { 972 using OpRewritePattern<memref::CopyOp>::OpRewritePattern; 973 974 LogicalResult matchAndRewrite(memref::CopyOp copyOp, 975 PatternRewriter &rewriter) const override; 976 }; 977 978 /// Return vector::CombiningKind for the given op. 979 llvm::Optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp); 980 981 //===----------------------------------------------------------------------===// 982 // Transformation and lowering options exposed as auxiliary structs. 983 //===----------------------------------------------------------------------===// 984 /// Options to control the application of enabling transformations. 985 /// Hoisting transformations are always deemed beneficial and must be disabled 986 /// explicitly. 987 struct LinalgEnablingOptions { 988 /// Enable LICM. 989 bool licm = true; 990 LinalgEnablingOptions &enableLICM(bool val = true) { 991 licm = val; 992 return *this; 993 } 994 /// Enable hoisting of redundant vector transfer ops. 995 bool hoistRedundantVectorTransfers = true; 996 LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val = true) { 997 hoistRedundantVectorTransfers = val; 998 return *this; 999 } 1000 /// Enable hoisting of redundant vector transfer ops on tensor. 1001 bool hoistRedundantVectorTransfersOnTensor = true; 1002 LinalgEnablingOptions & 1003 enableHoistRedundantVectorTransfersOnTensor(bool val = true) { 1004 hoistRedundantVectorTransfersOnTensor = val; 1005 return *this; 1006 } 1007 }; 1008 1009 /// Vector lowering options control how ops are lowered down to 1-D and scf.for 1010 /// form. 1011 struct LinalgVectorLoweringOptions { 1012 /// Enable lowering of vector.contract. 1013 /// In a progressive lowering of vectors, this would be the 1st step. 1014 bool contractionLowering = false; 1015 LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) { 1016 contractionLowering = val; 1017 return *this; 1018 } 1019 /// Enable lowering of vector.multi_reduce. 1020 /// In a progressive lowering of vectors, this would be the 2nd step. 1021 bool multiReductionLowering = false; 1022 LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) { 1023 multiReductionLowering = val; 1024 return *this; 1025 } 1026 /// Trigger full / partial vector.transfer splits. 1027 /// In a progressive lowering of vectors, this would be the 3rd step. 1028 bool transferPartialRewrite = false; 1029 LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) { 1030 transferPartialRewrite = val; 1031 return *this; 1032 } 1033 /// Enable lowering of vector.transfer to scf. 1034 /// In a progressive lowering of vectors, this would be the 4th step. 1035 bool transferToSCFConversion = false; 1036 LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) { 1037 transferToSCFConversion = val; 1038 return *this; 1039 } 1040 /// Maximal transfer rank under which we do not lower further. 1041 int64_t maxTransferRank = 1; setMaxTransferRankLinalgVectorLoweringOptions1042 LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) { 1043 maxTransferRank = val; 1044 return *this; 1045 } 1046 /// Vector lowering operations may result in surprising behavior when 1047 /// composing multiple codegen strategies and must be enabled explicitly. 1048 /// In a progressive lowering of vectors, this would be the 5th step. 1049 bool transferLowering = true; 1050 LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) { 1051 transferLowering = val; 1052 return *this; 1053 } 1054 /// Enable lowering of vector.shape_cast to insert/extract. 1055 /// In a progressive lowering of vectors, this would be the 6th step. 1056 bool shapeCastLowering = true; 1057 LinalgVectorLoweringOptions &enableShapeCastLowering(bool val = true) { 1058 shapeCastLowering = val; 1059 return *this; 1060 } 1061 /// Enable lowering of vector.transpose. 1062 /// In a progressive lowering of vectors, this would be the 7th step. 1063 bool transposeLowering = false; 1064 LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) { 1065 transposeLowering = val; 1066 return *this; 1067 } 1068 /// Enable AVX2-specific lowerings. 1069 bool avx2Lowering = false; 1070 LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) { 1071 avx2Lowering = val; 1072 return *this; 1073 } 1074 1075 /// Configure the post staged-patterns late vector.transfer to scf 1076 /// conversion. 1077 VectorTransferToSCFOptions vectorTransferToSCFOptions; 1078 LinalgVectorLoweringOptions & setVectorTransferToSCFOptionsLinalgVectorLoweringOptions1079 setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { 1080 vectorTransferToSCFOptions = options; 1081 return *this; 1082 } 1083 /// Configure late vector transformations. 1084 vector::VectorTransformsOptions vectorTransformOptions; 1085 LinalgVectorLoweringOptions & setVectorTransformsOptionsLinalgVectorLoweringOptions1086 setVectorTransformsOptions(vector::VectorTransformsOptions options) { 1087 vectorTransformOptions = options; 1088 return *this; 1089 } 1090 /// Configure specialized vector lowerings. 1091 x86vector::avx2::LoweringOptions avx2LoweringOptions; 1092 LinalgVectorLoweringOptions & setAVX2LoweringOptionsLinalgVectorLoweringOptions1093 setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) { 1094 avx2LoweringOptions = options; 1095 return *this; 1096 } 1097 }; 1098 1099 //===----------------------------------------------------------------------===// 1100 // Transformations exposed as rewrite patterns. 1101 //===----------------------------------------------------------------------===// 1102 /// 1103 /// Linalg lowering patterns. 1104 /// 1105 /// Apply the `linalgLowerOpToLoops` transformation as a pattern. 1106 /// `filter` controls LinalgTransformMarker matching and update when specified. 1107 /// See `linalgLowerOpToLoops` for more details. 1108 enum class LinalgLoweringType { 1109 LibraryCall = 0, 1110 Loops = 1, 1111 AffineLoops = 2, 1112 ParallelLoops = 3 1113 }; 1114 1115 template <typename OpTy> 1116 struct LinalgLoweringPattern : public RewritePattern { 1117 LinalgLoweringPattern( 1118 MLIRContext *context, LinalgLoweringType loweringType, 1119 LinalgTransformationFilter f = LinalgTransformationFilter(), 1120 PatternBenefit benefit = 1) RewritePatternLinalgLoweringPattern1121 : RewritePattern(OpTy::getOperationName(), benefit, context), 1122 filter(std::move(f)), loweringType(loweringType) {} 1123 1124 // TODO: Move implementation to .cpp once named ops are auto-generated. matchAndRewriteLinalgLoweringPattern1125 LogicalResult matchAndRewrite(Operation *op, 1126 PatternRewriter &rewriter) const override { 1127 LinalgOp linalgOp = dyn_cast<LinalgOp>(op); 1128 if (!linalgOp) 1129 return failure(); 1130 if (failed(filter.checkAndNotify(rewriter, linalgOp))) 1131 return failure(); 1132 1133 switch (loweringType) { 1134 case LinalgLoweringType::LibraryCall: 1135 // TODO: Move lowering to library calls here. 1136 return failure(); 1137 case LinalgLoweringType::Loops: 1138 if (failed(linalgOpToLoops(rewriter, op))) 1139 return failure(); 1140 break; 1141 case LinalgLoweringType::AffineLoops: 1142 if (failed(linalgOpToAffineLoops(rewriter, op))) 1143 return failure(); 1144 break; 1145 case LinalgLoweringType::ParallelLoops: 1146 if (failed(linalgOpToParallelLoops(rewriter, op))) 1147 return failure(); 1148 break; 1149 } 1150 1151 rewriter.eraseOp(op); 1152 return success(); 1153 } 1154 1155 private: 1156 /// LinalgTransformMarker handles special attribute manipulations. 1157 LinalgTransformationFilter filter; 1158 /// Controls whether the pattern lowers to library calls, scf.for, affine.for 1159 /// or scf.parallel. 1160 LinalgLoweringType loweringType; 1161 }; 1162 1163 /// Linalg generalization patterns 1164 1165 /// Populates `patterns` with patterns to convert spec-generated named ops to 1166 /// linalg.generic ops. 1167 void populateLinalgNamedOpsGeneralizationPatterns( 1168 RewritePatternSet &patterns, 1169 const LinalgTransformationFilter &filter = LinalgTransformationFilter()); 1170 1171 /// Linalg decompose convolutions patterns 1172 1173 /// Populates patterns to decompose high-D convolution ops into low-D ones. This 1174 /// is a step in progressive lowering for convolution ops, afterwards we can 1175 /// vectorize the low-D convolution ops. 1176 void populateDecomposeConvolutionPatterns( 1177 RewritePatternSet &patterns, 1178 const LinalgTransformationFilter &filter = LinalgTransformationFilter(), 1179 PatternBenefit benefit = 1); 1180 1181 //===----------------------------------------------------------------------===// 1182 // Op-specific patterns. 1183 //===----------------------------------------------------------------------===// 1184 1185 /// tensor::PadOp is not canonicalized away yet, so we provide a transformation 1186 /// to `linalg.generic`. 1187 struct PadOpTransformationPattern : public OpRewritePattern<tensor::PadOp> { 1188 using OpRewritePattern<tensor::PadOp>::OpRewritePattern; 1189 1190 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1191 PatternRewriter &rewriter) const override; 1192 }; 1193 1194 /// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands to 1195 /// a static bounding box. Use `paddingValues` and `packPaddings` to set padding 1196 /// value and nofold attribute of the created tensor::PadOps, respectively. 1197 /// Update `paddedOp` to the cloned operation with statically shaped 1198 /// `paddingDimensions` and return the extracted dynamically shaped results. If 1199 /// padding fails, return failure. 1200 FailureOr<SmallVector<Value>> 1201 rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, 1202 ArrayRef<int64_t> paddingDimensions, 1203 ArrayRef<Attribute> paddingValues, 1204 ArrayRef<bool> packPaddings, LinalgOp &paddedOp); 1205 1206 using OptimizeCopyFn = 1207 std::function<LogicalResult(PatternRewriter &, tensor::PadOp, Value)>; 1208 1209 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp and 1210 /// InsertSliceOp. For now, only constant padding values are supported. 1211 /// `OptimizeCopyFn` can be used to customize copying step optimization. 1212 struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> { 1213 GeneralizePadOpPattern(MLIRContext *context, 1214 OptimizeCopyFn optimizeCopyFn = nullptr, 1215 PatternBenefit benefit = 1) 1216 : OpRewritePattern<tensor::PadOp>(context, benefit), 1217 optimizeCopyFn(std::move(optimizeCopyFn)) {} 1218 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1219 PatternRewriter &rewriter) const override; 1220 1221 protected: 1222 OptimizeCopyFn optimizeCopyFn; 1223 Value createFillOrGenerateOp(PatternRewriter &rewriter, tensor::PadOp padOp, 1224 Value dest, 1225 const SmallVector<Value> &dynSizes) const; 1226 }; 1227 1228 /// Populates `patterns` with patterns that vectorize tensor.pad. 1229 /// These patterns are meant to apply in a complementary fashion. Benefits 1230 /// are used to encode a certain ordering of pattern application. To avoid 1231 /// scattering magic constants throughout the code base, the patterns must be 1232 /// added with this function. `baseBenefit` can be used to offset the benefit 1233 /// of all tensor::PadOp vectorization patterns by a certain value. 1234 void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, 1235 PatternBenefit baseBenefit = 1); 1236 1237 /// Match and rewrite for the pattern: 1238 /// ``` 1239 /// %alloc = ... 1240 /// [optional] %view = memref.view %alloc ... 1241 /// %subView = subview %allocOrView ... 1242 /// [optional] linalg.fill(%allocOrView, %cst) ... 1243 /// ... 1244 /// memref.copy(%in, %subView) ... 1245 /// vector.transfer_read %allocOrView[...], %cst ... 1246 /// ``` 1247 /// into 1248 /// ``` 1249 /// [unchanged] %alloc = ... 1250 /// [unchanged] [optional] %view = memref.view %alloc ... 1251 /// [unchanged] [unchanged] %subView = subview %allocOrView ... 1252 /// ... 1253 /// vector.transfer_read %in[...], %cst ... 1254 /// ``` 1255 /// Where there is no interleaved use between memref.copy and transfer_read as 1256 /// well as no interleaved use between linalg.fill and memref.copy (if 1257 /// linalg.fill is specified). 1258 /// This is a custom rewrite to forward partial reads (with optional fills) to 1259 /// vector.transfer_read. 1260 struct LinalgCopyVTRForwardingPattern 1261 : public OpRewritePattern<vector::TransferReadOp> { 1262 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 1263 1264 LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, 1265 PatternRewriter &rewriter) const override; 1266 }; 1267 1268 /// Match and rewrite for the pattern: 1269 /// ``` 1270 /// %alloc = ... 1271 /// [optional] %view = memref.view %alloc ... 1272 /// %subView = subview %allocOrView... 1273 /// ... 1274 /// vector.transfer_write %..., %allocOrView[...] 1275 /// memref.copy(%subView, %out) 1276 /// ``` 1277 /// into 1278 /// ``` 1279 /// [unchanged] %alloc = ... 1280 /// [unchanged] [optional] %view = memref.view %alloc ... 1281 /// [unchanged] %subView = subview %allocOrView... 1282 /// ... 1283 /// vector.transfer_write %..., %out[...] 1284 /// ``` 1285 /// Where there is no interleaved use between transfer_write and memref.copy. 1286 /// This is a custom rewrite to forward partial writes to vector.transfer_write. 1287 struct LinalgCopyVTWForwardingPattern 1288 : public OpRewritePattern<vector::TransferWriteOp> { 1289 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 1290 1291 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 1292 PatternRewriter &rewriter) const override; 1293 }; 1294 1295 //===----------------------------------------------------------------------===// 1296 // Support for staged pattern application. 1297 //===----------------------------------------------------------------------===// 1298 /// Helper function to allow applying rewrite patterns, interleaved with more 1299 /// global transformations, in a staged fashion: 1300 /// 1. the first stage consists of a list of FrozenRewritePatternSet. Each 1301 /// FrozenRewritePatternSet in this list is applied once, in order. 1302 /// 2. the second stage consists of a single RewritePattern that is applied 1303 /// greedily until convergence. 1304 /// 3. the third stage consists of applying a lambda, generally used for 1305 /// non-local transformation effects. This allows creating custom fused 1306 /// transformations where patterns can be ordered and applied at a finer 1307 /// granularity than a sequence of traditional compiler passes. 1308 LogicalResult applyStagedPatterns( 1309 Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns, 1310 const FrozenRewritePatternSet &stage2Patterns, 1311 function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr); 1312 1313 /// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)). 1314 struct ExtractSliceOfPadTensorSwapPattern 1315 : public OpRewritePattern<tensor::ExtractSliceOp> { 1316 /// A function to control pattern application and rewrite logic. 1317 /// 1318 /// The function will be given the slice op and should return: 1319 /// - None: to fail the match and not apply the pattern; 1320 /// - true: to apply the pattern with zero slice guard; 1321 /// - false: to apply the pattern without zero slice guard. 1322 /// 1323 /// See the documentation for tensor::bubbleUpPadSlice regarding zero slice 1324 /// guard. 1325 using ControlFn = std::function<llvm::Optional<bool>(tensor::ExtractSliceOp)>; 1326 1327 ExtractSliceOfPadTensorSwapPattern(MLIRContext *context, 1328 ControlFn controlFn = nullptr, 1329 PatternBenefit benefit = 1) OpRewritePatternExtractSliceOfPadTensorSwapPattern1330 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} 1331 1332 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 1333 PatternRewriter &rewriter) const override; 1334 1335 private: 1336 ControlFn controlFn; 1337 }; 1338 1339 //===----------------------------------------------------------------------===// 1340 // Helper classes for type list expansion. 1341 //===----------------------------------------------------------------------===// 1342 template <typename... OpTypes> 1343 class VectorizationPatterns; 1344 1345 template <> 1346 class VectorizationPatterns<> { 1347 public: insert(RewritePatternSet & patterns,const LinalgVectorizationOptions & options,const LinalgTransformationFilter & f)1348 static void insert(RewritePatternSet &patterns, 1349 const LinalgVectorizationOptions &options, 1350 const LinalgTransformationFilter &f) {} 1351 }; 1352 1353 template <typename OpTy, typename... OpTypes> 1354 class VectorizationPatterns<OpTy, OpTypes...> { 1355 public: insert(RewritePatternSet & patterns,const LinalgVectorizationOptions & options,const LinalgTransformationFilter & f)1356 static void insert(RewritePatternSet &patterns, 1357 const LinalgVectorizationOptions &options, 1358 const LinalgTransformationFilter &f) { 1359 patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(), 1360 patterns.getContext(), options, f); 1361 VectorizationPatterns<OpTypes...>::insert(patterns, options, f); 1362 } 1363 }; 1364 1365 template <typename... OpTypes> 1366 class TilingPatterns; 1367 1368 template <> 1369 class TilingPatterns<> { 1370 public: insert(RewritePatternSet & patterns,const LinalgTilingOptions & options,const LinalgTransformationFilter & f)1371 static void insert(RewritePatternSet &patterns, 1372 const LinalgTilingOptions &options, 1373 const LinalgTransformationFilter &f) {} 1374 }; 1375 1376 template <typename OpTy, typename... OpTypes> 1377 class TilingPatterns<OpTy, OpTypes...> { 1378 public: insert(RewritePatternSet & patterns,const LinalgTilingOptions & options,const LinalgTransformationFilter & f)1379 static void insert(RewritePatternSet &patterns, 1380 const LinalgTilingOptions &options, 1381 const LinalgTransformationFilter &f) { 1382 patterns.add<LinalgTilingPattern>(OpTy::getOperationName(), 1383 patterns.getContext(), options, f); 1384 TilingPatterns<OpTypes...>::insert(patterns, options, f); 1385 } 1386 }; 1387 1388 /// Function signature to control reduction splitting. This returns a pair 1389 /// containing a ratio and a dimension index. The ratio is used to split the 1390 /// reduction dimension. The dimension index is used to control where the extra 1391 /// dimension is added to the intermediate tensor shape. If the ratio value is 1392 /// less or equal to 1 then nothing will be done. 1393 // TODO: don't use unsigned unless doing bit manipulation. 1394 using ControlSplitReductionFn = 1395 std::function<std::pair<int64_t, unsigned>(LinalgOp op)>; 1396 1397 /// Patterns to apply `splitReduction` below. 1398 void populateSplitReductionPattern( 1399 RewritePatternSet &patterns, 1400 const ControlSplitReductionFn &controlSplitReductionFn, 1401 const LinalgTransformationFilter &f = LinalgTransformationFilter(), 1402 bool useAlloc = false); 1403 1404 /// Apply transformation to split the single linalg op reduction into a parallel 1405 /// and reduction dimension. Then create a new linalg.generic op doing the rest 1406 /// of the reduction. Return the new linalg op with an extra parallel dimension 1407 /// or failure if the transformation didn't happen. 1408 /// Example: 1409 /// ``` 1410 /// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, 1411 /// affine_map<(d0) -> ()>], 1412 /// iterator_types = ["reduction"]} 1413 /// ins(%in : tensor<32xf32>) 1414 /// outs(%out : tensor<f32>) { 1415 /// ^bb0(%arg1: f32, %arg2: f32): 1416 /// %y = arith.addf %arg1, %arg2 : f32 1417 /// linalg.yield %y : f32 1418 /// } -> tensor<f32> 1419 /// ``` 1420 /// To: 1421 /// ``` 1422 /// %cst = arith.constant 0.000000e+00 : f32 1423 /// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32> 1424 /// %1 = linalg.init_tensor [4] : tensor<4xf32> 1425 /// %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> 1426 /// %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 1427 /// affine_map<(d0, d1) -> (d0)>], 1428 /// iterator_types = ["parallel", "reduction"]} 1429 /// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) { 1430 /// ^bb0(%arg3: f32, %arg5: f32): 1431 /// %5 = arith.addf %arg3, %arg4 : f32 1432 /// linalg.yield %5 : f32 1433 /// } -> tensor<4xf32> 1434 /// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, 1435 /// affine_map<(d0) -> ()>], 1436 /// iterator_types = ["reduction"]} 1437 /// ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) { 1438 /// ^bb0(%arg3: f32, %arg4: f32): 1439 /// %5 = arith.addf %arg3, %arg4 : f32 1440 /// linalg.yield %5 : f32 1441 /// } -> tensor<f32> 1442 /// ``` 1443 FailureOr<LinalgOp> 1444 splitReduction(PatternRewriter &b, LinalgOp op, 1445 const ControlSplitReductionFn &controlSplitReductionFn, 1446 const LinalgTransformationFilter &f, bool useAlloc = false); 1447 1448 /// Filterless version of the above. 1449 /// Returns both the new linalg ops as well as the fillOp needed to initialize 1450 /// the temporary expanded tensor with the proper neutral element. 1451 struct SplitReductionResult { 1452 Operation *initOrAlloc; 1453 FillOp fillOp; 1454 LinalgOp splitLinalgOp; 1455 LinalgOp resultCombiningLinalgOp; 1456 }; 1457 FailureOr<SplitReductionResult> 1458 splitReduction(PatternRewriter &b, LinalgOp op, 1459 const ControlSplitReductionFn &controlSplitReductionFn, 1460 bool useAlloc = false); 1461 1462 /// Scaling-based implementation of the split reduction transformation. 1463 /// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension 1464 /// `k` into `k * scale + kk`. 1465 /// 1466 /// Example: 1467 /// ``` 1468 /// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) 1469 /// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 1470 /// ``` 1471 /// 1472 /// Is transformed to: 1473 /// 1474 /// ``` 1475 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)> 1476 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)> 1477 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> 1478 /// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 1479 /// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 1480 /// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> 1481 /// %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32> 1482 /// %cst = arith.constant 0.000000e+00 : f32 1483 /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) -> 1484 /// tensor<16x32x64xf32> 1485 /// %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1> 1486 /// 1487 /// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], 1488 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 1489 /// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>) 1490 /// outs(%1 : tensor<16x32x64xf32>) { 1491 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32): 1492 /// %5 = arith.mulf %arg3, %arg4 : f32 1493 /// %6 = arith.addf %arg6, %5 : f32 1494 /// linalg.yield %6 : f32 1495 /// } -> tensor<16x32x64xf32> 1496 /// 1497 /// %4 = linalg.generic {indexing_maps = [#map4, #map5], 1498 /// iterator_types = ["parallel", "parallel", "reduction"]} 1499 // ins(%3 : tensor<16x32x64xf32>) 1500 /// outs(%C : tensor<16x32xf32>) { 1501 /// ^bb0(%arg3: f32, %arg4: f32): 1502 /// %5 = arith.addf %arg3, %arg4 : f32 1503 /// linalg.yield %5 : f32 1504 /// } -> tensor<16x32xf32> 1505 /// 1506 /// return %4 : tensor<16x32xf32> 1507 /// ``` 1508 FailureOr<SplitReductionResult> 1509 splitReductionByScaling(PatternRewriter &b, LinalgOp op, 1510 const ControlSplitReductionFn &controlSplitReductionFn, 1511 bool useAlloc = false); 1512 1513 } // namespace linalg 1514 } // namespace mlir 1515 1516 #endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H 1517