1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 #define DEBUG_TYPE "linalg-vectorization" 38 39 static bool hasMultiplyAddBody(Region &r) { 40 if (!llvm::hasSingleElement(r)) 41 return false; 42 if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) 43 return false; 44 45 using mlir::matchers::m_Val; 46 auto a = m_Val(r.getArgument(0)); 47 auto b = m_Val(r.getArgument(1)); 48 auto c = m_Val(r.getArgument(2)); 49 // TODO: Update this detection once we have matcher support for specifying 50 // that any permutation of operands matches. 51 auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); 52 auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); 53 auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); 54 auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); 55 auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c)); 56 auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b))); 57 auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c)); 58 auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a))); 59 return pattern1.match(&r.front().back()) || 60 pattern2.match(&r.front().back()) || 61 pattern3.match(&r.front().back()) || 62 pattern4.match(&r.front().back()) || 63 pattern5.match(&r.front().back()) || 64 pattern6.match(&r.front().back()) || 65 pattern7.match(&r.front().back()) || pattern8.match(&r.front().back()); 66 } 67 68 // TODO: Should be Tablegen'd from a single source that generates the op itself. 69 static LogicalResult isContraction(Operation *op) { 70 // TODO: interface for named ops. 71 if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp, 72 linalg::VecmatOp, linalg::DotOp>(op)) 73 return success(); 74 75 auto genericOp = dyn_cast<linalg::GenericOp>(op); 76 if (!genericOp) 77 return failure(); 78 79 auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>(); 80 return success( 81 genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && 82 llvm::all_of(mapRange, 83 [](AffineMap m) { return m.isProjectedPermutation(); }) && 84 hasMultiplyAddBody(genericOp.region())); 85 } 86 87 static bool hasOnlyScalarElementwiseOp(Region &r) { 88 if (!llvm::hasSingleElement(r)) 89 return false; 90 for (Operation &op : r.front()) { 91 if (!(isa<ConstantOp, linalg::YieldOp>(op) || 92 op.hasTrait<OpTrait::ElementwiseMappable>()) || 93 llvm::any_of(op.getResultTypes(), 94 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 95 return false; 96 } 97 return true; 98 } 99 100 // Return true if the op is an element-wise linalg op. 101 static bool isElementwise(Operation *op) { 102 auto genericOp = dyn_cast<linalg::GenericOp>(op); 103 if (!genericOp) 104 return false; 105 if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) 106 return false; 107 // TODO: relax the restrictions on indexing map. 108 for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) { 109 if (!genericOp.getOutputIndexingMap(i).isIdentity()) 110 return false; 111 } 112 // Currently limit the input indexing map to minor identity as other 113 // permutations might require adding transpose ops to convert the vector read 114 // to the right shape. 115 for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) { 116 if (!genericOp.getInputIndexingMap(i).isMinorIdentity()) 117 return false; 118 } 119 return hasOnlyScalarElementwiseOp(genericOp.getRegion()); 120 } 121 122 static VectorType extractVectorTypeFromScalarView(Value v) { 123 MemRefType mt = v.getType().cast<MemRefType>(); 124 return mt.getShape().empty() 125 ? VectorType() 126 : VectorType::get(mt.getShape(), mt.getElementType()); 127 } 128 129 static Value transferReadVector(OpBuilder &builder, Value memref) { 130 edsc::ScopedContext scope(builder); 131 auto memrefType = memref.getType().cast<MemRefType>(); 132 if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) { 133 SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0)); 134 return vector_transfer_read(vectorType, memref, indices); 135 } 136 return std_load(memref); 137 } 138 139 static void transferWriteVector(OpBuilder &builder, Value value, Value memref) { 140 edsc::ScopedContext scope(builder); 141 auto memrefType = memref.getType().cast<MemRefType>(); 142 if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) { 143 SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0)); 144 if (vectorType != value.getType()) 145 value = vector_broadcast(vectorType, value); 146 vector_transfer_write(value, memref, indices); 147 } else { 148 std_store(value, memref); 149 } 150 } 151 152 namespace { 153 // Transforms scalar operations into their vectorized counterparts, 154 // while using the provided generic op to map: 155 // * Its arguments to transfer reads from the views of the generic op. 156 // * linalg.yield ops to transfer writes to the views of the generic op. 157 class GenericVectorizer { 158 public: 159 GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic) 160 : builder(builder), generic(generic) {} 161 162 // Takes a scalar operation and builds its vectorized counterpart or 163 // counterparts using the underlying builder. 164 // If operands of the scalar operation are referring to previously vectorized 165 // operations, then in their vectorized form these operands will be referring 166 // to previous vectorization results. 167 void vectorize(Operation &scalarOp) { 168 auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp); 169 if (yieldOp) { 170 for (auto outputAndMemref : 171 llvm::zip(yieldOp.values(), generic.getOutputBuffers())) { 172 Value vectorValue = vectorize(std::get<0>(outputAndMemref)); 173 transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref)); 174 } 175 return; 176 } 177 Operation *vectorOp = uncachedVectorize(scalarOp); 178 assert(scalarOp.getNumResults() == vectorOp->getNumResults()); 179 for (auto result : 180 llvm::zip(scalarOp.getResults(), vectorOp->getResults())) { 181 valueCache[std::get<0>(result)] = std::get<1>(result); 182 } 183 } 184 185 private: 186 // Transforms a scalar value into its vectorized counterpart, recursively 187 // vectorizing operations as necessary using the underlying builder. 188 // Keeps track of previously vectorized values and reuses vectorization 189 // results if these values come up again. 190 Value vectorize(Value scalarValue) { 191 // Don't vectorize values coming from outside the region. 192 if (scalarValue.getParentRegion() != &generic.region()) 193 return scalarValue; 194 auto vectorValueIt = valueCache.find(scalarValue); 195 if (vectorValueIt != valueCache.end()) 196 return vectorValueIt->second; 197 198 // If the value is from the region but not in the cache it means it is a 199 // block argument. 200 auto scalarArg = scalarValue.cast<BlockArgument>(); 201 assert(scalarArg.getOwner() == &generic.region().front()); 202 Value vector_arg = 203 generic.getInputsAndOutputBuffers()[scalarArg.getArgNumber()]; 204 Value vectorResult = transferReadVector(builder, vector_arg); 205 valueCache[scalarArg] = vectorResult; 206 return vectorResult; 207 } 208 209 // Return the largest shape of all the given values. Return an empty 210 // SmallVector if there are no vector value. 211 static SmallVector<int64_t, 4> getLargestShape(ArrayRef<Value> values) { 212 SmallVector<int64_t, 4> largestShape; 213 int64_t maxSize = 1; 214 for (Value value : values) { 215 auto vecType = value.getType().dyn_cast<VectorType>(); 216 if (!vecType) 217 continue; 218 if (maxSize < vecType.getNumElements()) { 219 maxSize = vecType.getNumElements(); 220 largestShape.assign(vecType.getShape().begin(), 221 vecType.getShape().end()); 222 } 223 } 224 return largestShape; 225 } 226 227 // If the value's type doesn't have the given shape broadcast it. 228 Value broadcastIfNeeded(Value value, ArrayRef<int64_t> shape) { 229 auto vecType = value.getType().dyn_cast<VectorType>(); 230 if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) 231 return value; 232 auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() 233 : value.getType()); 234 return builder.create<vector::BroadcastOp>( 235 builder.getInsertionPoint()->getLoc(), newVecType, value); 236 } 237 238 // Takes a scalar operation and builds its vectorized counterpart or 239 // counterparts using underlying builder without involving any caches. 240 Operation *uncachedVectorize(Operation &base_scalarOp) { 241 SmallVector<Value, 4> vectorizedOperands; 242 for (Value operand : base_scalarOp.getOperands()) { 243 vectorizedOperands.push_back(vectorize(operand)); 244 } 245 SmallVector<int64_t, 4> shape = getLargestShape(vectorizedOperands); 246 for (Value &operand : vectorizedOperands) 247 operand = broadcastIfNeeded(operand, shape); 248 OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName()); 249 state.addAttributes(base_scalarOp.getAttrs()); 250 state.addOperands(vectorizedOperands); 251 if (shape.empty()) { 252 state.addTypes(base_scalarOp.getResultTypes()); 253 } else { 254 SmallVector<VectorType, 4> vectorizedTypes; 255 for (auto Type : base_scalarOp.getResultTypes()) 256 vectorizedTypes.push_back(VectorType::get(shape, Type)); 257 state.addTypes(vectorizedTypes); 258 } 259 return builder.createOperation(state); 260 } 261 262 OpBuilder &builder; 263 linalg::GenericOp generic; 264 llvm::DenseMap<Value, Value> valueCache; 265 }; 266 } // namespace 267 268 // Replaces elementwise linalg.generic ops with their bodies with scalar 269 // operations from these bodies promoted to vector operations. 270 static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) { 271 GenericVectorizer vectorizer(builder, op); 272 for (Operation &scalarOp : op.region().front()) { 273 vectorizer.vectorize(scalarOp); 274 } 275 } 276 277 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 278 auto linalgOp = cast<linalg::LinalgOp>(op); 279 // All types must be static shape to go to vector. 280 for (Value operand : linalgOp.getInputsAndOutputBuffers()) 281 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 282 return failure(); 283 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 284 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 285 return failure(); 286 287 if (isa<linalg::FillOp, linalg::CopyOp>(op)) 288 return success(); 289 if (isElementwise(op)) 290 return success(); 291 return isContraction(op); 292 } 293 294 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 295 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 296 297 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 298 (void)dbgPref; 299 edsc::ScopedContext scope(builder, op->getLoc()); 300 // In the case of 0-D memrefs, return null and special case to scalar load or 301 // store later. 302 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 303 // Vectorize fill as a vector.broadcast. 304 LLVM_DEBUG(dbgs() << dbgPref 305 << "Rewrite linalg.fill as vector.broadcast: " << *op); 306 transferWriteVector(builder, fillOp.value(), fillOp.output()); 307 return; 308 } 309 if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) { 310 // Vectorize copy as a vector.transfer_read+vector.transfer_write. 311 LLVM_DEBUG(dbgs() << dbgPref 312 << "Rewrite linalg.copy as vector.transfer_read + " 313 "vector.transfer_write: " 314 << *op); 315 Value vector = transferReadVector(builder, copyOp.input()); 316 transferWriteVector(builder, vector, copyOp.output()); 317 return; 318 } 319 320 if (isElementwise(op)) { 321 LLVM_DEBUG(dbgs() << dbgPref 322 << "Rewrite linalg op as vector.transfer_read + " 323 "vector_op + vector.transfer_write: " 324 << *op); 325 return vectorizeElementwise(cast<linalg::GenericOp>(op), builder); 326 } 327 328 assert(succeeded(isContraction(op)) && "Expected contraction"); 329 330 // Vectorize other ops as vector contraction. 331 // TODO: interface. 332 LLVM_DEBUG(dbgs() << dbgPref 333 << "Rewrite linalg op as vector.contract: " << *op); 334 auto linalgOp = cast<linalg::LinalgOp>(op); 335 Value viewA = linalgOp.getInput(0); 336 Value viewB = linalgOp.getInput(1); 337 Value viewC = linalgOp.getOutputBuffer(0); 338 VectorType vtA = extractVectorTypeFromScalarView(viewA); 339 VectorType vtB = extractVectorTypeFromScalarView(viewB); 340 VectorType vtC = extractVectorTypeFromScalarView(viewC); 341 Value zero = std_constant_index(0); 342 SmallVector<Value, 4> indicesA, indicesB, indicesC; 343 if (vtA) 344 indicesA = SmallVector<Value, 4>(vtA.getRank(), zero); 345 if (vtB) 346 indicesB = SmallVector<Value, 4>(vtB.getRank(), zero); 347 if (vtC) 348 indicesC = SmallVector<Value, 4>(vtC.getRank(), zero); 349 Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value 350 : std_load(viewA, indicesA).value; 351 Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value 352 : std_load(viewB, indicesB).value; 353 Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value 354 : std_load(viewC, indicesC).value; 355 Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), 356 linalgOp.iterator_types()); 357 if (vtC) 358 vector_transfer_write(res, viewC, indicesC); 359 else 360 std_store(res, viewC, indicesC); 361 } 362 363 /// Check whether there is any interleaved use of any `values` between `firstOp` 364 /// and `secondOp`. Conservatively return `true` if any op or value is in a 365 /// different block. 366 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 367 ValueRange values) { 368 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 369 (void)dbgPref; 370 if (firstOp->getBlock() != secondOp->getBlock() || 371 !firstOp->isBeforeInBlock(secondOp)) { 372 LLVM_DEBUG(llvm::dbgs() 373 << dbgPref << "interleavedUses precondition failed, firstOp: " 374 << *firstOp << ", second op: " << *secondOp); 375 return true; 376 } 377 for (auto v : values) { 378 for (auto &u : v.getUses()) { 379 Operation *owner = u.getOwner(); 380 if (owner == firstOp || owner == secondOp) 381 continue; 382 // TODO: this is too conservative, use dominance info in the future. 383 if (owner->getBlock() == firstOp->getBlock() && 384 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 385 continue; 386 LLVM_DEBUG(llvm::dbgs() 387 << dbgPref << " found interleaved op " << *owner 388 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 389 return true; 390 } 391 } 392 return false; 393 } 394 395 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 396 static SubViewOp getSubViewUseIfUnique(Value v) { 397 SubViewOp subViewOp; 398 for (auto &u : v.getUses()) { 399 if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) { 400 if (subViewOp) 401 return SubViewOp(); 402 subViewOp = newSubViewOp; 403 } 404 } 405 return subViewOp; 406 } 407 408 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 409 /// when available. 410 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 411 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 412 413 // Transfer into `view`. 414 Value viewOrAlloc = xferOp.memref(); 415 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 416 !viewOrAlloc.getDefiningOp<AllocOp>()) 417 return failure(); 418 419 StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; 420 (void)dbgPref; 421 LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); 422 423 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 424 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 425 if (!subViewOp) 426 return failure(); 427 Value subView = subViewOp.getResult(); 428 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); 429 430 // Find the copy into `subView` without interleaved uses. 431 CopyOp copyOp; 432 for (auto &u : subView.getUses()) { 433 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 434 if (newCopyOp.getOutputBuffer(0) != subView) 435 continue; 436 LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); 437 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 438 continue; 439 copyOp = newCopyOp; 440 break; 441 } 442 } 443 if (!copyOp) 444 return failure(); 445 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); 446 447 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 448 FillOp maybeFillOp; 449 for (auto &u : viewOrAlloc.getUses()) { 450 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 451 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 452 continue; 453 LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); 454 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 455 continue; 456 maybeFillOp = newFillOp; 457 break; 458 } 459 } 460 // Ensure padding matches. 461 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 462 return failure(); 463 if (maybeFillOp) 464 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); 465 466 // `in` is the subview that linalg.copy reads. Replace it. 467 Value in = copyOp.getInput(0); 468 469 // linalg.copy + linalg.fill can be used to create a padded local buffer. 470 // The `masked` attribute is only valid on this padded buffer. 471 // When forwarding to vector.transfer_read, the attribute must be reset 472 // conservatively. 473 Value res = rewriter.create<vector::TransferReadOp>( 474 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 475 xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); 476 477 if (maybeFillOp) 478 rewriter.eraseOp(maybeFillOp); 479 rewriter.eraseOp(copyOp); 480 rewriter.replaceOp(xferOp, res); 481 482 return success(); 483 } 484 485 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 486 /// when available. 487 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 488 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 489 // Transfer into `viewOrAlloc`. 490 Value viewOrAlloc = xferOp.memref(); 491 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 492 !viewOrAlloc.getDefiningOp<AllocOp>()) 493 return failure(); 494 495 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 496 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 497 if (!subViewOp) 498 return failure(); 499 Value subView = subViewOp.getResult(); 500 501 // Find the copy from `subView` without interleaved uses. 502 CopyOp copyOp; 503 for (auto &u : subViewOp.getResult().getUses()) { 504 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 505 if (newCopyOp.getInput(0) != subView) 506 continue; 507 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 508 continue; 509 copyOp = newCopyOp; 510 break; 511 } 512 } 513 if (!copyOp) 514 return failure(); 515 516 // `out` is the subview copied into that we replace. 517 Value out = copyOp.getOutputBuffer(0); 518 519 // Forward vector.transfer into copy. 520 // linalg.copy + linalg.fill can be used to create a padded local buffer. 521 // The `masked` attribute is only valid on this padded buffer. 522 // When forwarding to vector.transfer_write, the attribute must be reset 523 // conservatively. 524 rewriter.create<vector::TransferWriteOp>( 525 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 526 xferOp.permutation_map(), ArrayAttr()); 527 528 rewriter.eraseOp(copyOp); 529 rewriter.eraseOp(xferOp); 530 531 return success(); 532 } 533 534 template <class ConvOp, int N> 535 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite( 536 ConvOp op, PatternRewriter &rewriter) const { 537 Location loc = op.getLoc(); 538 MLIRContext *context = op.getContext(); 539 edsc::ScopedContext scope(rewriter, loc); 540 541 ShapedType inShapeType = op.getInputShapedType(0); 542 ShapedType kShapeType = op.getInputShapedType(1); 543 544 ArrayRef<int64_t> inShape = inShapeType.getShape(); 545 ArrayRef<int64_t> kShape = kShapeType.getShape(); 546 547 if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) 548 return failure(); 549 550 SmallVector<AffineExpr, 4> mapping; 551 SmallVector<int64_t, 4> vectorDims; 552 // Fail to apply when the size of not vectorized dimension is not 1. 553 for (unsigned i = 0; i < N; i++) { 554 if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) 555 return failure(); 556 557 if (mask[i] && inShape[i] != kShape[i]) 558 return failure(); 559 560 if (mask[i]) { 561 mapping.push_back(getAffineDimExpr(i, context)); 562 vectorDims.push_back(inShape[i]); 563 } 564 } 565 566 Value input = op.getInput(0); 567 Value kernel = op.getInput(1); 568 Value output = op.getOutputBuffer(0); 569 570 unsigned rank = inShapeType.getRank(); 571 unsigned numDims = mapping.size(); 572 Type elemType = inShapeType.getElementType(); 573 574 auto map = AffineMap::get(rank, 0, mapping, context); 575 SmallVector<Value, 4> zeros(rank, std_constant_index(0)); 576 auto vecType = VectorType::get(vectorDims, elemType); 577 578 auto inputVec = vector_transfer_read(vecType, input, zeros, map); 579 auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); 580 581 auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); 582 583 std::array<AffineMap, 3> indexingMaps{ 584 AffineMap::getMultiDimIdentityMap(numDims, context), 585 AffineMap::getMultiDimIdentityMap(numDims, context), 586 AffineMap::get(numDims, 0, {}, context)}; 587 588 std::vector<StringRef> iteratorTypes(numDims, "reduction"); 589 590 auto result = rewriter.create<vector::ContractionOp>( 591 loc, inputVec, kernelVec, acc, 592 rewriter.getAffineMapArrayAttr(indexingMaps), 593 rewriter.getStrArrayAttr(iteratorTypes)); 594 595 rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros)); 596 rewriter.eraseOp(op); 597 return success(); 598 } 599 600 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>; 601 602 /// Inserts tiling, promotion and vectorization pattern for ConvOp 603 /// conversion into corresponding pattern lists. 604 template <typename ConvOp, unsigned N> 605 static void 606 populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, 607 OwningRewritePatternList &promotionPatterns, 608 OwningRewritePatternList &vectorizationPatterns, 609 ArrayRef<int64_t> tileSizes, 610 MLIRContext *context) { 611 if (tileSizes.size() < N) 612 return; 613 614 constexpr static StringRef kTiledMarker = "TILED"; 615 constexpr static StringRef kPromotedMarker = "PROMOTED"; 616 tilingPatterns.insert<LinalgTilingPattern<ConvOp>>( 617 context, LinalgTilingOptions().setTileSizes(tileSizes), 618 LinalgMarker({}, Identifier::get(kTiledMarker, context))); 619 620 promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>( 621 context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 622 LinalgMarker(Identifier::get(kTiledMarker, context), 623 Identifier::get(kPromotedMarker, context))); 624 625 SmallVector<bool, 4> mask(N); 626 int offset = tileSizes.size() - N; 627 std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), 628 [](int64_t i) -> bool { return i > 1; }); 629 630 vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask); 631 } 632 633 void mlir::linalg::populateConvVectorizationPatterns( 634 MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns, 635 ArrayRef<int64_t> tileSizes) { 636 OwningRewritePatternList tiling, promotion, vectorization; 637 populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization, 638 tileSizes, context); 639 640 populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization, 641 tileSizes, context); 642 643 populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization, 644 tileSizes, context); 645 646 populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization, 647 tileSizes, context); 648 649 populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization, 650 tileSizes, context); 651 652 populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization, 653 tileSizes, context); 654 655 populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization, 656 tileSizes, context); 657 658 populateVectorizationPatterns<ConvNDHWCOp, 5>( 659 tiling, promotion, vectorization, tileSizes, context); 660 661 populateVectorizationPatterns<ConvNCDHWOp, 5>( 662 tiling, promotion, vectorization, tileSizes, context); 663 664 patterns.push_back(std::move(tiling)); 665 patterns.push_back(std::move(promotion)); 666 patterns.push_back(std::move(vectorization)); 667 } 668