1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 14 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 15 #include "mlir/Dialect/SCF/Utils/Utils.h" 16 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 20 using namespace mlir; 21 22 namespace { 23 /// A simple pattern rewriter that implements no special logic. 24 class SimpleRewriter : public PatternRewriter { 25 public: 26 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 27 }; 28 } // namespace 29 30 //===----------------------------------------------------------------------===// 31 // GetParentForOp 32 //===----------------------------------------------------------------------===// 33 34 DiagnosedSilenceableFailure 35 transform::GetParentForOp::apply(transform::TransformResults &results, 36 transform::TransformState &state) { 37 SetVector<Operation *> parents; 38 for (Operation *target : state.getPayloadOps(getTarget())) { 39 scf::ForOp loop; 40 Operation *current = target; 41 for (unsigned i = 0, e = getNumLoops(); i < e; ++i) { 42 loop = current->getParentOfType<scf::ForOp>(); 43 if (!loop) { 44 DiagnosedSilenceableFailure diag = emitSilenceableError() 45 << "could not find an '" 46 << scf::ForOp::getOperationName() 47 << "' parent"; 48 diag.attachNote(target->getLoc()) << "target op"; 49 return diag; 50 } 51 current = loop; 52 } 53 parents.insert(loop); 54 } 55 results.set(getResult().cast<OpResult>(), parents.getArrayRef()); 56 return DiagnosedSilenceableFailure::success(); 57 } 58 59 //===----------------------------------------------------------------------===// 60 // LoopOutlineOp 61 //===----------------------------------------------------------------------===// 62 63 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses 64 /// the provided rewriter for all operations to remain compatible with the 65 /// rewriting infra, as opposed to just splicing the op in place. 66 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, 67 Operation *op) { 68 if (op->getNumRegions() != 1) 69 return nullptr; 70 OpBuilder::InsertionGuard g(b); 71 b.setInsertionPoint(op); 72 scf::ExecuteRegionOp executeRegionOp = 73 b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes()); 74 { 75 OpBuilder::InsertionGuard g(b); 76 b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); 77 Operation *clonedOp = b.cloneWithoutRegions(*op); 78 Region &clonedRegion = clonedOp->getRegions().front(); 79 assert(clonedRegion.empty() && "expected empty region"); 80 b.inlineRegionBefore(op->getRegions().front(), clonedRegion, 81 clonedRegion.end()); 82 b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults()); 83 } 84 b.replaceOp(op, executeRegionOp.getResults()); 85 return executeRegionOp; 86 } 87 88 DiagnosedSilenceableFailure 89 transform::LoopOutlineOp::apply(transform::TransformResults &results, 90 transform::TransformState &state) { 91 SmallVector<Operation *> transformed; 92 DenseMap<Operation *, SymbolTable> symbolTables; 93 for (Operation *target : state.getPayloadOps(getTarget())) { 94 Location location = target->getLoc(); 95 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); 96 SimpleRewriter rewriter(getContext()); 97 scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); 98 if (!exec) { 99 DiagnosedSilenceableFailure diag = emitSilenceableError() 100 << "failed to outline"; 101 diag.attachNote(target->getLoc()) << "target op"; 102 return diag; 103 } 104 func::CallOp call; 105 FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion( 106 rewriter, location, exec.getRegion(), getFuncName(), &call); 107 108 if (failed(outlined)) { 109 (void)reportUnknownTransformError(target); 110 return DiagnosedSilenceableFailure::definiteFailure(); 111 } 112 113 if (symbolTableOp) { 114 SymbolTable &symbolTable = 115 symbolTables.try_emplace(symbolTableOp, symbolTableOp) 116 .first->getSecond(); 117 symbolTable.insert(*outlined); 118 call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); 119 } 120 transformed.push_back(*outlined); 121 } 122 results.set(getTransformed().cast<OpResult>(), transformed); 123 return DiagnosedSilenceableFailure::success(); 124 } 125 126 //===----------------------------------------------------------------------===// 127 // LoopPeelOp 128 //===----------------------------------------------------------------------===// 129 130 DiagnosedSilenceableFailure 131 transform::LoopPeelOp::applyToOne(scf::ForOp target, 132 SmallVector<Operation *> &results, 133 transform::TransformState &state) { 134 scf::ForOp result; 135 IRRewriter rewriter(target->getContext()); 136 // This helper returns failure when peeling does not occur (i.e. when the IR 137 // is not modified). This is not a failure for the op as the postcondition: 138 // "the loop trip count is divisible by the step" 139 // is valid. 140 LogicalResult status = 141 scf::peelAndCanonicalizeForLoop(rewriter, target, result); 142 // TODO: Return both the peeled loop and the remainder loop. 143 results.push_back(failed(status) ? target : result); 144 return DiagnosedSilenceableFailure(success()); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // LoopPipelineOp 149 //===----------------------------------------------------------------------===// 150 151 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an 152 /// operation to its logical time position given the iteration interval and the 153 /// read latency. The latter is only relevant for vector transfers. 154 static void 155 loopScheduling(scf::ForOp forOp, 156 std::vector<std::pair<Operation *, unsigned>> &schedule, 157 unsigned iterationInterval, unsigned readLatency) { 158 auto getLatency = [&](Operation *op) -> unsigned { 159 if (isa<vector::TransferReadOp>(op)) 160 return readLatency; 161 return 1; 162 }; 163 164 DenseMap<Operation *, unsigned> opCycles; 165 std::map<unsigned, std::vector<Operation *>> wrappedSchedule; 166 for (Operation &op : forOp.getBody()->getOperations()) { 167 if (isa<scf::YieldOp>(op)) 168 continue; 169 unsigned earlyCycle = 0; 170 for (Value operand : op.getOperands()) { 171 Operation *def = operand.getDefiningOp(); 172 if (!def) 173 continue; 174 earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); 175 } 176 opCycles[&op] = earlyCycle; 177 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); 178 } 179 for (const auto &it : wrappedSchedule) { 180 for (Operation *op : it.second) { 181 unsigned cycle = opCycles[op]; 182 schedule.emplace_back(op, cycle / iterationInterval); 183 } 184 } 185 } 186 187 DiagnosedSilenceableFailure 188 transform::LoopPipelineOp::applyToOne(scf::ForOp target, 189 SmallVector<Operation *> &results, 190 transform::TransformState &state) { 191 scf::PipeliningOption options; 192 options.getScheduleFn = 193 [this](scf::ForOp forOp, 194 std::vector<std::pair<Operation *, unsigned>> &schedule) mutable { 195 loopScheduling(forOp, schedule, getIterationInterval(), 196 getReadLatency()); 197 }; 198 scf::ForLoopPipeliningPattern pattern(options, target->getContext()); 199 SimpleRewriter rewriter(getContext()); 200 rewriter.setInsertionPoint(target); 201 FailureOr<scf::ForOp> patternResult = 202 pattern.returningMatchAndRewrite(target, rewriter); 203 if (succeeded(patternResult)) { 204 results.push_back(*patternResult); 205 return DiagnosedSilenceableFailure(success()); 206 } 207 results.assign(1, nullptr); 208 return emitDefaultSilenceableFailure(target); 209 } 210 211 //===----------------------------------------------------------------------===// 212 // LoopUnrollOp 213 //===----------------------------------------------------------------------===// 214 215 DiagnosedSilenceableFailure 216 transform::LoopUnrollOp::applyToOne(scf::ForOp target, 217 SmallVector<Operation *> &results, 218 transform::TransformState &state) { 219 if (failed(loopUnrollByFactor(target, getFactor()))) { 220 Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note); 221 diag << "op failed to unroll"; 222 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); 223 } 224 return DiagnosedSilenceableFailure(success()); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // Transform op registration 229 //===----------------------------------------------------------------------===// 230 231 namespace { 232 class SCFTransformDialectExtension 233 : public transform::TransformDialectExtension< 234 SCFTransformDialectExtension> { 235 public: 236 SCFTransformDialectExtension() { 237 declareDependentDialect<AffineDialect>(); 238 declareDependentDialect<func::FuncDialect>(); 239 registerTransformOps< 240 #define GET_OP_LIST 241 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 242 >(); 243 } 244 }; 245 } // namespace 246 247 #define GET_OP_CLASSES 248 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 249 250 void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { 251 registry.addExtensions<SCFTransformDialectExtension>(); 252 } 253