1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector transfer operations to SCF. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 16 17 #include "../PassDetail.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/Utils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/Dialect/Vector/VectorUtils.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/ImplicitLocOpBuilder.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "mlir/Transforms/Passes.h" 29 30 using namespace mlir; 31 using vector::TransferReadOp; 32 using vector::TransferWriteOp; 33 34 namespace { 35 36 /// Attribute name used for labeling transfer ops during progressive lowering. 37 static const char kPassLabel[] = "__vector_to_scf_lowering__"; 38 39 /// Patterns that inherit from this struct have access to 40 /// VectorTransferToSCFOptions. 41 template <typename OpTy> 42 struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 43 explicit VectorToSCFPattern(MLIRContext *context, 44 VectorTransferToSCFOptions opt) 45 : OpRewritePattern<OpTy>(context), options(opt) {} 46 47 VectorTransferToSCFOptions options; 48 }; 49 50 /// Given a vector transfer op, calculate which dimension of the `source` 51 /// memref should be unpacked in the next application of TransferOpConversion. 52 /// A return value of None indicates a broadcast. 53 template <typename OpTy> 54 static Optional<int64_t> unpackedDim(OpTy xferOp) { 55 auto map = xferOp.permutation_map(); 56 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 57 return expr.getPosition(); 58 } 59 assert(xferOp.isBroadcastDim(0) && 60 "Expected AffineDimExpr or AffineConstantExpr"); 61 return None; 62 } 63 64 /// Compute the permutation map for the new (N-1)-D vector transfer op. This 65 /// map is identical to the current permutation map, but the first result is 66 /// omitted. 67 template <typename OpTy> 68 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { 69 auto map = xferOp.permutation_map(); 70 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 71 b.getContext()); 72 } 73 74 /// Calculate the indices for the new vector transfer op. 75 /// 76 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 77 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 78 /// ^^^^^^ 79 /// `iv` is the iteration variable of the (new) surrounding loop. 80 template <typename OpTy> 81 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, 82 SmallVector<Value, 8> &indices) { 83 typename OpTy::Adaptor adaptor(xferOp); 84 // Corresponding memref dim of the vector dim that is unpacked. 85 auto dim = unpackedDim(xferOp); 86 auto prevIndices = adaptor.indices(); 87 indices.append(prevIndices.begin(), prevIndices.end()); 88 89 Location loc = xferOp.getLoc(); 90 bool isBroadcast = !dim.hasValue(); 91 if (!isBroadcast) { 92 AffineExpr d0, d1; 93 bindDims(xferOp.getContext(), d0, d1); 94 Value offset = adaptor.indices()[dim.getValue()]; 95 indices[dim.getValue()] = 96 makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 97 } 98 } 99 100 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, 101 Value value) { 102 if (hasRetVal) { 103 assert(value && "Expected non-empty value"); 104 b.create<scf::YieldOp>(loc, value); 105 } else { 106 b.create<scf::YieldOp>(loc); 107 } 108 } 109 110 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 111 /// is set to true. No such check is generated under following circumstances: 112 /// * xferOp does not have a mask. 113 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 114 /// computed and attached to the new transfer op in the pattern.) 115 /// * The to-be-unpacked dim of xferOp is a broadcast. 116 template <typename OpTy> 117 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { 118 if (!xferOp.mask()) 119 return Value(); 120 if (xferOp.getMaskType().getRank() != 1) 121 return Value(); 122 if (xferOp.isBroadcastDim(0)) 123 return Value(); 124 125 Location loc = xferOp.getLoc(); 126 Value ivI32 = 127 b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv); 128 return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32); 129 } 130 131 /// Helper function TransferOpConversion and TransferOp1dConversion. 132 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 133 /// specified dimension `dim` with the loop iteration variable `iv`. 134 /// E.g., when unpacking dimension 0 from: 135 /// ``` 136 /// %vec = vector.transfer_read %A[%a, %b] %cst 137 /// : vector<5x4xf32>, memref<?x?xf32> 138 /// ``` 139 /// An if check similar to this will be generated inside the loop: 140 /// ``` 141 /// %d = memref.dim %A, %c0 : memref<?x?xf32> 142 /// if (%a + iv < %d) { 143 /// (in-bounds case) 144 /// } else { 145 /// (out-of-bounds case) 146 /// } 147 /// ``` 148 /// 149 /// If the transfer is 1D and has a mask, this function generates a more complex 150 /// check also accounts for potentially masked out elements. 151 /// 152 /// This function variant returns the value returned by `inBoundsCase` or 153 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 154 /// `resultTypes`. 155 template <typename OpTy> 156 static Value generateInBoundsCheck( 157 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 158 TypeRange resultTypes, 159 function_ref<Value(OpBuilder &, Location)> inBoundsCase, 160 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 161 bool hasRetVal = !resultTypes.empty(); 162 Value cond; // Condition to be built... 163 164 // Condition check 1: Access in-bounds? 165 bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. 166 Location loc = xferOp.getLoc(); 167 ImplicitLocOpBuilder lb(xferOp.getLoc(), b); 168 if (!xferOp.isDimInBounds(0) && !isBroadcast) { 169 Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim); 170 AffineExpr d0, d1; 171 bindDims(xferOp.getContext(), d0, d1); 172 Value base = xferOp.indices()[dim.getValue()]; 173 Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); 174 cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx); 175 } 176 177 // Condition check 2: Masked in? 178 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { 179 if (cond) 180 cond = lb.create<AndOp>(cond, maskCond); 181 else 182 cond = maskCond; 183 } 184 185 // If the condition is non-empty, generate an SCF::IfOp. 186 if (cond) { 187 auto check = lb.create<scf::IfOp>( 188 resultTypes, cond, 189 /*thenBuilder=*/ 190 [&](OpBuilder &b, Location loc) { 191 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); 192 }, 193 /*elseBuilder=*/ 194 [&](OpBuilder &b, Location loc) { 195 if (outOfBoundsCase) { 196 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); 197 } else { 198 b.create<scf::YieldOp>(loc); 199 } 200 }); 201 202 return hasRetVal ? check.getResult(0) : Value(); 203 } 204 205 // Condition is empty, no need for an SCF::IfOp. 206 return inBoundsCase(b, loc); 207 } 208 209 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 210 /// a return value. Consequently, this function does not have a return value. 211 template <typename OpTy> 212 static void generateInBoundsCheck( 213 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 214 function_ref<void(OpBuilder &, Location)> inBoundsCase, 215 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 216 generateInBoundsCheck( 217 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), 218 /*inBoundsCase=*/ 219 [&](OpBuilder &b, Location loc) { 220 inBoundsCase(b, loc); 221 return Value(); 222 }, 223 /*outOfBoundsCase=*/ 224 [&](OpBuilder &b, Location loc) { 225 if (outOfBoundsCase) 226 outOfBoundsCase(b, loc); 227 return Value(); 228 }); 229 } 230 231 /// Given an ArrayAttr, return a copy where the first element is dropped. 232 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { 233 if (!attr) 234 return attr; 235 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); 236 } 237 238 /// Add the pass label to a vector transfer op if its rank is not the target 239 /// rank. 240 template <typename OpTy> 241 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, 242 unsigned targetRank) { 243 if (newXferOp.getVectorType().getRank() > targetRank) 244 newXferOp->setAttr(kPassLabel, b.getUnitAttr()); 245 } 246 247 /// Return true if this transfer op operates on a source tensor. 248 template <typename OpTy> 249 static bool isTensorOp(OpTy xferOp) { 250 if (xferOp.getShapedType().template isa<RankedTensorType>()) { 251 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { 252 // TransferWriteOps on tensors have a result. 253 assert(xferOp->getNumResults() > 0); 254 } 255 return true; 256 } 257 return false; 258 } 259 260 namespace lowering_n_d { 261 262 /// Helper data structure for data and mask buffers. 263 struct BufferAllocs { 264 Value dataBuffer; 265 Value maskBuffer; 266 }; 267 268 /// Allocate temporary buffers for data (vector) and mask (if present). 269 /// TODO: Parallelism and threadlocal considerations. 270 template <typename OpTy> 271 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { 272 Location loc = xferOp.getLoc(); 273 OpBuilder::InsertionGuard guard(b); 274 Operation *scope = 275 xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 276 assert(scope && "Expected op to be inside automatic allocation scope"); 277 b.setInsertionPointToStart(&scope->getRegion(0).front()); 278 279 BufferAllocs result; 280 auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 281 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); 282 283 if (xferOp.mask()) { 284 auto maskType = MemRefType::get({}, xferOp.mask().getType()); 285 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); 286 b.setInsertionPoint(xferOp); 287 b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer); 288 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer); 289 } 290 291 return result; 292 } 293 294 /// Given a MemRefType with VectorType element type, unpack one dimension from 295 /// the VectorType into the MemRefType. 296 /// 297 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 298 static MemRefType unpackOneDim(MemRefType type) { 299 auto vectorType = type.getElementType().dyn_cast<VectorType>(); 300 auto memrefShape = type.getShape(); 301 SmallVector<int64_t, 8> newMemrefShape; 302 newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 303 newMemrefShape.push_back(vectorType.getDimSize(0)); 304 return MemRefType::get(newMemrefShape, 305 VectorType::get(vectorType.getShape().drop_front(), 306 vectorType.getElementType())); 307 } 308 309 /// Given a transfer op, find the memref from which the mask is loaded. This 310 /// is similar to Strategy<TransferWriteOp>::getBuffer. 311 template <typename OpTy> 312 static Value getMaskBuffer(OpTy xferOp) { 313 assert(xferOp.mask() && "Expected that transfer op has mask"); 314 auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>(); 315 assert(loadOp && "Expected transfer op mask produced by LoadOp"); 316 return loadOp.getMemRef(); 317 } 318 319 /// Codegen strategy, depending on the operation. 320 template <typename OpTy> 321 struct Strategy; 322 323 /// Code strategy for vector TransferReadOp. 324 template <> 325 struct Strategy<TransferReadOp> { 326 /// Find the StoreOp that is used for writing the current TransferReadOp's 327 /// result to the temporary buffer allocation. 328 static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 329 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 330 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 331 assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 332 return storeOp; 333 } 334 335 /// Find the temporary buffer allocation. All labeled TransferReadOps are 336 /// used like this, where %buf is either the buffer allocation or a type cast 337 /// of the buffer allocation: 338 /// ``` 339 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 340 /// memref.store %vec, %buf[...] ... 341 /// ``` 342 static Value getBuffer(TransferReadOp xferOp) { 343 return getStoreOp(xferOp).getMemRef(); 344 } 345 346 /// Retrieve the indices of the current StoreOp that stores into the buffer. 347 static void getBufferIndices(TransferReadOp xferOp, 348 SmallVector<Value, 8> &indices) { 349 auto storeOp = getStoreOp(xferOp); 350 auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); 351 indices.append(prevIndices.begin(), prevIndices.end()); 352 } 353 354 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 355 /// accesses on the to-be-unpacked dimension. 356 /// 357 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 358 /// variable `iv`. 359 /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 360 /// 361 /// E.g.: 362 /// ``` 363 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 364 /// : memref<?x?x?xf32>, vector<4x3xf32> 365 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 366 /// ``` 367 /// Is rewritten to: 368 /// ``` 369 /// %casted = vector.type_cast %buf 370 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 371 /// for %j = 0 to 4 { 372 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 373 /// : memref<?x?x?xf32>, vector<3xf32> 374 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 375 /// } 376 /// ``` 377 /// 378 /// Note: The loop and type cast are generated in TransferOpConversion. 379 /// The original TransferReadOp and store op are deleted in `cleanup`. 380 /// Note: The `mask` operand is set in TransferOpConversion. 381 static TransferReadOp rewriteOp(OpBuilder &b, 382 VectorTransferToSCFOptions options, 383 TransferReadOp xferOp, Value buffer, Value iv, 384 ValueRange /*loopState*/) { 385 SmallVector<Value, 8> storeIndices; 386 getBufferIndices(xferOp, storeIndices); 387 storeIndices.push_back(iv); 388 389 SmallVector<Value, 8> xferIndices; 390 getXferIndices(b, xferOp, iv, xferIndices); 391 392 Location loc = xferOp.getLoc(); 393 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 394 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 395 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 396 auto newXferOp = b.create<vector::TransferReadOp>( 397 loc, vecType, xferOp.source(), xferIndices, 398 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(), 399 Value(), inBoundsAttr); 400 401 maybeApplyPassLabel(b, newXferOp, options.targetRank); 402 403 b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices); 404 return newXferOp; 405 } 406 407 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 408 /// padding value to the temporary buffer. 409 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, 410 Value buffer, Value iv, 411 ValueRange /*loopState*/) { 412 SmallVector<Value, 8> storeIndices; 413 getBufferIndices(xferOp, storeIndices); 414 storeIndices.push_back(iv); 415 416 Location loc = xferOp.getLoc(); 417 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 418 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 419 auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding()); 420 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); 421 422 return Value(); 423 } 424 425 /// Cleanup after rewriting the op. 426 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, 427 scf::ForOp /*forOp*/) { 428 rewriter.eraseOp(getStoreOp(xferOp)); 429 rewriter.eraseOp(xferOp); 430 } 431 432 /// Return the initial loop state for the generated scf.for loop. 433 static Value initialLoopState(TransferReadOp xferOp) { return Value(); } 434 }; 435 436 /// Codegen strategy for vector TransferWriteOp. 437 template <> 438 struct Strategy<TransferWriteOp> { 439 /// Find the temporary buffer allocation. All labeled TransferWriteOps are 440 /// used like this, where %buf is either the buffer allocation or a type cast 441 /// of the buffer allocation: 442 /// ``` 443 /// %vec = memref.load %buf[...] ... 444 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 445 /// ``` 446 static Value getBuffer(TransferWriteOp xferOp) { 447 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 448 assert(loadOp && "Expected transfer op vector produced by LoadOp"); 449 return loadOp.getMemRef(); 450 } 451 452 /// Retrieve the indices of the current LoadOp that loads from the buffer. 453 static void getBufferIndices(TransferWriteOp xferOp, 454 SmallVector<Value, 8> &indices) { 455 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 456 auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); 457 indices.append(prevIndices.begin(), prevIndices.end()); 458 } 459 460 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 461 /// accesses on the to-be-unpacked dimension. 462 /// 463 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 464 /// using the loop iteration variable `iv`. 465 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 466 /// to memory. 467 /// 468 /// Note: For more details, see comments on Strategy<TransferReadOp>. 469 static TransferWriteOp rewriteOp(OpBuilder &b, 470 VectorTransferToSCFOptions options, 471 TransferWriteOp xferOp, Value buffer, 472 Value iv, ValueRange loopState) { 473 SmallVector<Value, 8> loadIndices; 474 getBufferIndices(xferOp, loadIndices); 475 loadIndices.push_back(iv); 476 477 SmallVector<Value, 8> xferIndices; 478 getXferIndices(b, xferOp, iv, xferIndices); 479 480 Location loc = xferOp.getLoc(); 481 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); 482 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 483 auto source = loopState.empty() ? xferOp.source() : loopState[0]; 484 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 485 auto newXferOp = b.create<vector::TransferWriteOp>( 486 loc, type, vec, source, xferIndices, 487 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 488 inBoundsAttr); 489 490 maybeApplyPassLabel(b, newXferOp, options.targetRank); 491 492 return newXferOp; 493 } 494 495 /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 496 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, 497 Value buffer, Value iv, 498 ValueRange loopState) { 499 return isTensorOp(xferOp) ? loopState[0] : Value(); 500 } 501 502 /// Cleanup after rewriting the op. 503 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, 504 scf::ForOp forOp) { 505 if (isTensorOp(xferOp)) { 506 assert(forOp->getNumResults() == 1 && "Expected one for loop result"); 507 rewriter.replaceOp(xferOp, forOp->getResult(0)); 508 } else { 509 rewriter.eraseOp(xferOp); 510 } 511 } 512 513 /// Return the initial loop state for the generated scf.for loop. 514 static Value initialLoopState(TransferWriteOp xferOp) { 515 return isTensorOp(xferOp) ? xferOp.source() : Value(); 516 } 517 }; 518 519 template <typename OpTy> 520 LogicalResult checkPrepareXferOp(OpTy xferOp, 521 VectorTransferToSCFOptions options) { 522 if (xferOp->hasAttr(kPassLabel)) 523 return failure(); 524 if (xferOp.getVectorType().getRank() <= options.targetRank) 525 return failure(); 526 if (isTensorOp(xferOp) && !options.lowerTensors) 527 return failure(); 528 // Transfer ops that modify the element type are not supported atm. 529 if (xferOp.getVectorType().getElementType() != 530 xferOp.getShapedType().getElementType()) 531 return failure(); 532 return success(); 533 } 534 535 /// Prepare a TransferReadOp for progressive lowering. 536 /// 537 /// 1. Allocate a temporary buffer. 538 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 539 /// 3. Store the result of the TransferReadOp into the temporary buffer. 540 /// 4. Load the result from the temporary buffer and replace all uses of the 541 /// original TransferReadOp with this load. 542 /// 543 /// E.g.: 544 /// ``` 545 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 546 /// : vector<5x4xf32>, memref<?x?x?xf32> 547 /// ``` 548 /// is rewritten to: 549 /// ``` 550 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 551 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 552 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 553 /// memref.store %1, %0[] : memref<vector<5x4xf32>> 554 /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 555 /// ``` 556 /// 557 /// Note: A second temporary buffer may be allocated for the `mask` operand. 558 struct PrepareTransferReadConversion 559 : public VectorToSCFPattern<TransferReadOp> { 560 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 561 562 LogicalResult matchAndRewrite(TransferReadOp xferOp, 563 PatternRewriter &rewriter) const override { 564 if (checkPrepareXferOp(xferOp, options).failed()) 565 return failure(); 566 567 auto buffers = allocBuffers(rewriter, xferOp); 568 auto *newXfer = rewriter.clone(*xferOp.getOperation()); 569 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 570 if (xferOp.mask()) { 571 dyn_cast<TransferReadOp>(newXfer).maskMutable().assign( 572 buffers.maskBuffer); 573 } 574 575 Location loc = xferOp.getLoc(); 576 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), 577 buffers.dataBuffer); 578 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 579 580 return success(); 581 } 582 }; 583 584 /// Prepare a TransferWriteOp for progressive lowering. 585 /// 586 /// 1. Allocate a temporary buffer. 587 /// 2. Store the vector into the buffer. 588 /// 3. Load the vector from the buffer again. 589 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 590 /// marking it eligible for progressive lowering via TransferOpConversion. 591 /// 592 /// E.g.: 593 /// ``` 594 /// vector.transfer_write %vec, %A[%a, %b, %c] 595 /// : vector<5x4xf32>, memref<?x?x?xf32> 596 /// ``` 597 /// is rewritten to: 598 /// ``` 599 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 600 /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 601 /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 602 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 603 /// : vector<5x4xf32>, memref<?x?x?xf32> 604 /// ``` 605 /// 606 /// Note: A second temporary buffer may be allocated for the `mask` operand. 607 struct PrepareTransferWriteConversion 608 : public VectorToSCFPattern<TransferWriteOp> { 609 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 610 611 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 612 PatternRewriter &rewriter) const override { 613 if (checkPrepareXferOp(xferOp, options).failed()) 614 return failure(); 615 616 Location loc = xferOp.getLoc(); 617 auto buffers = allocBuffers(rewriter, xferOp); 618 rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer); 619 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); 620 rewriter.updateRootInPlace(xferOp, [&]() { 621 xferOp.vectorMutable().assign(loadedVec); 622 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 623 }); 624 625 if (xferOp.mask()) { 626 rewriter.updateRootInPlace( 627 xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); }); 628 } 629 630 return success(); 631 } 632 }; 633 634 /// Progressive lowering of vector transfer ops: Unpack one dimension. 635 /// 636 /// 1. Unpack one dimension from the current buffer type and cast the buffer 637 /// to that new type. E.g.: 638 /// ``` 639 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 640 /// vector.transfer_write %vec ... 641 /// ``` 642 /// The following cast is generated: 643 /// ``` 644 /// %casted = vector.type_cast %0 645 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 646 /// ``` 647 /// 2. Generate a for loop and rewrite the transfer op according to the 648 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 649 /// out-of-bounds, generate an if-check and handle both cases separately. 650 /// 3. Clean up according to the corresponding Strategy<OpTy>. 651 /// 652 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor 653 /// source (as opposed to a memref source), then each iteration of the generated 654 /// scf.for loop yields the new tensor value. E.g.: 655 /// ``` 656 /// %result = scf.for i = 0 to 5 { 657 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> 658 /// %1 = vector.transfer_write %0, %source[...] 659 /// : vector<4x3xf32>, tensor<5x4x3xf32> 660 /// scf.yield %1 : tensor<5x4x3xf32> 661 /// } 662 /// ``` 663 template <typename OpTy> 664 struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 665 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 666 667 void initialize() { 668 // This pattern recursively unpacks one dimension at a time. The recursion 669 // bounded as the rank is strictly decreasing. 670 this->setHasBoundedRewriteRecursion(); 671 } 672 673 LogicalResult matchAndRewrite(OpTy xferOp, 674 PatternRewriter &rewriter) const override { 675 if (!xferOp->hasAttr(kPassLabel)) 676 return failure(); 677 678 // Find and cast data buffer. How the buffer can be found depends on OpTy. 679 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); 680 auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 681 auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>(); 682 auto castedDataType = unpackOneDim(dataBufferType); 683 auto castedDataBuffer = 684 locB.create<vector::TypeCastOp>(castedDataType, dataBuffer); 685 686 // If the xferOp has a mask: Find and cast mask buffer. 687 Value castedMaskBuffer; 688 if (xferOp.mask()) { 689 auto maskBuffer = getMaskBuffer(xferOp); 690 auto maskBufferType = 691 maskBuffer.getType().template dyn_cast<MemRefType>(); 692 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 693 // Do not unpack a dimension of the mask, if: 694 // * To-be-unpacked transfer op dimension is a broadcast. 695 // * Mask is 1D, i.e., the mask cannot be further unpacked. 696 // (That means that all remaining dimensions of the transfer op must 697 // be broadcasted.) 698 castedMaskBuffer = maskBuffer; 699 } else { 700 auto castedMaskType = unpackOneDim(maskBufferType); 701 castedMaskBuffer = 702 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); 703 } 704 } 705 706 // Loop bounds and step. 707 auto lb = locB.create<ConstantIndexOp>(0); 708 auto ub = locB.create<ConstantIndexOp>( 709 castedDataType.getDimSize(castedDataType.getRank() - 1)); 710 auto step = locB.create<ConstantIndexOp>(1); 711 // TransferWriteOps that operate on tensors return the modified tensor and 712 // require a loop state. 713 auto loopState = Strategy<OpTy>::initialLoopState(xferOp); 714 715 // Generate for loop. 716 auto result = locB.create<scf::ForOp>( 717 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 718 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 719 Type stateType = loopState.empty() ? Type() : loopState[0].getType(); 720 721 auto result = generateInBoundsCheck( 722 b, xferOp, iv, unpackedDim(xferOp), 723 stateType ? TypeRange(stateType) : TypeRange(), 724 /*inBoundsCase=*/ 725 [&](OpBuilder &b, Location loc) { 726 // Create new transfer op. 727 OpTy newXfer = Strategy<OpTy>::rewriteOp( 728 b, this->options, xferOp, castedDataBuffer, iv, loopState); 729 730 // If old transfer op has a mask: Set mask on new transfer op. 731 // Special case: If the mask of the old transfer op is 1D and 732 // the 733 // unpacked dim is not a broadcast, no mask is 734 // needed on the new transfer op. 735 if (xferOp.mask() && (xferOp.isBroadcastDim(0) || 736 xferOp.getMaskType().getRank() > 1)) { 737 OpBuilder::InsertionGuard guard(b); 738 b.setInsertionPoint(newXfer); // Insert load before newXfer. 739 740 SmallVector<Value, 8> loadIndices; 741 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); 742 // In case of broadcast: Use same indices to load from memref 743 // as before. 744 if (!xferOp.isBroadcastDim(0)) 745 loadIndices.push_back(iv); 746 747 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, 748 loadIndices); 749 rewriter.updateRootInPlace( 750 newXfer, [&]() { newXfer.maskMutable().assign(mask); }); 751 } 752 753 return loopState.empty() ? Value() : newXfer->getResult(0); 754 }, 755 /*outOfBoundsCase=*/ 756 [&](OpBuilder &b, Location /*loc*/) { 757 return Strategy<OpTy>::handleOutOfBoundsDim( 758 b, xferOp, castedDataBuffer, iv, loopState); 759 }); 760 761 maybeYieldValue(b, loc, !loopState.empty(), result); 762 }); 763 764 Strategy<OpTy>::cleanup(rewriter, xferOp, result); 765 return success(); 766 } 767 }; 768 769 } // namespace lowering_n_d 770 771 namespace lowering_n_d_unrolled { 772 773 /// If the original transfer op has a mask, compute the mask of the new transfer 774 /// op (for the current iteration `i`) and assign it. 775 template <typename OpTy> 776 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, 777 int64_t i) { 778 if (!xferOp.mask()) 779 return; 780 781 if (xferOp.isBroadcastDim(0)) { 782 // To-be-unpacked dimension is a broadcast, which does not have a 783 // corresponding mask dimension. Mask attribute remains unchanged. 784 newXferOp.maskMutable().assign(xferOp.mask()); 785 return; 786 } 787 788 if (xferOp.getMaskType().getRank() > 1) { 789 // Unpack one dimension of the mask. 790 OpBuilder::InsertionGuard guard(b); 791 b.setInsertionPoint(newXferOp); // Insert load before newXfer. 792 793 llvm::SmallVector<int64_t, 1> indices({i}); 794 Location loc = xferOp.getLoc(); 795 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices); 796 newXferOp.maskMutable().assign(newMask); 797 } 798 799 // If we end up here: The mask of the old transfer op is 1D and the unpacked 800 // dim is not a broadcast, so no mask is needed on the new transfer op. 801 // `generateInBoundsCheck` will have evaluated the mask already. 802 } 803 804 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 805 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 806 /// memref buffer is allocated and the SCF loop is fully unrolled. 807 /// 808 /// ``` 809 /// E.g.: 810 /// ``` 811 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 812 /// : memref<?x?x?xf32>, vector<5x4xf32> 813 /// ``` 814 /// is rewritten to IR such as (simplified): 815 /// ``` 816 /// %v_init = splat %padding : vector<5x4xf32> 817 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 818 /// : memref<?x?x?xf32>, vector<4xf32> 819 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 820 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 821 /// : memref<?x?x?xf32>, vector<4xf32> 822 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 823 /// ... 824 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 825 /// : memref<?x?x?xf32>, vector<4xf32> 826 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 827 /// ``` 828 /// 829 /// Note: As an optimization, if the result of the original TransferReadOp 830 /// was directly inserted into another vector, no new %v_init vector is created. 831 /// Instead, the new TransferReadOp results are inserted into that vector. 832 struct UnrollTransferReadConversion 833 : public VectorToSCFPattern<TransferReadOp> { 834 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 835 836 void initialize() { 837 // This pattern recursively unpacks one dimension at a time. The recursion 838 // bounded as the rank is strictly decreasing. 839 setHasBoundedRewriteRecursion(); 840 } 841 842 /// Return the vector into which the newly created TransferReadOp results 843 /// are inserted. 844 Value getResultVector(TransferReadOp xferOp, 845 PatternRewriter &rewriter) const { 846 if (auto insertOp = getInsertOp(xferOp)) 847 return insertOp.dest(); 848 Location loc = xferOp.getLoc(); 849 return rewriter.create<SplatOp>(loc, xferOp.getVectorType(), 850 xferOp.padding()); 851 } 852 853 /// If the result of the TransferReadOp has exactly one user, which is a 854 /// vector::InsertOp, return that operation. 855 vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 856 if (xferOp->hasOneUse()) { 857 Operation *xferOpUser = *xferOp->getUsers().begin(); 858 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 859 return insertOp; 860 } 861 862 return vector::InsertOp(); 863 } 864 865 /// If the result of the TransferReadOp has exactly one user, which is a 866 /// vector::InsertOp, return that operation's indices. 867 void getInsertionIndices(TransferReadOp xferOp, 868 SmallVector<int64_t, 8> &indices) const { 869 if (auto insertOp = getInsertOp(xferOp)) { 870 llvm::for_each(insertOp.position(), [&](Attribute attr) { 871 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 872 }); 873 } 874 } 875 876 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 877 /// accesses, and broadcasts and transposes in permutation maps. 878 LogicalResult matchAndRewrite(TransferReadOp xferOp, 879 PatternRewriter &rewriter) const override { 880 if (xferOp.getVectorType().getRank() <= options.targetRank) 881 return failure(); 882 if (isTensorOp(xferOp) && !options.lowerTensors) 883 return failure(); 884 // Transfer ops that modify the element type are not supported atm. 885 if (xferOp.getVectorType().getElementType() != 886 xferOp.getShapedType().getElementType()) 887 return failure(); 888 889 auto insertOp = getInsertOp(xferOp); 890 auto vec = getResultVector(xferOp, rewriter); 891 auto vecType = vec.getType().dyn_cast<VectorType>(); 892 auto xferVecType = xferOp.getVectorType(); 893 auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), 894 xferVecType.getElementType()); 895 int64_t dimSize = xferVecType.getShape()[0]; 896 897 // Generate fully unrolled loop of transfer ops. 898 Location loc = xferOp.getLoc(); 899 for (int64_t i = 0; i < dimSize; ++i) { 900 Value iv = rewriter.create<ConstantIndexOp>(loc, i); 901 902 vec = generateInBoundsCheck( 903 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), 904 /*inBoundsCase=*/ 905 [&](OpBuilder &b, Location loc) { 906 // Indices for the new transfer op. 907 SmallVector<Value, 8> xferIndices; 908 getXferIndices(b, xferOp, iv, xferIndices); 909 910 // Indices for the new vector.insert op. 911 SmallVector<int64_t, 8> insertionIndices; 912 getInsertionIndices(xferOp, insertionIndices); 913 insertionIndices.push_back(i); 914 915 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 916 auto newXferOp = b.create<vector::TransferReadOp>( 917 loc, newXferVecType, xferOp.source(), xferIndices, 918 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 919 xferOp.padding(), Value(), inBoundsAttr); 920 maybeAssignMask(b, xferOp, newXferOp, i); 921 return b.create<vector::InsertOp>(loc, newXferOp, vec, 922 insertionIndices); 923 }, 924 /*outOfBoundsCase=*/ 925 [&](OpBuilder &b, Location loc) { 926 // Loop through original (unmodified) vector. 927 return vec; 928 }); 929 } 930 931 if (insertOp) { 932 // Rewrite single user of the old TransferReadOp, which was an InsertOp. 933 rewriter.replaceOp(insertOp, vec); 934 rewriter.eraseOp(xferOp); 935 } else { 936 rewriter.replaceOp(xferOp, vec); 937 } 938 939 return success(); 940 } 941 }; 942 943 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 944 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 945 /// memref buffer is allocated and the SCF loop is fully unrolled. 946 /// 947 /// ``` 948 /// E.g.: 949 /// ``` 950 /// vector.transfer_write %vec, %A[%a, %b, %c] 951 /// : vector<5x4xf32>, memref<?x?x?xf32> 952 /// ``` 953 /// is rewritten to IR such as (simplified): 954 /// ``` 955 /// %v0 = vector.extract %vec[0] : vector<5x4xf32> 956 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 957 /// %v1 = vector.extract %vec[1] : vector<5x4xf32> 958 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 959 /// ... 960 /// %v4 = vector.extract %vec[4] : vector<5x4xf32> 961 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 962 /// ``` 963 /// 964 /// Note: As an optimization, if the vector of the original TransferWriteOp 965 /// was directly extracted from another vector via an ExtractOp `a`, extract 966 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 967 /// doing so, `a` may become dead, and the number of ExtractOps generated during 968 /// recursive application of this pattern will be minimal. 969 struct UnrollTransferWriteConversion 970 : public VectorToSCFPattern<TransferWriteOp> { 971 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 972 973 void initialize() { 974 // This pattern recursively unpacks one dimension at a time. The recursion 975 // bounded as the rank is strictly decreasing. 976 setHasBoundedRewriteRecursion(); 977 } 978 979 /// Return the vector from which newly generated ExtracOps will extract. 980 Value getDataVector(TransferWriteOp xferOp) const { 981 if (auto extractOp = getExtractOp(xferOp)) 982 return extractOp.vector(); 983 return xferOp.vector(); 984 } 985 986 /// If the input of the given TransferWriteOp is an ExtractOp, return it. 987 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 988 if (auto *op = xferOp.vector().getDefiningOp()) 989 return dyn_cast<vector::ExtractOp>(op); 990 return vector::ExtractOp(); 991 } 992 993 /// If the input of the given TransferWriteOp is an ExtractOp, return its 994 /// indices. 995 void getExtractionIndices(TransferWriteOp xferOp, 996 SmallVector<int64_t, 8> &indices) const { 997 if (auto extractOp = getExtractOp(xferOp)) { 998 llvm::for_each(extractOp.position(), [&](Attribute attr) { 999 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 1000 }); 1001 } 1002 } 1003 1004 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 1005 /// accesses, and broadcasts and transposes in permutation maps. 1006 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 1007 PatternRewriter &rewriter) const override { 1008 if (xferOp.getVectorType().getRank() <= options.targetRank) 1009 return failure(); 1010 if (isTensorOp(xferOp) && !options.lowerTensors) 1011 return failure(); 1012 // Transfer ops that modify the element type are not supported atm. 1013 if (xferOp.getVectorType().getElementType() != 1014 xferOp.getShapedType().getElementType()) 1015 return failure(); 1016 1017 auto vec = getDataVector(xferOp); 1018 auto xferVecType = xferOp.getVectorType(); 1019 int64_t dimSize = xferVecType.getShape()[0]; 1020 auto source = xferOp.source(); // memref or tensor to be written to. 1021 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 1022 1023 // Generate fully unrolled loop of transfer ops. 1024 Location loc = xferOp.getLoc(); 1025 for (int64_t i = 0; i < dimSize; ++i) { 1026 Value iv = rewriter.create<ConstantIndexOp>(loc, i); 1027 1028 auto updatedSource = generateInBoundsCheck( 1029 rewriter, xferOp, iv, unpackedDim(xferOp), 1030 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), 1031 /*inBoundsCase=*/ 1032 [&](OpBuilder &b, Location loc) { 1033 // Indices for the new transfer op. 1034 SmallVector<Value, 8> xferIndices; 1035 getXferIndices(b, xferOp, iv, xferIndices); 1036 1037 // Indices for the new vector.extract op. 1038 SmallVector<int64_t, 8> extractionIndices; 1039 getExtractionIndices(xferOp, extractionIndices); 1040 extractionIndices.push_back(i); 1041 1042 auto extracted = 1043 b.create<vector::ExtractOp>(loc, vec, extractionIndices); 1044 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 1045 auto newXferOp = b.create<vector::TransferWriteOp>( 1046 loc, sourceType, extracted, source, xferIndices, 1047 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 1048 inBoundsAttr); 1049 1050 maybeAssignMask(b, xferOp, newXferOp, i); 1051 1052 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); 1053 }, 1054 /*outOfBoundsCase=*/ 1055 [&](OpBuilder &b, Location loc) { 1056 return isTensorOp(xferOp) ? source : Value(); 1057 }); 1058 1059 if (isTensorOp(xferOp)) 1060 source = updatedSource; 1061 } 1062 1063 if (isTensorOp(xferOp)) 1064 rewriter.replaceOp(xferOp, source); 1065 else 1066 rewriter.eraseOp(xferOp); 1067 1068 return success(); 1069 } 1070 }; 1071 1072 } // namespace lowering_n_d_unrolled 1073 1074 namespace lowering_1_d { 1075 1076 /// Compute the indices into the memref for the LoadOp/StoreOp generated as 1077 /// part of TransferOp1dConversion. Return the memref dimension on which 1078 /// the transfer is operating. A return value of None indicates a broadcast. 1079 template <typename OpTy> 1080 static Optional<int64_t> 1081 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, 1082 SmallVector<Value, 8> &memrefIndices) { 1083 auto indices = xferOp.indices(); 1084 auto map = xferOp.permutation_map(); 1085 1086 memrefIndices.append(indices.begin(), indices.end()); 1087 assert(map.getNumResults() == 1 && 1088 "Expected 1 permutation map result for 1D transfer"); 1089 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 1090 Location loc = xferOp.getLoc(); 1091 auto dim = expr.getPosition(); 1092 AffineExpr d0, d1; 1093 bindDims(xferOp.getContext(), d0, d1); 1094 Value offset = memrefIndices[dim]; 1095 memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 1096 return dim; 1097 } 1098 1099 assert(xferOp.isBroadcastDim(0) && 1100 "Expected AffineDimExpr or AffineConstantExpr"); 1101 return None; 1102 } 1103 1104 /// Codegen strategy for TransferOp1dConversion, depending on the 1105 /// operation. 1106 template <typename OpTy> 1107 struct Strategy1d; 1108 1109 /// Codegen strategy for TransferReadOp. 1110 template <> 1111 struct Strategy1d<TransferReadOp> { 1112 static void generateForLoopBody(OpBuilder &b, Location loc, 1113 TransferReadOp xferOp, Value iv, 1114 ValueRange loopState) { 1115 SmallVector<Value, 8> indices; 1116 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1117 Value ivI32 = 1118 b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv); 1119 auto vec = loopState[0]; 1120 1121 // In case of out-of-bounds access, leave `vec` as is (was initialized with 1122 // padding value). 1123 auto nextVec = generateInBoundsCheck( 1124 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), 1125 /*inBoundsCase=*/ 1126 [&](OpBuilder &b, Location loc) { 1127 Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices); 1128 return b.create<vector::InsertElementOp>(loc, val, vec, ivI32); 1129 }, 1130 /*outOfBoundsCase=*/ 1131 [&](OpBuilder & /*b*/, Location loc) { return vec; }); 1132 b.create<scf::YieldOp>(loc, nextVec); 1133 } 1134 1135 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { 1136 // Inititalize vector with padding value. 1137 Location loc = xferOp.getLoc(); 1138 return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding()); 1139 } 1140 }; 1141 1142 /// Codegen strategy for TransferWriteOp. 1143 template <> 1144 struct Strategy1d<TransferWriteOp> { 1145 static void generateForLoopBody(OpBuilder &b, Location loc, 1146 TransferWriteOp xferOp, Value iv, 1147 ValueRange /*loopState*/) { 1148 SmallVector<Value, 8> indices; 1149 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1150 Value ivI32 = 1151 b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv); 1152 1153 // Nothing to do in case of out-of-bounds access. 1154 generateInBoundsCheck( 1155 b, xferOp, iv, dim, 1156 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 1157 auto val = 1158 b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32); 1159 b.create<memref::StoreOp>(loc, val, xferOp.source(), indices); 1160 }); 1161 b.create<scf::YieldOp>(loc); 1162 } 1163 1164 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { 1165 return Value(); 1166 } 1167 }; 1168 1169 /// Return true if the last dimension of the MemRefType has unit stride. 1170 static bool isLastMemrefDimUnitStride(MemRefType type) { 1171 int64_t offset; 1172 SmallVector<int64_t, 4> strides; 1173 auto successStrides = getStridesAndOffset(type, strides, offset); 1174 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 1175 } 1176 1177 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 1178 /// necessary in cases where a 1D vector transfer op cannot be lowered into 1179 /// vector load/stores due to non-unit strides or broadcasts: 1180 /// 1181 /// * Transfer dimension is not the last memref dimension 1182 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 1183 /// * Memref has a layout map with non-unit stride on the last dimension 1184 /// 1185 /// This pattern generates IR as follows: 1186 /// 1187 /// 1. Generate a for loop iterating over each vector element. 1188 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 1189 /// depending on OpTy. 1190 /// 1191 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 1192 /// can be generated instead of TransferOp1dConversion. Add such a pattern 1193 /// to ConvertVectorToLLVM. 1194 /// 1195 /// E.g.: 1196 /// ``` 1197 /// vector.transfer_write %vec, %A[%a, %b] 1198 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 1199 /// : vector<9xf32>, memref<?x?xf32> 1200 /// ``` 1201 /// Is rewritten to approximately the following pseudo-IR: 1202 /// ``` 1203 /// for i = 0 to 9 { 1204 /// %t = vector.extractelement %vec[i] : vector<9xf32> 1205 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 1206 /// } 1207 /// ``` 1208 template <typename OpTy> 1209 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 1210 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 1211 1212 LogicalResult matchAndRewrite(OpTy xferOp, 1213 PatternRewriter &rewriter) const override { 1214 auto map = xferOp.permutation_map(); 1215 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1216 1217 if (!memRefType) 1218 return failure(); 1219 if (xferOp.getVectorType().getRank() != 1) 1220 return failure(); 1221 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) 1222 return failure(); // Handled by ConvertVectorToLLVM 1223 1224 // Loop bounds, step, state... 1225 Location loc = xferOp.getLoc(); 1226 auto vecType = xferOp.getVectorType(); 1227 auto lb = rewriter.create<ConstantIndexOp>(loc, 0); 1228 auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0)); 1229 auto step = rewriter.create<ConstantIndexOp>(loc, 1); 1230 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); 1231 1232 // Generate for loop. 1233 rewriter.replaceOpWithNewOp<scf::ForOp>( 1234 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 1235 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 1236 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); 1237 }); 1238 1239 return success(); 1240 } 1241 }; 1242 1243 } // namespace lowering_1_d 1244 } // namespace 1245 1246 namespace mlir { 1247 1248 void populateVectorToSCFConversionPatterns( 1249 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 1250 if (options.unroll) { 1251 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, 1252 lowering_n_d_unrolled::UnrollTransferWriteConversion>( 1253 patterns.getContext(), options); 1254 } else { 1255 patterns.add<lowering_n_d::PrepareTransferReadConversion, 1256 lowering_n_d::PrepareTransferWriteConversion, 1257 lowering_n_d::TransferOpConversion<TransferReadOp>, 1258 lowering_n_d::TransferOpConversion<TransferWriteOp>>( 1259 patterns.getContext(), options); 1260 } 1261 1262 if (options.targetRank == 1) { 1263 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, 1264 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( 1265 patterns.getContext(), options); 1266 } 1267 } 1268 1269 } // namespace mlir 1270 1271 namespace { 1272 1273 struct ConvertVectorToSCFPass 1274 : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 1275 ConvertVectorToSCFPass() = default; 1276 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1277 this->fullUnroll = options.unroll; 1278 this->targetRank = options.targetRank; 1279 this->lowerPermutationMaps = options.lowerPermutationMaps; 1280 this->lowerTensors = options.lowerTensors; 1281 } 1282 1283 void runOnFunction() override { 1284 VectorTransferToSCFOptions options; 1285 options.unroll = fullUnroll; 1286 options.targetRank = targetRank; 1287 options.lowerPermutationMaps = lowerPermutationMaps; 1288 options.lowerTensors = lowerTensors; 1289 1290 // Lower permutation maps first. 1291 if (lowerPermutationMaps) { 1292 RewritePatternSet lowerTransferPatterns(getFunction().getContext()); 1293 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 1294 lowerTransferPatterns); 1295 (void)applyPatternsAndFoldGreedily(getFunction(), 1296 std::move(lowerTransferPatterns)); 1297 } 1298 1299 RewritePatternSet patterns(getFunction().getContext()); 1300 populateVectorToSCFConversionPatterns(patterns, options); 1301 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 1302 } 1303 }; 1304 1305 } // namespace 1306 1307 std::unique_ptr<Pass> 1308 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1309 return std::make_unique<ConvertVectorToSCFPass>(options); 1310 } 1311