1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector transfer operations to SCF. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 16 17 #include "../PassDetail.h" 18 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 19 #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/SCF/EDSC/Intrinsics.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/Dialect/Vector/VectorUtils.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "mlir/Transforms/Passes.h" 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using vector::TransferReadOp; 34 using vector::TransferWriteOp; 35 36 namespace { 37 38 /// Attribute name used for labeling transfer ops during progressive lowering. 39 static const char kPassLabel[] = "__vector_to_scf_lowering__"; 40 41 /// Lower to 1D transfer ops. Target-specific lowering will lower those. 42 static const int64_t kTargetRank = 1; 43 44 /// Given a MemRefType with VectorType element type, unpack one dimension from 45 /// the VectorType into the MemRefType. 46 /// 47 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 48 static MemRefType unpackOneDim(MemRefType type) { 49 auto vectorType = type.getElementType().dyn_cast<VectorType>(); 50 auto memrefShape = type.getShape(); 51 SmallVector<int64_t, 8> newMemrefShape; 52 newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 53 newMemrefShape.push_back(vectorType.getDimSize(0)); 54 return MemRefType::get(newMemrefShape, 55 VectorType::get(vectorType.getShape().drop_front(), 56 vectorType.getElementType())); 57 } 58 59 /// Helper data structure for data and mask buffers. 60 struct BufferAllocs { 61 Value dataBuffer; 62 Value maskBuffer; 63 }; 64 65 /// Allocate temporary buffers for data (vector) and mask (if present). 66 /// TODO: Parallelism and threadlocal considerations. 67 template <typename OpTy> 68 static BufferAllocs allocBuffers(OpTy xferOp) { 69 auto &b = ScopedContext::getBuilderRef(); 70 OpBuilder::InsertionGuard guard(b); 71 Operation *scope = 72 xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 73 assert(scope && "Expected op to be inside automatic allocation scope"); 74 b.setInsertionPointToStart(&scope->getRegion(0).front()); 75 76 BufferAllocs result; 77 auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 78 result.dataBuffer = memref_alloca(bufferType).value; 79 80 if (xferOp.mask()) { 81 auto maskType = MemRefType::get({}, xferOp.mask().getType()); 82 Value maskBuffer = memref_alloca(maskType); 83 memref_store(xferOp.mask(), maskBuffer); 84 result.maskBuffer = memref_load(maskBuffer); 85 } 86 87 return result; 88 } 89 90 /// Given a vector transfer op, calculate which dimension of the `source` 91 /// memref should be unpacked in the next application of TransferOpConversion. 92 /// A return value of None indicates a broadcast. 93 template <typename OpTy> 94 static Optional<int64_t> unpackedDim(OpTy xferOp) { 95 auto map = xferOp.permutation_map(); 96 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 97 return expr.getPosition(); 98 } 99 assert(xferOp.isBroadcastDim(0) && 100 "Expected AffineDimExpr or AffineConstantExpr"); 101 return None; 102 } 103 104 /// Compute the permutation map for the new (N-1)-D vector transfer op. This 105 /// map is identical to the current permutation map, but the first result is 106 /// omitted. 107 template <typename OpTy> 108 static AffineMap unpackedPermutationMap(OpTy xferOp, OpBuilder &builder) { 109 auto map = xferOp.permutation_map(); 110 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 111 builder.getContext()); 112 } 113 114 /// Calculate the indices for the new vector transfer op. 115 /// 116 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 117 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 118 /// ^^^^^^ 119 /// `iv` is the iteration variable of the (new) surrounding loop. 120 template <typename OpTy> 121 static void getXferIndices(OpTy xferOp, Value iv, 122 SmallVector<Value, 8> &indices) { 123 typename OpTy::Adaptor adaptor(xferOp); 124 // Corresponding memref dim of the vector dim that is unpacked. 125 auto dim = unpackedDim(xferOp); 126 auto prevIndices = adaptor.indices(); 127 indices.append(prevIndices.begin(), prevIndices.end()); 128 129 bool isBroadcast = !dim.hasValue(); 130 if (!isBroadcast) { 131 using edsc::op::operator+; 132 indices[dim.getValue()] = adaptor.indices()[dim.getValue()] + iv; 133 } 134 } 135 136 static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc, 137 Value value) { 138 if (hasRetVal) { 139 builder.create<scf::YieldOp>(loc, value); 140 } else { 141 builder.create<scf::YieldOp>(loc); 142 } 143 } 144 145 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 146 /// is set to true. No such check is generated under following circumstances: 147 /// * xferOp does not have a mask. 148 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 149 /// computed and attached to the new transfer op in the pattern.) 150 /// * The to-be-unpacked dim of xferOp is a broadcast. 151 template <typename OpTy> 152 static Value generateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) { 153 if (!xferOp.mask()) 154 return Value(); 155 if (xferOp.getMaskType().getRank() != 1) 156 return Value(); 157 if (xferOp.isBroadcastDim(0)) 158 return Value(); 159 160 auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); 161 return vector_extract_element(xferOp.mask(), ivI32).value; 162 } 163 164 /// Helper function TransferOpConversion and TransferOp1dConversion. 165 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 166 /// specified dimension `dim` with the loop iteration variable `iv`. 167 /// E.g., when unpacking dimension 0 from: 168 /// ``` 169 /// %vec = vector.transfer_read %A[%a, %b] %cst 170 /// : vector<5x4xf32>, memref<?x?xf32> 171 /// ``` 172 /// An if check similar to this will be generated inside the loop: 173 /// ``` 174 /// %d = memref.dim %A, %c0 : memref<?x?xf32> 175 /// if (%a + iv < %d) { 176 /// (in-bounds case) 177 /// } else { 178 /// (out-of-bounds case) 179 /// } 180 /// ``` 181 /// 182 /// If the transfer is 1D and has a mask, this function generates a more complex 183 /// check also accounts for potentially masked out elements. 184 /// 185 /// This function variant returns the value returned by `inBoundsCase` or 186 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 187 /// `resultTypes`. 188 template <typename OpTy> 189 static Value generateInBoundsCheck( 190 OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim, 191 TypeRange resultTypes, 192 function_ref<Value(OpBuilder &, Location)> inBoundsCase, 193 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 194 bool hasRetVal = !resultTypes.empty(); 195 Value cond; // Condition to be built... 196 197 // Condition check 1: Access in-bounds? 198 bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. 199 if (!xferOp.isDimInBounds(0) && !isBroadcast) { 200 auto memrefDim = 201 memref_dim(xferOp.source(), std_constant_index(dim.getValue())); 202 using edsc::op::operator+; 203 auto memrefIdx = xferOp.indices()[dim.getValue()] + iv; 204 cond = std_cmpi_sgt(memrefDim.value, memrefIdx); 205 } 206 207 // Condition check 2: Masked in? 208 if (auto maskCond = generateMaskCheck(builder, xferOp, iv)) { 209 if (cond) { 210 cond = builder.create<AndOp>(xferOp.getLoc(), cond, maskCond); 211 } else { 212 cond = maskCond; 213 } 214 } 215 216 // If the condition is non-empty, generate an SCF::IfOp. 217 if (cond) { 218 auto check = builder.create<scf::IfOp>( 219 xferOp.getLoc(), resultTypes, cond, 220 /*thenBuilder=*/ 221 [&](OpBuilder &builder, Location loc) { 222 maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc)); 223 }, 224 /*elseBuilder=*/ 225 [&](OpBuilder &builder, Location loc) { 226 if (outOfBoundsCase) { 227 maybeYieldValue(hasRetVal, builder, loc, 228 outOfBoundsCase(builder, loc)); 229 } else { 230 builder.create<scf::YieldOp>(loc); 231 } 232 }); 233 234 return hasRetVal ? check.getResult(0) : Value(); 235 } 236 237 // Condition is empty, no need for an SCF::IfOp. 238 return inBoundsCase(builder, xferOp.getLoc()); 239 } 240 241 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 242 /// a return value. Consequently, this function does not have a return value. 243 template <typename OpTy> 244 static void generateInBoundsCheck( 245 OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim, 246 function_ref<void(OpBuilder &, Location)> inBoundsCase, 247 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 248 generateInBoundsCheck( 249 xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(), 250 /*inBoundsCase=*/ 251 [&](OpBuilder &builder, Location loc) { 252 inBoundsCase(builder, loc); 253 return Value(); 254 }, 255 /*outOfBoundsCase=*/ 256 [&](OpBuilder &builder, Location loc) { 257 if (outOfBoundsCase) 258 outOfBoundsCase(builder, loc); 259 return Value(); 260 }); 261 } 262 263 /// Given an ArrayAttr, return a copy where the first element is dropped. 264 static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) { 265 if (!attr) 266 return attr; 267 return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front()); 268 } 269 270 /// Add the pass label to a vector transfer op if its rank is not the target 271 /// rank. 272 template <typename OpTy> 273 static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp) { 274 if (newXferOp.getVectorType().getRank() > kTargetRank) 275 newXferOp->setAttr(kPassLabel, builder.getUnitAttr()); 276 } 277 278 /// Given a transfer op, find the memref from which the mask is loaded. This 279 /// is similar to Strategy<TransferWriteOp>::getBuffer. 280 template <typename OpTy> 281 static Value getMaskBuffer(OpTy xferOp) { 282 assert(xferOp.mask() && "Expected that transfer op has mask"); 283 auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>(); 284 assert(loadOp && "Expected transfer op mask produced by LoadOp"); 285 return loadOp.getMemRef(); 286 } 287 288 /// Codegen strategy, depending on the operation. 289 template <typename OpTy> 290 struct Strategy; 291 292 /// Code strategy for vector TransferReadOp. 293 template <> 294 struct Strategy<TransferReadOp> { 295 /// Find the StoreOp that is used for writing the current TransferReadOp's 296 /// result to the temporary buffer allocation. 297 static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 298 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 299 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 300 assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 301 return storeOp; 302 } 303 304 /// Find the temporary buffer allocation. All labeled TransferReadOps are 305 /// used like this, where %buf is either the buffer allocation or a type cast 306 /// of the buffer allocation: 307 /// ``` 308 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 309 /// memref.store %vec, %buf[...] ... 310 /// ``` 311 static Value getBuffer(TransferReadOp xferOp) { 312 return getStoreOp(xferOp).getMemRef(); 313 } 314 315 /// Retrieve the indices of the current StoreOp that stores into the buffer. 316 static void getBufferIndices(TransferReadOp xferOp, 317 SmallVector<Value, 8> &indices) { 318 auto storeOp = getStoreOp(xferOp); 319 auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); 320 indices.append(prevIndices.begin(), prevIndices.end()); 321 } 322 323 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 324 /// accesses on the to-be-unpacked dimension. 325 /// 326 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 327 /// variable `iv`. 328 /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 329 /// 330 /// E.g.: 331 /// ``` 332 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 333 /// : memref<?x?x?xf32>, vector<4x3xf32> 334 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 335 /// ``` 336 /// Is rewritten to: 337 /// ``` 338 /// %casted = vector.type_cast %buf 339 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 340 /// for %j = 0 to 4 { 341 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 342 /// : memref<?x?x?xf32>, vector<3xf32> 343 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 344 /// } 345 /// ``` 346 /// 347 /// Note: The loop and type cast are generated in TransferOpConversion. 348 /// The original TransferReadOp and store op are deleted in `cleanup`. 349 /// Note: The `mask` operand is set in TransferOpConversion. 350 static TransferReadOp rewriteOp(OpBuilder &builder, TransferReadOp xferOp, 351 Value buffer, Value iv) { 352 SmallVector<Value, 8> storeIndices; 353 getBufferIndices(xferOp, storeIndices); 354 storeIndices.push_back(iv); 355 356 SmallVector<Value, 8> xferIndices; 357 getXferIndices(xferOp, iv, xferIndices); 358 359 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 360 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 361 auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); 362 auto newXfer = 363 vector_transfer_read( 364 vecType, xferOp.source(), xferIndices, 365 AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), 366 xferOp.padding(), Value(), inBoundsAttr) 367 .value; 368 369 maybeApplyPassLabel(builder, 370 dyn_cast<TransferReadOp>(newXfer.getDefiningOp())); 371 372 memref_store(newXfer, buffer, storeIndices); 373 return newXfer.getDefiningOp<TransferReadOp>(); 374 } 375 376 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 377 /// padding value to the temporary buffer. 378 static void handleOutOfBoundsDim(OpBuilder & /*builder*/, 379 TransferReadOp xferOp, Value buffer, 380 Value iv) { 381 SmallVector<Value, 8> storeIndices; 382 getBufferIndices(xferOp, storeIndices); 383 storeIndices.push_back(iv); 384 385 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 386 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 387 auto vec = std_splat(vecType, xferOp.padding()); 388 memref_store(vec, buffer, storeIndices); 389 } 390 391 /// Cleanup after rewriting the op. 392 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) { 393 rewriter.eraseOp(getStoreOp(xferOp)); 394 rewriter.eraseOp(xferOp); 395 } 396 }; 397 398 /// Codegen strategy for vector TransferWriteOp. 399 template <> 400 struct Strategy<TransferWriteOp> { 401 /// Find the temporary buffer allocation. All labeled TransferWriteOps are 402 /// used like this, where %buf is either the buffer allocation or a type cast 403 /// of the buffer allocation: 404 /// ``` 405 /// %vec = memref.load %buf[...] ... 406 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 407 /// ``` 408 static Value getBuffer(TransferWriteOp xferOp) { 409 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 410 assert(loadOp && "Expected transfer op vector produced by LoadOp"); 411 return loadOp.getMemRef(); 412 } 413 414 /// Retrieve the indices of the current LoadOp that loads from the buffer. 415 static void getBufferIndices(TransferWriteOp xferOp, 416 SmallVector<Value, 8> &indices) { 417 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 418 auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); 419 indices.append(prevIndices.begin(), prevIndices.end()); 420 } 421 422 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 423 /// accesses on the to-be-unpacked dimension. 424 /// 425 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 426 /// using the loop iteration variable `iv`. 427 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 428 /// to memory. 429 /// 430 /// Note: For more details, see comments on Strategy<TransferReadOp>. 431 static TransferWriteOp rewriteOp(OpBuilder &builder, TransferWriteOp xferOp, 432 Value buffer, Value iv) { 433 SmallVector<Value, 8> loadIndices; 434 getBufferIndices(xferOp, loadIndices); 435 loadIndices.push_back(iv); 436 437 SmallVector<Value, 8> xferIndices; 438 getXferIndices(xferOp, iv, xferIndices); 439 440 auto vec = memref_load(buffer, loadIndices); 441 auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); 442 auto newXfer = vector_transfer_write( 443 Type(), vec, xferOp.source(), xferIndices, 444 AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(), 445 inBoundsAttr); 446 447 maybeApplyPassLabel(builder, newXfer.op); 448 449 return newXfer; 450 } 451 452 /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 453 static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp, 454 Value buffer, Value iv) {} 455 456 /// Cleanup after rewriting the op. 457 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) { 458 rewriter.eraseOp(xferOp); 459 } 460 }; 461 462 template <typename OpTy> 463 LogicalResult checkPrepareXferOp(OpTy xferOp) { 464 if (xferOp->hasAttr(kPassLabel)) 465 return failure(); 466 if (xferOp.getVectorType().getRank() <= kTargetRank) 467 return failure(); 468 return success(); 469 } 470 471 /// Prepare a TransferReadOp for progressive lowering. 472 /// 473 /// 1. Allocate a temporary buffer. 474 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 475 /// 3. Store the result of the TransferReadOp into the temporary buffer. 476 /// 4. Load the result from the temporary buffer and replace all uses of the 477 /// original TransferReadOp with this load. 478 /// 479 /// E.g.: 480 /// ``` 481 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 482 /// : vector<5x4xf32>, memref<?x?x?xf32> 483 /// ``` 484 /// is rewritten to: 485 /// ``` 486 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 487 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 488 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 489 /// memref.store %1, %0[] : memref<vector<5x4xf32>> 490 /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 491 /// ``` 492 /// 493 /// Note: A second temporary buffer may be allocated for the `mask` operand. 494 struct PrepareTransferReadConversion : public OpRewritePattern<TransferReadOp> { 495 using OpRewritePattern<TransferReadOp>::OpRewritePattern; 496 497 LogicalResult matchAndRewrite(TransferReadOp xferOp, 498 PatternRewriter &rewriter) const override { 499 if (checkPrepareXferOp(xferOp).failed()) 500 return failure(); 501 502 ScopedContext scope(rewriter, xferOp.getLoc()); 503 auto buffers = allocBuffers(xferOp); 504 auto *newXfer = rewriter.clone(*xferOp.getOperation()); 505 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 506 if (xferOp.mask()) { 507 dyn_cast<TransferReadOp>(newXfer).maskMutable().assign( 508 buffers.maskBuffer); 509 } 510 511 memref_store(newXfer->getResult(0), buffers.dataBuffer); 512 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 513 514 return success(); 515 } 516 }; 517 518 /// Prepare a TransferWriteOp for progressive lowering. 519 /// 520 /// 1. Allocate a temporary buffer. 521 /// 2. Store the vector into the buffer. 522 /// 3. Load the vector from the buffer again. 523 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 524 /// marking it eligible for progressive lowering via TransferOpConversion. 525 /// 526 /// E.g.: 527 /// ``` 528 /// vector.transfer_write %vec, %A[%a, %b, %c] 529 /// : vector<5x4xf32>, memref<?x?x?xf32> 530 /// ``` 531 /// is rewritten to: 532 /// ``` 533 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 534 /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 535 /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 536 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 537 /// : vector<5x4xf32>, memref<?x?x?xf32> 538 /// ``` 539 /// 540 /// Note: A second temporary buffer may be allocated for the `mask` operand. 541 struct PrepareTransferWriteConversion 542 : public OpRewritePattern<TransferWriteOp> { 543 using OpRewritePattern<TransferWriteOp>::OpRewritePattern; 544 545 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 546 PatternRewriter &rewriter) const override { 547 if (checkPrepareXferOp(xferOp).failed()) 548 return failure(); 549 550 ScopedContext scope(rewriter, xferOp.getLoc()); 551 auto buffers = allocBuffers(xferOp); 552 memref_store(xferOp.vector(), buffers.dataBuffer); 553 auto loadedVec = memref_load(buffers.dataBuffer); 554 rewriter.updateRootInPlace(xferOp, [&]() { 555 xferOp.vectorMutable().assign(loadedVec); 556 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 557 }); 558 559 if (xferOp.mask()) { 560 rewriter.updateRootInPlace( 561 xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); }); 562 } 563 564 return success(); 565 } 566 }; 567 568 /// Progressive lowering of vector transfer ops: Unpack one dimension. 569 /// 570 /// 1. Unpack one dimension from the current buffer type and cast the buffer 571 /// to that new type. E.g.: 572 /// ``` 573 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 574 /// vector.transfer_write %vec ... 575 /// ``` 576 /// The following cast is generated: 577 /// ``` 578 /// %casted = vector.type_cast %0 579 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 580 /// ``` 581 /// 2. Generate a for loop and rewrite the transfer op according to the 582 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 583 /// out-of-bounds, generate an if-check and handle both cases separately. 584 /// 3. Clean up according to the corresponding Strategy<OpTy>. 585 template <typename OpTy> 586 struct TransferOpConversion : public OpRewritePattern<OpTy> { 587 using OpRewritePattern<OpTy>::OpRewritePattern; 588 589 LogicalResult matchAndRewrite(OpTy xferOp, 590 PatternRewriter &rewriter) const override { 591 if (!xferOp->hasAttr(kPassLabel)) 592 return failure(); 593 594 ScopedContext scope(rewriter, xferOp.getLoc()); 595 596 // Find and cast data buffer. How the buffer can be found depends on OpTy. 597 auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 598 auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>(); 599 auto castedDataType = unpackOneDim(dataBufferType); 600 auto castedDataBuffer = vector_type_cast(castedDataType, dataBuffer); 601 602 // If the xferOp has a mask: Find and cast mask buffer. 603 Value castedMaskBuffer; 604 if (xferOp.mask()) { 605 auto maskBuffer = getMaskBuffer(xferOp); 606 auto maskBufferType = 607 maskBuffer.getType().template dyn_cast<MemRefType>(); 608 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 609 // Do not unpack a dimension of the mask, if: 610 // * To-be-unpacked transfer op dimension is a broadcast. 611 // * Mask is 1D, i.e., the mask cannot be further unpacked. 612 // (That means that all remaining dimensions of the transfer op must 613 // be broadcasted.) 614 castedMaskBuffer = maskBuffer; 615 } else { 616 auto castedMaskType = unpackOneDim(maskBufferType); 617 castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer); 618 } 619 } 620 621 // Loop bounds and step. 622 auto lb = std_constant_index(0).value; 623 auto ub = std_constant_index( 624 castedDataType.getDimSize(castedDataType.getRank() - 1)) 625 .value; 626 auto step = std_constant_index(1).value; 627 628 // Generate for loop. 629 rewriter.create<scf::ForOp>( 630 xferOp.getLoc(), lb, ub, step, ValueRange(), 631 [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) { 632 ScopedContext scope(b, loc); 633 generateInBoundsCheck( 634 xferOp, iv, b, unpackedDim(xferOp), 635 /*inBoundsCase=*/ 636 [&](OpBuilder &b, Location /*loc*/) { 637 // Create new transfer op. 638 OpTy newXfer = 639 Strategy<OpTy>::rewriteOp(b, xferOp, castedDataBuffer, iv); 640 641 // If old transfer op has a mask: Set mask on new transfer op. 642 // Special case: If the mask of the old transfer op is 1D and 643 // the 644 // unpacked dim is not a broadcast, no mask is 645 // needed on the new transfer op. 646 if (xferOp.mask() && (xferOp.isBroadcastDim(0) || 647 xferOp.getMaskType().getRank() > 1)) { 648 OpBuilder::InsertionGuard guard(b); 649 b.setInsertionPoint(newXfer); // Insert load before newXfer. 650 651 SmallVector<Value, 8> loadIndices; 652 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); 653 // In case of broadcast: Use same indices to load from memref 654 // as before. 655 if (!xferOp.isBroadcastDim(0)) 656 loadIndices.push_back(iv); 657 658 auto mask = memref_load(castedMaskBuffer, loadIndices); 659 rewriter.updateRootInPlace( 660 newXfer, [&]() { newXfer.maskMutable().assign(mask); }); 661 } 662 }, 663 /*outOfBoundsCase=*/ 664 [&](OpBuilder &b, Location /*loc*/) { 665 Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, 666 castedDataBuffer, iv); 667 }); 668 b.create<scf::YieldOp>(loc); 669 }); 670 671 Strategy<OpTy>::cleanup(rewriter, xferOp); 672 return success(); 673 } 674 }; 675 676 /// If the original transfer op has a mask, compute the mask of the new transfer 677 /// op (for the current iteration `i`) and assign it. 678 template <typename OpTy> 679 static void maybeAssignMask(OpBuilder &builder, OpTy xferOp, OpTy newXferOp, 680 int64_t i) { 681 if (!xferOp.mask()) 682 return; 683 684 if (xferOp.isBroadcastDim(0)) { 685 // To-be-unpacked dimension is a broadcast, which does not have a 686 // corresponding mask dimension. Mask attribute remains unchanged. 687 newXferOp.maskMutable().assign(xferOp.mask()); 688 return; 689 } 690 691 if (xferOp.getMaskType().getRank() > 1) { 692 // Unpack one dimension of the mask. 693 OpBuilder::InsertionGuard guard(builder); 694 builder.setInsertionPoint(newXferOp); // Insert load before newXfer. 695 696 llvm::SmallVector<int64_t, 1> indices({i}); 697 auto newMask = vector_extract(xferOp.mask(), indices).value; 698 newXferOp.maskMutable().assign(newMask); 699 } 700 701 // If we end up here: The mask of the old transfer op is 1D and the unpacked 702 // dim is not a broadcast, so no mask is needed on the new transfer op. 703 // `generateInBoundsCheck` will have evaluated the mask already. 704 } 705 706 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 707 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 708 /// memref buffer is allocated and the SCF loop is fully unrolled. 709 /// 710 /// ``` 711 /// E.g.: 712 /// ``` 713 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 714 /// : memref<?x?x?xf32>, vector<5x4xf32> 715 /// ``` 716 /// is rewritten to IR such as (simplified): 717 /// ``` 718 /// %v_init = splat %padding : vector<5x4xf32> 719 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 720 /// : memref<?x?x?xf32>, vector<4xf32> 721 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 722 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 723 /// : memref<?x?x?xf32>, vector<4xf32> 724 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 725 /// ... 726 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 727 /// : memref<?x?x?xf32>, vector<4xf32> 728 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 729 /// ``` 730 /// 731 /// Note: As an optimization, if the result of the original TransferReadOp 732 /// was directly inserted into another vector, no new %v_init vector is created. 733 /// Instead, the new TransferReadOp results are inserted into that vector. 734 struct UnrollTransferReadConversion : public OpRewritePattern<TransferReadOp> { 735 using OpRewritePattern<TransferReadOp>::OpRewritePattern; 736 737 /// Return the vector into which the newly created TransferReadOp results 738 /// are inserted. 739 Value getResultVector(TransferReadOp xferOp, 740 PatternRewriter &rewriter) const { 741 if (auto insertOp = getInsertOp(xferOp)) 742 return insertOp.dest(); 743 return std_splat(xferOp.getVectorType(), xferOp.padding()).value; 744 } 745 746 /// If the result of the TransferReadOp has exactly one user, which is a 747 /// vector::InsertOp, return that operation. 748 vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 749 if (xferOp->hasOneUse()) { 750 Operation *xferOpUser = *xferOp->getUsers().begin(); 751 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 752 return insertOp; 753 } 754 755 return vector::InsertOp(); 756 } 757 758 /// If the result of the TransferReadOp has exactly one user, which is a 759 /// vector::InsertOp, return that operation's indices. 760 void getInsertionIndices(TransferReadOp xferOp, 761 SmallVector<int64_t, 8> &indices) const { 762 if (auto insertOp = getInsertOp(xferOp)) { 763 llvm::for_each(insertOp.position(), [&](Attribute attr) { 764 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 765 }); 766 } 767 } 768 769 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 770 /// accesses, and broadcasts and transposes in permutation maps. 771 LogicalResult matchAndRewrite(TransferReadOp xferOp, 772 PatternRewriter &rewriter) const override { 773 if (xferOp.getVectorType().getRank() <= kTargetRank) 774 return failure(); 775 776 ScopedContext scope(rewriter, xferOp.getLoc()); 777 auto insertOp = getInsertOp(xferOp); 778 auto vec = getResultVector(xferOp, rewriter); 779 auto vecType = vec.getType().dyn_cast<VectorType>(); 780 auto xferVecType = xferOp.getVectorType(); 781 auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), 782 xferVecType.getElementType()); 783 int64_t dimSize = xferVecType.getShape()[0]; 784 785 // Generate fully unrolled loop of transfer ops. 786 for (int64_t i = 0; i < dimSize; ++i) { 787 Value iv = std_constant_index(i); 788 789 vec = generateInBoundsCheck( 790 xferOp, iv, rewriter, unpackedDim(xferOp), TypeRange(vecType), 791 /*inBoundsCase=*/ 792 [&](OpBuilder &b, Location loc) { 793 ScopedContext scope(b, loc); 794 795 // Indices for the new transfer op. 796 SmallVector<Value, 8> xferIndices; 797 getXferIndices(xferOp, iv, xferIndices); 798 799 // Indices for the new vector.insert op. 800 SmallVector<int64_t, 8> insertionIndices; 801 getInsertionIndices(xferOp, insertionIndices); 802 insertionIndices.push_back(i); 803 804 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 805 auto newXferOpVal = 806 vector_transfer_read( 807 newXferVecType, xferOp.source(), xferIndices, 808 AffineMapAttr::get(unpackedPermutationMap(xferOp, b)), 809 xferOp.padding(), Value(), inBoundsAttr) 810 .value; 811 auto newXferOp = 812 dyn_cast<TransferReadOp>(newXferOpVal.getDefiningOp()); 813 814 maybeAssignMask(b, xferOp, newXferOp, i); 815 816 return vector_insert(newXferOp, vec, insertionIndices).value; 817 }, 818 /*outOfBoundsCase=*/ 819 [&](OpBuilder &b, Location loc) { 820 // Loop through original (unmodified) vector. 821 return vec; 822 }); 823 } 824 825 if (insertOp) { 826 // Rewrite single user of the old TransferReadOp, which was an InsertOp. 827 rewriter.replaceOp(insertOp, vec); 828 rewriter.eraseOp(xferOp); 829 } else { 830 rewriter.replaceOp(xferOp, vec); 831 } 832 833 return success(); 834 } 835 }; 836 837 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 838 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 839 /// memref buffer is allocated and the SCF loop is fully unrolled. 840 /// 841 /// ``` 842 /// E.g.: 843 /// ``` 844 /// vector.transfer_write %vec, %A[%a, %b, %c] 845 /// : vector<5x4xf32>, memref<?x?x?xf32> 846 /// ``` 847 /// is rewritten to IR such as (simplified): 848 /// ``` 849 /// %v0 = vector.extract %vec[0] : vector<5x4xf32> 850 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 851 /// %v1 = vector.extract %vec[1] : vector<5x4xf32> 852 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 853 /// ... 854 /// %v4 = vector.extract %vec[4] : vector<5x4xf32> 855 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 856 /// ``` 857 /// 858 /// Note: As an optimization, if the vector of the original TransferWriteOp 859 /// was directly extracted from another vector via an ExtractOp `a`, extract 860 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 861 /// doing so, `a` may become dead, and the number of ExtractOps generated during 862 /// recursive application of this pattern will be minimal. 863 struct UnrollTransferWriteConversion 864 : public OpRewritePattern<TransferWriteOp> { 865 using OpRewritePattern<TransferWriteOp>::OpRewritePattern; 866 867 /// Return the vector from which newly generated ExtracOps will extract. 868 Value getDataVector(TransferWriteOp xferOp) const { 869 if (auto extractOp = getExtractOp(xferOp)) 870 return extractOp.vector(); 871 return xferOp.vector(); 872 } 873 874 /// If the input of the given TransferWriteOp is an ExtractOp, return it. 875 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 876 if (auto *op = xferOp.vector().getDefiningOp()) 877 return dyn_cast<vector::ExtractOp>(op); 878 return vector::ExtractOp(); 879 } 880 881 /// If the input of the given TransferWriteOp is an ExtractOp, return its 882 /// indices. 883 void getExtractionIndices(TransferWriteOp xferOp, 884 SmallVector<int64_t, 8> &indices) const { 885 if (auto extractOp = getExtractOp(xferOp)) { 886 llvm::for_each(extractOp.position(), [&](Attribute attr) { 887 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 888 }); 889 } 890 } 891 892 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 893 /// accesses, and broadcasts and transposes in permutation maps. 894 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 895 PatternRewriter &rewriter) const override { 896 if (xferOp.getVectorType().getRank() <= kTargetRank) 897 return failure(); 898 899 ScopedContext scope(rewriter, xferOp.getLoc()); 900 auto vec = getDataVector(xferOp); 901 auto xferVecType = xferOp.getVectorType(); 902 int64_t dimSize = xferVecType.getShape()[0]; 903 904 // Generate fully unrolled loop of transfer ops. 905 for (int64_t i = 0; i < dimSize; ++i) { 906 Value iv = std_constant_index(i); 907 908 generateInBoundsCheck( 909 xferOp, iv, rewriter, unpackedDim(xferOp), 910 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 911 ScopedContext scope(b, loc); 912 913 // Indices for the new transfer op. 914 SmallVector<Value, 8> xferIndices; 915 getXferIndices(xferOp, iv, xferIndices); 916 917 // Indices for the new vector.extract op. 918 SmallVector<int64_t, 8> extractionIndices; 919 getExtractionIndices(xferOp, extractionIndices); 920 extractionIndices.push_back(i); 921 922 auto extracted = vector_extract(vec, extractionIndices).value; 923 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 924 925 auto newXferOp = 926 vector_transfer_write( 927 Type(), extracted, xferOp.source(), xferIndices, 928 AffineMapAttr::get(unpackedPermutationMap(xferOp, b)), 929 Value(), inBoundsAttr) 930 .op; 931 932 maybeAssignMask(b, xferOp, newXferOp, i); 933 }); 934 } 935 936 rewriter.eraseOp(xferOp); 937 return success(); 938 } 939 }; 940 941 /// Compute the indices into the memref for the LoadOp/StoreOp generated as 942 /// part of TransferOp1dConversion. Return the memref dimension on which 943 /// the transfer is operating. A return value of None indicates a broadcast. 944 template <typename OpTy> 945 static Optional<int64_t> 946 get1dMemrefIndices(OpTy xferOp, Value iv, 947 SmallVector<Value, 8> &memrefIndices) { 948 auto indices = xferOp.indices(); 949 auto map = xferOp.permutation_map(); 950 951 memrefIndices.append(indices.begin(), indices.end()); 952 assert(map.getNumResults() == 1 && 953 "Expected 1 permutation map result for 1D transfer"); 954 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 955 auto dim = expr.getPosition(); 956 using edsc::op::operator+; 957 memrefIndices[dim] = memrefIndices[dim] + iv; 958 return dim; 959 } 960 961 assert(xferOp.isBroadcastDim(0) && 962 "Expected AffineDimExpr or AffineConstantExpr"); 963 return None; 964 } 965 966 /// Codegen strategy for TransferOp1dConversion, depending on the 967 /// operation. 968 template <typename OpTy> 969 struct Strategy1d; 970 971 /// Codegen strategy for TransferReadOp. 972 template <> 973 struct Strategy1d<TransferReadOp> { 974 static void generateForLoopBody(OpBuilder &builder, Location loc, 975 TransferReadOp xferOp, Value iv, 976 ValueRange loopState) { 977 SmallVector<Value, 8> indices; 978 auto dim = get1dMemrefIndices(xferOp, iv, indices); 979 auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); 980 auto vec = loopState[0]; 981 982 // In case of out-of-bounds access, leave `vec` as is (was initialized with 983 // padding value). 984 auto nextVec = generateInBoundsCheck( 985 xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()), 986 /*inBoundsCase=*/ 987 [&](OpBuilder & /*b*/, Location loc) { 988 auto val = memref_load(xferOp.source(), indices); 989 return vector_insert_element(val, vec, ivI32.value).value; 990 }, 991 /*outOfBoundsCase=*/ 992 [&](OpBuilder & /*b*/, Location loc) { return vec; }); 993 builder.create<scf::YieldOp>(loc, nextVec); 994 } 995 996 static Value initialLoopState(TransferReadOp xferOp) { 997 // Inititalize vector with padding value. 998 return std_splat(xferOp.getVectorType(), xferOp.padding()).value; 999 } 1000 }; 1001 1002 /// Codegen strategy for TransferWriteOp. 1003 template <> 1004 struct Strategy1d<TransferWriteOp> { 1005 static void generateForLoopBody(OpBuilder &builder, Location loc, 1006 TransferWriteOp xferOp, Value iv, 1007 ValueRange /*loopState*/) { 1008 SmallVector<Value, 8> indices; 1009 auto dim = get1dMemrefIndices(xferOp, iv, indices); 1010 auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); 1011 1012 // Nothing to do in case of out-of-bounds access. 1013 generateInBoundsCheck( 1014 xferOp, iv, builder, dim, 1015 /*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) { 1016 auto val = vector_extract_element(xferOp.vector(), ivI32.value); 1017 memref_store(val, xferOp.source(), indices); 1018 }); 1019 builder.create<scf::YieldOp>(loc); 1020 } 1021 1022 static Value initialLoopState(TransferWriteOp xferOp) { return Value(); } 1023 }; 1024 1025 /// Return true if the last dimension of the MemRefType has unit stride. 1026 static bool isLastMemrefDimUnitStride(MemRefType type) { 1027 int64_t offset; 1028 SmallVector<int64_t, 4> strides; 1029 auto successStrides = getStridesAndOffset(type, strides, offset); 1030 return succeeded(successStrides) && strides.back() == 1; 1031 } 1032 1033 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 1034 /// necessary in cases where a 1D vector transfer op cannot be lowered into 1035 /// vector load/stores due to non-unit strides or broadcasts: 1036 /// 1037 /// * Transfer dimension is not the last memref dimension 1038 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 1039 /// * Memref has a layout map with non-unit stride on the last dimension 1040 /// 1041 /// This pattern generates IR as follows: 1042 /// 1043 /// 1. Generate a for loop iterating over each vector element. 1044 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 1045 /// depending on OpTy. 1046 /// 1047 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 1048 /// can be generated instead of TransferOp1dConversion. Add such a pattern 1049 /// to ConvertVectorToLLVM. 1050 /// 1051 /// E.g.: 1052 /// ``` 1053 /// vector.transfer_write %vec, %A[%a, %b] 1054 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 1055 /// : vector<9xf32>, memref<?x?xf32> 1056 /// ``` 1057 /// Is rewritten to approximately the following pseudo-IR: 1058 /// ``` 1059 /// for i = 0 to 9 { 1060 /// %t = vector.extractelement %vec[i] : vector<9xf32> 1061 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 1062 /// } 1063 /// ``` 1064 template <typename OpTy> 1065 struct TransferOp1dConversion : public OpRewritePattern<OpTy> { 1066 using OpRewritePattern<OpTy>::OpRewritePattern; 1067 1068 LogicalResult matchAndRewrite(OpTy xferOp, 1069 PatternRewriter &rewriter) const override { 1070 ScopedContext scope(rewriter, xferOp.getLoc()); 1071 auto map = xferOp.permutation_map(); 1072 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1073 1074 if (!memRefType) 1075 return failure(); 1076 if (xferOp.getVectorType().getRank() != 1) 1077 return failure(); 1078 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) 1079 return failure(); // Handled by ConvertVectorToLLVM 1080 1081 // Loop bounds, step, state... 1082 auto vecType = xferOp.getVectorType(); 1083 auto lb = std_constant_index(0); 1084 auto ub = std_constant_index(vecType.getDimSize(0)); 1085 auto step = std_constant_index(1); 1086 auto loopState = Strategy1d<OpTy>::initialLoopState(xferOp); 1087 1088 // Generate for loop. 1089 rewriter.replaceOpWithNewOp<scf::ForOp>( 1090 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 1091 [&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) { 1092 ScopedContext nestedScope(builder, loc); 1093 Strategy1d<OpTy>::generateForLoopBody(builder, loc, xferOp, iv, 1094 loopState); 1095 }); 1096 1097 return success(); 1098 } 1099 }; 1100 1101 } // namespace 1102 1103 namespace mlir { 1104 1105 void populateVectorToSCFConversionPatterns( 1106 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 1107 if (options.unroll) { 1108 patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>( 1109 patterns.getContext()); 1110 } else { 1111 patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion, 1112 TransferOpConversion<TransferReadOp>, 1113 TransferOpConversion<TransferWriteOp>>(patterns.getContext()); 1114 } 1115 1116 if (kTargetRank == 1) { 1117 patterns.add<TransferOp1dConversion<TransferReadOp>, 1118 TransferOp1dConversion<TransferWriteOp>>( 1119 patterns.getContext()); 1120 } 1121 } 1122 1123 } // namespace mlir 1124 1125 namespace { 1126 1127 struct ConvertVectorToSCFPass 1128 : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 1129 ConvertVectorToSCFPass() = default; 1130 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1131 this->fullUnroll = options.unroll; 1132 } 1133 1134 void runOnFunction() override { 1135 RewritePatternSet patterns(getFunction().getContext()); 1136 populateVectorToSCFConversionPatterns( 1137 patterns, VectorTransferToSCFOptions().setUnroll(fullUnroll)); 1138 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 1139 } 1140 }; 1141 1142 } // namespace 1143 1144 std::unique_ptr<Pass> 1145 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1146 return std::make_unique<ConvertVectorToSCFPass>(options); 1147 } 1148