1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 #define DEBUG_TYPE "linalg-vectorization" 38 39 static bool hasMultiplyAddBody(Region &r) { 40 if (!llvm::hasSingleElement(r)) 41 return false; 42 if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) 43 return false; 44 45 using mlir::matchers::m_Val; 46 auto a = m_Val(r.getArgument(0)); 47 auto b = m_Val(r.getArgument(1)); 48 auto c = m_Val(r.getArgument(2)); 49 // TODO: Update this detection once we have matcher support for specifying 50 // that any permutation of operands matches. 51 auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); 52 auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); 53 auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); 54 auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); 55 auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c)); 56 auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b))); 57 auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c)); 58 auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a))); 59 return pattern1.match(&r.front().back()) || 60 pattern2.match(&r.front().back()) || 61 pattern3.match(&r.front().back()) || 62 pattern4.match(&r.front().back()) || 63 pattern5.match(&r.front().back()) || 64 pattern6.match(&r.front().back()) || 65 pattern7.match(&r.front().back()) || pattern8.match(&r.front().back()); 66 } 67 68 // TODO: Should be Tablegen'd from a single source that generates the op itself. 69 static LogicalResult isContraction(Operation *op) { 70 // TODO: interface for named ops. 71 if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp, 72 linalg::VecmatOp, linalg::DotOp>(op)) 73 return success(); 74 75 auto genericOp = dyn_cast<linalg::GenericOp>(op); 76 if (!genericOp) 77 return failure(); 78 79 auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>(); 80 return success( 81 genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && 82 llvm::all_of(mapRange, 83 [](AffineMap m) { return m.isProjectedPermutation(); }) && 84 hasMultiplyAddBody(genericOp.region())); 85 } 86 87 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 88 auto linalgOp = cast<linalg::LinalgOp>(op); 89 // All types must be static shape to go to vector. 90 for (Value operand : linalgOp.getInputsAndOutputBuffers()) 91 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 92 return failure(); 93 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 94 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 95 return failure(); 96 97 if (isa<linalg::FillOp, linalg::CopyOp>(op)) 98 return success(); 99 100 return isContraction(op); 101 } 102 103 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 104 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 105 106 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 107 (void)dbgPref; 108 edsc::ScopedContext scope(builder, op->getLoc()); 109 // In the case of 0-D memrefs, return null and special case to scalar load or 110 // store later. 111 auto extractVectorTypeFromScalarView = [](Value v) { 112 MemRefType mt = v.getType().cast<MemRefType>(); 113 return mt.getShape().empty() 114 ? VectorType() 115 : VectorType::get(mt.getShape(), mt.getElementType()); 116 }; 117 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 118 // Vectorize fill as a vector.broadcast. 119 LLVM_DEBUG(dbgs() << dbgPref 120 << "Rewrite linalg.fill as vector.broadcast: " << *op); 121 Value viewOutput = fillOp.output(); 122 if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) { 123 auto vecType = 124 VectorType::get(fillOp.getOutputBufferType(0).getShape(), 125 fillOp.getOutputBufferType(0).getElementType()); 126 Value vector = vector_broadcast(vecType, fillOp.value()); 127 Value zero = std_constant_index(0); 128 SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero); 129 vector_transfer_write(vector, viewOutput, indicesOutput); 130 } else { 131 std_store(fillOp.value(), viewOutput); 132 } 133 return; 134 } 135 if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) { 136 // Vectorize copy as a vector.transfer_read+vector.transfer_write. 137 LLVM_DEBUG(dbgs() << dbgPref 138 << "Rewrite linalg.copy as vector.transfer_read + " 139 "vector.transfer_write: " 140 << *op); 141 Value zero = std_constant_index(0); 142 Value viewInput = copyOp.input(); 143 Value viewOutput = copyOp.output(); 144 Value vector; 145 if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) { 146 SmallVector<Value, 4> indicesInput(inputType.getRank(), zero); 147 if (copyOp.inputPermutation()) 148 vector = vector_transfer_read( 149 extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput, 150 copyOp.inputPermutation().getValue()); 151 else 152 vector = 153 vector_transfer_read(extractVectorTypeFromScalarView(viewInput), 154 viewInput, indicesInput); 155 } else { 156 vector = std_load(viewInput).value; 157 } 158 if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) { 159 SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero); 160 if (copyOp.outputPermutation()) 161 vector_transfer_write(vector, viewOutput, indicesOutput, 162 copyOp.outputPermutation().getValue()); 163 else 164 vector_transfer_write(vector, viewOutput, indicesOutput); 165 } else { 166 std_store(vector, viewOutput); 167 } 168 return; 169 } 170 171 assert(succeeded(isContraction(op)) && "Expected contraction"); 172 173 // Vectorize other ops as vector contraction. 174 // TODO: interface. 175 LLVM_DEBUG(dbgs() << dbgPref 176 << "Rewrite linalg op as vector.contract: " << *op); 177 auto linalgOp = cast<linalg::LinalgOp>(op); 178 Value viewA = linalgOp.getInput(0); 179 Value viewB = linalgOp.getInput(1); 180 Value viewC = linalgOp.getOutputBuffer(0); 181 VectorType vtA = extractVectorTypeFromScalarView(viewA); 182 VectorType vtB = extractVectorTypeFromScalarView(viewB); 183 VectorType vtC = extractVectorTypeFromScalarView(viewC); 184 Value zero = std_constant_index(0); 185 SmallVector<Value, 4> indicesA, indicesB, indicesC; 186 if (vtA) 187 indicesA = SmallVector<Value, 4>(vtA.getRank(), zero); 188 if (vtB) 189 indicesB = SmallVector<Value, 4>(vtB.getRank(), zero); 190 if (vtC) 191 indicesC = SmallVector<Value, 4>(vtC.getRank(), zero); 192 Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value 193 : std_load(viewA, indicesA).value; 194 Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value 195 : std_load(viewB, indicesB).value; 196 Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value 197 : std_load(viewC, indicesC).value; 198 Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), 199 linalgOp.iterator_types()); 200 if (vtC) 201 vector_transfer_write(res, viewC, indicesC); 202 else 203 std_store(res, viewC, indicesC); 204 } 205 206 /// Check whether there is any interleaved use of any `values` between `firstOp` 207 /// and `secondOp`. Conservatively return `true` if any op or value is in a 208 /// different block. 209 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 210 ValueRange values) { 211 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 212 (void)dbgPref; 213 if (firstOp->getBlock() != secondOp->getBlock() || 214 !firstOp->isBeforeInBlock(secondOp)) { 215 LLVM_DEBUG(llvm::dbgs() 216 << dbgPref << "interleavedUses precondition failed, firstOp: " 217 << *firstOp << ", second op: " << *secondOp); 218 return true; 219 } 220 for (auto v : values) { 221 for (auto &u : v.getUses()) { 222 Operation *owner = u.getOwner(); 223 if (owner == firstOp || owner == secondOp) 224 continue; 225 // TODO: this is too conservative, use dominance info in the future. 226 if (owner->getBlock() == firstOp->getBlock() && 227 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 228 continue; 229 LLVM_DEBUG(llvm::dbgs() 230 << dbgPref << " found interleaved op " << *owner 231 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 232 return true; 233 } 234 } 235 return false; 236 } 237 238 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 239 static SubViewOp getSubViewUseIfUnique(Value v) { 240 SubViewOp subViewOp; 241 for (auto &u : v.getUses()) { 242 if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) { 243 if (subViewOp) 244 return SubViewOp(); 245 subViewOp = newSubViewOp; 246 } 247 } 248 return subViewOp; 249 } 250 251 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 252 /// when available. 253 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 254 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 255 256 // Transfer into `view`. 257 Value viewOrAlloc = xferOp.memref(); 258 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 259 !viewOrAlloc.getDefiningOp<AllocOp>()) 260 return failure(); 261 262 StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; 263 (void)dbgPref; 264 LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); 265 266 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 267 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 268 if (!subViewOp) 269 return failure(); 270 Value subView = subViewOp.getResult(); 271 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); 272 273 // Find the copy into `subView` without interleaved uses. 274 CopyOp copyOp; 275 for (auto &u : subView.getUses()) { 276 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 277 if (newCopyOp.getOutputBuffer(0) != subView) 278 continue; 279 LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); 280 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 281 continue; 282 copyOp = newCopyOp; 283 break; 284 } 285 } 286 if (!copyOp) 287 return failure(); 288 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); 289 290 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 291 FillOp maybeFillOp; 292 for (auto &u : viewOrAlloc.getUses()) { 293 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 294 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 295 continue; 296 LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); 297 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 298 continue; 299 maybeFillOp = newFillOp; 300 break; 301 } 302 } 303 // Ensure padding matches. 304 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 305 return failure(); 306 if (maybeFillOp) 307 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); 308 309 // `in` is the subview that linalg.copy reads. Replace it. 310 Value in = copyOp.getInput(0); 311 312 // linalg.copy + linalg.fill can be used to create a padded local buffer. 313 // The `masked` attribute is only valid on this padded buffer. 314 // When forwarding to vector.transfer_read, the attribute must be reset 315 // conservatively. 316 Value res = rewriter.create<vector::TransferReadOp>( 317 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 318 xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); 319 320 if (maybeFillOp) 321 rewriter.eraseOp(maybeFillOp); 322 rewriter.eraseOp(copyOp); 323 rewriter.replaceOp(xferOp, res); 324 325 return success(); 326 } 327 328 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 329 /// when available. 330 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 331 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 332 // Transfer into `viewOrAlloc`. 333 Value viewOrAlloc = xferOp.memref(); 334 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 335 !viewOrAlloc.getDefiningOp<AllocOp>()) 336 return failure(); 337 338 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 339 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 340 if (!subViewOp) 341 return failure(); 342 Value subView = subViewOp.getResult(); 343 344 // Find the copy from `subView` without interleaved uses. 345 CopyOp copyOp; 346 for (auto &u : subViewOp.getResult().getUses()) { 347 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 348 if (newCopyOp.getInput(0) != subView) 349 continue; 350 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 351 continue; 352 copyOp = newCopyOp; 353 break; 354 } 355 } 356 if (!copyOp) 357 return failure(); 358 359 // `out` is the subview copied into that we replace. 360 Value out = copyOp.getOutputBuffer(0); 361 362 // Forward vector.transfer into copy. 363 // linalg.copy + linalg.fill can be used to create a padded local buffer. 364 // The `masked` attribute is only valid on this padded buffer. 365 // When forwarding to vector.transfer_write, the attribute must be reset 366 // conservatively. 367 rewriter.create<vector::TransferWriteOp>( 368 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 369 xferOp.permutation_map(), ArrayAttr()); 370 371 rewriter.eraseOp(copyOp); 372 rewriter.eraseOp(xferOp); 373 374 return success(); 375 } 376 377 template <class ConvOp, int N> 378 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite( 379 ConvOp op, PatternRewriter &rewriter) const { 380 Location loc = op.getLoc(); 381 MLIRContext *context = op.getContext(); 382 edsc::ScopedContext scope(rewriter, loc); 383 384 ShapedType inShapeType = op.getInputShapedType(0); 385 ShapedType kShapeType = op.getInputShapedType(1); 386 387 ArrayRef<int64_t> inShape = inShapeType.getShape(); 388 ArrayRef<int64_t> kShape = kShapeType.getShape(); 389 390 if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) 391 return failure(); 392 393 SmallVector<AffineExpr, 4> mapping; 394 SmallVector<int64_t, 4> vectorDims; 395 // Fail to apply when the size of not vectorized dimension is not 1. 396 for (unsigned i = 0; i < N; i++) { 397 if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) 398 return failure(); 399 400 if (mask[i] && inShape[i] != kShape[i]) 401 return failure(); 402 403 if (mask[i]) { 404 mapping.push_back(getAffineDimExpr(i, context)); 405 vectorDims.push_back(inShape[i]); 406 } 407 } 408 409 Value input = op.getInput(0); 410 Value kernel = op.getInput(1); 411 Value output = op.getOutputBuffer(0); 412 413 unsigned rank = inShapeType.getRank(); 414 unsigned numDims = mapping.size(); 415 Type elemType = inShapeType.getElementType(); 416 417 auto map = AffineMap::get(rank, 0, mapping, context); 418 SmallVector<Value, 4> zeros(rank, std_constant_index(0)); 419 auto vecType = VectorType::get(vectorDims, elemType); 420 421 auto inputVec = vector_transfer_read(vecType, input, zeros, map); 422 auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); 423 424 auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); 425 426 std::array<AffineMap, 3> indexingMaps{ 427 AffineMap::getMultiDimIdentityMap(numDims, context), 428 AffineMap::getMultiDimIdentityMap(numDims, context), 429 AffineMap::get(numDims, 0, {}, context)}; 430 431 std::vector<StringRef> iteratorTypes(numDims, "reduction"); 432 433 auto result = rewriter.create<vector::ContractionOp>( 434 loc, inputVec, kernelVec, acc, 435 rewriter.getAffineMapArrayAttr(indexingMaps), 436 rewriter.getStrArrayAttr(iteratorTypes)); 437 438 rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros)); 439 rewriter.eraseOp(op); 440 return success(); 441 } 442 443 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>; 444 445 /// Inserts tiling, promotion and vectorization pattern for ConvOp 446 /// conversion into corresponding pattern lists. 447 template <typename ConvOp, unsigned N> 448 static void 449 populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, 450 OwningRewritePatternList &promotionPatterns, 451 OwningRewritePatternList &vectorizationPatterns, 452 ArrayRef<int64_t> tileSizes, 453 MLIRContext *context) { 454 if (tileSizes.size() < N) 455 return; 456 457 constexpr static StringRef kTiledMarker = "TILED"; 458 constexpr static StringRef kPromotedMarker = "PROMOTED"; 459 tilingPatterns.insert<LinalgTilingPattern<ConvOp>>( 460 context, LinalgTilingOptions().setTileSizes(tileSizes), 461 LinalgMarker({}, Identifier::get(kTiledMarker, context))); 462 463 promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>( 464 context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 465 LinalgMarker(Identifier::get(kTiledMarker, context), 466 Identifier::get(kPromotedMarker, context))); 467 468 SmallVector<bool, 4> mask(N); 469 int offset = tileSizes.size() - N; 470 std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), 471 [](int64_t i) -> bool { return i > 1; }); 472 473 vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask); 474 } 475 476 void mlir::linalg::populateConvVectorizationPatterns( 477 MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns, 478 ArrayRef<int64_t> tileSizes) { 479 OwningRewritePatternList tiling, promotion, vectorization; 480 populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization, 481 tileSizes, context); 482 483 populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization, 484 tileSizes, context); 485 486 populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization, 487 tileSizes, context); 488 489 populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization, 490 tileSizes, context); 491 492 populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization, 493 tileSizes, context); 494 495 populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization, 496 tileSizes, context); 497 498 populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization, 499 tileSizes, context); 500 501 populateVectorizationPatterns<ConvNDHWCOp, 5>( 502 tiling, promotion, vectorization, tileSizes, context); 503 504 populateVectorizationPatterns<ConvNCDHWOp, 5>( 505 tiling, promotion, vectorization, tileSizes, context); 506 507 patterns.push_back(std::move(tiling)); 508 patterns.push_back(std::move(promotion)); 509 patterns.push_back(std::move(vectorization)); 510 } 511