1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 #define DEBUG_TYPE "linalg-vectorization" 38 39 static bool hasMultiplyAddBody(Region &r) { 40 if (!llvm::hasSingleElement(r)) 41 return false; 42 if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) 43 return false; 44 45 using mlir::matchers::m_Val; 46 auto a = m_Val(r.getArgument(0)); 47 auto b = m_Val(r.getArgument(1)); 48 auto c = m_Val(r.getArgument(2)); 49 // TODO: Update this detection once we have matcher support for specifying 50 // that any permutation of operands matches. 51 auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); 52 auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); 53 auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); 54 auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); 55 auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c)); 56 auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b))); 57 auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c)); 58 auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a))); 59 return pattern1.match(&r.front().back()) || 60 pattern2.match(&r.front().back()) || 61 pattern3.match(&r.front().back()) || 62 pattern4.match(&r.front().back()) || 63 pattern5.match(&r.front().back()) || 64 pattern6.match(&r.front().back()) || 65 pattern7.match(&r.front().back()) || pattern8.match(&r.front().back()); 66 } 67 68 // TODO: Should be Tablegen'd from a single source that generates the op itself. 69 static LogicalResult isContraction(Operation *op) { 70 // TODO: interface for named ops. 71 if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp, 72 linalg::VecmatOp, linalg::DotOp>(op)) 73 return success(); 74 75 auto genericOp = dyn_cast<linalg::GenericOp>(op); 76 if (!genericOp) 77 return failure(); 78 79 auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>(); 80 return success( 81 genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && 82 llvm::all_of(mapRange, 83 [](AffineMap m) { return m.isProjectedPermutation(); }) && 84 hasMultiplyAddBody(genericOp.region())); 85 } 86 87 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 88 auto linalgOp = cast<linalg::LinalgOp>(op); 89 // All types must be static shape to go to vector. 90 for (Value operand : linalgOp.getInputsAndOutputBuffers()) 91 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 92 return failure(); 93 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 94 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 95 return failure(); 96 97 if (isa<linalg::FillOp, linalg::CopyOp>(op)) 98 return success(); 99 100 return isContraction(op); 101 } 102 103 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 104 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 105 106 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 107 (void)dbgPref; 108 edsc::ScopedContext scope(builder, op->getLoc()); 109 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 110 // Vectorize fill as a vector.broadcast. 111 LLVM_DEBUG(dbgs() << dbgPref 112 << "Rewrite linalg.fill as vector.broadcast: " << *op); 113 Value memref = vector_type_cast(fillOp.getOutputBuffer(0)); 114 Value dst = std_load(memref); 115 Value res = vector_broadcast(dst.getType(), fillOp.value()); 116 std_store(res, memref); 117 return; 118 } 119 120 // In the case of 0-D memrefs, return null and special case to scalar load or 121 // store later. 122 auto extractVectorTypeFromScalarView = [](Value v) { 123 MemRefType mt = v.getType().cast<MemRefType>(); 124 return mt.getShape().empty() 125 ? VectorType() 126 : VectorType::get(mt.getShape(), mt.getElementType()); 127 }; 128 129 if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) { 130 // Vectorize copy as a vector.transfer_read+vector.transfer_write. 131 LLVM_DEBUG(dbgs() << dbgPref 132 << "Rewrite linalg.copy as vector.transfer_read + " 133 "vector.transfer_write: " 134 << *op); 135 Value zero = std_constant_index(0); 136 Value viewInput = copyOp.input(); 137 Value viewOutput = copyOp.output(); 138 Value vector; 139 if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) { 140 SmallVector<Value, 4> indicesInput(inputType.getRank(), zero); 141 if (copyOp.inputPermutation()) 142 vector = vector_transfer_read( 143 extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput, 144 copyOp.inputPermutation().getValue()); 145 else 146 vector = 147 vector_transfer_read(extractVectorTypeFromScalarView(viewInput), 148 viewInput, indicesInput); 149 } else { 150 vector = std_load(viewInput).value; 151 } 152 if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) { 153 SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero); 154 if (copyOp.outputPermutation()) 155 vector_transfer_write(vector, viewOutput, indicesOutput, 156 copyOp.outputPermutation().getValue()); 157 else 158 vector_transfer_write(vector, viewOutput, indicesOutput); 159 } else { 160 std_store(vector, viewOutput); 161 } 162 return; 163 } 164 165 assert(succeeded(isContraction(op)) && "Expected contraction"); 166 167 // Vectorize other ops as vector contraction. 168 // TODO: interface. 169 LLVM_DEBUG(dbgs() << dbgPref 170 << "Rewrite linalg op as vector.contract: " << *op); 171 auto linalgOp = cast<linalg::LinalgOp>(op); 172 Value viewA = linalgOp.getInput(0); 173 Value viewB = linalgOp.getInput(1); 174 Value viewC = linalgOp.getOutputBuffer(0); 175 VectorType vtA = extractVectorTypeFromScalarView(viewA); 176 VectorType vtB = extractVectorTypeFromScalarView(viewB); 177 VectorType vtC = extractVectorTypeFromScalarView(viewC); 178 Value zero = std_constant_index(0); 179 SmallVector<Value, 4> indicesA, indicesB, indicesC; 180 if (vtA) 181 indicesA = SmallVector<Value, 4>(vtA.getRank(), zero); 182 if (vtB) 183 indicesB = SmallVector<Value, 4>(vtB.getRank(), zero); 184 if (vtC) 185 indicesC = SmallVector<Value, 4>(vtC.getRank(), zero); 186 Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value 187 : std_load(viewA, indicesA).value; 188 Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value 189 : std_load(viewB, indicesB).value; 190 Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value 191 : std_load(viewC, indicesC).value; 192 Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), 193 linalgOp.iterator_types()); 194 if (vtC) 195 vector_transfer_write(res, viewC, indicesC); 196 else 197 std_store(res, viewC, indicesC); 198 } 199 200 /// Check whether there is any interleaved use of any `values` between `firstOp` 201 /// and `secondOp`. Conservatively return `true` if any op or value is in a 202 /// different block. 203 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 204 ValueRange values) { 205 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 206 (void)dbgPref; 207 if (firstOp->getBlock() != secondOp->getBlock() || 208 !firstOp->isBeforeInBlock(secondOp)) { 209 LLVM_DEBUG(llvm::dbgs() 210 << dbgPref << "interleavedUses precondition failed, firstOp: " 211 << *firstOp << ", second op: " << *secondOp); 212 return true; 213 } 214 for (auto v : values) { 215 for (auto &u : v.getUses()) { 216 Operation *owner = u.getOwner(); 217 if (owner == firstOp || owner == secondOp) 218 continue; 219 // TODO: this is too conservative, use dominance info in the future. 220 if (owner->getBlock() == firstOp->getBlock() && 221 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 222 continue; 223 LLVM_DEBUG(llvm::dbgs() 224 << dbgPref << " found interleaved op " << *owner 225 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 226 return true; 227 } 228 } 229 return false; 230 } 231 232 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 233 static SubViewOp getSubViewUseIfUnique(Value v) { 234 SubViewOp subViewOp; 235 for (auto &u : v.getUses()) { 236 if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) { 237 if (subViewOp) 238 return SubViewOp(); 239 subViewOp = newSubViewOp; 240 } 241 } 242 return subViewOp; 243 } 244 245 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 246 /// when available. 247 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 248 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 249 250 // Transfer into `view`. 251 Value viewOrAlloc = xferOp.memref(); 252 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 253 !viewOrAlloc.getDefiningOp<AllocOp>()) 254 return failure(); 255 256 StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; 257 (void)dbgPref; 258 LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); 259 260 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 261 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 262 if (!subViewOp) 263 return failure(); 264 Value subView = subViewOp.getResult(); 265 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); 266 267 // Find the copy into `subView` without interleaved uses. 268 CopyOp copyOp; 269 for (auto &u : subView.getUses()) { 270 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 271 if (newCopyOp.getOutputBuffer(0) != subView) 272 continue; 273 LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); 274 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 275 continue; 276 copyOp = newCopyOp; 277 break; 278 } 279 } 280 if (!copyOp) 281 return failure(); 282 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); 283 284 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 285 FillOp maybeFillOp; 286 for (auto &u : viewOrAlloc.getUses()) { 287 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 288 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 289 continue; 290 LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); 291 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 292 continue; 293 maybeFillOp = newFillOp; 294 break; 295 } 296 } 297 // Ensure padding matches. 298 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 299 return failure(); 300 if (maybeFillOp) 301 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); 302 303 // `in` is the subview that linalg.copy reads. Replace it. 304 Value in = copyOp.getInput(0); 305 306 // linalg.copy + linalg.fill can be used to create a padded local buffer. 307 // The `masked` attribute is only valid on this padded buffer. 308 // When forwarding to vector.transfer_read, the attribute must be reset 309 // conservatively. 310 Value res = rewriter.create<vector::TransferReadOp>( 311 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 312 xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); 313 314 if (maybeFillOp) 315 rewriter.eraseOp(maybeFillOp); 316 rewriter.eraseOp(copyOp); 317 rewriter.replaceOp(xferOp, res); 318 319 return success(); 320 } 321 322 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 323 /// when available. 324 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 325 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 326 // Transfer into `viewOrAlloc`. 327 Value viewOrAlloc = xferOp.memref(); 328 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 329 !viewOrAlloc.getDefiningOp<AllocOp>()) 330 return failure(); 331 332 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 333 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 334 if (!subViewOp) 335 return failure(); 336 Value subView = subViewOp.getResult(); 337 338 // Find the copy from `subView` without interleaved uses. 339 CopyOp copyOp; 340 for (auto &u : subViewOp.getResult().getUses()) { 341 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 342 if (newCopyOp.getInput(0) != subView) 343 continue; 344 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 345 continue; 346 copyOp = newCopyOp; 347 break; 348 } 349 } 350 if (!copyOp) 351 return failure(); 352 353 // `out` is the subview copied into that we replace. 354 Value out = copyOp.getOutputBuffer(0); 355 356 // Forward vector.transfer into copy. 357 // linalg.copy + linalg.fill can be used to create a padded local buffer. 358 // The `masked` attribute is only valid on this padded buffer. 359 // When forwarding to vector.transfer_write, the attribute must be reset 360 // conservatively. 361 rewriter.create<vector::TransferWriteOp>( 362 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 363 xferOp.permutation_map(), ArrayAttr()); 364 365 rewriter.eraseOp(copyOp); 366 rewriter.eraseOp(xferOp); 367 368 return success(); 369 } 370 371 template <class ConvOp, int N> 372 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite( 373 ConvOp op, PatternRewriter &rewriter) const { 374 Location loc = op.getLoc(); 375 MLIRContext *context = op.getContext(); 376 edsc::ScopedContext scope(rewriter, loc); 377 378 ShapedType inShapeType = op.getInputShapedType(0); 379 ShapedType kShapeType = op.getInputShapedType(1); 380 381 ArrayRef<int64_t> inShape = inShapeType.getShape(); 382 ArrayRef<int64_t> kShape = kShapeType.getShape(); 383 384 if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) 385 return failure(); 386 387 SmallVector<AffineExpr, 4> mapping; 388 // Fail to apply when the size of not vectorized dimension is not 1 or 389 // when the size of vectorized dimension is not dimSize. 390 for (unsigned i = 0; i < N; i++) { 391 if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) 392 return failure(); 393 if (mask[i] && (inShape[i] != tileSize || kShape[i] != tileSize)) 394 return failure(); 395 396 if (mask[i]) 397 mapping.push_back(getAffineDimExpr(i, context)); 398 } 399 400 Value input = op.getInput(0); 401 Value kernel = op.getInput(1); 402 Value output = op.getOutputBuffer(0); 403 404 unsigned rank = inShapeType.getRank(); 405 unsigned numDims = mapping.size(); 406 Type elemType = inShapeType.getElementType(); 407 408 auto map = AffineMap::get(rank, 0, mapping, context); 409 SmallVector<Value, 4> zeros(rank, std_constant_index(0)); 410 auto vecType = 411 VectorType::get(SmallVector<int64_t, 4>(numDims, tileSize), elemType); 412 413 auto inputVec = vector_transfer_read(vecType, input, zeros, map); 414 auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); 415 416 auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); 417 418 std::array<AffineMap, 3> indexingMaps{ 419 AffineMap::getMultiDimIdentityMap(numDims, context), 420 AffineMap::getMultiDimIdentityMap(numDims, context), 421 AffineMap::get(numDims, 0, {}, context)}; 422 423 std::vector<StringRef> iteratorTypes(numDims, "reduction"); 424 425 auto result = rewriter.create<vector::ContractionOp>( 426 loc, inputVec, kernelVec, acc, 427 rewriter.getAffineMapArrayAttr(indexingMaps), 428 rewriter.getStrArrayAttr(iteratorTypes)); 429 430 rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros)); 431 rewriter.eraseOp(op); 432 return success(); 433 } 434 435 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>; 436 437 /// Inserts tiling, promotion and vectorization pattern for ConvOp 438 /// conversion into corresponding pattern lists. 439 template <typename ConvOp, unsigned N> 440 static void 441 populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, 442 OwningRewritePatternList &promotionPatterns, 443 OwningRewritePatternList &vectorizationPatterns, 444 ArrayRef<int64_t> tileSizes, 445 MLIRContext *context) { 446 constexpr static StringRef kTiledMarker = "TILED"; 447 constexpr static StringRef kPromotedMarker = "PROMOTED"; 448 tilingPatterns.insert<LinalgTilingPattern<ConvOp>>( 449 context, LinalgTilingOptions().setTileSizes(tileSizes), 450 LinalgMarker({}, Identifier::get(kTiledMarker, context))); 451 452 promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>( 453 context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 454 LinalgMarker(Identifier::get(kTiledMarker, context), 455 Identifier::get(kPromotedMarker, context))); 456 457 SmallVector<bool, 4> mask(N); 458 int offset = tileSizes.size() - N; 459 std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), 460 [](int64_t i) -> bool { return i != ConvOpConst::noTile; }); 461 462 vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask); 463 } 464 465 void mlir::linalg::populateConvVectorizationPatterns( 466 MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns) { 467 const int64_t tileSize = ConvOpConst::tileSize; 468 const int64_t noTile = ConvOpConst::noTile; 469 auto makeTileSizes = [&](unsigned numNoTile, unsigned numTile) { 470 SmallVector<int64_t, 10> result(numNoTile, noTile); 471 result.append(numTile, tileSize); 472 return result; 473 }; 474 475 OwningRewritePatternList tiling, promotion, vectorization; 476 populateVectorizationPatterns<ConvWOp, 1>( 477 tiling, promotion, vectorization, 478 makeTileSizes(/*numNoTile=*/1, /*numTile*/ 1), context); 479 480 populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization, 481 makeTileSizes(3, 2), context); 482 483 populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization, 484 makeTileSizes(3, 2), context); 485 486 populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization, 487 makeTileSizes(2, 2), context); 488 489 populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization, 490 makeTileSizes(4, 3), context); 491 492 populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization, 493 makeTileSizes(4, 3), context); 494 495 populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization, 496 makeTileSizes(3, 3), context); 497 498 populateVectorizationPatterns<ConvNDHWCOp, 5>( 499 tiling, promotion, vectorization, makeTileSizes(5, 4), context); 500 501 populateVectorizationPatterns<ConvNCDHWOp, 5>( 502 tiling, promotion, vectorization, makeTileSizes(5, 4), context); 503 504 patterns.push_back(std::move(tiling)); 505 patterns.push_back(std::move(promotion)); 506 patterns.push_back(std::move(vectorization)); 507 } 508