1 //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Interfaces/TilingInterface.h" 16 17 using namespace mlir; 18 using namespace mlir::tensor; 19 20 namespace { 21 22 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { 23 24 SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const { 25 ReifiedRankedShapedTypeDims reifiedShapes; 26 ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = 27 dyn_cast<ReifyRankedShapedTypeOpInterface>(op); 28 (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); 29 30 auto padOp = cast<PadOp>(op); 31 SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]); 32 Value initTensor = b.create<linalg::InitTensorOp>( 33 op->getLoc(), mixedSizes, padOp.getResultType().getElementType()); 34 return {initTensor}; 35 } 36 37 SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const { 38 auto padOp = cast<PadOp>(op); 39 SmallVector<StringRef> iteratorTypes(padOp.getResultType().getRank(), 40 getParallelIteratorTypeName()); 41 return iteratorTypes; 42 } 43 44 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 45 ReifiedRankedShapedTypeDims reifiedShapes; 46 ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = 47 dyn_cast<ReifyRankedShapedTypeOpInterface>(op); 48 (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); 49 50 Location loc = op->getLoc(); 51 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 52 Value one = b.create<arith::ConstantIndexOp>(loc, 1); 53 // Initialize all the ranges to {zero, one, one}. All the `ub`s are 54 // overwritten. 55 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); 56 for (const auto &ub : enumerate(reifiedShapes[0])) 57 loopRanges[ub.index()].size = ub.value(); 58 return loopRanges; 59 } 60 61 SmallVector<Operation *> 62 getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest, 63 ArrayRef<OpFoldResult> offsets, 64 ArrayRef<OpFoldResult> sizes, 65 bool /*tileDestOperands*/) const { 66 auto padOp = cast<PadOp>(op); 67 // Only constant padding value supported. 68 Value padValue = padOp.getConstantPaddingValue(); 69 if (!padValue) 70 return {}; 71 72 // Helper variables and functions for various arithmetic operations. These 73 // are used extensively for computing new offset/length and padding values. 74 Location loc = op->getLoc(); 75 AffineExpr dim0, dim1; 76 bindDims(b.getContext(), dim0, dim1); 77 // Add two integers. 78 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 79 auto add = [&](Value v1, Value v2) { 80 return b.createOrFold<AffineApplyOp>(loc, addMap, ValueRange{v1, v2}); 81 }; 82 // Subtract two integers. 83 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 84 auto sub = [&](Value v1, Value v2) { 85 return b.createOrFold<AffineApplyOp>(loc, subMap, ValueRange{v1, v2}); 86 }; 87 // Take the minimum of two integers. 88 auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); 89 auto min = [&](Value v1, Value v2) { 90 return b.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2}); 91 }; 92 // Take the maximum of two integers. 93 auto max = [&](Value v1, Value v2) { 94 return b.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2}); 95 }; 96 // Zero index-typed integer. 97 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 98 99 // Helper function for filling static/dynamic low/high padding indices 100 // vectors of PadOp. 101 auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices, 102 SmallVector<int64_t> &staticIndices) { 103 if (auto constInt = getConstantIntValue(val)) { 104 staticIndices.push_back(*constInt); 105 } else { 106 staticIndices.push_back(ShapedType::kDynamicSize); 107 dynIndices.push_back(val); 108 } 109 }; 110 111 // Compute new offsets, lengths, low padding, high padding. 112 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 113 SmallVector<Value> newLows, newHighs; 114 SmallVector<int64_t> staticNewLows, staticNewHighs; 115 // Set to true if the original data source is not read at all. 116 bool hasZeroLen = false; 117 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 118 // is true if the original data source turns out to be unused at runtime. 119 Value dynHasZeroLenCond; 120 121 int64_t rank = padOp.getSourceType().getRank(); 122 for (unsigned dim = 0; dim < rank; ++dim) { 123 auto low = 124 getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedLowPad()[dim]); 125 bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0); 126 auto high = 127 getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedHighPad()[dim]); 128 bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0); 129 auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]); 130 auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]); 131 auto srcSize = b.createOrFold<tensor::DimOp>(loc, padOp.source(), dim); 132 133 // The new amount of low padding is `low - offset`. Except for the case 134 // where none of the low padding is read. In that case, the new amount of 135 // low padding is zero. 136 // 137 // Optimization: If low = 0, then newLow = 0. 138 Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 139 appendIndex(newLow, newLows, staticNewLows); 140 141 // Start reading the data from position `offset - low`. Since the original 142 // read may have started in the low padding zone, this value could be 143 // negative. Therefore, start reading from: 144 // 145 // max(offset - low, 0) 146 // 147 // The original read could also have started in the high padding zone. 148 // In that case, set the offset to the end of source tensor. The new 149 // ExtractSliceOp length will be zero in that case. (Effectively reading 150 // no data from the source.) 151 // 152 // Optimization: If low = 0, then the formula can be simplified. 153 Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize) 154 : min(offset, srcSize); 155 newOffsets.push_back(getAsOpFoldResult(newOffset)); 156 157 // The original ExtractSliceOp was reading until position `offset + 158 // length`. Therefore, the corresponding position within the source tensor 159 // is: 160 // 161 // offset + length - low 162 // 163 // In case the original ExtractSliceOp stopped reading within the low 164 // padding zone, this value can be negative. In that case, the end 165 // position of the read should be zero. (Similar to newOffset.) 166 // 167 // The original read could also have stopped in the high padding zone. 168 // In that case, set the end positition of the read should be the end of 169 // the source tensor. (Similar to newOffset.) 170 // 171 // endLoc = min(max(offset - low + length, 0), srcSize) 172 // 173 // The new ExtractSliceOp length is `endLoc - newOffset`. 174 // 175 // Optimization: If low = 0, then the formula can be simplified. 176 Value endLoc = 177 hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) 178 : min(add(offset, length), srcSize); 179 Value newLength = sub(endLoc, newOffset); 180 newLengths.push_back(getAsOpFoldResult(newLength)); 181 182 // Check if newLength is zero. In that case, no SubTensorOp should be 183 // executed. 184 if (auto newLengthInt = getConstantIntValue(newLength)) { 185 hasZeroLen |= *newLengthInt == 0; 186 } else { 187 Value check = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 188 newLength, zero); 189 dynHasZeroLenCond = 190 dynHasZeroLenCond 191 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) 192 : check; 193 } 194 195 // The amount of high padding is simply the number of elements remaining, 196 // so that the result has the same length as the original ExtractSliceOp. 197 // As an optimization, if the original high padding is zero, then the new 198 // high padding must also be zero. 199 Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero; 200 appendIndex(newHigh, newHighs, staticNewHighs); 201 202 // Only unit stride supported. 203 newStrides.push_back(b.getIndexAttr(1)); 204 } 205 206 // The shape of the result can be obtained from the sizes passed in. 207 SmallVector<Value> dynDims; 208 SmallVector<int64_t> shape; 209 dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize); 210 RankedTensorType resultType = 211 RankedTensorType::get(shape, padOp.getResultType().getElementType()); 212 213 // Insert cast to ensure that types match. (May be folded away.) 214 auto castResult = [&](Value val) -> Operation * { 215 auto castOp = b.create<tensor::CastOp>(loc, resultType, val); 216 return castOp; 217 }; 218 219 // In cases where the original data source is unused: Emit a GenerateOp and 220 // do not generate a SliceOp. (The result shape of the SliceOp would 221 // have a dimension of size 0, the semantics of which is unclear.) 222 auto createGenerateOp = [&]() { 223 // Create GenerateOp. 224 auto generateOp = b.create<tensor::GenerateOp>( 225 loc, resultType, dynDims, 226 [&](OpBuilder &builder, Location gLoc, ValueRange indices) { 227 builder.create<tensor::YieldOp>(gLoc, padValue); 228 }); 229 return castResult(generateOp); 230 }; 231 232 // Emit a SliceOp and a PadOp. Should not be used in cases where 233 // the result shape of the new SliceOp has a zero dimension. 234 auto createPadTensorOfSubTensor = [&]() { 235 // Create pad_tensor(subtensor(x)). 236 auto newSliceOp = b.create<tensor::ExtractSliceOp>( 237 loc, padOp.source(), newOffsets, newLengths, newStrides); 238 auto newPadOp = b.create<PadOp>(loc, newSliceOp, staticNewLows, 239 staticNewHighs, newLows, newHighs); 240 241 // Copy region to new PadOp. 242 BlockAndValueMapping bvm; 243 padOp.region().cloneInto(&newPadOp.getRegion(), bvm); 244 245 // Cast result and return. 246 return castResult(newPadOp); 247 }; 248 249 // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known 250 // that the original data source x is not used. 251 if (hasZeroLen) 252 return {createGenerateOp()}; 253 254 // If there are dynamic dimensions: Generate an scf.if check to avoid 255 // creating SliceOps with result dimensions of size 0 at runtime. 256 if (dynHasZeroLenCond) { 257 auto result = b.create<scf::IfOp>( 258 loc, resultType, dynHasZeroLenCond, 259 /*thenBuilder=*/ 260 [&](OpBuilder &b, Location loc) { 261 b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0)); 262 }, 263 /*elseBuilder=*/ 264 [&](OpBuilder &b, Location loc) { 265 b.create<scf::YieldOp>(loc, 266 createPadTensorOfSubTensor()->getResult(0)); 267 }); 268 return {result}; 269 } 270 return {createPadTensorOfSubTensor()}; 271 } 272 }; 273 274 } // namespace 275 276 void mlir::tensor::registerTilingOpInterfaceExternalModels( 277 DialectRegistry ®istry) { 278 registry.addOpInterface<tensor::PadOp, PadOpTiling>(); 279 } 280