1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector transfer operations to SCF. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 16 17 #include "../PassDetail.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Affine/Utils.h" 20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/Dialect/Vector/VectorUtils.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/ImplicitLocOpBuilder.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "mlir/Transforms/Passes.h" 30 31 using namespace mlir; 32 using vector::TransferReadOp; 33 using vector::TransferWriteOp; 34 35 namespace { 36 37 /// Attribute name used for labeling transfer ops during progressive lowering. 38 static const char kPassLabel[] = "__vector_to_scf_lowering__"; 39 40 /// Patterns that inherit from this struct have access to 41 /// VectorTransferToSCFOptions. 42 template <typename OpTy> 43 struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 44 explicit VectorToSCFPattern(MLIRContext *context, 45 VectorTransferToSCFOptions opt) 46 : OpRewritePattern<OpTy>(context), options(opt) {} 47 48 VectorTransferToSCFOptions options; 49 }; 50 51 /// Given a vector transfer op, calculate which dimension of the `source` 52 /// memref should be unpacked in the next application of TransferOpConversion. 53 /// A return value of None indicates a broadcast. 54 template <typename OpTy> 55 static Optional<int64_t> unpackedDim(OpTy xferOp) { 56 auto map = xferOp.permutation_map(); 57 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 58 return expr.getPosition(); 59 } 60 assert(xferOp.isBroadcastDim(0) && 61 "Expected AffineDimExpr or AffineConstantExpr"); 62 return None; 63 } 64 65 /// Compute the permutation map for the new (N-1)-D vector transfer op. This 66 /// map is identical to the current permutation map, but the first result is 67 /// omitted. 68 template <typename OpTy> 69 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { 70 auto map = xferOp.permutation_map(); 71 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 72 b.getContext()); 73 } 74 75 /// Calculate the indices for the new vector transfer op. 76 /// 77 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 78 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 79 /// ^^^^^^ 80 /// `iv` is the iteration variable of the (new) surrounding loop. 81 template <typename OpTy> 82 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, 83 SmallVector<Value, 8> &indices) { 84 typename OpTy::Adaptor adaptor(xferOp); 85 // Corresponding memref dim of the vector dim that is unpacked. 86 auto dim = unpackedDim(xferOp); 87 auto prevIndices = adaptor.indices(); 88 indices.append(prevIndices.begin(), prevIndices.end()); 89 90 Location loc = xferOp.getLoc(); 91 bool isBroadcast = !dim.hasValue(); 92 if (!isBroadcast) { 93 AffineExpr d0, d1; 94 bindDims(xferOp.getContext(), d0, d1); 95 Value offset = adaptor.indices()[dim.getValue()]; 96 indices[dim.getValue()] = 97 makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 98 } 99 } 100 101 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, 102 Value value) { 103 if (hasRetVal) { 104 assert(value && "Expected non-empty value"); 105 b.create<scf::YieldOp>(loc, value); 106 } else { 107 b.create<scf::YieldOp>(loc); 108 } 109 } 110 111 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 112 /// is set to true. No such check is generated under following circumstances: 113 /// * xferOp does not have a mask. 114 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 115 /// computed and attached to the new transfer op in the pattern.) 116 /// * The to-be-unpacked dim of xferOp is a broadcast. 117 template <typename OpTy> 118 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { 119 if (!xferOp.mask()) 120 return Value(); 121 if (xferOp.getMaskType().getRank() != 1) 122 return Value(); 123 if (xferOp.isBroadcastDim(0)) 124 return Value(); 125 126 Location loc = xferOp.getLoc(); 127 Value ivI32 = b.create<arith::IndexCastOp>( 128 loc, IntegerType::get(b.getContext(), 32), iv); 129 return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32); 130 } 131 132 /// Helper function TransferOpConversion and TransferOp1dConversion. 133 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 134 /// specified dimension `dim` with the loop iteration variable `iv`. 135 /// E.g., when unpacking dimension 0 from: 136 /// ``` 137 /// %vec = vector.transfer_read %A[%a, %b] %cst 138 /// : vector<5x4xf32>, memref<?x?xf32> 139 /// ``` 140 /// An if check similar to this will be generated inside the loop: 141 /// ``` 142 /// %d = memref.dim %A, %c0 : memref<?x?xf32> 143 /// if (%a + iv < %d) { 144 /// (in-bounds case) 145 /// } else { 146 /// (out-of-bounds case) 147 /// } 148 /// ``` 149 /// 150 /// If the transfer is 1D and has a mask, this function generates a more complex 151 /// check also accounts for potentially masked out elements. 152 /// 153 /// This function variant returns the value returned by `inBoundsCase` or 154 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 155 /// `resultTypes`. 156 template <typename OpTy> 157 static Value generateInBoundsCheck( 158 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 159 TypeRange resultTypes, 160 function_ref<Value(OpBuilder &, Location)> inBoundsCase, 161 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 162 bool hasRetVal = !resultTypes.empty(); 163 Value cond; // Condition to be built... 164 165 // Condition check 1: Access in-bounds? 166 bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. 167 Location loc = xferOp.getLoc(); 168 ImplicitLocOpBuilder lb(xferOp.getLoc(), b); 169 if (!xferOp.isDimInBounds(0) && !isBroadcast) { 170 Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim); 171 AffineExpr d0, d1; 172 bindDims(xferOp.getContext(), d0, d1); 173 Value base = xferOp.indices()[dim.getValue()]; 174 Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); 175 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, 176 memrefIdx); 177 } 178 179 // Condition check 2: Masked in? 180 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { 181 if (cond) 182 cond = lb.create<arith::AndIOp>(cond, maskCond); 183 else 184 cond = maskCond; 185 } 186 187 // If the condition is non-empty, generate an SCF::IfOp. 188 if (cond) { 189 auto check = lb.create<scf::IfOp>( 190 resultTypes, cond, 191 /*thenBuilder=*/ 192 [&](OpBuilder &b, Location loc) { 193 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); 194 }, 195 /*elseBuilder=*/ 196 [&](OpBuilder &b, Location loc) { 197 if (outOfBoundsCase) { 198 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); 199 } else { 200 b.create<scf::YieldOp>(loc); 201 } 202 }); 203 204 return hasRetVal ? check.getResult(0) : Value(); 205 } 206 207 // Condition is empty, no need for an SCF::IfOp. 208 return inBoundsCase(b, loc); 209 } 210 211 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 212 /// a return value. Consequently, this function does not have a return value. 213 template <typename OpTy> 214 static void generateInBoundsCheck( 215 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 216 function_ref<void(OpBuilder &, Location)> inBoundsCase, 217 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 218 generateInBoundsCheck( 219 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), 220 /*inBoundsCase=*/ 221 [&](OpBuilder &b, Location loc) { 222 inBoundsCase(b, loc); 223 return Value(); 224 }, 225 /*outOfBoundsCase=*/ 226 [&](OpBuilder &b, Location loc) { 227 if (outOfBoundsCase) 228 outOfBoundsCase(b, loc); 229 return Value(); 230 }); 231 } 232 233 /// Given an ArrayAttr, return a copy where the first element is dropped. 234 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { 235 if (!attr) 236 return attr; 237 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); 238 } 239 240 /// Add the pass label to a vector transfer op if its rank is not the target 241 /// rank. 242 template <typename OpTy> 243 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, 244 unsigned targetRank) { 245 if (newXferOp.getVectorType().getRank() > targetRank) 246 newXferOp->setAttr(kPassLabel, b.getUnitAttr()); 247 } 248 249 /// Return true if this transfer op operates on a source tensor. 250 template <typename OpTy> 251 static bool isTensorOp(OpTy xferOp) { 252 if (xferOp.getShapedType().template isa<RankedTensorType>()) { 253 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { 254 // TransferWriteOps on tensors have a result. 255 assert(xferOp->getNumResults() > 0); 256 } 257 return true; 258 } 259 return false; 260 } 261 262 namespace lowering_n_d { 263 264 /// Helper data structure for data and mask buffers. 265 struct BufferAllocs { 266 Value dataBuffer; 267 Value maskBuffer; 268 }; 269 270 /// Allocate temporary buffers for data (vector) and mask (if present). 271 /// TODO: Parallelism and threadlocal considerations. 272 template <typename OpTy> 273 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { 274 Location loc = xferOp.getLoc(); 275 OpBuilder::InsertionGuard guard(b); 276 Operation *scope = 277 xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 278 assert(scope && "Expected op to be inside automatic allocation scope"); 279 b.setInsertionPointToStart(&scope->getRegion(0).front()); 280 281 BufferAllocs result; 282 auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 283 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); 284 285 if (xferOp.mask()) { 286 auto maskType = MemRefType::get({}, xferOp.mask().getType()); 287 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); 288 b.setInsertionPoint(xferOp); 289 b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer); 290 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer); 291 } 292 293 return result; 294 } 295 296 /// Given a MemRefType with VectorType element type, unpack one dimension from 297 /// the VectorType into the MemRefType. 298 /// 299 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 300 static MemRefType unpackOneDim(MemRefType type) { 301 auto vectorType = type.getElementType().dyn_cast<VectorType>(); 302 auto memrefShape = type.getShape(); 303 SmallVector<int64_t, 8> newMemrefShape; 304 newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 305 newMemrefShape.push_back(vectorType.getDimSize(0)); 306 return MemRefType::get(newMemrefShape, 307 VectorType::get(vectorType.getShape().drop_front(), 308 vectorType.getElementType())); 309 } 310 311 /// Given a transfer op, find the memref from which the mask is loaded. This 312 /// is similar to Strategy<TransferWriteOp>::getBuffer. 313 template <typename OpTy> 314 static Value getMaskBuffer(OpTy xferOp) { 315 assert(xferOp.mask() && "Expected that transfer op has mask"); 316 auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>(); 317 assert(loadOp && "Expected transfer op mask produced by LoadOp"); 318 return loadOp.getMemRef(); 319 } 320 321 /// Codegen strategy, depending on the operation. 322 template <typename OpTy> 323 struct Strategy; 324 325 /// Code strategy for vector TransferReadOp. 326 template <> 327 struct Strategy<TransferReadOp> { 328 /// Find the StoreOp that is used for writing the current TransferReadOp's 329 /// result to the temporary buffer allocation. 330 static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 331 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 332 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 333 assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 334 return storeOp; 335 } 336 337 /// Find the temporary buffer allocation. All labeled TransferReadOps are 338 /// used like this, where %buf is either the buffer allocation or a type cast 339 /// of the buffer allocation: 340 /// ``` 341 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 342 /// memref.store %vec, %buf[...] ... 343 /// ``` 344 static Value getBuffer(TransferReadOp xferOp) { 345 return getStoreOp(xferOp).getMemRef(); 346 } 347 348 /// Retrieve the indices of the current StoreOp that stores into the buffer. 349 static void getBufferIndices(TransferReadOp xferOp, 350 SmallVector<Value, 8> &indices) { 351 auto storeOp = getStoreOp(xferOp); 352 auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); 353 indices.append(prevIndices.begin(), prevIndices.end()); 354 } 355 356 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 357 /// accesses on the to-be-unpacked dimension. 358 /// 359 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 360 /// variable `iv`. 361 /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 362 /// 363 /// E.g.: 364 /// ``` 365 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 366 /// : memref<?x?x?xf32>, vector<4x3xf32> 367 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 368 /// ``` 369 /// Is rewritten to: 370 /// ``` 371 /// %casted = vector.type_cast %buf 372 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 373 /// for %j = 0 to 4 { 374 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 375 /// : memref<?x?x?xf32>, vector<3xf32> 376 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 377 /// } 378 /// ``` 379 /// 380 /// Note: The loop and type cast are generated in TransferOpConversion. 381 /// The original TransferReadOp and store op are deleted in `cleanup`. 382 /// Note: The `mask` operand is set in TransferOpConversion. 383 static TransferReadOp rewriteOp(OpBuilder &b, 384 VectorTransferToSCFOptions options, 385 TransferReadOp xferOp, Value buffer, Value iv, 386 ValueRange /*loopState*/) { 387 SmallVector<Value, 8> storeIndices; 388 getBufferIndices(xferOp, storeIndices); 389 storeIndices.push_back(iv); 390 391 SmallVector<Value, 8> xferIndices; 392 getXferIndices(b, xferOp, iv, xferIndices); 393 394 Location loc = xferOp.getLoc(); 395 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 396 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 397 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 398 auto newXferOp = b.create<vector::TransferReadOp>( 399 loc, vecType, xferOp.source(), xferIndices, 400 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(), 401 Value(), inBoundsAttr); 402 403 maybeApplyPassLabel(b, newXferOp, options.targetRank); 404 405 b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices); 406 return newXferOp; 407 } 408 409 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 410 /// padding value to the temporary buffer. 411 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, 412 Value buffer, Value iv, 413 ValueRange /*loopState*/) { 414 SmallVector<Value, 8> storeIndices; 415 getBufferIndices(xferOp, storeIndices); 416 storeIndices.push_back(iv); 417 418 Location loc = xferOp.getLoc(); 419 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 420 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 421 auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding()); 422 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); 423 424 return Value(); 425 } 426 427 /// Cleanup after rewriting the op. 428 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, 429 scf::ForOp /*forOp*/) { 430 rewriter.eraseOp(getStoreOp(xferOp)); 431 rewriter.eraseOp(xferOp); 432 } 433 434 /// Return the initial loop state for the generated scf.for loop. 435 static Value initialLoopState(TransferReadOp xferOp) { return Value(); } 436 }; 437 438 /// Codegen strategy for vector TransferWriteOp. 439 template <> 440 struct Strategy<TransferWriteOp> { 441 /// Find the temporary buffer allocation. All labeled TransferWriteOps are 442 /// used like this, where %buf is either the buffer allocation or a type cast 443 /// of the buffer allocation: 444 /// ``` 445 /// %vec = memref.load %buf[...] ... 446 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 447 /// ``` 448 static Value getBuffer(TransferWriteOp xferOp) { 449 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 450 assert(loadOp && "Expected transfer op vector produced by LoadOp"); 451 return loadOp.getMemRef(); 452 } 453 454 /// Retrieve the indices of the current LoadOp that loads from the buffer. 455 static void getBufferIndices(TransferWriteOp xferOp, 456 SmallVector<Value, 8> &indices) { 457 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 458 auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); 459 indices.append(prevIndices.begin(), prevIndices.end()); 460 } 461 462 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 463 /// accesses on the to-be-unpacked dimension. 464 /// 465 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 466 /// using the loop iteration variable `iv`. 467 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 468 /// to memory. 469 /// 470 /// Note: For more details, see comments on Strategy<TransferReadOp>. 471 static TransferWriteOp rewriteOp(OpBuilder &b, 472 VectorTransferToSCFOptions options, 473 TransferWriteOp xferOp, Value buffer, 474 Value iv, ValueRange loopState) { 475 SmallVector<Value, 8> loadIndices; 476 getBufferIndices(xferOp, loadIndices); 477 loadIndices.push_back(iv); 478 479 SmallVector<Value, 8> xferIndices; 480 getXferIndices(b, xferOp, iv, xferIndices); 481 482 Location loc = xferOp.getLoc(); 483 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); 484 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 485 auto source = loopState.empty() ? xferOp.source() : loopState[0]; 486 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 487 auto newXferOp = b.create<vector::TransferWriteOp>( 488 loc, type, vec, source, xferIndices, 489 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 490 inBoundsAttr); 491 492 maybeApplyPassLabel(b, newXferOp, options.targetRank); 493 494 return newXferOp; 495 } 496 497 /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 498 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, 499 Value buffer, Value iv, 500 ValueRange loopState) { 501 return isTensorOp(xferOp) ? loopState[0] : Value(); 502 } 503 504 /// Cleanup after rewriting the op. 505 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, 506 scf::ForOp forOp) { 507 if (isTensorOp(xferOp)) { 508 assert(forOp->getNumResults() == 1 && "Expected one for loop result"); 509 rewriter.replaceOp(xferOp, forOp->getResult(0)); 510 } else { 511 rewriter.eraseOp(xferOp); 512 } 513 } 514 515 /// Return the initial loop state for the generated scf.for loop. 516 static Value initialLoopState(TransferWriteOp xferOp) { 517 return isTensorOp(xferOp) ? xferOp.source() : Value(); 518 } 519 }; 520 521 template <typename OpTy> 522 LogicalResult checkPrepareXferOp(OpTy xferOp, 523 VectorTransferToSCFOptions options) { 524 if (xferOp->hasAttr(kPassLabel)) 525 return failure(); 526 if (xferOp.getVectorType().getRank() <= options.targetRank) 527 return failure(); 528 if (isTensorOp(xferOp) && !options.lowerTensors) 529 return failure(); 530 // Transfer ops that modify the element type are not supported atm. 531 if (xferOp.getVectorType().getElementType() != 532 xferOp.getShapedType().getElementType()) 533 return failure(); 534 return success(); 535 } 536 537 /// Prepare a TransferReadOp for progressive lowering. 538 /// 539 /// 1. Allocate a temporary buffer. 540 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 541 /// 3. Store the result of the TransferReadOp into the temporary buffer. 542 /// 4. Load the result from the temporary buffer and replace all uses of the 543 /// original TransferReadOp with this load. 544 /// 545 /// E.g.: 546 /// ``` 547 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 548 /// : vector<5x4xf32>, memref<?x?x?xf32> 549 /// ``` 550 /// is rewritten to: 551 /// ``` 552 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 553 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 554 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 555 /// memref.store %1, %0[] : memref<vector<5x4xf32>> 556 /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 557 /// ``` 558 /// 559 /// Note: A second temporary buffer may be allocated for the `mask` operand. 560 struct PrepareTransferReadConversion 561 : public VectorToSCFPattern<TransferReadOp> { 562 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 563 564 LogicalResult matchAndRewrite(TransferReadOp xferOp, 565 PatternRewriter &rewriter) const override { 566 if (checkPrepareXferOp(xferOp, options).failed()) 567 return failure(); 568 569 auto buffers = allocBuffers(rewriter, xferOp); 570 auto *newXfer = rewriter.clone(*xferOp.getOperation()); 571 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 572 if (xferOp.mask()) { 573 dyn_cast<TransferReadOp>(newXfer).maskMutable().assign( 574 buffers.maskBuffer); 575 } 576 577 Location loc = xferOp.getLoc(); 578 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), 579 buffers.dataBuffer); 580 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 581 582 return success(); 583 } 584 }; 585 586 /// Prepare a TransferWriteOp for progressive lowering. 587 /// 588 /// 1. Allocate a temporary buffer. 589 /// 2. Store the vector into the buffer. 590 /// 3. Load the vector from the buffer again. 591 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 592 /// marking it eligible for progressive lowering via TransferOpConversion. 593 /// 594 /// E.g.: 595 /// ``` 596 /// vector.transfer_write %vec, %A[%a, %b, %c] 597 /// : vector<5x4xf32>, memref<?x?x?xf32> 598 /// ``` 599 /// is rewritten to: 600 /// ``` 601 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 602 /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 603 /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 604 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 605 /// : vector<5x4xf32>, memref<?x?x?xf32> 606 /// ``` 607 /// 608 /// Note: A second temporary buffer may be allocated for the `mask` operand. 609 struct PrepareTransferWriteConversion 610 : public VectorToSCFPattern<TransferWriteOp> { 611 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 612 613 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 614 PatternRewriter &rewriter) const override { 615 if (checkPrepareXferOp(xferOp, options).failed()) 616 return failure(); 617 618 Location loc = xferOp.getLoc(); 619 auto buffers = allocBuffers(rewriter, xferOp); 620 rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer); 621 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); 622 rewriter.updateRootInPlace(xferOp, [&]() { 623 xferOp.vectorMutable().assign(loadedVec); 624 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 625 }); 626 627 if (xferOp.mask()) { 628 rewriter.updateRootInPlace( 629 xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); }); 630 } 631 632 return success(); 633 } 634 }; 635 636 /// Progressive lowering of vector transfer ops: Unpack one dimension. 637 /// 638 /// 1. Unpack one dimension from the current buffer type and cast the buffer 639 /// to that new type. E.g.: 640 /// ``` 641 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 642 /// vector.transfer_write %vec ... 643 /// ``` 644 /// The following cast is generated: 645 /// ``` 646 /// %casted = vector.type_cast %0 647 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 648 /// ``` 649 /// 2. Generate a for loop and rewrite the transfer op according to the 650 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 651 /// out-of-bounds, generate an if-check and handle both cases separately. 652 /// 3. Clean up according to the corresponding Strategy<OpTy>. 653 /// 654 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor 655 /// source (as opposed to a memref source), then each iteration of the generated 656 /// scf.for loop yields the new tensor value. E.g.: 657 /// ``` 658 /// %result = scf.for i = 0 to 5 { 659 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> 660 /// %1 = vector.transfer_write %0, %source[...] 661 /// : vector<4x3xf32>, tensor<5x4x3xf32> 662 /// scf.yield %1 : tensor<5x4x3xf32> 663 /// } 664 /// ``` 665 template <typename OpTy> 666 struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 667 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 668 669 void initialize() { 670 // This pattern recursively unpacks one dimension at a time. The recursion 671 // bounded as the rank is strictly decreasing. 672 this->setHasBoundedRewriteRecursion(); 673 } 674 675 LogicalResult matchAndRewrite(OpTy xferOp, 676 PatternRewriter &rewriter) const override { 677 if (!xferOp->hasAttr(kPassLabel)) 678 return failure(); 679 680 // Find and cast data buffer. How the buffer can be found depends on OpTy. 681 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); 682 auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 683 auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>(); 684 auto castedDataType = unpackOneDim(dataBufferType); 685 auto castedDataBuffer = 686 locB.create<vector::TypeCastOp>(castedDataType, dataBuffer); 687 688 // If the xferOp has a mask: Find and cast mask buffer. 689 Value castedMaskBuffer; 690 if (xferOp.mask()) { 691 auto maskBuffer = getMaskBuffer(xferOp); 692 auto maskBufferType = 693 maskBuffer.getType().template dyn_cast<MemRefType>(); 694 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 695 // Do not unpack a dimension of the mask, if: 696 // * To-be-unpacked transfer op dimension is a broadcast. 697 // * Mask is 1D, i.e., the mask cannot be further unpacked. 698 // (That means that all remaining dimensions of the transfer op must 699 // be broadcasted.) 700 castedMaskBuffer = maskBuffer; 701 } else { 702 auto castedMaskType = unpackOneDim(maskBufferType); 703 castedMaskBuffer = 704 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); 705 } 706 } 707 708 // Loop bounds and step. 709 auto lb = locB.create<arith::ConstantIndexOp>(0); 710 auto ub = locB.create<arith::ConstantIndexOp>( 711 castedDataType.getDimSize(castedDataType.getRank() - 1)); 712 auto step = locB.create<arith::ConstantIndexOp>(1); 713 // TransferWriteOps that operate on tensors return the modified tensor and 714 // require a loop state. 715 auto loopState = Strategy<OpTy>::initialLoopState(xferOp); 716 717 // Generate for loop. 718 auto result = locB.create<scf::ForOp>( 719 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 720 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 721 Type stateType = loopState.empty() ? Type() : loopState[0].getType(); 722 723 auto result = generateInBoundsCheck( 724 b, xferOp, iv, unpackedDim(xferOp), 725 stateType ? TypeRange(stateType) : TypeRange(), 726 /*inBoundsCase=*/ 727 [&](OpBuilder &b, Location loc) { 728 // Create new transfer op. 729 OpTy newXfer = Strategy<OpTy>::rewriteOp( 730 b, this->options, xferOp, castedDataBuffer, iv, loopState); 731 732 // If old transfer op has a mask: Set mask on new transfer op. 733 // Special case: If the mask of the old transfer op is 1D and 734 // the 735 // unpacked dim is not a broadcast, no mask is 736 // needed on the new transfer op. 737 if (xferOp.mask() && (xferOp.isBroadcastDim(0) || 738 xferOp.getMaskType().getRank() > 1)) { 739 OpBuilder::InsertionGuard guard(b); 740 b.setInsertionPoint(newXfer); // Insert load before newXfer. 741 742 SmallVector<Value, 8> loadIndices; 743 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); 744 // In case of broadcast: Use same indices to load from memref 745 // as before. 746 if (!xferOp.isBroadcastDim(0)) 747 loadIndices.push_back(iv); 748 749 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, 750 loadIndices); 751 rewriter.updateRootInPlace( 752 newXfer, [&]() { newXfer.maskMutable().assign(mask); }); 753 } 754 755 return loopState.empty() ? Value() : newXfer->getResult(0); 756 }, 757 /*outOfBoundsCase=*/ 758 [&](OpBuilder &b, Location /*loc*/) { 759 return Strategy<OpTy>::handleOutOfBoundsDim( 760 b, xferOp, castedDataBuffer, iv, loopState); 761 }); 762 763 maybeYieldValue(b, loc, !loopState.empty(), result); 764 }); 765 766 Strategy<OpTy>::cleanup(rewriter, xferOp, result); 767 return success(); 768 } 769 }; 770 771 } // namespace lowering_n_d 772 773 namespace lowering_n_d_unrolled { 774 775 /// If the original transfer op has a mask, compute the mask of the new transfer 776 /// op (for the current iteration `i`) and assign it. 777 template <typename OpTy> 778 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, 779 int64_t i) { 780 if (!xferOp.mask()) 781 return; 782 783 if (xferOp.isBroadcastDim(0)) { 784 // To-be-unpacked dimension is a broadcast, which does not have a 785 // corresponding mask dimension. Mask attribute remains unchanged. 786 newXferOp.maskMutable().assign(xferOp.mask()); 787 return; 788 } 789 790 if (xferOp.getMaskType().getRank() > 1) { 791 // Unpack one dimension of the mask. 792 OpBuilder::InsertionGuard guard(b); 793 b.setInsertionPoint(newXferOp); // Insert load before newXfer. 794 795 llvm::SmallVector<int64_t, 1> indices({i}); 796 Location loc = xferOp.getLoc(); 797 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices); 798 newXferOp.maskMutable().assign(newMask); 799 } 800 801 // If we end up here: The mask of the old transfer op is 1D and the unpacked 802 // dim is not a broadcast, so no mask is needed on the new transfer op. 803 // `generateInBoundsCheck` will have evaluated the mask already. 804 } 805 806 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 807 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 808 /// memref buffer is allocated and the SCF loop is fully unrolled. 809 /// 810 /// ``` 811 /// E.g.: 812 /// ``` 813 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 814 /// : memref<?x?x?xf32>, vector<5x4xf32> 815 /// ``` 816 /// is rewritten to IR such as (simplified): 817 /// ``` 818 /// %v_init = splat %padding : vector<5x4xf32> 819 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 820 /// : memref<?x?x?xf32>, vector<4xf32> 821 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 822 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 823 /// : memref<?x?x?xf32>, vector<4xf32> 824 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 825 /// ... 826 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 827 /// : memref<?x?x?xf32>, vector<4xf32> 828 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 829 /// ``` 830 /// 831 /// Note: As an optimization, if the result of the original TransferReadOp 832 /// was directly inserted into another vector, no new %v_init vector is created. 833 /// Instead, the new TransferReadOp results are inserted into that vector. 834 struct UnrollTransferReadConversion 835 : public VectorToSCFPattern<TransferReadOp> { 836 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 837 838 void initialize() { 839 // This pattern recursively unpacks one dimension at a time. The recursion 840 // bounded as the rank is strictly decreasing. 841 setHasBoundedRewriteRecursion(); 842 } 843 844 /// Return the vector into which the newly created TransferReadOp results 845 /// are inserted. 846 Value getResultVector(TransferReadOp xferOp, 847 PatternRewriter &rewriter) const { 848 if (auto insertOp = getInsertOp(xferOp)) 849 return insertOp.dest(); 850 Location loc = xferOp.getLoc(); 851 return rewriter.create<SplatOp>(loc, xferOp.getVectorType(), 852 xferOp.padding()); 853 } 854 855 /// If the result of the TransferReadOp has exactly one user, which is a 856 /// vector::InsertOp, return that operation. 857 vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 858 if (xferOp->hasOneUse()) { 859 Operation *xferOpUser = *xferOp->getUsers().begin(); 860 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 861 return insertOp; 862 } 863 864 return vector::InsertOp(); 865 } 866 867 /// If the result of the TransferReadOp has exactly one user, which is a 868 /// vector::InsertOp, return that operation's indices. 869 void getInsertionIndices(TransferReadOp xferOp, 870 SmallVector<int64_t, 8> &indices) const { 871 if (auto insertOp = getInsertOp(xferOp)) { 872 llvm::for_each(insertOp.position(), [&](Attribute attr) { 873 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 874 }); 875 } 876 } 877 878 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 879 /// accesses, and broadcasts and transposes in permutation maps. 880 LogicalResult matchAndRewrite(TransferReadOp xferOp, 881 PatternRewriter &rewriter) const override { 882 if (xferOp.getVectorType().getRank() <= options.targetRank) 883 return failure(); 884 if (isTensorOp(xferOp) && !options.lowerTensors) 885 return failure(); 886 // Transfer ops that modify the element type are not supported atm. 887 if (xferOp.getVectorType().getElementType() != 888 xferOp.getShapedType().getElementType()) 889 return failure(); 890 891 auto insertOp = getInsertOp(xferOp); 892 auto vec = getResultVector(xferOp, rewriter); 893 auto vecType = vec.getType().dyn_cast<VectorType>(); 894 auto xferVecType = xferOp.getVectorType(); 895 auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), 896 xferVecType.getElementType()); 897 int64_t dimSize = xferVecType.getShape()[0]; 898 899 // Generate fully unrolled loop of transfer ops. 900 Location loc = xferOp.getLoc(); 901 for (int64_t i = 0; i < dimSize; ++i) { 902 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 903 904 vec = generateInBoundsCheck( 905 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), 906 /*inBoundsCase=*/ 907 [&](OpBuilder &b, Location loc) { 908 // Indices for the new transfer op. 909 SmallVector<Value, 8> xferIndices; 910 getXferIndices(b, xferOp, iv, xferIndices); 911 912 // Indices for the new vector.insert op. 913 SmallVector<int64_t, 8> insertionIndices; 914 getInsertionIndices(xferOp, insertionIndices); 915 insertionIndices.push_back(i); 916 917 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 918 auto newXferOp = b.create<vector::TransferReadOp>( 919 loc, newXferVecType, xferOp.source(), xferIndices, 920 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 921 xferOp.padding(), Value(), inBoundsAttr); 922 maybeAssignMask(b, xferOp, newXferOp, i); 923 return b.create<vector::InsertOp>(loc, newXferOp, vec, 924 insertionIndices); 925 }, 926 /*outOfBoundsCase=*/ 927 [&](OpBuilder &b, Location loc) { 928 // Loop through original (unmodified) vector. 929 return vec; 930 }); 931 } 932 933 if (insertOp) { 934 // Rewrite single user of the old TransferReadOp, which was an InsertOp. 935 rewriter.replaceOp(insertOp, vec); 936 rewriter.eraseOp(xferOp); 937 } else { 938 rewriter.replaceOp(xferOp, vec); 939 } 940 941 return success(); 942 } 943 }; 944 945 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 946 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 947 /// memref buffer is allocated and the SCF loop is fully unrolled. 948 /// 949 /// ``` 950 /// E.g.: 951 /// ``` 952 /// vector.transfer_write %vec, %A[%a, %b, %c] 953 /// : vector<5x4xf32>, memref<?x?x?xf32> 954 /// ``` 955 /// is rewritten to IR such as (simplified): 956 /// ``` 957 /// %v0 = vector.extract %vec[0] : vector<5x4xf32> 958 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 959 /// %v1 = vector.extract %vec[1] : vector<5x4xf32> 960 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 961 /// ... 962 /// %v4 = vector.extract %vec[4] : vector<5x4xf32> 963 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 964 /// ``` 965 /// 966 /// Note: As an optimization, if the vector of the original TransferWriteOp 967 /// was directly extracted from another vector via an ExtractOp `a`, extract 968 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 969 /// doing so, `a` may become dead, and the number of ExtractOps generated during 970 /// recursive application of this pattern will be minimal. 971 struct UnrollTransferWriteConversion 972 : public VectorToSCFPattern<TransferWriteOp> { 973 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 974 975 void initialize() { 976 // This pattern recursively unpacks one dimension at a time. The recursion 977 // bounded as the rank is strictly decreasing. 978 setHasBoundedRewriteRecursion(); 979 } 980 981 /// Return the vector from which newly generated ExtracOps will extract. 982 Value getDataVector(TransferWriteOp xferOp) const { 983 if (auto extractOp = getExtractOp(xferOp)) 984 return extractOp.vector(); 985 return xferOp.vector(); 986 } 987 988 /// If the input of the given TransferWriteOp is an ExtractOp, return it. 989 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 990 if (auto *op = xferOp.vector().getDefiningOp()) 991 return dyn_cast<vector::ExtractOp>(op); 992 return vector::ExtractOp(); 993 } 994 995 /// If the input of the given TransferWriteOp is an ExtractOp, return its 996 /// indices. 997 void getExtractionIndices(TransferWriteOp xferOp, 998 SmallVector<int64_t, 8> &indices) const { 999 if (auto extractOp = getExtractOp(xferOp)) { 1000 llvm::for_each(extractOp.position(), [&](Attribute attr) { 1001 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 1002 }); 1003 } 1004 } 1005 1006 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 1007 /// accesses, and broadcasts and transposes in permutation maps. 1008 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 1009 PatternRewriter &rewriter) const override { 1010 if (xferOp.getVectorType().getRank() <= options.targetRank) 1011 return failure(); 1012 if (isTensorOp(xferOp) && !options.lowerTensors) 1013 return failure(); 1014 // Transfer ops that modify the element type are not supported atm. 1015 if (xferOp.getVectorType().getElementType() != 1016 xferOp.getShapedType().getElementType()) 1017 return failure(); 1018 1019 auto vec = getDataVector(xferOp); 1020 auto xferVecType = xferOp.getVectorType(); 1021 int64_t dimSize = xferVecType.getShape()[0]; 1022 auto source = xferOp.source(); // memref or tensor to be written to. 1023 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 1024 1025 // Generate fully unrolled loop of transfer ops. 1026 Location loc = xferOp.getLoc(); 1027 for (int64_t i = 0; i < dimSize; ++i) { 1028 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 1029 1030 auto updatedSource = generateInBoundsCheck( 1031 rewriter, xferOp, iv, unpackedDim(xferOp), 1032 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), 1033 /*inBoundsCase=*/ 1034 [&](OpBuilder &b, Location loc) { 1035 // Indices for the new transfer op. 1036 SmallVector<Value, 8> xferIndices; 1037 getXferIndices(b, xferOp, iv, xferIndices); 1038 1039 // Indices for the new vector.extract op. 1040 SmallVector<int64_t, 8> extractionIndices; 1041 getExtractionIndices(xferOp, extractionIndices); 1042 extractionIndices.push_back(i); 1043 1044 auto extracted = 1045 b.create<vector::ExtractOp>(loc, vec, extractionIndices); 1046 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 1047 auto newXferOp = b.create<vector::TransferWriteOp>( 1048 loc, sourceType, extracted, source, xferIndices, 1049 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 1050 inBoundsAttr); 1051 1052 maybeAssignMask(b, xferOp, newXferOp, i); 1053 1054 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); 1055 }, 1056 /*outOfBoundsCase=*/ 1057 [&](OpBuilder &b, Location loc) { 1058 return isTensorOp(xferOp) ? source : Value(); 1059 }); 1060 1061 if (isTensorOp(xferOp)) 1062 source = updatedSource; 1063 } 1064 1065 if (isTensorOp(xferOp)) 1066 rewriter.replaceOp(xferOp, source); 1067 else 1068 rewriter.eraseOp(xferOp); 1069 1070 return success(); 1071 } 1072 }; 1073 1074 } // namespace lowering_n_d_unrolled 1075 1076 namespace lowering_1_d { 1077 1078 /// Compute the indices into the memref for the LoadOp/StoreOp generated as 1079 /// part of TransferOp1dConversion. Return the memref dimension on which 1080 /// the transfer is operating. A return value of None indicates a broadcast. 1081 template <typename OpTy> 1082 static Optional<int64_t> 1083 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, 1084 SmallVector<Value, 8> &memrefIndices) { 1085 auto indices = xferOp.indices(); 1086 auto map = xferOp.permutation_map(); 1087 1088 memrefIndices.append(indices.begin(), indices.end()); 1089 assert(map.getNumResults() == 1 && 1090 "Expected 1 permutation map result for 1D transfer"); 1091 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 1092 Location loc = xferOp.getLoc(); 1093 auto dim = expr.getPosition(); 1094 AffineExpr d0, d1; 1095 bindDims(xferOp.getContext(), d0, d1); 1096 Value offset = memrefIndices[dim]; 1097 memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 1098 return dim; 1099 } 1100 1101 assert(xferOp.isBroadcastDim(0) && 1102 "Expected AffineDimExpr or AffineConstantExpr"); 1103 return None; 1104 } 1105 1106 /// Codegen strategy for TransferOp1dConversion, depending on the 1107 /// operation. 1108 template <typename OpTy> 1109 struct Strategy1d; 1110 1111 /// Codegen strategy for TransferReadOp. 1112 template <> 1113 struct Strategy1d<TransferReadOp> { 1114 static void generateForLoopBody(OpBuilder &b, Location loc, 1115 TransferReadOp xferOp, Value iv, 1116 ValueRange loopState) { 1117 SmallVector<Value, 8> indices; 1118 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1119 Value ivI32 = b.create<arith::IndexCastOp>( 1120 loc, IntegerType::get(b.getContext(), 32), iv); 1121 auto vec = loopState[0]; 1122 1123 // In case of out-of-bounds access, leave `vec` as is (was initialized with 1124 // padding value). 1125 auto nextVec = generateInBoundsCheck( 1126 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), 1127 /*inBoundsCase=*/ 1128 [&](OpBuilder &b, Location loc) { 1129 Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices); 1130 return b.create<vector::InsertElementOp>(loc, val, vec, ivI32); 1131 }, 1132 /*outOfBoundsCase=*/ 1133 [&](OpBuilder & /*b*/, Location loc) { return vec; }); 1134 b.create<scf::YieldOp>(loc, nextVec); 1135 } 1136 1137 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { 1138 // Inititalize vector with padding value. 1139 Location loc = xferOp.getLoc(); 1140 return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding()); 1141 } 1142 }; 1143 1144 /// Codegen strategy for TransferWriteOp. 1145 template <> 1146 struct Strategy1d<TransferWriteOp> { 1147 static void generateForLoopBody(OpBuilder &b, Location loc, 1148 TransferWriteOp xferOp, Value iv, 1149 ValueRange /*loopState*/) { 1150 SmallVector<Value, 8> indices; 1151 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1152 Value ivI32 = b.create<arith::IndexCastOp>( 1153 loc, IntegerType::get(b.getContext(), 32), iv); 1154 1155 // Nothing to do in case of out-of-bounds access. 1156 generateInBoundsCheck( 1157 b, xferOp, iv, dim, 1158 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 1159 auto val = 1160 b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32); 1161 b.create<memref::StoreOp>(loc, val, xferOp.source(), indices); 1162 }); 1163 b.create<scf::YieldOp>(loc); 1164 } 1165 1166 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { 1167 return Value(); 1168 } 1169 }; 1170 1171 /// Return true if the last dimension of the MemRefType has unit stride. 1172 static bool isLastMemrefDimUnitStride(MemRefType type) { 1173 int64_t offset; 1174 SmallVector<int64_t, 4> strides; 1175 auto successStrides = getStridesAndOffset(type, strides, offset); 1176 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 1177 } 1178 1179 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 1180 /// necessary in cases where a 1D vector transfer op cannot be lowered into 1181 /// vector load/stores due to non-unit strides or broadcasts: 1182 /// 1183 /// * Transfer dimension is not the last memref dimension 1184 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 1185 /// * Memref has a layout map with non-unit stride on the last dimension 1186 /// 1187 /// This pattern generates IR as follows: 1188 /// 1189 /// 1. Generate a for loop iterating over each vector element. 1190 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 1191 /// depending on OpTy. 1192 /// 1193 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 1194 /// can be generated instead of TransferOp1dConversion. Add such a pattern 1195 /// to ConvertVectorToLLVM. 1196 /// 1197 /// E.g.: 1198 /// ``` 1199 /// vector.transfer_write %vec, %A[%a, %b] 1200 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 1201 /// : vector<9xf32>, memref<?x?xf32> 1202 /// ``` 1203 /// Is rewritten to approximately the following pseudo-IR: 1204 /// ``` 1205 /// for i = 0 to 9 { 1206 /// %t = vector.extractelement %vec[i] : vector<9xf32> 1207 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 1208 /// } 1209 /// ``` 1210 template <typename OpTy> 1211 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 1212 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 1213 1214 LogicalResult matchAndRewrite(OpTy xferOp, 1215 PatternRewriter &rewriter) const override { 1216 auto map = xferOp.permutation_map(); 1217 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1218 1219 if (!memRefType) 1220 return failure(); 1221 if (xferOp.getVectorType().getRank() != 1) 1222 return failure(); 1223 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) 1224 return failure(); // Handled by ConvertVectorToLLVM 1225 1226 // Loop bounds, step, state... 1227 Location loc = xferOp.getLoc(); 1228 auto vecType = xferOp.getVectorType(); 1229 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1230 auto ub = 1231 rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); 1232 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1233 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); 1234 1235 // Generate for loop. 1236 rewriter.replaceOpWithNewOp<scf::ForOp>( 1237 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 1238 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 1239 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); 1240 }); 1241 1242 return success(); 1243 } 1244 }; 1245 1246 } // namespace lowering_1_d 1247 } // namespace 1248 1249 namespace mlir { 1250 1251 void populateVectorToSCFConversionPatterns( 1252 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 1253 if (options.unroll) { 1254 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, 1255 lowering_n_d_unrolled::UnrollTransferWriteConversion>( 1256 patterns.getContext(), options); 1257 } else { 1258 patterns.add<lowering_n_d::PrepareTransferReadConversion, 1259 lowering_n_d::PrepareTransferWriteConversion, 1260 lowering_n_d::TransferOpConversion<TransferReadOp>, 1261 lowering_n_d::TransferOpConversion<TransferWriteOp>>( 1262 patterns.getContext(), options); 1263 } 1264 1265 if (options.targetRank == 1) { 1266 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, 1267 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( 1268 patterns.getContext(), options); 1269 } 1270 } 1271 1272 } // namespace mlir 1273 1274 namespace { 1275 1276 struct ConvertVectorToSCFPass 1277 : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 1278 ConvertVectorToSCFPass() = default; 1279 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1280 this->fullUnroll = options.unroll; 1281 this->targetRank = options.targetRank; 1282 this->lowerPermutationMaps = options.lowerPermutationMaps; 1283 this->lowerTensors = options.lowerTensors; 1284 } 1285 1286 void runOnFunction() override { 1287 VectorTransferToSCFOptions options; 1288 options.unroll = fullUnroll; 1289 options.targetRank = targetRank; 1290 options.lowerPermutationMaps = lowerPermutationMaps; 1291 options.lowerTensors = lowerTensors; 1292 1293 // Lower permutation maps first. 1294 if (lowerPermutationMaps) { 1295 RewritePatternSet lowerTransferPatterns(getFunction().getContext()); 1296 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 1297 lowerTransferPatterns); 1298 (void)applyPatternsAndFoldGreedily(getFunction(), 1299 std::move(lowerTransferPatterns)); 1300 } 1301 1302 RewritePatternSet patterns(getFunction().getContext()); 1303 populateVectorToSCFConversionPatterns(patterns, options); 1304 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 1305 } 1306 }; 1307 1308 } // namespace 1309 1310 std::unique_ptr<Pass> 1311 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1312 return std::make_unique<ConvertVectorToSCFPass>(options); 1313 } 1314