1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 #define DEBUG_TYPE "linalg-vectorization" 38 39 static bool hasMultiplyAddBody(Region &r) { 40 if (!llvm::hasSingleElement(r)) 41 return false; 42 if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) 43 return false; 44 45 using mlir::matchers::m_Val; 46 auto a = m_Val(r.getArgument(0)); 47 auto b = m_Val(r.getArgument(1)); 48 auto c = m_Val(r.getArgument(2)); 49 // TODO: Update this detection once we have matcher support for specifying 50 // that any permutation of operands matches. 51 auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); 52 auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); 53 auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); 54 auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); 55 auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c)); 56 auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b))); 57 auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c)); 58 auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a))); 59 return pattern1.match(&r.front().back()) || 60 pattern2.match(&r.front().back()) || 61 pattern3.match(&r.front().back()) || 62 pattern4.match(&r.front().back()) || 63 pattern5.match(&r.front().back()) || 64 pattern6.match(&r.front().back()) || 65 pattern7.match(&r.front().back()) || pattern8.match(&r.front().back()); 66 } 67 68 // TODO: Should be Tablegen'd from a single source that generates the op itself. 69 static LogicalResult isContraction(Operation *op) { 70 // TODO: interface for named ops. 71 if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp, 72 linalg::VecmatOp, linalg::DotOp>(op)) 73 return success(); 74 75 auto genericOp = dyn_cast<linalg::GenericOp>(op); 76 if (!genericOp) 77 return failure(); 78 79 auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>(); 80 return success( 81 genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && 82 llvm::all_of(mapRange, 83 [](AffineMap m) { return m.isProjectedPermutation(); }) && 84 hasMultiplyAddBody(genericOp.region())); 85 } 86 87 static bool hasOnlyScalarElementwiseOp(Region &r) { 88 if (!llvm::hasSingleElement(r)) 89 return false; 90 for (Operation &op : r.front()) { 91 if (!(isa<ConstantOp, linalg::YieldOp>(op) || 92 op.hasTrait<OpTrait::ElementwiseMappable>()) || 93 llvm::any_of(op.getResultTypes(), 94 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 95 return false; 96 } 97 return true; 98 } 99 100 // Return true if the op is an element-wise linalg op. 101 static bool isElementwise(Operation *op) { 102 auto genericOp = dyn_cast<linalg::GenericOp>(op); 103 if (!genericOp) 104 return false; 105 if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) 106 return false; 107 // TODO: relax the restrictions on indexing map. 108 for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) { 109 if (!genericOp.getOutputIndexingMap(i).isIdentity()) 110 return false; 111 } 112 // Currently limit the input indexing map to minor identity as other 113 // permutations might require adding transpose ops to convert the vector read 114 // to the right shape. 115 for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) { 116 if (!genericOp.getInputIndexingMap(i).isMinorIdentity()) 117 return false; 118 } 119 return hasOnlyScalarElementwiseOp(genericOp.getRegion()); 120 } 121 122 static VectorType extractVectorTypeFromScalarView(Value v) { 123 MemRefType mt = v.getType().cast<MemRefType>(); 124 return mt.getShape().empty() 125 ? VectorType() 126 : VectorType::get(mt.getShape(), mt.getElementType()); 127 } 128 129 static Value transferReadVector(OpBuilder &builder, Value memref) { 130 edsc::ScopedContext scope(builder); 131 auto memrefType = memref.getType().cast<MemRefType>(); 132 if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) { 133 SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0)); 134 return vector_transfer_read(vectorType, memref, indices); 135 } 136 return std_load(memref); 137 } 138 139 static void transferWriteVector(OpBuilder &builder, Value value, Value memref) { 140 edsc::ScopedContext scope(builder); 141 auto memrefType = memref.getType().cast<MemRefType>(); 142 if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) { 143 SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0)); 144 if (vectorType != value.getType()) 145 value = vector_broadcast(vectorType, value); 146 vector_transfer_write(value, memref, indices); 147 } else { 148 std_store(value, memref); 149 } 150 } 151 152 namespace { 153 // Transforms scalar operations into their vectorized counterparts, 154 // while using the provided generic op to map: 155 // * Its arguments to transfer reads from the views of the generic op. 156 // * linalg.yield ops to transfer writes to the views of the generic op. 157 class GenericVectorizer { 158 public: 159 GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic) 160 : builder(builder), generic(generic) {} 161 162 // Takes a scalar operation and builds its vectorized counterpart or 163 // counterparts using the underlying builder. 164 // If operands of the scalar operation are referring to previously vectorized 165 // operations, then in their vectorized form these operands will be referring 166 // to previous vectorization results. 167 void vectorize(Operation &scalarOp) { 168 auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp); 169 if (yieldOp) { 170 for (auto outputAndMemref : 171 llvm::zip(yieldOp.values(), generic.getOutputBuffers())) { 172 Value vectorValue = vectorize(std::get<0>(outputAndMemref)); 173 transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref)); 174 } 175 return; 176 } 177 Operation *vectorOp = uncachedVectorize(scalarOp); 178 assert(scalarOp.getNumResults() == vectorOp->getNumResults()); 179 for (auto result : 180 llvm::zip(scalarOp.getResults(), vectorOp->getResults())) { 181 valueCache[std::get<0>(result)] = std::get<1>(result); 182 } 183 } 184 185 private: 186 // Transforms a scalar value into its vectorized counterpart, recursively 187 // vectorizing operations as necessary using the underlying builder. 188 // Keeps track of previously vectorized values and reuses vectorization 189 // results if these values come up again. 190 Value vectorize(Value scalarValue) { 191 // Don't vectorize values coming from outside the region. 192 if (scalarValue.getParentRegion() != &generic.region()) 193 return scalarValue; 194 auto vectorValueIt = valueCache.find(scalarValue); 195 if (vectorValueIt != valueCache.end()) 196 return vectorValueIt->second; 197 198 // If the value is from the region but not in the cache it means it is a 199 // block argument. 200 auto scalarArg = scalarValue.cast<BlockArgument>(); 201 assert(scalarArg.getOwner() == &generic.region().front()); 202 Value vectorArg = generic.getShapedOperand(scalarArg.getArgNumber()); 203 Value vectorResult = transferReadVector(builder, vectorArg); 204 valueCache[scalarArg] = vectorResult; 205 return vectorResult; 206 } 207 208 // Return the largest shape of all the given values. Return an empty 209 // SmallVector if there are no vector value. 210 static SmallVector<int64_t, 4> getLargestShape(ArrayRef<Value> values) { 211 SmallVector<int64_t, 4> largestShape; 212 int64_t maxSize = 1; 213 for (Value value : values) { 214 auto vecType = value.getType().dyn_cast<VectorType>(); 215 if (!vecType) 216 continue; 217 if (maxSize < vecType.getNumElements()) { 218 maxSize = vecType.getNumElements(); 219 largestShape.assign(vecType.getShape().begin(), 220 vecType.getShape().end()); 221 } 222 } 223 return largestShape; 224 } 225 226 // If the value's type doesn't have the given shape broadcast it. 227 Value broadcastIfNeeded(Value value, ArrayRef<int64_t> shape) { 228 auto vecType = value.getType().dyn_cast<VectorType>(); 229 if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) 230 return value; 231 auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() 232 : value.getType()); 233 return builder.create<vector::BroadcastOp>( 234 builder.getInsertionPoint()->getLoc(), newVecType, value); 235 } 236 237 // Takes a scalar operation and builds its vectorized counterpart or 238 // counterparts using underlying builder without involving any caches. 239 Operation *uncachedVectorize(Operation &base_scalarOp) { 240 SmallVector<Value, 4> vectorizedOperands; 241 for (Value operand : base_scalarOp.getOperands()) { 242 vectorizedOperands.push_back(vectorize(operand)); 243 } 244 SmallVector<int64_t, 4> shape = getLargestShape(vectorizedOperands); 245 for (Value &operand : vectorizedOperands) 246 operand = broadcastIfNeeded(operand, shape); 247 OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName()); 248 state.addAttributes(base_scalarOp.getAttrs()); 249 state.addOperands(vectorizedOperands); 250 if (shape.empty()) { 251 state.addTypes(base_scalarOp.getResultTypes()); 252 } else { 253 SmallVector<VectorType, 4> vectorizedTypes; 254 for (auto Type : base_scalarOp.getResultTypes()) 255 vectorizedTypes.push_back(VectorType::get(shape, Type)); 256 state.addTypes(vectorizedTypes); 257 } 258 return builder.createOperation(state); 259 } 260 261 OpBuilder &builder; 262 linalg::GenericOp generic; 263 llvm::DenseMap<Value, Value> valueCache; 264 }; 265 } // namespace 266 267 // Replaces elementwise linalg.generic ops with their bodies with scalar 268 // operations from these bodies promoted to vector operations. 269 static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) { 270 GenericVectorizer vectorizer(builder, op); 271 for (Operation &scalarOp : op.region().front()) { 272 vectorizer.vectorize(scalarOp); 273 } 274 } 275 276 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 277 auto linalgOp = cast<linalg::LinalgOp>(op); 278 // All types must be static shape to go to vector. 279 for (Value operand : linalgOp.getShapedOperands()) 280 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 281 return failure(); 282 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 283 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 284 return failure(); 285 286 if (isa<linalg::FillOp, linalg::CopyOp>(op)) 287 return success(); 288 if (isElementwise(op)) 289 return success(); 290 return isContraction(op); 291 } 292 293 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 294 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 295 296 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 297 (void)dbgPref; 298 edsc::ScopedContext scope(builder, op->getLoc()); 299 // In the case of 0-D memrefs, return null and special case to scalar load or 300 // store later. 301 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 302 // Vectorize fill as a vector.broadcast. 303 LLVM_DEBUG(dbgs() << dbgPref 304 << "Rewrite linalg.fill as vector.broadcast: " << *op); 305 transferWriteVector(builder, fillOp.value(), fillOp.output()); 306 return; 307 } 308 if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) { 309 // Vectorize copy as a vector.transfer_read+vector.transfer_write. 310 LLVM_DEBUG(dbgs() << dbgPref 311 << "Rewrite linalg.copy as vector.transfer_read + " 312 "vector.transfer_write: " 313 << *op); 314 Value vector = transferReadVector(builder, copyOp.input()); 315 transferWriteVector(builder, vector, copyOp.output()); 316 return; 317 } 318 319 if (isElementwise(op)) { 320 LLVM_DEBUG(dbgs() << dbgPref 321 << "Rewrite linalg op as vector.transfer_read + " 322 "vector_op + vector.transfer_write: " 323 << *op); 324 return vectorizeElementwise(cast<linalg::GenericOp>(op), builder); 325 } 326 327 assert(succeeded(isContraction(op)) && "Expected contraction"); 328 329 // Vectorize other ops as vector contraction. 330 // TODO: interface. 331 LLVM_DEBUG(dbgs() << dbgPref 332 << "Rewrite linalg op as vector.contract: " << *op); 333 auto linalgOp = cast<linalg::LinalgOp>(op); 334 Value viewA = linalgOp.getInput(0); 335 Value viewB = linalgOp.getInput(1); 336 Value viewC = linalgOp.getOutputBuffer(0); 337 VectorType vtA = extractVectorTypeFromScalarView(viewA); 338 VectorType vtB = extractVectorTypeFromScalarView(viewB); 339 VectorType vtC = extractVectorTypeFromScalarView(viewC); 340 Value zero = std_constant_index(0); 341 SmallVector<Value, 4> indicesA, indicesB, indicesC; 342 if (vtA) 343 indicesA = SmallVector<Value, 4>(vtA.getRank(), zero); 344 if (vtB) 345 indicesB = SmallVector<Value, 4>(vtB.getRank(), zero); 346 if (vtC) 347 indicesC = SmallVector<Value, 4>(vtC.getRank(), zero); 348 Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value 349 : std_load(viewA, indicesA).value; 350 Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value 351 : std_load(viewB, indicesB).value; 352 Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value 353 : std_load(viewC, indicesC).value; 354 Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), 355 linalgOp.iterator_types()); 356 if (vtC) 357 vector_transfer_write(res, viewC, indicesC); 358 else 359 std_store(res, viewC, indicesC); 360 } 361 362 /// Check whether there is any interleaved use of any `values` between `firstOp` 363 /// and `secondOp`. Conservatively return `true` if any op or value is in a 364 /// different block. 365 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 366 ValueRange values) { 367 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 368 (void)dbgPref; 369 if (firstOp->getBlock() != secondOp->getBlock() || 370 !firstOp->isBeforeInBlock(secondOp)) { 371 LLVM_DEBUG(llvm::dbgs() 372 << dbgPref << "interleavedUses precondition failed, firstOp: " 373 << *firstOp << ", second op: " << *secondOp); 374 return true; 375 } 376 for (auto v : values) { 377 for (auto &u : v.getUses()) { 378 Operation *owner = u.getOwner(); 379 if (owner == firstOp || owner == secondOp) 380 continue; 381 // TODO: this is too conservative, use dominance info in the future. 382 if (owner->getBlock() == firstOp->getBlock() && 383 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 384 continue; 385 LLVM_DEBUG(llvm::dbgs() 386 << dbgPref << " found interleaved op " << *owner 387 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 388 return true; 389 } 390 } 391 return false; 392 } 393 394 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 395 static SubViewOp getSubViewUseIfUnique(Value v) { 396 SubViewOp subViewOp; 397 for (auto &u : v.getUses()) { 398 if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) { 399 if (subViewOp) 400 return SubViewOp(); 401 subViewOp = newSubViewOp; 402 } 403 } 404 return subViewOp; 405 } 406 407 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 408 /// when available. 409 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 410 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 411 412 // Transfer into `view`. 413 Value viewOrAlloc = xferOp.source(); 414 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 415 !viewOrAlloc.getDefiningOp<AllocOp>()) 416 return failure(); 417 418 StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; 419 (void)dbgPref; 420 LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); 421 422 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 423 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 424 if (!subViewOp) 425 return failure(); 426 Value subView = subViewOp.getResult(); 427 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); 428 429 // Find the copy into `subView` without interleaved uses. 430 CopyOp copyOp; 431 for (auto &u : subView.getUses()) { 432 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 433 if (newCopyOp.getOutputBuffer(0) != subView) 434 continue; 435 LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); 436 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 437 continue; 438 copyOp = newCopyOp; 439 break; 440 } 441 } 442 if (!copyOp) 443 return failure(); 444 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); 445 446 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 447 FillOp maybeFillOp; 448 for (auto &u : viewOrAlloc.getUses()) { 449 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 450 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 451 continue; 452 LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); 453 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 454 continue; 455 maybeFillOp = newFillOp; 456 break; 457 } 458 } 459 // Ensure padding matches. 460 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 461 return failure(); 462 if (maybeFillOp) 463 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); 464 465 // `in` is the subview that linalg.copy reads. Replace it. 466 Value in = copyOp.getInput(0); 467 468 // linalg.copy + linalg.fill can be used to create a padded local buffer. 469 // The `masked` attribute is only valid on this padded buffer. 470 // When forwarding to vector.transfer_read, the attribute must be reset 471 // conservatively. 472 Value res = rewriter.create<vector::TransferReadOp>( 473 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 474 xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); 475 476 if (maybeFillOp) 477 rewriter.eraseOp(maybeFillOp); 478 rewriter.eraseOp(copyOp); 479 rewriter.replaceOp(xferOp, res); 480 481 return success(); 482 } 483 484 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 485 /// when available. 486 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 487 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 488 // Transfer into `viewOrAlloc`. 489 Value viewOrAlloc = xferOp.source(); 490 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 491 !viewOrAlloc.getDefiningOp<AllocOp>()) 492 return failure(); 493 494 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 495 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 496 if (!subViewOp) 497 return failure(); 498 Value subView = subViewOp.getResult(); 499 500 // Find the copy from `subView` without interleaved uses. 501 CopyOp copyOp; 502 for (auto &u : subViewOp.getResult().getUses()) { 503 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 504 if (newCopyOp.getInput(0) != subView) 505 continue; 506 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 507 continue; 508 copyOp = newCopyOp; 509 break; 510 } 511 } 512 if (!copyOp) 513 return failure(); 514 515 // `out` is the subview copied into that we replace. 516 Value out = copyOp.getOutputBuffer(0); 517 518 // Forward vector.transfer into copy. 519 // linalg.copy + linalg.fill can be used to create a padded local buffer. 520 // The `masked` attribute is only valid on this padded buffer. 521 // When forwarding to vector.transfer_write, the attribute must be reset 522 // conservatively. 523 rewriter.create<vector::TransferWriteOp>( 524 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 525 xferOp.permutation_map(), ArrayAttr()); 526 527 rewriter.eraseOp(copyOp); 528 rewriter.eraseOp(xferOp); 529 530 return success(); 531 } 532 533 template <class ConvOp, int N> 534 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite( 535 ConvOp op, PatternRewriter &rewriter) const { 536 Location loc = op.getLoc(); 537 MLIRContext *context = op.getContext(); 538 edsc::ScopedContext scope(rewriter, loc); 539 540 ShapedType inShapeType = op.getInputShapedType(0); 541 ShapedType kShapeType = op.getInputShapedType(1); 542 543 ArrayRef<int64_t> inShape = inShapeType.getShape(); 544 ArrayRef<int64_t> kShape = kShapeType.getShape(); 545 546 if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) 547 return failure(); 548 549 SmallVector<AffineExpr, 4> mapping; 550 SmallVector<int64_t, 4> vectorDims; 551 // Fail to apply when the size of not vectorized dimension is not 1. 552 for (unsigned i = 0; i < N; i++) { 553 if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) 554 return failure(); 555 556 if (mask[i] && inShape[i] != kShape[i]) 557 return failure(); 558 559 if (mask[i]) { 560 mapping.push_back(getAffineDimExpr(i, context)); 561 vectorDims.push_back(inShape[i]); 562 } 563 } 564 565 Value input = op.getInput(0); 566 Value kernel = op.getInput(1); 567 Value output = op.getOutputBuffer(0); 568 569 unsigned rank = inShapeType.getRank(); 570 unsigned numDims = mapping.size(); 571 Type elemType = inShapeType.getElementType(); 572 573 auto map = AffineMap::get(rank, 0, mapping, context); 574 SmallVector<Value, 4> zeros(rank, std_constant_index(0)); 575 auto vecType = VectorType::get(vectorDims, elemType); 576 577 auto inputVec = vector_transfer_read(vecType, input, zeros, map); 578 auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); 579 580 auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); 581 582 std::array<AffineMap, 3> indexingMaps{ 583 AffineMap::getMultiDimIdentityMap(numDims, context), 584 AffineMap::getMultiDimIdentityMap(numDims, context), 585 AffineMap::get(numDims, 0, {}, context)}; 586 587 std::vector<StringRef> iteratorTypes(numDims, "reduction"); 588 589 auto result = rewriter.create<vector::ContractionOp>( 590 loc, inputVec, kernelVec, acc, 591 rewriter.getAffineMapArrayAttr(indexingMaps), 592 rewriter.getStrArrayAttr(iteratorTypes)); 593 594 rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros)); 595 rewriter.eraseOp(op); 596 return success(); 597 } 598 599 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>; 600 601 /// Inserts tiling, promotion and vectorization pattern for ConvOp 602 /// conversion into corresponding pattern lists. 603 template <typename ConvOp, unsigned N> 604 static void 605 populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, 606 OwningRewritePatternList &promotionPatterns, 607 OwningRewritePatternList &vectorizationPatterns, 608 ArrayRef<int64_t> tileSizes, 609 MLIRContext *context) { 610 if (tileSizes.size() < N) 611 return; 612 613 constexpr static StringRef kTiledMarker = "TILED"; 614 constexpr static StringRef kPromotedMarker = "PROMOTED"; 615 tilingPatterns.insert<LinalgTilingPattern<ConvOp>>( 616 context, LinalgTilingOptions().setTileSizes(tileSizes), 617 LinalgMarker({}, Identifier::get(kTiledMarker, context))); 618 619 promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>( 620 context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 621 LinalgMarker(Identifier::get(kTiledMarker, context), 622 Identifier::get(kPromotedMarker, context))); 623 624 SmallVector<bool, 4> mask(N); 625 int offset = tileSizes.size() - N; 626 std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), 627 [](int64_t i) -> bool { return i > 1; }); 628 629 vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask); 630 } 631 632 void mlir::linalg::populateConvVectorizationPatterns( 633 MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns, 634 ArrayRef<int64_t> tileSizes) { 635 OwningRewritePatternList tiling, promotion, vectorization; 636 populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization, 637 tileSizes, context); 638 639 populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization, 640 tileSizes, context); 641 642 populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization, 643 tileSizes, context); 644 645 populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization, 646 tileSizes, context); 647 648 populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization, 649 tileSizes, context); 650 651 populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization, 652 tileSizes, context); 653 654 populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization, 655 tileSizes, context); 656 657 populateVectorizationPatterns<ConvNDHWCOp, 5>( 658 tiling, promotion, vectorization, tileSizes, context); 659 660 populateVectorizationPatterns<ConvNCDHWOp, 5>( 661 tiling, promotion, vectorization, tileSizes, context); 662 663 patterns.push_back(std::move(tiling)); 664 patterns.push_back(std::move(promotion)); 665 patterns.push_back(std::move(vectorization)); 666 } 667