1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector transfer operations to SCF. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 16 17 #include "../PassDetail.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/Utils.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/Dialect/Vector/VectorUtils.h" 23 #include "mlir/IR/Builders.h" 24 #include "mlir/IR/ImplicitLocOpBuilder.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "mlir/Transforms/Passes.h" 28 29 using namespace mlir; 30 using vector::TransferReadOp; 31 using vector::TransferWriteOp; 32 33 namespace { 34 35 /// Attribute name used for labeling transfer ops during progressive lowering. 36 static const char kPassLabel[] = "__vector_to_scf_lowering__"; 37 38 /// Patterns that inherit from this struct have access to 39 /// VectorTransferToSCFOptions. 40 template <typename OpTy> 41 struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 42 explicit VectorToSCFPattern(MLIRContext *context, 43 VectorTransferToSCFOptions opt) 44 : OpRewritePattern<OpTy>(context), options(opt) {} 45 46 VectorTransferToSCFOptions options; 47 }; 48 49 /// Given a vector transfer op, calculate which dimension of the `source` 50 /// memref should be unpacked in the next application of TransferOpConversion. 51 /// A return value of None indicates a broadcast. 52 template <typename OpTy> 53 static Optional<int64_t> unpackedDim(OpTy xferOp) { 54 auto map = xferOp.permutation_map(); 55 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 56 return expr.getPosition(); 57 } 58 assert(xferOp.isBroadcastDim(0) && 59 "Expected AffineDimExpr or AffineConstantExpr"); 60 return None; 61 } 62 63 /// Compute the permutation map for the new (N-1)-D vector transfer op. This 64 /// map is identical to the current permutation map, but the first result is 65 /// omitted. 66 template <typename OpTy> 67 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { 68 auto map = xferOp.permutation_map(); 69 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 70 b.getContext()); 71 } 72 73 /// Calculate the indices for the new vector transfer op. 74 /// 75 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 76 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 77 /// ^^^^^^ 78 /// `iv` is the iteration variable of the (new) surrounding loop. 79 template <typename OpTy> 80 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, 81 SmallVector<Value, 8> &indices) { 82 typename OpTy::Adaptor adaptor(xferOp); 83 // Corresponding memref dim of the vector dim that is unpacked. 84 auto dim = unpackedDim(xferOp); 85 auto prevIndices = adaptor.indices(); 86 indices.append(prevIndices.begin(), prevIndices.end()); 87 88 Location loc = xferOp.getLoc(); 89 bool isBroadcast = !dim.hasValue(); 90 if (!isBroadcast) { 91 AffineExpr d0, d1; 92 bindDims(xferOp.getContext(), d0, d1); 93 Value offset = adaptor.indices()[dim.getValue()]; 94 indices[dim.getValue()] = 95 makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 96 } 97 } 98 99 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, 100 Value value) { 101 if (hasRetVal) { 102 assert(value && "Expected non-empty value"); 103 b.create<scf::YieldOp>(loc, value); 104 } else { 105 b.create<scf::YieldOp>(loc); 106 } 107 } 108 109 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 110 /// is set to true. No such check is generated under following circumstances: 111 /// * xferOp does not have a mask. 112 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 113 /// computed and attached to the new transfer op in the pattern.) 114 /// * The to-be-unpacked dim of xferOp is a broadcast. 115 template <typename OpTy> 116 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { 117 if (!xferOp.mask()) 118 return Value(); 119 if (xferOp.getMaskType().getRank() != 1) 120 return Value(); 121 if (xferOp.isBroadcastDim(0)) 122 return Value(); 123 124 Location loc = xferOp.getLoc(); 125 Value ivI32 = 126 b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv); 127 return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32); 128 } 129 130 /// Helper function TransferOpConversion and TransferOp1dConversion. 131 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 132 /// specified dimension `dim` with the loop iteration variable `iv`. 133 /// E.g., when unpacking dimension 0 from: 134 /// ``` 135 /// %vec = vector.transfer_read %A[%a, %b] %cst 136 /// : vector<5x4xf32>, memref<?x?xf32> 137 /// ``` 138 /// An if check similar to this will be generated inside the loop: 139 /// ``` 140 /// %d = memref.dim %A, %c0 : memref<?x?xf32> 141 /// if (%a + iv < %d) { 142 /// (in-bounds case) 143 /// } else { 144 /// (out-of-bounds case) 145 /// } 146 /// ``` 147 /// 148 /// If the transfer is 1D and has a mask, this function generates a more complex 149 /// check also accounts for potentially masked out elements. 150 /// 151 /// This function variant returns the value returned by `inBoundsCase` or 152 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 153 /// `resultTypes`. 154 template <typename OpTy> 155 static Value generateInBoundsCheck( 156 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 157 TypeRange resultTypes, 158 function_ref<Value(OpBuilder &, Location)> inBoundsCase, 159 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 160 bool hasRetVal = !resultTypes.empty(); 161 Value cond; // Condition to be built... 162 163 // Condition check 1: Access in-bounds? 164 bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. 165 Location loc = xferOp.getLoc(); 166 ImplicitLocOpBuilder lb(xferOp.getLoc(), b); 167 if (!xferOp.isDimInBounds(0) && !isBroadcast) { 168 Value memrefDim = lb.create<memref::DimOp>(xferOp.source(), *dim); 169 AffineExpr d0, d1; 170 bindDims(xferOp.getContext(), d0, d1); 171 Value base = xferOp.indices()[dim.getValue()]; 172 Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); 173 cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx); 174 } 175 176 // Condition check 2: Masked in? 177 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { 178 if (cond) 179 cond = lb.create<AndOp>(cond, maskCond); 180 else 181 cond = maskCond; 182 } 183 184 // If the condition is non-empty, generate an SCF::IfOp. 185 if (cond) { 186 auto check = lb.create<scf::IfOp>( 187 resultTypes, cond, 188 /*thenBuilder=*/ 189 [&](OpBuilder &b, Location loc) { 190 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); 191 }, 192 /*elseBuilder=*/ 193 [&](OpBuilder &b, Location loc) { 194 if (outOfBoundsCase) { 195 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); 196 } else { 197 b.create<scf::YieldOp>(loc); 198 } 199 }); 200 201 return hasRetVal ? check.getResult(0) : Value(); 202 } 203 204 // Condition is empty, no need for an SCF::IfOp. 205 return inBoundsCase(b, loc); 206 } 207 208 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 209 /// a return value. Consequently, this function does not have a return value. 210 template <typename OpTy> 211 static void generateInBoundsCheck( 212 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 213 function_ref<void(OpBuilder &, Location)> inBoundsCase, 214 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 215 generateInBoundsCheck( 216 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), 217 /*inBoundsCase=*/ 218 [&](OpBuilder &b, Location loc) { 219 inBoundsCase(b, loc); 220 return Value(); 221 }, 222 /*outOfBoundsCase=*/ 223 [&](OpBuilder &b, Location loc) { 224 if (outOfBoundsCase) 225 outOfBoundsCase(b, loc); 226 return Value(); 227 }); 228 } 229 230 /// Given an ArrayAttr, return a copy where the first element is dropped. 231 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { 232 if (!attr) 233 return attr; 234 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); 235 } 236 237 /// Add the pass label to a vector transfer op if its rank is not the target 238 /// rank. 239 template <typename OpTy> 240 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, 241 unsigned targetRank) { 242 if (newXferOp.getVectorType().getRank() > targetRank) 243 newXferOp->setAttr(kPassLabel, b.getUnitAttr()); 244 } 245 246 /// Return true if this transfer op operates on a source tensor. 247 template <typename OpTy> 248 static bool isTensorOp(OpTy xferOp) { 249 if (xferOp.getShapedType().template isa<RankedTensorType>()) { 250 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { 251 // TransferWriteOps on tensors have a result. 252 assert(xferOp->getNumResults() > 0); 253 } 254 return true; 255 } 256 return false; 257 } 258 259 namespace lowering_n_d { 260 261 /// Helper data structure for data and mask buffers. 262 struct BufferAllocs { 263 Value dataBuffer; 264 Value maskBuffer; 265 }; 266 267 /// Allocate temporary buffers for data (vector) and mask (if present). 268 /// TODO: Parallelism and threadlocal considerations. 269 template <typename OpTy> 270 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { 271 Location loc = xferOp.getLoc(); 272 OpBuilder::InsertionGuard guard(b); 273 Operation *scope = 274 xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 275 assert(scope && "Expected op to be inside automatic allocation scope"); 276 b.setInsertionPointToStart(&scope->getRegion(0).front()); 277 278 BufferAllocs result; 279 auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 280 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); 281 282 if (xferOp.mask()) { 283 auto maskType = MemRefType::get({}, xferOp.mask().getType()); 284 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); 285 b.setInsertionPoint(xferOp); 286 b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer); 287 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer); 288 } 289 290 return result; 291 } 292 293 /// Given a MemRefType with VectorType element type, unpack one dimension from 294 /// the VectorType into the MemRefType. 295 /// 296 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 297 static MemRefType unpackOneDim(MemRefType type) { 298 auto vectorType = type.getElementType().dyn_cast<VectorType>(); 299 auto memrefShape = type.getShape(); 300 SmallVector<int64_t, 8> newMemrefShape; 301 newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 302 newMemrefShape.push_back(vectorType.getDimSize(0)); 303 return MemRefType::get(newMemrefShape, 304 VectorType::get(vectorType.getShape().drop_front(), 305 vectorType.getElementType())); 306 } 307 308 /// Given a transfer op, find the memref from which the mask is loaded. This 309 /// is similar to Strategy<TransferWriteOp>::getBuffer. 310 template <typename OpTy> 311 static Value getMaskBuffer(OpTy xferOp) { 312 assert(xferOp.mask() && "Expected that transfer op has mask"); 313 auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>(); 314 assert(loadOp && "Expected transfer op mask produced by LoadOp"); 315 return loadOp.getMemRef(); 316 } 317 318 /// Codegen strategy, depending on the operation. 319 template <typename OpTy> 320 struct Strategy; 321 322 /// Code strategy for vector TransferReadOp. 323 template <> 324 struct Strategy<TransferReadOp> { 325 /// Find the StoreOp that is used for writing the current TransferReadOp's 326 /// result to the temporary buffer allocation. 327 static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 328 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 329 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 330 assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 331 return storeOp; 332 } 333 334 /// Find the temporary buffer allocation. All labeled TransferReadOps are 335 /// used like this, where %buf is either the buffer allocation or a type cast 336 /// of the buffer allocation: 337 /// ``` 338 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 339 /// memref.store %vec, %buf[...] ... 340 /// ``` 341 static Value getBuffer(TransferReadOp xferOp) { 342 return getStoreOp(xferOp).getMemRef(); 343 } 344 345 /// Retrieve the indices of the current StoreOp that stores into the buffer. 346 static void getBufferIndices(TransferReadOp xferOp, 347 SmallVector<Value, 8> &indices) { 348 auto storeOp = getStoreOp(xferOp); 349 auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); 350 indices.append(prevIndices.begin(), prevIndices.end()); 351 } 352 353 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 354 /// accesses on the to-be-unpacked dimension. 355 /// 356 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 357 /// variable `iv`. 358 /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 359 /// 360 /// E.g.: 361 /// ``` 362 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 363 /// : memref<?x?x?xf32>, vector<4x3xf32> 364 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 365 /// ``` 366 /// Is rewritten to: 367 /// ``` 368 /// %casted = vector.type_cast %buf 369 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 370 /// for %j = 0 to 4 { 371 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 372 /// : memref<?x?x?xf32>, vector<3xf32> 373 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 374 /// } 375 /// ``` 376 /// 377 /// Note: The loop and type cast are generated in TransferOpConversion. 378 /// The original TransferReadOp and store op are deleted in `cleanup`. 379 /// Note: The `mask` operand is set in TransferOpConversion. 380 static TransferReadOp rewriteOp(OpBuilder &b, 381 VectorTransferToSCFOptions options, 382 TransferReadOp xferOp, Value buffer, Value iv, 383 ValueRange /*loopState*/) { 384 SmallVector<Value, 8> storeIndices; 385 getBufferIndices(xferOp, storeIndices); 386 storeIndices.push_back(iv); 387 388 SmallVector<Value, 8> xferIndices; 389 getXferIndices(b, xferOp, iv, xferIndices); 390 391 Location loc = xferOp.getLoc(); 392 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 393 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 394 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 395 auto newXferOp = b.create<vector::TransferReadOp>( 396 loc, vecType, xferOp.source(), xferIndices, 397 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(), 398 Value(), inBoundsAttr); 399 400 maybeApplyPassLabel(b, newXferOp, options.targetRank); 401 402 b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices); 403 return newXferOp; 404 } 405 406 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 407 /// padding value to the temporary buffer. 408 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, 409 Value buffer, Value iv, 410 ValueRange /*loopState*/) { 411 SmallVector<Value, 8> storeIndices; 412 getBufferIndices(xferOp, storeIndices); 413 storeIndices.push_back(iv); 414 415 Location loc = xferOp.getLoc(); 416 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 417 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 418 auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding()); 419 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); 420 421 return Value(); 422 } 423 424 /// Cleanup after rewriting the op. 425 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, 426 scf::ForOp /*forOp*/) { 427 rewriter.eraseOp(getStoreOp(xferOp)); 428 rewriter.eraseOp(xferOp); 429 } 430 431 /// Return the initial loop state for the generated scf.for loop. 432 static Value initialLoopState(TransferReadOp xferOp) { return Value(); } 433 }; 434 435 /// Codegen strategy for vector TransferWriteOp. 436 template <> 437 struct Strategy<TransferWriteOp> { 438 /// Find the temporary buffer allocation. All labeled TransferWriteOps are 439 /// used like this, where %buf is either the buffer allocation or a type cast 440 /// of the buffer allocation: 441 /// ``` 442 /// %vec = memref.load %buf[...] ... 443 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 444 /// ``` 445 static Value getBuffer(TransferWriteOp xferOp) { 446 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 447 assert(loadOp && "Expected transfer op vector produced by LoadOp"); 448 return loadOp.getMemRef(); 449 } 450 451 /// Retrieve the indices of the current LoadOp that loads from the buffer. 452 static void getBufferIndices(TransferWriteOp xferOp, 453 SmallVector<Value, 8> &indices) { 454 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 455 auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); 456 indices.append(prevIndices.begin(), prevIndices.end()); 457 } 458 459 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 460 /// accesses on the to-be-unpacked dimension. 461 /// 462 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 463 /// using the loop iteration variable `iv`. 464 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 465 /// to memory. 466 /// 467 /// Note: For more details, see comments on Strategy<TransferReadOp>. 468 static TransferWriteOp rewriteOp(OpBuilder &b, 469 VectorTransferToSCFOptions options, 470 TransferWriteOp xferOp, Value buffer, 471 Value iv, ValueRange loopState) { 472 SmallVector<Value, 8> loadIndices; 473 getBufferIndices(xferOp, loadIndices); 474 loadIndices.push_back(iv); 475 476 SmallVector<Value, 8> xferIndices; 477 getXferIndices(b, xferOp, iv, xferIndices); 478 479 Location loc = xferOp.getLoc(); 480 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); 481 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 482 auto source = loopState.empty() ? xferOp.source() : loopState[0]; 483 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 484 auto newXferOp = b.create<vector::TransferWriteOp>( 485 loc, type, vec, source, xferIndices, 486 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 487 inBoundsAttr); 488 489 maybeApplyPassLabel(b, newXferOp, options.targetRank); 490 491 return newXferOp; 492 } 493 494 /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 495 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, 496 Value buffer, Value iv, 497 ValueRange loopState) { 498 return isTensorOp(xferOp) ? loopState[0] : Value(); 499 } 500 501 /// Cleanup after rewriting the op. 502 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, 503 scf::ForOp forOp) { 504 if (isTensorOp(xferOp)) { 505 assert(forOp->getNumResults() == 1 && "Expected one for loop result"); 506 rewriter.replaceOp(xferOp, forOp->getResult(0)); 507 } else { 508 rewriter.eraseOp(xferOp); 509 } 510 } 511 512 /// Return the initial loop state for the generated scf.for loop. 513 static Value initialLoopState(TransferWriteOp xferOp) { 514 return isTensorOp(xferOp) ? xferOp.source() : Value(); 515 } 516 }; 517 518 template <typename OpTy> 519 LogicalResult checkPrepareXferOp(OpTy xferOp, 520 VectorTransferToSCFOptions options) { 521 if (xferOp->hasAttr(kPassLabel)) 522 return failure(); 523 if (xferOp.getVectorType().getRank() <= options.targetRank) 524 return failure(); 525 if (isTensorOp(xferOp) && !options.lowerTensors) 526 return failure(); 527 // Transfer ops that modify the element type are not supported atm. 528 if (xferOp.getVectorType().getElementType() != 529 xferOp.getShapedType().getElementType()) 530 return failure(); 531 return success(); 532 } 533 534 /// Prepare a TransferReadOp for progressive lowering. 535 /// 536 /// 1. Allocate a temporary buffer. 537 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 538 /// 3. Store the result of the TransferReadOp into the temporary buffer. 539 /// 4. Load the result from the temporary buffer and replace all uses of the 540 /// original TransferReadOp with this load. 541 /// 542 /// E.g.: 543 /// ``` 544 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 545 /// : vector<5x4xf32>, memref<?x?x?xf32> 546 /// ``` 547 /// is rewritten to: 548 /// ``` 549 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 550 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 551 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 552 /// memref.store %1, %0[] : memref<vector<5x4xf32>> 553 /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 554 /// ``` 555 /// 556 /// Note: A second temporary buffer may be allocated for the `mask` operand. 557 struct PrepareTransferReadConversion 558 : public VectorToSCFPattern<TransferReadOp> { 559 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 560 561 LogicalResult matchAndRewrite(TransferReadOp xferOp, 562 PatternRewriter &rewriter) const override { 563 if (checkPrepareXferOp(xferOp, options).failed()) 564 return failure(); 565 566 auto buffers = allocBuffers(rewriter, xferOp); 567 auto *newXfer = rewriter.clone(*xferOp.getOperation()); 568 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 569 if (xferOp.mask()) { 570 dyn_cast<TransferReadOp>(newXfer).maskMutable().assign( 571 buffers.maskBuffer); 572 } 573 574 Location loc = xferOp.getLoc(); 575 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), 576 buffers.dataBuffer); 577 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 578 579 return success(); 580 } 581 }; 582 583 /// Prepare a TransferWriteOp for progressive lowering. 584 /// 585 /// 1. Allocate a temporary buffer. 586 /// 2. Store the vector into the buffer. 587 /// 3. Load the vector from the buffer again. 588 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 589 /// marking it eligible for progressive lowering via TransferOpConversion. 590 /// 591 /// E.g.: 592 /// ``` 593 /// vector.transfer_write %vec, %A[%a, %b, %c] 594 /// : vector<5x4xf32>, memref<?x?x?xf32> 595 /// ``` 596 /// is rewritten to: 597 /// ``` 598 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 599 /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 600 /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 601 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 602 /// : vector<5x4xf32>, memref<?x?x?xf32> 603 /// ``` 604 /// 605 /// Note: A second temporary buffer may be allocated for the `mask` operand. 606 struct PrepareTransferWriteConversion 607 : public VectorToSCFPattern<TransferWriteOp> { 608 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 609 610 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 611 PatternRewriter &rewriter) const override { 612 if (checkPrepareXferOp(xferOp, options).failed()) 613 return failure(); 614 615 Location loc = xferOp.getLoc(); 616 auto buffers = allocBuffers(rewriter, xferOp); 617 rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer); 618 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); 619 rewriter.updateRootInPlace(xferOp, [&]() { 620 xferOp.vectorMutable().assign(loadedVec); 621 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 622 }); 623 624 if (xferOp.mask()) { 625 rewriter.updateRootInPlace( 626 xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); }); 627 } 628 629 return success(); 630 } 631 }; 632 633 /// Progressive lowering of vector transfer ops: Unpack one dimension. 634 /// 635 /// 1. Unpack one dimension from the current buffer type and cast the buffer 636 /// to that new type. E.g.: 637 /// ``` 638 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 639 /// vector.transfer_write %vec ... 640 /// ``` 641 /// The following cast is generated: 642 /// ``` 643 /// %casted = vector.type_cast %0 644 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 645 /// ``` 646 /// 2. Generate a for loop and rewrite the transfer op according to the 647 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 648 /// out-of-bounds, generate an if-check and handle both cases separately. 649 /// 3. Clean up according to the corresponding Strategy<OpTy>. 650 /// 651 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor 652 /// source (as opposed to a memref source), then each iteration of the generated 653 /// scf.for loop yields the new tensor value. E.g.: 654 /// ``` 655 /// %result = scf.for i = 0 to 5 { 656 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> 657 /// %1 = vector.transfer_write %0, %source[...] 658 /// : vector<4x3xf32>, tensor<5x4x3xf32> 659 /// scf.yield %1 : tensor<5x4x3xf32> 660 /// } 661 /// ``` 662 template <typename OpTy> 663 struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 664 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 665 666 LogicalResult matchAndRewrite(OpTy xferOp, 667 PatternRewriter &rewriter) const override { 668 if (!xferOp->hasAttr(kPassLabel)) 669 return failure(); 670 671 // Find and cast data buffer. How the buffer can be found depends on OpTy. 672 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); 673 auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 674 auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>(); 675 auto castedDataType = unpackOneDim(dataBufferType); 676 auto castedDataBuffer = 677 locB.create<vector::TypeCastOp>(castedDataType, dataBuffer); 678 679 // If the xferOp has a mask: Find and cast mask buffer. 680 Value castedMaskBuffer; 681 if (xferOp.mask()) { 682 auto maskBuffer = getMaskBuffer(xferOp); 683 auto maskBufferType = 684 maskBuffer.getType().template dyn_cast<MemRefType>(); 685 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 686 // Do not unpack a dimension of the mask, if: 687 // * To-be-unpacked transfer op dimension is a broadcast. 688 // * Mask is 1D, i.e., the mask cannot be further unpacked. 689 // (That means that all remaining dimensions of the transfer op must 690 // be broadcasted.) 691 castedMaskBuffer = maskBuffer; 692 } else { 693 auto castedMaskType = unpackOneDim(maskBufferType); 694 castedMaskBuffer = 695 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); 696 } 697 } 698 699 // Loop bounds and step. 700 auto lb = locB.create<ConstantIndexOp>(0); 701 auto ub = locB.create<ConstantIndexOp>( 702 castedDataType.getDimSize(castedDataType.getRank() - 1)); 703 auto step = locB.create<ConstantIndexOp>(1); 704 // TransferWriteOps that operate on tensors return the modified tensor and 705 // require a loop state. 706 auto loopState = Strategy<OpTy>::initialLoopState(xferOp); 707 708 // Generate for loop. 709 auto result = locB.create<scf::ForOp>( 710 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 711 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 712 Type stateType = loopState.empty() ? Type() : loopState[0].getType(); 713 714 auto result = generateInBoundsCheck( 715 b, xferOp, iv, unpackedDim(xferOp), 716 stateType ? TypeRange(stateType) : TypeRange(), 717 /*inBoundsCase=*/ 718 [&](OpBuilder &b, Location loc) { 719 // Create new transfer op. 720 OpTy newXfer = Strategy<OpTy>::rewriteOp( 721 b, this->options, xferOp, castedDataBuffer, iv, loopState); 722 723 // If old transfer op has a mask: Set mask on new transfer op. 724 // Special case: If the mask of the old transfer op is 1D and 725 // the 726 // unpacked dim is not a broadcast, no mask is 727 // needed on the new transfer op. 728 if (xferOp.mask() && (xferOp.isBroadcastDim(0) || 729 xferOp.getMaskType().getRank() > 1)) { 730 OpBuilder::InsertionGuard guard(b); 731 b.setInsertionPoint(newXfer); // Insert load before newXfer. 732 733 SmallVector<Value, 8> loadIndices; 734 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); 735 // In case of broadcast: Use same indices to load from memref 736 // as before. 737 if (!xferOp.isBroadcastDim(0)) 738 loadIndices.push_back(iv); 739 740 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, 741 loadIndices); 742 rewriter.updateRootInPlace( 743 newXfer, [&]() { newXfer.maskMutable().assign(mask); }); 744 } 745 746 return loopState.empty() ? Value() : newXfer->getResult(0); 747 }, 748 /*outOfBoundsCase=*/ 749 [&](OpBuilder &b, Location /*loc*/) { 750 return Strategy<OpTy>::handleOutOfBoundsDim( 751 b, xferOp, castedDataBuffer, iv, loopState); 752 }); 753 754 maybeYieldValue(b, loc, !loopState.empty(), result); 755 }); 756 757 Strategy<OpTy>::cleanup(rewriter, xferOp, result); 758 return success(); 759 } 760 }; 761 762 } // namespace lowering_n_d 763 764 namespace lowering_n_d_unrolled { 765 766 /// If the original transfer op has a mask, compute the mask of the new transfer 767 /// op (for the current iteration `i`) and assign it. 768 template <typename OpTy> 769 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, 770 int64_t i) { 771 if (!xferOp.mask()) 772 return; 773 774 if (xferOp.isBroadcastDim(0)) { 775 // To-be-unpacked dimension is a broadcast, which does not have a 776 // corresponding mask dimension. Mask attribute remains unchanged. 777 newXferOp.maskMutable().assign(xferOp.mask()); 778 return; 779 } 780 781 if (xferOp.getMaskType().getRank() > 1) { 782 // Unpack one dimension of the mask. 783 OpBuilder::InsertionGuard guard(b); 784 b.setInsertionPoint(newXferOp); // Insert load before newXfer. 785 786 llvm::SmallVector<int64_t, 1> indices({i}); 787 Location loc = xferOp.getLoc(); 788 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices); 789 newXferOp.maskMutable().assign(newMask); 790 } 791 792 // If we end up here: The mask of the old transfer op is 1D and the unpacked 793 // dim is not a broadcast, so no mask is needed on the new transfer op. 794 // `generateInBoundsCheck` will have evaluated the mask already. 795 } 796 797 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 798 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 799 /// memref buffer is allocated and the SCF loop is fully unrolled. 800 /// 801 /// ``` 802 /// E.g.: 803 /// ``` 804 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 805 /// : memref<?x?x?xf32>, vector<5x4xf32> 806 /// ``` 807 /// is rewritten to IR such as (simplified): 808 /// ``` 809 /// %v_init = splat %padding : vector<5x4xf32> 810 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 811 /// : memref<?x?x?xf32>, vector<4xf32> 812 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 813 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 814 /// : memref<?x?x?xf32>, vector<4xf32> 815 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 816 /// ... 817 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 818 /// : memref<?x?x?xf32>, vector<4xf32> 819 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 820 /// ``` 821 /// 822 /// Note: As an optimization, if the result of the original TransferReadOp 823 /// was directly inserted into another vector, no new %v_init vector is created. 824 /// Instead, the new TransferReadOp results are inserted into that vector. 825 struct UnrollTransferReadConversion 826 : public VectorToSCFPattern<TransferReadOp> { 827 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 828 829 /// Return the vector into which the newly created TransferReadOp results 830 /// are inserted. 831 Value getResultVector(TransferReadOp xferOp, 832 PatternRewriter &rewriter) const { 833 if (auto insertOp = getInsertOp(xferOp)) 834 return insertOp.dest(); 835 Location loc = xferOp.getLoc(); 836 return rewriter.create<SplatOp>(loc, xferOp.getVectorType(), 837 xferOp.padding()); 838 } 839 840 /// If the result of the TransferReadOp has exactly one user, which is a 841 /// vector::InsertOp, return that operation. 842 vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 843 if (xferOp->hasOneUse()) { 844 Operation *xferOpUser = *xferOp->getUsers().begin(); 845 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 846 return insertOp; 847 } 848 849 return vector::InsertOp(); 850 } 851 852 /// If the result of the TransferReadOp has exactly one user, which is a 853 /// vector::InsertOp, return that operation's indices. 854 void getInsertionIndices(TransferReadOp xferOp, 855 SmallVector<int64_t, 8> &indices) const { 856 if (auto insertOp = getInsertOp(xferOp)) { 857 llvm::for_each(insertOp.position(), [&](Attribute attr) { 858 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 859 }); 860 } 861 } 862 863 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 864 /// accesses, and broadcasts and transposes in permutation maps. 865 LogicalResult matchAndRewrite(TransferReadOp xferOp, 866 PatternRewriter &rewriter) const override { 867 if (xferOp.getVectorType().getRank() <= options.targetRank) 868 return failure(); 869 if (isTensorOp(xferOp) && !options.lowerTensors) 870 return failure(); 871 // Transfer ops that modify the element type are not supported atm. 872 if (xferOp.getVectorType().getElementType() != 873 xferOp.getShapedType().getElementType()) 874 return failure(); 875 876 auto insertOp = getInsertOp(xferOp); 877 auto vec = getResultVector(xferOp, rewriter); 878 auto vecType = vec.getType().dyn_cast<VectorType>(); 879 auto xferVecType = xferOp.getVectorType(); 880 auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), 881 xferVecType.getElementType()); 882 int64_t dimSize = xferVecType.getShape()[0]; 883 884 // Generate fully unrolled loop of transfer ops. 885 Location loc = xferOp.getLoc(); 886 for (int64_t i = 0; i < dimSize; ++i) { 887 Value iv = rewriter.create<ConstantIndexOp>(loc, i); 888 889 vec = generateInBoundsCheck( 890 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), 891 /*inBoundsCase=*/ 892 [&](OpBuilder &b, Location loc) { 893 // Indices for the new transfer op. 894 SmallVector<Value, 8> xferIndices; 895 getXferIndices(b, xferOp, iv, xferIndices); 896 897 // Indices for the new vector.insert op. 898 SmallVector<int64_t, 8> insertionIndices; 899 getInsertionIndices(xferOp, insertionIndices); 900 insertionIndices.push_back(i); 901 902 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 903 auto newXferOp = b.create<vector::TransferReadOp>( 904 loc, newXferVecType, xferOp.source(), xferIndices, 905 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 906 xferOp.padding(), Value(), inBoundsAttr); 907 maybeAssignMask(b, xferOp, newXferOp, i); 908 return b.create<vector::InsertOp>(loc, newXferOp, vec, 909 insertionIndices); 910 }, 911 /*outOfBoundsCase=*/ 912 [&](OpBuilder &b, Location loc) { 913 // Loop through original (unmodified) vector. 914 return vec; 915 }); 916 } 917 918 if (insertOp) { 919 // Rewrite single user of the old TransferReadOp, which was an InsertOp. 920 rewriter.replaceOp(insertOp, vec); 921 rewriter.eraseOp(xferOp); 922 } else { 923 rewriter.replaceOp(xferOp, vec); 924 } 925 926 return success(); 927 } 928 }; 929 930 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 931 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 932 /// memref buffer is allocated and the SCF loop is fully unrolled. 933 /// 934 /// ``` 935 /// E.g.: 936 /// ``` 937 /// vector.transfer_write %vec, %A[%a, %b, %c] 938 /// : vector<5x4xf32>, memref<?x?x?xf32> 939 /// ``` 940 /// is rewritten to IR such as (simplified): 941 /// ``` 942 /// %v0 = vector.extract %vec[0] : vector<5x4xf32> 943 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 944 /// %v1 = vector.extract %vec[1] : vector<5x4xf32> 945 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 946 /// ... 947 /// %v4 = vector.extract %vec[4] : vector<5x4xf32> 948 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 949 /// ``` 950 /// 951 /// Note: As an optimization, if the vector of the original TransferWriteOp 952 /// was directly extracted from another vector via an ExtractOp `a`, extract 953 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 954 /// doing so, `a` may become dead, and the number of ExtractOps generated during 955 /// recursive application of this pattern will be minimal. 956 struct UnrollTransferWriteConversion 957 : public VectorToSCFPattern<TransferWriteOp> { 958 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 959 960 /// Return the vector from which newly generated ExtracOps will extract. 961 Value getDataVector(TransferWriteOp xferOp) const { 962 if (auto extractOp = getExtractOp(xferOp)) 963 return extractOp.vector(); 964 return xferOp.vector(); 965 } 966 967 /// If the input of the given TransferWriteOp is an ExtractOp, return it. 968 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 969 if (auto *op = xferOp.vector().getDefiningOp()) 970 return dyn_cast<vector::ExtractOp>(op); 971 return vector::ExtractOp(); 972 } 973 974 /// If the input of the given TransferWriteOp is an ExtractOp, return its 975 /// indices. 976 void getExtractionIndices(TransferWriteOp xferOp, 977 SmallVector<int64_t, 8> &indices) const { 978 if (auto extractOp = getExtractOp(xferOp)) { 979 llvm::for_each(extractOp.position(), [&](Attribute attr) { 980 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 981 }); 982 } 983 } 984 985 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 986 /// accesses, and broadcasts and transposes in permutation maps. 987 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 988 PatternRewriter &rewriter) const override { 989 if (xferOp.getVectorType().getRank() <= options.targetRank) 990 return failure(); 991 if (isTensorOp(xferOp) && !options.lowerTensors) 992 return failure(); 993 // Transfer ops that modify the element type are not supported atm. 994 if (xferOp.getVectorType().getElementType() != 995 xferOp.getShapedType().getElementType()) 996 return failure(); 997 998 auto vec = getDataVector(xferOp); 999 auto xferVecType = xferOp.getVectorType(); 1000 int64_t dimSize = xferVecType.getShape()[0]; 1001 auto source = xferOp.source(); // memref or tensor to be written to. 1002 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 1003 1004 // Generate fully unrolled loop of transfer ops. 1005 Location loc = xferOp.getLoc(); 1006 for (int64_t i = 0; i < dimSize; ++i) { 1007 Value iv = rewriter.create<ConstantIndexOp>(loc, i); 1008 1009 auto updatedSource = generateInBoundsCheck( 1010 rewriter, xferOp, iv, unpackedDim(xferOp), 1011 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), 1012 /*inBoundsCase=*/ 1013 [&](OpBuilder &b, Location loc) { 1014 // Indices for the new transfer op. 1015 SmallVector<Value, 8> xferIndices; 1016 getXferIndices(b, xferOp, iv, xferIndices); 1017 1018 // Indices for the new vector.extract op. 1019 SmallVector<int64_t, 8> extractionIndices; 1020 getExtractionIndices(xferOp, extractionIndices); 1021 extractionIndices.push_back(i); 1022 1023 auto extracted = 1024 b.create<vector::ExtractOp>(loc, vec, extractionIndices); 1025 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 1026 auto newXferOp = b.create<vector::TransferWriteOp>( 1027 loc, sourceType, extracted, source, xferIndices, 1028 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 1029 inBoundsAttr); 1030 1031 maybeAssignMask(b, xferOp, newXferOp, i); 1032 1033 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); 1034 }, 1035 /*outOfBoundsCase=*/ 1036 [&](OpBuilder &b, Location loc) { 1037 return isTensorOp(xferOp) ? source : Value(); 1038 }); 1039 1040 if (isTensorOp(xferOp)) 1041 source = updatedSource; 1042 } 1043 1044 if (isTensorOp(xferOp)) 1045 rewriter.replaceOp(xferOp, source); 1046 else 1047 rewriter.eraseOp(xferOp); 1048 1049 return success(); 1050 } 1051 }; 1052 1053 } // namespace lowering_n_d_unrolled 1054 1055 namespace lowering_1_d { 1056 1057 /// Compute the indices into the memref for the LoadOp/StoreOp generated as 1058 /// part of TransferOp1dConversion. Return the memref dimension on which 1059 /// the transfer is operating. A return value of None indicates a broadcast. 1060 template <typename OpTy> 1061 static Optional<int64_t> 1062 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, 1063 SmallVector<Value, 8> &memrefIndices) { 1064 auto indices = xferOp.indices(); 1065 auto map = xferOp.permutation_map(); 1066 1067 memrefIndices.append(indices.begin(), indices.end()); 1068 assert(map.getNumResults() == 1 && 1069 "Expected 1 permutation map result for 1D transfer"); 1070 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 1071 Location loc = xferOp.getLoc(); 1072 auto dim = expr.getPosition(); 1073 AffineExpr d0, d1; 1074 bindDims(xferOp.getContext(), d0, d1); 1075 Value offset = memrefIndices[dim]; 1076 memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 1077 return dim; 1078 } 1079 1080 assert(xferOp.isBroadcastDim(0) && 1081 "Expected AffineDimExpr or AffineConstantExpr"); 1082 return None; 1083 } 1084 1085 /// Codegen strategy for TransferOp1dConversion, depending on the 1086 /// operation. 1087 template <typename OpTy> 1088 struct Strategy1d; 1089 1090 /// Codegen strategy for TransferReadOp. 1091 template <> 1092 struct Strategy1d<TransferReadOp> { 1093 static void generateForLoopBody(OpBuilder &b, Location loc, 1094 TransferReadOp xferOp, Value iv, 1095 ValueRange loopState) { 1096 SmallVector<Value, 8> indices; 1097 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1098 Value ivI32 = 1099 b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv); 1100 auto vec = loopState[0]; 1101 1102 // In case of out-of-bounds access, leave `vec` as is (was initialized with 1103 // padding value). 1104 auto nextVec = generateInBoundsCheck( 1105 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), 1106 /*inBoundsCase=*/ 1107 [&](OpBuilder &b, Location loc) { 1108 Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices); 1109 return b.create<vector::InsertElementOp>(loc, val, vec, ivI32); 1110 }, 1111 /*outOfBoundsCase=*/ 1112 [&](OpBuilder & /*b*/, Location loc) { return vec; }); 1113 b.create<scf::YieldOp>(loc, nextVec); 1114 } 1115 1116 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { 1117 // Inititalize vector with padding value. 1118 Location loc = xferOp.getLoc(); 1119 return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding()); 1120 } 1121 }; 1122 1123 /// Codegen strategy for TransferWriteOp. 1124 template <> 1125 struct Strategy1d<TransferWriteOp> { 1126 static void generateForLoopBody(OpBuilder &b, Location loc, 1127 TransferWriteOp xferOp, Value iv, 1128 ValueRange /*loopState*/) { 1129 SmallVector<Value, 8> indices; 1130 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1131 Value ivI32 = 1132 b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv); 1133 1134 // Nothing to do in case of out-of-bounds access. 1135 generateInBoundsCheck( 1136 b, xferOp, iv, dim, 1137 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 1138 auto val = 1139 b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32); 1140 b.create<memref::StoreOp>(loc, val, xferOp.source(), indices); 1141 }); 1142 b.create<scf::YieldOp>(loc); 1143 } 1144 1145 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { 1146 return Value(); 1147 } 1148 }; 1149 1150 /// Return true if the last dimension of the MemRefType has unit stride. 1151 static bool isLastMemrefDimUnitStride(MemRefType type) { 1152 int64_t offset; 1153 SmallVector<int64_t, 4> strides; 1154 auto successStrides = getStridesAndOffset(type, strides, offset); 1155 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 1156 } 1157 1158 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 1159 /// necessary in cases where a 1D vector transfer op cannot be lowered into 1160 /// vector load/stores due to non-unit strides or broadcasts: 1161 /// 1162 /// * Transfer dimension is not the last memref dimension 1163 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 1164 /// * Memref has a layout map with non-unit stride on the last dimension 1165 /// 1166 /// This pattern generates IR as follows: 1167 /// 1168 /// 1. Generate a for loop iterating over each vector element. 1169 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 1170 /// depending on OpTy. 1171 /// 1172 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 1173 /// can be generated instead of TransferOp1dConversion. Add such a pattern 1174 /// to ConvertVectorToLLVM. 1175 /// 1176 /// E.g.: 1177 /// ``` 1178 /// vector.transfer_write %vec, %A[%a, %b] 1179 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 1180 /// : vector<9xf32>, memref<?x?xf32> 1181 /// ``` 1182 /// Is rewritten to approximately the following pseudo-IR: 1183 /// ``` 1184 /// for i = 0 to 9 { 1185 /// %t = vector.extractelement %vec[i] : vector<9xf32> 1186 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 1187 /// } 1188 /// ``` 1189 template <typename OpTy> 1190 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 1191 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 1192 1193 LogicalResult matchAndRewrite(OpTy xferOp, 1194 PatternRewriter &rewriter) const override { 1195 auto map = xferOp.permutation_map(); 1196 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1197 1198 if (!memRefType) 1199 return failure(); 1200 if (xferOp.getVectorType().getRank() != 1) 1201 return failure(); 1202 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) 1203 return failure(); // Handled by ConvertVectorToLLVM 1204 1205 // Loop bounds, step, state... 1206 Location loc = xferOp.getLoc(); 1207 auto vecType = xferOp.getVectorType(); 1208 auto lb = rewriter.create<ConstantIndexOp>(loc, 0); 1209 auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0)); 1210 auto step = rewriter.create<ConstantIndexOp>(loc, 1); 1211 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); 1212 1213 // Generate for loop. 1214 rewriter.replaceOpWithNewOp<scf::ForOp>( 1215 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 1216 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 1217 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); 1218 }); 1219 1220 return success(); 1221 } 1222 }; 1223 1224 } // namespace lowering_1_d 1225 } // namespace 1226 1227 namespace mlir { 1228 1229 void populateVectorToSCFConversionPatterns( 1230 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 1231 if (options.unroll) { 1232 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, 1233 lowering_n_d_unrolled::UnrollTransferWriteConversion>( 1234 patterns.getContext(), options); 1235 } else { 1236 patterns.add<lowering_n_d::PrepareTransferReadConversion, 1237 lowering_n_d::PrepareTransferWriteConversion, 1238 lowering_n_d::TransferOpConversion<TransferReadOp>, 1239 lowering_n_d::TransferOpConversion<TransferWriteOp>>( 1240 patterns.getContext(), options); 1241 } 1242 1243 if (options.targetRank == 1) { 1244 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, 1245 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( 1246 patterns.getContext(), options); 1247 } 1248 } 1249 1250 } // namespace mlir 1251 1252 namespace { 1253 1254 struct ConvertVectorToSCFPass 1255 : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 1256 ConvertVectorToSCFPass() = default; 1257 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1258 this->fullUnroll = options.unroll; 1259 this->targetRank = options.targetRank; 1260 this->lowerPermutationMaps = options.lowerPermutationMaps; 1261 this->lowerTensors = options.lowerTensors; 1262 } 1263 1264 void runOnFunction() override { 1265 VectorTransferToSCFOptions options; 1266 options.unroll = fullUnroll; 1267 options.targetRank = targetRank; 1268 options.lowerPermutationMaps = lowerPermutationMaps; 1269 options.lowerTensors = lowerTensors; 1270 1271 // Lower permutation maps first. 1272 if (lowerPermutationMaps) { 1273 RewritePatternSet lowerTransferPatterns(getFunction().getContext()); 1274 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 1275 lowerTransferPatterns); 1276 (void)applyPatternsAndFoldGreedily(getFunction(), 1277 std::move(lowerTransferPatterns)); 1278 } 1279 1280 RewritePatternSet patterns(getFunction().getContext()); 1281 populateVectorToSCFConversionPatterns(patterns, options); 1282 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 1283 } 1284 }; 1285 1286 } // namespace 1287 1288 std::unique_ptr<Pass> 1289 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1290 return std::make_unique<ConvertVectorToSCFPass>(options); 1291 } 1292