1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/RegionUtils.h" 27 #include "llvm/ADT/ScopeExit.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <type_traits> 31 32 using namespace mlir; 33 using namespace mlir::edsc; 34 using namespace mlir::edsc::intrinsics; 35 using namespace mlir::linalg; 36 37 using llvm::dbgs; 38 39 #define DEBUG_TYPE "linalg-vectorization" 40 41 /// Return the unique instance of OpType in `block` if it is indeed unique. 42 /// Return null if none or more than 1 instances exist. 43 template <typename OpType> static OpType getSingleOpOfType(Block &block) { 44 OpType res; 45 block.walk([&](OpType op) { 46 if (res) { 47 res = nullptr; 48 return WalkResult::interrupt(); 49 } 50 res = op; 51 return WalkResult::advance(); 52 }); 53 return res; 54 } 55 56 /// Helper data structure to represent the result of vectorization. 57 /// In certain specific cases, like terminators, we do not want to propagate/ 58 enum VectorizationStatus { 59 /// Op failed to vectorize. 60 Failure = 0, 61 /// Op vectorized and custom function took care of replacement logic 62 NoReplace, 63 /// Op vectorized into a new Op whose results will replace original Op's 64 /// results. 65 NewOp 66 // TODO: support values if Op vectorized to Many-Ops whose results we need to 67 // aggregate for replacement. 68 }; 69 struct VectorizationResult { 70 /// Return status from vectorizing the current op. 71 enum VectorizationStatus status = VectorizationStatus::Failure; 72 /// New vectorized operation to replace the current op. 73 /// Replacement behavior is specified by `status`. 74 Operation *newOp; 75 }; 76 77 /// Return a vector type of the same shape and element type as the (assumed) 78 /// ShapedType of `v`. 79 static VectorType extractVectorTypeFromShapedValue(Value v) { 80 auto st = v.getType().cast<ShapedType>(); 81 if (st.isa<MemRefType>() && st.getShape().empty()) 82 return VectorType(); 83 return VectorType::get(st.getShape(), st.getElementType()); 84 } 85 86 /// Build a vector.transfer_read from `source` at indices set to all `0`. 87 /// If source has rank zero, build an memref.load. 88 /// Return the produced value. 89 static Value buildVectorRead(OpBuilder &builder, Value source, 90 VectorType vectorType, AffineMap map) { 91 edsc::ScopedContext scope(builder); 92 auto shapedType = source.getType().cast<ShapedType>(); 93 if (vectorType) { 94 SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0)); 95 if (map) 96 return vector_transfer_read(vectorType, source, indices, map); 97 return vector_transfer_read(vectorType, source, indices); 98 } 99 return memref_load(source); 100 } 101 102 /// Build a vector.transfer_write of `value` into `dest` at indices set to all 103 /// `0`. If `dest` has null rank, build an memref.store. 104 /// Return the produced value or null if no value is produced. 105 static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) { 106 edsc::ScopedContext scope(builder); 107 Operation *write; 108 auto shapedType = dest.getType().cast<ShapedType>(); 109 if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) { 110 SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0)); 111 if (vectorType != value.getType()) 112 value = vector_broadcast(vectorType, value); 113 write = vector_transfer_write(value, dest, indices); 114 } else { 115 write = memref_store(value, dest); 116 } 117 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write); 118 if (!write->getResults().empty()) 119 return write->getResult(0); 120 return Value(); 121 } 122 123 /// If value of assumed VectorType has a shape different than `shape`, buil and 124 /// return a new vector.broadcast to `shape`. 125 /// Otherwise, just return value. 126 static Value broadcastIfNeeded(OpBuilder &builder, Value value, 127 ArrayRef<int64_t> shape) { 128 auto vecType = value.getType().dyn_cast<VectorType>(); 129 if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) 130 return value; 131 auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() 132 : value.getType()); 133 return builder.create<vector::BroadcastOp>( 134 builder.getInsertionPoint()->getLoc(), newVecType, value); 135 } 136 137 // Custom vectorization function type. Produce a vector form of Operation* 138 // assuming all its vectorized operands are already in the BlockAndValueMapping. 139 // Return nullptr if the Operation cannot be vectorized. 140 using CustomVectorizationHook = std::function<VectorizationResult( 141 Operation *, const BlockAndValueMapping &)>; 142 143 /// Helper function to vectorize the terminator of a `linalgOp`. New result 144 /// vector values are appended to `newResults`. Return 145 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it 146 /// should not try to map produced operations and instead return the results 147 /// using the `newResults` vector making them available to the 148 /// vectorization algorithm for RAUW. This function is meant to be used as a 149 /// CustomVectorizationHook. 150 static VectorizationResult 151 vectorizeLinalgYield(OpBuilder &builder, Operation *op, 152 const BlockAndValueMapping &bvm, LinalgOp linalgOp, 153 SmallVectorImpl<Value> &newResults) { 154 auto yieldOp = dyn_cast<linalg::YieldOp>(op); 155 if (!yieldOp) 156 return VectorizationResult{VectorizationStatus::Failure, nullptr}; 157 for (auto outputs : llvm::enumerate(yieldOp.values())) { 158 // TODO: Scan for an opportunity for reuse. 159 // TODO: use a map. 160 Value vectorValue = bvm.lookup(outputs.value()); 161 Value newResult = buildVectorWrite(builder, vectorValue, 162 linalgOp.getOutput(outputs.index())); 163 if (newResult) 164 newResults.push_back(newResult); 165 } 166 return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; 167 } 168 169 /// Generic vectorization for a single operation `op`, given already vectorized 170 /// operands carried by `bvm`. Vectorization occurs as follows: 171 /// 1. Try to apply any of the `customVectorizationHooks` and return its 172 /// result on success. 173 /// 2. Clone any constant in the current scope without vectorization: each 174 /// consumer of the constant will later determine the shape to which the 175 /// constant needs to be broadcast to. 176 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose 177 /// of the `customVectorizationHooks` to cover such cases. 178 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first 179 /// operand of maximal rank. Other operands have smaller rank and are 180 /// broadcast accordingly. It is assumed this broadcast is always legal, 181 /// otherwise, it means one of the `customVectorizationHooks` is incorrect. 182 /// 183 /// This function assumes all operands of `op` have been vectorized and are in 184 /// the `bvm` mapping. As a consequence, this function is meant to be called on 185 /// a topologically-sorted list of ops. 186 /// This function does not update `bvm` but returns a VectorizationStatus that 187 /// instructs the caller what `bvm` update needs to occur. 188 static VectorizationResult 189 vectorizeOneOp(OpBuilder &builder, Operation *op, 190 const BlockAndValueMapping &bvm, 191 ArrayRef<CustomVectorizationHook> customVectorizationHooks) { 192 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op); 193 194 // 1. Try to apply any CustomVectorizationHook. 195 if (!customVectorizationHooks.empty()) { 196 for (auto &customFunc : customVectorizationHooks) { 197 VectorizationResult result = customFunc(op, bvm); 198 if (result.status == VectorizationStatus::Failure) 199 continue; 200 return result; 201 } 202 } 203 204 // 2. Constant ops don't get vectorized but rather broadcasted at their users. 205 // Clone so that the constant is not confined to the linalgOp block . 206 if (isa<ConstantOp>(op)) 207 return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)}; 208 209 // 3. Only ElementwiseMappable are allowed in the generic vectorization. 210 if (!OpTrait::hasElementwiseMappableTraits(op)) 211 return VectorizationResult{VectorizationStatus::Failure, nullptr}; 212 213 // 4. Generic vectorization path for ElementwiseMappable ops. 214 // a. first get the first max ranked shape. 215 SmallVector<int64_t, 4> firstMaxRankedShape; 216 for (Value operand : op->getOperands()) { 217 auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>(); 218 if (vt && firstMaxRankedShape.size() < vt.getShape().size()) 219 firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); 220 } 221 // b. broadcast each op if needed. 222 auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) { 223 return firstMaxRankedShape.empty() 224 ? bvm.lookup(v) 225 : broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape); 226 }); 227 // c. for elementwise, the result is the vector with the firstMaxRankedShape 228 auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) { 229 return firstMaxRankedShape.empty() 230 ? t 231 : VectorType::get(firstMaxRankedShape, t); 232 }); 233 234 // Build and return the new op. 235 OperationState state(op->getLoc(), op->getName()); 236 state.addAttributes(op->getAttrs()); 237 state.addOperands(llvm::to_vector<4>(vectorizedOperands)); 238 state.addTypes(llvm::to_vector<4>(returnTypes)); 239 return VectorizationResult{VectorizationStatus::NewOp, 240 builder.createOperation(state)}; 241 } 242 243 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. 244 static bool hasOnlyScalarElementwiseOp(Region &r) { 245 if (!llvm::hasSingleElement(r)) 246 return false; 247 for (Operation &op : r.front()) { 248 if (!(isa<ConstantOp, linalg::YieldOp>(op) || 249 OpTrait::hasElementwiseMappableTraits(&op)) || 250 llvm::any_of(op.getResultTypes(), 251 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 252 return false; 253 } 254 return true; 255 } 256 257 // Return true if the op is an element-wise linalg op. 258 static bool isElementwise(Operation *op) { 259 auto linalgOp = dyn_cast<linalg::LinalgOp>(op); 260 if (!linalgOp) 261 return false; 262 if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) 263 return false; 264 // TODO: relax the restrictions on indexing map. 265 for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { 266 if (!linalgOp.getOutputIndexingMap(i).isIdentity()) 267 return false; 268 } 269 if (linalgOp->getNumRegions() != 1) 270 return false; 271 return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); 272 } 273 274 // Calculate the map to apply to transfer_read to convert the input shape into 275 // the output shape. 276 static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) { 277 AffineMap linalgMap = linalgOp.getIndexingMap(argIndex); 278 MLIRContext *context = linalgMap.getContext(); 279 AffineExpr zero = mlir::getAffineConstantExpr(0, context); 280 SmallVector<AffineExpr, 4> exprs(linalgMap.getNumInputs(), zero); 281 for (unsigned i : llvm::seq(unsigned(0), linalgMap.getNumResults())) { 282 exprs[linalgMap.getDimPosition(i)] = getAffineDimExpr(i, context); 283 } 284 return AffineMap::get(linalgMap.getNumResults(), /*symbolCount=*/0, exprs, 285 context); 286 } 287 288 /// Generic vectorization function that rewrites the body of a `linalgOp` into 289 /// vector form. Generic vectorization proceeds as follows: 290 /// 1. Verify the `linalgOp` has one non-empty region. 291 /// 2. Values defined above the region are mapped to themselves and will be 292 /// broadcasted on a per-need basis by their consumers. 293 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d 294 /// load). 295 /// TODO: Reuse opportunities for RAR dependencies. 296 /// 4. Register CustomVectorizationHook for YieldOp to capture the results. 297 /// 5. Iteratively call vectorizeOneOp on the region operations. 298 LogicalResult vectorizeAsLinalgGeneric( 299 OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults, 300 ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) { 301 // 1. Fail to vectorize if the operation does not have one non-empty region. 302 if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty()) 303 return failure(); 304 auto &block = linalgOp->getRegion(0).front(); 305 306 BlockAndValueMapping bvm; 307 // 2. Values defined above the region can only be broadcast for now. Make them 308 // map to themselves. 309 SetVector<Value> valuesSet; 310 mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); 311 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef()); 312 313 // 3. Turn all BBArgs into vector.transfer_read / load. 314 SmallVector<AffineMap> indexings; 315 for (auto bbarg : block.getArguments()) { 316 Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber()); 317 AffineMap map; 318 VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg); 319 if (isElementwise(linalgOp) && 320 !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) { 321 // Currently assume we don't support output permutations. 322 assert(linalgOp.getNumOutputs() > 0 && 323 linalgOp.getOutputIndexingMap(0).isIdentity()); 324 ArrayRef<int64_t> outputShape = 325 linalgOp.getOutputShapedType(0).getShape(); 326 vectorType = VectorType::get(outputShape, vectorType.getElementType()); 327 map = getTransferReadMap(linalgOp, bbarg.getArgNumber()); 328 } 329 Value vectorRead = buildVectorRead(builder, vectorArg, vectorType, map); 330 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" 331 << bbarg.getArgNumber() << "): " << vectorRead); 332 bvm.map(bbarg, vectorRead); 333 bvm.map(vectorArg, vectorRead); 334 } 335 336 // 4. Register CustomVectorizationHook for yieldOp. 337 CustomVectorizationHook vectorizeYield = 338 [&](Operation *op, 339 const BlockAndValueMapping &bvm) -> VectorizationResult { 340 return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults); 341 }; 342 // Append the vectorizeYield hook. 343 auto hooks = llvm::to_vector<4>(customVectorizationHooks); 344 hooks.push_back(vectorizeYield); 345 346 // 5. Iteratively call `vectorizeOneOp` to each op in the slice. 347 for (Operation &op : block.getOperations()) { 348 VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks); 349 if (result.status == VectorizationStatus::Failure) { 350 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op); 351 return failure(); 352 } 353 if (result.status == VectorizationStatus::NewOp) { 354 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: " 355 << *result.newOp;); 356 bvm.map(op.getResults(), result.newOp->getResults()); 357 } 358 } 359 360 return success(); 361 } 362 363 static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp, 364 SmallVectorImpl<Value> &newResults) { 365 assert(isaContractionOpInterface(linalgOp) && 366 "expected vectorizeContraction preconditions to be met"); 367 Location loc = linalgOp.getLoc(); 368 // Vectorize other ops as vector contraction. 369 // TODO: interface. 370 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " 371 << "Rewrite linalg op as vector.contract: "; 372 linalgOp.dump()); 373 // Special function that describes how to vectorize the multiplication op in a 374 // linalg contraction. 375 CustomVectorizationHook vectorizeContraction = 376 [&](Operation *op, 377 const BlockAndValueMapping &bvm) -> VectorizationResult { 378 if (!isa<MulIOp, MulFOp>(op)) 379 return VectorizationResult{VectorizationStatus::Failure, nullptr}; 380 auto outShape = linalgOp.getOutputShapedType(0).getShape(); 381 auto vType = outShape.empty() 382 ? op->getResult(0).getType() 383 : VectorType::get(outShape, op->getResult(0).getType()); 384 auto zero = 385 builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType)); 386 Operation *contract = builder.create<vector::ContractionOp>( 387 loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero, 388 linalgOp.indexing_maps(), linalgOp.iterator_types()); 389 return VectorizationResult{VectorizationStatus::NewOp, contract}; 390 }; 391 return vectorizeAsLinalgGeneric(builder, linalgOp, newResults, 392 {vectorizeContraction}); 393 } 394 395 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 396 auto linalgOp = cast<linalg::LinalgOp>(op); 397 // All types must be static shape to go to vector. 398 for (Value operand : linalgOp.getShapedOperands()) 399 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 400 return failure(); 401 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 402 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 403 return failure(); 404 // TODO: remove once index ops are supported. 405 if (linalgOp.hasIndexSemantics()) 406 return failure(); 407 if (isElementwise(op)) 408 return success(); 409 return success(isaContractionOpInterface(linalgOp)); 410 } 411 412 LogicalResult 413 mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op, 414 SmallVectorImpl<Value> &newResults) { 415 if (failed(vectorizeLinalgOpPrecondition(op))) 416 return failure(); 417 418 edsc::ScopedContext scope(builder, op->getLoc()); 419 if (isElementwise(op)) { 420 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " 421 << "Vectorize linalg op as a generic: " << *op); 422 return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op), newResults); 423 } 424 425 return vectorizeContraction(builder, cast<LinalgOp>(op), newResults); 426 } 427 428 //----------------------------------------------------------------------------// 429 // Misc. vectorization patterns. 430 //----------------------------------------------------------------------------// 431 432 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and 433 /// TransferWriteOp. For now, this only applies when all low and high paddings 434 /// are determined to be zero. 435 LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite( 436 linalg::PadTensorOp padOp, PatternRewriter &rewriter) const { 437 // Helper function to determine whether an OpFoldResult is not a zero Index. 438 auto isNotZeroIndex = [](OpFoldResult ofr) { 439 if (Attribute attr = ofr.dyn_cast<Attribute>()) 440 return attr.cast<IntegerAttr>().getInt() != 0; 441 Value v = ofr.get<Value>(); 442 if (auto constOp = v.getDefiningOp<ConstantOp>()) 443 if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>()) 444 return intAttr.getValue().getSExtValue() != 0; 445 return true; 446 }; 447 448 auto resultShapedType = padOp.result().getType().cast<ShapedType>(); 449 // Bail on non-static shapes. 450 if (!resultShapedType.hasStaticShape()) 451 return failure(); 452 453 // If any pad_low is not a static 0, needs a mask. Bail for now. 454 if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex)) 455 return failure(); 456 VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result()); 457 if (!vectorType) 458 return failure(); 459 460 // Only support padding with a constant for now, i.e. either: 461 // 1. A BBarg from a different block. 462 // 2. A value defined outside of the current block. 463 Block &block = padOp.region().front(); 464 auto yieldOp = cast<YieldOp>(block.getTerminator()); 465 assert(yieldOp.getNumOperands() == 1 && "expected single operand yield"); 466 Value padValue = yieldOp.values().front(); 467 Operation *definingOp = padValue.getDefiningOp(); 468 if (definingOp && definingOp->getBlock() == &block) 469 return failure(); 470 if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block) 471 return failure(); 472 473 // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail. 474 if (llvm::any_of(padOp.getMixedHighPad(), 475 [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); })) 476 return failure(); 477 478 // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] + 479 // TransferWriteOp@[0..0]. 480 SmallVector<Value> indices( 481 resultShapedType.getRank(), 482 rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0)); 483 Value read = rewriter.create<vector::TransferReadOp>( 484 padOp.getLoc(), vectorType, padOp.source(), indices, padValue); 485 Value init = 486 rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(), 487 resultShapedType.getElementType()); 488 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init, 489 indices); 490 491 return success(); 492 } 493 494 // TODO: cleanup all the convolution vectorization patterns. 495 template <class ConvOp, int N> 496 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite( 497 ConvOp op, PatternRewriter &rewriter) const { 498 Location loc = op.getLoc(); 499 MLIRContext *context = op.getContext(); 500 edsc::ScopedContext scope(rewriter, loc); 501 502 ShapedType inShapeType = op.getInputShapedType(0); 503 ShapedType kShapeType = op.getInputShapedType(1); 504 505 ArrayRef<int64_t> inShape = inShapeType.getShape(); 506 ArrayRef<int64_t> kShape = kShapeType.getShape(); 507 508 if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) 509 return failure(); 510 511 SmallVector<AffineExpr, 4> mapping; 512 SmallVector<int64_t, 4> vectorDims; 513 // Fail to apply when the size of not vectorized dimension is not 1. 514 for (unsigned i = 0; i < N; i++) { 515 if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) 516 return failure(); 517 518 if (mask[i] && inShape[i] != kShape[i]) 519 return failure(); 520 521 if (mask[i]) { 522 mapping.push_back(getAffineDimExpr(i, context)); 523 vectorDims.push_back(inShape[i]); 524 } 525 } 526 527 Value input = op.getInput(0); 528 Value kernel = op.getInput(1); 529 Value output = op.getOutputBuffer(0); 530 531 unsigned rank = inShapeType.getRank(); 532 unsigned numDims = mapping.size(); 533 Type elemType = inShapeType.getElementType(); 534 535 auto map = AffineMap::get(rank, 0, mapping, context); 536 SmallVector<Value, 4> zeros(rank, std_constant_index(0)); 537 auto vecType = VectorType::get(vectorDims, elemType); 538 539 auto inputVec = vector_transfer_read(vecType, input, zeros, map); 540 auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); 541 542 auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); 543 544 std::array<AffineMap, 3> indexingMaps{ 545 AffineMap::getMultiDimIdentityMap(numDims, context), 546 AffineMap::getMultiDimIdentityMap(numDims, context), 547 AffineMap::get(numDims, 0, {}, context)}; 548 549 std::vector<StringRef> iteratorTypes(numDims, "reduction"); 550 551 auto result = rewriter.create<vector::ContractionOp>( 552 loc, inputVec, kernelVec, acc, 553 rewriter.getAffineMapArrayAttr(indexingMaps), 554 rewriter.getStrArrayAttr(iteratorTypes)); 555 556 rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros)); 557 rewriter.eraseOp(op); 558 return success(); 559 } 560 561 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>; 562 563 /// Inserts tiling, promotion and vectorization pattern for ConvOp 564 /// conversion into corresponding pattern lists. 565 template <typename ConvOp, unsigned N> 566 static void populateVectorizationPatterns( 567 RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns, 568 RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) { 569 auto *context = tilingPatterns.getContext(); 570 if (tileSizes.size() < N) 571 return; 572 573 constexpr static StringRef kTiledMarker = "TILED"; 574 constexpr static StringRef kPromotedMarker = "PROMOTED"; 575 tilingPatterns.add<LinalgTilingPattern<ConvOp>>( 576 context, LinalgTilingOptions().setTileSizes(tileSizes), 577 LinalgTransformationFilter(ArrayRef<Identifier>{}, 578 Identifier::get(kTiledMarker, context))); 579 580 promotionPatterns.add<LinalgPromotionPattern<ConvOp>>( 581 context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 582 LinalgTransformationFilter(Identifier::get(kTiledMarker, context), 583 Identifier::get(kPromotedMarker, context))); 584 585 SmallVector<bool, 4> mask(N); 586 int offset = tileSizes.size() - N; 587 std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), 588 [](int64_t i) -> bool { return i > 1; }); 589 590 vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask); 591 } 592 593 void mlir::linalg::populateConvVectorizationPatterns( 594 MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns, 595 ArrayRef<int64_t> tileSizes) { 596 RewritePatternSet tiling(context); 597 RewritePatternSet promotion(context); 598 RewritePatternSet vectorization(context); 599 populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization, 600 tileSizes); 601 602 populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization, 603 tileSizes); 604 populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>( 605 tiling, promotion, vectorization, tileSizes); 606 607 populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization, 608 tileSizes); 609 populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>( 610 tiling, promotion, vectorization, tileSizes); 611 612 populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization, 613 tileSizes); 614 615 populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization, 616 tileSizes); 617 populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>( 618 tiling, promotion, vectorization, tileSizes); 619 620 populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization, 621 tileSizes); 622 populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>( 623 tiling, promotion, vectorization, tileSizes); 624 625 populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization, 626 tileSizes); 627 628 populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion, 629 vectorization, tileSizes); 630 populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>( 631 tiling, promotion, vectorization, tileSizes); 632 633 populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion, 634 vectorization, tileSizes); 635 populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>( 636 tiling, promotion, vectorization, tileSizes); 637 638 patterns.push_back(std::move(tiling)); 639 patterns.push_back(std::move(promotion)); 640 patterns.push_back(std::move(vectorization)); 641 } 642 643 //----------------------------------------------------------------------------// 644 // Forwarding patterns 645 //----------------------------------------------------------------------------// 646 647 /// Check whether there is any interleaved use of any `values` between `firstOp` 648 /// and `secondOp`. Conservatively return `true` if any op or value is in a 649 /// different block. 650 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 651 ValueRange values) { 652 if (firstOp->getBlock() != secondOp->getBlock() || 653 !firstOp->isBeforeInBlock(secondOp)) { 654 LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " 655 << "interleavedUses precondition failed, firstOp: " 656 << *firstOp << ", second op: " << *secondOp); 657 return true; 658 } 659 for (auto v : values) { 660 for (auto &u : v.getUses()) { 661 Operation *owner = u.getOwner(); 662 if (owner == firstOp || owner == secondOp) 663 continue; 664 // TODO: this is too conservative, use dominance info in the future. 665 if (owner->getBlock() == firstOp->getBlock() && 666 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 667 continue; 668 LLVM_DEBUG(llvm::dbgs() 669 << "\n[" DEBUG_TYPE "]: " 670 << " found interleaved op " << *owner 671 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 672 return true; 673 } 674 } 675 return false; 676 } 677 678 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 679 static memref::SubViewOp getSubViewUseIfUnique(Value v) { 680 memref::SubViewOp subViewOp; 681 for (auto &u : v.getUses()) { 682 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) { 683 if (subViewOp) 684 return memref::SubViewOp(); 685 subViewOp = newSubViewOp; 686 } 687 } 688 return subViewOp; 689 } 690 691 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 692 /// when available. 693 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 694 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 695 696 // Transfer into `view`. 697 Value viewOrAlloc = xferOp.source(); 698 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && 699 !viewOrAlloc.getDefiningOp<memref::AllocOp>()) 700 return failure(); 701 702 LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc); 703 704 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 705 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 706 if (!subViewOp) 707 return failure(); 708 Value subView = subViewOp.getResult(); 709 LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " 710 << "with subView " << subView); 711 712 // Find the copy into `subView` without interleaved uses. 713 CopyOp copyOp; 714 for (auto &u : subView.getUses()) { 715 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 716 if (newCopyOp.getOutputBuffer(0) != subView) 717 continue; 718 LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " 719 << "copy candidate " << *newCopyOp); 720 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 721 continue; 722 copyOp = newCopyOp; 723 break; 724 } 725 } 726 if (!copyOp) 727 return failure(); 728 LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " 729 << "with copy " << *copyOp); 730 731 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 732 FillOp maybeFillOp; 733 for (auto &u : viewOrAlloc.getUses()) { 734 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 735 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 736 continue; 737 LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " 738 << "fill candidate " << *newFillOp); 739 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 740 continue; 741 maybeFillOp = newFillOp; 742 break; 743 } 744 } 745 // Ensure padding matches. 746 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 747 return failure(); 748 if (maybeFillOp) 749 LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " 750 << "with maybeFillOp " << *maybeFillOp); 751 752 // `in` is the subview that linalg.copy reads. Replace it. 753 Value in = copyOp.getInput(0); 754 755 // linalg.copy + linalg.fill can be used to create a padded local buffer. 756 // The `masked` attribute is only valid on this padded buffer. 757 // When forwarding to vector.transfer_read, the attribute must be reset 758 // conservatively. 759 Value res = rewriter.create<vector::TransferReadOp>( 760 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 761 xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); 762 763 if (maybeFillOp) 764 rewriter.eraseOp(maybeFillOp); 765 rewriter.eraseOp(copyOp); 766 rewriter.replaceOp(xferOp, res); 767 768 return success(); 769 } 770 771 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 772 /// when available. 773 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 774 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 775 // Transfer into `viewOrAlloc`. 776 Value viewOrAlloc = xferOp.source(); 777 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && 778 !viewOrAlloc.getDefiningOp<memref::AllocOp>()) 779 return failure(); 780 781 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 782 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 783 if (!subViewOp) 784 return failure(); 785 Value subView = subViewOp.getResult(); 786 787 // Find the copy from `subView` without interleaved uses. 788 CopyOp copyOp; 789 for (auto &u : subViewOp.getResult().getUses()) { 790 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 791 if (newCopyOp.getInput(0) != subView) 792 continue; 793 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 794 continue; 795 copyOp = newCopyOp; 796 break; 797 } 798 } 799 if (!copyOp) 800 return failure(); 801 802 // `out` is the subview copied into that we replace. 803 Value out = copyOp.getOutputBuffer(0); 804 805 // Forward vector.transfer into copy. 806 // linalg.copy + linalg.fill can be used to create a padded local buffer. 807 // The `masked` attribute is only valid on this padded buffer. 808 // When forwarding to vector.transfer_write, the attribute must be reset 809 // conservatively. 810 rewriter.create<vector::TransferWriteOp>( 811 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 812 xferOp.permutation_map(), ArrayAttr()); 813 814 rewriter.eraseOp(copyOp); 815 rewriter.eraseOp(xferOp); 816 817 return success(); 818 } 819