1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/PDL/IR/PDL.h" 13 #include "mlir/Dialect/SCF/IR/SCF.h" 14 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 15 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 16 #include "mlir/Dialect/SCF/Utils/Utils.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 19 #include "mlir/Dialect/Vector/IR/VectorOps.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// A simple pattern rewriter that implements no special logic. 25 class SimpleRewriter : public PatternRewriter { 26 public: 27 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} 28 }; 29 } // namespace 30 31 //===----------------------------------------------------------------------===// 32 // GetParentForOp 33 //===----------------------------------------------------------------------===// 34 35 DiagnosedSilenceableFailure 36 transform::GetParentForOp::apply(transform::TransformResults &results, 37 transform::TransformState &state) { 38 SetVector<Operation *> parents; 39 for (Operation *target : state.getPayloadOps(getTarget())) { 40 scf::ForOp loop; 41 Operation *current = target; 42 for (unsigned i = 0, e = getNumLoops(); i < e; ++i) { 43 loop = current->getParentOfType<scf::ForOp>(); 44 if (!loop) { 45 DiagnosedSilenceableFailure diag = emitSilenceableError() 46 << "could not find an '" 47 << scf::ForOp::getOperationName() 48 << "' parent"; 49 diag.attachNote(target->getLoc()) << "target op"; 50 return diag; 51 } 52 current = loop; 53 } 54 parents.insert(loop); 55 } 56 results.set(getResult().cast<OpResult>(), parents.getArrayRef()); 57 return DiagnosedSilenceableFailure::success(); 58 } 59 60 //===----------------------------------------------------------------------===// 61 // LoopOutlineOp 62 //===----------------------------------------------------------------------===// 63 64 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses 65 /// the provided rewriter for all operations to remain compatible with the 66 /// rewriting infra, as opposed to just splicing the op in place. 67 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, 68 Operation *op) { 69 if (op->getNumRegions() != 1) 70 return nullptr; 71 OpBuilder::InsertionGuard g(b); 72 b.setInsertionPoint(op); 73 scf::ExecuteRegionOp executeRegionOp = 74 b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes()); 75 { 76 OpBuilder::InsertionGuard g(b); 77 b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); 78 Operation *clonedOp = b.cloneWithoutRegions(*op); 79 Region &clonedRegion = clonedOp->getRegions().front(); 80 assert(clonedRegion.empty() && "expected empty region"); 81 b.inlineRegionBefore(op->getRegions().front(), clonedRegion, 82 clonedRegion.end()); 83 b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults()); 84 } 85 b.replaceOp(op, executeRegionOp.getResults()); 86 return executeRegionOp; 87 } 88 89 DiagnosedSilenceableFailure 90 transform::LoopOutlineOp::apply(transform::TransformResults &results, 91 transform::TransformState &state) { 92 SmallVector<Operation *> transformed; 93 DenseMap<Operation *, SymbolTable> symbolTables; 94 for (Operation *target : state.getPayloadOps(getTarget())) { 95 Location location = target->getLoc(); 96 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); 97 SimpleRewriter rewriter(getContext()); 98 scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); 99 if (!exec) { 100 DiagnosedSilenceableFailure diag = emitSilenceableError() 101 << "failed to outline"; 102 diag.attachNote(target->getLoc()) << "target op"; 103 return diag; 104 } 105 func::CallOp call; 106 FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion( 107 rewriter, location, exec.getRegion(), getFuncName(), &call); 108 109 if (failed(outlined)) { 110 (void)reportUnknownTransformError(target); 111 return DiagnosedSilenceableFailure::definiteFailure(); 112 } 113 114 if (symbolTableOp) { 115 SymbolTable &symbolTable = 116 symbolTables.try_emplace(symbolTableOp, symbolTableOp) 117 .first->getSecond(); 118 symbolTable.insert(*outlined); 119 call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); 120 } 121 transformed.push_back(*outlined); 122 } 123 results.set(getTransformed().cast<OpResult>(), transformed); 124 return DiagnosedSilenceableFailure::success(); 125 } 126 127 //===----------------------------------------------------------------------===// 128 // LoopPeelOp 129 //===----------------------------------------------------------------------===// 130 131 DiagnosedSilenceableFailure 132 transform::LoopPeelOp::applyToOne(scf::ForOp target, 133 SmallVector<Operation *> &results, 134 transform::TransformState &state) { 135 scf::ForOp result; 136 IRRewriter rewriter(target->getContext()); 137 // This helper returns failure when peeling does not occur (i.e. when the IR 138 // is not modified). This is not a failure for the op as the postcondition: 139 // "the loop trip count is divisible by the step" 140 // is valid. 141 LogicalResult status = 142 scf::peelAndCanonicalizeForLoop(rewriter, target, result); 143 // TODO: Return both the peeled loop and the remainder loop. 144 results.push_back(failed(status) ? target : result); 145 return DiagnosedSilenceableFailure(success()); 146 } 147 148 //===----------------------------------------------------------------------===// 149 // LoopPipelineOp 150 //===----------------------------------------------------------------------===// 151 152 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an 153 /// operation to its logical time position given the iteration interval and the 154 /// read latency. The latter is only relevant for vector transfers. 155 static void 156 loopScheduling(scf::ForOp forOp, 157 std::vector<std::pair<Operation *, unsigned>> &schedule, 158 unsigned iterationInterval, unsigned readLatency) { 159 auto getLatency = [&](Operation *op) -> unsigned { 160 if (isa<vector::TransferReadOp>(op)) 161 return readLatency; 162 return 1; 163 }; 164 165 DenseMap<Operation *, unsigned> opCycles; 166 std::map<unsigned, std::vector<Operation *>> wrappedSchedule; 167 for (Operation &op : forOp.getBody()->getOperations()) { 168 if (isa<scf::YieldOp>(op)) 169 continue; 170 unsigned earlyCycle = 0; 171 for (Value operand : op.getOperands()) { 172 Operation *def = operand.getDefiningOp(); 173 if (!def) 174 continue; 175 earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); 176 } 177 opCycles[&op] = earlyCycle; 178 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); 179 } 180 for (const auto &it : wrappedSchedule) { 181 for (Operation *op : it.second) { 182 unsigned cycle = opCycles[op]; 183 schedule.emplace_back(op, cycle / iterationInterval); 184 } 185 } 186 } 187 188 DiagnosedSilenceableFailure 189 transform::LoopPipelineOp::applyToOne(scf::ForOp target, 190 SmallVector<Operation *> &results, 191 transform::TransformState &state) { 192 scf::PipeliningOption options; 193 options.getScheduleFn = 194 [this](scf::ForOp forOp, 195 std::vector<std::pair<Operation *, unsigned>> &schedule) mutable { 196 loopScheduling(forOp, schedule, getIterationInterval(), 197 getReadLatency()); 198 }; 199 scf::ForLoopPipeliningPattern pattern(options, target->getContext()); 200 SimpleRewriter rewriter(getContext()); 201 rewriter.setInsertionPoint(target); 202 FailureOr<scf::ForOp> patternResult = 203 pattern.returningMatchAndRewrite(target, rewriter); 204 if (succeeded(patternResult)) { 205 results.push_back(*patternResult); 206 return DiagnosedSilenceableFailure(success()); 207 } 208 results.assign(1, nullptr); 209 return emitDefaultSilenceableFailure(target); 210 } 211 212 //===----------------------------------------------------------------------===// 213 // LoopUnrollOp 214 //===----------------------------------------------------------------------===// 215 216 DiagnosedSilenceableFailure 217 transform::LoopUnrollOp::applyToOne(scf::ForOp target, 218 SmallVector<Operation *> &results, 219 transform::TransformState &state) { 220 if (failed(loopUnrollByFactor(target, getFactor()))) { 221 Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note); 222 diag << "op failed to unroll"; 223 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); 224 } 225 return DiagnosedSilenceableFailure(success()); 226 } 227 228 //===----------------------------------------------------------------------===// 229 // Transform op registration 230 //===----------------------------------------------------------------------===// 231 232 namespace { 233 class SCFTransformDialectExtension 234 : public transform::TransformDialectExtension< 235 SCFTransformDialectExtension> { 236 public: 237 using Base::Base; 238 239 void init() { 240 declareDependentDialect<pdl::PDLDialect>(); 241 242 declareGeneratedDialect<AffineDialect>(); 243 declareGeneratedDialect<func::FuncDialect>(); 244 245 registerTransformOps< 246 #define GET_OP_LIST 247 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 248 >(); 249 } 250 }; 251 } // namespace 252 253 #define GET_OP_CLASSES 254 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 255 256 void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { 257 registry.addExtensions<SCFTransformDialectExtension>(); 258 } 259