1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 #define DEBUG_TYPE "linalg-vectorization" 38 39 static bool hasMultiplyAddBody(Region &r) { 40 if (!llvm::hasSingleElement(r)) 41 return false; 42 if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) 43 return false; 44 45 using mlir::matchers::m_Val; 46 auto a = m_Val(r.getArgument(0)); 47 auto b = m_Val(r.getArgument(1)); 48 auto c = m_Val(r.getArgument(2)); 49 // TODO: Update this detection once we have matcher support for specifying 50 // that any permutation of operands matches. 51 auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); 52 auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); 53 auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); 54 auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); 55 auto pattern5 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c)); 56 auto pattern6 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b))); 57 auto pattern7 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c)); 58 auto pattern8 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a))); 59 return pattern1.match(&r.front().back()) || 60 pattern2.match(&r.front().back()) || 61 pattern3.match(&r.front().back()) || 62 pattern4.match(&r.front().back()) || 63 pattern5.match(&r.front().back()) || 64 pattern6.match(&r.front().back()) || 65 pattern7.match(&r.front().back()) || pattern8.match(&r.front().back()); 66 } 67 68 // TODO: Should be Tablegen'd from a single source that generates the op itself. 69 static LogicalResult isContraction(Operation *op) { 70 // TODO: interface for named ops. 71 if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp, 72 linalg::DotOp>(op)) 73 return success(); 74 75 auto genericOp = dyn_cast<linalg::GenericOp>(op); 76 if (!genericOp) 77 return failure(); 78 79 auto mapRange = 80 genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>(); 81 82 return success( 83 genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && 84 llvm::all_of(mapRange, 85 [](AffineMap m) { return m.isProjectedPermutation(); }) && 86 hasMultiplyAddBody(genericOp.region())); 87 } 88 89 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 90 auto linalgOp = cast<linalg::LinalgOp>(op); 91 // All types must be static shape to go to vector. 92 for (Value operand : linalgOp.getInputsAndOutputBuffers()) 93 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 94 return failure(); 95 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 96 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 97 return failure(); 98 99 if (isa<linalg::FillOp, linalg::CopyOp>(op)) 100 return success(); 101 102 return isContraction(op); 103 } 104 105 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 106 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 107 108 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 109 (void)dbgPref; 110 edsc::ScopedContext scope(builder, op->getLoc()); 111 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 112 // Vectorize fill as a vector.broadcast. 113 LLVM_DEBUG(dbgs() << dbgPref 114 << "Rewrite linalg.fill as vector.broadcast: " << *op); 115 Value memref = vector_type_cast(fillOp.getOutputBuffer(0)); 116 Value dst = std_load(memref); 117 Value res = vector_broadcast(dst.getType(), fillOp.value()); 118 std_store(res, memref); 119 return; 120 } 121 122 // In the case of 0-D memrefs, return null and special case to scalar load or 123 // store later. 124 auto extractVectorTypeFromScalarView = [](Value v) { 125 MemRefType mt = v.getType().cast<MemRefType>(); 126 return mt.getShape().empty() 127 ? VectorType() 128 : VectorType::get(mt.getShape(), mt.getElementType()); 129 }; 130 131 if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) { 132 // Vectorize copy as a vector.transfer_read+vector.transfer_write. 133 LLVM_DEBUG(dbgs() << dbgPref 134 << "Rewrite linalg.copy as vector.transfer_read + " 135 "vector.transfer_write: " 136 << *op); 137 Value zero = std_constant_index(0); 138 Value viewInput = copyOp.input(); 139 Value viewOutput = copyOp.output(); 140 Value vector; 141 if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) { 142 SmallVector<Value, 4> indicesInput(inputType.getRank(), zero); 143 if (copyOp.inputPermutation()) 144 vector = vector_transfer_read( 145 extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput, 146 copyOp.inputPermutation().getValue()); 147 else 148 vector = 149 vector_transfer_read(extractVectorTypeFromScalarView(viewInput), 150 viewInput, indicesInput); 151 } else { 152 vector = std_load(viewInput).value; 153 } 154 if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) { 155 SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero); 156 if (copyOp.outputPermutation()) 157 vector_transfer_write(vector, viewOutput, indicesOutput, 158 copyOp.outputPermutation().getValue()); 159 else 160 vector_transfer_write(vector, viewOutput, indicesOutput); 161 } else { 162 std_store(vector, viewOutput); 163 } 164 return; 165 } 166 167 assert(succeeded(isContraction(op)) && "Expected contraction"); 168 169 // Vectorize other ops as vector contraction. 170 // TODO: interface. 171 LLVM_DEBUG(dbgs() << dbgPref 172 << "Rewrite linalg op as vector.contract: " << *op); 173 auto linalgOp = cast<linalg::LinalgOp>(op); 174 Value viewA = linalgOp.getInput(0); 175 Value viewB = linalgOp.getInput(1); 176 Value viewC = linalgOp.getOutputBuffer(0); 177 VectorType vtA = extractVectorTypeFromScalarView(viewA); 178 VectorType vtB = extractVectorTypeFromScalarView(viewB); 179 VectorType vtC = extractVectorTypeFromScalarView(viewC); 180 Value zero = std_constant_index(0); 181 SmallVector<Value, 4> indicesA, indicesB, indicesC; 182 if (vtA) 183 indicesA = SmallVector<Value, 4>(vtA.getRank(), zero); 184 if (vtB) 185 indicesB = SmallVector<Value, 4>(vtB.getRank(), zero); 186 if (vtC) 187 indicesC = SmallVector<Value, 4>(vtC.getRank(), zero); 188 Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value 189 : std_load(viewA, indicesA).value; 190 Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value 191 : std_load(viewB, indicesB).value; 192 Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value 193 : std_load(viewC, indicesC).value; 194 Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), 195 linalgOp.iterator_types()); 196 if (vtC) 197 vector_transfer_write(res, viewC, indicesC); 198 else 199 std_store(res, viewC, indicesC); 200 } 201 202 /// Check whether there is any interleaved use of any `values` between `firstOp` 203 /// and `secondOp`. Conservatively return `true` if any op or value is in a 204 /// different block. 205 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 206 ValueRange values) { 207 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 208 (void)dbgPref; 209 if (firstOp->getBlock() != secondOp->getBlock() || 210 !firstOp->isBeforeInBlock(secondOp)) { 211 LLVM_DEBUG(llvm::dbgs() 212 << dbgPref << "interleavedUses precondition failed, firstOp: " 213 << *firstOp << ", second op: " << *secondOp); 214 return true; 215 } 216 for (auto v : values) { 217 for (auto &u : v.getUses()) { 218 Operation *owner = u.getOwner(); 219 if (owner == firstOp || owner == secondOp) 220 continue; 221 // TODO: this is too conservative, use dominance info in the future. 222 if (owner->getBlock() == firstOp->getBlock() && 223 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 224 continue; 225 LLVM_DEBUG(llvm::dbgs() 226 << dbgPref << " found interleaved op " << *owner 227 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 228 return true; 229 } 230 } 231 return false; 232 } 233 234 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 235 static SubViewOp getSubViewUseIfUnique(Value v) { 236 SubViewOp subViewOp; 237 for (auto &u : v.getUses()) { 238 if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) { 239 if (subViewOp) 240 return SubViewOp(); 241 subViewOp = newSubViewOp; 242 } 243 } 244 return subViewOp; 245 } 246 247 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 248 /// when available. 249 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 250 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 251 252 // Transfer into `view`. 253 Value viewOrAlloc = xferOp.memref(); 254 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 255 !viewOrAlloc.getDefiningOp<AllocOp>()) 256 return failure(); 257 258 StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; 259 (void)dbgPref; 260 LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); 261 262 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 263 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 264 if (!subViewOp) 265 return failure(); 266 Value subView = subViewOp.getResult(); 267 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); 268 269 // Find the copy into `subView` without interleaved uses. 270 CopyOp copyOp; 271 for (auto &u : subView.getUses()) { 272 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 273 if (newCopyOp.getOutputBuffer(0) != subView) 274 continue; 275 LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); 276 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 277 continue; 278 copyOp = newCopyOp; 279 break; 280 } 281 } 282 if (!copyOp) 283 return failure(); 284 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); 285 286 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 287 FillOp maybeFillOp; 288 for (auto &u : viewOrAlloc.getUses()) { 289 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 290 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 291 continue; 292 LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); 293 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 294 continue; 295 maybeFillOp = newFillOp; 296 break; 297 } 298 } 299 // Ensure padding matches. 300 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 301 return failure(); 302 if (maybeFillOp) 303 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); 304 305 // `in` is the subview that linalg.copy reads. Replace it. 306 Value in = copyOp.getInput(0); 307 308 // linalg.copy + linalg.fill can be used to create a padded local buffer. 309 // The `masked` attribute is only valid on this padded buffer. 310 // When forwarding to vector.transfer_read, the attribute must be reset 311 // conservatively. 312 Value res = rewriter.create<vector::TransferReadOp>( 313 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 314 xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); 315 316 if (maybeFillOp) 317 rewriter.eraseOp(maybeFillOp); 318 rewriter.eraseOp(copyOp); 319 rewriter.replaceOp(xferOp, res); 320 321 return success(); 322 } 323 324 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 325 /// when available. 326 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 327 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 328 // Transfer into `viewOrAlloc`. 329 Value viewOrAlloc = xferOp.memref(); 330 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 331 !viewOrAlloc.getDefiningOp<AllocOp>()) 332 return failure(); 333 334 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 335 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 336 if (!subViewOp) 337 return failure(); 338 Value subView = subViewOp.getResult(); 339 340 // Find the copy from `subView` without interleaved uses. 341 CopyOp copyOp; 342 for (auto &u : subViewOp.getResult().getUses()) { 343 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 344 if (newCopyOp.getInput(0) != subView) 345 continue; 346 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 347 continue; 348 copyOp = newCopyOp; 349 break; 350 } 351 } 352 if (!copyOp) 353 return failure(); 354 355 // `out` is the subview copied into that we replace. 356 Value out = copyOp.getOutputBuffer(0); 357 358 // Forward vector.transfer into copy. 359 // linalg.copy + linalg.fill can be used to create a padded local buffer. 360 // The `masked` attribute is only valid on this padded buffer. 361 // When forwarding to vector.transfer_write, the attribute must be reset 362 // conservatively. 363 rewriter.create<vector::TransferWriteOp>( 364 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 365 xferOp.permutation_map(), ArrayAttr()); 366 367 rewriter.eraseOp(copyOp); 368 rewriter.eraseOp(xferOp); 369 370 return success(); 371 } 372