1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector transfer operations to SCF. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 16 17 #include "../PassDetail.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 23 #include "mlir/IR/Builders.h" 24 #include "mlir/IR/ImplicitLocOpBuilder.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "mlir/Transforms/Passes.h" 28 29 using namespace mlir; 30 using vector::TransferReadOp; 31 using vector::TransferWriteOp; 32 33 namespace { 34 35 /// Attribute name used for labeling transfer ops during progressive lowering. 36 static const char kPassLabel[] = "__vector_to_scf_lowering__"; 37 38 /// Patterns that inherit from this struct have access to 39 /// VectorTransferToSCFOptions. 40 template <typename OpTy> 41 struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 42 explicit VectorToSCFPattern(MLIRContext *context, 43 VectorTransferToSCFOptions opt) 44 : OpRewritePattern<OpTy>(context), options(opt) {} 45 46 VectorTransferToSCFOptions options; 47 }; 48 49 /// Given a vector transfer op, calculate which dimension of the `source` 50 /// memref should be unpacked in the next application of TransferOpConversion. 51 /// A return value of None indicates a broadcast. 52 template <typename OpTy> 53 static Optional<int64_t> unpackedDim(OpTy xferOp) { 54 // TODO: support 0-d corner case. 55 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 56 auto map = xferOp.permutation_map(); 57 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 58 return expr.getPosition(); 59 } 60 assert(xferOp.isBroadcastDim(0) && 61 "Expected AffineDimExpr or AffineConstantExpr"); 62 return None; 63 } 64 65 /// Compute the permutation map for the new (N-1)-D vector transfer op. This 66 /// map is identical to the current permutation map, but the first result is 67 /// omitted. 68 template <typename OpTy> 69 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { 70 // TODO: support 0-d corner case. 71 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 72 auto map = xferOp.permutation_map(); 73 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 74 b.getContext()); 75 } 76 77 /// Calculate the indices for the new vector transfer op. 78 /// 79 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 80 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 81 /// ^^^^^^ 82 /// `iv` is the iteration variable of the (new) surrounding loop. 83 template <typename OpTy> 84 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, 85 SmallVector<Value, 8> &indices) { 86 typename OpTy::Adaptor adaptor(xferOp); 87 // Corresponding memref dim of the vector dim that is unpacked. 88 auto dim = unpackedDim(xferOp); 89 auto prevIndices = adaptor.indices(); 90 indices.append(prevIndices.begin(), prevIndices.end()); 91 92 Location loc = xferOp.getLoc(); 93 bool isBroadcast = !dim.hasValue(); 94 if (!isBroadcast) { 95 AffineExpr d0, d1; 96 bindDims(xferOp.getContext(), d0, d1); 97 Value offset = adaptor.indices()[dim.getValue()]; 98 indices[dim.getValue()] = 99 makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 100 } 101 } 102 103 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, 104 Value value) { 105 if (hasRetVal) { 106 assert(value && "Expected non-empty value"); 107 b.create<scf::YieldOp>(loc, value); 108 } else { 109 b.create<scf::YieldOp>(loc); 110 } 111 } 112 113 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 114 /// is set to true. No such check is generated under following circumstances: 115 /// * xferOp does not have a mask. 116 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 117 /// computed and attached to the new transfer op in the pattern.) 118 /// * The to-be-unpacked dim of xferOp is a broadcast. 119 template <typename OpTy> 120 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { 121 if (!xferOp.mask()) 122 return Value(); 123 if (xferOp.getMaskType().getRank() != 1) 124 return Value(); 125 if (xferOp.isBroadcastDim(0)) 126 return Value(); 127 128 Location loc = xferOp.getLoc(); 129 return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), iv); 130 } 131 132 /// Helper function TransferOpConversion and TransferOp1dConversion. 133 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 134 /// specified dimension `dim` with the loop iteration variable `iv`. 135 /// E.g., when unpacking dimension 0 from: 136 /// ``` 137 /// %vec = vector.transfer_read %A[%a, %b] %cst 138 /// : vector<5x4xf32>, memref<?x?xf32> 139 /// ``` 140 /// An if check similar to this will be generated inside the loop: 141 /// ``` 142 /// %d = memref.dim %A, %c0 : memref<?x?xf32> 143 /// if (%a + iv < %d) { 144 /// (in-bounds case) 145 /// } else { 146 /// (out-of-bounds case) 147 /// } 148 /// ``` 149 /// 150 /// If the transfer is 1D and has a mask, this function generates a more complex 151 /// check also accounts for potentially masked out elements. 152 /// 153 /// This function variant returns the value returned by `inBoundsCase` or 154 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 155 /// `resultTypes`. 156 template <typename OpTy> 157 static Value generateInBoundsCheck( 158 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 159 TypeRange resultTypes, 160 function_ref<Value(OpBuilder &, Location)> inBoundsCase, 161 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 162 bool hasRetVal = !resultTypes.empty(); 163 Value cond; // Condition to be built... 164 165 // Condition check 1: Access in-bounds? 166 bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. 167 Location loc = xferOp.getLoc(); 168 ImplicitLocOpBuilder lb(xferOp.getLoc(), b); 169 if (!xferOp.isDimInBounds(0) && !isBroadcast) { 170 Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim); 171 AffineExpr d0, d1; 172 bindDims(xferOp.getContext(), d0, d1); 173 Value base = xferOp.indices()[dim.getValue()]; 174 Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); 175 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, 176 memrefIdx); 177 } 178 179 // Condition check 2: Masked in? 180 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { 181 if (cond) 182 cond = lb.create<arith::AndIOp>(cond, maskCond); 183 else 184 cond = maskCond; 185 } 186 187 // If the condition is non-empty, generate an SCF::IfOp. 188 if (cond) { 189 auto check = lb.create<scf::IfOp>( 190 resultTypes, cond, 191 /*thenBuilder=*/ 192 [&](OpBuilder &b, Location loc) { 193 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); 194 }, 195 /*elseBuilder=*/ 196 [&](OpBuilder &b, Location loc) { 197 if (outOfBoundsCase) { 198 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); 199 } else { 200 b.create<scf::YieldOp>(loc); 201 } 202 }); 203 204 return hasRetVal ? check.getResult(0) : Value(); 205 } 206 207 // Condition is empty, no need for an SCF::IfOp. 208 return inBoundsCase(b, loc); 209 } 210 211 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 212 /// a return value. Consequently, this function does not have a return value. 213 template <typename OpTy> 214 static void generateInBoundsCheck( 215 OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 216 function_ref<void(OpBuilder &, Location)> inBoundsCase, 217 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 218 generateInBoundsCheck( 219 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), 220 /*inBoundsCase=*/ 221 [&](OpBuilder &b, Location loc) { 222 inBoundsCase(b, loc); 223 return Value(); 224 }, 225 /*outOfBoundsCase=*/ 226 [&](OpBuilder &b, Location loc) { 227 if (outOfBoundsCase) 228 outOfBoundsCase(b, loc); 229 return Value(); 230 }); 231 } 232 233 /// Given an ArrayAttr, return a copy where the first element is dropped. 234 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { 235 if (!attr) 236 return attr; 237 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); 238 } 239 240 /// Add the pass label to a vector transfer op if its rank is not the target 241 /// rank. 242 template <typename OpTy> 243 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, 244 unsigned targetRank) { 245 if (newXferOp.getVectorType().getRank() > targetRank) 246 newXferOp->setAttr(kPassLabel, b.getUnitAttr()); 247 } 248 249 /// Return true if this transfer op operates on a source tensor. 250 template <typename OpTy> 251 static bool isTensorOp(OpTy xferOp) { 252 if (xferOp.getShapedType().template isa<RankedTensorType>()) { 253 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { 254 // TransferWriteOps on tensors have a result. 255 assert(xferOp->getNumResults() > 0); 256 } 257 return true; 258 } 259 return false; 260 } 261 262 namespace lowering_n_d { 263 264 /// Helper data structure for data and mask buffers. 265 struct BufferAllocs { 266 Value dataBuffer; 267 Value maskBuffer; 268 }; 269 270 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 271 static Operation *getAutomaticAllocationScope(Operation *op) { 272 Operation *scope = 273 op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 274 assert(scope && "Expected op to be inside automatic allocation scope"); 275 return scope; 276 } 277 278 /// Allocate temporary buffers for data (vector) and mask (if present). 279 template <typename OpTy> 280 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { 281 Location loc = xferOp.getLoc(); 282 OpBuilder::InsertionGuard guard(b); 283 Operation *scope = getAutomaticAllocationScope(xferOp); 284 assert(scope->getNumRegions() == 1 && 285 "AutomaticAllocationScope with >1 regions"); 286 b.setInsertionPointToStart(&scope->getRegion(0).front()); 287 288 BufferAllocs result; 289 auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 290 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); 291 292 if (xferOp.mask()) { 293 auto maskType = MemRefType::get({}, xferOp.mask().getType()); 294 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); 295 b.setInsertionPoint(xferOp); 296 b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer); 297 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer); 298 } 299 300 return result; 301 } 302 303 /// Given a MemRefType with VectorType element type, unpack one dimension from 304 /// the VectorType into the MemRefType. 305 /// 306 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 307 static MemRefType unpackOneDim(MemRefType type) { 308 auto vectorType = type.getElementType().dyn_cast<VectorType>(); 309 auto memrefShape = type.getShape(); 310 SmallVector<int64_t, 8> newMemrefShape; 311 newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 312 newMemrefShape.push_back(vectorType.getDimSize(0)); 313 return MemRefType::get(newMemrefShape, 314 VectorType::get(vectorType.getShape().drop_front(), 315 vectorType.getElementType())); 316 } 317 318 /// Given a transfer op, find the memref from which the mask is loaded. This 319 /// is similar to Strategy<TransferWriteOp>::getBuffer. 320 template <typename OpTy> 321 static Value getMaskBuffer(OpTy xferOp) { 322 assert(xferOp.mask() && "Expected that transfer op has mask"); 323 auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>(); 324 assert(loadOp && "Expected transfer op mask produced by LoadOp"); 325 return loadOp.getMemRef(); 326 } 327 328 /// Codegen strategy, depending on the operation. 329 template <typename OpTy> 330 struct Strategy; 331 332 /// Code strategy for vector TransferReadOp. 333 template <> 334 struct Strategy<TransferReadOp> { 335 /// Find the StoreOp that is used for writing the current TransferReadOp's 336 /// result to the temporary buffer allocation. 337 static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 338 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 339 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 340 assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 341 return storeOp; 342 } 343 344 /// Find the temporary buffer allocation. All labeled TransferReadOps are 345 /// used like this, where %buf is either the buffer allocation or a type cast 346 /// of the buffer allocation: 347 /// ``` 348 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 349 /// memref.store %vec, %buf[...] ... 350 /// ``` 351 static Value getBuffer(TransferReadOp xferOp) { 352 return getStoreOp(xferOp).getMemRef(); 353 } 354 355 /// Retrieve the indices of the current StoreOp that stores into the buffer. 356 static void getBufferIndices(TransferReadOp xferOp, 357 SmallVector<Value, 8> &indices) { 358 auto storeOp = getStoreOp(xferOp); 359 auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); 360 indices.append(prevIndices.begin(), prevIndices.end()); 361 } 362 363 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 364 /// accesses on the to-be-unpacked dimension. 365 /// 366 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 367 /// variable `iv`. 368 /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 369 /// 370 /// E.g.: 371 /// ``` 372 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 373 /// : memref<?x?x?xf32>, vector<4x3xf32> 374 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 375 /// ``` 376 /// Is rewritten to: 377 /// ``` 378 /// %casted = vector.type_cast %buf 379 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 380 /// for %j = 0 to 4 { 381 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 382 /// : memref<?x?x?xf32>, vector<3xf32> 383 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 384 /// } 385 /// ``` 386 /// 387 /// Note: The loop and type cast are generated in TransferOpConversion. 388 /// The original TransferReadOp and store op are deleted in `cleanup`. 389 /// Note: The `mask` operand is set in TransferOpConversion. 390 static TransferReadOp rewriteOp(OpBuilder &b, 391 VectorTransferToSCFOptions options, 392 TransferReadOp xferOp, Value buffer, Value iv, 393 ValueRange /*loopState*/) { 394 SmallVector<Value, 8> storeIndices; 395 getBufferIndices(xferOp, storeIndices); 396 storeIndices.push_back(iv); 397 398 SmallVector<Value, 8> xferIndices; 399 getXferIndices(b, xferOp, iv, xferIndices); 400 401 Location loc = xferOp.getLoc(); 402 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 403 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 404 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 405 auto newXferOp = b.create<vector::TransferReadOp>( 406 loc, vecType, xferOp.source(), xferIndices, 407 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(), 408 Value(), inBoundsAttr); 409 410 maybeApplyPassLabel(b, newXferOp, options.targetRank); 411 412 b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices); 413 return newXferOp; 414 } 415 416 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 417 /// padding value to the temporary buffer. 418 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, 419 Value buffer, Value iv, 420 ValueRange /*loopState*/) { 421 SmallVector<Value, 8> storeIndices; 422 getBufferIndices(xferOp, storeIndices); 423 storeIndices.push_back(iv); 424 425 Location loc = xferOp.getLoc(); 426 auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 427 auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 428 auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.padding()); 429 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); 430 431 return Value(); 432 } 433 434 /// Cleanup after rewriting the op. 435 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, 436 scf::ForOp /*forOp*/) { 437 rewriter.eraseOp(getStoreOp(xferOp)); 438 rewriter.eraseOp(xferOp); 439 } 440 441 /// Return the initial loop state for the generated scf.for loop. 442 static Value initialLoopState(TransferReadOp xferOp) { return Value(); } 443 }; 444 445 /// Codegen strategy for vector TransferWriteOp. 446 template <> 447 struct Strategy<TransferWriteOp> { 448 /// Find the temporary buffer allocation. All labeled TransferWriteOps are 449 /// used like this, where %buf is either the buffer allocation or a type cast 450 /// of the buffer allocation: 451 /// ``` 452 /// %vec = memref.load %buf[...] ... 453 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 454 /// ``` 455 static Value getBuffer(TransferWriteOp xferOp) { 456 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 457 assert(loadOp && "Expected transfer op vector produced by LoadOp"); 458 return loadOp.getMemRef(); 459 } 460 461 /// Retrieve the indices of the current LoadOp that loads from the buffer. 462 static void getBufferIndices(TransferWriteOp xferOp, 463 SmallVector<Value, 8> &indices) { 464 auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 465 auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); 466 indices.append(prevIndices.begin(), prevIndices.end()); 467 } 468 469 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 470 /// accesses on the to-be-unpacked dimension. 471 /// 472 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 473 /// using the loop iteration variable `iv`. 474 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 475 /// to memory. 476 /// 477 /// Note: For more details, see comments on Strategy<TransferReadOp>. 478 static TransferWriteOp rewriteOp(OpBuilder &b, 479 VectorTransferToSCFOptions options, 480 TransferWriteOp xferOp, Value buffer, 481 Value iv, ValueRange loopState) { 482 SmallVector<Value, 8> loadIndices; 483 getBufferIndices(xferOp, loadIndices); 484 loadIndices.push_back(iv); 485 486 SmallVector<Value, 8> xferIndices; 487 getXferIndices(b, xferOp, iv, xferIndices); 488 489 Location loc = xferOp.getLoc(); 490 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); 491 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 492 auto source = loopState.empty() ? xferOp.source() : loopState[0]; 493 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 494 auto newXferOp = b.create<vector::TransferWriteOp>( 495 loc, type, vec, source, xferIndices, 496 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 497 inBoundsAttr); 498 499 maybeApplyPassLabel(b, newXferOp, options.targetRank); 500 501 return newXferOp; 502 } 503 504 /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 505 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, 506 Value buffer, Value iv, 507 ValueRange loopState) { 508 return isTensorOp(xferOp) ? loopState[0] : Value(); 509 } 510 511 /// Cleanup after rewriting the op. 512 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, 513 scf::ForOp forOp) { 514 if (isTensorOp(xferOp)) { 515 assert(forOp->getNumResults() == 1 && "Expected one for loop result"); 516 rewriter.replaceOp(xferOp, forOp->getResult(0)); 517 } else { 518 rewriter.eraseOp(xferOp); 519 } 520 } 521 522 /// Return the initial loop state for the generated scf.for loop. 523 static Value initialLoopState(TransferWriteOp xferOp) { 524 return isTensorOp(xferOp) ? xferOp.source() : Value(); 525 } 526 }; 527 528 template <typename OpTy> 529 LogicalResult checkPrepareXferOp(OpTy xferOp, 530 VectorTransferToSCFOptions options) { 531 if (xferOp->hasAttr(kPassLabel)) 532 return failure(); 533 if (xferOp.getVectorType().getRank() <= options.targetRank) 534 return failure(); 535 if (isTensorOp(xferOp) && !options.lowerTensors) 536 return failure(); 537 // Transfer ops that modify the element type are not supported atm. 538 if (xferOp.getVectorType().getElementType() != 539 xferOp.getShapedType().getElementType()) 540 return failure(); 541 return success(); 542 } 543 544 /// Prepare a TransferReadOp for progressive lowering. 545 /// 546 /// 1. Allocate a temporary buffer. 547 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 548 /// 3. Store the result of the TransferReadOp into the temporary buffer. 549 /// 4. Load the result from the temporary buffer and replace all uses of the 550 /// original TransferReadOp with this load. 551 /// 552 /// E.g.: 553 /// ``` 554 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 555 /// : vector<5x4xf32>, memref<?x?x?xf32> 556 /// ``` 557 /// is rewritten to: 558 /// ``` 559 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 560 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 561 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 562 /// memref.store %1, %0[] : memref<vector<5x4xf32>> 563 /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 564 /// ``` 565 /// 566 /// Note: A second temporary buffer may be allocated for the `mask` operand. 567 struct PrepareTransferReadConversion 568 : public VectorToSCFPattern<TransferReadOp> { 569 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 570 571 LogicalResult matchAndRewrite(TransferReadOp xferOp, 572 PatternRewriter &rewriter) const override { 573 if (checkPrepareXferOp(xferOp, options).failed()) 574 return failure(); 575 576 auto buffers = allocBuffers(rewriter, xferOp); 577 auto *newXfer = rewriter.clone(*xferOp.getOperation()); 578 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 579 if (xferOp.mask()) { 580 dyn_cast<TransferReadOp>(newXfer).maskMutable().assign( 581 buffers.maskBuffer); 582 } 583 584 Location loc = xferOp.getLoc(); 585 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), 586 buffers.dataBuffer); 587 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 588 589 return success(); 590 } 591 }; 592 593 /// Prepare a TransferWriteOp for progressive lowering. 594 /// 595 /// 1. Allocate a temporary buffer. 596 /// 2. Store the vector into the buffer. 597 /// 3. Load the vector from the buffer again. 598 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 599 /// marking it eligible for progressive lowering via TransferOpConversion. 600 /// 601 /// E.g.: 602 /// ``` 603 /// vector.transfer_write %vec, %A[%a, %b, %c] 604 /// : vector<5x4xf32>, memref<?x?x?xf32> 605 /// ``` 606 /// is rewritten to: 607 /// ``` 608 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 609 /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 610 /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 611 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 612 /// : vector<5x4xf32>, memref<?x?x?xf32> 613 /// ``` 614 /// 615 /// Note: A second temporary buffer may be allocated for the `mask` operand. 616 struct PrepareTransferWriteConversion 617 : public VectorToSCFPattern<TransferWriteOp> { 618 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 619 620 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 621 PatternRewriter &rewriter) const override { 622 if (checkPrepareXferOp(xferOp, options).failed()) 623 return failure(); 624 625 Location loc = xferOp.getLoc(); 626 auto buffers = allocBuffers(rewriter, xferOp); 627 rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer); 628 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); 629 rewriter.updateRootInPlace(xferOp, [&]() { 630 xferOp.vectorMutable().assign(loadedVec); 631 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 632 }); 633 634 if (xferOp.mask()) { 635 rewriter.updateRootInPlace( 636 xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); }); 637 } 638 639 return success(); 640 } 641 }; 642 643 /// Progressive lowering of vector transfer ops: Unpack one dimension. 644 /// 645 /// 1. Unpack one dimension from the current buffer type and cast the buffer 646 /// to that new type. E.g.: 647 /// ``` 648 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 649 /// vector.transfer_write %vec ... 650 /// ``` 651 /// The following cast is generated: 652 /// ``` 653 /// %casted = vector.type_cast %0 654 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 655 /// ``` 656 /// 2. Generate a for loop and rewrite the transfer op according to the 657 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 658 /// out-of-bounds, generate an if-check and handle both cases separately. 659 /// 3. Clean up according to the corresponding Strategy<OpTy>. 660 /// 661 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor 662 /// source (as opposed to a memref source), then each iteration of the generated 663 /// scf.for loop yields the new tensor value. E.g.: 664 /// ``` 665 /// %result = scf.for i = 0 to 5 { 666 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> 667 /// %1 = vector.transfer_write %0, %source[...] 668 /// : vector<4x3xf32>, tensor<5x4x3xf32> 669 /// scf.yield %1 : tensor<5x4x3xf32> 670 /// } 671 /// ``` 672 template <typename OpTy> 673 struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 674 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 675 676 void initialize() { 677 // This pattern recursively unpacks one dimension at a time. The recursion 678 // bounded as the rank is strictly decreasing. 679 this->setHasBoundedRewriteRecursion(); 680 } 681 682 LogicalResult matchAndRewrite(OpTy xferOp, 683 PatternRewriter &rewriter) const override { 684 if (!xferOp->hasAttr(kPassLabel)) 685 return failure(); 686 687 // Find and cast data buffer. How the buffer can be found depends on OpTy. 688 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); 689 auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 690 auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>(); 691 auto castedDataType = unpackOneDim(dataBufferType); 692 auto castedDataBuffer = 693 locB.create<vector::TypeCastOp>(castedDataType, dataBuffer); 694 695 // If the xferOp has a mask: Find and cast mask buffer. 696 Value castedMaskBuffer; 697 if (xferOp.mask()) { 698 auto maskBuffer = getMaskBuffer(xferOp); 699 auto maskBufferType = 700 maskBuffer.getType().template dyn_cast<MemRefType>(); 701 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 702 // Do not unpack a dimension of the mask, if: 703 // * To-be-unpacked transfer op dimension is a broadcast. 704 // * Mask is 1D, i.e., the mask cannot be further unpacked. 705 // (That means that all remaining dimensions of the transfer op must 706 // be broadcasted.) 707 castedMaskBuffer = maskBuffer; 708 } else { 709 auto castedMaskType = unpackOneDim(maskBufferType); 710 castedMaskBuffer = 711 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); 712 } 713 } 714 715 // Loop bounds and step. 716 auto lb = locB.create<arith::ConstantIndexOp>(0); 717 auto ub = locB.create<arith::ConstantIndexOp>( 718 castedDataType.getDimSize(castedDataType.getRank() - 1)); 719 auto step = locB.create<arith::ConstantIndexOp>(1); 720 // TransferWriteOps that operate on tensors return the modified tensor and 721 // require a loop state. 722 auto loopState = Strategy<OpTy>::initialLoopState(xferOp); 723 724 // Generate for loop. 725 auto result = locB.create<scf::ForOp>( 726 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 727 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 728 Type stateType = loopState.empty() ? Type() : loopState[0].getType(); 729 730 auto result = generateInBoundsCheck( 731 b, xferOp, iv, unpackedDim(xferOp), 732 stateType ? TypeRange(stateType) : TypeRange(), 733 /*inBoundsCase=*/ 734 [&](OpBuilder &b, Location loc) { 735 // Create new transfer op. 736 OpTy newXfer = Strategy<OpTy>::rewriteOp( 737 b, this->options, xferOp, castedDataBuffer, iv, loopState); 738 739 // If old transfer op has a mask: Set mask on new transfer op. 740 // Special case: If the mask of the old transfer op is 1D and 741 // the 742 // unpacked dim is not a broadcast, no mask is 743 // needed on the new transfer op. 744 if (xferOp.mask() && (xferOp.isBroadcastDim(0) || 745 xferOp.getMaskType().getRank() > 1)) { 746 OpBuilder::InsertionGuard guard(b); 747 b.setInsertionPoint(newXfer); // Insert load before newXfer. 748 749 SmallVector<Value, 8> loadIndices; 750 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); 751 // In case of broadcast: Use same indices to load from memref 752 // as before. 753 if (!xferOp.isBroadcastDim(0)) 754 loadIndices.push_back(iv); 755 756 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, 757 loadIndices); 758 rewriter.updateRootInPlace( 759 newXfer, [&]() { newXfer.maskMutable().assign(mask); }); 760 } 761 762 return loopState.empty() ? Value() : newXfer->getResult(0); 763 }, 764 /*outOfBoundsCase=*/ 765 [&](OpBuilder &b, Location /*loc*/) { 766 return Strategy<OpTy>::handleOutOfBoundsDim( 767 b, xferOp, castedDataBuffer, iv, loopState); 768 }); 769 770 maybeYieldValue(b, loc, !loopState.empty(), result); 771 }); 772 773 Strategy<OpTy>::cleanup(rewriter, xferOp, result); 774 return success(); 775 } 776 }; 777 778 } // namespace lowering_n_d 779 780 namespace lowering_n_d_unrolled { 781 782 /// If the original transfer op has a mask, compute the mask of the new transfer 783 /// op (for the current iteration `i`) and assign it. 784 template <typename OpTy> 785 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, 786 int64_t i) { 787 if (!xferOp.mask()) 788 return; 789 790 if (xferOp.isBroadcastDim(0)) { 791 // To-be-unpacked dimension is a broadcast, which does not have a 792 // corresponding mask dimension. Mask attribute remains unchanged. 793 newXferOp.maskMutable().assign(xferOp.mask()); 794 return; 795 } 796 797 if (xferOp.getMaskType().getRank() > 1) { 798 // Unpack one dimension of the mask. 799 OpBuilder::InsertionGuard guard(b); 800 b.setInsertionPoint(newXferOp); // Insert load before newXfer. 801 802 llvm::SmallVector<int64_t, 1> indices({i}); 803 Location loc = xferOp.getLoc(); 804 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices); 805 newXferOp.maskMutable().assign(newMask); 806 } 807 808 // If we end up here: The mask of the old transfer op is 1D and the unpacked 809 // dim is not a broadcast, so no mask is needed on the new transfer op. 810 // `generateInBoundsCheck` will have evaluated the mask already. 811 } 812 813 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 814 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 815 /// memref buffer is allocated and the SCF loop is fully unrolled. 816 /// 817 /// ``` 818 /// E.g.: 819 /// ``` 820 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 821 /// : memref<?x?x?xf32>, vector<5x4xf32> 822 /// ``` 823 /// is rewritten to IR such as (simplified): 824 /// ``` 825 /// %v_init = splat %padding : vector<5x4xf32> 826 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 827 /// : memref<?x?x?xf32>, vector<4xf32> 828 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 829 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 830 /// : memref<?x?x?xf32>, vector<4xf32> 831 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 832 /// ... 833 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 834 /// : memref<?x?x?xf32>, vector<4xf32> 835 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 836 /// ``` 837 /// 838 /// Note: As an optimization, if the result of the original TransferReadOp 839 /// was directly inserted into another vector, no new %v_init vector is created. 840 /// Instead, the new TransferReadOp results are inserted into that vector. 841 struct UnrollTransferReadConversion 842 : public VectorToSCFPattern<TransferReadOp> { 843 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 844 845 void initialize() { 846 // This pattern recursively unpacks one dimension at a time. The recursion 847 // bounded as the rank is strictly decreasing. 848 setHasBoundedRewriteRecursion(); 849 } 850 851 /// Return the vector into which the newly created TransferReadOp results 852 /// are inserted. 853 Value getResultVector(TransferReadOp xferOp, 854 PatternRewriter &rewriter) const { 855 if (auto insertOp = getInsertOp(xferOp)) 856 return insertOp.dest(); 857 Location loc = xferOp.getLoc(); 858 return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(), 859 xferOp.padding()); 860 } 861 862 /// If the result of the TransferReadOp has exactly one user, which is a 863 /// vector::InsertOp, return that operation. 864 vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 865 if (xferOp->hasOneUse()) { 866 Operation *xferOpUser = *xferOp->getUsers().begin(); 867 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 868 return insertOp; 869 } 870 871 return vector::InsertOp(); 872 } 873 874 /// If the result of the TransferReadOp has exactly one user, which is a 875 /// vector::InsertOp, return that operation's indices. 876 void getInsertionIndices(TransferReadOp xferOp, 877 SmallVector<int64_t, 8> &indices) const { 878 if (auto insertOp = getInsertOp(xferOp)) { 879 llvm::for_each(insertOp.position(), [&](Attribute attr) { 880 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 881 }); 882 } 883 } 884 885 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 886 /// accesses, and broadcasts and transposes in permutation maps. 887 LogicalResult matchAndRewrite(TransferReadOp xferOp, 888 PatternRewriter &rewriter) const override { 889 if (xferOp.getVectorType().getRank() <= options.targetRank) 890 return failure(); 891 if (isTensorOp(xferOp) && !options.lowerTensors) 892 return failure(); 893 // Transfer ops that modify the element type are not supported atm. 894 if (xferOp.getVectorType().getElementType() != 895 xferOp.getShapedType().getElementType()) 896 return failure(); 897 898 auto insertOp = getInsertOp(xferOp); 899 auto vec = getResultVector(xferOp, rewriter); 900 auto vecType = vec.getType().dyn_cast<VectorType>(); 901 auto xferVecType = xferOp.getVectorType(); 902 auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), 903 xferVecType.getElementType()); 904 int64_t dimSize = xferVecType.getShape()[0]; 905 906 // Generate fully unrolled loop of transfer ops. 907 Location loc = xferOp.getLoc(); 908 for (int64_t i = 0; i < dimSize; ++i) { 909 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 910 911 vec = generateInBoundsCheck( 912 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), 913 /*inBoundsCase=*/ 914 [&](OpBuilder &b, Location loc) { 915 // Indices for the new transfer op. 916 SmallVector<Value, 8> xferIndices; 917 getXferIndices(b, xferOp, iv, xferIndices); 918 919 // Indices for the new vector.insert op. 920 SmallVector<int64_t, 8> insertionIndices; 921 getInsertionIndices(xferOp, insertionIndices); 922 insertionIndices.push_back(i); 923 924 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 925 auto newXferOp = b.create<vector::TransferReadOp>( 926 loc, newXferVecType, xferOp.source(), xferIndices, 927 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 928 xferOp.padding(), Value(), inBoundsAttr); 929 maybeAssignMask(b, xferOp, newXferOp, i); 930 return b.create<vector::InsertOp>(loc, newXferOp, vec, 931 insertionIndices); 932 }, 933 /*outOfBoundsCase=*/ 934 [&](OpBuilder &b, Location loc) { 935 // Loop through original (unmodified) vector. 936 return vec; 937 }); 938 } 939 940 if (insertOp) { 941 // Rewrite single user of the old TransferReadOp, which was an InsertOp. 942 rewriter.replaceOp(insertOp, vec); 943 rewriter.eraseOp(xferOp); 944 } else { 945 rewriter.replaceOp(xferOp, vec); 946 } 947 948 return success(); 949 } 950 }; 951 952 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 953 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 954 /// memref buffer is allocated and the SCF loop is fully unrolled. 955 /// 956 /// ``` 957 /// E.g.: 958 /// ``` 959 /// vector.transfer_write %vec, %A[%a, %b, %c] 960 /// : vector<5x4xf32>, memref<?x?x?xf32> 961 /// ``` 962 /// is rewritten to IR such as (simplified): 963 /// ``` 964 /// %v0 = vector.extract %vec[0] : vector<5x4xf32> 965 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 966 /// %v1 = vector.extract %vec[1] : vector<5x4xf32> 967 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 968 /// ... 969 /// %v4 = vector.extract %vec[4] : vector<5x4xf32> 970 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 971 /// ``` 972 /// 973 /// Note: As an optimization, if the vector of the original TransferWriteOp 974 /// was directly extracted from another vector via an ExtractOp `a`, extract 975 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 976 /// doing so, `a` may become dead, and the number of ExtractOps generated during 977 /// recursive application of this pattern will be minimal. 978 struct UnrollTransferWriteConversion 979 : public VectorToSCFPattern<TransferWriteOp> { 980 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 981 982 void initialize() { 983 // This pattern recursively unpacks one dimension at a time. The recursion 984 // bounded as the rank is strictly decreasing. 985 setHasBoundedRewriteRecursion(); 986 } 987 988 /// Return the vector from which newly generated ExtracOps will extract. 989 Value getDataVector(TransferWriteOp xferOp) const { 990 if (auto extractOp = getExtractOp(xferOp)) 991 return extractOp.vector(); 992 return xferOp.vector(); 993 } 994 995 /// If the input of the given TransferWriteOp is an ExtractOp, return it. 996 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 997 if (auto *op = xferOp.vector().getDefiningOp()) 998 return dyn_cast<vector::ExtractOp>(op); 999 return vector::ExtractOp(); 1000 } 1001 1002 /// If the input of the given TransferWriteOp is an ExtractOp, return its 1003 /// indices. 1004 void getExtractionIndices(TransferWriteOp xferOp, 1005 SmallVector<int64_t, 8> &indices) const { 1006 if (auto extractOp = getExtractOp(xferOp)) { 1007 llvm::for_each(extractOp.position(), [&](Attribute attr) { 1008 indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 1009 }); 1010 } 1011 } 1012 1013 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 1014 /// accesses, and broadcasts and transposes in permutation maps. 1015 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 1016 PatternRewriter &rewriter) const override { 1017 if (xferOp.getVectorType().getRank() <= options.targetRank) 1018 return failure(); 1019 if (isTensorOp(xferOp) && !options.lowerTensors) 1020 return failure(); 1021 // Transfer ops that modify the element type are not supported atm. 1022 if (xferOp.getVectorType().getElementType() != 1023 xferOp.getShapedType().getElementType()) 1024 return failure(); 1025 1026 auto vec = getDataVector(xferOp); 1027 auto xferVecType = xferOp.getVectorType(); 1028 int64_t dimSize = xferVecType.getShape()[0]; 1029 auto source = xferOp.source(); // memref or tensor to be written to. 1030 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 1031 1032 // Generate fully unrolled loop of transfer ops. 1033 Location loc = xferOp.getLoc(); 1034 for (int64_t i = 0; i < dimSize; ++i) { 1035 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 1036 1037 auto updatedSource = generateInBoundsCheck( 1038 rewriter, xferOp, iv, unpackedDim(xferOp), 1039 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), 1040 /*inBoundsCase=*/ 1041 [&](OpBuilder &b, Location loc) { 1042 // Indices for the new transfer op. 1043 SmallVector<Value, 8> xferIndices; 1044 getXferIndices(b, xferOp, iv, xferIndices); 1045 1046 // Indices for the new vector.extract op. 1047 SmallVector<int64_t, 8> extractionIndices; 1048 getExtractionIndices(xferOp, extractionIndices); 1049 extractionIndices.push_back(i); 1050 1051 auto extracted = 1052 b.create<vector::ExtractOp>(loc, vec, extractionIndices); 1053 auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 1054 auto newXferOp = b.create<vector::TransferWriteOp>( 1055 loc, sourceType, extracted, source, xferIndices, 1056 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 1057 inBoundsAttr); 1058 1059 maybeAssignMask(b, xferOp, newXferOp, i); 1060 1061 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); 1062 }, 1063 /*outOfBoundsCase=*/ 1064 [&](OpBuilder &b, Location loc) { 1065 return isTensorOp(xferOp) ? source : Value(); 1066 }); 1067 1068 if (isTensorOp(xferOp)) 1069 source = updatedSource; 1070 } 1071 1072 if (isTensorOp(xferOp)) 1073 rewriter.replaceOp(xferOp, source); 1074 else 1075 rewriter.eraseOp(xferOp); 1076 1077 return success(); 1078 } 1079 }; 1080 1081 } // namespace lowering_n_d_unrolled 1082 1083 namespace lowering_1_d { 1084 1085 /// Compute the indices into the memref for the LoadOp/StoreOp generated as 1086 /// part of TransferOp1dConversion. Return the memref dimension on which 1087 /// the transfer is operating. A return value of None indicates a broadcast. 1088 template <typename OpTy> 1089 static Optional<int64_t> 1090 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, 1091 SmallVector<Value, 8> &memrefIndices) { 1092 auto indices = xferOp.indices(); 1093 auto map = xferOp.permutation_map(); 1094 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 1095 1096 memrefIndices.append(indices.begin(), indices.end()); 1097 assert(map.getNumResults() == 1 && 1098 "Expected 1 permutation map result for 1D transfer"); 1099 if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 1100 Location loc = xferOp.getLoc(); 1101 auto dim = expr.getPosition(); 1102 AffineExpr d0, d1; 1103 bindDims(xferOp.getContext(), d0, d1); 1104 Value offset = memrefIndices[dim]; 1105 memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 1106 return dim; 1107 } 1108 1109 assert(xferOp.isBroadcastDim(0) && 1110 "Expected AffineDimExpr or AffineConstantExpr"); 1111 return None; 1112 } 1113 1114 /// Codegen strategy for TransferOp1dConversion, depending on the 1115 /// operation. 1116 template <typename OpTy> 1117 struct Strategy1d; 1118 1119 /// Codegen strategy for TransferReadOp. 1120 template <> 1121 struct Strategy1d<TransferReadOp> { 1122 static void generateForLoopBody(OpBuilder &b, Location loc, 1123 TransferReadOp xferOp, Value iv, 1124 ValueRange loopState) { 1125 SmallVector<Value, 8> indices; 1126 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1127 auto vec = loopState[0]; 1128 1129 // In case of out-of-bounds access, leave `vec` as is (was initialized with 1130 // padding value). 1131 auto nextVec = generateInBoundsCheck( 1132 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), 1133 /*inBoundsCase=*/ 1134 [&](OpBuilder &b, Location loc) { 1135 Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices); 1136 return b.create<vector::InsertElementOp>(loc, val, vec, iv); 1137 }, 1138 /*outOfBoundsCase=*/ 1139 [&](OpBuilder & /*b*/, Location loc) { return vec; }); 1140 b.create<scf::YieldOp>(loc, nextVec); 1141 } 1142 1143 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { 1144 // Inititalize vector with padding value. 1145 Location loc = xferOp.getLoc(); 1146 return b.create<vector::SplatOp>(loc, xferOp.getVectorType(), 1147 xferOp.padding()); 1148 } 1149 }; 1150 1151 /// Codegen strategy for TransferWriteOp. 1152 template <> 1153 struct Strategy1d<TransferWriteOp> { 1154 static void generateForLoopBody(OpBuilder &b, Location loc, 1155 TransferWriteOp xferOp, Value iv, 1156 ValueRange /*loopState*/) { 1157 SmallVector<Value, 8> indices; 1158 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1159 1160 // Nothing to do in case of out-of-bounds access. 1161 generateInBoundsCheck( 1162 b, xferOp, iv, dim, 1163 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 1164 auto val = 1165 b.create<vector::ExtractElementOp>(loc, xferOp.vector(), iv); 1166 b.create<memref::StoreOp>(loc, val, xferOp.source(), indices); 1167 }); 1168 b.create<scf::YieldOp>(loc); 1169 } 1170 1171 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { 1172 return Value(); 1173 } 1174 }; 1175 1176 /// Return true if the last dimension of the MemRefType has unit stride. 1177 static bool isLastMemrefDimUnitStride(MemRefType type) { 1178 int64_t offset; 1179 SmallVector<int64_t, 4> strides; 1180 auto successStrides = getStridesAndOffset(type, strides, offset); 1181 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 1182 } 1183 1184 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 1185 /// necessary in cases where a 1D vector transfer op cannot be lowered into 1186 /// vector load/stores due to non-unit strides or broadcasts: 1187 /// 1188 /// * Transfer dimension is not the last memref dimension 1189 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 1190 /// * Memref has a layout map with non-unit stride on the last dimension 1191 /// 1192 /// This pattern generates IR as follows: 1193 /// 1194 /// 1. Generate a for loop iterating over each vector element. 1195 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 1196 /// depending on OpTy. 1197 /// 1198 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 1199 /// can be generated instead of TransferOp1dConversion. Add such a pattern 1200 /// to ConvertVectorToLLVM. 1201 /// 1202 /// E.g.: 1203 /// ``` 1204 /// vector.transfer_write %vec, %A[%a, %b] 1205 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 1206 /// : vector<9xf32>, memref<?x?xf32> 1207 /// ``` 1208 /// Is rewritten to approximately the following pseudo-IR: 1209 /// ``` 1210 /// for i = 0 to 9 { 1211 /// %t = vector.extractelement %vec[i] : vector<9xf32> 1212 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 1213 /// } 1214 /// ``` 1215 template <typename OpTy> 1216 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 1217 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 1218 1219 LogicalResult matchAndRewrite(OpTy xferOp, 1220 PatternRewriter &rewriter) const override { 1221 // TODO: support 0-d corner case. 1222 if (xferOp.getTransferRank() == 0) 1223 return failure(); 1224 auto map = xferOp.permutation_map(); 1225 auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 1226 1227 if (!memRefType) 1228 return failure(); 1229 if (xferOp.getVectorType().getRank() != 1) 1230 return failure(); 1231 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) 1232 return failure(); // Handled by ConvertVectorToLLVM 1233 1234 // Loop bounds, step, state... 1235 Location loc = xferOp.getLoc(); 1236 auto vecType = xferOp.getVectorType(); 1237 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1238 auto ub = 1239 rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); 1240 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1241 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); 1242 1243 // Generate for loop. 1244 rewriter.replaceOpWithNewOp<scf::ForOp>( 1245 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 1246 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 1247 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); 1248 }); 1249 1250 return success(); 1251 } 1252 }; 1253 1254 } // namespace lowering_1_d 1255 } // namespace 1256 1257 namespace mlir { 1258 1259 void populateVectorToSCFConversionPatterns( 1260 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 1261 if (options.unroll) { 1262 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, 1263 lowering_n_d_unrolled::UnrollTransferWriteConversion>( 1264 patterns.getContext(), options); 1265 } else { 1266 patterns.add<lowering_n_d::PrepareTransferReadConversion, 1267 lowering_n_d::PrepareTransferWriteConversion, 1268 lowering_n_d::TransferOpConversion<TransferReadOp>, 1269 lowering_n_d::TransferOpConversion<TransferWriteOp>>( 1270 patterns.getContext(), options); 1271 } 1272 1273 if (options.targetRank == 1) { 1274 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, 1275 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( 1276 patterns.getContext(), options); 1277 } 1278 } 1279 1280 } // namespace mlir 1281 1282 namespace { 1283 1284 struct ConvertVectorToSCFPass 1285 : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 1286 ConvertVectorToSCFPass() = default; 1287 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1288 this->fullUnroll = options.unroll; 1289 this->targetRank = options.targetRank; 1290 this->lowerPermutationMaps = options.lowerPermutationMaps; 1291 this->lowerTensors = options.lowerTensors; 1292 } 1293 1294 void runOnOperation() override { 1295 VectorTransferToSCFOptions options; 1296 options.unroll = fullUnroll; 1297 options.targetRank = targetRank; 1298 options.lowerPermutationMaps = lowerPermutationMaps; 1299 options.lowerTensors = lowerTensors; 1300 1301 // Lower permutation maps first. 1302 if (lowerPermutationMaps) { 1303 RewritePatternSet lowerTransferPatterns(getOperation().getContext()); 1304 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 1305 lowerTransferPatterns); 1306 (void)applyPatternsAndFoldGreedily(getOperation(), 1307 std::move(lowerTransferPatterns)); 1308 } 1309 1310 RewritePatternSet patterns(getOperation().getContext()); 1311 populateVectorToSCFConversionPatterns(patterns, options); 1312 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 1313 } 1314 }; 1315 1316 } // namespace 1317 1318 std::unique_ptr<Pass> 1319 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1320 return std::make_unique<ConvertVectorToSCFPass>(options); 1321 } 1322