1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent patterns to rewrite a vector.transfer 10 // op into a fully in-bounds part and a partial part. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <type_traits> 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 24 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Interfaces/VectorInterfaces.h" 28 29 #include "llvm/ADT/DenseSet.h" 30 #include "llvm/ADT/MapVector.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 36 #define DEBUG_TYPE "vector-transfer-split" 37 38 using namespace mlir; 39 using namespace mlir::vector; 40 41 static Optional<int64_t> extractConstantIndex(Value v) { 42 if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>()) 43 return cstOp.value(); 44 if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>()) 45 if (affineApplyOp.getAffineMap().isSingleConstant()) 46 return affineApplyOp.getAffineMap().getSingleConstantResult(); 47 return None; 48 } 49 50 // Missing foldings of scf.if make it necessary to perform poor man's folding 51 // eagerly, especially in the case of unrolling. In the future, this should go 52 // away once scf.if folds properly. 53 static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) { 54 auto maybeCstV = extractConstantIndex(v); 55 auto maybeCstUb = extractConstantIndex(ub); 56 if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb) 57 return Value(); 58 return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub); 59 } 60 61 /// Build the condition to ensure that a particular VectorTransferOpInterface 62 /// is in-bounds. 63 static Value createInBoundsCond(RewriterBase &b, 64 VectorTransferOpInterface xferOp) { 65 assert(xferOp.permutation_map().isMinorIdentity() && 66 "Expected minor identity map"); 67 Value inBoundsCond; 68 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 69 // Zip over the resulting vector shape and memref indices. 70 // If the dimension is known to be in-bounds, it does not participate in 71 // the construction of `inBoundsCond`. 72 if (xferOp.isDimInBounds(resultIdx)) 73 return; 74 // Fold or create the check that `index + vector_size` <= `memref_size`. 75 Location loc = xferOp.getLoc(); 76 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); 77 auto d0 = getAffineDimExpr(0, xferOp.getContext()); 78 auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext()); 79 Value sum = 80 makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]); 81 Value cond = createFoldedSLE( 82 b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx)); 83 if (!cond) 84 return; 85 // Conjunction over all dims for which we are in-bounds. 86 if (inBoundsCond) 87 inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond); 88 else 89 inBoundsCond = cond; 90 }); 91 return inBoundsCond; 92 } 93 94 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 95 /// masking) fastpath and a slowpath. 96 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 97 /// newly created conditional upon function return. 98 /// To accomodate for the fact that the original vector.transfer indexing may be 99 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 100 /// scf.if op returns a view and values of type index. 101 /// At this time, only vector.transfer_read case is implemented. 102 /// 103 /// Example (a 2-D vector.transfer_read): 104 /// ``` 105 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 106 /// ``` 107 /// is transformed into: 108 /// ``` 109 /// %1:3 = scf.if (%inBounds) { 110 /// // fastpath, direct cast 111 /// memref.cast %A: memref<A...> to compatibleMemRefType 112 /// scf.yield %view : compatibleMemRefType, index, index 113 /// } else { 114 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 115 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 116 /// scf.yield %4 : compatibleMemRefType, index, index 117 // } 118 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 119 /// ``` 120 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 121 /// 122 /// Preconditions: 123 /// 1. `xferOp.permutation_map()` must be a minor identity map 124 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` 125 /// must be equal. This will be relaxed in the future but requires 126 /// rank-reducing subviews. 127 static LogicalResult 128 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { 129 // TODO: support 0-d corner case. 130 if (xferOp.getTransferRank() == 0) 131 return failure(); 132 133 // TODO: expand support to these 2 cases. 134 if (!xferOp.permutation_map().isMinorIdentity()) 135 return failure(); 136 // Must have some out-of-bounds dimension to be a candidate for splitting. 137 if (!xferOp.hasOutOfBoundsDim()) 138 return failure(); 139 // Don't split transfer operations directly under IfOp, this avoids applying 140 // the pattern recursively. 141 // TODO: improve the filtering condition to make it more applicable. 142 if (isa<scf::IfOp>(xferOp->getParentOp())) 143 return failure(); 144 return success(); 145 } 146 147 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can 148 /// be cast. If the MemRefTypes don't have the same rank or are not strided, 149 /// return null; otherwise: 150 /// 1. if `aT` and `bT` are cast-compatible, return `aT`. 151 /// 2. else return a new MemRefType obtained by iterating over the shape and 152 /// strides and: 153 /// a. keeping the ones that are static and equal across `aT` and `bT`. 154 /// b. using a dynamic shape and/or stride for the dimensions that don't 155 /// agree. 156 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { 157 if (memref::CastOp::areCastCompatible(aT, bT)) 158 return aT; 159 if (aT.getRank() != bT.getRank()) 160 return MemRefType(); 161 int64_t aOffset, bOffset; 162 SmallVector<int64_t, 4> aStrides, bStrides; 163 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 164 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 165 aStrides.size() != bStrides.size()) 166 return MemRefType(); 167 168 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape(); 169 int64_t resOffset; 170 SmallVector<int64_t, 4> resShape(aT.getRank(), 0), 171 resStrides(bT.getRank(), 0); 172 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { 173 resShape[idx] = 174 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize; 175 resStrides[idx] = (aStrides[idx] == bStrides[idx]) 176 ? aStrides[idx] 177 : ShapedType::kDynamicStrideOrOffset; 178 } 179 resOffset = 180 (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset; 181 return MemRefType::get( 182 resShape, aT.getElementType(), 183 makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext())); 184 } 185 186 /// Operates under a scoped context to build the intersection between the 187 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. 188 // TODO: view intersection/union/differences should be a proper std op. 189 static std::pair<Value, Value> 190 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, 191 Value alloc) { 192 Location loc = xferOp.getLoc(); 193 int64_t memrefRank = xferOp.getShapedType().getRank(); 194 // TODO: relax this precondition, will require rank-reducing subviews. 195 assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() && 196 "Expected memref rank to match the alloc rank"); 197 ValueRange leadingIndices = 198 xferOp.indices().take_front(xferOp.getLeadingShapedRank()); 199 SmallVector<OpFoldResult, 4> sizes; 200 sizes.append(leadingIndices.begin(), leadingIndices.end()); 201 auto isaWrite = isa<vector::TransferWriteOp>(xferOp); 202 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 203 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 204 Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(), 205 xferOp.source(), indicesIdx); 206 Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx); 207 Value index = xferOp.indices()[indicesIdx]; 208 AffineExpr i, j, k; 209 bindDims(xferOp.getContext(), i, j, k); 210 SmallVector<AffineMap, 4> maps = 211 AffineMap::inferFromExprList(MapList{{i - j, k}}); 212 // affine_min(%dimMemRef - %index, %dimAlloc) 213 Value affineMin = b.create<AffineMinOp>( 214 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); 215 sizes.push_back(affineMin); 216 }); 217 218 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range( 219 xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; })); 220 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0)); 221 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1)); 222 auto copySrc = b.create<memref::SubViewOp>( 223 loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides); 224 auto copyDest = b.create<memref::SubViewOp>( 225 loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides); 226 return std::make_pair(copySrc, copyDest); 227 } 228 229 /// Given an `xferOp` for which: 230 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 231 /// 2. a memref of single vector `alloc` has been allocated. 232 /// Produce IR resembling: 233 /// ``` 234 /// %1:3 = scf.if (%inBounds) { 235 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType 236 /// scf.yield %view, ... : compatibleMemRefType, index, index 237 /// } else { 238 /// %2 = linalg.fill(%pad, %alloc) 239 /// %3 = subview %view [...][...][...] 240 /// %4 = subview %alloc [0, 0] [...] [...] 241 /// linalg.copy(%3, %4) 242 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType 243 /// scf.yield %5, ... : compatibleMemRefType, index, index 244 /// } 245 /// ``` 246 /// Return the produced scf::IfOp. 247 static scf::IfOp 248 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, 249 TypeRange returnTypes, Value inBoundsCond, 250 MemRefType compatibleMemRefType, Value alloc) { 251 Location loc = xferOp.getLoc(); 252 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 253 Value memref = xferOp.source(); 254 return b.create<scf::IfOp>( 255 loc, returnTypes, inBoundsCond, 256 [&](OpBuilder &b, Location loc) { 257 Value res = memref; 258 if (compatibleMemRefType != xferOp.getShapedType()) 259 res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType); 260 scf::ValueVector viewAndIndices{res}; 261 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), 262 xferOp.indices().end()); 263 b.create<scf::YieldOp>(loc, viewAndIndices); 264 }, 265 [&](OpBuilder &b, Location loc) { 266 b.create<linalg::FillOp>(loc, xferOp.padding(), alloc); 267 // Take partial subview of memref which guarantees no dimension 268 // overflows. 269 IRRewriter rewriter(b); 270 std::pair<Value, Value> copyArgs = createSubViewIntersection( 271 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 272 alloc); 273 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 274 Value casted = 275 b.create<memref::CastOp>(loc, alloc, compatibleMemRefType); 276 scf::ValueVector viewAndIndices{casted}; 277 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 278 zero); 279 b.create<scf::YieldOp>(loc, viewAndIndices); 280 }); 281 } 282 283 /// Given an `xferOp` for which: 284 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 285 /// 2. a memref of single vector `alloc` has been allocated. 286 /// Produce IR resembling: 287 /// ``` 288 /// %1:3 = scf.if (%inBounds) { 289 /// memref.cast %A: memref<A...> to compatibleMemRefType 290 /// scf.yield %view, ... : compatibleMemRefType, index, index 291 /// } else { 292 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...> 293 /// %3 = vector.type_cast %extra_alloc : 294 /// memref<...> to memref<vector<...>> 295 /// store %2, %3[] : memref<vector<...>> 296 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 297 /// scf.yield %4, ... : compatibleMemRefType, index, index 298 /// } 299 /// ``` 300 /// Return the produced scf::IfOp. 301 static scf::IfOp createFullPartialVectorTransferRead( 302 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, 303 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { 304 Location loc = xferOp.getLoc(); 305 scf::IfOp fullPartialIfOp; 306 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 307 Value memref = xferOp.source(); 308 return b.create<scf::IfOp>( 309 loc, returnTypes, inBoundsCond, 310 [&](OpBuilder &b, Location loc) { 311 Value res = memref; 312 if (compatibleMemRefType != xferOp.getShapedType()) 313 res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType); 314 scf::ValueVector viewAndIndices{res}; 315 viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), 316 xferOp.indices().end()); 317 b.create<scf::YieldOp>(loc, viewAndIndices); 318 }, 319 [&](OpBuilder &b, Location loc) { 320 Operation *newXfer = b.clone(*xferOp.getOperation()); 321 Value vector = cast<VectorTransferOpInterface>(newXfer).vector(); 322 b.create<memref::StoreOp>( 323 loc, vector, 324 b.create<vector::TypeCastOp>( 325 loc, MemRefType::get({}, vector.getType()), alloc)); 326 327 Value casted = 328 b.create<memref::CastOp>(loc, alloc, compatibleMemRefType); 329 scf::ValueVector viewAndIndices{casted}; 330 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 331 zero); 332 b.create<scf::YieldOp>(loc, viewAndIndices); 333 }); 334 } 335 336 /// Given an `xferOp` for which: 337 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 338 /// 2. a memref of single vector `alloc` has been allocated. 339 /// Produce IR resembling: 340 /// ``` 341 /// %1:3 = scf.if (%inBounds) { 342 /// memref.cast %A: memref<A...> to compatibleMemRefType 343 /// scf.yield %view, ... : compatibleMemRefType, index, index 344 /// } else { 345 /// %3 = vector.type_cast %extra_alloc : 346 /// memref<...> to memref<vector<...>> 347 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 348 /// scf.yield %4, ... : compatibleMemRefType, index, index 349 /// } 350 /// ``` 351 static ValueRange 352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, 353 TypeRange returnTypes, Value inBoundsCond, 354 MemRefType compatibleMemRefType, Value alloc) { 355 Location loc = xferOp.getLoc(); 356 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 357 Value memref = xferOp.source(); 358 return b 359 .create<scf::IfOp>( 360 loc, returnTypes, inBoundsCond, 361 [&](OpBuilder &b, Location loc) { 362 Value res = memref; 363 if (compatibleMemRefType != xferOp.getShapedType()) 364 res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType); 365 scf::ValueVector viewAndIndices{res}; 366 viewAndIndices.insert(viewAndIndices.end(), 367 xferOp.indices().begin(), 368 xferOp.indices().end()); 369 b.create<scf::YieldOp>(loc, viewAndIndices); 370 }, 371 [&](OpBuilder &b, Location loc) { 372 Value casted = 373 b.create<memref::CastOp>(loc, alloc, compatibleMemRefType); 374 scf::ValueVector viewAndIndices{casted}; 375 viewAndIndices.insert(viewAndIndices.end(), 376 xferOp.getTransferRank(), zero); 377 b.create<scf::YieldOp>(loc, viewAndIndices); 378 }) 379 ->getResults(); 380 } 381 382 /// Given an `xferOp` for which: 383 /// 1. `inBoundsCond` has been computed. 384 /// 2. a memref of single vector `alloc` has been allocated. 385 /// 3. it originally wrote to %view 386 /// Produce IR resembling: 387 /// ``` 388 /// %notInBounds = arith.xori %inBounds, %true 389 /// scf.if (%notInBounds) { 390 /// %3 = subview %alloc [...][...][...] 391 /// %4 = subview %view [0, 0][...][...] 392 /// linalg.copy(%3, %4) 393 /// } 394 /// ``` 395 static void createFullPartialLinalgCopy(RewriterBase &b, 396 vector::TransferWriteOp xferOp, 397 Value inBoundsCond, Value alloc) { 398 Location loc = xferOp.getLoc(); 399 auto notInBounds = b.create<arith::XOrIOp>( 400 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 401 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 402 IRRewriter rewriter(b); 403 std::pair<Value, Value> copyArgs = createSubViewIntersection( 404 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 405 alloc); 406 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 407 b.create<scf::YieldOp>(loc, ValueRange{}); 408 }); 409 } 410 411 /// Given an `xferOp` for which: 412 /// 1. `inBoundsCond` has been computed. 413 /// 2. a memref of single vector `alloc` has been allocated. 414 /// 3. it originally wrote to %view 415 /// Produce IR resembling: 416 /// ``` 417 /// %notInBounds = arith.xori %inBounds, %true 418 /// scf.if (%notInBounds) { 419 /// %2 = load %alloc : memref<vector<...>> 420 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...> 421 /// } 422 /// ``` 423 static void createFullPartialVectorTransferWrite(RewriterBase &b, 424 vector::TransferWriteOp xferOp, 425 Value inBoundsCond, 426 Value alloc) { 427 Location loc = xferOp.getLoc(); 428 auto notInBounds = b.create<arith::XOrIOp>( 429 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 430 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 431 BlockAndValueMapping mapping; 432 Value load = b.create<memref::LoadOp>( 433 loc, b.create<vector::TypeCastOp>( 434 loc, MemRefType::get({}, xferOp.vector().getType()), alloc)); 435 mapping.map(xferOp.vector(), load); 436 b.clone(*xferOp.getOperation(), mapping); 437 b.create<scf::YieldOp>(loc, ValueRange{}); 438 }); 439 } 440 441 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 442 /// masking) fastpath and a slowpath. 443 /// 444 /// For vector.transfer_read: 445 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 446 /// newly created conditional upon function return. 447 /// To accomodate for the fact that the original vector.transfer indexing may be 448 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 449 /// scf.if op returns a view and values of type index. 450 /// 451 /// Example (a 2-D vector.transfer_read): 452 /// ``` 453 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 454 /// ``` 455 /// is transformed into: 456 /// ``` 457 /// %1:3 = scf.if (%inBounds) { 458 /// // fastpath, direct cast 459 /// memref.cast %A: memref<A...> to compatibleMemRefType 460 /// scf.yield %view : compatibleMemRefType, index, index 461 /// } else { 462 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 463 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 464 /// scf.yield %4 : compatibleMemRefType, index, index 465 // } 466 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 467 /// ``` 468 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 469 /// 470 /// For vector.transfer_write: 471 /// There are 2 conditional blocks. First a block to decide which memref and 472 /// indices to use for an unmasked, inbounds write. Then a conditional block to 473 /// further copy a partial buffer into the final result in the slow path case. 474 /// 475 /// Example (a 2-D vector.transfer_write): 476 /// ``` 477 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...> 478 /// ``` 479 /// is transformed into: 480 /// ``` 481 /// %1:3 = scf.if (%inBounds) { 482 /// memref.cast %A: memref<A...> to compatibleMemRefType 483 /// scf.yield %view : compatibleMemRefType, index, index 484 /// } else { 485 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 486 /// scf.yield %4 : compatibleMemRefType, index, index 487 /// } 488 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ... 489 /// true]} 490 /// scf.if (%notInBounds) { 491 /// // slowpath: not in-bounds vector.transfer or linalg.copy. 492 /// } 493 /// ``` 494 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 495 /// 496 /// Preconditions: 497 /// 1. `xferOp.permutation_map()` must be a minor identity map 498 /// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()` 499 /// must be equal. This will be relaxed in the future but requires 500 /// rank-reducing subviews. 501 LogicalResult mlir::vector::splitFullAndPartialTransfer( 502 RewriterBase &b, VectorTransferOpInterface xferOp, 503 VectorTransformsOptions options, scf::IfOp *ifOp) { 504 if (options.vectorTransferSplit == VectorTransferSplit::None) 505 return failure(); 506 507 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true); 508 auto inBoundsAttr = b.getBoolArrayAttr(bools); 509 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { 510 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 511 return success(); 512 } 513 514 // Assert preconditions. Additionally, keep the variables in an inner scope to 515 // ensure they aren't used in the wrong scopes further down. 516 { 517 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && 518 "Expected splitFullAndPartialTransferPrecondition to hold"); 519 520 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation()); 521 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation()); 522 523 if (!(xferReadOp || xferWriteOp)) 524 return failure(); 525 if (xferWriteOp && xferWriteOp.mask()) 526 return failure(); 527 if (xferReadOp && xferReadOp.mask()) 528 return failure(); 529 } 530 531 RewriterBase::InsertionGuard guard(b); 532 b.setInsertionPoint(xferOp); 533 Value inBoundsCond = createInBoundsCond( 534 b, cast<VectorTransferOpInterface>(xferOp.getOperation())); 535 if (!inBoundsCond) 536 return failure(); 537 538 // Top of the function `alloc` for transient storage. 539 Value alloc; 540 { 541 FuncOp funcOp = xferOp->getParentOfType<FuncOp>(); 542 RewriterBase::InsertionGuard guard(b); 543 b.setInsertionPointToStart(&funcOp.getRegion().front()); 544 auto shape = xferOp.getVectorType().getShape(); 545 Type elementType = xferOp.getVectorType().getElementType(); 546 alloc = b.create<memref::AllocaOp>(funcOp.getLoc(), 547 MemRefType::get(shape, elementType), 548 ValueRange{}, b.getI64IntegerAttr(32)); 549 } 550 551 MemRefType compatibleMemRefType = 552 getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(), 553 alloc.getType().cast<MemRefType>()); 554 if (!compatibleMemRefType) 555 return failure(); 556 557 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(), 558 b.getIndexType()); 559 returnTypes[0] = compatibleMemRefType; 560 561 if (auto xferReadOp = 562 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) { 563 // Read case: full fill + partial copy -> in-bounds vector.xfer_read. 564 scf::IfOp fullPartialIfOp = 565 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer 566 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes, 567 inBoundsCond, 568 compatibleMemRefType, alloc) 569 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes, 570 inBoundsCond, compatibleMemRefType, 571 alloc); 572 if (ifOp) 573 *ifOp = fullPartialIfOp; 574 575 // Set existing read op to in-bounds, it always reads from a full buffer. 576 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) 577 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); 578 579 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 580 581 return success(); 582 } 583 584 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation()); 585 586 // Decide which location to write the entire vector to. 587 auto memrefAndIndices = getLocationToWriteFullVec( 588 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc); 589 590 // Do an in bounds write to either the output or the extra allocated buffer. 591 // The operation is cloned to prevent deleting information needed for the 592 // later IR creation. 593 BlockAndValueMapping mapping; 594 mapping.map(xferWriteOp.source(), memrefAndIndices.front()); 595 mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front()); 596 auto *clone = b.clone(*xferWriteOp, mapping); 597 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr); 598 599 // Create a potential copy from the allocated buffer to the final output in 600 // the slow path case. 601 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer) 602 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc); 603 else 604 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); 605 606 xferOp->erase(); 607 608 return success(); 609 } 610 611 LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite( 612 Operation *op, PatternRewriter &rewriter) const { 613 auto xferOp = dyn_cast<VectorTransferOpInterface>(op); 614 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || 615 failed(filter(xferOp))) 616 return failure(); 617 rewriter.startRootUpdate(xferOp); 618 if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) { 619 rewriter.finalizeRootUpdate(xferOp); 620 return success(); 621 } 622 rewriter.cancelRootUpdate(xferOp); 623 return failure(); 624 } 625