1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 #define DEBUG_TYPE "linalg-vectorization" 38 39 static bool hasMultiplyAddBody(Region &r) { 40 if (!llvm::hasSingleElement(r)) 41 return false; 42 if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) 43 return false; 44 45 using mlir::matchers::m_Val; 46 auto a = m_Val(r.front().getArgument(0)); 47 auto b = m_Val(r.front().getArgument(1)); 48 auto c = m_Val(r.front().getArgument(2)); 49 // TODO: Update this detection once we have matcher support for specifying 50 // that any permutation of operands matches. 51 auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); 52 auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); 53 auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); 54 auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); 55 return pattern1.match(&r.front().back()) || 56 pattern2.match(&r.front().back()) || 57 pattern3.match(&r.front().back()) || pattern4.match(&r.front().back()); 58 } 59 60 // TODO: Should be Tablegen'd from a single source that generates the op itself. 61 static LogicalResult isContraction(Operation *op) { 62 // TODO: interface for named ops. 63 if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp, 64 linalg::DotOp>(op)) 65 return success(); 66 67 auto genericOp = dyn_cast<linalg::GenericOp>(op); 68 if (!genericOp) 69 return failure(); 70 71 auto mapRange = 72 genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>(); 73 74 return success( 75 genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && 76 llvm::all_of(mapRange, 77 [](AffineMap m) { return m.isProjectedPermutation(); }) && 78 hasMultiplyAddBody(genericOp.region())); 79 } 80 81 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 82 auto linalgOp = cast<linalg::LinalgOp>(op); 83 // All types must be static shape to go to vector. 84 for (Value operand : linalgOp.getInputsAndOutputBuffers()) 85 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 86 return failure(); 87 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 88 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 89 return failure(); 90 91 if (isa<linalg::FillOp>(op)) 92 return success(); 93 94 return isContraction(op); 95 } 96 97 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 98 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 99 100 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 101 (void)dbgPref; 102 edsc::ScopedContext scope(builder, op->getLoc()); 103 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 104 // Vectorize fill as a vector.broadcast. 105 LLVM_DEBUG(dbgs() << dbgPref 106 << "Rewrite linalg.fill as vector.broadcast: " << *op); 107 Value memref = vector_type_cast(fillOp.getOutputBuffer(0)); 108 Value dst = std_load(memref); 109 Value res = vector_broadcast(dst.getType(), fillOp.value()); 110 std_store(res, memref); 111 return; 112 } 113 114 assert(succeeded(isContraction(op)) && "Expected contraction"); 115 116 // Vectorize other ops as vector contraction. 117 // TODO: interface. 118 LLVM_DEBUG(dbgs() << dbgPref 119 << "Rewrite linalg op as vector.contract: " << *op); 120 // In the case of 0-D memrefs, return null and special case to scalar load or 121 // store later. 122 auto extractVectorTypeFromScalarView = [](Value v) { 123 MemRefType mt = v.getType().cast<MemRefType>(); 124 return mt.getShape().empty() 125 ? VectorType() 126 : VectorType::get(mt.getShape(), mt.getElementType()); 127 }; 128 auto linalgOp = cast<linalg::LinalgOp>(op); 129 Value viewA = linalgOp.getInput(0); 130 Value viewB = linalgOp.getInput(1); 131 Value viewC = linalgOp.getOutputBuffer(0); 132 VectorType vtA = extractVectorTypeFromScalarView(viewA); 133 VectorType vtB = extractVectorTypeFromScalarView(viewB); 134 VectorType vtC = extractVectorTypeFromScalarView(viewC); 135 Value zero = std_constant_index(0); 136 SmallVector<Value, 4> indicesA, indicesB, indicesC; 137 if (vtA) 138 indicesA = SmallVector<Value, 4>(vtA.getRank(), zero); 139 if (vtB) 140 indicesB = SmallVector<Value, 4>(vtB.getRank(), zero); 141 if (vtC) 142 indicesC = SmallVector<Value, 4>(vtC.getRank(), zero); 143 Value a = vtA ? vector_transfer_read(vtA, viewA, indicesA).value 144 : std_load(viewA, indicesA).value; 145 Value b = vtB ? vector_transfer_read(vtB, viewB, indicesB).value 146 : std_load(viewB, indicesB).value; 147 Value c = vtC ? vector_transfer_read(vtC, viewC, indicesC).value 148 : std_load(viewC, indicesC).value; 149 Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), 150 linalgOp.iterator_types()); 151 if (vtC) 152 vector_transfer_write(res, viewC, indicesC); 153 else 154 std_store(res, viewC, indicesC); 155 } 156 157 /// Check whether there is any interleaved use of any `values` between `firstOp` 158 /// and `secondOp`. Conservatively return `true` if any op or value is in a 159 /// different block. 160 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 161 ValueRange values) { 162 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 163 (void)dbgPref; 164 if (firstOp->getBlock() != secondOp->getBlock() || 165 !firstOp->isBeforeInBlock(secondOp)) { 166 LLVM_DEBUG(llvm::dbgs() 167 << dbgPref << "interleavedUses precondition failed, firstOp: " 168 << *firstOp << ", second op: " << *secondOp); 169 return true; 170 } 171 for (auto v : values) { 172 for (auto &u : v.getUses()) { 173 Operation *owner = u.getOwner(); 174 if (owner == firstOp || owner == secondOp) 175 continue; 176 // TODO: this is too conservative, use dominance info in the future. 177 if (owner->getBlock() == firstOp->getBlock() && 178 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 179 continue; 180 LLVM_DEBUG(llvm::dbgs() 181 << dbgPref << " found interleaved op " << *owner 182 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 183 return true; 184 } 185 } 186 return false; 187 } 188 189 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 190 static SubViewOp getSubViewUseIfUnique(Value v) { 191 SubViewOp subViewOp; 192 for (auto &u : v.getUses()) { 193 if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) { 194 if (subViewOp) 195 return SubViewOp(); 196 subViewOp = newSubViewOp; 197 } 198 } 199 return subViewOp; 200 } 201 202 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 203 /// when available. 204 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 205 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 206 207 // Transfer into `view`. 208 Value viewOrAlloc = xferOp.memref(); 209 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 210 !viewOrAlloc.getDefiningOp<AllocOp>()) 211 return failure(); 212 213 StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; 214 (void)dbgPref; 215 LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); 216 217 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 218 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 219 if (!subViewOp) 220 return failure(); 221 Value subView = subViewOp.getResult(); 222 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); 223 224 // Find the copy into `subView` without interleaved uses. 225 CopyOp copyOp; 226 for (auto &u : subView.getUses()) { 227 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 228 if (newCopyOp.getOutputBuffer(0) != subView) 229 continue; 230 LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); 231 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 232 continue; 233 copyOp = newCopyOp; 234 break; 235 } 236 } 237 if (!copyOp) 238 return failure(); 239 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); 240 241 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 242 FillOp maybeFillOp; 243 for (auto &u : viewOrAlloc.getUses()) { 244 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 245 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 246 continue; 247 LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); 248 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 249 continue; 250 maybeFillOp = newFillOp; 251 break; 252 } 253 } 254 // Ensure padding matches. 255 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 256 return failure(); 257 if (maybeFillOp) 258 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); 259 260 // `in` is the subview that linalg.copy reads. Replace it. 261 Value in = copyOp.getInput(0); 262 263 // linalg.copy + linalg.fill can be used to create a padded local buffer. 264 // The `masked` attribute is only valid on this padded buffer. 265 // When forwarding to vector.transfer_read, the attribute must be reset 266 // conservatively. 267 Value res = rewriter.create<vector::TransferReadOp>( 268 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 269 xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); 270 271 if (maybeFillOp) 272 rewriter.eraseOp(maybeFillOp); 273 rewriter.eraseOp(copyOp); 274 rewriter.replaceOp(xferOp, res); 275 276 return success(); 277 } 278 279 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 280 /// when available. 281 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 282 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 283 // Transfer into `viewOrAlloc`. 284 Value viewOrAlloc = xferOp.memref(); 285 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 286 !viewOrAlloc.getDefiningOp<AllocOp>()) 287 return failure(); 288 289 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 290 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 291 if (!subViewOp) 292 return failure(); 293 Value subView = subViewOp.getResult(); 294 295 // Find the copy from `subView` without interleaved uses. 296 CopyOp copyOp; 297 for (auto &u : subViewOp.getResult().getUses()) { 298 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 299 if (newCopyOp.getInput(0) != subView) 300 continue; 301 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 302 continue; 303 copyOp = newCopyOp; 304 break; 305 } 306 } 307 if (!copyOp) 308 return failure(); 309 310 // `out` is the subview copied into that we replace. 311 Value out = copyOp.getOutputBuffer(0); 312 313 // Forward vector.transfer into copy. 314 // linalg.copy + linalg.fill can be used to create a padded local buffer. 315 // The `masked` attribute is only valid on this padded buffer. 316 // When forwarding to vector.transfer_write, the attribute must be reset 317 // conservatively. 318 rewriter.create<vector::TransferWriteOp>( 319 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 320 xferOp.permutation_map(), ArrayAttr()); 321 322 rewriter.eraseOp(copyOp); 323 rewriter.eraseOp(xferOp); 324 325 return success(); 326 } 327