1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Vectorization transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Support/LLVM.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <type_traits> 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 #define DEBUG_TYPE "linalg-vectorization" 38 39 static bool hasMultiplyAddBody(Region &r) { 40 if (!llvm::hasSingleElement(r)) 41 return false; 42 if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3)) 43 return false; 44 45 using mlir::matchers::m_Val; 46 auto a = m_Val(r.getArgument(0)); 47 auto b = m_Val(r.getArgument(1)); 48 auto c = m_Val(r.getArgument(2)); 49 // TODO: Update this detection once we have matcher support for specifying 50 // that any permutation of operands matches. 51 auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c)); 52 auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b))); 53 auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c)); 54 auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a))); 55 auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c)); 56 auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b))); 57 auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c)); 58 auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a))); 59 return pattern1.match(&r.front().back()) || 60 pattern2.match(&r.front().back()) || 61 pattern3.match(&r.front().back()) || 62 pattern4.match(&r.front().back()) || 63 pattern5.match(&r.front().back()) || 64 pattern6.match(&r.front().back()) || 65 pattern7.match(&r.front().back()) || pattern8.match(&r.front().back()); 66 } 67 68 // TODO: Should be Tablegen'd from a single source that generates the op itself. 69 static LogicalResult isContraction(Operation *op) { 70 // TODO: interface for named ops. 71 if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp, 72 linalg::VecmatOp, linalg::DotOp>(op)) 73 return success(); 74 75 auto genericOp = dyn_cast<linalg::GenericOp>(op); 76 if (!genericOp) 77 return failure(); 78 79 auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>(); 80 return success( 81 genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && 82 llvm::all_of(mapRange, 83 [](AffineMap m) { return m.isProjectedPermutation(); }) && 84 hasMultiplyAddBody(genericOp.region())); 85 } 86 87 static bool hasOnlyScalarElementwiseOp(Region &r) { 88 if (!llvm::hasSingleElement(r)) 89 return false; 90 for (Operation &op : r.front()) { 91 if (!(isa<ConstantOp, linalg::YieldOp>(op) || 92 op.hasTrait<OpTrait::ElementwiseMappable>()) || 93 llvm::any_of(op.getResultTypes(), 94 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 95 return false; 96 } 97 return true; 98 } 99 100 // Return true if the op is an element-wise linalg op. 101 static bool isElementwise(Operation *op) { 102 auto genericOp = dyn_cast<linalg::GenericOp>(op); 103 if (!genericOp) 104 return false; 105 if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) 106 return false; 107 // TODO: relax the restrictions on indexing map. 108 for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) { 109 if (!genericOp.getOutputIndexingMap(i).isIdentity()) 110 return false; 111 } 112 // Currently limit the input indexing map to minor identity as other 113 // permutations might require adding transpose ops to convert the vector read 114 // to the right shape. 115 for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) { 116 if (!genericOp.getInputIndexingMap(i).isMinorIdentity()) 117 return false; 118 } 119 return hasOnlyScalarElementwiseOp(genericOp.getRegion()); 120 } 121 122 static VectorType extractVectorTypeFromShapedValue(Value v) { 123 auto st = v.getType().cast<ShapedType>(); 124 if (st.isa<MemRefType>() && st.getShape().empty()) 125 return VectorType(); 126 return VectorType::get(st.getShape(), st.getElementType()); 127 } 128 129 static Value transferReadVector(OpBuilder &builder, Value source) { 130 edsc::ScopedContext scope(builder); 131 auto shapedType = source.getType().cast<ShapedType>(); 132 if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) { 133 SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0)); 134 return vector_transfer_read(vectorType, source, indices); 135 } 136 return std_load(source); 137 } 138 139 static Value transferWriteVector(OpBuilder &builder, Value value, Value dest) { 140 edsc::ScopedContext scope(builder); 141 Operation *write; 142 auto shapedType = dest.getType().cast<ShapedType>(); 143 if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) { 144 SmallVector<Value, 4> indices(shapedType.getRank(), std_constant_index(0)); 145 if (vectorType != value.getType()) 146 value = vector_broadcast(vectorType, value); 147 write = vector_transfer_write(value, dest, indices); 148 } else { 149 write = std_store(value, dest); 150 } 151 if (!write->getResults().empty()) 152 return write->getResult(0); 153 return Value(); 154 } 155 156 namespace { 157 // Transforms scalar operations into their vectorized counterparts, 158 // while using the provided generic op to map: 159 // * Its arguments to transfer reads from the views of the generic op. 160 // * linalg.yield ops to transfer writes to the views of the generic op. 161 class GenericVectorizer { 162 public: 163 GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic) 164 : builder(builder), generic(generic) {} 165 166 // Takes a scalar operation and builds its vectorized counterpart or 167 // counterparts using the underlying builder. 168 // If operands of the scalar operation are referring to previously vectorized 169 // operations, then in their vectorized form these operands will be referring 170 // to previous vectorization results. 171 void vectorize(Operation &scalarOp) { 172 auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp); 173 if (yieldOp) { 174 for (auto outputs : llvm::enumerate(yieldOp.values())) { 175 Value vectorValue = vectorize(outputs.value()); 176 Value result = transferWriteVector(builder, vectorValue, 177 generic.getOutput(outputs.index())); 178 if (result) 179 results.push_back(result); 180 } 181 return; 182 } 183 Operation *vectorOp = uncachedVectorize(scalarOp); 184 assert(scalarOp.getNumResults() == vectorOp->getNumResults()); 185 for (auto result : 186 llvm::zip(scalarOp.getResults(), vectorOp->getResults())) { 187 valueCache[std::get<0>(result)] = std::get<1>(result); 188 } 189 } 190 191 llvm::ArrayRef<Value> getResults() { return results; } 192 193 private: 194 // Transforms a scalar value into its vectorized counterpart, recursively 195 // vectorizing operations as necessary using the underlying builder. 196 // Keeps track of previously vectorized values and reuses vectorization 197 // results if these values come up again. 198 Value vectorize(Value scalarValue) { 199 // Don't vectorize values coming from outside the region. 200 if (scalarValue.getParentRegion() != &generic.region()) 201 return scalarValue; 202 auto vectorValueIt = valueCache.find(scalarValue); 203 if (vectorValueIt != valueCache.end()) 204 return vectorValueIt->second; 205 206 // If the value is from the region but not in the cache it means it is a 207 // block argument. 208 auto scalarArg = scalarValue.cast<BlockArgument>(); 209 assert(scalarArg.getOwner() == &generic.region().front()); 210 Value vectorArg = generic.getShapedOperand(scalarArg.getArgNumber()); 211 Value vectorResult = transferReadVector(builder, vectorArg); 212 valueCache[scalarArg] = vectorResult; 213 return vectorResult; 214 } 215 216 // Return the largest shape of all the given values. Return an empty 217 // SmallVector if there are no vector value. 218 static SmallVector<int64_t, 4> getLargestShape(ArrayRef<Value> values) { 219 SmallVector<int64_t, 4> largestShape; 220 int64_t maxSize = 1; 221 for (Value value : values) { 222 auto vecType = value.getType().dyn_cast<VectorType>(); 223 if (!vecType) 224 continue; 225 if (maxSize < vecType.getNumElements()) { 226 maxSize = vecType.getNumElements(); 227 largestShape.assign(vecType.getShape().begin(), 228 vecType.getShape().end()); 229 } 230 } 231 return largestShape; 232 } 233 234 // If the value's type doesn't have the given shape broadcast it. 235 Value broadcastIfNeeded(Value value, ArrayRef<int64_t> shape) { 236 auto vecType = value.getType().dyn_cast<VectorType>(); 237 if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) 238 return value; 239 auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() 240 : value.getType()); 241 return builder.create<vector::BroadcastOp>( 242 builder.getInsertionPoint()->getLoc(), newVecType, value); 243 } 244 245 // Takes a scalar operation and builds its vectorized counterpart or 246 // counterparts using underlying builder without involving any caches. 247 Operation *uncachedVectorize(Operation &base_scalarOp) { 248 SmallVector<Value, 4> vectorizedOperands; 249 for (Value operand : base_scalarOp.getOperands()) { 250 vectorizedOperands.push_back(vectorize(operand)); 251 } 252 SmallVector<int64_t, 4> shape = getLargestShape(vectorizedOperands); 253 for (Value &operand : vectorizedOperands) 254 operand = broadcastIfNeeded(operand, shape); 255 OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName()); 256 state.addAttributes(base_scalarOp.getAttrs()); 257 state.addOperands(vectorizedOperands); 258 if (shape.empty()) { 259 state.addTypes(base_scalarOp.getResultTypes()); 260 } else { 261 SmallVector<VectorType, 4> vectorizedTypes; 262 for (auto Type : base_scalarOp.getResultTypes()) 263 vectorizedTypes.push_back(VectorType::get(shape, Type)); 264 state.addTypes(vectorizedTypes); 265 } 266 return builder.createOperation(state); 267 } 268 269 OpBuilder &builder; 270 linalg::GenericOp generic; 271 llvm::DenseMap<Value, Value> valueCache; 272 SmallVector<Value, 8> results; 273 }; 274 } // namespace 275 276 // Replaces elementwise linalg.generic ops with their bodies with scalar 277 // operations from these bodies promoted to vector operations. 278 static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) { 279 GenericVectorizer vectorizer(builder, op); 280 for (Operation &scalarOp : op.region().front()) { 281 vectorizer.vectorize(scalarOp); 282 } 283 if (!op->getResults().empty()) 284 op->replaceAllUsesWith(vectorizer.getResults()); 285 } 286 287 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { 288 auto linalgOp = cast<linalg::LinalgOp>(op); 289 // All types must be static shape to go to vector. 290 for (Value operand : linalgOp.getShapedOperands()) 291 if (!operand.getType().cast<ShapedType>().hasStaticShape()) 292 return failure(); 293 for (Type outputTensorType : linalgOp.getOutputTensorTypes()) 294 if (!outputTensorType.cast<ShapedType>().hasStaticShape()) 295 return failure(); 296 297 if (isa<linalg::FillOp, linalg::CopyOp>(op)) 298 return success(); 299 if (isElementwise(op)) 300 return success(); 301 return isContraction(op); 302 } 303 304 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { 305 assert(succeeded(vectorizeLinalgOpPrecondition(op))); 306 307 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 308 (void)dbgPref; 309 edsc::ScopedContext scope(builder, op->getLoc()); 310 // In the case of 0-D memrefs, return null and special case to scalar load or 311 // store later. 312 if (auto fillOp = dyn_cast<linalg::FillOp>(op)) { 313 // Vectorize fill as a vector.broadcast. 314 LLVM_DEBUG(dbgs() << dbgPref 315 << "Rewrite linalg.fill as vector.broadcast: " << *op); 316 transferWriteVector(builder, fillOp.value(), fillOp.output()); 317 return; 318 } 319 if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) { 320 // Vectorize copy as a vector.transfer_read+vector.transfer_write. 321 LLVM_DEBUG(dbgs() << dbgPref 322 << "Rewrite linalg.copy as vector.transfer_read + " 323 "vector.transfer_write: " 324 << *op); 325 Value vector = transferReadVector(builder, copyOp.input()); 326 transferWriteVector(builder, vector, copyOp.output()); 327 return; 328 } 329 330 if (isElementwise(op)) { 331 LLVM_DEBUG(dbgs() << dbgPref 332 << "Rewrite linalg op as vector.transfer_read + " 333 "vector_op + vector.transfer_write: " 334 << *op); 335 return vectorizeElementwise(cast<linalg::GenericOp>(op), builder); 336 } 337 338 assert(succeeded(isContraction(op)) && "Expected contraction"); 339 340 // Vectorize other ops as vector contraction. 341 // TODO: interface. 342 LLVM_DEBUG(dbgs() << dbgPref 343 << "Rewrite linalg op as vector.contract: " << *op); 344 auto linalgOp = cast<linalg::LinalgOp>(op); 345 Value a = transferReadVector(builder, linalgOp.getInput(0)); 346 Value b = transferReadVector(builder, linalgOp.getInput(1)); 347 Value c = transferReadVector(builder, linalgOp.getOutput(0)); 348 Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), 349 linalgOp.iterator_types()); 350 Value writeResult = transferWriteVector(builder, res, linalgOp.getOutput(0)); 351 if (writeResult) 352 linalgOp->replaceAllUsesWith(ArrayRef<Value>(writeResult)); 353 } 354 355 /// Check whether there is any interleaved use of any `values` between `firstOp` 356 /// and `secondOp`. Conservatively return `true` if any op or value is in a 357 /// different block. 358 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, 359 ValueRange values) { 360 StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; 361 (void)dbgPref; 362 if (firstOp->getBlock() != secondOp->getBlock() || 363 !firstOp->isBeforeInBlock(secondOp)) { 364 LLVM_DEBUG(llvm::dbgs() 365 << dbgPref << "interleavedUses precondition failed, firstOp: " 366 << *firstOp << ", second op: " << *secondOp); 367 return true; 368 } 369 for (auto v : values) { 370 for (auto &u : v.getUses()) { 371 Operation *owner = u.getOwner(); 372 if (owner == firstOp || owner == secondOp) 373 continue; 374 // TODO: this is too conservative, use dominance info in the future. 375 if (owner->getBlock() == firstOp->getBlock() && 376 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) 377 continue; 378 LLVM_DEBUG(llvm::dbgs() 379 << dbgPref << " found interleaved op " << *owner 380 << ", firstOp: " << *firstOp << ", second op: " << *secondOp); 381 return true; 382 } 383 } 384 return false; 385 } 386 387 /// Return the unique subview use of `v` if it is indeed unique, null otherwise. 388 static SubViewOp getSubViewUseIfUnique(Value v) { 389 SubViewOp subViewOp; 390 for (auto &u : v.getUses()) { 391 if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) { 392 if (subViewOp) 393 return SubViewOp(); 394 subViewOp = newSubViewOp; 395 } 396 } 397 return subViewOp; 398 } 399 400 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 401 /// when available. 402 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( 403 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { 404 405 // Transfer into `view`. 406 Value viewOrAlloc = xferOp.source(); 407 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 408 !viewOrAlloc.getDefiningOp<AllocOp>()) 409 return failure(); 410 411 StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; 412 (void)dbgPref; 413 LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); 414 415 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 416 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 417 if (!subViewOp) 418 return failure(); 419 Value subView = subViewOp.getResult(); 420 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); 421 422 // Find the copy into `subView` without interleaved uses. 423 CopyOp copyOp; 424 for (auto &u : subView.getUses()) { 425 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 426 if (newCopyOp.getOutputBuffer(0) != subView) 427 continue; 428 LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); 429 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) 430 continue; 431 copyOp = newCopyOp; 432 break; 433 } 434 } 435 if (!copyOp) 436 return failure(); 437 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); 438 439 // Find the fill into `viewOrAlloc` without interleaved uses before the copy. 440 FillOp maybeFillOp; 441 for (auto &u : viewOrAlloc.getUses()) { 442 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { 443 if (newFillOp.getOutputBuffer(0) != viewOrAlloc) 444 continue; 445 LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); 446 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) 447 continue; 448 maybeFillOp = newFillOp; 449 break; 450 } 451 } 452 // Ensure padding matches. 453 if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) 454 return failure(); 455 if (maybeFillOp) 456 LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); 457 458 // `in` is the subview that linalg.copy reads. Replace it. 459 Value in = copyOp.getInput(0); 460 461 // linalg.copy + linalg.fill can be used to create a padded local buffer. 462 // The `masked` attribute is only valid on this padded buffer. 463 // When forwarding to vector.transfer_read, the attribute must be reset 464 // conservatively. 465 Value res = rewriter.create<vector::TransferReadOp>( 466 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), 467 xferOp.permutation_map(), xferOp.padding(), ArrayAttr()); 468 469 if (maybeFillOp) 470 rewriter.eraseOp(maybeFillOp); 471 rewriter.eraseOp(copyOp); 472 rewriter.replaceOp(xferOp, res); 473 474 return success(); 475 } 476 477 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, 478 /// when available. 479 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( 480 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { 481 // Transfer into `viewOrAlloc`. 482 Value viewOrAlloc = xferOp.source(); 483 if (!viewOrAlloc.getDefiningOp<ViewOp>() && 484 !viewOrAlloc.getDefiningOp<AllocOp>()) 485 return failure(); 486 487 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. 488 SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); 489 if (!subViewOp) 490 return failure(); 491 Value subView = subViewOp.getResult(); 492 493 // Find the copy from `subView` without interleaved uses. 494 CopyOp copyOp; 495 for (auto &u : subViewOp.getResult().getUses()) { 496 if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) { 497 if (newCopyOp.getInput(0) != subView) 498 continue; 499 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) 500 continue; 501 copyOp = newCopyOp; 502 break; 503 } 504 } 505 if (!copyOp) 506 return failure(); 507 508 // `out` is the subview copied into that we replace. 509 Value out = copyOp.getOutputBuffer(0); 510 511 // Forward vector.transfer into copy. 512 // linalg.copy + linalg.fill can be used to create a padded local buffer. 513 // The `masked` attribute is only valid on this padded buffer. 514 // When forwarding to vector.transfer_write, the attribute must be reset 515 // conservatively. 516 rewriter.create<vector::TransferWriteOp>( 517 xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), 518 xferOp.permutation_map(), ArrayAttr()); 519 520 rewriter.eraseOp(copyOp); 521 rewriter.eraseOp(xferOp); 522 523 return success(); 524 } 525 526 template <class ConvOp, int N> 527 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite( 528 ConvOp op, PatternRewriter &rewriter) const { 529 Location loc = op.getLoc(); 530 MLIRContext *context = op.getContext(); 531 edsc::ScopedContext scope(rewriter, loc); 532 533 ShapedType inShapeType = op.getInputShapedType(0); 534 ShapedType kShapeType = op.getInputShapedType(1); 535 536 ArrayRef<int64_t> inShape = inShapeType.getShape(); 537 ArrayRef<int64_t> kShape = kShapeType.getShape(); 538 539 if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) 540 return failure(); 541 542 SmallVector<AffineExpr, 4> mapping; 543 SmallVector<int64_t, 4> vectorDims; 544 // Fail to apply when the size of not vectorized dimension is not 1. 545 for (unsigned i = 0; i < N; i++) { 546 if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) 547 return failure(); 548 549 if (mask[i] && inShape[i] != kShape[i]) 550 return failure(); 551 552 if (mask[i]) { 553 mapping.push_back(getAffineDimExpr(i, context)); 554 vectorDims.push_back(inShape[i]); 555 } 556 } 557 558 Value input = op.getInput(0); 559 Value kernel = op.getInput(1); 560 Value output = op.getOutputBuffer(0); 561 562 unsigned rank = inShapeType.getRank(); 563 unsigned numDims = mapping.size(); 564 Type elemType = inShapeType.getElementType(); 565 566 auto map = AffineMap::get(rank, 0, mapping, context); 567 SmallVector<Value, 4> zeros(rank, std_constant_index(0)); 568 auto vecType = VectorType::get(vectorDims, elemType); 569 570 auto inputVec = vector_transfer_read(vecType, input, zeros, map); 571 auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); 572 573 auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); 574 575 std::array<AffineMap, 3> indexingMaps{ 576 AffineMap::getMultiDimIdentityMap(numDims, context), 577 AffineMap::getMultiDimIdentityMap(numDims, context), 578 AffineMap::get(numDims, 0, {}, context)}; 579 580 std::vector<StringRef> iteratorTypes(numDims, "reduction"); 581 582 auto result = rewriter.create<vector::ContractionOp>( 583 loc, inputVec, kernelVec, acc, 584 rewriter.getAffineMapArrayAttr(indexingMaps), 585 rewriter.getStrArrayAttr(iteratorTypes)); 586 587 rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros)); 588 rewriter.eraseOp(op); 589 return success(); 590 } 591 592 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>; 593 594 /// Inserts tiling, promotion and vectorization pattern for ConvOp 595 /// conversion into corresponding pattern lists. 596 template <typename ConvOp, unsigned N> 597 static void 598 populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, 599 OwningRewritePatternList &promotionPatterns, 600 OwningRewritePatternList &vectorizationPatterns, 601 ArrayRef<int64_t> tileSizes, 602 MLIRContext *context) { 603 if (tileSizes.size() < N) 604 return; 605 606 constexpr static StringRef kTiledMarker = "TILED"; 607 constexpr static StringRef kPromotedMarker = "PROMOTED"; 608 tilingPatterns.insert<LinalgTilingPattern<ConvOp>>( 609 context, LinalgTilingOptions().setTileSizes(tileSizes), 610 LinalgMarker({}, Identifier::get(kTiledMarker, context))); 611 612 promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>( 613 context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 614 LinalgMarker(Identifier::get(kTiledMarker, context), 615 Identifier::get(kPromotedMarker, context))); 616 617 SmallVector<bool, 4> mask(N); 618 int offset = tileSizes.size() - N; 619 std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), 620 [](int64_t i) -> bool { return i > 1; }); 621 622 vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask); 623 } 624 625 void mlir::linalg::populateConvVectorizationPatterns( 626 MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns, 627 ArrayRef<int64_t> tileSizes) { 628 OwningRewritePatternList tiling, promotion, vectorization; 629 populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization, 630 tileSizes, context); 631 632 populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization, 633 tileSizes, context); 634 635 populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization, 636 tileSizes, context); 637 638 populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization, 639 tileSizes, context); 640 641 populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization, 642 tileSizes, context); 643 644 populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization, 645 tileSizes, context); 646 647 populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization, 648 tileSizes, context); 649 650 populateVectorizationPatterns<ConvNDHWCOp, 5>( 651 tiling, promotion, vectorization, tileSizes, context); 652 653 populateVectorizationPatterns<ConvNCDHWOp, 5>( 654 tiling, promotion, vectorization, tileSizes, context); 655 656 patterns.push_back(std::move(tiling)); 657 patterns.push_back(std::move(promotion)); 658 patterns.push_back(std::move(vectorization)); 659 } 660