1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 #define DEBUG_TYPE "linalg-vectorization" 38 39 static bool hasMultiplyAddBody(linalg::GenericOp op) { 40 auto &r = op.region(); 41 if (!llvm::hasSingleElement(r)) 42 return false; 43 if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) 44 return false; 45 46 using mlir::matchers::m_Val; 47 auto a = m_Val(r.front().getArgument(0)); 48 auto b = m_Val(r.front().getArgument(1)); 49 auto c = m_Val(r.front().getArgument(2)); 50 // TODO: Update this detection once we have matcher support for specifying 51 // that any permutation of operands matches. 52 auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); 53 auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); 54 auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); 55 auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); 56 return pattern1.match(&r.front().back()) || 57 pattern2.match(&r.front().back()) || 58 pattern3.match(&r.front().back()) || pattern4.match(&r.front().back()); 59 } 60 61 // TODO: Should be Tablegen'd from a single source that generates the op itself. 62 static bool isRowMajorMatmul(linalg::GenericOp genericOp) { 63 return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && 64 isRowMajorMatmul(genericOp.indexing_maps()) && 65 hasMultiplyAddBody(genericOp); 66 } 67 68 // TODO: This is in fact much more general than just vectorization for matmul 69 // and fill ops. 70 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 71 auto linalgOp = cast<linalg::LinalgOp>(op); 72 // All types must be static shape to go to vector. 73 for (Value operand : linalgOp.getInputsAndOutputBuffers()) 74 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 75 return failure(); 76 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 77 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 78 return failure(); 79 if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op)) 80 return success(); 81 82 auto genericOp = dyn_cast<linalg::GenericOp>(op); 83 if (!genericOp || !::isRowMajorMatmul(genericOp)) 84 return failure(); 85 86 // TODO(ntv): non-identity layout. 87 auto isStaticMemRefWithIdentityLayout = [](Value v) { 88 auto m = v.getType().dyn_cast<MemRefType>(); 89 if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty()) 90 return false; 91 return true; 92 }; 93 return success(llvm::all_of(genericOp.getInputsAndOutputBuffers(), 94 isStaticMemRefWithIdentityLayout)); 95 } 96 97 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 98 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 99 100 if (auto convOp = dyn_cast<linalg::ConvOp>(op)) { 101 // TODO: add a level of indirection to linalg.generic. 102 if (convOp.padding()) 103 llvm_unreachable("Unexpected conv with padding"); 104 } 105 106 edsc::ScopedContext scope(builder, op->getLoc()); 107 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 108 // Vectorize fill as a vector.broadcast. 109 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE 110 "]: Rewrite linalg.fill as vector.broadcast: " 111 << *op << ":\n"); 112 Value memref = vector_type_cast(fillOp.getOutputBuffer(0)); 113 Value dst = std_load(memref); 114 Value res = vector_broadcast(dst.getType(), fillOp.value()); 115 std_store(res, memref); 116 return; 117 } 118 119 // Vectorize other ops as vector contraction (currently only matmul). 120 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE 121 "]: Rewrite linalg op as vector.contract: " 122 << *op << ":\n"); 123 auto linalgOp = cast<linalg::LinalgOp>(op); 124 Value a = std_load(vector_type_cast(linalgOp.getInput(0))); 125 Value b = std_load(vector_type_cast(linalgOp.getInput(1))); 126 Value memref = vector_type_cast(linalgOp.getOutputBuffer(0)); 127 Value c = std_load(memref); 128 Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), 129 linalgOp.iterator_types()); 130 std_store(res, memref); 131 } 132