1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent patterns to rewrite a vector.transfer 10 // op into a fully in-bounds part and a partial part. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <type_traits> 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 24 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Interfaces/VectorInterfaces.h" 28 29 #include "llvm/ADT/DenseSet.h" 30 #include "llvm/ADT/MapVector.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 36 #define DEBUG_TYPE "vector-transfer-split" 37 38 using namespace mlir; 39 using namespace mlir::vector; 40 41 static Optional<int64_t> extractConstantIndex(Value v) { 42 if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>()) 43 return cstOp.value(); 44 if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>()) 45 if (affineApplyOp.getAffineMap().isSingleConstant()) 46 return affineApplyOp.getAffineMap().getSingleConstantResult(); 47 return None; 48 } 49 50 // Missing foldings of scf.if make it necessary to perform poor man's folding 51 // eagerly, especially in the case of unrolling. In the future, this should go 52 // away once scf.if folds properly. 53 static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) { 54 auto maybeCstV = extractConstantIndex(v); 55 auto maybeCstUb = extractConstantIndex(ub); 56 if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb) 57 return Value(); 58 return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub); 59 } 60 61 /// Build the condition to ensure that a particular VectorTransferOpInterface 62 /// is in-bounds. 63 static Value createInBoundsCond(RewriterBase &b, 64 VectorTransferOpInterface xferOp) { 65 assert(xferOp.permutation_map().isMinorIdentity() && 66 "Expected minor identity map"); 67 Value inBoundsCond; 68 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 69 // Zip over the resulting vector shape and memref indices. 70 // If the dimension is known to be in-bounds, it does not participate in 71 // the construction of `inBoundsCond`. 72 if (xferOp.isDimInBounds(resultIdx)) 73 return; 74 // Fold or create the check that `index + vector_size` <= `memref_size`. 75 Location loc = xferOp.getLoc(); 76 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); 77 auto d0 = getAffineDimExpr(0, xferOp.getContext()); 78 auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext()); 79 Value sum = 80 makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]); 81 Value cond = createFoldedSLE( 82 b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx)); 83 if (!cond) 84 return; 85 // Conjunction over all dims for which we are in-bounds. 86 if (inBoundsCond) 87 inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond); 88 else 89 inBoundsCond = cond; 90 }); 91 return inBoundsCond; 92 } 93 94 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 95 /// masking) fastpath and a slowpath. 96 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 97 /// newly created conditional upon function return. 98 /// To accomodate for the fact that the original vector.transfer indexing may be 99 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 100 /// scf.if op returns a view and values of type index. 101 /// At this time, only vector.transfer_read case is implemented. 102 /// 103 /// Example (a 2-D vector.transfer_read): 104 /// ``` 105 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 106 /// ``` 107 /// is transformed into: 108 /// ``` 109 /// %1:3 = scf.if (%inBounds) { 110 /// // fastpath, direct cast 111 /// memref.cast %A: memref<A...> to compatibleMemRefType 112 /// scf.yield %view : compatibleMemRefType, index, index 113 /// } else { 114 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 115 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 116 /// scf.yield %4 : compatibleMemRefType, index, index 117 // } 118 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 119 /// ``` 120 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 121 /// 122 /// Preconditions: 123 /// 1. `xferOp.permutation_map()` must be a minor identity map 124 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` 125 /// must be equal. This will be relaxed in the future but requires 126 /// rank-reducing subviews. 127 static LogicalResult 128 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { 129 // TODO: support 0-d corner case. 130 if (xferOp.getTransferRank() == 0) 131 return failure(); 132 133 // TODO: expand support to these 2 cases. 134 if (!xferOp.permutation_map().isMinorIdentity()) 135 return failure(); 136 // Must have some out-of-bounds dimension to be a candidate for splitting. 137 if (!xferOp.hasOutOfBoundsDim()) 138 return failure(); 139 // Don't split transfer operations directly under IfOp, this avoids applying 140 // the pattern recursively. 141 // TODO: improve the filtering condition to make it more applicable. 142 if (isa<scf::IfOp>(xferOp->getParentOp())) 143 return failure(); 144 return success(); 145 } 146 147 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can 148 /// be cast. If the MemRefTypes don't have the same rank or are not strided, 149 /// return null; otherwise: 150 /// 1. if `aT` and `bT` are cast-compatible, return `aT`. 151 /// 2. else return a new MemRefType obtained by iterating over the shape and 152 /// strides and: 153 /// a. keeping the ones that are static and equal across `aT` and `bT`. 154 /// b. using a dynamic shape and/or stride for the dimensions that don't 155 /// agree. 156 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { 157 if (memref::CastOp::areCastCompatible(aT, bT)) 158 return aT; 159 if (aT.getRank() != bT.getRank()) 160 return MemRefType(); 161 int64_t aOffset, bOffset; 162 SmallVector<int64_t, 4> aStrides, bStrides; 163 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 164 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 165 aStrides.size() != bStrides.size()) 166 return MemRefType(); 167 168 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape(); 169 int64_t resOffset; 170 SmallVector<int64_t, 4> resShape(aT.getRank(), 0), 171 resStrides(bT.getRank(), 0); 172 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { 173 resShape[idx] = 174 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize; 175 resStrides[idx] = (aStrides[idx] == bStrides[idx]) 176 ? aStrides[idx] 177 : ShapedType::kDynamicStrideOrOffset; 178 } 179 resOffset = 180 (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset; 181 return MemRefType::get( 182 resShape, aT.getElementType(), 183 makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext())); 184 } 185 186 /// Operates under a scoped context to build the intersection between the 187 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. 188 // TODO: view intersection/union/differences should be a proper std op. 189 static std::pair<Value, Value> 190 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, 191 Value alloc) { 192 Location loc = xferOp.getLoc(); 193 int64_t memrefRank = xferOp.getShapedType().getRank(); 194 // TODO: relax this precondition, will require rank-reducing subviews. 195 assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() && 196 "Expected memref rank to match the alloc rank"); 197 ValueRange leadingIndices = 198 xferOp.indices().take_front(xferOp.getLeadingShapedRank()); 199 SmallVector<OpFoldResult, 4> sizes; 200 sizes.append(leadingIndices.begin(), leadingIndices.end()); 201 auto isaWrite = isa<vector::TransferWriteOp>(xferOp); 202 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 203 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 204 Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(), 205 xferOp.source(), indicesIdx); 206 Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx); 207 Value index = xferOp.indices()[indicesIdx]; 208 AffineExpr i, j, k; 209 bindDims(xferOp.getContext(), i, j, k); 210 SmallVector<AffineMap, 4> maps = 211 AffineMap::inferFromExprList(MapList{{i - j, k}}); 212 // affine_min(%dimMemRef - %index, %dimAlloc) 213 Value affineMin = b.create<AffineMinOp>( 214 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); 215 sizes.push_back(affineMin); 216 }); 217 218 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range( 219 xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; })); 220 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0)); 221 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1)); 222 auto copySrc = b.create<memref::SubViewOp>( 223 loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides); 224 auto copyDest = b.create<memref::SubViewOp>( 225 loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides); 226 return std::make_pair(copySrc, copyDest); 227 } 228 229 /// Given an `xferOp` for which: 230 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 231 /// 2. a memref of single vector `alloc` has been allocated. 232 /// Produce IR resembling: 233 /// ``` 234 /// %1:3 = scf.if (%inBounds) { 235 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType 236 /// scf.yield %view, ... : compatibleMemRefType, index, index 237 /// } else { 238 /// %2 = linalg.fill(%pad, %alloc) 239 /// %3 = subview %view [...][...][...] 240 /// %4 = subview %alloc [0, 0] [...] [...] 241 /// linalg.copy(%3, %4) 242 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType 243 /// scf.yield %5, ... : compatibleMemRefType, index, index 244 /// } 245 /// ``` 246 /// Return the produced scf::IfOp. 247 static scf::IfOp 248 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, 249 TypeRange returnTypes, Value inBoundsCond, 250 MemRefType compatibleMemRefType, Value alloc) { 251 Location loc = xferOp.getLoc(); 252 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 253 Value memref = xferOp.source(); 254 return b.create<scf::IfOp>( 255 loc, returnTypes, inBoundsCond, 256 [&](OpBuilder &b, Location loc) { 257 Value res = memref; 258 if (compatibleMemRefType != xferOp.getShapedType()) 259 res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType); 260 scf::ValueVector viewAndIndices{res}; 261 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), 262 xferOp.indices().end()); 263 b.create<scf::YieldOp>(loc, viewAndIndices); 264 }, 265 [&](OpBuilder &b, Location loc) { 266 b.create<linalg::FillOp>(loc, xferOp.padding(), alloc); 267 // Take partial subview of memref which guarantees no dimension 268 // overflows. 269 IRRewriter rewriter(b); 270 std::pair<Value, Value> copyArgs = createSubViewIntersection( 271 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 272 alloc); 273 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 274 Value casted = 275 b.create<memref::CastOp>(loc, alloc, compatibleMemRefType); 276 scf::ValueVector viewAndIndices{casted}; 277 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 278 zero); 279 b.create<scf::YieldOp>(loc, viewAndIndices); 280 }); 281 } 282 283 /// Given an `xferOp` for which: 284 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 285 /// 2. a memref of single vector `alloc` has been allocated. 286 /// Produce IR resembling: 287 /// ``` 288 /// %1:3 = scf.if (%inBounds) { 289 /// memref.cast %A: memref<A...> to compatibleMemRefType 290 /// scf.yield %view, ... : compatibleMemRefType, index, index 291 /// } else { 292 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...> 293 /// %3 = vector.type_cast %extra_alloc : 294 /// memref<...> to memref<vector<...>> 295 /// store %2, %3[] : memref<vector<...>> 296 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 297 /// scf.yield %4, ... : compatibleMemRefType, index, index 298 /// } 299 /// ``` 300 /// Return the produced scf::IfOp. 301 static scf::IfOp createFullPartialVectorTransferRead( 302 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, 303 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { 304 Location loc = xferOp.getLoc(); 305 scf::IfOp fullPartialIfOp; 306 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 307 Value memref = xferOp.source(); 308 return b.create<scf::IfOp>( 309 loc, returnTypes, inBoundsCond, 310 [&](OpBuilder &b, Location loc) { 311 Value res = memref; 312 if (compatibleMemRefType != xferOp.getShapedType()) 313 res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType); 314 scf::ValueVector viewAndIndices{res}; 315 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), 316 xferOp.indices().end()); 317 b.create<scf::YieldOp>(loc, viewAndIndices); 318 }, 319 [&](OpBuilder &b, Location loc) { 320 Operation *newXfer = b.clone(*xferOp.getOperation()); 321 Value vector = cast<VectorTransferOpInterface>(newXfer).vector(); 322 b.create<memref::StoreOp>( 323 loc, vector, 324 b.create<vector::TypeCastOp>( 325 loc, MemRefType::get({}, vector.getType()), alloc)); 326 327 Value casted = 328 b.create<memref::CastOp>(loc, alloc, compatibleMemRefType); 329 scf::ValueVector viewAndIndices{casted}; 330 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 331 zero); 332 b.create<scf::YieldOp>(loc, viewAndIndices); 333 }); 334 } 335 336 /// Given an `xferOp` for which: 337 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 338 /// 2. a memref of single vector `alloc` has been allocated. 339 /// Produce IR resembling: 340 /// ``` 341 /// %1:3 = scf.if (%inBounds) { 342 /// memref.cast %A: memref<A...> to compatibleMemRefType 343 /// scf.yield %view, ... : compatibleMemRefType, index, index 344 /// } else { 345 /// %3 = vector.type_cast %extra_alloc : 346 /// memref<...> to memref<vector<...>> 347 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 348 /// scf.yield %4, ... : compatibleMemRefType, index, index 349 /// } 350 /// ``` 351 static ValueRange 352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, 353 TypeRange returnTypes, Value inBoundsCond, 354 MemRefType compatibleMemRefType, Value alloc) { 355 Location loc = xferOp.getLoc(); 356 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 357 Value memref = xferOp.source(); 358 return b 359 .create<scf::IfOp>( 360 loc, returnTypes, inBoundsCond, 361 [&](OpBuilder &b, Location loc) { 362 Value res = memref; 363 if (compatibleMemRefType != xferOp.getShapedType()) 364 res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType); 365 scf::ValueVector viewAndIndices{res}; 366 viewAndIndices.insert(viewAndIndices.end(), 367 xferOp.indices().begin(), 368 xferOp.indices().end()); 369 b.create<scf::YieldOp>(loc, viewAndIndices); 370 }, 371 [&](OpBuilder &b, Location loc) { 372 Value casted = 373 b.create<memref::CastOp>(loc, alloc, compatibleMemRefType); 374 scf::ValueVector viewAndIndices{casted}; 375 viewAndIndices.insert(viewAndIndices.end(), 376 xferOp.getTransferRank(), zero); 377 b.create<scf::YieldOp>(loc, viewAndIndices); 378 }) 379 ->getResults(); 380 } 381 382 /// Given an `xferOp` for which: 383 /// 1. `inBoundsCond` has been computed. 384 /// 2. a memref of single vector `alloc` has been allocated. 385 /// 3. it originally wrote to %view 386 /// Produce IR resembling: 387 /// ``` 388 /// %notInBounds = arith.xori %inBounds, %true 389 /// scf.if (%notInBounds) { 390 /// %3 = subview %alloc [...][...][...] 391 /// %4 = subview %view [0, 0][...][...] 392 /// linalg.copy(%3, %4) 393 /// } 394 /// ``` 395 static void createFullPartialLinalgCopy(RewriterBase &b, 396 vector::TransferWriteOp xferOp, 397 Value inBoundsCond, Value alloc) { 398 Location loc = xferOp.getLoc(); 399 auto notInBounds = b.create<arith::XOrIOp>( 400 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 401 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 402 IRRewriter rewriter(b); 403 std::pair<Value, Value> copyArgs = createSubViewIntersection( 404 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 405 alloc); 406 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 407 b.create<scf::YieldOp>(loc, ValueRange{}); 408 }); 409 } 410 411 /// Given an `xferOp` for which: 412 /// 1. `inBoundsCond` has been computed. 413 /// 2. a memref of single vector `alloc` has been allocated. 414 /// 3. it originally wrote to %view 415 /// Produce IR resembling: 416 /// ``` 417 /// %notInBounds = arith.xori %inBounds, %true 418 /// scf.if (%notInBounds) { 419 /// %2 = load %alloc : memref<vector<...>> 420 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...> 421 /// } 422 /// ``` 423 static void createFullPartialVectorTransferWrite(RewriterBase &b, 424 vector::TransferWriteOp xferOp, 425 Value inBoundsCond, 426 Value alloc) { 427 Location loc = xferOp.getLoc(); 428 auto notInBounds = b.create<arith::XOrIOp>( 429 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 430 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 431 BlockAndValueMapping mapping; 432 Value load = b.create<memref::LoadOp>( 433 loc, b.create<vector::TypeCastOp>( 434 loc, MemRefType::get({}, xferOp.vector().getType()), alloc)); 435 mapping.map(xferOp.vector(), load); 436 b.clone(*xferOp.getOperation(), mapping); 437 b.create<scf::YieldOp>(loc, ValueRange{}); 438 }); 439 } 440 441 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 442 static Operation *getAutomaticAllocationScope(Operation *op) { 443 Operation *scope = 444 op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 445 assert(scope && "Expected op to be inside automatic allocation scope"); 446 return scope; 447 } 448 449 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 450 /// masking) fastpath and a slowpath. 451 /// 452 /// For vector.transfer_read: 453 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 454 /// newly created conditional upon function return. 455 /// To accomodate for the fact that the original vector.transfer indexing may be 456 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 457 /// scf.if op returns a view and values of type index. 458 /// 459 /// Example (a 2-D vector.transfer_read): 460 /// ``` 461 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 462 /// ``` 463 /// is transformed into: 464 /// ``` 465 /// %1:3 = scf.if (%inBounds) { 466 /// // fastpath, direct cast 467 /// memref.cast %A: memref<A...> to compatibleMemRefType 468 /// scf.yield %view : compatibleMemRefType, index, index 469 /// } else { 470 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 471 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 472 /// scf.yield %4 : compatibleMemRefType, index, index 473 // } 474 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 475 /// ``` 476 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 477 /// 478 /// For vector.transfer_write: 479 /// There are 2 conditional blocks. First a block to decide which memref and 480 /// indices to use for an unmasked, inbounds write. Then a conditional block to 481 /// further copy a partial buffer into the final result in the slow path case. 482 /// 483 /// Example (a 2-D vector.transfer_write): 484 /// ``` 485 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...> 486 /// ``` 487 /// is transformed into: 488 /// ``` 489 /// %1:3 = scf.if (%inBounds) { 490 /// memref.cast %A: memref<A...> to compatibleMemRefType 491 /// scf.yield %view : compatibleMemRefType, index, index 492 /// } else { 493 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 494 /// scf.yield %4 : compatibleMemRefType, index, index 495 /// } 496 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ... 497 /// true]} 498 /// scf.if (%notInBounds) { 499 /// // slowpath: not in-bounds vector.transfer or linalg.copy. 500 /// } 501 /// ``` 502 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 503 /// 504 /// Preconditions: 505 /// 1. `xferOp.permutation_map()` must be a minor identity map 506 /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()` 507 /// must be equal. This will be relaxed in the future but requires 508 /// rank-reducing subviews. 509 LogicalResult mlir::vector::splitFullAndPartialTransfer( 510 RewriterBase &b, VectorTransferOpInterface xferOp, 511 VectorTransformsOptions options, scf::IfOp *ifOp) { 512 if (options.vectorTransferSplit == VectorTransferSplit::None) 513 return failure(); 514 515 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true); 516 auto inBoundsAttr = b.getBoolArrayAttr(bools); 517 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { 518 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 519 return success(); 520 } 521 522 // Assert preconditions. Additionally, keep the variables in an inner scope to 523 // ensure they aren't used in the wrong scopes further down. 524 { 525 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && 526 "Expected splitFullAndPartialTransferPrecondition to hold"); 527 528 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation()); 529 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation()); 530 531 if (!(xferReadOp || xferWriteOp)) 532 return failure(); 533 if (xferWriteOp && xferWriteOp.mask()) 534 return failure(); 535 if (xferReadOp && xferReadOp.mask()) 536 return failure(); 537 } 538 539 RewriterBase::InsertionGuard guard(b); 540 b.setInsertionPoint(xferOp); 541 Value inBoundsCond = createInBoundsCond( 542 b, cast<VectorTransferOpInterface>(xferOp.getOperation())); 543 if (!inBoundsCond) 544 return failure(); 545 546 // Top of the function `alloc` for transient storage. 547 Value alloc; 548 { 549 RewriterBase::InsertionGuard guard(b); 550 Operation *scope = getAutomaticAllocationScope(xferOp); 551 assert(scope->getNumRegions() == 1 && 552 "AutomaticAllocationScope with >1 regions"); 553 b.setInsertionPointToStart(&scope->getRegion(0).front()); 554 auto shape = xferOp.getVectorType().getShape(); 555 Type elementType = xferOp.getVectorType().getElementType(); 556 alloc = b.create<memref::AllocaOp>(scope->getLoc(), 557 MemRefType::get(shape, elementType), 558 ValueRange{}, b.getI64IntegerAttr(32)); 559 } 560 561 MemRefType compatibleMemRefType = 562 getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(), 563 alloc.getType().cast<MemRefType>()); 564 if (!compatibleMemRefType) 565 return failure(); 566 567 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(), 568 b.getIndexType()); 569 returnTypes[0] = compatibleMemRefType; 570 571 if (auto xferReadOp = 572 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) { 573 // Read case: full fill + partial copy -> in-bounds vector.xfer_read. 574 scf::IfOp fullPartialIfOp = 575 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer 576 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes, 577 inBoundsCond, 578 compatibleMemRefType, alloc) 579 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes, 580 inBoundsCond, compatibleMemRefType, 581 alloc); 582 if (ifOp) 583 *ifOp = fullPartialIfOp; 584 585 // Set existing read op to in-bounds, it always reads from a full buffer. 586 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) 587 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); 588 589 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 590 591 return success(); 592 } 593 594 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation()); 595 596 // Decide which location to write the entire vector to. 597 auto memrefAndIndices = getLocationToWriteFullVec( 598 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc); 599 600 // Do an in bounds write to either the output or the extra allocated buffer. 601 // The operation is cloned to prevent deleting information needed for the 602 // later IR creation. 603 BlockAndValueMapping mapping; 604 mapping.map(xferWriteOp.source(), memrefAndIndices.front()); 605 mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front()); 606 auto *clone = b.clone(*xferWriteOp, mapping); 607 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr); 608 609 // Create a potential copy from the allocated buffer to the final output in 610 // the slow path case. 611 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer) 612 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc); 613 else 614 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); 615 616 xferOp->erase(); 617 618 return success(); 619 } 620 621 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite( 622 Operation *op, PatternRewriter &rewriter) const { 623 auto xferOp = dyn_cast<VectorTransferOpInterface>(op); 624 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || 625 failed(filter(xferOp))) 626 return failure(); 627 rewriter.startRootUpdate(xferOp); 628 if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) { 629 rewriter.finalizeRootUpdate(xferOp); 630 return success(); 631 } 632 rewriter.cancelRootUpdate(xferOp); 633 return failure(); 634 } 635