1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/RegionUtils.h" 27 #include "llvm/ADT/ScopeExit.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <type_traits> 31 32 using namespace mlir; 33 using namespace mlir::edsc; 34 using namespace mlir::edsc::intrinsics; 35 using namespace mlir::linalg; 36 37 using llvm::dbgs; 38 39 #define DEBUG_TYPE "linalg-vectorization" 40 41 /// Return true if the use-def chain from `v` to `from` consists of 0 or more 42 /// unary single-operand operations. 43 // TODO: relax to multi-operands with constants, which are technically unary ops 44 // as needed (e.g. add5). 45 static bool isChainOfUnaryOpsFrom(Value v, Value from) { 46 while (v != from) { 47 Operation *op = v.getDefiningOp(); 48 if (!op || op->getNumOperands() != 1) 49 return false; 50 v = op->getOperand(0); 51 }; 52 return true; 53 } 54 55 /// Return the unique instance of OpType in `block` if it is indeed unique. 56 /// Return null if none or more than 1 instances exist. 57 template <typename OpType> 58 static OpType getSingleOpOfType(Block &block) { 59 OpType res; 60 block.walk([&](OpType op) { 61 if (res) { 62 res = nullptr; 63 return WalkResult::interrupt(); 64 } 65 res = op; 66 return WalkResult::advance(); 67 }); 68 return res; 69 } 70 71 /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))` 72 /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent 73 /// unary operations that may change the type. 74 template <typename AddOpType, typename MulOpType> 75 static bool isAddMul(Block &block) { 76 if (block.getNumArguments() != 3) 77 return false; 78 Operation *yieldOp = block.getTerminator(); 79 if (yieldOp->getNumOperands() != 1) 80 return false; 81 82 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isAddMul: "; block.dump()); 83 AddOpType addOp = getSingleOpOfType<AddOpType>(block); 84 MulOpType mulOp = getSingleOpOfType<MulOpType>(block); 85 if (!addOp || !mulOp) 86 return false; 87 88 Value argA = block.getArgument(0), argB = block.getArgument(1); 89 Value a = mulOp->getOperand(0), b = mulOp->getOperand(1); 90 Value mul = mulOp->getResult(0); 91 Value argC = block.getArgument(2); 92 Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1); 93 Value add = addOp->getResult(0); 94 Value res = yieldOp->getOperand(0); 95 // Result traces back to add. 96 auto un = isChainOfUnaryOpsFrom; 97 bool success = un(res, add); 98 // One of the operands of add traces back to argC, the other to the mul. 99 success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC)); 100 // One of the operands of mul traces back to argA, the other to argB. 101 success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA)); 102 return success; 103 } 104 105 /// Helper data structure to represent the result of vectorization. 106 /// In certain specific cases, like terminators, we do not want to propagate/ 107 enum VectorizationStatus { 108 /// Op failed to vectorize. 109 Failure = 0, 110 /// Op vectorized and custom function took care of replacement logic 111 NoReplace, 112 /// Op vectorized into a new Op whose results will replace original Op's 113 /// results. 114 NewOp 115 // TODO: support values if Op vectorized to Many-Ops whose results we need to 116 // aggregate for replacement. 117 }; 118 struct VectorizationResult { 119 /// Return status from vectorizing the current op. 120 enum VectorizationStatus status = VectorizationStatus::Failure; 121 /// New vectorized operation to replace the current op. 122 /// Replacement behavior is specified by `status`. 123 Operation *newOp; 124 }; 125 126 /// Return a vector type of the same shape and element type as the (assumed) 127 /// ShapedType of `v`. 128 static VectorType extractVectorTypeFromShapedValue(Value v) { 129 auto st = v.getType().cast<ShapedType>(); 130 if (st.isa<MemRefType>() && st.getShape().empty()) 131 return VectorType(); 132 return VectorType::get(st.getShape(), st.getElementType()); 133 } 134 135 /// Build a vector.transfer_read from `source` at indices set to all `0`. 136 /// If source has rank zero, build an std.load. 137 /// Return the produced value. 138 static Value buildVectorRead(OpBuilder &builder, Value source) { 139 edsc::ScopedContext scope(builder); 140 auto shapedType = source.getType().cast<ShapedType>(); 141 if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) { 142 SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0)); 143 return vector_transfer_read(vectorType, source, indices); 144 } 145 return std_load(source); 146 } 147 148 /// Build a vector.transfer_write of `value` into `dest` at indices set to all 149 /// `0`. If `dest` has null rank, build an std.store. 150 /// Return the produced value or null if no value is produced. 151 static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) { 152 edsc::ScopedContext scope(builder); 153 Operation *write; 154 auto shapedType = dest.getType().cast<ShapedType>(); 155 if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) { 156 SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0)); 157 if (vectorType != value.getType()) 158 value = vector_broadcast(vectorType, value); 159 write = vector_transfer_write(value, dest, indices); 160 } else { 161 write = std_store(value, dest); 162 } 163 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write); 164 if (!write->getResults().empty()) 165 return write->getResult(0); 166 return Value(); 167 } 168 169 /// If value of assumed VectorType has a shape different than `shape`, buil and 170 /// return a new vector.broadcast to `shape`. 171 /// Otherwise, just return value. 172 static Value broadcastIfNeeded(OpBuilder &builder, Value value, 173 ArrayRef<int64_t> shape) { 174 auto vecType = value.getType().dyn_cast<VectorType>(); 175 if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) 176 return value; 177 auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() 178 : value.getType()); 179 return builder.create<vector::BroadcastOp>( 180 builder.getInsertionPoint()->getLoc(), newVecType, value); 181 } 182 183 // Custom vectorization function type. Produce a vector form of Operation* 184 // assuming all its vectorized operands are already in the BlockAndValueMapping. 185 // Return nullptr if the Operation cannot be vectorized. 186 using CustomVectorizationHook = std::function<VectorizationResult( 187 Operation *, const BlockAndValueMapping &)>; 188 189 /// Helper function to vectorize the terminator of a `linalgOp`. New result 190 /// vector values are appended to `results`. 191 /// Return VectorizationStatus::NoReplace to signal the vectorization algorithm 192 /// that it should not try to map produced operations: this is the purpose of 193 /// the `results` argument to capture such values and make them available for 194 /// RAUW to the vectorization algorithm. 195 /// This function is meant to be used as a CustomVectorizationHook. 196 static VectorizationResult 197 vectorizeLinalgYield(OpBuilder &builder, Operation *op, 198 const BlockAndValueMapping &bvm, LinalgOp linalgOp, 199 SmallVectorImpl<Value> &results) { 200 auto yieldOp = dyn_cast<linalg::YieldOp>(op); 201 if (!yieldOp) 202 return VectorizationResult{VectorizationStatus::Failure, nullptr}; 203 for (auto outputs : llvm::enumerate(yieldOp.values())) { 204 // TODO: Scan for an opportunity for reuse. 205 // TODO: use a map. 206 Value vectorValue = bvm.lookup(outputs.value()); 207 Value result = buildVectorWrite(builder, vectorValue, 208 linalgOp.getOutput(outputs.index())); 209 if (result) 210 results.push_back(result); 211 } 212 return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; 213 } 214 215 /// Generic vectorization for a single operation `op`, given already vectorized 216 /// operands carried by `bvm`. Vectorization occurs as follows: 217 /// 1. Try to apply any of the `customVectorizationHooks` and return its 218 /// result on success. 219 /// 2. Clone any constant in the current scope without vectorization: each 220 /// consumer of the constant will later determine the shape to which the 221 /// constant needs to be broadcast to. 222 /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose 223 /// of the `customVectorizationHooks` to cover such cases. 224 /// 4. Clone `op` in vector form to a vector of shape prescribed by the first 225 /// operand of maximal rank. Other operands have smaller rank and are 226 /// broadcast accordingly. It is assumed this broadcast is always legal, 227 /// otherwise, it means one of the `customVectorizationHooks` is incorrect. 228 /// 229 /// This function assumes all operands of `op` have been vectorized and are in 230 /// the `bvm` mapping. As a consequence, this function is meant to be called on 231 /// a topologically-sorted list of ops. 232 /// This function does not update `bvm` but returns a VectorizationStatus that 233 /// instructs the caller what `bvm` update needs to occur. 234 static VectorizationResult 235 vectorizeOneOp(OpBuilder &builder, Operation *op, 236 const BlockAndValueMapping &bvm, 237 ArrayRef<CustomVectorizationHook> customVectorizationHooks) { 238 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op); 239 240 // 1. Try to apply any CustomVectorizationHook. 241 if (!customVectorizationHooks.empty()) { 242 for (auto &customFunc : customVectorizationHooks) { 243 VectorizationResult result = customFunc(op, bvm); 244 if (result.status == VectorizationStatus::Failure) 245 continue; 246 return result; 247 } 248 } 249 250 // 2. Constant ops don't get vectorized but rather broadcasted at their users. 251 // Clone so that the constant is not confined to the linalgOp block . 252 if (isa<ConstantOp>(op)) 253 return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)}; 254 255 // 3. Only ElementwiseMappable are allowed in the generic vectorization. 256 if (!op->hasTrait<OpTrait::ElementwiseMappable>()) 257 return VectorizationResult{VectorizationStatus::Failure, nullptr}; 258 259 // 4. Generic vectorization path for ElementwiseMappable ops. 260 // a. first get the first max ranked shape. 261 SmallVector<int64_t, 4> firstMaxRankedShape; 262 for (Value operand : op->getOperands()) { 263 auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>(); 264 if (vt && firstMaxRankedShape.size() < vt.getShape().size()) 265 firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); 266 } 267 // b. broadcast each op if needed. 268 auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) { 269 return firstMaxRankedShape.empty() 270 ? bvm.lookup(v) 271 : broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape); 272 }); 273 // c. for elementwise, the result is the vector with the firstMaxRankedShape 274 auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) { 275 return firstMaxRankedShape.empty() 276 ? t 277 : VectorType::get(firstMaxRankedShape, t); 278 }); 279 280 // Build and return the new op. 281 OperationState state(op->getLoc(), op->getName()); 282 state.addAttributes(op->getAttrs()); 283 state.addOperands(llvm::to_vector<4>(vectorizedOperands)); 284 state.addTypes(llvm::to_vector<4>(returnTypes)); 285 return VectorizationResult{VectorizationStatus::NewOp, 286 builder.createOperation(state)}; 287 } 288 289 /// Generic vectorization function that rewrites the body of a `linalgOp` into 290 /// vector form. Generic vectorization proceeds as follows: 291 /// 1. The region for the linalg op is created if necessary. 292 /// 2. Values defined above the region are mapped to themselves and will be 293 /// broadcasted on a per-need basis by their consumers. 294 /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d 295 /// load). 296 /// TODO: Reuse opportunities for RAR dependencies. 297 /// 4. Register CustomVectorizationHook for YieldOp to capture the results. 298 /// 5. Iteratively call vectorizeOneOp on the region operations. 299 /// 6. RAUW the linalg op by the results captured vectorizing the YieldOp. 300 static LogicalResult vectorizeAsLinalgGeneric( 301 OpBuilder &builder, LinalgOp linalgOp, 302 ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) { 303 // 1. Certain Linalg ops do not have a region but only a region builder. 304 // If so, build the region so we can vectorize. 305 std::unique_ptr<Region> owningRegion; 306 Region *region; 307 if (linalgOp->getNumRegions() > 0) { 308 region = &linalgOp->getRegion(0); 309 } else { 310 // RAII avoid remaining in block. 311 OpBuilder::InsertionGuard g(builder); 312 owningRegion = std::make_unique<Region>(); 313 region = owningRegion.get(); 314 Block *block = builder.createBlock(region); 315 auto elementTypes = llvm::to_vector<4>( 316 llvm::map_range(linalgOp.getShapedOperandTypes(), 317 [](ShapedType t) { return t.getElementType(); })); 318 block->addArguments(elementTypes); 319 linalgOp.getRegionBuilder()(*block); 320 } 321 Block *block = ®ion->front(); 322 323 BlockAndValueMapping bvm; 324 // 2. Values defined above the region can only be broadcast for now. Make them 325 // map to themselves. 326 llvm::SetVector<Value> valuesSet; 327 mlir::getUsedValuesDefinedAbove(*region, valuesSet); 328 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef()); 329 330 // 3. Turn all BBArgs into vector.transfer_read / load. 331 SmallVector<AffineMap> indexings; 332 for (auto bbarg : block->getArguments()) { 333 Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber()); 334 Value vectorRead = buildVectorRead(builder, vectorArg); 335 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" 336 << bbarg.getArgNumber() << "): " << vectorRead); 337 bvm.map(bbarg, vectorRead); 338 bvm.map(vectorArg, vectorRead); 339 } 340 341 // 4. Register CustomVectorizationHook for yieldOp. 342 SmallVector<Value> results; 343 CustomVectorizationHook vectorizeYield = 344 [&](Operation *op, 345 const BlockAndValueMapping &bvm) -> VectorizationResult { 346 return vectorizeLinalgYield(builder, op, bvm, linalgOp, results); 347 }; 348 // Append the vectorizeYield hook. 349 auto hooks = llvm::to_vector<4>(customVectorizationHooks); 350 hooks.push_back(vectorizeYield); 351 352 // 5. Iteratively call `vectorizeOneOp` to each op in the slice. 353 for (Operation &op : block->getOperations()) { 354 VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks); 355 if (result.status == VectorizationStatus::Failure) { 356 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op); 357 return failure(); 358 } 359 if (result.status == VectorizationStatus::NewOp) { 360 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: " 361 << *result.newOp;); 362 bvm.map(op.getResults(), result.newOp->getResults()); 363 } 364 } 365 366 // 6. RAUW the linalg op by the results captured vectorizing the YieldOp. 367 if (!results.empty()) 368 linalgOp->replaceAllUsesWith(results); 369 return success(); 370 } 371 372 /// Detect whether the LinalgOp `op` is a contraction. 373 /// A Linalg contraction is defined in general terms: 374 /// 1. Has 2 input and 1 output shapes. 375 /// 2. Has at least one reduction dimension. 376 /// 3. Has only projected permutation indexing maps. 377 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field 378 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary 379 /// operations that may change the type (e.g. for mixed-precision). 380 /// As a consequence, when vectorization of such an op occurs, the only special 381 /// behavior is that the (unique) MulOpType is vectorized into a 382 /// `vector.contract`. All other ops are handled in a generic fashion. 383 /// In the future, we may wish to allow more input arguments and elementwise and 384 /// constant operations that do not involve the reduction dimension(s). 385 static LogicalResult isContraction(Operation *op) { 386 LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isContraction: "; op->dump()); 387 auto linalgOp = dyn_cast<linalg::LinalgOp>(op); 388 if (!linalgOp) 389 return failure(); 390 391 auto mapRange = linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>(); 392 return success( 393 linalgOp.getNumInputs() == 2 && linalgOp.getNumOutputs() == 1 && 394 linalgOp.getNumReductionLoops() > 0 && 395 llvm::all_of(mapRange, 396 [](AffineMap m) { return m.isProjectedPermutation(); }) && 397 // TODO: more fields than add/mul. 398 (isAddMul<AddFOp, MulFOp>(linalgOp->getRegion(0).front()) || 399 isAddMul<AddIOp, MulIOp>(linalgOp->getRegion(0).front()))); 400 } 401 402 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. 403 static bool hasOnlyScalarElementwiseOp(Region &r) { 404 if (!llvm::hasSingleElement(r)) 405 return false; 406 for (Operation &op : r.front()) { 407 if (!(isa<ConstantOp, linalg::YieldOp>(op) || 408 op.hasTrait<OpTrait::ElementwiseMappable>()) || 409 llvm::any_of(op.getResultTypes(), 410 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 411 return false; 412 } 413 return true; 414 } 415 416 // Return true if the op is an element-wise linalg op. 417 static bool isElementwise(Operation *op) { 418 auto genericOp = dyn_cast<linalg::GenericOp>(op); 419 if (!genericOp) 420 return false; 421 if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) 422 return false; 423 // TODO: relax the restrictions on indexing map. 424 for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) { 425 if (!genericOp.getOutputIndexingMap(i).isIdentity()) 426 return false; 427 } 428 // Currently bound the input indexing map to minor identity as other 429 // permutations might require adding transpose ops to convert the vector read 430 // to the right shape. 431 for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) { 432 if (!genericOp.getInputIndexingMap(i).isMinorIdentity()) 433 return false; 434 } 435 return hasOnlyScalarElementwiseOp(genericOp.getRegion()); 436 } 437 438 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 439 auto linalgOp = cast<linalg::LinalgOp>(op); 440 // All types must be static shape to go to vector. 441 for (Value operand : linalgOp.getShapedOperands()) 442 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 443 return failure(); 444 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 445 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 446 return failure(); 447 448 if (isa<linalg::FillOp, linalg::CopyOp>(op)) 449 return success(); 450 if (isElementwise(op)) 451 return success(); 452 return isContraction(op); 453 } 454 455 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 456 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 457 458 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 459 (void)dbgPref; 460 edsc::ScopedContext scope(builder, op->getLoc()); 461 // In the case of 0-D memrefs, return null and special case to scalar load or 462 // store later. 463 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 464 // Vectorize fill as a vector.broadcast. 465 LLVM_DEBUG(dbgs() << dbgPref 466 << "Rewrite linalg.fill as vector.broadcast: " << *op); 467 buildVectorWrite(builder, fillOp.value(), fillOp.output()); 468 return; 469 } 470 if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) { 471 // Vectorize copy as a vector.transfer_read+vector.transfer_write. 472 LLVM_DEBUG(dbgs() << dbgPref 473 << "Rewrite linalg.copy as vector.transfer_read + " 474 "vector.transfer_write: " 475 << *op); 476 Value vector = buildVectorRead(builder, copyOp.input()); 477 buildVectorWrite(builder, vector, copyOp.output()); 478 return; 479 } 480 481 auto linalgOp = cast<linalg::LinalgOp>(op); 482 Location loc = linalgOp.getLoc(); 483 484 if (isElementwise(op)) { 485 LLVM_DEBUG(dbgs() << dbgPref 486 << "Rewrite linalg op as vector.transfer_read + " << *op); 487 auto status = vectorizeAsLinalgGeneric(builder, linalgOp); 488 (void)status; 489 assert(succeeded(status) && 490 "Unexpected vectorization failed despite preconditions"); 491 return; 492 } 493 494 assert(succeeded(isContraction(op)) && "Expected contraction"); 495 496 // Vectorize other ops as vector contraction. 497 // TODO: interface. 498 LLVM_DEBUG(dbgs() << dbgPref 499 << "Rewrite linalg op as vector.contract: " << *op); 500 // Special function that describes how to vectorize the multiplication op in a 501 // linalg contraction. 502 CustomVectorizationHook vectorizeContraction = 503 [&](Operation *op, 504 const BlockAndValueMapping &bvm) -> VectorizationResult { 505 if (!isa<MulIOp, MulFOp>(op)) 506 return VectorizationResult{VectorizationStatus::Failure, nullptr}; 507 auto outShape = linalgOp.getOutputShapedType(0).getShape(); 508 auto vType = outShape.empty() 509 ? op->getResult(0).getType() 510 : VectorType::get(outShape, op->getResult(0).getType()); 511 auto zero = 512 builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType)); 513 Operation *contract = builder.create<vector::ContractionOp>( 514 loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero, 515 linalgOp.indexing_maps(), linalgOp.iterator_types()); 516 return VectorizationResult{VectorizationStatus::NewOp, contract}; 517 }; 518 auto status = 519 vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction}); 520 (void)status; 521 assert(succeeded(status) && 522 "Unexpected vectorization failed despite preconditions"); 523 } 524 525 //----------------------------------------------------------------------------// 526 // Misc. conv vectorization patterns. 527 //----------------------------------------------------------------------------// 528 // TODO: cleanup all this. 529 template <class ConvOp, int N> 530 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite( 531 ConvOp op, PatternRewriter &rewriter) const { 532 Location loc = op.getLoc(); 533 MLIRContext *context = op.getContext(); 534 edsc::ScopedContext scope(rewriter, loc); 535 536 ShapedType inShapeType = op.getInputShapedType(0); 537 ShapedType kShapeType = op.getInputShapedType(1); 538 539 ArrayRef<int64_t> inShape = inShapeType.getShape(); 540 ArrayRef<int64_t> kShape = kShapeType.getShape(); 541 542 if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) 543 return failure(); 544 545 SmallVector<AffineExpr, 4> mapping; 546 SmallVector<int64_t, 4> vectorDims; 547 // Fail to apply when the size of not vectorized dimension is not 1. 548 for (unsigned i = 0; i < N; i++) { 549 if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) 550 return failure(); 551 552 if (mask[i] && inShape[i] != kShape[i]) 553 return failure(); 554 555 if (mask[i]) { 556 mapping.push_back(getAffineDimExpr(i, context)); 557 vectorDims.push_back(inShape[i]); 558 } 559 } 560 561 Value input = op.getInput(0); 562 Value kernel = op.getInput(1); 563 Value output = op.getOutputBuffer(0); 564 565 unsigned rank = inShapeType.getRank(); 566 unsigned numDims = mapping.size(); 567 Type elemType = inShapeType.getElementType(); 568 569 auto map = AffineMap::get(rank, 0, mapping, context); 570 SmallVector<Value, 4> zeros(rank, std_constant_index(0)); 571 auto vecType = VectorType::get(vectorDims, elemType); 572 573 auto inputVec = vector_transfer_read(vecType, input, zeros, map); 574 auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); 575 576 auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); 577 578 std::array<AffineMap, 3> indexingMaps{ 579 AffineMap::getMultiDimIdentityMap(numDims, context), 580 AffineMap::getMultiDimIdentityMap(numDims, context), 581 AffineMap::get(numDims, 0, {}, context)}; 582 583 std::vector<StringRef> iteratorTypes(numDims, "reduction"); 584 585 auto result = rewriter.create<vector::ContractionOp>( 586 loc, inputVec, kernelVec, acc, 587 rewriter.getAffineMapArrayAttr(indexingMaps), 588 rewriter.getStrArrayAttr(iteratorTypes)); 589 590 rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros)); 591 rewriter.eraseOp(op); 592 return success(); 593 } 594 595 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>; 596 597 /// Inserts tiling, promotion and vectorization pattern for ConvOp 598 /// conversion into corresponding pattern lists. 599 template <typename ConvOp, unsigned N> 600 static void 601 populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, 602 OwningRewritePatternList &promotionPatterns, 603 OwningRewritePatternList &vectorizationPatterns, 604 ArrayRef<int64_t> tileSizes, 605 MLIRContext *context) { 606 if (tileSizes.size() < N) 607 return; 608 609 constexpr static StringRef kTiledMarker = "TILED"; 610 constexpr static StringRef kPromotedMarker = "PROMOTED"; 611 tilingPatterns.insert<LinalgTilingPattern<ConvOp>>( 612 context, LinalgTilingOptions().setTileSizes(tileSizes), 613 LinalgTransformationFilter(ArrayRef<Identifier>{}, 614 Identifier::get(kTiledMarker, context))); 615 616 promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>( 617 context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 618 LinalgTransformationFilter(Identifier::get(kTiledMarker, context), 619 Identifier::get(kPromotedMarker, context))); 620 621 SmallVector<bool, 4> mask(N); 622 int offset = tileSizes.size() - N; 623 std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), 624 [](int64_t i) -> bool { return i > 1; }); 625 626 vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask); 627 } 628 629 void mlir::linalg::populateConvVectorizationPatterns( 630 MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns, 631 ArrayRef<int64_t> tileSizes) { 632 OwningRewritePatternList tiling, promotion, vectorization; 633 populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization, 634 tileSizes, context); 635 636 populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization, 637 tileSizes, context); 638 639 populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization, 640 tileSizes, context); 641 642 populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization, 643 tileSizes, context); 644 645 populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization, 646 tileSizes, context); 647 648 populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization, 649 tileSizes, context); 650 651 populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization, 652 tileSizes, context); 653 654 populateVectorizationPatterns<ConvNDHWCOp, 5>( 655 tiling, promotion, vectorization, tileSizes, context); 656 657 populateVectorizationPatterns<ConvNCDHWOp, 5>( 658 tiling, promotion, vectorization, tileSizes, context); 659 660 patterns.push_back(std::move(tiling)); 661 patterns.push_back(std::move(promotion)); 662 patterns.push_back(std::move(vectorization)); 663 } 664 665 //----------------------------------------------------------------------------// 666 // Forwarding patterns 667 //----------------------------------------------------------------------------// 668 669 /// Check whether there is any interleaved use of any `values` between `firstOp` 670 /// and `secondOp`. Conservatively return `true` if any op or value is in a 671 /// different block. 672 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 673 ValueRange values) { 674 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 675 (void)dbgPref; 676 if (firstOp->getBlock() != secondOp->getBlock() || 677 !firstOp->isBeforeInBlock(secondOp)) { 678 LLVM_DEBUG(llvm::dbgs() 679 << dbgPref << "interleavedUses precondition failed, firstOp: " 680 << *firstOp << ", second op: " << *secondOp); 681 return true; 682 } 683 for (auto v : values) { 684 for (auto &u : v.getUses()) { 685 Operation *owner = u.getOwner(); 686 if (owner == firstOp || owner == secondOp) 687 continue; 688 // TODO: this is too conservative, use dominance info in the future. 689 if (owner->getBlock() == firstOp->getBlock() && 690 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 691 continue; 692 LLVM_DEBUG(llvm::dbgs() 693 << dbgPref << " found interleaved op " << *owner 694 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 695 return true; 696 } 697 } 698 return false; 699 } 700 701 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 702 static SubViewOp getSubViewUseIfUnique(Value v) { 703 SubViewOp subViewOp; 704 for (auto &u : v.getUses()) { 705 if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) { 706 if (subViewOp) 707 return SubViewOp(); 708 subViewOp = newSubViewOp; 709 } 710 } 711 return subViewOp; 712 } 713 714 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 715 /// when available. 716 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 717 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 718 719 // Transfer into `view`. 720 Value viewOrAlloc = xferOp.source(); 721 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 722 !viewOrAlloc.getDefiningOp<AllocOp>()) 723 return failure(); 724 725 StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; 726 (void)dbgPref; 727 LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); 728 729 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 730 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 731 if (!subViewOp) 732 return failure(); 733 Value subView = subViewOp.getResult(); 734 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); 735 736 // Find the copy into `subView` without interleaved uses. 737 CopyOp copyOp; 738 for (auto &u : subView.getUses()) { 739 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 740 if (newCopyOp.getOutputBuffer(0) != subView) 741 continue; 742 LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); 743 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 744 continue; 745 copyOp = newCopyOp; 746 break; 747 } 748 } 749 if (!copyOp) 750 return failure(); 751 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); 752 753 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 754 FillOp maybeFillOp; 755 for (auto &u : viewOrAlloc.getUses()) { 756 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 757 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 758 continue; 759 LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); 760 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 761 continue; 762 maybeFillOp = newFillOp; 763 break; 764 } 765 } 766 // Ensure padding matches. 767 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 768 return failure(); 769 if (maybeFillOp) 770 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); 771 772 // `in` is the subview that linalg.copy reads. Replace it. 773 Value in = copyOp.getInput(0); 774 775 // linalg.copy + linalg.fill can be used to create a padded local buffer. 776 // The `masked` attribute is only valid on this padded buffer. 777 // When forwarding to vector.transfer_read, the attribute must be reset 778 // conservatively. 779 Value res = rewriter.create<vector::TransferReadOp>( 780 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 781 xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); 782 783 if (maybeFillOp) 784 rewriter.eraseOp(maybeFillOp); 785 rewriter.eraseOp(copyOp); 786 rewriter.replaceOp(xferOp, res); 787 788 return success(); 789 } 790 791 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 792 /// when available. 793 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 794 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 795 // Transfer into `viewOrAlloc`. 796 Value viewOrAlloc = xferOp.source(); 797 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 798 !viewOrAlloc.getDefiningOp<AllocOp>()) 799 return failure(); 800 801 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 802 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 803 if (!subViewOp) 804 return failure(); 805 Value subView = subViewOp.getResult(); 806 807 // Find the copy from `subView` without interleaved uses. 808 CopyOp copyOp; 809 for (auto &u : subViewOp.getResult().getUses()) { 810 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 811 if (newCopyOp.getInput(0) != subView) 812 continue; 813 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 814 continue; 815 copyOp = newCopyOp; 816 break; 817 } 818 } 819 if (!copyOp) 820 return failure(); 821 822 // `out` is the subview copied into that we replace. 823 Value out = copyOp.getOutputBuffer(0); 824 825 // Forward vector.transfer into copy. 826 // linalg.copy + linalg.fill can be used to create a padded local buffer. 827 // The `masked` attribute is only valid on this padded buffer. 828 // When forwarding to vector.transfer_write, the attribute must be reset 829 // conservatively. 830 rewriter.create<vector::TransferWriteOp>( 831 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 832 xferOp.permutation_map(), ArrayAttr()); 833 834 rewriter.eraseOp(copyOp); 835 rewriter.eraseOp(xferOp); 836 837 return success(); 838 } 839