1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewrite patterns for the permutation_map attribute of 10 // vector.transfer operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/Interfaces/VectorInterfaces.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 /// Transpose a vector transfer op's `in_bounds` attribute according to given 24 /// indices. 25 static ArrayAttr 26 transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, 27 const SmallVector<unsigned> &permutation) { 28 SmallVector<bool> newInBoundsValues; 29 for (unsigned pos : permutation) 30 newInBoundsValues.push_back( 31 attr.getValue()[pos].cast<BoolAttr>().getValue()); 32 return builder.getBoolArrayAttr(newInBoundsValues); 33 } 34 35 /// Lower transfer_read op with permutation into a transfer_read with a 36 /// permutation map composed of leading zeros followed by a minor identiy + 37 /// vector.transpose op. 38 /// Ex: 39 /// vector.transfer_read ... 40 /// permutation_map: (d0, d1, d2) -> (0, d1) 41 /// into: 42 /// %v = vector.transfer_read ... 43 /// permutation_map: (d0, d1, d2) -> (d1, 0) 44 /// vector.transpose %v, [1, 0] 45 /// 46 /// vector.transfer_read ... 47 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) 48 /// into: 49 /// %v = vector.transfer_read ... 50 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) 51 /// vector.transpose %v, [0, 1, 3, 2, 4] 52 /// Note that an alternative is to transform it to linalg.transpose + 53 /// vector.transfer_read to do the transpose in memory instead. 54 struct TransferReadPermutationLowering 55 : public OpRewritePattern<vector::TransferReadOp> { 56 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 57 58 LogicalResult matchAndRewrite(vector::TransferReadOp op, 59 PatternRewriter &rewriter) const override { 60 // TODO: support 0-d corner case. 61 if (op.getTransferRank() == 0) 62 return failure(); 63 64 SmallVector<unsigned> permutation; 65 AffineMap map = op.getPermutationMap(); 66 if (map.getNumResults() == 0) 67 return failure(); 68 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) 69 return failure(); 70 AffineMap permutationMap = 71 map.getPermutationMap(permutation, op.getContext()); 72 if (permutationMap.isIdentity()) 73 return failure(); 74 75 permutationMap = map.getPermutationMap(permutation, op.getContext()); 76 // Caluclate the map of the new read by applying the inverse permutation. 77 permutationMap = inversePermutation(permutationMap); 78 AffineMap newMap = permutationMap.compose(map); 79 // Apply the reverse transpose to deduce the type of the transfer_read. 80 ArrayRef<int64_t> originalShape = op.getVectorType().getShape(); 81 SmallVector<int64_t> newVectorShape(originalShape.size()); 82 for (const auto &pos : llvm::enumerate(permutation)) { 83 newVectorShape[pos.value()] = originalShape[pos.index()]; 84 } 85 86 // Transpose mask operand. 87 Value newMask; 88 if (op.getMask()) { 89 // Remove unused dims from the permutation map. E.g.: 90 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2) 91 // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0) 92 auto comp = compressUnusedDims(map); 93 // Get positions of remaining result dims. 94 // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0) 95 // maskTransposeIndices = [ 2, 1, 0] 96 SmallVector<int64_t> maskTransposeIndices; 97 for (unsigned i = 0; i < comp.getNumResults(); ++i) { 98 if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>()) 99 maskTransposeIndices.push_back(expr.getPosition()); 100 } 101 102 newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(), 103 maskTransposeIndices); 104 } 105 106 // Transpose in_bounds attribute. 107 ArrayAttr newInBoundsAttr = 108 op.getInBounds() 109 ? transposeInBoundsAttr(rewriter, op.getInBounds().getValue(), 110 permutation) 111 : ArrayAttr(); 112 113 // Generate new transfer_read operation. 114 VectorType newReadType = 115 VectorType::get(newVectorShape, op.getVectorType().getElementType()); 116 Value newRead = rewriter.create<vector::TransferReadOp>( 117 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 118 AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr); 119 120 // Transpose result of transfer_read. 121 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); 122 rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead, 123 transposePerm); 124 return success(); 125 } 126 }; 127 128 /// Lower transfer_write op with permutation into a transfer_write with a 129 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) 130 /// Ex: 131 /// vector.transfer_write %v ... 132 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) 133 /// into: 134 /// %tmp = vector.transpose %v, [2, 0, 1] 135 /// vector.transfer_write %tmp ... 136 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) 137 /// 138 /// vector.transfer_write %v ... 139 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) 140 /// into: 141 /// %tmp = vector.transpose %v, [1, 0] 142 /// %v = vector.transfer_write %tmp ... 143 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) 144 struct TransferWritePermutationLowering 145 : public OpRewritePattern<vector::TransferWriteOp> { 146 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 147 148 LogicalResult matchAndRewrite(vector::TransferWriteOp op, 149 PatternRewriter &rewriter) const override { 150 // TODO: support 0-d corner case. 151 if (op.getTransferRank() == 0) 152 return failure(); 153 154 SmallVector<unsigned> permutation; 155 AffineMap map = op.getPermutationMap(); 156 if (map.isMinorIdentity()) 157 return failure(); 158 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) 159 return failure(); 160 161 // Remove unused dims from the permutation map. E.g.: 162 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) 163 // comp = (d0, d1, d2) -> (d2, d0, d1) 164 auto comp = compressUnusedDims(map); 165 // Get positions of remaining result dims. 166 SmallVector<int64_t> indices; 167 llvm::transform(comp.getResults(), std::back_inserter(indices), 168 [](AffineExpr expr) { 169 return expr.dyn_cast<AffineDimExpr>().getPosition(); 170 }); 171 172 // Transpose mask operand. 173 Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>( 174 op.getLoc(), op.getMask(), indices) 175 : Value(); 176 177 // Transpose in_bounds attribute. 178 ArrayAttr newInBoundsAttr = 179 op.getInBounds() 180 ? transposeInBoundsAttr(rewriter, op.getInBounds().getValue(), 181 permutation) 182 : ArrayAttr(); 183 184 // Generate new transfer_write operation. 185 Value newVec = rewriter.create<vector::TransposeOp>( 186 op.getLoc(), op.getVector(), indices); 187 auto newMap = AffineMap::getMinorIdentityMap( 188 map.getNumDims(), map.getNumResults(), rewriter.getContext()); 189 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 190 op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), 191 newMask, newInBoundsAttr); 192 193 return success(); 194 } 195 }; 196 197 /// Lower transfer_read op with broadcast in the leading dimensions into 198 /// transfer_read of lower rank + vector.broadcast. 199 /// Ex: vector.transfer_read ... 200 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) 201 /// into: 202 /// %v = vector.transfer_read ... 203 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) 204 /// vector.broadcast %v 205 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> { 206 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 207 208 LogicalResult matchAndRewrite(vector::TransferReadOp op, 209 PatternRewriter &rewriter) const override { 210 // TODO: support 0-d corner case. 211 if (op.getTransferRank() == 0) 212 return failure(); 213 214 AffineMap map = op.getPermutationMap(); 215 unsigned numLeadingBroadcast = 0; 216 for (auto expr : map.getResults()) { 217 auto dimExpr = expr.dyn_cast<AffineConstantExpr>(); 218 if (!dimExpr || dimExpr.getValue() != 0) 219 break; 220 numLeadingBroadcast++; 221 } 222 // If there are no leading zeros in the map there is nothing to do. 223 if (numLeadingBroadcast == 0) 224 return failure(); 225 VectorType originalVecType = op.getVectorType(); 226 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; 227 // Calculate new map, vector type and masks without the leading zeros. 228 AffineMap newMap = AffineMap::get( 229 map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), 230 op.getContext()); 231 // Only remove the leading zeros if the rest of the map is a minor identity 232 // with broadasting. Otherwise we first want to permute the map. 233 if (!newMap.isMinorIdentityWithBroadcasting()) 234 return failure(); 235 236 // TODO: support zero-dimension vectors natively. See: 237 // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. 238 // In the meantime, lower these to a scalar load when they pop up. 239 if (reducedShapeRank == 0) { 240 Value newRead; 241 if (op.getShapedType().isa<TensorType>()) { 242 newRead = rewriter.create<tensor::ExtractOp>( 243 op.getLoc(), op.getSource(), op.getIndices()); 244 } else { 245 newRead = rewriter.create<memref::LoadOp>( 246 op.getLoc(), originalVecType.getElementType(), op.getSource(), 247 op.getIndices()); 248 } 249 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, 250 newRead); 251 return success(); 252 } 253 SmallVector<int64_t> newShape = llvm::to_vector<4>( 254 originalVecType.getShape().take_back(reducedShapeRank)); 255 // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. 256 if (newShape.empty()) 257 return failure(); 258 VectorType newReadType = 259 VectorType::get(newShape, originalVecType.getElementType()); 260 ArrayAttr newInBoundsAttr = 261 op.getInBounds() 262 ? rewriter.getArrayAttr( 263 op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) 264 : ArrayAttr(); 265 Value newRead = rewriter.create<vector::TransferReadOp>( 266 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 267 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 268 newInBoundsAttr); 269 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, 270 newRead); 271 return success(); 272 } 273 }; 274 275 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 276 RewritePatternSet &patterns) { 277 patterns.add<TransferReadPermutationLowering, 278 TransferWritePermutationLowering, TransferOpReduceRank>( 279 patterns.getContext()); 280 } 281