1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector transfer operations to SCF. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 16 17 #include "../PassDetail.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/Utils.h" 20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/Vector/VectorTransforms.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/ImplicitLocOpBuilder.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "mlir/Transforms/Passes.h" 29 30 using namespace mlir; 31 using vector::TransferReadOp; 32 using vector::TransferWriteOp; 33 34 namespace { 35 36 /// Attribute name used for labeling transfer ops during progressive lowering. 37 static const char kPassLabel[] = "__vector_to_scf_lowering__"; 38 39 /// Patterns that inherit from this struct have access to 40 /// VectorTransferToSCFOptions. 41 template <typename OpTy> 42 struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 43 explicit VectorToSCFPattern(MLIRContext *context, 44 VectorTransferToSCFOptions opt) 45 : OpRewritePattern<OpTy>(context), options(opt) {} 46 47 VectorTransferToSCFOptions options; 48 }; 49 50 /// Given a vector transfer op, calculate which dimension of the `source` 51 /// memref should be unpacked in the next application of TransferOpConversion. 52 /// A return value of None indicates a broadcast. 53 template <typename OpTy> 54 static Optional<int64_t> unpackedDim(OpTy xferOp) { 55 auto map = xferOp.permutation_map(); 56 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 57 return expr.getPosition(); 58 } 59 assert(xferOp.isBroadcastDim(0) && 60 "Expected AffineDimExpr or AffineConstantExpr"); 61 return None; 62 } 63 64 /// Compute the permutation map for the new (N-1)-D vector transfer op. This 65 /// map is identical to the current permutation map, but the first result is 66 /// omitted. 67 template <typename OpTy> 68 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { 69 auto map = xferOp.permutation_map(); 70 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 71 b.getContext()); 72 } 73 74 /// Calculate the indices for the new vector transfer op. 75 /// 76 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 77 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 78 /// ^^^^^^ 79 /// `iv` is the iteration variable of the (new) surrounding loop. 80 template <typename OpTy> 81 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, 82 SmallVector<Value, 8> &indices) { 83 typename OpTy::Adaptor adaptor(xferOp); 84 // Corresponding memref dim of the vector dim that is unpacked. 85 auto dim = unpackedDim(xferOp); 86 auto prevIndices = adaptor.indices(); 87 indices.append(prevIndices.begin(), prevIndices.end()); 88 89 Location loc = xferOp.getLoc(); 90 bool isBroadcast = !dim.hasValue(); 91 if (!isBroadcast) { 92 AffineExpr d0, d1; 93 bindDims(xferOp.getContext(), d0, d1); 94 Value offset = adaptor.indices()[dim.getValue()]; 95 indices[dim.getValue()] = 96 makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 97 } 98 } 99 100 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, 101 Value value) { 102 if (hasRetVal) { 103 assert(value && "Expected non-empty value"); 104 b.create<scf::YieldOp>(loc, value); 105 } else { 106 b.create<scf::YieldOp>(loc); 107 } 108 } 109 110 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 111 /// is set to true. No such check is generated under following circumstances: 112 /// * xferOp does not have a mask. 113 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 114 /// computed and attached to the new transfer op in the pattern.) 115 /// * The to-be-unpacked dim of xferOp is a broadcast. 116 template <typename OpTy> 117 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { 118 if (!xferOp.mask()) 119 return Value(); 120 if (xferOp.getMaskType().getRank() != 1) 121 return Value(); 122 if (xferOp.isBroadcastDim(0)) 123 return Value(); 124 125 Location loc = xferOp.getLoc(); 126 Value ivI32 = b.create<arith::IndexCastOp>( 127 loc, IntegerType::get(b.getContext(), 32), iv); 128 return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32); 129 } 130 131 /// Helper function TransferOpConversion and TransferOp1dConversion. 132 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 133 /// specified dimension `dim` with the loop iteration variable `iv`. 134 /// E.g., when unpacking dimension 0 from: 135 /// ``` 136 /// %vec = vector.transfer_read %A[%a, %b] %cst 137 /// : vector<5x4xf32>, memref<?x?xf32> 138 /// ``` 139 /// An if check similar to this will be generated inside the loop: 140 /// ``` 141 /// %d = memref.dim %A, %c0 : memref<?x?xf32> 142 /// if (%a + iv < %d) { 143 /// (in-bounds case) 144 /// } else { 145 /// (out-of-bounds case) 146 /// } 147 /// ``` 148 /// 149 /// If the transfer is 1D and has a mask, this function generates a more complex 150 /// check also accounts for potentially masked out elements. 151 /// 152 /// This function variant returns the value returned by `inBoundsCase` or 153 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 154 /// `resultTypes`. 155 template <typename OpTy> 156 static Value generateInBoundsCheck( 157 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 158 TypeRange resultTypes, 159 function_ref<Value(OpBuilder &, Location)> inBoundsCase, 160 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 161 bool hasRetVal = !resultTypes.empty(); 162 Value cond; // Condition to be built... 163 164 // Condition check 1: Access in-bounds? 165 bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. 166 Location loc = xferOp.getLoc(); 167 ImplicitLocOpBuilder lb(xferOp.getLoc(), b); 168 if (!xferOp.isDimInBounds(0) && !isBroadcast) { 169 Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim); 170 AffineExpr d0, d1; 171 bindDims(xferOp.getContext(), d0, d1); 172 Value base = xferOp.indices()[dim.getValue()]; 173 Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); 174 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, 175 memrefIdx); 176 } 177 178 // Condition check 2: Masked in? 179 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { 180 if (cond) 181 cond = lb.create<arith::AndIOp>(cond, maskCond); 182 else 183 cond = maskCond; 184 } 185 186 // If the condition is non-empty, generate an SCF::IfOp. 187 if (cond) { 188 auto check = lb.create<scf::IfOp>( 189 resultTypes, cond, 190 /*thenBuilder=*/ 191 [&](OpBuilder &b, Location loc) { 192 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); 193 }, 194 /*elseBuilder=*/ 195 [&](OpBuilder &b, Location loc) { 196 if (outOfBoundsCase) { 197 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); 198 } else { 199 b.create<scf::YieldOp>(loc); 200 } 201 }); 202 203 return hasRetVal ? check.getResult(0) : Value(); 204 } 205 206 // Condition is empty, no need for an SCF::IfOp. 207 return inBoundsCase(b, loc); 208 } 209 210 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 211 /// a return value. Consequently, this function does not have a return value. 212 template <typename OpTy> 213 static void generateInBoundsCheck( 214 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 215 function_ref<void(OpBuilder &, Location)> inBoundsCase, 216 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 217 generateInBoundsCheck( 218 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), 219 /*inBoundsCase=*/ 220 [&](OpBuilder &b, Location loc) { 221 inBoundsCase(b, loc); 222 return Value(); 223 }, 224 /*outOfBoundsCase=*/ 225 [&](OpBuilder &b, Location loc) { 226 if (outOfBoundsCase) 227 outOfBoundsCase(b, loc); 228 return Value(); 229 }); 230 } 231 232 /// Given an ArrayAttr, return a copy where the first element is dropped. 233 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { 234 if (!attr) 235 return attr; 236 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); 237 } 238 239 /// Add the pass label to a vector transfer op if its rank is not the target 240 /// rank. 241 template <typename OpTy> 242 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, 243 unsigned targetRank) { 244 if (newXferOp.getVectorType().getRank() > targetRank) 245 newXferOp->setAttr(kPassLabel, b.getUnitAttr()); 246 } 247 248 /// Return true if this transfer op operates on a source tensor. 249 template <typename OpTy> 250 static bool isTensorOp(OpTy xferOp) { 251 if (xferOp.getShapedType().template isa<RankedTensorType>()) { 252 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { 253 // TransferWriteOps on tensors have a result. 254 assert(xferOp->getNumResults() > 0); 255 } 256 return true; 257 } 258 return false; 259 } 260 261 namespace lowering_n_d { 262 263 /// Helper data structure for data and mask buffers. 264 struct BufferAllocs { 265 Value dataBuffer; 266 Value maskBuffer; 267 }; 268 269 /// Allocate temporary buffers for data (vector) and mask (if present). 270 /// TODO: Parallelism and threadlocal considerations. 271 template <typename OpTy> 272 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { 273 Location loc = xferOp.getLoc(); 274 OpBuilder::InsertionGuard guard(b); 275 Operation *scope = 276 xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 277 assert(scope && "Expected op to be inside automatic allocation scope"); 278 b.setInsertionPointToStart(&scope->getRegion(0).front()); 279 280 BufferAllocs result; 281 auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 282 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); 283 284 if (xferOp.mask()) { 285 auto maskType = MemRefType::get({}, xferOp.mask().getType()); 286 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); 287 b.setInsertionPoint(xferOp); 288 b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer); 289 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer); 290 } 291 292 return result; 293 } 294 295 /// Given a MemRefType with VectorType element type, unpack one dimension from 296 /// the VectorType into the MemRefType. 297 /// 298 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 299 static MemRefType unpackOneDim(MemRefType type) { 300 auto vectorType = type.getElementType().dyn_cast<VectorType>(); 301 auto memrefShape = type.getShape(); 302 SmallVector<int64_t, 8> newMemrefShape; 303 newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 304 newMemrefShape.push_back(vectorType.getDimSize(0)); 305 return MemRefType::get(newMemrefShape, 306 VectorType::get(vectorType.getShape().drop_front(), 307 vectorType.getElementType())); 308 } 309 310 /// Given a transfer op, find the memref from which the mask is loaded. This 311 /// is similar to Strategy<TransferWriteOp>::getBuffer. 312 template <typename OpTy> 313 static Value getMaskBuffer(OpTy xferOp) { 314 assert(xferOp.mask() && "Expected that transfer op has mask"); 315 auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>(); 316 assert(loadOp && "Expected transfer op mask produced by LoadOp"); 317 return loadOp.getMemRef(); 318 } 319 320 /// Codegen strategy, depending on the operation. 321 template <typename OpTy> 322 struct Strategy; 323 324 /// Code strategy for vector TransferReadOp. 325 template <> 326 struct Strategy<TransferReadOp> { 327 /// Find the StoreOp that is used for writing the current TransferReadOp's 328 /// result to the temporary buffer allocation. 329 static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 330 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 331 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 332 assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 333 return storeOp; 334 } 335 336 /// Find the temporary buffer allocation. All labeled TransferReadOps are 337 /// used like this, where %buf is either the buffer allocation or a type cast 338 /// of the buffer allocation: 339 /// ``` 340 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 341 /// memref.store %vec, %buf[...] ... 342 /// ``` 343 static Value getBuffer(TransferReadOp xferOp) { 344 return getStoreOp(xferOp).getMemRef(); 345 } 346 347 /// Retrieve the indices of the current StoreOp that stores into the buffer. 348 static void getBufferIndices(TransferReadOp xferOp, 349 SmallVector<Value, 8> &indices) { 350 auto storeOp = getStoreOp(xferOp); 351 auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); 352 indices.append(prevIndices.begin(), prevIndices.end()); 353 } 354 355 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 356 /// accesses on the to-be-unpacked dimension. 357 /// 358 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 359 /// variable `iv`. 360 /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 361 /// 362 /// E.g.: 363 /// ``` 364 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 365 /// : memref<?x?x?xf32>, vector<4x3xf32> 366 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 367 /// ``` 368 /// Is rewritten to: 369 /// ``` 370 /// %casted = vector.type_cast %buf 371 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 372 /// for %j = 0 to 4 { 373 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 374 /// : memref<?x?x?xf32>, vector<3xf32> 375 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 376 /// } 377 /// ``` 378 /// 379 /// Note: The loop and type cast are generated in TransferOpConversion. 380 /// The original TransferReadOp and store op are deleted in `cleanup`. 381 /// Note: The `mask` operand is set in TransferOpConversion. 382 static TransferReadOp rewriteOp(OpBuilder &b, 383 VectorTransferToSCFOptions options, 384 TransferReadOp xferOp, Value buffer, Value iv, 385 ValueRange /*loopState*/) { 386 SmallVector<Value, 8> storeIndices; 387 getBufferIndices(xferOp, storeIndices); 388 storeIndices.push_back(iv); 389 390 SmallVector<Value, 8> xferIndices; 391 getXferIndices(b, xferOp, iv, xferIndices); 392 393 Location loc = xferOp.getLoc(); 394 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 395 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 396 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 397 auto newXferOp = b.create<vector::TransferReadOp>( 398 loc, vecType, xferOp.source(), xferIndices, 399 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(), 400 Value(), inBoundsAttr); 401 402 maybeApplyPassLabel(b, newXferOp, options.targetRank); 403 404 b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices); 405 return newXferOp; 406 } 407 408 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 409 /// padding value to the temporary buffer. 410 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, 411 Value buffer, Value iv, 412 ValueRange /*loopState*/) { 413 SmallVector<Value, 8> storeIndices; 414 getBufferIndices(xferOp, storeIndices); 415 storeIndices.push_back(iv); 416 417 Location loc = xferOp.getLoc(); 418 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 419 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 420 auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding()); 421 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); 422 423 return Value(); 424 } 425 426 /// Cleanup after rewriting the op. 427 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, 428 scf::ForOp /*forOp*/) { 429 rewriter.eraseOp(getStoreOp(xferOp)); 430 rewriter.eraseOp(xferOp); 431 } 432 433 /// Return the initial loop state for the generated scf.for loop. 434 static Value initialLoopState(TransferReadOp xferOp) { return Value(); } 435 }; 436 437 /// Codegen strategy for vector TransferWriteOp. 438 template <> 439 struct Strategy<TransferWriteOp> { 440 /// Find the temporary buffer allocation. All labeled TransferWriteOps are 441 /// used like this, where %buf is either the buffer allocation or a type cast 442 /// of the buffer allocation: 443 /// ``` 444 /// %vec = memref.load %buf[...] ... 445 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 446 /// ``` 447 static Value getBuffer(TransferWriteOp xferOp) { 448 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 449 assert(loadOp && "Expected transfer op vector produced by LoadOp"); 450 return loadOp.getMemRef(); 451 } 452 453 /// Retrieve the indices of the current LoadOp that loads from the buffer. 454 static void getBufferIndices(TransferWriteOp xferOp, 455 SmallVector<Value, 8> &indices) { 456 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 457 auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); 458 indices.append(prevIndices.begin(), prevIndices.end()); 459 } 460 461 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 462 /// accesses on the to-be-unpacked dimension. 463 /// 464 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 465 /// using the loop iteration variable `iv`. 466 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 467 /// to memory. 468 /// 469 /// Note: For more details, see comments on Strategy<TransferReadOp>. 470 static TransferWriteOp rewriteOp(OpBuilder &b, 471 VectorTransferToSCFOptions options, 472 TransferWriteOp xferOp, Value buffer, 473 Value iv, ValueRange loopState) { 474 SmallVector<Value, 8> loadIndices; 475 getBufferIndices(xferOp, loadIndices); 476 loadIndices.push_back(iv); 477 478 SmallVector<Value, 8> xferIndices; 479 getXferIndices(b, xferOp, iv, xferIndices); 480 481 Location loc = xferOp.getLoc(); 482 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); 483 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 484 auto source = loopState.empty() ? xferOp.source() : loopState[0]; 485 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 486 auto newXferOp = b.create<vector::TransferWriteOp>( 487 loc, type, vec, source, xferIndices, 488 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 489 inBoundsAttr); 490 491 maybeApplyPassLabel(b, newXferOp, options.targetRank); 492 493 return newXferOp; 494 } 495 496 /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 497 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, 498 Value buffer, Value iv, 499 ValueRange loopState) { 500 return isTensorOp(xferOp) ? loopState[0] : Value(); 501 } 502 503 /// Cleanup after rewriting the op. 504 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, 505 scf::ForOp forOp) { 506 if (isTensorOp(xferOp)) { 507 assert(forOp->getNumResults() == 1 && "Expected one for loop result"); 508 rewriter.replaceOp(xferOp, forOp->getResult(0)); 509 } else { 510 rewriter.eraseOp(xferOp); 511 } 512 } 513 514 /// Return the initial loop state for the generated scf.for loop. 515 static Value initialLoopState(TransferWriteOp xferOp) { 516 return isTensorOp(xferOp) ? xferOp.source() : Value(); 517 } 518 }; 519 520 template <typename OpTy> 521 LogicalResult checkPrepareXferOp(OpTy xferOp, 522 VectorTransferToSCFOptions options) { 523 if (xferOp->hasAttr(kPassLabel)) 524 return failure(); 525 if (xferOp.getVectorType().getRank() <= options.targetRank) 526 return failure(); 527 if (isTensorOp(xferOp) && !options.lowerTensors) 528 return failure(); 529 // Transfer ops that modify the element type are not supported atm. 530 if (xferOp.getVectorType().getElementType() != 531 xferOp.getShapedType().getElementType()) 532 return failure(); 533 return success(); 534 } 535 536 /// Prepare a TransferReadOp for progressive lowering. 537 /// 538 /// 1. Allocate a temporary buffer. 539 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 540 /// 3. Store the result of the TransferReadOp into the temporary buffer. 541 /// 4. Load the result from the temporary buffer and replace all uses of the 542 /// original TransferReadOp with this load. 543 /// 544 /// E.g.: 545 /// ``` 546 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 547 /// : vector<5x4xf32>, memref<?x?x?xf32> 548 /// ``` 549 /// is rewritten to: 550 /// ``` 551 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 552 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 553 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 554 /// memref.store %1, %0[] : memref<vector<5x4xf32>> 555 /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 556 /// ``` 557 /// 558 /// Note: A second temporary buffer may be allocated for the `mask` operand. 559 struct PrepareTransferReadConversion 560 : public VectorToSCFPattern<TransferReadOp> { 561 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 562 563 LogicalResult matchAndRewrite(TransferReadOp xferOp, 564 PatternRewriter &rewriter) const override { 565 if (checkPrepareXferOp(xferOp, options).failed()) 566 return failure(); 567 568 auto buffers = allocBuffers(rewriter, xferOp); 569 auto *newXfer = rewriter.clone(*xferOp.getOperation()); 570 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 571 if (xferOp.mask()) { 572 dyn_cast<TransferReadOp>(newXfer).maskMutable().assign( 573 buffers.maskBuffer); 574 } 575 576 Location loc = xferOp.getLoc(); 577 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), 578 buffers.dataBuffer); 579 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 580 581 return success(); 582 } 583 }; 584 585 /// Prepare a TransferWriteOp for progressive lowering. 586 /// 587 /// 1. Allocate a temporary buffer. 588 /// 2. Store the vector into the buffer. 589 /// 3. Load the vector from the buffer again. 590 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 591 /// marking it eligible for progressive lowering via TransferOpConversion. 592 /// 593 /// E.g.: 594 /// ``` 595 /// vector.transfer_write %vec, %A[%a, %b, %c] 596 /// : vector<5x4xf32>, memref<?x?x?xf32> 597 /// ``` 598 /// is rewritten to: 599 /// ``` 600 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 601 /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 602 /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 603 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 604 /// : vector<5x4xf32>, memref<?x?x?xf32> 605 /// ``` 606 /// 607 /// Note: A second temporary buffer may be allocated for the `mask` operand. 608 struct PrepareTransferWriteConversion 609 : public VectorToSCFPattern<TransferWriteOp> { 610 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 611 612 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 613 PatternRewriter &rewriter) const override { 614 if (checkPrepareXferOp(xferOp, options).failed()) 615 return failure(); 616 617 Location loc = xferOp.getLoc(); 618 auto buffers = allocBuffers(rewriter, xferOp); 619 rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer); 620 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); 621 rewriter.updateRootInPlace(xferOp, [&]() { 622 xferOp.vectorMutable().assign(loadedVec); 623 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 624 }); 625 626 if (xferOp.mask()) { 627 rewriter.updateRootInPlace( 628 xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); }); 629 } 630 631 return success(); 632 } 633 }; 634 635 /// Progressive lowering of vector transfer ops: Unpack one dimension. 636 /// 637 /// 1. Unpack one dimension from the current buffer type and cast the buffer 638 /// to that new type. E.g.: 639 /// ``` 640 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 641 /// vector.transfer_write %vec ... 642 /// ``` 643 /// The following cast is generated: 644 /// ``` 645 /// %casted = vector.type_cast %0 646 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 647 /// ``` 648 /// 2. Generate a for loop and rewrite the transfer op according to the 649 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 650 /// out-of-bounds, generate an if-check and handle both cases separately. 651 /// 3. Clean up according to the corresponding Strategy<OpTy>. 652 /// 653 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor 654 /// source (as opposed to a memref source), then each iteration of the generated 655 /// scf.for loop yields the new tensor value. E.g.: 656 /// ``` 657 /// %result = scf.for i = 0 to 5 { 658 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> 659 /// %1 = vector.transfer_write %0, %source[...] 660 /// : vector<4x3xf32>, tensor<5x4x3xf32> 661 /// scf.yield %1 : tensor<5x4x3xf32> 662 /// } 663 /// ``` 664 template <typename OpTy> 665 struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 666 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 667 668 void initialize() { 669 // This pattern recursively unpacks one dimension at a time. The recursion 670 // bounded as the rank is strictly decreasing. 671 this->setHasBoundedRewriteRecursion(); 672 } 673 674 LogicalResult matchAndRewrite(OpTy xferOp, 675 PatternRewriter &rewriter) const override { 676 if (!xferOp->hasAttr(kPassLabel)) 677 return failure(); 678 679 // Find and cast data buffer. How the buffer can be found depends on OpTy. 680 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); 681 auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 682 auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>(); 683 auto castedDataType = unpackOneDim(dataBufferType); 684 auto castedDataBuffer = 685 locB.create<vector::TypeCastOp>(castedDataType, dataBuffer); 686 687 // If the xferOp has a mask: Find and cast mask buffer. 688 Value castedMaskBuffer; 689 if (xferOp.mask()) { 690 auto maskBuffer = getMaskBuffer(xferOp); 691 auto maskBufferType = 692 maskBuffer.getType().template dyn_cast<MemRefType>(); 693 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 694 // Do not unpack a dimension of the mask, if: 695 // * To-be-unpacked transfer op dimension is a broadcast. 696 // * Mask is 1D, i.e., the mask cannot be further unpacked. 697 // (That means that all remaining dimensions of the transfer op must 698 // be broadcasted.) 699 castedMaskBuffer = maskBuffer; 700 } else { 701 auto castedMaskType = unpackOneDim(maskBufferType); 702 castedMaskBuffer = 703 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); 704 } 705 } 706 707 // Loop bounds and step. 708 auto lb = locB.create<arith::ConstantIndexOp>(0); 709 auto ub = locB.create<arith::ConstantIndexOp>( 710 castedDataType.getDimSize(castedDataType.getRank() - 1)); 711 auto step = locB.create<arith::ConstantIndexOp>(1); 712 // TransferWriteOps that operate on tensors return the modified tensor and 713 // require a loop state. 714 auto loopState = Strategy<OpTy>::initialLoopState(xferOp); 715 716 // Generate for loop. 717 auto result = locB.create<scf::ForOp>( 718 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 719 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 720 Type stateType = loopState.empty() ? Type() : loopState[0].getType(); 721 722 auto result = generateInBoundsCheck( 723 b, xferOp, iv, unpackedDim(xferOp), 724 stateType ? TypeRange(stateType) : TypeRange(), 725 /*inBoundsCase=*/ 726 [&](OpBuilder &b, Location loc) { 727 // Create new transfer op. 728 OpTy newXfer = Strategy<OpTy>::rewriteOp( 729 b, this->options, xferOp, castedDataBuffer, iv, loopState); 730 731 // If old transfer op has a mask: Set mask on new transfer op. 732 // Special case: If the mask of the old transfer op is 1D and 733 // the 734 // unpacked dim is not a broadcast, no mask is 735 // needed on the new transfer op. 736 if (xferOp.mask() && (xferOp.isBroadcastDim(0) || 737 xferOp.getMaskType().getRank() > 1)) { 738 OpBuilder::InsertionGuard guard(b); 739 b.setInsertionPoint(newXfer); // Insert load before newXfer. 740 741 SmallVector<Value, 8> loadIndices; 742 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); 743 // In case of broadcast: Use same indices to load from memref 744 // as before. 745 if (!xferOp.isBroadcastDim(0)) 746 loadIndices.push_back(iv); 747 748 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, 749 loadIndices); 750 rewriter.updateRootInPlace( 751 newXfer, [&]() { newXfer.maskMutable().assign(mask); }); 752 } 753 754 return loopState.empty() ? Value() : newXfer->getResult(0); 755 }, 756 /*outOfBoundsCase=*/ 757 [&](OpBuilder &b, Location /*loc*/) { 758 return Strategy<OpTy>::handleOutOfBoundsDim( 759 b, xferOp, castedDataBuffer, iv, loopState); 760 }); 761 762 maybeYieldValue(b, loc, !loopState.empty(), result); 763 }); 764 765 Strategy<OpTy>::cleanup(rewriter, xferOp, result); 766 return success(); 767 } 768 }; 769 770 } // namespace lowering_n_d 771 772 namespace lowering_n_d_unrolled { 773 774 /// If the original transfer op has a mask, compute the mask of the new transfer 775 /// op (for the current iteration `i`) and assign it. 776 template <typename OpTy> 777 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, 778 int64_t i) { 779 if (!xferOp.mask()) 780 return; 781 782 if (xferOp.isBroadcastDim(0)) { 783 // To-be-unpacked dimension is a broadcast, which does not have a 784 // corresponding mask dimension. Mask attribute remains unchanged. 785 newXferOp.maskMutable().assign(xferOp.mask()); 786 return; 787 } 788 789 if (xferOp.getMaskType().getRank() > 1) { 790 // Unpack one dimension of the mask. 791 OpBuilder::InsertionGuard guard(b); 792 b.setInsertionPoint(newXferOp); // Insert load before newXfer. 793 794 llvm::SmallVector<int64_t, 1> indices({i}); 795 Location loc = xferOp.getLoc(); 796 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices); 797 newXferOp.maskMutable().assign(newMask); 798 } 799 800 // If we end up here: The mask of the old transfer op is 1D and the unpacked 801 // dim is not a broadcast, so no mask is needed on the new transfer op. 802 // `generateInBoundsCheck` will have evaluated the mask already. 803 } 804 805 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 806 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 807 /// memref buffer is allocated and the SCF loop is fully unrolled. 808 /// 809 /// ``` 810 /// E.g.: 811 /// ``` 812 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 813 /// : memref<?x?x?xf32>, vector<5x4xf32> 814 /// ``` 815 /// is rewritten to IR such as (simplified): 816 /// ``` 817 /// %v_init = splat %padding : vector<5x4xf32> 818 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 819 /// : memref<?x?x?xf32>, vector<4xf32> 820 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 821 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 822 /// : memref<?x?x?xf32>, vector<4xf32> 823 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 824 /// ... 825 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 826 /// : memref<?x?x?xf32>, vector<4xf32> 827 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 828 /// ``` 829 /// 830 /// Note: As an optimization, if the result of the original TransferReadOp 831 /// was directly inserted into another vector, no new %v_init vector is created. 832 /// Instead, the new TransferReadOp results are inserted into that vector. 833 struct UnrollTransferReadConversion 834 : public VectorToSCFPattern<TransferReadOp> { 835 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 836 837 void initialize() { 838 // This pattern recursively unpacks one dimension at a time. The recursion 839 // bounded as the rank is strictly decreasing. 840 setHasBoundedRewriteRecursion(); 841 } 842 843 /// Return the vector into which the newly created TransferReadOp results 844 /// are inserted. 845 Value getResultVector(TransferReadOp xferOp, 846 PatternRewriter &rewriter) const { 847 if (auto insertOp = getInsertOp(xferOp)) 848 return insertOp.dest(); 849 Location loc = xferOp.getLoc(); 850 return rewriter.create<SplatOp>(loc, xferOp.getVectorType(), 851 xferOp.padding()); 852 } 853 854 /// If the result of the TransferReadOp has exactly one user, which is a 855 /// vector::InsertOp, return that operation. 856 vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 857 if (xferOp->hasOneUse()) { 858 Operation *xferOpUser = *xferOp->getUsers().begin(); 859 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 860 return insertOp; 861 } 862 863 return vector::InsertOp(); 864 } 865 866 /// If the result of the TransferReadOp has exactly one user, which is a 867 /// vector::InsertOp, return that operation's indices. 868 void getInsertionIndices(TransferReadOp xferOp, 869 SmallVector<int64_t, 8> &indices) const { 870 if (auto insertOp = getInsertOp(xferOp)) { 871 llvm::for_each(insertOp.position(), [&](Attribute attr) { 872 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 873 }); 874 } 875 } 876 877 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 878 /// accesses, and broadcasts and transposes in permutation maps. 879 LogicalResult matchAndRewrite(TransferReadOp xferOp, 880 PatternRewriter &rewriter) const override { 881 if (xferOp.getVectorType().getRank() <= options.targetRank) 882 return failure(); 883 if (isTensorOp(xferOp) && !options.lowerTensors) 884 return failure(); 885 // Transfer ops that modify the element type are not supported atm. 886 if (xferOp.getVectorType().getElementType() != 887 xferOp.getShapedType().getElementType()) 888 return failure(); 889 890 auto insertOp = getInsertOp(xferOp); 891 auto vec = getResultVector(xferOp, rewriter); 892 auto vecType = vec.getType().dyn_cast<VectorType>(); 893 auto xferVecType = xferOp.getVectorType(); 894 auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), 895 xferVecType.getElementType()); 896 int64_t dimSize = xferVecType.getShape()[0]; 897 898 // Generate fully unrolled loop of transfer ops. 899 Location loc = xferOp.getLoc(); 900 for (int64_t i = 0; i < dimSize; ++i) { 901 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 902 903 vec = generateInBoundsCheck( 904 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), 905 /*inBoundsCase=*/ 906 [&](OpBuilder &b, Location loc) { 907 // Indices for the new transfer op. 908 SmallVector<Value, 8> xferIndices; 909 getXferIndices(b, xferOp, iv, xferIndices); 910 911 // Indices for the new vector.insert op. 912 SmallVector<int64_t, 8> insertionIndices; 913 getInsertionIndices(xferOp, insertionIndices); 914 insertionIndices.push_back(i); 915 916 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 917 auto newXferOp = b.create<vector::TransferReadOp>( 918 loc, newXferVecType, xferOp.source(), xferIndices, 919 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 920 xferOp.padding(), Value(), inBoundsAttr); 921 maybeAssignMask(b, xferOp, newXferOp, i); 922 return b.create<vector::InsertOp>(loc, newXferOp, vec, 923 insertionIndices); 924 }, 925 /*outOfBoundsCase=*/ 926 [&](OpBuilder &b, Location loc) { 927 // Loop through original (unmodified) vector. 928 return vec; 929 }); 930 } 931 932 if (insertOp) { 933 // Rewrite single user of the old TransferReadOp, which was an InsertOp. 934 rewriter.replaceOp(insertOp, vec); 935 rewriter.eraseOp(xferOp); 936 } else { 937 rewriter.replaceOp(xferOp, vec); 938 } 939 940 return success(); 941 } 942 }; 943 944 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 945 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 946 /// memref buffer is allocated and the SCF loop is fully unrolled. 947 /// 948 /// ``` 949 /// E.g.: 950 /// ``` 951 /// vector.transfer_write %vec, %A[%a, %b, %c] 952 /// : vector<5x4xf32>, memref<?x?x?xf32> 953 /// ``` 954 /// is rewritten to IR such as (simplified): 955 /// ``` 956 /// %v0 = vector.extract %vec[0] : vector<5x4xf32> 957 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 958 /// %v1 = vector.extract %vec[1] : vector<5x4xf32> 959 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 960 /// ... 961 /// %v4 = vector.extract %vec[4] : vector<5x4xf32> 962 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 963 /// ``` 964 /// 965 /// Note: As an optimization, if the vector of the original TransferWriteOp 966 /// was directly extracted from another vector via an ExtractOp `a`, extract 967 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 968 /// doing so, `a` may become dead, and the number of ExtractOps generated during 969 /// recursive application of this pattern will be minimal. 970 struct UnrollTransferWriteConversion 971 : public VectorToSCFPattern<TransferWriteOp> { 972 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 973 974 void initialize() { 975 // This pattern recursively unpacks one dimension at a time. The recursion 976 // bounded as the rank is strictly decreasing. 977 setHasBoundedRewriteRecursion(); 978 } 979 980 /// Return the vector from which newly generated ExtracOps will extract. 981 Value getDataVector(TransferWriteOp xferOp) const { 982 if (auto extractOp = getExtractOp(xferOp)) 983 return extractOp.vector(); 984 return xferOp.vector(); 985 } 986 987 /// If the input of the given TransferWriteOp is an ExtractOp, return it. 988 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 989 if (auto *op = xferOp.vector().getDefiningOp()) 990 return dyn_cast<vector::ExtractOp>(op); 991 return vector::ExtractOp(); 992 } 993 994 /// If the input of the given TransferWriteOp is an ExtractOp, return its 995 /// indices. 996 void getExtractionIndices(TransferWriteOp xferOp, 997 SmallVector<int64_t, 8> &indices) const { 998 if (auto extractOp = getExtractOp(xferOp)) { 999 llvm::for_each(extractOp.position(), [&](Attribute attr) { 1000 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 1001 }); 1002 } 1003 } 1004 1005 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 1006 /// accesses, and broadcasts and transposes in permutation maps. 1007 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 1008 PatternRewriter &rewriter) const override { 1009 if (xferOp.getVectorType().getRank() <= options.targetRank) 1010 return failure(); 1011 if (isTensorOp(xferOp) && !options.lowerTensors) 1012 return failure(); 1013 // Transfer ops that modify the element type are not supported atm. 1014 if (xferOp.getVectorType().getElementType() != 1015 xferOp.getShapedType().getElementType()) 1016 return failure(); 1017 1018 auto vec = getDataVector(xferOp); 1019 auto xferVecType = xferOp.getVectorType(); 1020 int64_t dimSize = xferVecType.getShape()[0]; 1021 auto source = xferOp.source(); // memref or tensor to be written to. 1022 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 1023 1024 // Generate fully unrolled loop of transfer ops. 1025 Location loc = xferOp.getLoc(); 1026 for (int64_t i = 0; i < dimSize; ++i) { 1027 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 1028 1029 auto updatedSource = generateInBoundsCheck( 1030 rewriter, xferOp, iv, unpackedDim(xferOp), 1031 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), 1032 /*inBoundsCase=*/ 1033 [&](OpBuilder &b, Location loc) { 1034 // Indices for the new transfer op. 1035 SmallVector<Value, 8> xferIndices; 1036 getXferIndices(b, xferOp, iv, xferIndices); 1037 1038 // Indices for the new vector.extract op. 1039 SmallVector<int64_t, 8> extractionIndices; 1040 getExtractionIndices(xferOp, extractionIndices); 1041 extractionIndices.push_back(i); 1042 1043 auto extracted = 1044 b.create<vector::ExtractOp>(loc, vec, extractionIndices); 1045 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 1046 auto newXferOp = b.create<vector::TransferWriteOp>( 1047 loc, sourceType, extracted, source, xferIndices, 1048 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 1049 inBoundsAttr); 1050 1051 maybeAssignMask(b, xferOp, newXferOp, i); 1052 1053 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); 1054 }, 1055 /*outOfBoundsCase=*/ 1056 [&](OpBuilder &b, Location loc) { 1057 return isTensorOp(xferOp) ? source : Value(); 1058 }); 1059 1060 if (isTensorOp(xferOp)) 1061 source = updatedSource; 1062 } 1063 1064 if (isTensorOp(xferOp)) 1065 rewriter.replaceOp(xferOp, source); 1066 else 1067 rewriter.eraseOp(xferOp); 1068 1069 return success(); 1070 } 1071 }; 1072 1073 } // namespace lowering_n_d_unrolled 1074 1075 namespace lowering_1_d { 1076 1077 /// Compute the indices into the memref for the LoadOp/StoreOp generated as 1078 /// part of TransferOp1dConversion. Return the memref dimension on which 1079 /// the transfer is operating. A return value of None indicates a broadcast. 1080 template <typename OpTy> 1081 static Optional<int64_t> 1082 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, 1083 SmallVector<Value, 8> &memrefIndices) { 1084 auto indices = xferOp.indices(); 1085 auto map = xferOp.permutation_map(); 1086 1087 memrefIndices.append(indices.begin(), indices.end()); 1088 assert(map.getNumResults() == 1 && 1089 "Expected 1 permutation map result for 1D transfer"); 1090 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 1091 Location loc = xferOp.getLoc(); 1092 auto dim = expr.getPosition(); 1093 AffineExpr d0, d1; 1094 bindDims(xferOp.getContext(), d0, d1); 1095 Value offset = memrefIndices[dim]; 1096 memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 1097 return dim; 1098 } 1099 1100 assert(xferOp.isBroadcastDim(0) && 1101 "Expected AffineDimExpr or AffineConstantExpr"); 1102 return None; 1103 } 1104 1105 /// Codegen strategy for TransferOp1dConversion, depending on the 1106 /// operation. 1107 template <typename OpTy> 1108 struct Strategy1d; 1109 1110 /// Codegen strategy for TransferReadOp. 1111 template <> 1112 struct Strategy1d<TransferReadOp> { 1113 static void generateForLoopBody(OpBuilder &b, Location loc, 1114 TransferReadOp xferOp, Value iv, 1115 ValueRange loopState) { 1116 SmallVector<Value, 8> indices; 1117 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1118 Value ivI32 = b.create<arith::IndexCastOp>( 1119 loc, IntegerType::get(b.getContext(), 32), iv); 1120 auto vec = loopState[0]; 1121 1122 // In case of out-of-bounds access, leave `vec` as is (was initialized with 1123 // padding value). 1124 auto nextVec = generateInBoundsCheck( 1125 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), 1126 /*inBoundsCase=*/ 1127 [&](OpBuilder &b, Location loc) { 1128 Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices); 1129 return b.create<vector::InsertElementOp>(loc, val, vec, ivI32); 1130 }, 1131 /*outOfBoundsCase=*/ 1132 [&](OpBuilder & /*b*/, Location loc) { return vec; }); 1133 b.create<scf::YieldOp>(loc, nextVec); 1134 } 1135 1136 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { 1137 // Inititalize vector with padding value. 1138 Location loc = xferOp.getLoc(); 1139 return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding()); 1140 } 1141 }; 1142 1143 /// Codegen strategy for TransferWriteOp. 1144 template <> 1145 struct Strategy1d<TransferWriteOp> { 1146 static void generateForLoopBody(OpBuilder &b, Location loc, 1147 TransferWriteOp xferOp, Value iv, 1148 ValueRange /*loopState*/) { 1149 SmallVector<Value, 8> indices; 1150 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1151 Value ivI32 = b.create<arith::IndexCastOp>( 1152 loc, IntegerType::get(b.getContext(), 32), iv); 1153 1154 // Nothing to do in case of out-of-bounds access. 1155 generateInBoundsCheck( 1156 b, xferOp, iv, dim, 1157 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 1158 auto val = 1159 b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32); 1160 b.create<memref::StoreOp>(loc, val, xferOp.source(), indices); 1161 }); 1162 b.create<scf::YieldOp>(loc); 1163 } 1164 1165 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { 1166 return Value(); 1167 } 1168 }; 1169 1170 /// Return true if the last dimension of the MemRefType has unit stride. 1171 static bool isLastMemrefDimUnitStride(MemRefType type) { 1172 int64_t offset; 1173 SmallVector<int64_t, 4> strides; 1174 auto successStrides = getStridesAndOffset(type, strides, offset); 1175 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 1176 } 1177 1178 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 1179 /// necessary in cases where a 1D vector transfer op cannot be lowered into 1180 /// vector load/stores due to non-unit strides or broadcasts: 1181 /// 1182 /// * Transfer dimension is not the last memref dimension 1183 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 1184 /// * Memref has a layout map with non-unit stride on the last dimension 1185 /// 1186 /// This pattern generates IR as follows: 1187 /// 1188 /// 1. Generate a for loop iterating over each vector element. 1189 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 1190 /// depending on OpTy. 1191 /// 1192 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 1193 /// can be generated instead of TransferOp1dConversion. Add such a pattern 1194 /// to ConvertVectorToLLVM. 1195 /// 1196 /// E.g.: 1197 /// ``` 1198 /// vector.transfer_write %vec, %A[%a, %b] 1199 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 1200 /// : vector<9xf32>, memref<?x?xf32> 1201 /// ``` 1202 /// Is rewritten to approximately the following pseudo-IR: 1203 /// ``` 1204 /// for i = 0 to 9 { 1205 /// %t = vector.extractelement %vec[i] : vector<9xf32> 1206 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 1207 /// } 1208 /// ``` 1209 template <typename OpTy> 1210 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 1211 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 1212 1213 LogicalResult matchAndRewrite(OpTy xferOp, 1214 PatternRewriter &rewriter) const override { 1215 auto map = xferOp.permutation_map(); 1216 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1217 1218 if (!memRefType) 1219 return failure(); 1220 if (xferOp.getVectorType().getRank() != 1) 1221 return failure(); 1222 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) 1223 return failure(); // Handled by ConvertVectorToLLVM 1224 1225 // Loop bounds, step, state... 1226 Location loc = xferOp.getLoc(); 1227 auto vecType = xferOp.getVectorType(); 1228 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1229 auto ub = 1230 rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); 1231 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1232 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); 1233 1234 // Generate for loop. 1235 rewriter.replaceOpWithNewOp<scf::ForOp>( 1236 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 1237 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 1238 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); 1239 }); 1240 1241 return success(); 1242 } 1243 }; 1244 1245 } // namespace lowering_1_d 1246 } // namespace 1247 1248 namespace mlir { 1249 1250 void populateVectorToSCFConversionPatterns( 1251 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 1252 if (options.unroll) { 1253 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, 1254 lowering_n_d_unrolled::UnrollTransferWriteConversion>( 1255 patterns.getContext(), options); 1256 } else { 1257 patterns.add<lowering_n_d::PrepareTransferReadConversion, 1258 lowering_n_d::PrepareTransferWriteConversion, 1259 lowering_n_d::TransferOpConversion<TransferReadOp>, 1260 lowering_n_d::TransferOpConversion<TransferWriteOp>>( 1261 patterns.getContext(), options); 1262 } 1263 1264 if (options.targetRank == 1) { 1265 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, 1266 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( 1267 patterns.getContext(), options); 1268 } 1269 } 1270 1271 } // namespace mlir 1272 1273 namespace { 1274 1275 struct ConvertVectorToSCFPass 1276 : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 1277 ConvertVectorToSCFPass() = default; 1278 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1279 this->fullUnroll = options.unroll; 1280 this->targetRank = options.targetRank; 1281 this->lowerPermutationMaps = options.lowerPermutationMaps; 1282 this->lowerTensors = options.lowerTensors; 1283 } 1284 1285 void runOnFunction() override { 1286 VectorTransferToSCFOptions options; 1287 options.unroll = fullUnroll; 1288 options.targetRank = targetRank; 1289 options.lowerPermutationMaps = lowerPermutationMaps; 1290 options.lowerTensors = lowerTensors; 1291 1292 // Lower permutation maps first. 1293 if (lowerPermutationMaps) { 1294 RewritePatternSet lowerTransferPatterns(getFunction().getContext()); 1295 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 1296 lowerTransferPatterns); 1297 (void)applyPatternsAndFoldGreedily(getFunction(), 1298 std::move(lowerTransferPatterns)); 1299 } 1300 1301 RewritePatternSet patterns(getFunction().getContext()); 1302 populateVectorToSCFConversionPatterns(patterns, options); 1303 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 1304 } 1305 }; 1306 1307 } // namespace 1308 1309 std::unique_ptr<Pass> 1310 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1311 return std::make_unique<ConvertVectorToSCFPass>(options); 1312 } 1313