1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/SCF/SCF.h" 12 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 13 14 using namespace mlir; 15 using namespace mlir::vector; 16 17 static LogicalResult 18 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 19 const WarpExecuteOnLane0LoweringOptions &options) { 20 assert(warpOp.getBodyRegion().hasOneBlock() && 21 "expected WarpOp with single block"); 22 Block *warpOpBody = &warpOp.getBodyRegion().front(); 23 Location loc = warpOp.getLoc(); 24 25 // Passed all checks. Start rewriting. 26 OpBuilder::InsertionGuard g(rewriter); 27 rewriter.setInsertionPoint(warpOp); 28 29 // Create scf.if op. 30 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 31 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 32 warpOp.getLaneid(), c0); 33 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 34 /*withElseRegion=*/false); 35 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 36 37 // Store vectors that are defined outside of warpOp into the scratch pad 38 // buffer. 39 SmallVector<Value> bbArgReplacements; 40 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 41 Value val = it.value(); 42 Value bbArg = warpOpBody->getArgument(it.index()); 43 44 rewriter.setInsertionPoint(ifOp); 45 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 46 bbArg.getType()); 47 48 // Store arg vector into buffer. 49 rewriter.setInsertionPoint(ifOp); 50 auto vectorType = val.getType().cast<VectorType>(); 51 int64_t storeSize = vectorType.getShape()[0]; 52 Value storeOffset = rewriter.create<arith::MulIOp>( 53 loc, warpOp.getLaneid(), 54 rewriter.create<arith::ConstantIndexOp>(loc, storeSize)); 55 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset); 56 57 // Load bbArg vector from buffer. 58 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 59 auto bbArgType = bbArg.getType().cast<VectorType>(); 60 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0); 61 bbArgReplacements.push_back(loadOp); 62 } 63 64 // Insert sync after all the stores and before all the loads. 65 if (!warpOp.getArgs().empty()) { 66 rewriter.setInsertionPoint(ifOp); 67 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 68 } 69 70 // Move body of warpOp to ifOp. 71 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 72 73 // Rewrite terminator and compute replacements of WarpOp results. 74 SmallVector<Value> replacements; 75 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 76 Location yieldLoc = yieldOp.getLoc(); 77 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 78 Value val = it.value(); 79 Type resultType = warpOp->getResultTypes()[it.index()]; 80 rewriter.setInsertionPoint(ifOp); 81 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 82 val.getType()); 83 84 // Store yielded value into buffer. 85 rewriter.setInsertionPoint(yieldOp); 86 if (val.getType().isa<VectorType>()) 87 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0); 88 else 89 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0); 90 91 // Load value from buffer (after warpOp). 92 rewriter.setInsertionPointAfter(ifOp); 93 if (resultType == val.getType()) { 94 // Result type and yielded value type are the same. This is a broadcast. 95 // E.g.: 96 // %r = vector_ext.warp_execute_on_lane_0(...) -> (f32) { 97 // vector_ext.yield %cst : f32 98 // } 99 // Both types are f32. The constant %cst is broadcasted to all lanes. 100 // This is described in more detail in the documentation of the op. 101 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0); 102 replacements.push_back(loadOp); 103 } else { 104 auto loadedVectorType = resultType.cast<VectorType>(); 105 int64_t loadSize = loadedVectorType.getShape()[0]; 106 107 // loadOffset = laneid * loadSize 108 Value loadOffset = rewriter.create<arith::MulIOp>( 109 loc, warpOp.getLaneid(), 110 rewriter.create<arith::ConstantIndexOp>(loc, loadSize)); 111 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType, 112 buffer, loadOffset); 113 replacements.push_back(loadOp); 114 } 115 } 116 117 // Insert sync after all the stores and before all the loads. 118 if (!yieldOp.operands().empty()) { 119 rewriter.setInsertionPointAfter(ifOp); 120 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 121 } 122 123 // Delete terminator and add empty scf.yield. 124 rewriter.eraseOp(yieldOp); 125 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 126 rewriter.create<scf::YieldOp>(yieldLoc); 127 128 // Compute replacements for WarpOp results. 129 rewriter.replaceOp(warpOp, replacements); 130 131 return success(); 132 } 133 134 namespace { 135 136 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 137 WarpOpToScfForPattern(MLIRContext *context, 138 const WarpExecuteOnLane0LoweringOptions &options, 139 PatternBenefit benefit = 1) 140 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 141 options(options) {} 142 143 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 144 PatternRewriter &rewriter) const override { 145 return rewriteWarpOpToScfFor(rewriter, warpOp, options); 146 } 147 148 private: 149 const WarpExecuteOnLane0LoweringOptions &options; 150 }; 151 152 } // namespace 153 154 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 155 RewritePatternSet &patterns, 156 const WarpExecuteOnLane0LoweringOptions &options) { 157 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options); 158 } 159