1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/SCF/Patterns.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/SCF/Transforms.h" 15 #include "mlir/Dialect/SCF/Utils/Utils.h" 16 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 19 using namespace mlir; 20 21 namespace { 22 /// A simple pattern rewriter that implements no special logic. 23 class SimpleRewriter : public PatternRewriter { 24 public: 25 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 26 }; 27 } // namespace 28 29 //===----------------------------------------------------------------------===// 30 // GetParentForOp 31 //===----------------------------------------------------------------------===// 32 33 LogicalResult 34 transform::GetParentForOp::apply(transform::TransformResults &results, 35 transform::TransformState &state) { 36 SetVector<Operation *> parents; 37 for (Operation *target : state.getPayloadOps(getTarget())) { 38 scf::ForOp loop; 39 Operation *current = target; 40 for (unsigned i = 0, e = getNumLoops(); i < e; ++i) { 41 loop = current->getParentOfType<scf::ForOp>(); 42 if (!loop) { 43 InFlightDiagnostic diag = emitError() << "could not find an '" 44 << scf::ForOp::getOperationName() 45 << "' parent"; 46 diag.attachNote(target->getLoc()) << "target op"; 47 return diag; 48 } 49 current = loop; 50 } 51 parents.insert(loop); 52 } 53 results.set(getResult().cast<OpResult>(), parents.getArrayRef()); 54 return success(); 55 } 56 57 //===----------------------------------------------------------------------===// 58 // LoopOutlineOp 59 //===----------------------------------------------------------------------===// 60 61 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses 62 /// the provided rewriter for all operations to remain compatible with the 63 /// rewriting infra, as opposed to just splicing the op in place. 64 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, 65 Operation *op) { 66 if (op->getNumRegions() != 1) 67 return nullptr; 68 OpBuilder::InsertionGuard g(b); 69 b.setInsertionPoint(op); 70 scf::ExecuteRegionOp executeRegionOp = 71 b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes()); 72 { 73 OpBuilder::InsertionGuard g(b); 74 b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); 75 Operation *clonedOp = b.cloneWithoutRegions(*op); 76 Region &clonedRegion = clonedOp->getRegions().front(); 77 assert(clonedRegion.empty() && "expected empty region"); 78 b.inlineRegionBefore(op->getRegions().front(), clonedRegion, 79 clonedRegion.end()); 80 b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults()); 81 } 82 b.replaceOp(op, executeRegionOp.getResults()); 83 return executeRegionOp; 84 } 85 86 LogicalResult 87 transform::LoopOutlineOp::apply(transform::TransformResults &results, 88 transform::TransformState &state) { 89 SmallVector<Operation *> transformed; 90 DenseMap<Operation *, SymbolTable> symbolTables; 91 for (Operation *target : state.getPayloadOps(getTarget())) { 92 Location location = target->getLoc(); 93 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); 94 SimpleRewriter rewriter(getContext()); 95 scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); 96 if (!exec) { 97 InFlightDiagnostic diag = emitError() << "failed to outline"; 98 diag.attachNote(target->getLoc()) << "target op"; 99 return diag; 100 } 101 func::CallOp call; 102 FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion( 103 rewriter, location, exec.getRegion(), getFuncName(), &call); 104 105 if (failed(outlined)) 106 return reportUnknownTransformError(target); 107 108 if (symbolTableOp) { 109 SymbolTable &symbolTable = 110 symbolTables.try_emplace(symbolTableOp, symbolTableOp) 111 .first->getSecond(); 112 symbolTable.insert(*outlined); 113 call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); 114 } 115 transformed.push_back(*outlined); 116 } 117 results.set(getTransformed().cast<OpResult>(), transformed); 118 return success(); 119 } 120 121 //===----------------------------------------------------------------------===// 122 // LoopPeelOp 123 //===----------------------------------------------------------------------===// 124 125 FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop) { 126 scf::ForOp result; 127 IRRewriter rewriter(loop->getContext()); 128 LogicalResult status = 129 scf::peelAndCanonicalizeForLoop(rewriter, loop, result); 130 if (failed(status)) { 131 if (getFailIfAlreadyDivisible()) 132 return reportUnknownTransformError(loop); 133 return loop; 134 } 135 return result; 136 } 137 138 //===----------------------------------------------------------------------===// 139 // LoopPipelineOp 140 //===----------------------------------------------------------------------===// 141 142 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an 143 /// operation to its logical time position given the iteration interval and the 144 /// read latency. The latter is only relevant for vector transfers. 145 static void 146 loopScheduling(scf::ForOp forOp, 147 std::vector<std::pair<Operation *, unsigned>> &schedule, 148 unsigned iterationInterval, unsigned readLatency) { 149 auto getLatency = [&](Operation *op) -> unsigned { 150 if (isa<vector::TransferReadOp>(op)) 151 return readLatency; 152 return 1; 153 }; 154 155 DenseMap<Operation *, unsigned> opCycles; 156 std::map<unsigned, std::vector<Operation *>> wrappedSchedule; 157 for (Operation &op : forOp.getBody()->getOperations()) { 158 if (isa<scf::YieldOp>(op)) 159 continue; 160 unsigned earlyCycle = 0; 161 for (Value operand : op.getOperands()) { 162 Operation *def = operand.getDefiningOp(); 163 if (!def) 164 continue; 165 earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); 166 } 167 opCycles[&op] = earlyCycle; 168 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); 169 } 170 for (const auto &it : wrappedSchedule) { 171 for (Operation *op : it.second) { 172 unsigned cycle = opCycles[op]; 173 schedule.push_back(std::make_pair(op, cycle / iterationInterval)); 174 } 175 } 176 } 177 178 FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) { 179 scf::PipeliningOption options; 180 options.getScheduleFn = 181 [this](scf::ForOp forOp, 182 std::vector<std::pair<Operation *, unsigned>> &schedule) mutable { 183 loopScheduling(forOp, schedule, getIterationInterval(), 184 getReadLatency()); 185 }; 186 187 scf::ForLoopPipeliningPattern pattern(options, loop->getContext()); 188 SimpleRewriter rewriter(getContext()); 189 rewriter.setInsertionPoint(loop); 190 FailureOr<scf::ForOp> patternResult = 191 pattern.returningMatchAndRewrite(loop, rewriter); 192 if (failed(patternResult)) 193 return reportUnknownTransformError(loop); 194 return patternResult; 195 } 196 197 //===----------------------------------------------------------------------===// 198 // LoopUnrollOp 199 //===----------------------------------------------------------------------===// 200 201 LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) { 202 if (failed(loopUnrollByFactor(loop, getFactor()))) 203 return reportUnknownTransformError(loop); 204 return success(); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // Transform op registration 209 //===----------------------------------------------------------------------===// 210 211 namespace { 212 class SCFTransformDialectExtension 213 : public transform::TransformDialectExtension< 214 SCFTransformDialectExtension> { 215 public: 216 SCFTransformDialectExtension() { 217 declareDependentDialect<AffineDialect>(); 218 declareDependentDialect<func::FuncDialect>(); 219 registerTransformOps< 220 #define GET_OP_LIST 221 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 222 >(); 223 } 224 }; 225 } // namespace 226 227 #define GET_OP_CLASSES 228 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 229 230 void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { 231 registry.addExtensions<SCFTransformDialectExtension>(); 232 } 233