1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 #define DEBUG_TYPE "linalg-vectorization" 38 39 static bool hasMultiplyAddBody(linalg::GenericOp op) { 40 auto &r = op.region(); 41 if (!llvm::hasSingleElement(r)) 42 return false; 43 if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) 44 return false; 45 46 using mlir::matchers::m_Val; 47 auto a = m_Val(r.front().getArgument(0)); 48 auto b = m_Val(r.front().getArgument(1)); 49 auto c = m_Val(r.front().getArgument(2)); 50 // TODO: Update this detection once we have matcher support for specifying 51 // that any permutation of operands matches. 52 auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); 53 auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); 54 auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); 55 auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); 56 return pattern1.match(&r.front().back()) || 57 pattern2.match(&r.front().back()) || 58 pattern3.match(&r.front().back()) || pattern4.match(&r.front().back()); 59 } 60 61 // TODO: Should be Tablegen'd from a single source that generates the op itself. 62 static bool isRowMajorMatmul(linalg::GenericOp genericOp) { 63 return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && 64 isRowMajorMatmul(genericOp.indexing_maps()) && 65 hasMultiplyAddBody(genericOp); 66 } 67 68 // TODO: This is in fact much more general than just vectorization for matmul 69 // and fill ops. 70 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 71 auto linalgOp = cast<linalg::LinalgOp>(op); 72 // All types must be static shape to go to vector. 73 for (Value operand : linalgOp.getInputsAndOutputBuffers()) 74 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 75 return failure(); 76 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 77 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 78 return failure(); 79 if (isa<linalg::MatmulOp, linalg::FillOp>(op)) 80 return success(); 81 82 auto genericOp = dyn_cast<linalg::GenericOp>(op); 83 if (!genericOp || !::isRowMajorMatmul(genericOp)) 84 return failure(); 85 86 // TODO(ntv): non-identity layout. 87 auto isStaticMemRefWithIdentityLayout = [](Value v) { 88 auto m = v.getType().dyn_cast<MemRefType>(); 89 if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) 90 return false; 91 return true; 92 }; 93 return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(), 94 isStaticMemRefWithIdentityLayout)); 95 } 96 97 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 98 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 99 100 if (auto convOp = dyn_cast<linalg::ConvOp>(op)) { 101 // TODO: add a level of indirection to linalg.generic. 102 if (convOp.padding()) 103 llvm_unreachable("Unexpected conv with padding"); 104 } 105 106 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 107 (void)dbgPref; 108 edsc::ScopedContext scope(builder, op->getLoc()); 109 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 110 // Vectorize fill as a vector.broadcast. 111 LLVM_DEBUG(dbgs() << dbgPref 112 << "Rewrite linalg.fill as vector.broadcast: " << *op); 113 Value memref = vector_type_cast(fillOp.getOutputBuffer(0)); 114 Value dst = std_load(memref); 115 Value res = vector_broadcast(dst.getType(), fillOp.value()); 116 std_store(res, memref); 117 return; 118 } 119 120 // Vectorize other ops as vector contraction (currently only matmul). 121 LLVM_DEBUG(dbgs() << dbgPref 122 << "Rewrite linalg op as vector.contract: " << *op); 123 auto extractVectorTypeFromScalarView = [](Value v) { 124 MemRefType mt = v.getType().cast<MemRefType>(); 125 return VectorType::get(mt.getShape(), mt.getElementType()); 126 }; 127 auto linalgOp = cast<linalg::LinalgOp>(op); 128 Value viewA = linalgOp.getInput(0); 129 Value viewB = linalgOp.getInput(1); 130 Value viewC = linalgOp.getOutputBuffer(0); 131 Value zero = std_constant_index(0); 132 SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(), 133 zero); 134 SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(), 135 zero); 136 SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(), 137 zero); 138 Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA, 139 indicesA); 140 Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB, 141 indicesB); 142 Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC, 143 indicesC); 144 Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), 145 linalgOp.iterator_types()); 146 vector_transfer_write(res, viewC, indicesC); 147 } 148 149 /// Check whether there is any interleaved use of any `values` between `firstOp` 150 /// and `secondOp`. Conservatively return `true` if any op or value is in a 151 /// different block. 152 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 153 ValueRange values) { 154 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 155 (void)dbgPref; 156 if (firstOp->getBlock() != secondOp->getBlock() || 157 !firstOp->isBeforeInBlock(secondOp)) { 158 LLVM_DEBUG(llvm::dbgs() 159 << dbgPref << "interleavedUses precondition failed, firstOp: " 160 << *firstOp << ", second op: " << *secondOp); 161 return true; 162 } 163 for (auto v : values) { 164 for (auto &u : v.getUses()) { 165 Operation *owner = u.getOwner(); 166 if (owner == firstOp || owner == secondOp) 167 continue; 168 // TODO: this is too conservative, use dominance info in the future. 169 if (owner->getBlock() == firstOp->getBlock() && 170 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 171 continue; 172 LLVM_DEBUG(llvm::dbgs() 173 << dbgPref << " found interleaved op " << *owner 174 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 175 return true; 176 } 177 } 178 return false; 179 } 180 181 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 182 static SubViewOp getSubViewUseIfUnique(Value v) { 183 SubViewOp subViewOp; 184 for (auto &u : v.getUses()) { 185 if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) { 186 if (subViewOp) 187 return SubViewOp(); 188 subViewOp = newSubViewOp; 189 } 190 } 191 return subViewOp; 192 } 193 194 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 195 /// when available. 196 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 197 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 198 199 // Transfer into `view`. 200 Value viewOrAlloc = xferOp.memref(); 201 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 202 !viewOrAlloc.getDefiningOp<AllocOp>()) 203 return failure(); 204 205 StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; 206 (void)dbgPref; 207 LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); 208 209 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 210 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 211 if (!subViewOp) 212 return failure(); 213 Value subView = subViewOp.getResult(); 214 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); 215 216 // Find the copy into `subView` without interleaved uses. 217 CopyOp copyOp; 218 for (auto &u : subView.getUses()) { 219 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 220 if (newCopyOp.getOutputBuffer(0) != subView) 221 continue; 222 LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); 223 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 224 continue; 225 copyOp = newCopyOp; 226 break; 227 } 228 } 229 if (!copyOp) 230 return failure(); 231 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); 232 233 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 234 FillOp maybeFillOp; 235 for (auto &u : viewOrAlloc.getUses()) { 236 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 237 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 238 continue; 239 LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); 240 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 241 continue; 242 maybeFillOp = newFillOp; 243 break; 244 } 245 } 246 // Ensure padding matches. 247 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 248 return failure(); 249 if (maybeFillOp) 250 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); 251 252 // `in` is the subview that linalg.copy reads. Replace it. 253 Value in = copyOp.getInput(0); 254 255 Value res = rewriter.create<vector::TransferReadOp>( 256 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 257 xferOp.permutation_map(), xferOp.padding(), 258 xferOp.masked() ? *xferOp.masked() : ArrayAttr()); 259 260 if (maybeFillOp) 261 rewriter.eraseOp(maybeFillOp); 262 rewriter.eraseOp(copyOp); 263 rewriter.replaceOp(xferOp, res); 264 265 return success(); 266 } 267 268 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 269 /// when available. 270 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 271 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 272 // Transfer into `viewOrAlloc`. 273 Value viewOrAlloc = xferOp.memref(); 274 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 275 !viewOrAlloc.getDefiningOp<AllocOp>()) 276 return failure(); 277 278 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 279 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 280 if (!subViewOp) 281 return failure(); 282 Value subView = subViewOp.getResult(); 283 284 // Find the copy from `subView` without interleaved uses. 285 CopyOp copyOp; 286 for (auto &u : subViewOp.getResult().getUses()) { 287 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 288 if (newCopyOp.getInput(0) != subView) 289 continue; 290 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 291 continue; 292 copyOp = newCopyOp; 293 break; 294 } 295 } 296 if (!copyOp) 297 return failure(); 298 299 // `out` is the subview copied into that we replace. 300 Value out = copyOp.getOutputBuffer(0); 301 302 // Forward vector.transfer into copy. 303 rewriter.create<vector::TransferWriteOp>( 304 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 305 xferOp.permutation_map(), 306 xferOp.masked() ? *xferOp.masked() : ArrayAttr()); 307 308 rewriter.eraseOp(copyOp); 309 rewriter.eraseOp(xferOp); 310 311 return success(); 312 } 313