1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/IR/SCF.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/FunctionInterfaces.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Support/MathExtras.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 
23 using namespace mlir;
24 using namespace mlir::scf;
25 
26 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // SCFDialect Dialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 struct SCFInlinerInterface : public DialectInlinerInterface {
34   using DialectInlinerInterface::DialectInlinerInterface;
35   // We don't have any special restrictions on what can be inlined into
36   // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInline__anoncfc0330f0111::SCFInlinerInterface37   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
38                        BlockAndValueMapping &valueMapping) const final {
39     return true;
40   }
41   // Operations in scf dialect are always legal to inline since they are
42   // pure.
isLegalToInline__anoncfc0330f0111::SCFInlinerInterface43   bool isLegalToInline(Operation *, Region *, bool,
44                        BlockAndValueMapping &) const final {
45     return true;
46   }
47   // Handle the given inlined terminator by replacing it with a new operation
48   // as necessary. Required when the region has only one block.
handleTerminator__anoncfc0330f0111::SCFInlinerInterface49   void handleTerminator(Operation *op,
50                         ArrayRef<Value> valuesToRepl) const final {
51     auto retValOp = dyn_cast<scf::YieldOp>(op);
52     if (!retValOp)
53       return;
54 
55     for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
56       std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
57     }
58   }
59 };
60 } // namespace
61 
62 //===----------------------------------------------------------------------===//
63 // SCFDialect
64 //===----------------------------------------------------------------------===//
65 
initialize()66 void SCFDialect::initialize() {
67   addOperations<
68 #define GET_OP_LIST
69 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
70       >();
71   addInterfaces<SCFInlinerInterface>();
72 }
73 
74 /// Default callback for IfOp builders. Inserts a yield without arguments.
buildTerminatedBody(OpBuilder & builder,Location loc)75 void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
76   builder.create<scf::YieldOp>(loc);
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // ExecuteRegionOp
81 //===----------------------------------------------------------------------===//
82 
83 /// Replaces the given op with the contents of the given single-block region,
84 /// using the operands of the block terminator to replace operation results.
replaceOpWithRegion(PatternRewriter & rewriter,Operation * op,Region & region,ValueRange blockArgs={})85 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
86                                 Region &region, ValueRange blockArgs = {}) {
87   assert(llvm::hasSingleElement(region) && "expected single-region block");
88   Block *block = &region.front();
89   Operation *terminator = block->getTerminator();
90   ValueRange results = terminator->getOperands();
91   rewriter.mergeBlockBefore(block, op, blockArgs);
92   rewriter.replaceOp(op, results);
93   rewriter.eraseOp(terminator);
94 }
95 
96 ///
97 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
98 ///    block+
99 /// `}`
100 ///
101 /// Example:
102 ///   scf.execute_region -> i32 {
103 ///     %idx = load %rI[%i] : memref<128xi32>
104 ///     return %idx : i32
105 ///   }
106 ///
parse(OpAsmParser & parser,OperationState & result)107 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
108                                    OperationState &result) {
109   if (parser.parseOptionalArrowTypeList(result.types))
110     return failure();
111 
112   // Introduce the body region and parse it.
113   Region *body = result.addRegion();
114   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
115       parser.parseOptionalAttrDict(result.attributes))
116     return failure();
117 
118   return success();
119 }
120 
print(OpAsmPrinter & p)121 void ExecuteRegionOp::print(OpAsmPrinter &p) {
122   p.printOptionalArrowTypeList(getResultTypes());
123 
124   p << ' ';
125   p.printRegion(getRegion(),
126                 /*printEntryBlockArgs=*/false,
127                 /*printBlockTerminators=*/true);
128 
129   p.printOptionalAttrDict((*this)->getAttrs());
130 }
131 
verify()132 LogicalResult ExecuteRegionOp::verify() {
133   if (getRegion().empty())
134     return emitOpError("region needs to have at least one block");
135   if (getRegion().front().getNumArguments() > 0)
136     return emitOpError("region cannot have any arguments");
137   return success();
138 }
139 
140 // Inline an ExecuteRegionOp if it only contains one block.
141 //     "test.foo"() : () -> ()
142 //      %v = scf.execute_region -> i64 {
143 //        %x = "test.val"() : () -> i64
144 //        scf.yield %x : i64
145 //      }
146 //      "test.bar"(%v) : (i64) -> ()
147 //
148 //  becomes
149 //
150 //     "test.foo"() : () -> ()
151 //     %x = "test.val"() : () -> i64
152 //     "test.bar"(%x) : (i64) -> ()
153 //
154 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
155   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
156 
matchAndRewriteSingleBlockExecuteInliner157   LogicalResult matchAndRewrite(ExecuteRegionOp op,
158                                 PatternRewriter &rewriter) const override {
159     if (!llvm::hasSingleElement(op.getRegion()))
160       return failure();
161     replaceOpWithRegion(rewriter, op, op.getRegion());
162     return success();
163   }
164 };
165 
166 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
167 // TODO generalize the conditions for operations which can be inlined into.
168 // func @func_execute_region_elim() {
169 //     "test.foo"() : () -> ()
170 //     %v = scf.execute_region -> i64 {
171 //       %c = "test.cmp"() : () -> i1
172 //       cf.cond_br %c, ^bb2, ^bb3
173 //     ^bb2:
174 //       %x = "test.val1"() : () -> i64
175 //       cf.br ^bb4(%x : i64)
176 //     ^bb3:
177 //       %y = "test.val2"() : () -> i64
178 //       cf.br ^bb4(%y : i64)
179 //     ^bb4(%z : i64):
180 //       scf.yield %z : i64
181 //     }
182 //     "test.bar"(%v) : (i64) -> ()
183 //   return
184 // }
185 //
186 //  becomes
187 //
188 // func @func_execute_region_elim() {
189 //    "test.foo"() : () -> ()
190 //    %c = "test.cmp"() : () -> i1
191 //    cf.cond_br %c, ^bb1, ^bb2
192 //  ^bb1:  // pred: ^bb0
193 //    %x = "test.val1"() : () -> i64
194 //    cf.br ^bb3(%x : i64)
195 //  ^bb2:  // pred: ^bb0
196 //    %y = "test.val2"() : () -> i64
197 //    cf.br ^bb3(%y : i64)
198 //  ^bb3(%z: i64):  // 2 preds: ^bb1, ^bb2
199 //    "test.bar"(%z) : (i64) -> ()
200 //    return
201 //  }
202 //
203 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
204   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
205 
matchAndRewriteMultiBlockExecuteInliner206   LogicalResult matchAndRewrite(ExecuteRegionOp op,
207                                 PatternRewriter &rewriter) const override {
208     if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
209       return failure();
210 
211     Block *prevBlock = op->getBlock();
212     Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
213     rewriter.setInsertionPointToEnd(prevBlock);
214 
215     rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
216 
217     for (Block &blk : op.getRegion()) {
218       if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
219         rewriter.setInsertionPoint(yieldOp);
220         rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
221                                       yieldOp.getResults());
222         rewriter.eraseOp(yieldOp);
223       }
224     }
225 
226     rewriter.inlineRegionBefore(op.getRegion(), postBlock);
227     SmallVector<Value> blockArgs;
228 
229     for (auto res : op.getResults())
230       blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
231 
232     rewriter.replaceOp(op, blockArgs);
233     return success();
234   }
235 };
236 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)237 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
238                                                   MLIRContext *context) {
239   results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
240 }
241 
242 /// Given the region at `index`, or the parent operation if `index` is None,
243 /// return the successor regions. These are the regions that may be selected
244 /// during the flow of control. `operands` is a set of optional attributes that
245 /// correspond to a constant value for each operand, or null if that operand is
246 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)247 void ExecuteRegionOp::getSuccessorRegions(
248     Optional<unsigned> index, ArrayRef<Attribute> operands,
249     SmallVectorImpl<RegionSuccessor> &regions) {
250   // If the predecessor is the ExecuteRegionOp, branch into the body.
251   if (!index) {
252     regions.push_back(RegionSuccessor(&getRegion()));
253     return;
254   }
255 
256   // Otherwise, the region branches back to the parent operation.
257   regions.push_back(RegionSuccessor(getResults()));
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // ConditionOp
262 //===----------------------------------------------------------------------===//
263 
264 MutableOperandRange
getMutableSuccessorOperands(Optional<unsigned> index)265 ConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
266   // Pass all operands except the condition to the successor region.
267   return getArgsMutable();
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // ForOp
272 //===----------------------------------------------------------------------===//
273 
build(OpBuilder & builder,OperationState & result,Value lb,Value ub,Value step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)274 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
275                   Value ub, Value step, ValueRange iterArgs,
276                   BodyBuilderFn bodyBuilder) {
277   result.addOperands({lb, ub, step});
278   result.addOperands(iterArgs);
279   for (Value v : iterArgs)
280     result.addTypes(v.getType());
281   Region *bodyRegion = result.addRegion();
282   bodyRegion->push_back(new Block);
283   Block &bodyBlock = bodyRegion->front();
284   bodyBlock.addArgument(builder.getIndexType(), result.location);
285   for (Value v : iterArgs)
286     bodyBlock.addArgument(v.getType(), v.getLoc());
287 
288   // Create the default terminator if the builder is not provided and if the
289   // iteration arguments are not provided. Otherwise, leave this to the caller
290   // because we don't know which values to return from the loop.
291   if (iterArgs.empty() && !bodyBuilder) {
292     ForOp::ensureTerminator(*bodyRegion, builder, result.location);
293   } else if (bodyBuilder) {
294     OpBuilder::InsertionGuard guard(builder);
295     builder.setInsertionPointToStart(&bodyBlock);
296     bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
297                 bodyBlock.getArguments().drop_front());
298   }
299 }
300 
verify()301 LogicalResult ForOp::verify() {
302   if (auto cst = getStep().getDefiningOp<arith::ConstantIndexOp>())
303     if (cst.value() <= 0)
304       return emitOpError("constant step operand must be positive");
305 
306   auto opNumResults = getNumResults();
307   if (opNumResults == 0)
308     return success();
309   // If ForOp defines values, check that the number and types of
310   // the defined values match ForOp initial iter operands and backedge
311   // basic block arguments.
312   if (getNumIterOperands() != opNumResults)
313     return emitOpError(
314         "mismatch in number of loop-carried values and defined values");
315   return success();
316 }
317 
verifyRegions()318 LogicalResult ForOp::verifyRegions() {
319   // Check that the body defines as single block argument for the induction
320   // variable.
321   auto *body = getBody();
322   if (!body->getArgument(0).getType().isIndex())
323     return emitOpError(
324         "expected body first argument to be an index argument for "
325         "the induction variable");
326 
327   auto opNumResults = getNumResults();
328   if (opNumResults == 0)
329     return success();
330 
331   if (getNumRegionIterArgs() != opNumResults)
332     return emitOpError(
333         "mismatch in number of basic block args and defined values");
334 
335   auto iterOperands = getIterOperands();
336   auto iterArgs = getRegionIterArgs();
337   auto opResults = getResults();
338   unsigned i = 0;
339   for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
340     if (std::get<0>(e).getType() != std::get<2>(e).getType())
341       return emitOpError() << "types mismatch between " << i
342                            << "th iter operand and defined value";
343     if (std::get<1>(e).getType() != std::get<2>(e).getType())
344       return emitOpError() << "types mismatch between " << i
345                            << "th iter region arg and defined value";
346 
347     i++;
348   }
349   return success();
350 }
351 
getSingleInductionVar()352 Optional<Value> ForOp::getSingleInductionVar() { return getInductionVar(); }
353 
getSingleLowerBound()354 Optional<OpFoldResult> ForOp::getSingleLowerBound() {
355   return OpFoldResult(getLowerBound());
356 }
357 
getSingleStep()358 Optional<OpFoldResult> ForOp::getSingleStep() {
359   return OpFoldResult(getStep());
360 }
361 
getSingleUpperBound()362 Optional<OpFoldResult> ForOp::getSingleUpperBound() {
363   return OpFoldResult(getUpperBound());
364 }
365 
366 /// Prints the initialization list in the form of
367 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
368 /// where 'inner' values are assumed to be region arguments and 'outer' values
369 /// are regular SSA values.
printInitializationList(OpAsmPrinter & p,Block::BlockArgListType blocksArgs,ValueRange initializers,StringRef prefix="")370 static void printInitializationList(OpAsmPrinter &p,
371                                     Block::BlockArgListType blocksArgs,
372                                     ValueRange initializers,
373                                     StringRef prefix = "") {
374   assert(blocksArgs.size() == initializers.size() &&
375          "expected same length of arguments and initializers");
376   if (initializers.empty())
377     return;
378 
379   p << prefix << '(';
380   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
381     p << std::get<0>(it) << " = " << std::get<1>(it);
382   });
383   p << ")";
384 }
385 
print(OpAsmPrinter & p)386 void ForOp::print(OpAsmPrinter &p) {
387   p << " " << getInductionVar() << " = " << getLowerBound() << " to "
388     << getUpperBound() << " step " << getStep();
389 
390   printInitializationList(p, getRegionIterArgs(), getIterOperands(),
391                           " iter_args");
392   if (!getIterOperands().empty())
393     p << " -> (" << getIterOperands().getTypes() << ')';
394   p << ' ';
395   p.printRegion(getRegion(),
396                 /*printEntryBlockArgs=*/false,
397                 /*printBlockTerminators=*/hasIterOperands());
398   p.printOptionalAttrDict((*this)->getAttrs());
399 }
400 
parse(OpAsmParser & parser,OperationState & result)401 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
402   auto &builder = parser.getBuilder();
403   Type indexType = builder.getIndexType();
404 
405   OpAsmParser::Argument inductionVariable;
406   inductionVariable.type = indexType;
407   OpAsmParser::UnresolvedOperand lb, ub, step;
408 
409   // Parse the induction variable followed by '='.
410   if (parser.parseArgument(inductionVariable) || parser.parseEqual() ||
411       // Parse loop bounds.
412       parser.parseOperand(lb) ||
413       parser.resolveOperand(lb, indexType, result.operands) ||
414       parser.parseKeyword("to") || parser.parseOperand(ub) ||
415       parser.resolveOperand(ub, indexType, result.operands) ||
416       parser.parseKeyword("step") || parser.parseOperand(step) ||
417       parser.resolveOperand(step, indexType, result.operands))
418     return failure();
419 
420   // Parse the optional initial iteration arguments.
421   SmallVector<OpAsmParser::Argument, 4> regionArgs;
422   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
423   regionArgs.push_back(inductionVariable);
424 
425   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
426     // Parse assignment list and results type list.
427     if (parser.parseAssignmentList(regionArgs, operands) ||
428         parser.parseArrowTypeList(result.types))
429       return failure();
430 
431     // Resolve input operands.
432     for (auto argOperandType :
433          llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
434       Type type = std::get<2>(argOperandType);
435       std::get<0>(argOperandType).type = type;
436       if (parser.resolveOperand(std::get<1>(argOperandType), type,
437                                 result.operands))
438         return failure();
439     }
440   }
441 
442   if (regionArgs.size() != result.types.size() + 1)
443     return parser.emitError(
444         parser.getNameLoc(),
445         "mismatch in number of loop-carried values and defined values");
446 
447   // Parse the body region.
448   Region *body = result.addRegion();
449   if (parser.parseRegion(*body, regionArgs))
450     return failure();
451 
452   ForOp::ensureTerminator(*body, builder, result.location);
453 
454   // Parse the optional attribute list.
455   if (parser.parseOptionalAttrDict(result.attributes))
456     return failure();
457 
458   return success();
459 }
460 
getLoopBody()461 Region &ForOp::getLoopBody() { return getRegion(); }
462 
getForInductionVarOwner(Value val)463 ForOp mlir::scf::getForInductionVarOwner(Value val) {
464   auto ivArg = val.dyn_cast<BlockArgument>();
465   if (!ivArg)
466     return ForOp();
467   assert(ivArg.getOwner() && "unlinked block argument");
468   auto *containingOp = ivArg.getOwner()->getParentOp();
469   return dyn_cast_or_null<ForOp>(containingOp);
470 }
471 
472 /// Return operands used when entering the region at 'index'. These operands
473 /// correspond to the loop iterator operands, i.e., those excluding the
474 /// induction variable. LoopOp only has one region, so 0 is the only valid value
475 /// for `index`.
getSuccessorEntryOperands(Optional<unsigned> index)476 OperandRange ForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
477   assert(index && *index == 0 && "invalid region index");
478 
479   // The initial operands map to the loop arguments after the induction
480   // variable.
481   return getInitArgs();
482 }
483 
484 /// Given the region at `index`, or the parent operation if `index` is None,
485 /// return the successor regions. These are the regions that may be selected
486 /// during the flow of control. `operands` is a set of optional attributes that
487 /// correspond to a constant value for each operand, or null if that operand is
488 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)489 void ForOp::getSuccessorRegions(Optional<unsigned> index,
490                                 ArrayRef<Attribute> operands,
491                                 SmallVectorImpl<RegionSuccessor> &regions) {
492   // If the predecessor is the ForOp, branch into the body using the iterator
493   // arguments.
494   if (!index) {
495     regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
496     return;
497   }
498 
499   // Otherwise, the loop may branch back to itself or the parent operation.
500   assert(*index == 0 && "expected loop region");
501   regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
502   regions.push_back(RegionSuccessor(getResults()));
503 }
504 
buildLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ValueRange steps,ValueRange iterArgs,function_ref<ValueVector (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilder)505 LoopNest mlir::scf::buildLoopNest(
506     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
507     ValueRange steps, ValueRange iterArgs,
508     function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
509         bodyBuilder) {
510   assert(lbs.size() == ubs.size() &&
511          "expected the same number of lower and upper bounds");
512   assert(lbs.size() == steps.size() &&
513          "expected the same number of lower bounds and steps");
514 
515   // If there are no bounds, call the body-building function and return early.
516   if (lbs.empty()) {
517     ValueVector results =
518         bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
519                     : ValueVector();
520     assert(results.size() == iterArgs.size() &&
521            "loop nest body must return as many values as loop has iteration "
522            "arguments");
523     return LoopNest();
524   }
525 
526   // First, create the loop structure iteratively using the body-builder
527   // callback of `ForOp::build`. Do not create `YieldOp`s yet.
528   OpBuilder::InsertionGuard guard(builder);
529   SmallVector<scf::ForOp, 4> loops;
530   SmallVector<Value, 4> ivs;
531   loops.reserve(lbs.size());
532   ivs.reserve(lbs.size());
533   ValueRange currentIterArgs = iterArgs;
534   Location currentLoc = loc;
535   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
536     auto loop = builder.create<scf::ForOp>(
537         currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
538         [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
539             ValueRange args) {
540           ivs.push_back(iv);
541           // It is safe to store ValueRange args because it points to block
542           // arguments of a loop operation that we also own.
543           currentIterArgs = args;
544           currentLoc = nestedLoc;
545         });
546     // Set the builder to point to the body of the newly created loop. We don't
547     // do this in the callback because the builder is reset when the callback
548     // returns.
549     builder.setInsertionPointToStart(loop.getBody());
550     loops.push_back(loop);
551   }
552 
553   // For all loops but the innermost, yield the results of the nested loop.
554   for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
555     builder.setInsertionPointToEnd(loops[i].getBody());
556     builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
557   }
558 
559   // In the body of the innermost loop, call the body building function if any
560   // and yield its results.
561   builder.setInsertionPointToStart(loops.back().getBody());
562   ValueVector results = bodyBuilder
563                             ? bodyBuilder(builder, currentLoc, ivs,
564                                           loops.back().getRegionIterArgs())
565                             : ValueVector();
566   assert(results.size() == iterArgs.size() &&
567          "loop nest body must return as many values as loop has iteration "
568          "arguments");
569   builder.setInsertionPointToEnd(loops.back().getBody());
570   builder.create<scf::YieldOp>(loc, results);
571 
572   // Return the loops.
573   LoopNest res;
574   res.loops.assign(loops.begin(), loops.end());
575   return res;
576 }
577 
buildLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)578 LoopNest mlir::scf::buildLoopNest(
579     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
580     ValueRange steps,
581     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
582   // Delegate to the main function by wrapping the body builder.
583   return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
584                        [&bodyBuilder](OpBuilder &nestedBuilder,
585                                       Location nestedLoc, ValueRange ivs,
586                                       ValueRange) -> ValueVector {
587                          if (bodyBuilder)
588                            bodyBuilder(nestedBuilder, nestedLoc, ivs);
589                          return {};
590                        });
591 }
592 
593 namespace {
594 // Fold away ForOp iter arguments when:
595 // 1) The op yields the iter arguments.
596 // 2) The iter arguments have no use and the corresponding outer region
597 // iterators (inputs) are yielded.
598 // 3) The iter arguments have no use and the corresponding (operation) results
599 // have no use.
600 //
601 // These arguments must be defined outside of
602 // the ForOp region and can just be forwarded after simplifying the op inits,
603 // yields and returns.
604 //
605 // The implementation uses `mergeBlockBefore` to steal the content of the
606 // original ForOp and avoid cloning.
607 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
608   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
609 
matchAndRewrite__anoncfc0330f0511::ForOpIterArgsFolder610   LogicalResult matchAndRewrite(scf::ForOp forOp,
611                                 PatternRewriter &rewriter) const final {
612     bool canonicalize = false;
613     Block &block = forOp.getRegion().front();
614     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
615 
616     // An internal flat vector of block transfer
617     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
618     // transformed block argument mappings. This plays the role of a
619     // BlockAndValueMapping for the particular use case of calling into
620     // `mergeBlockBefore`.
621     SmallVector<bool, 4> keepMask;
622     keepMask.reserve(yieldOp.getNumOperands());
623     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
624         newResultValues;
625     newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
626     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
627     newIterArgs.reserve(forOp.getNumIterOperands());
628     newYieldValues.reserve(yieldOp.getNumOperands());
629     newResultValues.reserve(forOp.getNumResults());
630     for (auto it : llvm::zip(forOp.getIterOperands(),   // iter from outside
631                              forOp.getRegionIterArgs(), // iter inside region
632                              forOp.getResults(),        // op results
633                              yieldOp.getOperands()      // iter yield
634                              )) {
635       // Forwarded is `true` when:
636       // 1) The region `iter` argument is yielded.
637       // 2) The region `iter` argument has no use, and the corresponding iter
638       // operand (input) is yielded.
639       // 3) The region `iter` argument has no use, and the corresponding op
640       // result has no use.
641       bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
642                         (std::get<1>(it).use_empty() &&
643                          (std::get<0>(it) == std::get<3>(it) ||
644                           std::get<2>(it).use_empty())));
645       keepMask.push_back(!forwarded);
646       canonicalize |= forwarded;
647       if (forwarded) {
648         newBlockTransferArgs.push_back(std::get<0>(it));
649         newResultValues.push_back(std::get<0>(it));
650         continue;
651       }
652       newIterArgs.push_back(std::get<0>(it));
653       newYieldValues.push_back(std::get<3>(it));
654       newBlockTransferArgs.push_back(Value()); // placeholder with null value
655       newResultValues.push_back(Value());      // placeholder with null value
656     }
657 
658     if (!canonicalize)
659       return failure();
660 
661     scf::ForOp newForOp = rewriter.create<scf::ForOp>(
662         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
663         forOp.getStep(), newIterArgs);
664     newForOp->setAttrs(forOp->getAttrs());
665     Block &newBlock = newForOp.getRegion().front();
666 
667     // Replace the null placeholders with newly constructed values.
668     newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
669     for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
670          idx != e; ++idx) {
671       Value &blockTransferArg = newBlockTransferArgs[1 + idx];
672       Value &newResultVal = newResultValues[idx];
673       assert((blockTransferArg && newResultVal) ||
674              (!blockTransferArg && !newResultVal));
675       if (!blockTransferArg) {
676         blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
677         newResultVal = newForOp.getResult(collapsedIdx++);
678       }
679     }
680 
681     Block &oldBlock = forOp.getRegion().front();
682     assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
683            "unexpected argument size mismatch");
684 
685     // No results case: the scf::ForOp builder already created a zero
686     // result terminator. Merge before this terminator and just get rid of the
687     // original terminator that has been merged in.
688     if (newIterArgs.empty()) {
689       auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
690       rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
691       rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
692       rewriter.replaceOp(forOp, newResultValues);
693       return success();
694     }
695 
696     // No terminator case: merge and rewrite the merged terminator.
697     auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
698       OpBuilder::InsertionGuard g(rewriter);
699       rewriter.setInsertionPoint(mergedTerminator);
700       SmallVector<Value, 4> filteredOperands;
701       filteredOperands.reserve(newResultValues.size());
702       for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
703         if (keepMask[idx])
704           filteredOperands.push_back(mergedTerminator.getOperand(idx));
705       rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
706                                     filteredOperands);
707     };
708 
709     rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
710     auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
711     cloneFilteredTerminator(mergedYieldOp);
712     rewriter.eraseOp(mergedYieldOp);
713     rewriter.replaceOp(forOp, newResultValues);
714     return success();
715   }
716 };
717 
718 /// Rewriting pattern that erases loops that are known not to iterate, replaces
719 /// single-iteration loops with their bodies, and removes empty loops that
720 /// iterate at least once and only return values defined outside of the loop.
721 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
722   using OpRewritePattern<ForOp>::OpRewritePattern;
723 
matchAndRewrite__anoncfc0330f0511::SimplifyTrivialLoops724   LogicalResult matchAndRewrite(ForOp op,
725                                 PatternRewriter &rewriter) const override {
726     // If the upper bound is the same as the lower bound, the loop does not
727     // iterate, just remove it.
728     if (op.getLowerBound() == op.getUpperBound()) {
729       rewriter.replaceOp(op, op.getIterOperands());
730       return success();
731     }
732 
733     auto lb = op.getLowerBound().getDefiningOp<arith::ConstantOp>();
734     auto ub = op.getUpperBound().getDefiningOp<arith::ConstantOp>();
735     if (!lb || !ub)
736       return failure();
737 
738     // If the loop is known to have 0 iterations, remove it.
739     llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
740     llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
741     if (lbValue.sge(ubValue)) {
742       rewriter.replaceOp(op, op.getIterOperands());
743       return success();
744     }
745 
746     auto step = op.getStep().getDefiningOp<arith::ConstantOp>();
747     if (!step)
748       return failure();
749 
750     // If the loop is known to have 1 iteration, inline its body and remove the
751     // loop.
752     llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
753     if ((lbValue + stepValue).sge(ubValue)) {
754       SmallVector<Value, 4> blockArgs;
755       blockArgs.reserve(op.getNumIterOperands() + 1);
756       blockArgs.push_back(op.getLowerBound());
757       llvm::append_range(blockArgs, op.getIterOperands());
758       replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
759       return success();
760     }
761 
762     // Now we are left with loops that have more than 1 iterations.
763     Block &block = op.getRegion().front();
764     if (!llvm::hasSingleElement(block))
765       return failure();
766     // If the loop is empty, iterates at least once, and only returns values
767     // defined outside of the loop, remove it and replace it with yield values.
768     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
769     auto yieldOperands = yieldOp.getOperands();
770     if (llvm::any_of(yieldOperands,
771                      [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
772       return failure();
773     rewriter.replaceOp(op, yieldOperands);
774     return success();
775   }
776 };
777 
778 /// Perform a replacement of one iter OpOperand of an scf.for to the
779 /// `replacement` value which is expected to be the source of a tensor.cast.
780 /// tensor.cast ops are inserted inside the block to account for the type cast.
replaceTensorCastForOpIterArg(PatternRewriter & rewriter,OpOperand & operand,Value replacement)781 static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
782                                            OpOperand &operand,
783                                            Value replacement) {
784   Type oldType = operand.get().getType(), newType = replacement.getType();
785   assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
786          "expected ranked tensor types");
787 
788   // 1. Create new iter operands, exactly 1 is replaced.
789   ForOp forOp = cast<ForOp>(operand.getOwner());
790   assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
791          "expected an iter OpOperand");
792   if (operand.get().getType() == replacement.getType())
793     return forOp;
794   SmallVector<Value> newIterOperands;
795   for (OpOperand &opOperand : forOp.getIterOpOperands()) {
796     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
797       newIterOperands.push_back(replacement);
798       continue;
799     }
800     newIterOperands.push_back(opOperand.get());
801   }
802 
803   // 2. Create the new forOp shell.
804   scf::ForOp newForOp = rewriter.create<scf::ForOp>(
805       forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806       forOp.getStep(), newIterOperands);
807   newForOp->setAttrs(forOp->getAttrs());
808   Block &newBlock = newForOp.getRegion().front();
809   SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
810                                              newBlock.getArguments().end());
811 
812   // 3. Inject an incoming cast op at the beginning of the block for the bbArg
813   // corresponding to the `replacement` value.
814   OpBuilder::InsertionGuard g(rewriter);
815   rewriter.setInsertionPoint(&newBlock, newBlock.begin());
816   BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
817       newForOp->getOpOperand(operand.getOperandNumber()));
818   Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
819                                                  newRegionIterArg);
820   newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
821 
822   // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
823   Block &oldBlock = forOp.getRegion().front();
824   rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
825 
826   // 5. Inject an outgoing cast op at the end of the block and yield it instead.
827   auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
828   rewriter.setInsertionPoint(clonedYieldOp);
829   unsigned yieldIdx =
830       newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
831   Value castOut = rewriter.create<tensor::CastOp>(
832       newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
833   SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
834   newYieldOperands[yieldIdx] = castOut;
835   rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
836   rewriter.eraseOp(clonedYieldOp);
837 
838   // 6. Inject an outgoing cast op after the forOp.
839   rewriter.setInsertionPointAfter(newForOp);
840   SmallVector<Value> newResults = newForOp.getResults();
841   newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
842       newForOp.getLoc(), oldType, newResults[yieldIdx]);
843 
844   return newForOp;
845 }
846 
847 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
848 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
849 ///
850 /// ```
851 ///   %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
852 ///   %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
853 ///      -> (tensor<?x?xf32>) {
854 ///     %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
855 ///     scf.yield %2 : tensor<?x?xf32>
856 ///   }
857 ///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
858 ///   use_of(%2)
859 /// ```
860 ///
861 /// folds into:
862 ///
863 /// ```
864 ///   %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
865 ///       -> (tensor<32x1024xf32>) {
866 ///     %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
867 ///     %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
868 ///     %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
869 ///     scf.yield %4 : tensor<32x1024xf32>
870 ///   }
871 ///   use_of(%0)
872 /// ```
873 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
874   using OpRewritePattern<ForOp>::OpRewritePattern;
875 
matchAndRewrite__anoncfc0330f0511::ForOpTensorCastFolder876   LogicalResult matchAndRewrite(ForOp op,
877                                 PatternRewriter &rewriter) const override {
878     for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
879       OpOperand &iterOpOperand = std::get<0>(it);
880       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
881       if (!incomingCast)
882         continue;
883       if (!std::get<1>(it).hasOneUse())
884         continue;
885       auto outgoingCastOp =
886           dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
887       if (!outgoingCastOp)
888         continue;
889 
890       // Must be a tensor.cast op pair with matching types.
891       if (outgoingCastOp.getResult().getType() !=
892           incomingCast.getSource().getType())
893         continue;
894 
895       // Create a new ForOp with that iter operand replaced.
896       auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
897                                                     incomingCast.getSource());
898 
899       // Insert outgoing cast and use it to replace the corresponding result.
900       rewriter.setInsertionPointAfter(newForOp);
901       SmallVector<Value> replacements = newForOp.getResults();
902       unsigned returnIdx =
903           iterOpOperand.getOperandNumber() - op.getNumControlOperands();
904       replacements[returnIdx] = rewriter.create<tensor::CastOp>(
905           op.getLoc(), incomingCast.getDest().getType(),
906           replacements[returnIdx]);
907       rewriter.replaceOp(op, replacements);
908       return success();
909     }
910     return failure();
911   }
912 };
913 
914 /// Canonicalize the iter_args of an scf::ForOp that involve a
915 /// `bufferization.to_tensor` and for which only the last loop iteration is
916 /// actually visible outside of the loop. The canonicalization looks for a
917 /// pattern such as:
918 /// ```
919 ///    %t0 = ... : tensor_type
920 ///    %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
921 ///      ...
922 ///      // %m is either buffer_cast(%bb00) or defined above the loop
923 ///      %m... : memref_type
924 ///      ... // uses of %m with potential inplace updates
925 ///      %new_tensor = bufferization.to_tensor %m : memref_type
926 ///      ...
927 ///      scf.yield %new_tensor : tensor_type
928 ///    }
929 /// ```
930 ///
931 /// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
932 /// `%m = buffer_cast %bb0` op that feeds into the yielded
933 /// `bufferization.to_tensor` op.
934 ///
935 /// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
936 /// occurs between `bufferization.to_tensor and yield then the value %0
937 /// visible outside of the loop is the last `bufferization.to_tensor`
938 /// produced in the loop.
939 ///
940 /// For now, we approximate the absence of aliasing by only supporting the case
941 /// when the bufferization.to_tensor is the operation immediately preceding
942 /// the yield.
943 //
944 /// The canonicalization rewrites the pattern as:
945 /// ```
946 ///    // %m is either a buffer_cast or defined above
947 ///    %m... : memref_type
948 ///    scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
949 ///      ... // uses of %m with potential inplace updates
950 ///      scf.yield %bb0: tensor_type
951 ///    }
952 ///    %0 = bufferization.to_tensor %m : memref_type
953 /// ```
954 ///
955 /// A later bbArg canonicalization will further rewrite as:
956 /// ```
957 ///    // %m is either a buffer_cast or defined above
958 ///    %m... : memref_type
959 ///    scf.for ... { // no iter_args
960 ///      ... // uses of %m with potential inplace updates
961 ///    }
962 ///    %0 = bufferization.to_tensor %m : memref_type
963 /// ```
964 struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
965   using OpRewritePattern<ForOp>::OpRewritePattern;
966 
matchAndRewrite__anoncfc0330f0511::LastTensorLoadCanonicalization967   LogicalResult matchAndRewrite(ForOp forOp,
968                                 PatternRewriter &rewriter) const override {
969     assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
970            "unexpected multiple blocks");
971 
972     Location loc = forOp.getLoc();
973     DenseMap<Value, Value> replacements;
974     for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
975       unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
976       auto yieldOp =
977           cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
978       Value yieldVal = yieldOp->getOperand(idx);
979       auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
980       bool isTensor = bbArg.getType().isa<TensorType>();
981 
982       bufferization::ToMemrefOp tensorToMemref;
983       // Either bbArg has no use or it has a single buffer_cast use.
984       if (bbArg.hasOneUse())
985         tensorToMemref =
986             dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
987       if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
988         continue;
989       // If tensorToMemref is present, it must feed into the `ToTensorOp`.
990       if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
991         continue;
992       // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
993       // must be before `ToTensorOp` in the block so that the lastWrite
994       // property is not subject to additional side-effects.
995       // For now, we only support the case when ToTensorOp appears
996       // immediately before the terminator.
997       if (tensorLoadOp->getNextNode() != yieldOp)
998         continue;
999 
1000       // Clone the optional tensorToMemref before forOp.
1001       if (tensorToMemref) {
1002         rewriter.setInsertionPoint(forOp);
1003         rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
1004             tensorToMemref, tensorToMemref.getMemref().getType(),
1005             tensorToMemref.getTensor());
1006       }
1007 
1008       // Clone the tensorLoad after forOp.
1009       rewriter.setInsertionPointAfter(forOp);
1010       Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
1011           loc, tensorLoadOp.getMemref());
1012       Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
1013       replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1014 
1015       // Make the terminator just yield the bbArg, the old tensorLoadOp + the
1016       // old bbArg (that is now directly yielded) will canonicalize away.
1017       rewriter.startRootUpdate(yieldOp);
1018       yieldOp.setOperand(idx, bbArg);
1019       rewriter.finalizeRootUpdate(yieldOp);
1020     }
1021     if (replacements.empty())
1022       return failure();
1023 
1024     // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
1025     // replaces the whole op and erase it unconditionally. This is wrong for
1026     // `forOp` as it generally contains ops with side effects.
1027     // Instead, use `rewriter.replaceOpWithIf`.
1028     SmallVector<Value> newResults;
1029     newResults.reserve(forOp.getNumResults());
1030     for (Value v : forOp.getResults()) {
1031       auto it = replacements.find(v);
1032       newResults.push_back((it != replacements.end()) ? it->second : v);
1033     }
1034     unsigned idx = 0;
1035     rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
1036       return op.get() != newResults[idx++];
1037     });
1038     return success();
1039   }
1040 };
1041 } // namespace
1042 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1043 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1044                                         MLIRContext *context) {
1045   results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1046               LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // ForeachThreadOp
1051 //===----------------------------------------------------------------------===//
1052 
verify()1053 LogicalResult ForeachThreadOp::verify() {
1054   // Call terminator's verify to produce most informative error messages.
1055   if (failed(getTerminator().verify()))
1056     return failure();
1057 
1058   // Check that the body defines as single block argument for the thread index.
1059   auto *body = getBody();
1060   if (body->getNumArguments() != getRank())
1061     return emitOpError("region expects ") << getRank() << " arguments";
1062 
1063   // Verify consistency between the result types and the terminator.
1064   auto terminatorTypes = getTerminator().getYieldedTypes();
1065   auto opResults = getResults();
1066   if (opResults.size() != terminatorTypes.size())
1067     return emitOpError("produces ")
1068            << opResults.size() << " results, but its terminator yields "
1069            << terminatorTypes.size() << " value(s)";
1070   unsigned i = 0;
1071   for (auto e : llvm::zip(terminatorTypes, opResults)) {
1072     if (std::get<0>(e) != std::get<1>(e).getType())
1073       return emitOpError() << "type mismatch between result " << i << " ("
1074                            << std::get<1>(e).getType() << ") and terminator ("
1075                            << std::get<0>(e) << ")";
1076     i++;
1077   }
1078   return success();
1079 }
1080 
print(OpAsmPrinter & p)1081 void ForeachThreadOp::print(OpAsmPrinter &p) {
1082   p << " (";
1083   llvm::interleaveComma(getThreadIndices(), p);
1084   p << ") in (";
1085   llvm::interleaveComma(getNumThreads(), p);
1086   p << ") -> (" << getResultTypes() << ") ";
1087   p.printRegion(getRegion(),
1088                 /*printEntryBlockArgs=*/false,
1089                 /*printBlockTerminators=*/getNumResults() > 0);
1090   p.printOptionalAttrDict(getOperation()->getAttrs());
1091 }
1092 
parse(OpAsmParser & parser,OperationState & result)1093 ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
1094                                    OperationState &result) {
1095   auto &builder = parser.getBuilder();
1096   // Parse an opening `(` followed by thread index variables followed by `)`
1097   // TODO: when we can refer to such "induction variable"-like handles from the
1098   // declarative assembly format, we can implement the parser as a custom hook.
1099   SmallVector<OpAsmParser::Argument, 4> threadIndices;
1100   if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
1101     return failure();
1102 
1103   // Parse `in` threadNums.
1104   SmallVector<OpAsmParser::UnresolvedOperand, 4> threadNums;
1105   if (parser.parseKeyword("in") ||
1106       parser.parseOperandList(threadNums, threadIndices.size(),
1107                               OpAsmParser::Delimiter::Paren) ||
1108       parser.resolveOperands(threadNums, builder.getIndexType(),
1109                              result.operands))
1110     return failure();
1111 
1112   // Parse optional results.
1113   if (parser.parseOptionalArrowTypeList(result.types))
1114     return failure();
1115 
1116   // Parse region.
1117   std::unique_ptr<Region> region = std::make_unique<Region>();
1118   for (auto &idx : threadIndices)
1119     idx.type = builder.getIndexType();
1120   if (parser.parseRegion(*region, threadIndices))
1121     return failure();
1122 
1123   // Ensure terminator and move region.
1124   OpBuilder b(builder.getContext());
1125   ForeachThreadOp::ensureTerminator(*region, b, result.location);
1126   result.addRegion(std::move(region));
1127 
1128   // Parse the optional attribute list.
1129   if (parser.parseOptionalAttrDict(result.attributes))
1130     return failure();
1131 
1132   return success();
1133 }
1134 
1135 // Bodyless builder, result types must be specified.
build(mlir::OpBuilder & builder,mlir::OperationState & result,TypeRange resultTypes,ValueRange numThreads,ArrayRef<int64_t> threadDimMapping)1136 void ForeachThreadOp::build(mlir::OpBuilder &builder,
1137                             mlir::OperationState &result, TypeRange resultTypes,
1138                             ValueRange numThreads,
1139                             ArrayRef<int64_t> threadDimMapping) {
1140   result.addOperands(numThreads);
1141   result.addAttribute(
1142       // TODO: getThreadDimMappingAttrName() but it is not a static member.
1143       "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
1144 
1145   Region *bodyRegion = result.addRegion();
1146   OpBuilder::InsertionGuard g(builder);
1147   // createBlock sets the IP inside the block.
1148   // Generally we would guard against that but the default ensureTerminator impl
1149   // expects it ..
1150   builder.createBlock(bodyRegion);
1151   Block &bodyBlock = bodyRegion->front();
1152   bodyBlock.addArguments(
1153       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1154       SmallVector<Location>(numThreads.size(), result.location));
1155   ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
1156   result.addTypes(resultTypes);
1157 }
1158 
1159 // Builder that takes a bodyBuilder lambda, result types are inferred from
1160 // the terminator.
build(mlir::OpBuilder & builder,mlir::OperationState & result,ValueRange numThreads,ArrayRef<int64_t> threadDimMapping,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)1161 void ForeachThreadOp::build(
1162     mlir::OpBuilder &builder, mlir::OperationState &result,
1163     ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
1164     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1165   result.addOperands(numThreads);
1166   result.addAttribute(
1167       // TODO: getThreadDimMappingAttrName() but it is not a static member.
1168       "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
1169 
1170   OpBuilder::InsertionGuard g(builder);
1171   Region *bodyRegion = result.addRegion();
1172   builder.createBlock(bodyRegion);
1173   Block &bodyBlock = bodyRegion->front();
1174   bodyBlock.addArguments(
1175       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1176       SmallVector<Location>(numThreads.size(), result.location));
1177 
1178   OpBuilder::InsertionGuard guard(builder);
1179   builder.setInsertionPointToStart(&bodyBlock);
1180   bodyBuilder(builder, result.location, bodyBlock.getArguments());
1181   auto terminator =
1182       llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
1183   assert(terminator &&
1184          "expected bodyBuilder to create PerformConcurrentlyOp terminator");
1185   result.addTypes(terminator.getYieldedTypes());
1186 }
1187 
1188 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1189 // unaware of the fact that our terminator also needs a region to be
1190 // well-formed. We override it here to ensure that we do the right thing.
ensureTerminator(Region & region,OpBuilder & builder,Location loc)1191 void ForeachThreadOp::ensureTerminator(Region &region, OpBuilder &builder,
1192                                        Location loc) {
1193   OpTrait::SingleBlockImplicitTerminator<PerformConcurrentlyOp>::Impl<
1194       ForeachThreadOp>::ensureTerminator(region, builder, loc);
1195   auto terminator =
1196       llvm::dyn_cast<PerformConcurrentlyOp>(region.front().getTerminator());
1197   if (terminator.getRegion().empty())
1198     builder.createBlock(&terminator.getRegion());
1199 }
1200 
getTerminator()1201 PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
1202   return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
1203 }
1204 
getForeachThreadOpThreadIndexOwner(Value val)1205 ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
1206   auto tidxArg = val.dyn_cast<BlockArgument>();
1207   if (!tidxArg)
1208     return ForeachThreadOp();
1209   assert(tidxArg.getOwner() && "unlinked block argument");
1210   auto *containingOp = tidxArg.getOwner()->getParentOp();
1211   return dyn_cast<ForeachThreadOp>(containingOp);
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // PerformConcurrentlyOp
1216 //===----------------------------------------------------------------------===//
1217 
1218 // Build a PerformConcurrentlyOp with mixed static and dynamic entries.
build(OpBuilder & b,OperationState & result)1219 void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
1220   OpBuilder::InsertionGuard g(b);
1221   Region *bodyRegion = result.addRegion();
1222   b.createBlock(bodyRegion);
1223 }
1224 
verify()1225 LogicalResult PerformConcurrentlyOp::verify() {
1226   // TODO: PerformConcurrentlyOpInterface.
1227   for (const Operation &op : getRegion().front().getOperations()) {
1228     if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1229       return this->emitOpError("expected only ")
1230              << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1231     }
1232   }
1233   return success();
1234 }
1235 
print(OpAsmPrinter & p)1236 void PerformConcurrentlyOp::print(OpAsmPrinter &p) {
1237   p << " ";
1238   p.printRegion(getRegion(),
1239                 /*printEntryBlockArgs=*/false,
1240                 /*printBlockTerminators=*/false);
1241   p.printOptionalAttrDict(getOperation()->getAttrs());
1242 }
1243 
parse(OpAsmParser & parser,OperationState & result)1244 ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
1245                                          OperationState &result) {
1246   auto &builder = parser.getBuilder();
1247 
1248   SmallVector<OpAsmParser::Argument, 8> regionOperands;
1249   std::unique_ptr<Region> region = std::make_unique<Region>();
1250   if (parser.parseRegion(*region, regionOperands))
1251     return failure();
1252 
1253   if (region->empty())
1254     OpBuilder(builder.getContext()).createBlock(region.get());
1255   result.addRegion(std::move(region));
1256 
1257   // Parse the optional attribute list.
1258   if (parser.parseOptionalAttrDict(result.attributes))
1259     return failure();
1260   return success();
1261 }
1262 
getParentResult(int64_t idx)1263 OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
1264   return getOperation()->getParentOp()->getResult(idx);
1265 }
1266 
getYieldedTypes()1267 SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
1268   return llvm::to_vector<4>(
1269       llvm::map_range(getYieldingOps(), [](Operation &op) {
1270         auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
1271         return insertSliceOp ? insertSliceOp.yieldedType() : Type();
1272       }));
1273 }
1274 
getYieldingOps()1275 llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::getYieldingOps() {
1276   return getRegion().front().getOperations();
1277 }
1278 
1279 //===----------------------------------------------------------------------===//
1280 // IfOp
1281 //===----------------------------------------------------------------------===//
1282 
insideMutuallyExclusiveBranches(Operation * a,Operation * b)1283 bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
1284   assert(a && "expected non-empty operation");
1285   assert(b && "expected non-empty operation");
1286 
1287   IfOp ifOp = a->getParentOfType<IfOp>();
1288   while (ifOp) {
1289     // Check if b is inside ifOp. (We already know that a is.)
1290     if (ifOp->isProperAncestor(b))
1291       // b is contained in ifOp. a and b are in mutually exclusive branches if
1292       // they are in different blocks of ifOp.
1293       return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1294              static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1295     // Check next enclosing IfOp.
1296     ifOp = ifOp->getParentOfType<IfOp>();
1297   }
1298 
1299   // Could not find a common IfOp among a's and b's ancestors.
1300   return false;
1301 }
1302 
build(OpBuilder & builder,OperationState & result,Value cond,bool withElseRegion)1303 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1304                  bool withElseRegion) {
1305   build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
1306 }
1307 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value cond,bool withElseRegion)1308 void IfOp::build(OpBuilder &builder, OperationState &result,
1309                  TypeRange resultTypes, Value cond, bool withElseRegion) {
1310   auto addTerminator = [&](OpBuilder &nested, Location loc) {
1311     if (resultTypes.empty())
1312       IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
1313                              loc);
1314   };
1315 
1316   build(builder, result, resultTypes, cond, addTerminator,
1317         withElseRegion ? addTerminator
1318                        : function_ref<void(OpBuilder &, Location)>());
1319 }
1320 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value cond,function_ref<void (OpBuilder &,Location)> thenBuilder,function_ref<void (OpBuilder &,Location)> elseBuilder)1321 void IfOp::build(OpBuilder &builder, OperationState &result,
1322                  TypeRange resultTypes, Value cond,
1323                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1324                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1325   assert(thenBuilder && "the builder callback for 'then' must be present");
1326 
1327   result.addOperands(cond);
1328   result.addTypes(resultTypes);
1329 
1330   OpBuilder::InsertionGuard guard(builder);
1331   Region *thenRegion = result.addRegion();
1332   builder.createBlock(thenRegion);
1333   thenBuilder(builder, result.location);
1334 
1335   Region *elseRegion = result.addRegion();
1336   if (!elseBuilder)
1337     return;
1338 
1339   builder.createBlock(elseRegion);
1340   elseBuilder(builder, result.location);
1341 }
1342 
build(OpBuilder & builder,OperationState & result,Value cond,function_ref<void (OpBuilder &,Location)> thenBuilder,function_ref<void (OpBuilder &,Location)> elseBuilder)1343 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1344                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1345                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1346   build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
1347 }
1348 
verify()1349 LogicalResult IfOp::verify() {
1350   if (getNumResults() != 0 && getElseRegion().empty())
1351     return emitOpError("must have an else block if defining values");
1352   return success();
1353 }
1354 
parse(OpAsmParser & parser,OperationState & result)1355 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1356   // Create the regions for 'then'.
1357   result.regions.reserve(2);
1358   Region *thenRegion = result.addRegion();
1359   Region *elseRegion = result.addRegion();
1360 
1361   auto &builder = parser.getBuilder();
1362   OpAsmParser::UnresolvedOperand cond;
1363   Type i1Type = builder.getIntegerType(1);
1364   if (parser.parseOperand(cond) ||
1365       parser.resolveOperand(cond, i1Type, result.operands))
1366     return failure();
1367   // Parse optional results type list.
1368   if (parser.parseOptionalArrowTypeList(result.types))
1369     return failure();
1370   // Parse the 'then' region.
1371   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1372     return failure();
1373   IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1374 
1375   // If we find an 'else' keyword then parse the 'else' region.
1376   if (!parser.parseOptionalKeyword("else")) {
1377     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1378       return failure();
1379     IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1380   }
1381 
1382   // Parse the optional attribute list.
1383   if (parser.parseOptionalAttrDict(result.attributes))
1384     return failure();
1385   return success();
1386 }
1387 
print(OpAsmPrinter & p)1388 void IfOp::print(OpAsmPrinter &p) {
1389   bool printBlockTerminators = false;
1390 
1391   p << " " << getCondition();
1392   if (!getResults().empty()) {
1393     p << " -> (" << getResultTypes() << ")";
1394     // Print yield explicitly if the op defines values.
1395     printBlockTerminators = true;
1396   }
1397   p << ' ';
1398   p.printRegion(getThenRegion(),
1399                 /*printEntryBlockArgs=*/false,
1400                 /*printBlockTerminators=*/printBlockTerminators);
1401 
1402   // Print the 'else' regions if it exists and has a block.
1403   auto &elseRegion = getElseRegion();
1404   if (!elseRegion.empty()) {
1405     p << " else ";
1406     p.printRegion(elseRegion,
1407                   /*printEntryBlockArgs=*/false,
1408                   /*printBlockTerminators=*/printBlockTerminators);
1409   }
1410 
1411   p.printOptionalAttrDict((*this)->getAttrs());
1412 }
1413 
1414 /// Given the region at `index`, or the parent operation if `index` is None,
1415 /// return the successor regions. These are the regions that may be selected
1416 /// during the flow of control. `operands` is a set of optional attributes that
1417 /// correspond to a constant value for each operand, or null if that operand is
1418 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1419 void IfOp::getSuccessorRegions(Optional<unsigned> index,
1420                                ArrayRef<Attribute> operands,
1421                                SmallVectorImpl<RegionSuccessor> &regions) {
1422   // The `then` and the `else` region branch back to the parent operation.
1423   if (index) {
1424     regions.push_back(RegionSuccessor(getResults()));
1425     return;
1426   }
1427 
1428   // Don't consider the else region if it is empty.
1429   Region *elseRegion = &this->getElseRegion();
1430   if (elseRegion->empty())
1431     elseRegion = nullptr;
1432 
1433   // Otherwise, the successor is dependent on the condition.
1434   bool condition;
1435   if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1436     condition = condAttr.getValue().isOneValue();
1437   } else {
1438     // If the condition isn't constant, both regions may be executed.
1439     regions.push_back(RegionSuccessor(&getThenRegion()));
1440     // If the else region does not exist, it is not a viable successor.
1441     if (elseRegion)
1442       regions.push_back(RegionSuccessor(elseRegion));
1443     return;
1444   }
1445 
1446   // Add the successor regions using the condition.
1447   regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
1448 }
1449 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1450 LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
1451                          SmallVectorImpl<OpFoldResult> &results) {
1452   // if (!c) then A() else B() -> if c then B() else A()
1453   if (getElseRegion().empty())
1454     return failure();
1455 
1456   arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
1457   if (!xorStmt)
1458     return failure();
1459 
1460   if (!matchPattern(xorStmt.getRhs(), m_One()))
1461     return failure();
1462 
1463   getConditionMutable().assign(xorStmt.getLhs());
1464   Block *thenBlock = &getThenRegion().front();
1465   // It would be nicer to use iplist::swap, but that has no implemented
1466   // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
1467   getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
1468                                      getElseRegion().getBlocks());
1469   getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
1470                                      getThenRegion().getBlocks(), thenBlock);
1471   return success();
1472 }
1473 
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & invocationBounds)1474 void IfOp::getRegionInvocationBounds(
1475     ArrayRef<Attribute> operands,
1476     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1477   if (auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
1478     // If the condition is known, then one region is known to be executed once
1479     // and the other zero times.
1480     invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
1481     invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
1482   } else {
1483     // Non-constant condition. Each region may be executed 0 or 1 times.
1484     invocationBounds.assign(2, {0, 1});
1485   }
1486 }
1487 
1488 namespace {
1489 // Pattern to remove unused IfOp results.
1490 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
1491   using OpRewritePattern<IfOp>::OpRewritePattern;
1492 
transferBody__anoncfc0330f0b11::RemoveUnusedResults1493   void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
1494                     PatternRewriter &rewriter) const {
1495     // Move all operations to the destination block.
1496     rewriter.mergeBlocks(source, dest);
1497     // Replace the yield op by one that returns only the used values.
1498     auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
1499     SmallVector<Value, 4> usedOperands;
1500     llvm::transform(usedResults, std::back_inserter(usedOperands),
1501                     [&](OpResult result) {
1502                       return yieldOp.getOperand(result.getResultNumber());
1503                     });
1504     rewriter.updateRootInPlace(yieldOp,
1505                                [&]() { yieldOp->setOperands(usedOperands); });
1506   }
1507 
matchAndRewrite__anoncfc0330f0b11::RemoveUnusedResults1508   LogicalResult matchAndRewrite(IfOp op,
1509                                 PatternRewriter &rewriter) const override {
1510     // Compute the list of used results.
1511     SmallVector<OpResult, 4> usedResults;
1512     llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
1513                   [](OpResult result) { return !result.use_empty(); });
1514 
1515     // Replace the operation if only a subset of its results have uses.
1516     if (usedResults.size() == op.getNumResults())
1517       return failure();
1518 
1519     // Compute the result types of the replacement operation.
1520     SmallVector<Type, 4> newTypes;
1521     llvm::transform(usedResults, std::back_inserter(newTypes),
1522                     [](OpResult result) { return result.getType(); });
1523 
1524     // Create a replacement operation with empty then and else regions.
1525     auto emptyBuilder = [](OpBuilder &, Location) {};
1526     auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
1527                                        emptyBuilder, emptyBuilder);
1528 
1529     // Move the bodies and replace the terminators (note there is a then and
1530     // an else region since the operation returns results).
1531     transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
1532     transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
1533 
1534     // Replace the operation by the new one.
1535     SmallVector<Value, 4> repResults(op.getNumResults());
1536     for (const auto &en : llvm::enumerate(usedResults))
1537       repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
1538     rewriter.replaceOp(op, repResults);
1539     return success();
1540   }
1541 };
1542 
1543 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
1544   using OpRewritePattern<IfOp>::OpRewritePattern;
1545 
matchAndRewrite__anoncfc0330f0b11::RemoveStaticCondition1546   LogicalResult matchAndRewrite(IfOp op,
1547                                 PatternRewriter &rewriter) const override {
1548     auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
1549     if (!constant)
1550       return failure();
1551 
1552     if (constant.getValue().cast<BoolAttr>().getValue())
1553       replaceOpWithRegion(rewriter, op, op.getThenRegion());
1554     else if (!op.getElseRegion().empty())
1555       replaceOpWithRegion(rewriter, op, op.getElseRegion());
1556     else
1557       rewriter.eraseOp(op);
1558 
1559     return success();
1560   }
1561 };
1562 
1563 /// Hoist any yielded results whose operands are defined outside
1564 /// the if, to a select instruction.
1565 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
1566   using OpRewritePattern<IfOp>::OpRewritePattern;
1567 
matchAndRewrite__anoncfc0330f0b11::ConvertTrivialIfToSelect1568   LogicalResult matchAndRewrite(IfOp op,
1569                                 PatternRewriter &rewriter) const override {
1570     if (op->getNumResults() == 0)
1571       return failure();
1572 
1573     auto cond = op.getCondition();
1574     auto thenYieldArgs = op.thenYield().getOperands();
1575     auto elseYieldArgs = op.elseYield().getOperands();
1576 
1577     SmallVector<Type> nonHoistable;
1578     for (const auto &it :
1579          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1580       Value trueVal = std::get<0>(it.value());
1581       Value falseVal = std::get<1>(it.value());
1582       if (&op.getThenRegion() == trueVal.getParentRegion() ||
1583           &op.getElseRegion() == falseVal.getParentRegion())
1584         nonHoistable.push_back(trueVal.getType());
1585     }
1586     // Early exit if there aren't any yielded values we can
1587     // hoist outside the if.
1588     if (nonHoistable.size() == op->getNumResults())
1589       return failure();
1590 
1591     IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
1592     if (replacement.thenBlock())
1593       rewriter.eraseBlock(replacement.thenBlock());
1594     replacement.getThenRegion().takeBody(op.getThenRegion());
1595     replacement.getElseRegion().takeBody(op.getElseRegion());
1596 
1597     SmallVector<Value> results(op->getNumResults());
1598     assert(thenYieldArgs.size() == results.size());
1599     assert(elseYieldArgs.size() == results.size());
1600 
1601     SmallVector<Value> trueYields;
1602     SmallVector<Value> falseYields;
1603     rewriter.setInsertionPoint(replacement);
1604     for (const auto &it :
1605          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1606       Value trueVal = std::get<0>(it.value());
1607       Value falseVal = std::get<1>(it.value());
1608       if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
1609           &replacement.getElseRegion() == falseVal.getParentRegion()) {
1610         results[it.index()] = replacement.getResult(trueYields.size());
1611         trueYields.push_back(trueVal);
1612         falseYields.push_back(falseVal);
1613       } else if (trueVal == falseVal)
1614         results[it.index()] = trueVal;
1615       else
1616         results[it.index()] = rewriter.create<arith::SelectOp>(
1617             op.getLoc(), cond, trueVal, falseVal);
1618     }
1619 
1620     rewriter.setInsertionPointToEnd(replacement.thenBlock());
1621     rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
1622 
1623     rewriter.setInsertionPointToEnd(replacement.elseBlock());
1624     rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
1625 
1626     rewriter.replaceOp(op, results);
1627     return success();
1628   }
1629 };
1630 
1631 /// Allow the true region of an if to assume the condition is true
1632 /// and vice versa. For example:
1633 ///
1634 ///   scf.if %cmp {
1635 ///      print(%cmp)
1636 ///   }
1637 ///
1638 ///  becomes
1639 ///
1640 ///   scf.if %cmp {
1641 ///      print(true)
1642 ///   }
1643 ///
1644 struct ConditionPropagation : public OpRewritePattern<IfOp> {
1645   using OpRewritePattern<IfOp>::OpRewritePattern;
1646 
matchAndRewrite__anoncfc0330f0b11::ConditionPropagation1647   LogicalResult matchAndRewrite(IfOp op,
1648                                 PatternRewriter &rewriter) const override {
1649     // Early exit if the condition is constant since replacing a constant
1650     // in the body with another constant isn't a simplification.
1651     if (op.getCondition().getDefiningOp<arith::ConstantOp>())
1652       return failure();
1653 
1654     bool changed = false;
1655     mlir::Type i1Ty = rewriter.getI1Type();
1656 
1657     // These variables serve to prevent creating duplicate constants
1658     // and hold constant true or false values.
1659     Value constantTrue = nullptr;
1660     Value constantFalse = nullptr;
1661 
1662     for (OpOperand &use :
1663          llvm::make_early_inc_range(op.getCondition().getUses())) {
1664       if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1665         changed = true;
1666 
1667         if (!constantTrue)
1668           constantTrue = rewriter.create<arith::ConstantOp>(
1669               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
1670 
1671         rewriter.updateRootInPlace(use.getOwner(),
1672                                    [&]() { use.set(constantTrue); });
1673       } else if (op.getElseRegion().isAncestor(
1674                      use.getOwner()->getParentRegion())) {
1675         changed = true;
1676 
1677         if (!constantFalse)
1678           constantFalse = rewriter.create<arith::ConstantOp>(
1679               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
1680 
1681         rewriter.updateRootInPlace(use.getOwner(),
1682                                    [&]() { use.set(constantFalse); });
1683       }
1684     }
1685 
1686     return success(changed);
1687   }
1688 };
1689 
1690 /// Remove any statements from an if that are equivalent to the condition
1691 /// or its negation. For example:
1692 ///
1693 ///    %res:2 = scf.if %cmp {
1694 ///       yield something(), true
1695 ///    } else {
1696 ///       yield something2(), false
1697 ///    }
1698 ///    print(%res#1)
1699 ///
1700 ///  becomes
1701 ///    %res = scf.if %cmp {
1702 ///       yield something()
1703 ///    } else {
1704 ///       yield something2()
1705 ///    }
1706 ///    print(%cmp)
1707 ///
1708 /// Additionally if both branches yield the same value, replace all uses
1709 /// of the result with the yielded value.
1710 ///
1711 ///    %res:2 = scf.if %cmp {
1712 ///       yield something(), %arg1
1713 ///    } else {
1714 ///       yield something2(), %arg1
1715 ///    }
1716 ///    print(%res#1)
1717 ///
1718 ///  becomes
1719 ///    %res = scf.if %cmp {
1720 ///       yield something()
1721 ///    } else {
1722 ///       yield something2()
1723 ///    }
1724 ///    print(%arg1)
1725 ///
1726 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
1727   using OpRewritePattern<IfOp>::OpRewritePattern;
1728 
matchAndRewrite__anoncfc0330f0b11::ReplaceIfYieldWithConditionOrValue1729   LogicalResult matchAndRewrite(IfOp op,
1730                                 PatternRewriter &rewriter) const override {
1731     // Early exit if there are no results that could be replaced.
1732     if (op.getNumResults() == 0)
1733       return failure();
1734 
1735     auto trueYield =
1736         cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
1737     auto falseYield =
1738         cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
1739 
1740     rewriter.setInsertionPoint(op->getBlock(),
1741                                op.getOperation()->getIterator());
1742     bool changed = false;
1743     Type i1Ty = rewriter.getI1Type();
1744     for (auto tup : llvm::zip(trueYield.getResults(), falseYield.getResults(),
1745                               op.getResults())) {
1746       Value trueResult, falseResult, opResult;
1747       std::tie(trueResult, falseResult, opResult) = tup;
1748 
1749       if (trueResult == falseResult) {
1750         if (!opResult.use_empty()) {
1751           opResult.replaceAllUsesWith(trueResult);
1752           changed = true;
1753         }
1754         continue;
1755       }
1756 
1757       auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
1758       if (!trueYield)
1759         continue;
1760 
1761       if (!trueYield.getType().isInteger(1))
1762         continue;
1763 
1764       auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
1765       if (!falseYield)
1766         continue;
1767 
1768       bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
1769       bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
1770       if (!trueVal && falseVal) {
1771         if (!opResult.use_empty()) {
1772           Value notCond = rewriter.create<arith::XOrIOp>(
1773               op.getLoc(), op.getCondition(),
1774               rewriter.create<arith::ConstantOp>(
1775                   op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
1776           opResult.replaceAllUsesWith(notCond);
1777           changed = true;
1778         }
1779       }
1780       if (trueVal && !falseVal) {
1781         if (!opResult.use_empty()) {
1782           opResult.replaceAllUsesWith(op.getCondition());
1783           changed = true;
1784         }
1785       }
1786     }
1787     return success(changed);
1788   }
1789 };
1790 
1791 /// Merge any consecutive scf.if's with the same condition.
1792 ///
1793 ///    scf.if %cond {
1794 ///       firstCodeTrue();...
1795 ///    } else {
1796 ///       firstCodeFalse();...
1797 ///    }
1798 ///    %res = scf.if %cond {
1799 ///       secondCodeTrue();...
1800 ///    } else {
1801 ///       secondCodeFalse();...
1802 ///    }
1803 ///
1804 ///  becomes
1805 ///    %res = scf.if %cmp {
1806 ///       firstCodeTrue();...
1807 ///       secondCodeTrue();...
1808 ///    } else {
1809 ///       firstCodeFalse();...
1810 ///       secondCodeFalse();...
1811 ///    }
1812 struct CombineIfs : public OpRewritePattern<IfOp> {
1813   using OpRewritePattern<IfOp>::OpRewritePattern;
1814 
matchAndRewrite__anoncfc0330f0b11::CombineIfs1815   LogicalResult matchAndRewrite(IfOp nextIf,
1816                                 PatternRewriter &rewriter) const override {
1817     Block *parent = nextIf->getBlock();
1818     if (nextIf == &parent->front())
1819       return failure();
1820 
1821     auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
1822     if (!prevIf)
1823       return failure();
1824 
1825     // Determine the logical then/else blocks when prevIf's
1826     // condition is used. Null means the block does not exist
1827     // in that case (e.g. empty else). If neither of these
1828     // are set, the two conditions cannot be compared.
1829     Block *nextThen = nullptr;
1830     Block *nextElse = nullptr;
1831     if (nextIf.getCondition() == prevIf.getCondition()) {
1832       nextThen = nextIf.thenBlock();
1833       if (!nextIf.getElseRegion().empty())
1834         nextElse = nextIf.elseBlock();
1835     }
1836     if (arith::XOrIOp notv =
1837             nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1838       if (notv.getLhs() == prevIf.getCondition() &&
1839           matchPattern(notv.getRhs(), m_One())) {
1840         nextElse = nextIf.thenBlock();
1841         if (!nextIf.getElseRegion().empty())
1842           nextThen = nextIf.elseBlock();
1843       }
1844     }
1845     if (arith::XOrIOp notv =
1846             prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1847       if (notv.getLhs() == nextIf.getCondition() &&
1848           matchPattern(notv.getRhs(), m_One())) {
1849         nextElse = nextIf.thenBlock();
1850         if (!nextIf.getElseRegion().empty())
1851           nextThen = nextIf.elseBlock();
1852       }
1853     }
1854 
1855     if (!nextThen && !nextElse)
1856       return failure();
1857 
1858     SmallVector<Value> prevElseYielded;
1859     if (!prevIf.getElseRegion().empty())
1860       prevElseYielded = prevIf.elseYield().getOperands();
1861     // Replace all uses of return values of op within nextIf with the
1862     // corresponding yields
1863     for (auto it : llvm::zip(prevIf.getResults(),
1864                              prevIf.thenYield().getOperands(), prevElseYielded))
1865       for (OpOperand &use :
1866            llvm::make_early_inc_range(std::get<0>(it).getUses())) {
1867         if (nextThen && nextThen->getParent()->isAncestor(
1868                             use.getOwner()->getParentRegion())) {
1869           rewriter.startRootUpdate(use.getOwner());
1870           use.set(std::get<1>(it));
1871           rewriter.finalizeRootUpdate(use.getOwner());
1872         } else if (nextElse && nextElse->getParent()->isAncestor(
1873                                    use.getOwner()->getParentRegion())) {
1874           rewriter.startRootUpdate(use.getOwner());
1875           use.set(std::get<2>(it));
1876           rewriter.finalizeRootUpdate(use.getOwner());
1877         }
1878       }
1879 
1880     SmallVector<Type> mergedTypes(prevIf.getResultTypes());
1881     llvm::append_range(mergedTypes, nextIf.getResultTypes());
1882 
1883     IfOp combinedIf = rewriter.create<IfOp>(
1884         nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
1885     rewriter.eraseBlock(&combinedIf.getThenRegion().back());
1886 
1887     rewriter.inlineRegionBefore(prevIf.getThenRegion(),
1888                                 combinedIf.getThenRegion(),
1889                                 combinedIf.getThenRegion().begin());
1890 
1891     if (nextThen) {
1892       YieldOp thenYield = combinedIf.thenYield();
1893       YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
1894       rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
1895       rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
1896 
1897       SmallVector<Value> mergedYields(thenYield.getOperands());
1898       llvm::append_range(mergedYields, thenYield2.getOperands());
1899       rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
1900       rewriter.eraseOp(thenYield);
1901       rewriter.eraseOp(thenYield2);
1902     }
1903 
1904     rewriter.inlineRegionBefore(prevIf.getElseRegion(),
1905                                 combinedIf.getElseRegion(),
1906                                 combinedIf.getElseRegion().begin());
1907 
1908     if (nextElse) {
1909       if (combinedIf.getElseRegion().empty()) {
1910         rewriter.inlineRegionBefore(*nextElse->getParent(),
1911                                     combinedIf.getElseRegion(),
1912                                     combinedIf.getElseRegion().begin());
1913       } else {
1914         YieldOp elseYield = combinedIf.elseYield();
1915         YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
1916         rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
1917 
1918         rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
1919 
1920         SmallVector<Value> mergedElseYields(elseYield.getOperands());
1921         llvm::append_range(mergedElseYields, elseYield2.getOperands());
1922 
1923         rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
1924         rewriter.eraseOp(elseYield);
1925         rewriter.eraseOp(elseYield2);
1926       }
1927     }
1928 
1929     SmallVector<Value> prevValues;
1930     SmallVector<Value> nextValues;
1931     for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
1932       if (pair.index() < prevIf.getNumResults())
1933         prevValues.push_back(pair.value());
1934       else
1935         nextValues.push_back(pair.value());
1936     }
1937     rewriter.replaceOp(prevIf, prevValues);
1938     rewriter.replaceOp(nextIf, nextValues);
1939     return success();
1940   }
1941 };
1942 
1943 /// Pattern to remove an empty else branch.
1944 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
1945   using OpRewritePattern<IfOp>::OpRewritePattern;
1946 
matchAndRewrite__anoncfc0330f0b11::RemoveEmptyElseBranch1947   LogicalResult matchAndRewrite(IfOp ifOp,
1948                                 PatternRewriter &rewriter) const override {
1949     // Cannot remove else region when there are operation results.
1950     if (ifOp.getNumResults())
1951       return failure();
1952     Block *elseBlock = ifOp.elseBlock();
1953     if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
1954       return failure();
1955     auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
1956     rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
1957                                 newIfOp.getThenRegion().begin());
1958     rewriter.eraseOp(ifOp);
1959     return success();
1960   }
1961 };
1962 
1963 /// Convert nested `if`s into `arith.andi` + single `if`.
1964 ///
1965 ///    scf.if %arg0 {
1966 ///      scf.if %arg1 {
1967 ///        ...
1968 ///        scf.yield
1969 ///      }
1970 ///      scf.yield
1971 ///    }
1972 ///  becomes
1973 ///
1974 ///    %0 = arith.andi %arg0, %arg1
1975 ///    scf.if %0 {
1976 ///      ...
1977 ///      scf.yield
1978 ///    }
1979 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
1980   using OpRewritePattern<IfOp>::OpRewritePattern;
1981 
matchAndRewrite__anoncfc0330f0b11::CombineNestedIfs1982   LogicalResult matchAndRewrite(IfOp op,
1983                                 PatternRewriter &rewriter) const override {
1984     auto nestedOps = op.thenBlock()->without_terminator();
1985     // Nested `if` must be the only op in block.
1986     if (!llvm::hasSingleElement(nestedOps))
1987       return failure();
1988 
1989     // If there is an else block, it can only yield
1990     if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
1991       return failure();
1992 
1993     auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
1994     if (!nestedIf)
1995       return failure();
1996 
1997     if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
1998       return failure();
1999 
2000     SmallVector<Value> thenYield(op.thenYield().getOperands());
2001     SmallVector<Value> elseYield;
2002     if (op.elseBlock())
2003       llvm::append_range(elseYield, op.elseYield().getOperands());
2004 
2005     // A list of indices for which we should upgrade the value yielded
2006     // in the else to a select.
2007     SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2008 
2009     // If the outer scf.if yields a value produced by the inner scf.if,
2010     // only permit combining if the value yielded when the condition
2011     // is false in the outer scf.if is the same value yielded when the
2012     // inner scf.if condition is false.
2013     // Note that the array access to elseYield will not go out of bounds
2014     // since it must have the same length as thenYield, since they both
2015     // come from the same scf.if.
2016     for (const auto &tup : llvm::enumerate(thenYield)) {
2017       if (tup.value().getDefiningOp() == nestedIf) {
2018         auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
2019         if (nestedIf.elseYield().getOperand(nestedIdx) !=
2020             elseYield[tup.index()]) {
2021           return failure();
2022         }
2023         // If the correctness test passes, we will yield
2024         // corresponding value from the inner scf.if
2025         thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2026         continue;
2027       }
2028 
2029       // Otherwise, we need to ensure the else block of the combined
2030       // condition still returns the same value when the outer condition is
2031       // true and the inner condition is false. This can be accomplished if
2032       // the then value is defined outside the outer scf.if and we replace the
2033       // value with a select that considers just the outer condition. Since
2034       // the else region contains just the yield, its yielded value is
2035       // defined outside the scf.if, by definition.
2036 
2037       // If the then value is defined within the scf.if, bail.
2038       if (tup.value().getParentRegion() == &op.getThenRegion()) {
2039         return failure();
2040       }
2041       elseYieldsToUpgradeToSelect.push_back(tup.index());
2042     }
2043 
2044     Location loc = op.getLoc();
2045     Value newCondition = rewriter.create<arith::AndIOp>(
2046         loc, op.getCondition(), nestedIf.getCondition());
2047     auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2048 
2049     SmallVector<Value> results;
2050     llvm::append_range(results, newIf.getResults());
2051     rewriter.setInsertionPoint(newIf);
2052 
2053     for (auto idx : elseYieldsToUpgradeToSelect)
2054       results[idx] = rewriter.create<arith::SelectOp>(
2055           op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2056 
2057     Block *newIfBlock = newIf.thenBlock();
2058     if (newIfBlock)
2059       rewriter.eraseOp(newIfBlock->getTerminator());
2060     else
2061       newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2062     rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2063     rewriter.setInsertionPointToEnd(newIf.thenBlock());
2064     rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2065     if (!elseYield.empty()) {
2066       rewriter.createBlock(&newIf.getElseRegion());
2067       rewriter.setInsertionPointToEnd(newIf.elseBlock());
2068       rewriter.create<YieldOp>(loc, elseYield);
2069     }
2070     rewriter.replaceOp(op, results);
2071     return success();
2072   }
2073 };
2074 
2075 } // namespace
2076 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2077 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2078                                        MLIRContext *context) {
2079   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2080               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2081               RemoveStaticCondition, RemoveUnusedResults,
2082               ReplaceIfYieldWithConditionOrValue>(context);
2083 }
2084 
thenBlock()2085 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
thenYield()2086 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
elseBlock()2087 Block *IfOp::elseBlock() {
2088   Region &r = getElseRegion();
2089   if (r.empty())
2090     return nullptr;
2091   return &r.back();
2092 }
elseYield()2093 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2094 
2095 //===----------------------------------------------------------------------===//
2096 // ParallelOp
2097 //===----------------------------------------------------------------------===//
2098 
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange initVals,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn)2099 void ParallelOp::build(
2100     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2101     ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2102     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2103         bodyBuilderFn) {
2104   result.addOperands(lowerBounds);
2105   result.addOperands(upperBounds);
2106   result.addOperands(steps);
2107   result.addOperands(initVals);
2108   result.addAttribute(
2109       ParallelOp::getOperandSegmentSizeAttr(),
2110       builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
2111                                 static_cast<int32_t>(upperBounds.size()),
2112                                 static_cast<int32_t>(steps.size()),
2113                                 static_cast<int32_t>(initVals.size())}));
2114   result.addTypes(initVals.getTypes());
2115 
2116   OpBuilder::InsertionGuard guard(builder);
2117   unsigned numIVs = steps.size();
2118   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2119   SmallVector<Location, 8> argLocs(numIVs, result.location);
2120   Region *bodyRegion = result.addRegion();
2121   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2122 
2123   if (bodyBuilderFn) {
2124     builder.setInsertionPointToStart(bodyBlock);
2125     bodyBuilderFn(builder, result.location,
2126                   bodyBlock->getArguments().take_front(numIVs),
2127                   bodyBlock->getArguments().drop_front(numIVs));
2128   }
2129   ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2130 }
2131 
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)2132 void ParallelOp::build(
2133     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2134     ValueRange upperBounds, ValueRange steps,
2135     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2136   // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2137   // we don't capture a reference to a temporary by constructing the lambda at
2138   // function level.
2139   auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2140                                            Location nestedLoc, ValueRange ivs,
2141                                            ValueRange) {
2142     bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2143   };
2144   function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2145   if (bodyBuilderFn)
2146     wrapper = wrappedBuilderFn;
2147 
2148   build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2149         wrapper);
2150 }
2151 
verify()2152 LogicalResult ParallelOp::verify() {
2153   // Check that there is at least one value in lowerBound, upperBound and step.
2154   // It is sufficient to test only step, because it is ensured already that the
2155   // number of elements in lowerBound, upperBound and step are the same.
2156   Operation::operand_range stepValues = getStep();
2157   if (stepValues.empty())
2158     return emitOpError(
2159         "needs at least one tuple element for lowerBound, upperBound and step");
2160 
2161   // Check whether all constant step values are positive.
2162   for (Value stepValue : stepValues)
2163     if (auto cst = stepValue.getDefiningOp<arith::ConstantIndexOp>())
2164       if (cst.value() <= 0)
2165         return emitOpError("constant step operand must be positive");
2166 
2167   // Check that the body defines the same number of block arguments as the
2168   // number of tuple elements in step.
2169   Block *body = getBody();
2170   if (body->getNumArguments() != stepValues.size())
2171     return emitOpError() << "expects the same number of induction variables: "
2172                          << body->getNumArguments()
2173                          << " as bound and step values: " << stepValues.size();
2174   for (auto arg : body->getArguments())
2175     if (!arg.getType().isIndex())
2176       return emitOpError(
2177           "expects arguments for the induction variable to be of index type");
2178 
2179   // Check that the yield has no results
2180   Operation *yield = body->getTerminator();
2181   if (yield->getNumOperands() != 0)
2182     return yield->emitOpError() << "not allowed to have operands inside '"
2183                                 << ParallelOp::getOperationName() << "'";
2184 
2185   // Check that the number of results is the same as the number of ReduceOps.
2186   SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
2187   auto resultsSize = getResults().size();
2188   auto reductionsSize = reductions.size();
2189   auto initValsSize = getInitVals().size();
2190   if (resultsSize != reductionsSize)
2191     return emitOpError() << "expects number of results: " << resultsSize
2192                          << " to be the same as number of reductions: "
2193                          << reductionsSize;
2194   if (resultsSize != initValsSize)
2195     return emitOpError() << "expects number of results: " << resultsSize
2196                          << " to be the same as number of initial values: "
2197                          << initValsSize;
2198 
2199   // Check that the types of the results and reductions are the same.
2200   for (auto resultAndReduce : llvm::zip(getResults(), reductions)) {
2201     auto resultType = std::get<0>(resultAndReduce).getType();
2202     auto reduceOp = std::get<1>(resultAndReduce);
2203     auto reduceType = reduceOp.getOperand().getType();
2204     if (resultType != reduceType)
2205       return reduceOp.emitOpError()
2206              << "expects type of reduce: " << reduceType
2207              << " to be the same as result type: " << resultType;
2208   }
2209   return success();
2210 }
2211 
parse(OpAsmParser & parser,OperationState & result)2212 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2213   auto &builder = parser.getBuilder();
2214   // Parse an opening `(` followed by induction variables followed by `)`
2215   SmallVector<OpAsmParser::Argument, 4> ivs;
2216   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
2217     return failure();
2218 
2219   // Parse loop bounds.
2220   SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2221   if (parser.parseEqual() ||
2222       parser.parseOperandList(lower, ivs.size(),
2223                               OpAsmParser::Delimiter::Paren) ||
2224       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2225     return failure();
2226 
2227   SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2228   if (parser.parseKeyword("to") ||
2229       parser.parseOperandList(upper, ivs.size(),
2230                               OpAsmParser::Delimiter::Paren) ||
2231       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2232     return failure();
2233 
2234   // Parse step values.
2235   SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2236   if (parser.parseKeyword("step") ||
2237       parser.parseOperandList(steps, ivs.size(),
2238                               OpAsmParser::Delimiter::Paren) ||
2239       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2240     return failure();
2241 
2242   // Parse init values.
2243   SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2244   if (succeeded(parser.parseOptionalKeyword("init"))) {
2245     if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2246       return failure();
2247   }
2248 
2249   // Parse optional results in case there is a reduce.
2250   if (parser.parseOptionalArrowTypeList(result.types))
2251     return failure();
2252 
2253   // Now parse the body.
2254   Region *body = result.addRegion();
2255   for (auto &iv : ivs)
2256     iv.type = builder.getIndexType();
2257   if (parser.parseRegion(*body, ivs))
2258     return failure();
2259 
2260   // Set `operand_segment_sizes` attribute.
2261   result.addAttribute(
2262       ParallelOp::getOperandSegmentSizeAttr(),
2263       builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
2264                                 static_cast<int32_t>(upper.size()),
2265                                 static_cast<int32_t>(steps.size()),
2266                                 static_cast<int32_t>(initVals.size())}));
2267 
2268   // Parse attributes.
2269   if (parser.parseOptionalAttrDict(result.attributes) ||
2270       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2271                              result.operands))
2272     return failure();
2273 
2274   // Add a terminator if none was parsed.
2275   ForOp::ensureTerminator(*body, builder, result.location);
2276   return success();
2277 }
2278 
print(OpAsmPrinter & p)2279 void ParallelOp::print(OpAsmPrinter &p) {
2280   p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2281     << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2282   if (!getInitVals().empty())
2283     p << " init (" << getInitVals() << ")";
2284   p.printOptionalArrowTypeList(getResultTypes());
2285   p << ' ';
2286   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2287   p.printOptionalAttrDict(
2288       (*this)->getAttrs(),
2289       /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2290 }
2291 
getLoopBody()2292 Region &ParallelOp::getLoopBody() { return getRegion(); }
2293 
getParallelForInductionVarOwner(Value val)2294 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
2295   auto ivArg = val.dyn_cast<BlockArgument>();
2296   if (!ivArg)
2297     return ParallelOp();
2298   assert(ivArg.getOwner() && "unlinked block argument");
2299   auto *containingOp = ivArg.getOwner()->getParentOp();
2300   return dyn_cast<ParallelOp>(containingOp);
2301 }
2302 
2303 namespace {
2304 // Collapse loop dimensions that perform a single iteration.
2305 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
2306   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2307 
matchAndRewrite__anoncfc0330f1411::CollapseSingleIterationLoops2308   LogicalResult matchAndRewrite(ParallelOp op,
2309                                 PatternRewriter &rewriter) const override {
2310     BlockAndValueMapping mapping;
2311     // Compute new loop bounds that omit all single-iteration loop dimensions.
2312     SmallVector<Value, 2> newLowerBounds;
2313     SmallVector<Value, 2> newUpperBounds;
2314     SmallVector<Value, 2> newSteps;
2315     newLowerBounds.reserve(op.getLowerBound().size());
2316     newUpperBounds.reserve(op.getUpperBound().size());
2317     newSteps.reserve(op.getStep().size());
2318     for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound(),
2319                               op.getStep(), op.getInductionVars())) {
2320       Value lowerBound, upperBound, step, iv;
2321       std::tie(lowerBound, upperBound, step, iv) = dim;
2322       // Collect the statically known loop bounds.
2323       auto lowerBoundConstant =
2324           dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
2325       auto upperBoundConstant =
2326           dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
2327       auto stepConstant =
2328           dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
2329       // Replace the loop induction variable by the lower bound if the loop
2330       // performs a single iteration. Otherwise, copy the loop bounds.
2331       if (lowerBoundConstant && upperBoundConstant && stepConstant &&
2332           (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
2333           (upperBoundConstant.value() - lowerBoundConstant.value()) <=
2334               stepConstant.value()) {
2335         mapping.map(iv, lowerBound);
2336       } else {
2337         newLowerBounds.push_back(lowerBound);
2338         newUpperBounds.push_back(upperBound);
2339         newSteps.push_back(step);
2340       }
2341     }
2342     // Exit if none of the loop dimensions perform a single iteration.
2343     if (newLowerBounds.size() == op.getLowerBound().size())
2344       return failure();
2345 
2346     if (newLowerBounds.empty()) {
2347       // All of the loop dimensions perform a single iteration. Inline
2348       // loop body and nested ReduceOp's
2349       SmallVector<Value> results;
2350       results.reserve(op.getInitVals().size());
2351       for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
2352         auto reduce = dyn_cast<ReduceOp>(bodyOp);
2353         if (!reduce) {
2354           rewriter.clone(bodyOp, mapping);
2355           continue;
2356         }
2357         Block &reduceBlock = reduce.getReductionOperator().front();
2358         auto initValIndex = results.size();
2359         mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
2360         mapping.map(reduceBlock.getArgument(1),
2361                     mapping.lookupOrDefault(reduce.getOperand()));
2362         for (auto &reduceBodyOp : reduceBlock.without_terminator())
2363           rewriter.clone(reduceBodyOp, mapping);
2364 
2365         auto result = mapping.lookupOrDefault(
2366             cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
2367         results.push_back(result);
2368       }
2369       rewriter.replaceOp(op, results);
2370       return success();
2371     }
2372     // Replace the parallel loop by lower-dimensional parallel loop.
2373     auto newOp =
2374         rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
2375                                     newSteps, op.getInitVals(), nullptr);
2376     // Clone the loop body and remap the block arguments of the collapsed loops
2377     // (inlining does not support a cancellable block argument mapping).
2378     rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
2379                                newOp.getRegion().begin(), mapping);
2380     rewriter.replaceOp(op, newOp.getResults());
2381     return success();
2382   }
2383 };
2384 
2385 /// Removes parallel loops in which at least one lower/upper bound pair consists
2386 /// of the same values - such loops have an empty iteration domain.
2387 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
2388   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2389 
matchAndRewrite__anoncfc0330f1411::RemoveEmptyParallelLoops2390   LogicalResult matchAndRewrite(ParallelOp op,
2391                                 PatternRewriter &rewriter) const override {
2392     for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
2393       if (std::get<0>(dim) == std::get<1>(dim)) {
2394         rewriter.replaceOp(op, op.getInitVals());
2395         return success();
2396       }
2397     }
2398     return failure();
2399   }
2400 };
2401 
2402 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
2403   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2404 
matchAndRewrite__anoncfc0330f1411::MergeNestedParallelLoops2405   LogicalResult matchAndRewrite(ParallelOp op,
2406                                 PatternRewriter &rewriter) const override {
2407     Block &outerBody = op.getLoopBody().front();
2408     if (!llvm::hasSingleElement(outerBody.without_terminator()))
2409       return failure();
2410 
2411     auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
2412     if (!innerOp)
2413       return failure();
2414 
2415     for (auto val : outerBody.getArguments())
2416       if (llvm::is_contained(innerOp.getLowerBound(), val) ||
2417           llvm::is_contained(innerOp.getUpperBound(), val) ||
2418           llvm::is_contained(innerOp.getStep(), val))
2419         return failure();
2420 
2421     // Reductions are not supported yet.
2422     if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2423       return failure();
2424 
2425     auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
2426                            ValueRange iterVals, ValueRange) {
2427       Block &innerBody = innerOp.getLoopBody().front();
2428       assert(iterVals.size() ==
2429              (outerBody.getNumArguments() + innerBody.getNumArguments()));
2430       BlockAndValueMapping mapping;
2431       mapping.map(outerBody.getArguments(),
2432                   iterVals.take_front(outerBody.getNumArguments()));
2433       mapping.map(innerBody.getArguments(),
2434                   iterVals.take_back(innerBody.getNumArguments()));
2435       for (Operation &op : innerBody.without_terminator())
2436         builder.clone(op, mapping);
2437     };
2438 
2439     auto concatValues = [](const auto &first, const auto &second) {
2440       SmallVector<Value> ret;
2441       ret.reserve(first.size() + second.size());
2442       ret.assign(first.begin(), first.end());
2443       ret.append(second.begin(), second.end());
2444       return ret;
2445     };
2446 
2447     auto newLowerBounds =
2448         concatValues(op.getLowerBound(), innerOp.getLowerBound());
2449     auto newUpperBounds =
2450         concatValues(op.getUpperBound(), innerOp.getUpperBound());
2451     auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2452 
2453     rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
2454                                             newSteps, llvm::None, bodyBuilder);
2455     return success();
2456   }
2457 };
2458 
2459 } // namespace
2460 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2461 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
2462                                              MLIRContext *context) {
2463   results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
2464               MergeNestedParallelLoops>(context);
2465 }
2466 
2467 //===----------------------------------------------------------------------===//
2468 // ReduceOp
2469 //===----------------------------------------------------------------------===//
2470 
build(OpBuilder & builder,OperationState & result,Value operand,function_ref<void (OpBuilder &,Location,Value,Value)> bodyBuilderFn)2471 void ReduceOp::build(
2472     OpBuilder &builder, OperationState &result, Value operand,
2473     function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
2474   auto type = operand.getType();
2475   result.addOperands(operand);
2476 
2477   OpBuilder::InsertionGuard guard(builder);
2478   Region *bodyRegion = result.addRegion();
2479   Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type},
2480                                     {result.location, result.location});
2481   if (bodyBuilderFn)
2482     bodyBuilderFn(builder, result.location, body->getArgument(0),
2483                   body->getArgument(1));
2484 }
2485 
verifyRegions()2486 LogicalResult ReduceOp::verifyRegions() {
2487   // The region of a ReduceOp has two arguments of the same type as its operand.
2488   auto type = getOperand().getType();
2489   Block &block = getReductionOperator().front();
2490   if (block.empty())
2491     return emitOpError("the block inside reduce should not be empty");
2492   if (block.getNumArguments() != 2 ||
2493       llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
2494         return arg.getType() != type;
2495       }))
2496     return emitOpError() << "expects two arguments to reduce block of type "
2497                          << type;
2498 
2499   // Check that the block is terminated by a ReduceReturnOp.
2500   if (!isa<ReduceReturnOp>(block.getTerminator()))
2501     return emitOpError("the block inside reduce should be terminated with a "
2502                        "'scf.reduce.return' op");
2503 
2504   return success();
2505 }
2506 
parse(OpAsmParser & parser,OperationState & result)2507 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2508   // Parse an opening `(` followed by the reduced value followed by `)`
2509   OpAsmParser::UnresolvedOperand operand;
2510   if (parser.parseLParen() || parser.parseOperand(operand) ||
2511       parser.parseRParen())
2512     return failure();
2513 
2514   Type resultType;
2515   // Parse the type of the operand (and also what reduce computes on).
2516   if (parser.parseColonType(resultType) ||
2517       parser.resolveOperand(operand, resultType, result.operands))
2518     return failure();
2519 
2520   // Now parse the body.
2521   Region *body = result.addRegion();
2522   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2523     return failure();
2524 
2525   return success();
2526 }
2527 
print(OpAsmPrinter & p)2528 void ReduceOp::print(OpAsmPrinter &p) {
2529   p << "(" << getOperand() << ") ";
2530   p << " : " << getOperand().getType() << ' ';
2531   p.printRegion(getReductionOperator());
2532 }
2533 
2534 //===----------------------------------------------------------------------===//
2535 // ReduceReturnOp
2536 //===----------------------------------------------------------------------===//
2537 
verify()2538 LogicalResult ReduceReturnOp::verify() {
2539   // The type of the return value should be the same type as the type of the
2540   // operand of the enclosing ReduceOp.
2541   auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
2542   Type reduceType = reduceOp.getOperand().getType();
2543   if (reduceType != getResult().getType())
2544     return emitOpError() << "needs to have type " << reduceType
2545                          << " (the type of the enclosing ReduceOp)";
2546   return success();
2547 }
2548 
2549 //===----------------------------------------------------------------------===//
2550 // WhileOp
2551 //===----------------------------------------------------------------------===//
2552 
getSuccessorEntryOperands(Optional<unsigned> index)2553 OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
2554   assert(index && *index == 0 &&
2555          "WhileOp is expected to branch only to the first region");
2556 
2557   return getInits();
2558 }
2559 
getConditionOp()2560 ConditionOp WhileOp::getConditionOp() {
2561   return cast<ConditionOp>(getBefore().front().getTerminator());
2562 }
2563 
getYieldOp()2564 YieldOp WhileOp::getYieldOp() {
2565   return cast<YieldOp>(getAfter().front().getTerminator());
2566 }
2567 
getBeforeArguments()2568 Block::BlockArgListType WhileOp::getBeforeArguments() {
2569   return getBefore().front().getArguments();
2570 }
2571 
getAfterArguments()2572 Block::BlockArgListType WhileOp::getAfterArguments() {
2573   return getAfter().front().getArguments();
2574 }
2575 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)2576 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
2577                                   ArrayRef<Attribute> operands,
2578                                   SmallVectorImpl<RegionSuccessor> &regions) {
2579   // The parent op always branches to the condition region.
2580   if (!index) {
2581     regions.emplace_back(&getBefore(), getBefore().getArguments());
2582     return;
2583   }
2584 
2585   assert(*index < 2 && "there are only two regions in a WhileOp");
2586   // The body region always branches back to the condition region.
2587   if (*index == 1) {
2588     regions.emplace_back(&getBefore(), getBefore().getArguments());
2589     return;
2590   }
2591 
2592   // Try to narrow the successor to the condition region.
2593   assert(!operands.empty() && "expected at least one operand");
2594   auto cond = operands[0].dyn_cast_or_null<BoolAttr>();
2595   if (!cond || !cond.getValue())
2596     regions.emplace_back(getResults());
2597   if (!cond || cond.getValue())
2598     regions.emplace_back(&getAfter(), getAfter().getArguments());
2599 }
2600 
2601 /// Parses a `while` op.
2602 ///
2603 /// op ::= `scf.while` assignments `:` function-type region `do` region
2604 ///         `attributes` attribute-dict
2605 /// initializer ::= /* empty */ | `(` assignment-list `)`
2606 /// assignment-list ::= assignment | assignment `,` assignment-list
2607 /// assignment ::= ssa-value `=` ssa-value
parse(OpAsmParser & parser,OperationState & result)2608 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2609   SmallVector<OpAsmParser::Argument, 4> regionArgs;
2610   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2611   Region *before = result.addRegion();
2612   Region *after = result.addRegion();
2613 
2614   OptionalParseResult listResult =
2615       parser.parseOptionalAssignmentList(regionArgs, operands);
2616   if (listResult.hasValue() && failed(listResult.getValue()))
2617     return failure();
2618 
2619   FunctionType functionType;
2620   SMLoc typeLoc = parser.getCurrentLocation();
2621   if (failed(parser.parseColonType(functionType)))
2622     return failure();
2623 
2624   result.addTypes(functionType.getResults());
2625 
2626   if (functionType.getNumInputs() != operands.size()) {
2627     return parser.emitError(typeLoc)
2628            << "expected as many input types as operands "
2629            << "(expected " << operands.size() << " got "
2630            << functionType.getNumInputs() << ")";
2631   }
2632 
2633   // Resolve input operands.
2634   if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2635                                     parser.getCurrentLocation(),
2636                                     result.operands)))
2637     return failure();
2638 
2639   // Propagate the types into the region arguments.
2640   for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2641     regionArgs[i].type = functionType.getInput(i);
2642 
2643   return failure(parser.parseRegion(*before, regionArgs) ||
2644                  parser.parseKeyword("do") || parser.parseRegion(*after) ||
2645                  parser.parseOptionalAttrDictWithKeyword(result.attributes));
2646 }
2647 
2648 /// Prints a `while` op.
print(OpAsmPrinter & p)2649 void scf::WhileOp::print(OpAsmPrinter &p) {
2650   printInitializationList(p, getBefore().front().getArguments(), getInits(),
2651                           " ");
2652   p << " : ";
2653   p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
2654   p << ' ';
2655   p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
2656   p << " do ";
2657   p.printRegion(getAfter());
2658   p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2659 }
2660 
2661 /// Verifies that two ranges of types match, i.e. have the same number of
2662 /// entries and that types are pairwise equals. Reports errors on the given
2663 /// operation in case of mismatch.
2664 template <typename OpTy>
verifyTypeRangesMatch(OpTy op,TypeRange left,TypeRange right,StringRef message)2665 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
2666                                            TypeRange right, StringRef message) {
2667   if (left.size() != right.size())
2668     return op.emitOpError("expects the same number of ") << message;
2669 
2670   for (unsigned i = 0, e = left.size(); i < e; ++i) {
2671     if (left[i] != right[i]) {
2672       InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
2673                                 << message;
2674       diag.attachNote() << "for argument " << i << ", found " << left[i]
2675                         << " and " << right[i];
2676       return diag;
2677     }
2678   }
2679 
2680   return success();
2681 }
2682 
2683 /// Verifies that the first block of the given `region` is terminated by a
2684 /// YieldOp. Reports errors on the given operation if it is not the case.
2685 template <typename TerminatorTy>
verifyAndGetTerminator(scf::WhileOp op,Region & region,StringRef errorMessage)2686 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
2687                                            StringRef errorMessage) {
2688   Operation *terminatorOperation = region.front().getTerminator();
2689   if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
2690     return yield;
2691 
2692   auto diag = op.emitOpError(errorMessage);
2693   if (terminatorOperation)
2694     diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
2695   return nullptr;
2696 }
2697 
verify()2698 LogicalResult scf::WhileOp::verify() {
2699   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
2700       *this, getBefore(),
2701       "expects the 'before' region to terminate with 'scf.condition'");
2702   if (!beforeTerminator)
2703     return failure();
2704 
2705   auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
2706       *this, getAfter(),
2707       "expects the 'after' region to terminate with 'scf.yield'");
2708   return success(afterTerminator != nullptr);
2709 }
2710 
2711 namespace {
2712 /// Replace uses of the condition within the do block with true, since otherwise
2713 /// the block would not be evaluated.
2714 ///
2715 /// scf.while (..) : (i1, ...) -> ... {
2716 ///  %condition = call @evaluate_condition() : () -> i1
2717 ///  scf.condition(%condition) %condition : i1, ...
2718 /// } do {
2719 /// ^bb0(%arg0: i1, ...):
2720 ///    use(%arg0)
2721 ///    ...
2722 ///
2723 /// becomes
2724 /// scf.while (..) : (i1, ...) -> ... {
2725 ///  %condition = call @evaluate_condition() : () -> i1
2726 ///  scf.condition(%condition) %condition : i1, ...
2727 /// } do {
2728 /// ^bb0(%arg0: i1, ...):
2729 ///    use(%true)
2730 ///    ...
2731 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
2732   using OpRewritePattern<WhileOp>::OpRewritePattern;
2733 
matchAndRewrite__anoncfc0330f1811::WhileConditionTruth2734   LogicalResult matchAndRewrite(WhileOp op,
2735                                 PatternRewriter &rewriter) const override {
2736     auto term = op.getConditionOp();
2737 
2738     // These variables serve to prevent creating duplicate constants
2739     // and hold constant true or false values.
2740     Value constantTrue = nullptr;
2741 
2742     bool replaced = false;
2743     for (auto yieldedAndBlockArgs :
2744          llvm::zip(term.getArgs(), op.getAfterArguments())) {
2745       if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
2746         if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
2747           if (!constantTrue)
2748             constantTrue = rewriter.create<arith::ConstantOp>(
2749                 op.getLoc(), term.getCondition().getType(),
2750                 rewriter.getBoolAttr(true));
2751 
2752           std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
2753           replaced = true;
2754         }
2755       }
2756     }
2757     return success(replaced);
2758   }
2759 };
2760 
2761 /// Remove loop invariant arguments from `before` block of scf.while.
2762 /// A before block argument is considered loop invariant if :-
2763 ///   1. i-th yield operand is equal to the i-th while operand.
2764 ///   2. i-th yield operand is k-th after block argument which is (k+1)-th
2765 ///      condition operand AND this (k+1)-th condition operand is equal to i-th
2766 ///      iter argument/while operand.
2767 /// For the arguments which are removed, their uses inside scf.while
2768 /// are replaced with their corresponding initial value.
2769 ///
2770 /// Eg:
2771 ///    INPUT :-
2772 ///    %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
2773 ///                                     ..., %argN_before = %N)
2774 ///           {
2775 ///                ...
2776 ///                scf.condition(%cond) %arg1_before, %arg0_before,
2777 ///                                     %arg2_before, %arg0_before, ...
2778 ///           } do {
2779 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2780 ///                  ..., %argK_after):
2781 ///                ...
2782 ///                scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
2783 ///           }
2784 ///
2785 ///    OUTPUT :-
2786 ///    %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
2787 ///                                     %N)
2788 ///           {
2789 ///                ...
2790 ///                scf.condition(%cond) %b, %a, %arg2_before, %a, ...
2791 ///           } do {
2792 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2793 ///                  ..., %argK_after):
2794 ///                ...
2795 ///                scf.yield %arg1_after, ..., %argN
2796 ///           }
2797 ///
2798 ///    EXPLANATION:
2799 ///      We iterate over each yield operand.
2800 ///        1. 0-th yield operand %arg0_after_2 is 4-th condition operand
2801 ///           %arg0_before, which in turn is the 0-th iter argument. So we
2802 ///           remove 0-th before block argument and yield operand, and replace
2803 ///           all uses of the 0-th before block argument with its initial value
2804 ///           %a.
2805 ///        2. 1-th yield operand %b is equal to the 1-th iter arg's initial
2806 ///           value. So we remove this operand and the corresponding before
2807 ///           block argument and replace all uses of 1-th before block argument
2808 ///           with %b.
2809 struct RemoveLoopInvariantArgsFromBeforeBlock
2810     : public OpRewritePattern<WhileOp> {
2811   using OpRewritePattern<WhileOp>::OpRewritePattern;
2812 
matchAndRewrite__anoncfc0330f1811::RemoveLoopInvariantArgsFromBeforeBlock2813   LogicalResult matchAndRewrite(WhileOp op,
2814                                 PatternRewriter &rewriter) const override {
2815     Block &afterBlock = op.getAfter().front();
2816     Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
2817     ConditionOp condOp = op.getConditionOp();
2818     OperandRange condOpArgs = condOp.getArgs();
2819     Operation *yieldOp = afterBlock.getTerminator();
2820     ValueRange yieldOpArgs = yieldOp->getOperands();
2821 
2822     bool canSimplify = false;
2823     for (const auto &it :
2824          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2825       auto index = static_cast<unsigned>(it.index());
2826       Value initVal, yieldOpArg;
2827       std::tie(initVal, yieldOpArg) = it.value();
2828       // If i-th yield operand is equal to the i-th operand of the scf.while,
2829       // the i-th before block argument is a loop invariant.
2830       if (yieldOpArg == initVal) {
2831         canSimplify = true;
2832         break;
2833       }
2834       // If the i-th yield operand is k-th after block argument, then we check
2835       // if the (k+1)-th condition op operand is equal to either the i-th before
2836       // block argument or the initial value of i-th before block argument. If
2837       // the comparison results `true`, i-th before block argument is a loop
2838       // invariant.
2839       auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2840       if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2841         Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2842         if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2843           canSimplify = true;
2844           break;
2845         }
2846       }
2847     }
2848 
2849     if (!canSimplify)
2850       return failure();
2851 
2852     SmallVector<Value> newInitArgs, newYieldOpArgs;
2853     DenseMap<unsigned, Value> beforeBlockInitValMap;
2854     SmallVector<Location> newBeforeBlockArgLocs;
2855     for (const auto &it :
2856          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2857       auto index = static_cast<unsigned>(it.index());
2858       Value initVal, yieldOpArg;
2859       std::tie(initVal, yieldOpArg) = it.value();
2860 
2861       // If i-th yield operand is equal to the i-th operand of the scf.while,
2862       // the i-th before block argument is a loop invariant.
2863       if (yieldOpArg == initVal) {
2864         beforeBlockInitValMap.insert({index, initVal});
2865         continue;
2866       } else {
2867         // If the i-th yield operand is k-th after block argument, then we check
2868         // if the (k+1)-th condition op operand is equal to either the i-th
2869         // before block argument or the initial value of i-th before block
2870         // argument. If the comparison results `true`, i-th before block
2871         // argument is a loop invariant.
2872         auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2873         if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2874           Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2875           if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2876             beforeBlockInitValMap.insert({index, initVal});
2877             continue;
2878           }
2879         }
2880       }
2881       newInitArgs.emplace_back(initVal);
2882       newYieldOpArgs.emplace_back(yieldOpArg);
2883       newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
2884     }
2885 
2886     {
2887       OpBuilder::InsertionGuard g(rewriter);
2888       rewriter.setInsertionPoint(yieldOp);
2889       rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
2890     }
2891 
2892     auto newWhile =
2893         rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
2894 
2895     Block &newBeforeBlock = *rewriter.createBlock(
2896         &newWhile.getBefore(), /*insertPt*/ {},
2897         ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
2898 
2899     Block &beforeBlock = op.getBefore().front();
2900     SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
2901     // For each i-th before block argument we find it's replacement value as :-
2902     //   1. If i-th before block argument is a loop invariant, we fetch it's
2903     //      initial value from `beforeBlockInitValMap` by querying for key `i`.
2904     //   2. Else we fetch j-th new before block argument as the replacement
2905     //      value of i-th before block argument.
2906     for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
2907       // If the index 'i' argument was a loop invariant we fetch it's initial
2908       // value from `beforeBlockInitValMap`.
2909       if (beforeBlockInitValMap.count(i) != 0)
2910         newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
2911       else
2912         newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
2913     }
2914 
2915     rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
2916     rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
2917                                 newWhile.getAfter().begin());
2918 
2919     rewriter.replaceOp(op, newWhile.getResults());
2920     return success();
2921   }
2922 };
2923 
2924 /// Remove loop invariant value from result (condition op) of scf.while.
2925 /// A value is considered loop invariant if the final value yielded by
2926 /// scf.condition is defined outside of the `before` block. We remove the
2927 /// corresponding argument in `after` block and replace the use with the value.
2928 /// We also replace the use of the corresponding result of scf.while with the
2929 /// value.
2930 ///
2931 /// Eg:
2932 ///    INPUT :-
2933 ///    %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
2934 ///                                             %argN_before = %N) {
2935 ///                ...
2936 ///                scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
2937 ///           } do {
2938 ///             ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
2939 ///                ...
2940 ///                some_func(%arg1_after)
2941 ///                ...
2942 ///                scf.yield %arg0_after, %arg2_after, ..., %argN_after
2943 ///           }
2944 ///
2945 ///    OUTPUT :-
2946 ///    %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
2947 ///                ...
2948 ///                scf.condition(%cond) %arg0, %arg1, ..., %argM
2949 ///           } do {
2950 ///             ^bb0(%arg0, %arg3, ..., %argM):
2951 ///                ...
2952 ///                some_func(%a)
2953 ///                ...
2954 ///                scf.yield %arg0, %b, ..., %argN
2955 ///           }
2956 ///
2957 ///     EXPLANATION:
2958 ///       1. The 1-th and 2-th operand of scf.condition are defined outside the
2959 ///          before block of scf.while, so they get removed.
2960 ///       2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
2961 ///          replaced by %b.
2962 ///       3. The corresponding after block argument %arg1_after's uses are
2963 ///          replaced by %a and %arg2_after's uses are replaced by %b.
2964 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
2965   using OpRewritePattern<WhileOp>::OpRewritePattern;
2966 
matchAndRewrite__anoncfc0330f1811::RemoveLoopInvariantValueYielded2967   LogicalResult matchAndRewrite(WhileOp op,
2968                                 PatternRewriter &rewriter) const override {
2969     Block &beforeBlock = op.getBefore().front();
2970     ConditionOp condOp = op.getConditionOp();
2971     OperandRange condOpArgs = condOp.getArgs();
2972 
2973     bool canSimplify = false;
2974     for (Value condOpArg : condOpArgs) {
2975       // Those values not defined within `before` block will be considered as
2976       // loop invariant values. We map the corresponding `index` with their
2977       // value.
2978       if (condOpArg.getParentBlock() != &beforeBlock) {
2979         canSimplify = true;
2980         break;
2981       }
2982     }
2983 
2984     if (!canSimplify)
2985       return failure();
2986 
2987     Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
2988 
2989     SmallVector<Value> newCondOpArgs;
2990     SmallVector<Type> newAfterBlockType;
2991     DenseMap<unsigned, Value> condOpInitValMap;
2992     SmallVector<Location> newAfterBlockArgLocs;
2993     for (const auto &it : llvm::enumerate(condOpArgs)) {
2994       auto index = static_cast<unsigned>(it.index());
2995       Value condOpArg = it.value();
2996       // Those values not defined within `before` block will be considered as
2997       // loop invariant values. We map the corresponding `index` with their
2998       // value.
2999       if (condOpArg.getParentBlock() != &beforeBlock) {
3000         condOpInitValMap.insert({index, condOpArg});
3001       } else {
3002         newCondOpArgs.emplace_back(condOpArg);
3003         newAfterBlockType.emplace_back(condOpArg.getType());
3004         newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3005       }
3006     }
3007 
3008     {
3009       OpBuilder::InsertionGuard g(rewriter);
3010       rewriter.setInsertionPoint(condOp);
3011       rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3012                                                newCondOpArgs);
3013     }
3014 
3015     auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3016                                              op.getOperands());
3017 
3018     Block &newAfterBlock =
3019         *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3020                               newAfterBlockType, newAfterBlockArgLocs);
3021 
3022     Block &afterBlock = op.getAfter().front();
3023     // Since a new scf.condition op was created, we need to fetch the new
3024     // `after` block arguments which will be used while replacing operations of
3025     // previous scf.while's `after` blocks. We'd also be fetching new result
3026     // values too.
3027     SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3028     SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3029     for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3030       Value afterBlockArg, result;
3031       // If index 'i' argument was loop invariant we fetch it's value from the
3032       // `condOpInitMap` map.
3033       if (condOpInitValMap.count(i) != 0) {
3034         afterBlockArg = condOpInitValMap[i];
3035         result = afterBlockArg;
3036       } else {
3037         afterBlockArg = newAfterBlock.getArgument(j);
3038         result = newWhile.getResult(j);
3039         j++;
3040       }
3041       newAfterBlockArgs[i] = afterBlockArg;
3042       newWhileResults[i] = result;
3043     }
3044 
3045     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3046     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3047                                 newWhile.getBefore().begin());
3048 
3049     rewriter.replaceOp(op, newWhileResults);
3050     return success();
3051   }
3052 };
3053 
3054 /// Remove WhileOp results that are also unused in 'after' block.
3055 ///
3056 ///  %0:2 = scf.while () : () -> (i32, i64) {
3057 ///    %condition = "test.condition"() : () -> i1
3058 ///    %v1 = "test.get_some_value"() : () -> i32
3059 ///    %v2 = "test.get_some_value"() : () -> i64
3060 ///    scf.condition(%condition) %v1, %v2 : i32, i64
3061 ///  } do {
3062 ///  ^bb0(%arg0: i32, %arg1: i64):
3063 ///    "test.use"(%arg0) : (i32) -> ()
3064 ///    scf.yield
3065 ///  }
3066 ///  return %0#0 : i32
3067 ///
3068 /// becomes
3069 ///  %0 = scf.while () : () -> (i32) {
3070 ///    %condition = "test.condition"() : () -> i1
3071 ///    %v1 = "test.get_some_value"() : () -> i32
3072 ///    %v2 = "test.get_some_value"() : () -> i64
3073 ///    scf.condition(%condition) %v1 : i32
3074 ///  } do {
3075 ///  ^bb0(%arg0: i32):
3076 ///    "test.use"(%arg0) : (i32) -> ()
3077 ///    scf.yield
3078 ///  }
3079 ///  return %0 : i32
3080 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3081   using OpRewritePattern<WhileOp>::OpRewritePattern;
3082 
matchAndRewrite__anoncfc0330f1811::WhileUnusedResult3083   LogicalResult matchAndRewrite(WhileOp op,
3084                                 PatternRewriter &rewriter) const override {
3085     auto term = op.getConditionOp();
3086     auto afterArgs = op.getAfterArguments();
3087     auto termArgs = term.getArgs();
3088 
3089     // Collect results mapping, new terminator args and new result types.
3090     SmallVector<unsigned> newResultsIndices;
3091     SmallVector<Type> newResultTypes;
3092     SmallVector<Value> newTermArgs;
3093     SmallVector<Location> newArgLocs;
3094     bool needUpdate = false;
3095     for (const auto &it :
3096          llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3097       auto i = static_cast<unsigned>(it.index());
3098       Value result = std::get<0>(it.value());
3099       Value afterArg = std::get<1>(it.value());
3100       Value termArg = std::get<2>(it.value());
3101       if (result.use_empty() && afterArg.use_empty()) {
3102         needUpdate = true;
3103       } else {
3104         newResultsIndices.emplace_back(i);
3105         newTermArgs.emplace_back(termArg);
3106         newResultTypes.emplace_back(result.getType());
3107         newArgLocs.emplace_back(result.getLoc());
3108       }
3109     }
3110 
3111     if (!needUpdate)
3112       return failure();
3113 
3114     {
3115       OpBuilder::InsertionGuard g(rewriter);
3116       rewriter.setInsertionPoint(term);
3117       rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3118                                                newTermArgs);
3119     }
3120 
3121     auto newWhile =
3122         rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3123 
3124     Block &newAfterBlock = *rewriter.createBlock(
3125         &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3126 
3127     // Build new results list and new after block args (unused entries will be
3128     // null).
3129     SmallVector<Value> newResults(op.getNumResults());
3130     SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3131     for (const auto &it : llvm::enumerate(newResultsIndices)) {
3132       newResults[it.value()] = newWhile.getResult(it.index());
3133       newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3134     }
3135 
3136     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3137                                 newWhile.getBefore().begin());
3138 
3139     Block &afterBlock = op.getAfter().front();
3140     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3141 
3142     rewriter.replaceOp(op, newResults);
3143     return success();
3144   }
3145 };
3146 
3147 /// Replace operations equivalent to the condition in the do block with true,
3148 /// since otherwise the block would not be evaluated.
3149 ///
3150 /// scf.while (..) : (i32, ...) -> ... {
3151 ///  %z = ... : i32
3152 ///  %condition = cmpi pred %z, %a
3153 ///  scf.condition(%condition) %z : i32, ...
3154 /// } do {
3155 /// ^bb0(%arg0: i32, ...):
3156 ///    %condition2 = cmpi pred %arg0, %a
3157 ///    use(%condition2)
3158 ///    ...
3159 ///
3160 /// becomes
3161 /// scf.while (..) : (i32, ...) -> ... {
3162 ///  %z = ... : i32
3163 ///  %condition = cmpi pred %z, %a
3164 ///  scf.condition(%condition) %z : i32, ...
3165 /// } do {
3166 /// ^bb0(%arg0: i32, ...):
3167 ///    use(%true)
3168 ///    ...
3169 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3170   using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3171 
matchAndRewrite__anoncfc0330f1811::WhileCmpCond3172   LogicalResult matchAndRewrite(scf::WhileOp op,
3173                                 PatternRewriter &rewriter) const override {
3174     using namespace scf;
3175     auto cond = op.getConditionOp();
3176     auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3177     if (!cmp)
3178       return failure();
3179     bool changed = false;
3180     for (auto tup :
3181          llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
3182       for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3183         if (std::get<0>(tup) != cmp.getOperand(opIdx))
3184           continue;
3185         for (OpOperand &u :
3186              llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3187           auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3188           if (!cmp2)
3189             continue;
3190           // For a binary operator 1-opIdx gets the other side.
3191           if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3192             continue;
3193           bool samePredicate;
3194           if (cmp2.getPredicate() == cmp.getPredicate())
3195             samePredicate = true;
3196           else if (cmp2.getPredicate() ==
3197                    arith::invertPredicate(cmp.getPredicate()))
3198             samePredicate = false;
3199           else
3200             continue;
3201 
3202           rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3203                                                             1);
3204           changed = true;
3205         }
3206       }
3207     }
3208     return success(changed);
3209   }
3210 };
3211 
3212 struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
3213   using OpRewritePattern<WhileOp>::OpRewritePattern;
3214 
matchAndRewrite__anoncfc0330f1811::WhileUnusedArg3215   LogicalResult matchAndRewrite(WhileOp op,
3216                                 PatternRewriter &rewriter) const override {
3217 
3218     if (!llvm::any_of(op.getBeforeArguments(),
3219                       [](Value arg) { return arg.use_empty(); }))
3220       return failure();
3221 
3222     YieldOp yield = op.getYieldOp();
3223 
3224     // Collect results mapping, new terminator args and new result types.
3225     SmallVector<Value> newYields;
3226     SmallVector<Value> newInits;
3227     SmallVector<unsigned> argsToErase;
3228     for (const auto &it : llvm::enumerate(llvm::zip(
3229              op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
3230       Value beforeArg = std::get<0>(it.value());
3231       Value yieldValue = std::get<1>(it.value());
3232       Value initValue = std::get<2>(it.value());
3233       if (beforeArg.use_empty()) {
3234         argsToErase.push_back(it.index());
3235       } else {
3236         newYields.emplace_back(yieldValue);
3237         newInits.emplace_back(initValue);
3238       }
3239     }
3240 
3241     if (argsToErase.empty())
3242       return failure();
3243 
3244     rewriter.startRootUpdate(op);
3245     op.getBefore().front().eraseArguments(argsToErase);
3246     rewriter.finalizeRootUpdate(op);
3247 
3248     WhileOp replacement =
3249         rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
3250     replacement.getBefore().takeBody(op.getBefore());
3251     replacement.getAfter().takeBody(op.getAfter());
3252     rewriter.replaceOp(op, replacement.getResults());
3253 
3254     rewriter.setInsertionPoint(yield);
3255     rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3256     return success();
3257   }
3258 };
3259 } // namespace
3260 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3261 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3262                                           MLIRContext *context) {
3263   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
3264               RemoveLoopInvariantValueYielded, WhileConditionTruth,
3265               WhileCmpCond, WhileUnusedResult>(context);
3266 }
3267 
3268 //===----------------------------------------------------------------------===//
3269 // TableGen'd op method definitions
3270 //===----------------------------------------------------------------------===//
3271 
3272 #define GET_OP_CLASSES
3273 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
3274