1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent patterns to rewrite a vector.transfer 10 // op into a fully in-bounds part and a partial part. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <type_traits> 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 22 23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Interfaces/VectorInterfaces.h" 27 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/MapVector.h" 30 #include "llvm/ADT/STLExtras.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 35 #define DEBUG_TYPE "vector-transfer-split" 36 37 using namespace mlir; 38 using namespace mlir::vector; 39 40 static Optional<int64_t> extractConstantIndex(Value v) { 41 if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>()) 42 return cstOp.value(); 43 if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>()) 44 if (affineApplyOp.getAffineMap().isSingleConstant()) 45 return affineApplyOp.getAffineMap().getSingleConstantResult(); 46 return None; 47 } 48 49 // Missing foldings of scf.if make it necessary to perform poor man's folding 50 // eagerly, especially in the case of unrolling. In the future, this should go 51 // away once scf.if folds properly. 52 static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) { 53 auto maybeCstV = extractConstantIndex(v); 54 auto maybeCstUb = extractConstantIndex(ub); 55 if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb) 56 return Value(); 57 return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub); 58 } 59 60 /// Build the condition to ensure that a particular VectorTransferOpInterface 61 /// is in-bounds. 62 static Value createInBoundsCond(RewriterBase &b, 63 VectorTransferOpInterface xferOp) { 64 assert(xferOp.permutation_map().isMinorIdentity() && 65 "Expected minor identity map"); 66 Value inBoundsCond; 67 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 68 // Zip over the resulting vector shape and memref indices. 69 // If the dimension is known to be in-bounds, it does not participate in 70 // the construction of `inBoundsCond`. 71 if (xferOp.isDimInBounds(resultIdx)) 72 return; 73 // Fold or create the check that `index + vector_size` <= `memref_size`. 74 Location loc = xferOp.getLoc(); 75 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); 76 auto d0 = getAffineDimExpr(0, xferOp.getContext()); 77 auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext()); 78 Value sum = 79 makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]); 80 Value cond = createFoldedSLE( 81 b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx)); 82 if (!cond) 83 return; 84 // Conjunction over all dims for which we are in-bounds. 85 if (inBoundsCond) 86 inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond); 87 else 88 inBoundsCond = cond; 89 }); 90 return inBoundsCond; 91 } 92 93 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 94 /// masking) fastpath and a slowpath. 95 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 96 /// newly created conditional upon function return. 97 /// To accomodate for the fact that the original vector.transfer indexing may be 98 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 99 /// scf.if op returns a view and values of type index. 100 /// At this time, only vector.transfer_read case is implemented. 101 /// 102 /// Example (a 2-D vector.transfer_read): 103 /// ``` 104 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 105 /// ``` 106 /// is transformed into: 107 /// ``` 108 /// %1:3 = scf.if (%inBounds) { 109 /// // fastpath, direct cast 110 /// memref.cast %A: memref<A...> to compatibleMemRefType 111 /// scf.yield %view : compatibleMemRefType, index, index 112 /// } else { 113 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 114 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 115 /// scf.yield %4 : compatibleMemRefType, index, index 116 // } 117 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 118 /// ``` 119 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 120 /// 121 /// Preconditions: 122 /// 1. `xferOp.permutation_map()` must be a minor identity map 123 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` 124 /// must be equal. This will be relaxed in the future but requires 125 /// rank-reducing subviews. 126 static LogicalResult 127 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { 128 // TODO: support 0-d corner case. 129 if (xferOp.getTransferRank() == 0) 130 return failure(); 131 132 // TODO: expand support to these 2 cases. 133 if (!xferOp.permutation_map().isMinorIdentity()) 134 return failure(); 135 // Must have some out-of-bounds dimension to be a candidate for splitting. 136 if (!xferOp.hasOutOfBoundsDim()) 137 return failure(); 138 // Don't split transfer operations directly under IfOp, this avoids applying 139 // the pattern recursively. 140 // TODO: improve the filtering condition to make it more applicable. 141 if (isa<scf::IfOp>(xferOp->getParentOp())) 142 return failure(); 143 return success(); 144 } 145 146 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can 147 /// be cast. If the MemRefTypes don't have the same rank or are not strided, 148 /// return null; otherwise: 149 /// 1. if `aT` and `bT` are cast-compatible, return `aT`. 150 /// 2. else return a new MemRefType obtained by iterating over the shape and 151 /// strides and: 152 /// a. keeping the ones that are static and equal across `aT` and `bT`. 153 /// b. using a dynamic shape and/or stride for the dimensions that don't 154 /// agree. 155 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { 156 if (memref::CastOp::areCastCompatible(aT, bT)) 157 return aT; 158 if (aT.getRank() != bT.getRank()) 159 return MemRefType(); 160 int64_t aOffset, bOffset; 161 SmallVector<int64_t, 4> aStrides, bStrides; 162 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 163 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 164 aStrides.size() != bStrides.size()) 165 return MemRefType(); 166 167 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape(); 168 int64_t resOffset; 169 SmallVector<int64_t, 4> resShape(aT.getRank(), 0), 170 resStrides(bT.getRank(), 0); 171 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { 172 resShape[idx] = 173 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize; 174 resStrides[idx] = (aStrides[idx] == bStrides[idx]) 175 ? aStrides[idx] 176 : ShapedType::kDynamicStrideOrOffset; 177 } 178 resOffset = 179 (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset; 180 return MemRefType::get( 181 resShape, aT.getElementType(), 182 makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext())); 183 } 184 185 /// Operates under a scoped context to build the intersection between the 186 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. 187 // TODO: view intersection/union/differences should be a proper std op. 188 static std::pair<Value, Value> 189 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, 190 Value alloc) { 191 Location loc = xferOp.getLoc(); 192 int64_t memrefRank = xferOp.getShapedType().getRank(); 193 // TODO: relax this precondition, will require rank-reducing subviews. 194 assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() && 195 "Expected memref rank to match the alloc rank"); 196 ValueRange leadingIndices = 197 xferOp.indices().take_front(xferOp.getLeadingShapedRank()); 198 SmallVector<OpFoldResult, 4> sizes; 199 sizes.append(leadingIndices.begin(), leadingIndices.end()); 200 auto isaWrite = isa<vector::TransferWriteOp>(xferOp); 201 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 202 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 203 Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(), 204 xferOp.source(), indicesIdx); 205 Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx); 206 Value index = xferOp.indices()[indicesIdx]; 207 AffineExpr i, j, k; 208 bindDims(xferOp.getContext(), i, j, k); 209 SmallVector<AffineMap, 4> maps = 210 AffineMap::inferFromExprList(MapList{{i - j, k}}); 211 // affine_min(%dimMemRef - %index, %dimAlloc) 212 Value affineMin = b.create<AffineMinOp>( 213 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); 214 sizes.push_back(affineMin); 215 }); 216 217 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range( 218 xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; })); 219 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0)); 220 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1)); 221 auto copySrc = b.create<memref::SubViewOp>( 222 loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides); 223 auto copyDest = b.create<memref::SubViewOp>( 224 loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides); 225 return std::make_pair(copySrc, copyDest); 226 } 227 228 /// Given an `xferOp` for which: 229 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 230 /// 2. a memref of single vector `alloc` has been allocated. 231 /// Produce IR resembling: 232 /// ``` 233 /// %1:3 = scf.if (%inBounds) { 234 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType 235 /// scf.yield %view, ... : compatibleMemRefType, index, index 236 /// } else { 237 /// %2 = linalg.fill(%pad, %alloc) 238 /// %3 = subview %view [...][...][...] 239 /// %4 = subview %alloc [0, 0] [...] [...] 240 /// linalg.copy(%3, %4) 241 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType 242 /// scf.yield %5, ... : compatibleMemRefType, index, index 243 /// } 244 /// ``` 245 /// Return the produced scf::IfOp. 246 static scf::IfOp 247 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, 248 TypeRange returnTypes, Value inBoundsCond, 249 MemRefType compatibleMemRefType, Value alloc) { 250 Location loc = xferOp.getLoc(); 251 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 252 Value memref = xferOp.source(); 253 return b.create<scf::IfOp>( 254 loc, returnTypes, inBoundsCond, 255 [&](OpBuilder &b, Location loc) { 256 Value res = memref; 257 if (compatibleMemRefType != xferOp.getShapedType()) 258 res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref); 259 scf::ValueVector viewAndIndices{res}; 260 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), 261 xferOp.indices().end()); 262 b.create<scf::YieldOp>(loc, viewAndIndices); 263 }, 264 [&](OpBuilder &b, Location loc) { 265 b.create<linalg::FillOp>(loc, xferOp.padding(), alloc); 266 // Take partial subview of memref which guarantees no dimension 267 // overflows. 268 IRRewriter rewriter(b); 269 std::pair<Value, Value> copyArgs = createSubViewIntersection( 270 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 271 alloc); 272 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 273 Value casted = 274 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc); 275 scf::ValueVector viewAndIndices{casted}; 276 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 277 zero); 278 b.create<scf::YieldOp>(loc, viewAndIndices); 279 }); 280 } 281 282 /// Given an `xferOp` for which: 283 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 284 /// 2. a memref of single vector `alloc` has been allocated. 285 /// Produce IR resembling: 286 /// ``` 287 /// %1:3 = scf.if (%inBounds) { 288 /// memref.cast %A: memref<A...> to compatibleMemRefType 289 /// scf.yield %view, ... : compatibleMemRefType, index, index 290 /// } else { 291 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...> 292 /// %3 = vector.type_cast %extra_alloc : 293 /// memref<...> to memref<vector<...>> 294 /// store %2, %3[] : memref<vector<...>> 295 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 296 /// scf.yield %4, ... : compatibleMemRefType, index, index 297 /// } 298 /// ``` 299 /// Return the produced scf::IfOp. 300 static scf::IfOp createFullPartialVectorTransferRead( 301 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, 302 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { 303 Location loc = xferOp.getLoc(); 304 scf::IfOp fullPartialIfOp; 305 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 306 Value memref = xferOp.source(); 307 return b.create<scf::IfOp>( 308 loc, returnTypes, inBoundsCond, 309 [&](OpBuilder &b, Location loc) { 310 Value res = memref; 311 if (compatibleMemRefType != xferOp.getShapedType()) 312 res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref); 313 scf::ValueVector viewAndIndices{res}; 314 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), 315 xferOp.indices().end()); 316 b.create<scf::YieldOp>(loc, viewAndIndices); 317 }, 318 [&](OpBuilder &b, Location loc) { 319 Operation *newXfer = b.clone(*xferOp.getOperation()); 320 Value vector = cast<VectorTransferOpInterface>(newXfer).vector(); 321 b.create<memref::StoreOp>( 322 loc, vector, 323 b.create<vector::TypeCastOp>( 324 loc, MemRefType::get({}, vector.getType()), alloc)); 325 326 Value casted = 327 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc); 328 scf::ValueVector viewAndIndices{casted}; 329 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 330 zero); 331 b.create<scf::YieldOp>(loc, viewAndIndices); 332 }); 333 } 334 335 /// Given an `xferOp` for which: 336 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 337 /// 2. a memref of single vector `alloc` has been allocated. 338 /// Produce IR resembling: 339 /// ``` 340 /// %1:3 = scf.if (%inBounds) { 341 /// memref.cast %A: memref<A...> to compatibleMemRefType 342 /// scf.yield %view, ... : compatibleMemRefType, index, index 343 /// } else { 344 /// %3 = vector.type_cast %extra_alloc : 345 /// memref<...> to memref<vector<...>> 346 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 347 /// scf.yield %4, ... : compatibleMemRefType, index, index 348 /// } 349 /// ``` 350 static ValueRange 351 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, 352 TypeRange returnTypes, Value inBoundsCond, 353 MemRefType compatibleMemRefType, Value alloc) { 354 Location loc = xferOp.getLoc(); 355 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 356 Value memref = xferOp.source(); 357 return b 358 .create<scf::IfOp>( 359 loc, returnTypes, inBoundsCond, 360 [&](OpBuilder &b, Location loc) { 361 Value res = memref; 362 if (compatibleMemRefType != xferOp.getShapedType()) 363 res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref); 364 scf::ValueVector viewAndIndices{res}; 365 viewAndIndices.insert(viewAndIndices.end(), 366 xferOp.indices().begin(), 367 xferOp.indices().end()); 368 b.create<scf::YieldOp>(loc, viewAndIndices); 369 }, 370 [&](OpBuilder &b, Location loc) { 371 Value casted = 372 b.create<memref::CastOp>(loc, compatibleMemRefType, alloc); 373 scf::ValueVector viewAndIndices{casted}; 374 viewAndIndices.insert(viewAndIndices.end(), 375 xferOp.getTransferRank(), zero); 376 b.create<scf::YieldOp>(loc, viewAndIndices); 377 }) 378 ->getResults(); 379 } 380 381 /// Given an `xferOp` for which: 382 /// 1. `inBoundsCond` has been computed. 383 /// 2. a memref of single vector `alloc` has been allocated. 384 /// 3. it originally wrote to %view 385 /// Produce IR resembling: 386 /// ``` 387 /// %notInBounds = arith.xori %inBounds, %true 388 /// scf.if (%notInBounds) { 389 /// %3 = subview %alloc [...][...][...] 390 /// %4 = subview %view [0, 0][...][...] 391 /// linalg.copy(%3, %4) 392 /// } 393 /// ``` 394 static void createFullPartialLinalgCopy(RewriterBase &b, 395 vector::TransferWriteOp xferOp, 396 Value inBoundsCond, Value alloc) { 397 Location loc = xferOp.getLoc(); 398 auto notInBounds = b.create<arith::XOrIOp>( 399 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 400 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 401 IRRewriter rewriter(b); 402 std::pair<Value, Value> copyArgs = createSubViewIntersection( 403 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 404 alloc); 405 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 406 b.create<scf::YieldOp>(loc, ValueRange{}); 407 }); 408 } 409 410 /// Given an `xferOp` for which: 411 /// 1. `inBoundsCond` has been computed. 412 /// 2. a memref of single vector `alloc` has been allocated. 413 /// 3. it originally wrote to %view 414 /// Produce IR resembling: 415 /// ``` 416 /// %notInBounds = arith.xori %inBounds, %true 417 /// scf.if (%notInBounds) { 418 /// %2 = load %alloc : memref<vector<...>> 419 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...> 420 /// } 421 /// ``` 422 static void createFullPartialVectorTransferWrite(RewriterBase &b, 423 vector::TransferWriteOp xferOp, 424 Value inBoundsCond, 425 Value alloc) { 426 Location loc = xferOp.getLoc(); 427 auto notInBounds = b.create<arith::XOrIOp>( 428 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 429 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 430 BlockAndValueMapping mapping; 431 Value load = b.create<memref::LoadOp>( 432 loc, b.create<vector::TypeCastOp>( 433 loc, MemRefType::get({}, xferOp.vector().getType()), alloc)); 434 mapping.map(xferOp.vector(), load); 435 b.clone(*xferOp.getOperation(), mapping); 436 b.create<scf::YieldOp>(loc, ValueRange{}); 437 }); 438 } 439 440 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 441 static Operation *getAutomaticAllocationScope(Operation *op) { 442 Operation *scope = 443 op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 444 assert(scope && "Expected op to be inside automatic allocation scope"); 445 return scope; 446 } 447 448 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 449 /// masking) fastpath and a slowpath. 450 /// 451 /// For vector.transfer_read: 452 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 453 /// newly created conditional upon function return. 454 /// To accomodate for the fact that the original vector.transfer indexing may be 455 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 456 /// scf.if op returns a view and values of type index. 457 /// 458 /// Example (a 2-D vector.transfer_read): 459 /// ``` 460 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 461 /// ``` 462 /// is transformed into: 463 /// ``` 464 /// %1:3 = scf.if (%inBounds) { 465 /// // fastpath, direct cast 466 /// memref.cast %A: memref<A...> to compatibleMemRefType 467 /// scf.yield %view : compatibleMemRefType, index, index 468 /// } else { 469 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 470 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 471 /// scf.yield %4 : compatibleMemRefType, index, index 472 // } 473 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 474 /// ``` 475 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 476 /// 477 /// For vector.transfer_write: 478 /// There are 2 conditional blocks. First a block to decide which memref and 479 /// indices to use for an unmasked, inbounds write. Then a conditional block to 480 /// further copy a partial buffer into the final result in the slow path case. 481 /// 482 /// Example (a 2-D vector.transfer_write): 483 /// ``` 484 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...> 485 /// ``` 486 /// is transformed into: 487 /// ``` 488 /// %1:3 = scf.if (%inBounds) { 489 /// memref.cast %A: memref<A...> to compatibleMemRefType 490 /// scf.yield %view : compatibleMemRefType, index, index 491 /// } else { 492 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 493 /// scf.yield %4 : compatibleMemRefType, index, index 494 /// } 495 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ... 496 /// true]} 497 /// scf.if (%notInBounds) { 498 /// // slowpath: not in-bounds vector.transfer or linalg.copy. 499 /// } 500 /// ``` 501 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 502 /// 503 /// Preconditions: 504 /// 1. `xferOp.permutation_map()` must be a minor identity map 505 /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()` 506 /// must be equal. This will be relaxed in the future but requires 507 /// rank-reducing subviews. 508 LogicalResult mlir::vector::splitFullAndPartialTransfer( 509 RewriterBase &b, VectorTransferOpInterface xferOp, 510 VectorTransformsOptions options, scf::IfOp *ifOp) { 511 if (options.vectorTransferSplit == VectorTransferSplit::None) 512 return failure(); 513 514 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true); 515 auto inBoundsAttr = b.getBoolArrayAttr(bools); 516 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { 517 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 518 return success(); 519 } 520 521 // Assert preconditions. Additionally, keep the variables in an inner scope to 522 // ensure they aren't used in the wrong scopes further down. 523 { 524 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && 525 "Expected splitFullAndPartialTransferPrecondition to hold"); 526 527 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation()); 528 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation()); 529 530 if (!(xferReadOp || xferWriteOp)) 531 return failure(); 532 if (xferWriteOp && xferWriteOp.mask()) 533 return failure(); 534 if (xferReadOp && xferReadOp.mask()) 535 return failure(); 536 } 537 538 RewriterBase::InsertionGuard guard(b); 539 b.setInsertionPoint(xferOp); 540 Value inBoundsCond = createInBoundsCond( 541 b, cast<VectorTransferOpInterface>(xferOp.getOperation())); 542 if (!inBoundsCond) 543 return failure(); 544 545 // Top of the function `alloc` for transient storage. 546 Value alloc; 547 { 548 RewriterBase::InsertionGuard guard(b); 549 Operation *scope = getAutomaticAllocationScope(xferOp); 550 assert(scope->getNumRegions() == 1 && 551 "AutomaticAllocationScope with >1 regions"); 552 b.setInsertionPointToStart(&scope->getRegion(0).front()); 553 auto shape = xferOp.getVectorType().getShape(); 554 Type elementType = xferOp.getVectorType().getElementType(); 555 alloc = b.create<memref::AllocaOp>(scope->getLoc(), 556 MemRefType::get(shape, elementType), 557 ValueRange{}, b.getI64IntegerAttr(32)); 558 } 559 560 MemRefType compatibleMemRefType = 561 getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(), 562 alloc.getType().cast<MemRefType>()); 563 if (!compatibleMemRefType) 564 return failure(); 565 566 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(), 567 b.getIndexType()); 568 returnTypes[0] = compatibleMemRefType; 569 570 if (auto xferReadOp = 571 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) { 572 // Read case: full fill + partial copy -> in-bounds vector.xfer_read. 573 scf::IfOp fullPartialIfOp = 574 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer 575 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes, 576 inBoundsCond, 577 compatibleMemRefType, alloc) 578 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes, 579 inBoundsCond, compatibleMemRefType, 580 alloc); 581 if (ifOp) 582 *ifOp = fullPartialIfOp; 583 584 // Set existing read op to in-bounds, it always reads from a full buffer. 585 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) 586 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); 587 588 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 589 590 return success(); 591 } 592 593 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation()); 594 595 // Decide which location to write the entire vector to. 596 auto memrefAndIndices = getLocationToWriteFullVec( 597 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc); 598 599 // Do an in bounds write to either the output or the extra allocated buffer. 600 // The operation is cloned to prevent deleting information needed for the 601 // later IR creation. 602 BlockAndValueMapping mapping; 603 mapping.map(xferWriteOp.source(), memrefAndIndices.front()); 604 mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front()); 605 auto *clone = b.clone(*xferWriteOp, mapping); 606 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr); 607 608 // Create a potential copy from the allocated buffer to the final output in 609 // the slow path case. 610 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer) 611 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc); 612 else 613 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); 614 615 xferOp->erase(); 616 617 return success(); 618 } 619 620 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite( 621 Operation *op, PatternRewriter &rewriter) const { 622 auto xferOp = dyn_cast<VectorTransferOpInterface>(op); 623 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || 624 failed(filter(xferOp))) 625 return failure(); 626 rewriter.startRootUpdate(xferOp); 627 if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) { 628 rewriter.finalizeRootUpdate(xferOp); 629 return success(); 630 } 631 rewriter.cancelRootUpdate(xferOp); 632 return failure(); 633 } 634