1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent patterns to rewrite a vector.transfer 10 // op into a fully in-bounds part and a partial part. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <type_traits> 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 22 23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Interfaces/VectorInterfaces.h" 27 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/MapVector.h" 30 #include "llvm/ADT/STLExtras.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 35 #define DEBUG_TYPE "vector-transfer-split" 36 37 using namespace mlir; 38 using namespace mlir::vector; 39 40 static Optional<int64_t> extractConstantIndex(Value v) { 41 if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>()) 42 return cstOp.value(); 43 if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>()) 44 if (affineApplyOp.getAffineMap().isSingleConstant()) 45 return affineApplyOp.getAffineMap().getSingleConstantResult(); 46 return None; 47 } 48 49 // Missing foldings of scf.if make it necessary to perform poor man's folding 50 // eagerly, especially in the case of unrolling. In the future, this should go 51 // away once scf.if folds properly. 52 static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) { 53 auto maybeCstV = extractConstantIndex(v); 54 auto maybeCstUb = extractConstantIndex(ub); 55 if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb) 56 return Value(); 57 return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub); 58 } 59 60 /// Build the condition to ensure that a particular VectorTransferOpInterface 61 /// is in-bounds. 62 static Value createInBoundsCond(RewriterBase &b, 63 VectorTransferOpInterface xferOp) { 64 assert(xferOp.permutation_map().isMinorIdentity() && 65 "Expected minor identity map"); 66 Value inBoundsCond; 67 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 68 // Zip over the resulting vector shape and memref indices. 69 // If the dimension is known to be in-bounds, it does not participate in 70 // the construction of `inBoundsCond`. 71 if (xferOp.isDimInBounds(resultIdx)) 72 return; 73 // Fold or create the check that `index + vector_size` <= `memref_size`. 74 Location loc = xferOp.getLoc(); 75 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); 76 auto d0 = getAffineDimExpr(0, xferOp.getContext()); 77 auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext()); 78 Value sum = 79 makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]); 80 Value cond = createFoldedSLE( 81 b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx)); 82 if (!cond) 83 return; 84 // Conjunction over all dims for which we are in-bounds. 85 if (inBoundsCond) 86 inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond); 87 else 88 inBoundsCond = cond; 89 }); 90 return inBoundsCond; 91 } 92 93 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 94 /// masking) fastpath and a slowpath. 95 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 96 /// newly created conditional upon function return. 97 /// To accomodate for the fact that the original vector.transfer indexing may be 98 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 99 /// scf.if op returns a view and values of type index. 100 /// At this time, only vector.transfer_read case is implemented. 101 /// 102 /// Example (a 2-D vector.transfer_read): 103 /// ``` 104 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 105 /// ``` 106 /// is transformed into: 107 /// ``` 108 /// %1:3 = scf.if (%inBounds) { 109 /// // fastpath, direct cast 110 /// memref.cast %A: memref<A...> to compatibleMemRefType 111 /// scf.yield %view : compatibleMemRefType, index, index 112 /// } else { 113 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 114 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 115 /// scf.yield %4 : compatibleMemRefType, index, index 116 // } 117 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 118 /// ``` 119 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 120 /// 121 /// Preconditions: 122 /// 1. `xferOp.permutation_map()` must be a minor identity map 123 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` 124 /// must be equal. This will be relaxed in the future but requires 125 /// rank-reducing subviews. 126 static LogicalResult 127 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { 128 // TODO: support 0-d corner case. 129 if (xferOp.getTransferRank() == 0) 130 return failure(); 131 132 // TODO: expand support to these 2 cases. 133 if (!xferOp.permutation_map().isMinorIdentity()) 134 return failure(); 135 // Must have some out-of-bounds dimension to be a candidate for splitting. 136 if (!xferOp.hasOutOfBoundsDim()) 137 return failure(); 138 // Don't split transfer operations directly under IfOp, this avoids applying 139 // the pattern recursively. 140 // TODO: improve the filtering condition to make it more applicable. 141 if (isa<scf::IfOp>(xferOp->getParentOp())) 142 return failure(); 143 return success(); 144 } 145 146 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can 147 /// be cast. If the MemRefTypes don't have the same rank or are not strided, 148 /// return null; otherwise: 149 /// 1. if `aT` and `bT` are cast-compatible, return `aT`. 150 /// 2. else return a new MemRefType obtained by iterating over the shape and 151 /// strides and: 152 /// a. keeping the ones that are static and equal across `aT` and `bT`. 153 /// b. using a dynamic shape and/or stride for the dimensions that don't 154 /// agree. 155 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { 156 if (memref::CastOp::areCastCompatible(aT, bT)) 157 return aT; 158 if (aT.getRank() != bT.getRank()) 159 return MemRefType(); 160 int64_t aOffset, bOffset; 161 SmallVector<int64_t, 4> aStrides, bStrides; 162 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 163 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 164 aStrides.size() != bStrides.size()) 165 return MemRefType(); 166 167 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape(); 168 int64_t resOffset; 169 SmallVector<int64_t, 4> resShape(aT.getRank(), 0), 170 resStrides(bT.getRank(), 0); 171 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { 172 resShape[idx] = 173 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize; 174 resStrides[idx] = (aStrides[idx] == bStrides[idx]) 175 ? aStrides[idx] 176 : ShapedType::kDynamicStrideOrOffset; 177 } 178 resOffset = 179 (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset; 180 return MemRefType::get( 181 resShape, aT.getElementType(), 182 makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext())); 183 } 184 185 /// Operates under a scoped context to build the intersection between the 186 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. 187 // TODO: view intersection/union/differences should be a proper std op. 188 static std::pair<Value, Value> 189 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, 190 Value alloc) { 191 Location loc = xferOp.getLoc(); 192 int64_t memrefRank = xferOp.getShapedType().getRank(); 193 // TODO: relax this precondition, will require rank-reducing subviews. 194 assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() && 195 "Expected memref rank to match the alloc rank"); 196 ValueRange leadingIndices = 197 xferOp.indices().take_front(xferOp.getLeadingShapedRank()); 198 SmallVector<OpFoldResult, 4> sizes; 199 sizes.append(leadingIndices.begin(), leadingIndices.end()); 200 auto isaWrite = isa<vector::TransferWriteOp>(xferOp); 201 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 202 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 203 Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(), 204 xferOp.source(), indicesIdx); 205 Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx); 206 Value index = xferOp.indices()[indicesIdx]; 207 AffineExpr i, j, k; 208 bindDims(xferOp.getContext(), i, j, k); 209 SmallVector<AffineMap, 4> maps = 210 AffineMap::inferFromExprList(MapList{{i - j, k}}); 211 // affine_min(%dimMemRef - %index, %dimAlloc) 212 Value affineMin = b.create<AffineMinOp>( 213 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); 214 sizes.push_back(affineMin); 215 }); 216 217 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range( 218 xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; })); 219 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0)); 220 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1)); 221 auto copySrc = b.create<memref::SubViewOp>( 222 loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides); 223 auto copyDest = b.create<memref::SubViewOp>( 224 loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides); 225 return std::make_pair(copySrc, copyDest); 226 } 227 228 /// Given an `xferOp` for which: 229 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 230 /// 2. a memref of single vector `alloc` has been allocated. 231 /// Produce IR resembling: 232 /// ``` 233 /// %1:3 = scf.if (%inBounds) { 234 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType 235 /// scf.yield %view, ... : compatibleMemRefType, index, index 236 /// } else { 237 /// %2 = linalg.fill(%pad, %alloc) 238 /// %3 = subview %view [...][...][...] 239 /// %4 = subview %alloc [0, 0] [...] [...] 240 /// linalg.copy(%3, %4) 241 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType 242 /// scf.yield %5, ... : compatibleMemRefType, index, index 243 /// } 244 /// ``` 245 /// Return the produced scf::IfOp. 246 static scf::IfOp 247 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, 248 TypeRange returnTypes, Value inBoundsCond, 249 MemRefType compatibleMemRefType, Value alloc) { 250 Location loc = xferOp.getLoc(); 251 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 252 Value memref = xferOp.getSource(); 253 return b.create<scf::IfOp>( 254 loc, returnTypes, inBoundsCond, 255 [&](OpBuilder &b, Location loc) { 256 Value res = memref; 257 if (compatibleMemRefType != xferOp.getShapedType()) 258 res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref); 259 scf::ValueVector viewAndIndices{res}; 260 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), 261 xferOp.getIndices().end()); 262 b.create<scf::YieldOp>(loc, viewAndIndices); 263 }, 264 [&](OpBuilder &b, Location loc) { 265 b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()}, 266 ValueRange{alloc}); 267 // Take partial subview of memref which guarantees no dimension 268 // overflows. 269 IRRewriter rewriter(b); 270 std::pair<Value, Value> copyArgs = createSubViewIntersection( 271 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 272 alloc); 273 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 274 Value casted = 275 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc); 276 scf::ValueVector viewAndIndices{casted}; 277 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 278 zero); 279 b.create<scf::YieldOp>(loc, viewAndIndices); 280 }); 281 } 282 283 /// Given an `xferOp` for which: 284 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 285 /// 2. a memref of single vector `alloc` has been allocated. 286 /// Produce IR resembling: 287 /// ``` 288 /// %1:3 = scf.if (%inBounds) { 289 /// memref.cast %A: memref<A...> to compatibleMemRefType 290 /// scf.yield %view, ... : compatibleMemRefType, index, index 291 /// } else { 292 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...> 293 /// %3 = vector.type_cast %extra_alloc : 294 /// memref<...> to memref<vector<...>> 295 /// store %2, %3[] : memref<vector<...>> 296 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 297 /// scf.yield %4, ... : compatibleMemRefType, index, index 298 /// } 299 /// ``` 300 /// Return the produced scf::IfOp. 301 static scf::IfOp createFullPartialVectorTransferRead( 302 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, 303 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { 304 Location loc = xferOp.getLoc(); 305 scf::IfOp fullPartialIfOp; 306 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 307 Value memref = xferOp.getSource(); 308 return b.create<scf::IfOp>( 309 loc, returnTypes, inBoundsCond, 310 [&](OpBuilder &b, Location loc) { 311 Value res = memref; 312 if (compatibleMemRefType != xferOp.getShapedType()) 313 res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref); 314 scf::ValueVector viewAndIndices{res}; 315 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), 316 xferOp.getIndices().end()); 317 b.create<scf::YieldOp>(loc, viewAndIndices); 318 }, 319 [&](OpBuilder &b, Location loc) { 320 Operation *newXfer = b.clone(*xferOp.getOperation()); 321 Value vector = cast<VectorTransferOpInterface>(newXfer).vector(); 322 b.create<memref::StoreOp>( 323 loc, vector, 324 b.create<vector::TypeCastOp>( 325 loc, MemRefType::get({}, vector.getType()), alloc)); 326 327 Value casted = 328 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc); 329 scf::ValueVector viewAndIndices{casted}; 330 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 331 zero); 332 b.create<scf::YieldOp>(loc, viewAndIndices); 333 }); 334 } 335 336 /// Given an `xferOp` for which: 337 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 338 /// 2. a memref of single vector `alloc` has been allocated. 339 /// Produce IR resembling: 340 /// ``` 341 /// %1:3 = scf.if (%inBounds) { 342 /// memref.cast %A: memref<A...> to compatibleMemRefType 343 /// scf.yield %view, ... : compatibleMemRefType, index, index 344 /// } else { 345 /// %3 = vector.type_cast %extra_alloc : 346 /// memref<...> to memref<vector<...>> 347 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 348 /// scf.yield %4, ... : compatibleMemRefType, index, index 349 /// } 350 /// ``` 351 static ValueRange 352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, 353 TypeRange returnTypes, Value inBoundsCond, 354 MemRefType compatibleMemRefType, Value alloc) { 355 Location loc = xferOp.getLoc(); 356 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 357 Value memref = xferOp.getSource(); 358 return b 359 .create<scf::IfOp>( 360 loc, returnTypes, inBoundsCond, 361 [&](OpBuilder &b, Location loc) { 362 Value res = memref; 363 if (compatibleMemRefType != xferOp.getShapedType()) 364 res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref); 365 scf::ValueVector viewAndIndices{res}; 366 viewAndIndices.insert(viewAndIndices.end(), 367 xferOp.getIndices().begin(), 368 xferOp.getIndices().end()); 369 b.create<scf::YieldOp>(loc, viewAndIndices); 370 }, 371 [&](OpBuilder &b, Location loc) { 372 Value casted = 373 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc); 374 scf::ValueVector viewAndIndices{casted}; 375 viewAndIndices.insert(viewAndIndices.end(), 376 xferOp.getTransferRank(), zero); 377 b.create<scf::YieldOp>(loc, viewAndIndices); 378 }) 379 ->getResults(); 380 } 381 382 /// Given an `xferOp` for which: 383 /// 1. `inBoundsCond` has been computed. 384 /// 2. a memref of single vector `alloc` has been allocated. 385 /// 3. it originally wrote to %view 386 /// Produce IR resembling: 387 /// ``` 388 /// %notInBounds = arith.xori %inBounds, %true 389 /// scf.if (%notInBounds) { 390 /// %3 = subview %alloc [...][...][...] 391 /// %4 = subview %view [0, 0][...][...] 392 /// linalg.copy(%3, %4) 393 /// } 394 /// ``` 395 static void createFullPartialLinalgCopy(RewriterBase &b, 396 vector::TransferWriteOp xferOp, 397 Value inBoundsCond, Value alloc) { 398 Location loc = xferOp.getLoc(); 399 auto notInBounds = b.create<arith::XOrIOp>( 400 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 401 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 402 IRRewriter rewriter(b); 403 std::pair<Value, Value> copyArgs = createSubViewIntersection( 404 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 405 alloc); 406 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 407 b.create<scf::YieldOp>(loc, ValueRange{}); 408 }); 409 } 410 411 /// Given an `xferOp` for which: 412 /// 1. `inBoundsCond` has been computed. 413 /// 2. a memref of single vector `alloc` has been allocated. 414 /// 3. it originally wrote to %view 415 /// Produce IR resembling: 416 /// ``` 417 /// %notInBounds = arith.xori %inBounds, %true 418 /// scf.if (%notInBounds) { 419 /// %2 = load %alloc : memref<vector<...>> 420 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...> 421 /// } 422 /// ``` 423 static void createFullPartialVectorTransferWrite(RewriterBase &b, 424 vector::TransferWriteOp xferOp, 425 Value inBoundsCond, 426 Value alloc) { 427 Location loc = xferOp.getLoc(); 428 auto notInBounds = b.create<arith::XOrIOp>( 429 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 430 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 431 BlockAndValueMapping mapping; 432 Value load = b.create<memref::LoadOp>( 433 loc, 434 b.create<vector::TypeCastOp>( 435 loc, MemRefType::get({}, xferOp.getVector().getType()), alloc)); 436 mapping.map(xferOp.getVector(), load); 437 b.clone(*xferOp.getOperation(), mapping); 438 b.create<scf::YieldOp>(loc, ValueRange{}); 439 }); 440 } 441 442 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 443 static Operation *getAutomaticAllocationScope(Operation *op) { 444 // Find the closest surrounding allocation scope that is not a known looping 445 // construct (putting alloca's in loops doesn't always lower to deallocation 446 // until the end of the loop). 447 Operation *scope = nullptr; 448 for (Operation *parent = op->getParentOp(); parent != nullptr; 449 parent = parent->getParentOp()) { 450 if (parent->hasTrait<OpTrait::AutomaticAllocationScope>()) 451 scope = parent; 452 if (!isa<scf::ForOp, AffineForOp>(parent)) 453 break; 454 } 455 assert(scope && "Expected op to be inside automatic allocation scope"); 456 return scope; 457 } 458 459 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 460 /// masking) fastpath and a slowpath. 461 /// 462 /// For vector.transfer_read: 463 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 464 /// newly created conditional upon function return. 465 /// To accomodate for the fact that the original vector.transfer indexing may be 466 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 467 /// scf.if op returns a view and values of type index. 468 /// 469 /// Example (a 2-D vector.transfer_read): 470 /// ``` 471 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 472 /// ``` 473 /// is transformed into: 474 /// ``` 475 /// %1:3 = scf.if (%inBounds) { 476 /// // fastpath, direct cast 477 /// memref.cast %A: memref<A...> to compatibleMemRefType 478 /// scf.yield %view : compatibleMemRefType, index, index 479 /// } else { 480 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 481 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 482 /// scf.yield %4 : compatibleMemRefType, index, index 483 // } 484 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 485 /// ``` 486 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 487 /// 488 /// For vector.transfer_write: 489 /// There are 2 conditional blocks. First a block to decide which memref and 490 /// indices to use for an unmasked, inbounds write. Then a conditional block to 491 /// further copy a partial buffer into the final result in the slow path case. 492 /// 493 /// Example (a 2-D vector.transfer_write): 494 /// ``` 495 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...> 496 /// ``` 497 /// is transformed into: 498 /// ``` 499 /// %1:3 = scf.if (%inBounds) { 500 /// memref.cast %A: memref<A...> to compatibleMemRefType 501 /// scf.yield %view : compatibleMemRefType, index, index 502 /// } else { 503 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 504 /// scf.yield %4 : compatibleMemRefType, index, index 505 /// } 506 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ... 507 /// true]} 508 /// scf.if (%notInBounds) { 509 /// // slowpath: not in-bounds vector.transfer or linalg.copy. 510 /// } 511 /// ``` 512 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 513 /// 514 /// Preconditions: 515 /// 1. `xferOp.permutation_map()` must be a minor identity map 516 /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()` 517 /// must be equal. This will be relaxed in the future but requires 518 /// rank-reducing subviews. 519 LogicalResult mlir::vector::splitFullAndPartialTransfer( 520 RewriterBase &b, VectorTransferOpInterface xferOp, 521 VectorTransformsOptions options, scf::IfOp *ifOp) { 522 if (options.vectorTransferSplit == VectorTransferSplit::None) 523 return failure(); 524 525 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true); 526 auto inBoundsAttr = b.getBoolArrayAttr(bools); 527 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { 528 xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); 529 return success(); 530 } 531 532 // Assert preconditions. Additionally, keep the variables in an inner scope to 533 // ensure they aren't used in the wrong scopes further down. 534 { 535 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && 536 "Expected splitFullAndPartialTransferPrecondition to hold"); 537 538 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation()); 539 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation()); 540 541 if (!(xferReadOp || xferWriteOp)) 542 return failure(); 543 if (xferWriteOp && xferWriteOp.getMask()) 544 return failure(); 545 if (xferReadOp && xferReadOp.getMask()) 546 return failure(); 547 } 548 549 RewriterBase::InsertionGuard guard(b); 550 b.setInsertionPoint(xferOp); 551 Value inBoundsCond = createInBoundsCond( 552 b, cast<VectorTransferOpInterface>(xferOp.getOperation())); 553 if (!inBoundsCond) 554 return failure(); 555 556 // Top of the function `alloc` for transient storage. 557 Value alloc; 558 { 559 RewriterBase::InsertionGuard guard(b); 560 Operation *scope = getAutomaticAllocationScope(xferOp); 561 assert(scope->getNumRegions() == 1 && 562 "AutomaticAllocationScope with >1 regions"); 563 b.setInsertionPointToStart(&scope->getRegion(0).front()); 564 auto shape = xferOp.getVectorType().getShape(); 565 Type elementType = xferOp.getVectorType().getElementType(); 566 alloc = b.create<memref::AllocaOp>(scope->getLoc(), 567 MemRefType::get(shape, elementType), 568 ValueRange{}, b.getI64IntegerAttr(32)); 569 } 570 571 MemRefType compatibleMemRefType = 572 getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(), 573 alloc.getType().cast<MemRefType>()); 574 if (!compatibleMemRefType) 575 return failure(); 576 577 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(), 578 b.getIndexType()); 579 returnTypes[0] = compatibleMemRefType; 580 581 if (auto xferReadOp = 582 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) { 583 // Read case: full fill + partial copy -> in-bounds vector.xfer_read. 584 scf::IfOp fullPartialIfOp = 585 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer 586 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes, 587 inBoundsCond, 588 compatibleMemRefType, alloc) 589 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes, 590 inBoundsCond, compatibleMemRefType, 591 alloc); 592 if (ifOp) 593 *ifOp = fullPartialIfOp; 594 595 // Set existing read op to in-bounds, it always reads from a full buffer. 596 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) 597 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); 598 599 xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); 600 601 return success(); 602 } 603 604 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation()); 605 606 // Decide which location to write the entire vector to. 607 auto memrefAndIndices = getLocationToWriteFullVec( 608 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc); 609 610 // Do an in bounds write to either the output or the extra allocated buffer. 611 // The operation is cloned to prevent deleting information needed for the 612 // later IR creation. 613 BlockAndValueMapping mapping; 614 mapping.map(xferWriteOp.getSource(), memrefAndIndices.front()); 615 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front()); 616 auto *clone = b.clone(*xferWriteOp, mapping); 617 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr); 618 619 // Create a potential copy from the allocated buffer to the final output in 620 // the slow path case. 621 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer) 622 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc); 623 else 624 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); 625 626 xferOp->erase(); 627 628 return success(); 629 } 630 631 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite( 632 Operation *op, PatternRewriter &rewriter) const { 633 auto xferOp = dyn_cast<VectorTransferOpInterface>(op); 634 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || 635 failed(filter(xferOp))) 636 return failure(); 637 rewriter.startRootUpdate(xferOp); 638 if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) { 639 rewriter.finalizeRootUpdate(xferOp); 640 return success(); 641 } 642 rewriter.cancelRootUpdate(xferOp); 643 return failure(); 644 } 645