1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewrite patterns for the permutation_map attribute of 10 // vector.transfer operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/Interfaces/VectorInterfaces.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 /// Transpose a vector transfer op's `in_bounds` attribute according to given 24 /// indices. 25 static ArrayAttr 26 transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, 27 const SmallVector<unsigned> &permutation) { 28 SmallVector<bool> newInBoundsValues; 29 for (unsigned pos : permutation) 30 newInBoundsValues.push_back( 31 attr.getValue()[pos].cast<BoolAttr>().getValue()); 32 return builder.getBoolArrayAttr(newInBoundsValues); 33 } 34 35 /// Lower transfer_read op with permutation into a transfer_read with a 36 /// permutation map composed of leading zeros followed by a minor identiy + 37 /// vector.transpose op. 38 /// Ex: 39 /// vector.transfer_read ... 40 /// permutation_map: (d0, d1, d2) -> (0, d1) 41 /// into: 42 /// %v = vector.transfer_read ... 43 /// permutation_map: (d0, d1, d2) -> (d1, 0) 44 /// vector.transpose %v, [1, 0] 45 /// 46 /// vector.transfer_read ... 47 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) 48 /// into: 49 /// %v = vector.transfer_read ... 50 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) 51 /// vector.transpose %v, [0, 1, 3, 2, 4] 52 /// Note that an alternative is to transform it to linalg.transpose + 53 /// vector.transfer_read to do the transpose in memory instead. 54 struct TransferReadPermutationLowering 55 : public OpRewritePattern<vector::TransferReadOp> { 56 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 57 58 LogicalResult matchAndRewrite(vector::TransferReadOp op, 59 PatternRewriter &rewriter) const override { 60 // TODO: support 0-d corner case. 61 if (op.getTransferRank() == 0) 62 return failure(); 63 64 SmallVector<unsigned> permutation; 65 AffineMap map = op.getPermutationMap(); 66 if (map.getNumResults() == 0) 67 return failure(); 68 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) 69 return failure(); 70 AffineMap permutationMap = 71 map.getPermutationMap(permutation, op.getContext()); 72 if (permutationMap.isIdentity()) 73 return failure(); 74 75 permutationMap = map.getPermutationMap(permutation, op.getContext()); 76 // Caluclate the map of the new read by applying the inverse permutation. 77 permutationMap = inversePermutation(permutationMap); 78 AffineMap newMap = permutationMap.compose(map); 79 // Apply the reverse transpose to deduce the type of the transfer_read. 80 ArrayRef<int64_t> originalShape = op.getVectorType().getShape(); 81 SmallVector<int64_t> newVectorShape(originalShape.size()); 82 for (const auto &pos : llvm::enumerate(permutation)) { 83 newVectorShape[pos.value()] = originalShape[pos.index()]; 84 } 85 86 // Transpose mask operand. 87 Value newMask; 88 if (op.getMask()) { 89 // Remove unused dims from the permutation map. E.g.: 90 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2) 91 // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0) 92 auto comp = compressUnusedDims(map); 93 // Get positions of remaining result dims. 94 // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0) 95 // maskTransposeIndices = [ 2, 1, 0] 96 SmallVector<int64_t> maskTransposeIndices; 97 for (unsigned i = 0; i < comp.getNumResults(); ++i) { 98 if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>()) 99 maskTransposeIndices.push_back(expr.getPosition()); 100 } 101 102 newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(), 103 maskTransposeIndices); 104 } 105 106 // Transpose in_bounds attribute. 107 ArrayAttr newInBoundsAttr = 108 op.getInBounds() ? transposeInBoundsAttr( 109 rewriter, op.getInBounds().value(), permutation) 110 : ArrayAttr(); 111 112 // Generate new transfer_read operation. 113 VectorType newReadType = 114 VectorType::get(newVectorShape, op.getVectorType().getElementType()); 115 Value newRead = rewriter.create<vector::TransferReadOp>( 116 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 117 AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr); 118 119 // Transpose result of transfer_read. 120 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); 121 rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead, 122 transposePerm); 123 return success(); 124 } 125 }; 126 127 /// Lower transfer_write op with permutation into a transfer_write with a 128 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) 129 /// Ex: 130 /// vector.transfer_write %v ... 131 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) 132 /// into: 133 /// %tmp = vector.transpose %v, [2, 0, 1] 134 /// vector.transfer_write %tmp ... 135 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) 136 /// 137 /// vector.transfer_write %v ... 138 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) 139 /// into: 140 /// %tmp = vector.transpose %v, [1, 0] 141 /// %v = vector.transfer_write %tmp ... 142 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) 143 struct TransferWritePermutationLowering 144 : public OpRewritePattern<vector::TransferWriteOp> { 145 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 146 147 LogicalResult matchAndRewrite(vector::TransferWriteOp op, 148 PatternRewriter &rewriter) const override { 149 // TODO: support 0-d corner case. 150 if (op.getTransferRank() == 0) 151 return failure(); 152 153 SmallVector<unsigned> permutation; 154 AffineMap map = op.getPermutationMap(); 155 if (map.isMinorIdentity()) 156 return failure(); 157 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) 158 return failure(); 159 160 // Remove unused dims from the permutation map. E.g.: 161 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) 162 // comp = (d0, d1, d2) -> (d2, d0, d1) 163 auto comp = compressUnusedDims(map); 164 // Get positions of remaining result dims. 165 SmallVector<int64_t> indices; 166 llvm::transform(comp.getResults(), std::back_inserter(indices), 167 [](AffineExpr expr) { 168 return expr.dyn_cast<AffineDimExpr>().getPosition(); 169 }); 170 171 // Transpose mask operand. 172 Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>( 173 op.getLoc(), op.getMask(), indices) 174 : Value(); 175 176 // Transpose in_bounds attribute. 177 ArrayAttr newInBoundsAttr = 178 op.getInBounds() ? transposeInBoundsAttr( 179 rewriter, op.getInBounds().value(), permutation) 180 : ArrayAttr(); 181 182 // Generate new transfer_write operation. 183 Value newVec = rewriter.create<vector::TransposeOp>( 184 op.getLoc(), op.getVector(), indices); 185 auto newMap = AffineMap::getMinorIdentityMap( 186 map.getNumDims(), map.getNumResults(), rewriter.getContext()); 187 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 188 op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), 189 newMask, newInBoundsAttr); 190 191 return success(); 192 } 193 }; 194 195 /// Lower transfer_read op with broadcast in the leading dimensions into 196 /// transfer_read of lower rank + vector.broadcast. 197 /// Ex: vector.transfer_read ... 198 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) 199 /// into: 200 /// %v = vector.transfer_read ... 201 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) 202 /// vector.broadcast %v 203 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> { 204 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 205 206 LogicalResult matchAndRewrite(vector::TransferReadOp op, 207 PatternRewriter &rewriter) const override { 208 // TODO: support 0-d corner case. 209 if (op.getTransferRank() == 0) 210 return failure(); 211 212 AffineMap map = op.getPermutationMap(); 213 unsigned numLeadingBroadcast = 0; 214 for (auto expr : map.getResults()) { 215 auto dimExpr = expr.dyn_cast<AffineConstantExpr>(); 216 if (!dimExpr || dimExpr.getValue() != 0) 217 break; 218 numLeadingBroadcast++; 219 } 220 // If there are no leading zeros in the map there is nothing to do. 221 if (numLeadingBroadcast == 0) 222 return failure(); 223 VectorType originalVecType = op.getVectorType(); 224 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; 225 // Calculate new map, vector type and masks without the leading zeros. 226 AffineMap newMap = AffineMap::get( 227 map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), 228 op.getContext()); 229 // Only remove the leading zeros if the rest of the map is a minor identity 230 // with broadasting. Otherwise we first want to permute the map. 231 if (!newMap.isMinorIdentityWithBroadcasting()) 232 return failure(); 233 234 // TODO: support zero-dimension vectors natively. See: 235 // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. 236 // In the meantime, lower these to a scalar load when they pop up. 237 if (reducedShapeRank == 0) { 238 Value newRead; 239 if (op.getShapedType().isa<TensorType>()) { 240 newRead = rewriter.create<tensor::ExtractOp>( 241 op.getLoc(), op.getSource(), op.getIndices()); 242 } else { 243 newRead = rewriter.create<memref::LoadOp>( 244 op.getLoc(), originalVecType.getElementType(), op.getSource(), 245 op.getIndices()); 246 } 247 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, 248 newRead); 249 return success(); 250 } 251 SmallVector<int64_t> newShape = llvm::to_vector<4>( 252 originalVecType.getShape().take_back(reducedShapeRank)); 253 // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. 254 if (newShape.empty()) 255 return failure(); 256 VectorType newReadType = 257 VectorType::get(newShape, originalVecType.getElementType()); 258 ArrayAttr newInBoundsAttr = 259 op.getInBounds() 260 ? rewriter.getArrayAttr( 261 op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) 262 : ArrayAttr(); 263 Value newRead = rewriter.create<vector::TransferReadOp>( 264 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 265 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 266 newInBoundsAttr); 267 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, 268 newRead); 269 return success(); 270 } 271 }; 272 273 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 274 RewritePatternSet &patterns) { 275 patterns.add<TransferReadPermutationLowering, 276 TransferWritePermutationLowering, TransferOpReduceRank>( 277 patterns.getContext()); 278 } 279