1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/IR/SCF.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/FunctionInterfaces.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Support/MathExtras.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 
23 using namespace mlir;
24 using namespace mlir::scf;
25 
26 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // SCFDialect Dialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 struct SCFInlinerInterface : public DialectInlinerInterface {
34   using DialectInlinerInterface::DialectInlinerInterface;
35   // We don't have any special restrictions on what can be inlined into
36   // destination regions (e.g. while/conditional bodies). Always allow it.
37   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
38                        BlockAndValueMapping &valueMapping) const final {
39     return true;
40   }
41   // Operations in scf dialect are always legal to inline since they are
42   // pure.
43   bool isLegalToInline(Operation *, Region *, bool,
44                        BlockAndValueMapping &) const final {
45     return true;
46   }
47   // Handle the given inlined terminator by replacing it with a new operation
48   // as necessary. Required when the region has only one block.
49   void handleTerminator(Operation *op,
50                         ArrayRef<Value> valuesToRepl) const final {
51     auto retValOp = dyn_cast<scf::YieldOp>(op);
52     if (!retValOp)
53       return;
54 
55     for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
56       std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
57     }
58   }
59 };
60 } // namespace
61 
62 //===----------------------------------------------------------------------===//
63 // SCFDialect
64 //===----------------------------------------------------------------------===//
65 
66 void SCFDialect::initialize() {
67   addOperations<
68 #define GET_OP_LIST
69 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
70       >();
71   addInterfaces<SCFInlinerInterface>();
72 }
73 
74 /// Default callback for IfOp builders. Inserts a yield without arguments.
75 void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
76   builder.create<scf::YieldOp>(loc);
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // ExecuteRegionOp
81 //===----------------------------------------------------------------------===//
82 
83 /// Replaces the given op with the contents of the given single-block region,
84 /// using the operands of the block terminator to replace operation results.
85 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
86                                 Region &region, ValueRange blockArgs = {}) {
87   assert(llvm::hasSingleElement(region) && "expected single-region block");
88   Block *block = &region.front();
89   Operation *terminator = block->getTerminator();
90   ValueRange results = terminator->getOperands();
91   rewriter.mergeBlockBefore(block, op, blockArgs);
92   rewriter.replaceOp(op, results);
93   rewriter.eraseOp(terminator);
94 }
95 
96 ///
97 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
98 ///    block+
99 /// `}`
100 ///
101 /// Example:
102 ///   scf.execute_region -> i32 {
103 ///     %idx = load %rI[%i] : memref<128xi32>
104 ///     return %idx : i32
105 ///   }
106 ///
107 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
108                                    OperationState &result) {
109   if (parser.parseOptionalArrowTypeList(result.types))
110     return failure();
111 
112   // Introduce the body region and parse it.
113   Region *body = result.addRegion();
114   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
115       parser.parseOptionalAttrDict(result.attributes))
116     return failure();
117 
118   return success();
119 }
120 
121 void ExecuteRegionOp::print(OpAsmPrinter &p) {
122   p.printOptionalArrowTypeList(getResultTypes());
123 
124   p << ' ';
125   p.printRegion(getRegion(),
126                 /*printEntryBlockArgs=*/false,
127                 /*printBlockTerminators=*/true);
128 
129   p.printOptionalAttrDict((*this)->getAttrs());
130 }
131 
132 LogicalResult ExecuteRegionOp::verify() {
133   if (getRegion().empty())
134     return emitOpError("region needs to have at least one block");
135   if (getRegion().front().getNumArguments() > 0)
136     return emitOpError("region cannot have any arguments");
137   return success();
138 }
139 
140 // Inline an ExecuteRegionOp if it only contains one block.
141 //     "test.foo"() : () -> ()
142 //      %v = scf.execute_region -> i64 {
143 //        %x = "test.val"() : () -> i64
144 //        scf.yield %x : i64
145 //      }
146 //      "test.bar"(%v) : (i64) -> ()
147 //
148 //  becomes
149 //
150 //     "test.foo"() : () -> ()
151 //     %x = "test.val"() : () -> i64
152 //     "test.bar"(%x) : (i64) -> ()
153 //
154 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
155   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
156 
157   LogicalResult matchAndRewrite(ExecuteRegionOp op,
158                                 PatternRewriter &rewriter) const override {
159     if (!llvm::hasSingleElement(op.getRegion()))
160       return failure();
161     replaceOpWithRegion(rewriter, op, op.getRegion());
162     return success();
163   }
164 };
165 
166 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
167 // TODO generalize the conditions for operations which can be inlined into.
168 // func @func_execute_region_elim() {
169 //     "test.foo"() : () -> ()
170 //     %v = scf.execute_region -> i64 {
171 //       %c = "test.cmp"() : () -> i1
172 //       cf.cond_br %c, ^bb2, ^bb3
173 //     ^bb2:
174 //       %x = "test.val1"() : () -> i64
175 //       cf.br ^bb4(%x : i64)
176 //     ^bb3:
177 //       %y = "test.val2"() : () -> i64
178 //       cf.br ^bb4(%y : i64)
179 //     ^bb4(%z : i64):
180 //       scf.yield %z : i64
181 //     }
182 //     "test.bar"(%v) : (i64) -> ()
183 //   return
184 // }
185 //
186 //  becomes
187 //
188 // func @func_execute_region_elim() {
189 //    "test.foo"() : () -> ()
190 //    %c = "test.cmp"() : () -> i1
191 //    cf.cond_br %c, ^bb1, ^bb2
192 //  ^bb1:  // pred: ^bb0
193 //    %x = "test.val1"() : () -> i64
194 //    cf.br ^bb3(%x : i64)
195 //  ^bb2:  // pred: ^bb0
196 //    %y = "test.val2"() : () -> i64
197 //    cf.br ^bb3(%y : i64)
198 //  ^bb3(%z: i64):  // 2 preds: ^bb1, ^bb2
199 //    "test.bar"(%z) : (i64) -> ()
200 //    return
201 //  }
202 //
203 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
204   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
205 
206   LogicalResult matchAndRewrite(ExecuteRegionOp op,
207                                 PatternRewriter &rewriter) const override {
208     if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
209       return failure();
210 
211     Block *prevBlock = op->getBlock();
212     Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
213     rewriter.setInsertionPointToEnd(prevBlock);
214 
215     rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
216 
217     for (Block &blk : op.getRegion()) {
218       if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
219         rewriter.setInsertionPoint(yieldOp);
220         rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
221                                       yieldOp.getResults());
222         rewriter.eraseOp(yieldOp);
223       }
224     }
225 
226     rewriter.inlineRegionBefore(op.getRegion(), postBlock);
227     SmallVector<Value> blockArgs;
228 
229     for (auto res : op.getResults())
230       blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
231 
232     rewriter.replaceOp(op, blockArgs);
233     return success();
234   }
235 };
236 
237 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
238                                                   MLIRContext *context) {
239   results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
240 }
241 
242 /// Given the region at `index`, or the parent operation if `index` is None,
243 /// return the successor regions. These are the regions that may be selected
244 /// during the flow of control. `operands` is a set of optional attributes that
245 /// correspond to a constant value for each operand, or null if that operand is
246 /// not a constant.
247 void ExecuteRegionOp::getSuccessorRegions(
248     Optional<unsigned> index, ArrayRef<Attribute> operands,
249     SmallVectorImpl<RegionSuccessor> &regions) {
250   // If the predecessor is the ExecuteRegionOp, branch into the body.
251   if (!index) {
252     regions.push_back(RegionSuccessor(&getRegion()));
253     return;
254   }
255 
256   // Otherwise, the region branches back to the parent operation.
257   regions.push_back(RegionSuccessor(getResults()));
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // ConditionOp
262 //===----------------------------------------------------------------------===//
263 
264 MutableOperandRange
265 ConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
266   // Pass all operands except the condition to the successor region.
267   return getArgsMutable();
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // ForOp
272 //===----------------------------------------------------------------------===//
273 
274 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
275                   Value ub, Value step, ValueRange iterArgs,
276                   BodyBuilderFn bodyBuilder) {
277   result.addOperands({lb, ub, step});
278   result.addOperands(iterArgs);
279   for (Value v : iterArgs)
280     result.addTypes(v.getType());
281   Region *bodyRegion = result.addRegion();
282   bodyRegion->push_back(new Block);
283   Block &bodyBlock = bodyRegion->front();
284   bodyBlock.addArgument(builder.getIndexType(), result.location);
285   for (Value v : iterArgs)
286     bodyBlock.addArgument(v.getType(), v.getLoc());
287 
288   // Create the default terminator if the builder is not provided and if the
289   // iteration arguments are not provided. Otherwise, leave this to the caller
290   // because we don't know which values to return from the loop.
291   if (iterArgs.empty() && !bodyBuilder) {
292     ForOp::ensureTerminator(*bodyRegion, builder, result.location);
293   } else if (bodyBuilder) {
294     OpBuilder::InsertionGuard guard(builder);
295     builder.setInsertionPointToStart(&bodyBlock);
296     bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
297                 bodyBlock.getArguments().drop_front());
298   }
299 }
300 
301 LogicalResult ForOp::verify() {
302   if (auto cst = getStep().getDefiningOp<arith::ConstantIndexOp>())
303     if (cst.value() <= 0)
304       return emitOpError("constant step operand must be positive");
305 
306   auto opNumResults = getNumResults();
307   if (opNumResults == 0)
308     return success();
309   // If ForOp defines values, check that the number and types of
310   // the defined values match ForOp initial iter operands and backedge
311   // basic block arguments.
312   if (getNumIterOperands() != opNumResults)
313     return emitOpError(
314         "mismatch in number of loop-carried values and defined values");
315   return success();
316 }
317 
318 LogicalResult ForOp::verifyRegions() {
319   // Check that the body defines as single block argument for the induction
320   // variable.
321   auto *body = getBody();
322   if (!body->getArgument(0).getType().isIndex())
323     return emitOpError(
324         "expected body first argument to be an index argument for "
325         "the induction variable");
326 
327   auto opNumResults = getNumResults();
328   if (opNumResults == 0)
329     return success();
330 
331   if (getNumRegionIterArgs() != opNumResults)
332     return emitOpError(
333         "mismatch in number of basic block args and defined values");
334 
335   auto iterOperands = getIterOperands();
336   auto iterArgs = getRegionIterArgs();
337   auto opResults = getResults();
338   unsigned i = 0;
339   for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
340     if (std::get<0>(e).getType() != std::get<2>(e).getType())
341       return emitOpError() << "types mismatch between " << i
342                            << "th iter operand and defined value";
343     if (std::get<1>(e).getType() != std::get<2>(e).getType())
344       return emitOpError() << "types mismatch between " << i
345                            << "th iter region arg and defined value";
346 
347     i++;
348   }
349   return success();
350 }
351 
352 Optional<Value> ForOp::getSingleInductionVar() { return getInductionVar(); }
353 
354 Optional<OpFoldResult> ForOp::getSingleLowerBound() {
355   return OpFoldResult(getLowerBound());
356 }
357 
358 Optional<OpFoldResult> ForOp::getSingleStep() {
359   return OpFoldResult(getStep());
360 }
361 
362 Optional<OpFoldResult> ForOp::getSingleUpperBound() {
363   return OpFoldResult(getUpperBound());
364 }
365 
366 /// Prints the initialization list in the form of
367 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
368 /// where 'inner' values are assumed to be region arguments and 'outer' values
369 /// are regular SSA values.
370 static void printInitializationList(OpAsmPrinter &p,
371                                     Block::BlockArgListType blocksArgs,
372                                     ValueRange initializers,
373                                     StringRef prefix = "") {
374   assert(blocksArgs.size() == initializers.size() &&
375          "expected same length of arguments and initializers");
376   if (initializers.empty())
377     return;
378 
379   p << prefix << '(';
380   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
381     p << std::get<0>(it) << " = " << std::get<1>(it);
382   });
383   p << ")";
384 }
385 
386 void ForOp::print(OpAsmPrinter &p) {
387   p << " " << getInductionVar() << " = " << getLowerBound() << " to "
388     << getUpperBound() << " step " << getStep();
389 
390   printInitializationList(p, getRegionIterArgs(), getIterOperands(),
391                           " iter_args");
392   if (!getIterOperands().empty())
393     p << " -> (" << getIterOperands().getTypes() << ')';
394   p << ' ';
395   p.printRegion(getRegion(),
396                 /*printEntryBlockArgs=*/false,
397                 /*printBlockTerminators=*/hasIterOperands());
398   p.printOptionalAttrDict((*this)->getAttrs());
399 }
400 
401 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
402   auto &builder = parser.getBuilder();
403   Type indexType = builder.getIndexType();
404 
405   OpAsmParser::Argument inductionVariable;
406   inductionVariable.type = indexType;
407   OpAsmParser::UnresolvedOperand lb, ub, step;
408 
409   // Parse the induction variable followed by '='.
410   if (parser.parseArgument(inductionVariable) || parser.parseEqual() ||
411       // Parse loop bounds.
412       parser.parseOperand(lb) ||
413       parser.resolveOperand(lb, indexType, result.operands) ||
414       parser.parseKeyword("to") || parser.parseOperand(ub) ||
415       parser.resolveOperand(ub, indexType, result.operands) ||
416       parser.parseKeyword("step") || parser.parseOperand(step) ||
417       parser.resolveOperand(step, indexType, result.operands))
418     return failure();
419 
420   // Parse the optional initial iteration arguments.
421   SmallVector<OpAsmParser::Argument, 4> regionArgs;
422   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
423   regionArgs.push_back(inductionVariable);
424 
425   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
426     // Parse assignment list and results type list.
427     if (parser.parseAssignmentList(regionArgs, operands) ||
428         parser.parseArrowTypeList(result.types))
429       return failure();
430 
431     // Resolve input operands.
432     for (auto argOperandType :
433          llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
434       Type type = std::get<2>(argOperandType);
435       std::get<0>(argOperandType).type = type;
436       if (parser.resolveOperand(std::get<1>(argOperandType), type,
437                                 result.operands))
438         return failure();
439     }
440   }
441 
442   if (regionArgs.size() != result.types.size() + 1)
443     return parser.emitError(
444         parser.getNameLoc(),
445         "mismatch in number of loop-carried values and defined values");
446 
447   // Parse the body region.
448   Region *body = result.addRegion();
449   if (parser.parseRegion(*body, regionArgs))
450     return failure();
451 
452   ForOp::ensureTerminator(*body, builder, result.location);
453 
454   // Parse the optional attribute list.
455   if (parser.parseOptionalAttrDict(result.attributes))
456     return failure();
457 
458   return success();
459 }
460 
461 Region &ForOp::getLoopBody() { return getRegion(); }
462 
463 ForOp mlir::scf::getForInductionVarOwner(Value val) {
464   auto ivArg = val.dyn_cast<BlockArgument>();
465   if (!ivArg)
466     return ForOp();
467   assert(ivArg.getOwner() && "unlinked block argument");
468   auto *containingOp = ivArg.getOwner()->getParentOp();
469   return dyn_cast_or_null<ForOp>(containingOp);
470 }
471 
472 /// Return operands used when entering the region at 'index'. These operands
473 /// correspond to the loop iterator operands, i.e., those excluding the
474 /// induction variable. LoopOp only has one region, so 0 is the only valid value
475 /// for `index`.
476 OperandRange ForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
477   assert(index && *index == 0 && "invalid region index");
478 
479   // The initial operands map to the loop arguments after the induction
480   // variable.
481   return getInitArgs();
482 }
483 
484 /// Given the region at `index`, or the parent operation if `index` is None,
485 /// return the successor regions. These are the regions that may be selected
486 /// during the flow of control. `operands` is a set of optional attributes that
487 /// correspond to a constant value for each operand, or null if that operand is
488 /// not a constant.
489 void ForOp::getSuccessorRegions(Optional<unsigned> index,
490                                 ArrayRef<Attribute> operands,
491                                 SmallVectorImpl<RegionSuccessor> &regions) {
492   // If the predecessor is the ForOp, branch into the body using the iterator
493   // arguments.
494   if (!index) {
495     regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
496     return;
497   }
498 
499   // Otherwise, the loop may branch back to itself or the parent operation.
500   assert(*index == 0 && "expected loop region");
501   regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
502   regions.push_back(RegionSuccessor(getResults()));
503 }
504 
505 LoopNest mlir::scf::buildLoopNest(
506     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
507     ValueRange steps, ValueRange iterArgs,
508     function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
509         bodyBuilder) {
510   assert(lbs.size() == ubs.size() &&
511          "expected the same number of lower and upper bounds");
512   assert(lbs.size() == steps.size() &&
513          "expected the same number of lower bounds and steps");
514 
515   // If there are no bounds, call the body-building function and return early.
516   if (lbs.empty()) {
517     ValueVector results =
518         bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
519                     : ValueVector();
520     assert(results.size() == iterArgs.size() &&
521            "loop nest body must return as many values as loop has iteration "
522            "arguments");
523     return LoopNest();
524   }
525 
526   // First, create the loop structure iteratively using the body-builder
527   // callback of `ForOp::build`. Do not create `YieldOp`s yet.
528   OpBuilder::InsertionGuard guard(builder);
529   SmallVector<scf::ForOp, 4> loops;
530   SmallVector<Value, 4> ivs;
531   loops.reserve(lbs.size());
532   ivs.reserve(lbs.size());
533   ValueRange currentIterArgs = iterArgs;
534   Location currentLoc = loc;
535   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
536     auto loop = builder.create<scf::ForOp>(
537         currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
538         [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
539             ValueRange args) {
540           ivs.push_back(iv);
541           // It is safe to store ValueRange args because it points to block
542           // arguments of a loop operation that we also own.
543           currentIterArgs = args;
544           currentLoc = nestedLoc;
545         });
546     // Set the builder to point to the body of the newly created loop. We don't
547     // do this in the callback because the builder is reset when the callback
548     // returns.
549     builder.setInsertionPointToStart(loop.getBody());
550     loops.push_back(loop);
551   }
552 
553   // For all loops but the innermost, yield the results of the nested loop.
554   for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
555     builder.setInsertionPointToEnd(loops[i].getBody());
556     builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
557   }
558 
559   // In the body of the innermost loop, call the body building function if any
560   // and yield its results.
561   builder.setInsertionPointToStart(loops.back().getBody());
562   ValueVector results = bodyBuilder
563                             ? bodyBuilder(builder, currentLoc, ivs,
564                                           loops.back().getRegionIterArgs())
565                             : ValueVector();
566   assert(results.size() == iterArgs.size() &&
567          "loop nest body must return as many values as loop has iteration "
568          "arguments");
569   builder.setInsertionPointToEnd(loops.back().getBody());
570   builder.create<scf::YieldOp>(loc, results);
571 
572   // Return the loops.
573   LoopNest res;
574   res.loops.assign(loops.begin(), loops.end());
575   return res;
576 }
577 
578 LoopNest mlir::scf::buildLoopNest(
579     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
580     ValueRange steps,
581     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
582   // Delegate to the main function by wrapping the body builder.
583   return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
584                        [&bodyBuilder](OpBuilder &nestedBuilder,
585                                       Location nestedLoc, ValueRange ivs,
586                                       ValueRange) -> ValueVector {
587                          if (bodyBuilder)
588                            bodyBuilder(nestedBuilder, nestedLoc, ivs);
589                          return {};
590                        });
591 }
592 
593 namespace {
594 // Fold away ForOp iter arguments when:
595 // 1) The op yields the iter arguments.
596 // 2) The iter arguments have no use and the corresponding outer region
597 // iterators (inputs) are yielded.
598 // 3) The iter arguments have no use and the corresponding (operation) results
599 // have no use.
600 //
601 // These arguments must be defined outside of
602 // the ForOp region and can just be forwarded after simplifying the op inits,
603 // yields and returns.
604 //
605 // The implementation uses `mergeBlockBefore` to steal the content of the
606 // original ForOp and avoid cloning.
607 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
608   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
609 
610   LogicalResult matchAndRewrite(scf::ForOp forOp,
611                                 PatternRewriter &rewriter) const final {
612     bool canonicalize = false;
613     Block &block = forOp.getRegion().front();
614     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
615 
616     // An internal flat vector of block transfer
617     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
618     // transformed block argument mappings. This plays the role of a
619     // BlockAndValueMapping for the particular use case of calling into
620     // `mergeBlockBefore`.
621     SmallVector<bool, 4> keepMask;
622     keepMask.reserve(yieldOp.getNumOperands());
623     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
624         newResultValues;
625     newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
626     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
627     newIterArgs.reserve(forOp.getNumIterOperands());
628     newYieldValues.reserve(yieldOp.getNumOperands());
629     newResultValues.reserve(forOp.getNumResults());
630     for (auto it : llvm::zip(forOp.getIterOperands(),   // iter from outside
631                              forOp.getRegionIterArgs(), // iter inside region
632                              forOp.getResults(),        // op results
633                              yieldOp.getOperands()      // iter yield
634                              )) {
635       // Forwarded is `true` when:
636       // 1) The region `iter` argument is yielded.
637       // 2) The region `iter` argument has no use, and the corresponding iter
638       // operand (input) is yielded.
639       // 3) The region `iter` argument has no use, and the corresponding op
640       // result has no use.
641       bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
642                         (std::get<1>(it).use_empty() &&
643                          (std::get<0>(it) == std::get<3>(it) ||
644                           std::get<2>(it).use_empty())));
645       keepMask.push_back(!forwarded);
646       canonicalize |= forwarded;
647       if (forwarded) {
648         newBlockTransferArgs.push_back(std::get<0>(it));
649         newResultValues.push_back(std::get<0>(it));
650         continue;
651       }
652       newIterArgs.push_back(std::get<0>(it));
653       newYieldValues.push_back(std::get<3>(it));
654       newBlockTransferArgs.push_back(Value()); // placeholder with null value
655       newResultValues.push_back(Value());      // placeholder with null value
656     }
657 
658     if (!canonicalize)
659       return failure();
660 
661     scf::ForOp newForOp = rewriter.create<scf::ForOp>(
662         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
663         forOp.getStep(), newIterArgs);
664     newForOp->setAttrs(forOp->getAttrs());
665     Block &newBlock = newForOp.getRegion().front();
666 
667     // Replace the null placeholders with newly constructed values.
668     newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
669     for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
670          idx != e; ++idx) {
671       Value &blockTransferArg = newBlockTransferArgs[1 + idx];
672       Value &newResultVal = newResultValues[idx];
673       assert((blockTransferArg && newResultVal) ||
674              (!blockTransferArg && !newResultVal));
675       if (!blockTransferArg) {
676         blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
677         newResultVal = newForOp.getResult(collapsedIdx++);
678       }
679     }
680 
681     Block &oldBlock = forOp.getRegion().front();
682     assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
683            "unexpected argument size mismatch");
684 
685     // No results case: the scf::ForOp builder already created a zero
686     // result terminator. Merge before this terminator and just get rid of the
687     // original terminator that has been merged in.
688     if (newIterArgs.empty()) {
689       auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
690       rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
691       rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
692       rewriter.replaceOp(forOp, newResultValues);
693       return success();
694     }
695 
696     // No terminator case: merge and rewrite the merged terminator.
697     auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
698       OpBuilder::InsertionGuard g(rewriter);
699       rewriter.setInsertionPoint(mergedTerminator);
700       SmallVector<Value, 4> filteredOperands;
701       filteredOperands.reserve(newResultValues.size());
702       for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
703         if (keepMask[idx])
704           filteredOperands.push_back(mergedTerminator.getOperand(idx));
705       rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
706                                     filteredOperands);
707     };
708 
709     rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
710     auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
711     cloneFilteredTerminator(mergedYieldOp);
712     rewriter.eraseOp(mergedYieldOp);
713     rewriter.replaceOp(forOp, newResultValues);
714     return success();
715   }
716 };
717 
718 /// Rewriting pattern that erases loops that are known not to iterate, replaces
719 /// single-iteration loops with their bodies, and removes empty loops that
720 /// iterate at least once and only return values defined outside of the loop.
721 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
722   using OpRewritePattern<ForOp>::OpRewritePattern;
723 
724   LogicalResult matchAndRewrite(ForOp op,
725                                 PatternRewriter &rewriter) const override {
726     // If the upper bound is the same as the lower bound, the loop does not
727     // iterate, just remove it.
728     if (op.getLowerBound() == op.getUpperBound()) {
729       rewriter.replaceOp(op, op.getIterOperands());
730       return success();
731     }
732 
733     auto lb = op.getLowerBound().getDefiningOp<arith::ConstantOp>();
734     auto ub = op.getUpperBound().getDefiningOp<arith::ConstantOp>();
735     if (!lb || !ub)
736       return failure();
737 
738     // If the loop is known to have 0 iterations, remove it.
739     llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
740     llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
741     if (lbValue.sge(ubValue)) {
742       rewriter.replaceOp(op, op.getIterOperands());
743       return success();
744     }
745 
746     auto step = op.getStep().getDefiningOp<arith::ConstantOp>();
747     if (!step)
748       return failure();
749 
750     // If the loop is known to have 1 iteration, inline its body and remove the
751     // loop.
752     llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
753     if ((lbValue + stepValue).sge(ubValue)) {
754       SmallVector<Value, 4> blockArgs;
755       blockArgs.reserve(op.getNumIterOperands() + 1);
756       blockArgs.push_back(op.getLowerBound());
757       llvm::append_range(blockArgs, op.getIterOperands());
758       replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
759       return success();
760     }
761 
762     // Now we are left with loops that have more than 1 iterations.
763     Block &block = op.getRegion().front();
764     if (!llvm::hasSingleElement(block))
765       return failure();
766     // If the loop is empty, iterates at least once, and only returns values
767     // defined outside of the loop, remove it and replace it with yield values.
768     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
769     auto yieldOperands = yieldOp.getOperands();
770     if (llvm::any_of(yieldOperands,
771                      [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
772       return failure();
773     rewriter.replaceOp(op, yieldOperands);
774     return success();
775   }
776 };
777 
778 /// Perform a replacement of one iter OpOperand of an scf.for to the
779 /// `replacement` value which is expected to be the source of a tensor.cast.
780 /// tensor.cast ops are inserted inside the block to account for the type cast.
781 static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
782                                            OpOperand &operand,
783                                            Value replacement) {
784   Type oldType = operand.get().getType(), newType = replacement.getType();
785   assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
786          "expected ranked tensor types");
787 
788   // 1. Create new iter operands, exactly 1 is replaced.
789   ForOp forOp = cast<ForOp>(operand.getOwner());
790   assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
791          "expected an iter OpOperand");
792   if (operand.get().getType() == replacement.getType())
793     return forOp;
794   SmallVector<Value> newIterOperands;
795   for (OpOperand &opOperand : forOp.getIterOpOperands()) {
796     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
797       newIterOperands.push_back(replacement);
798       continue;
799     }
800     newIterOperands.push_back(opOperand.get());
801   }
802 
803   // 2. Create the new forOp shell.
804   scf::ForOp newForOp = rewriter.create<scf::ForOp>(
805       forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806       forOp.getStep(), newIterOperands);
807   newForOp->setAttrs(forOp->getAttrs());
808   Block &newBlock = newForOp.getRegion().front();
809   SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
810                                              newBlock.getArguments().end());
811 
812   // 3. Inject an incoming cast op at the beginning of the block for the bbArg
813   // corresponding to the `replacement` value.
814   OpBuilder::InsertionGuard g(rewriter);
815   rewriter.setInsertionPoint(&newBlock, newBlock.begin());
816   BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
817       newForOp->getOpOperand(operand.getOperandNumber()));
818   Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
819                                                  newRegionIterArg);
820   newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
821 
822   // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
823   Block &oldBlock = forOp.getRegion().front();
824   rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
825 
826   // 5. Inject an outgoing cast op at the end of the block and yield it instead.
827   auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
828   rewriter.setInsertionPoint(clonedYieldOp);
829   unsigned yieldIdx =
830       newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
831   Value castOut = rewriter.create<tensor::CastOp>(
832       newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
833   SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
834   newYieldOperands[yieldIdx] = castOut;
835   rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
836   rewriter.eraseOp(clonedYieldOp);
837 
838   // 6. Inject an outgoing cast op after the forOp.
839   rewriter.setInsertionPointAfter(newForOp);
840   SmallVector<Value> newResults = newForOp.getResults();
841   newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
842       newForOp.getLoc(), oldType, newResults[yieldIdx]);
843 
844   return newForOp;
845 }
846 
847 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
848 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
849 ///
850 /// ```
851 ///   %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
852 ///   %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
853 ///      -> (tensor<?x?xf32>) {
854 ///     %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
855 ///     scf.yield %2 : tensor<?x?xf32>
856 ///   }
857 ///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
858 ///   use_of(%2)
859 /// ```
860 ///
861 /// folds into:
862 ///
863 /// ```
864 ///   %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
865 ///       -> (tensor<32x1024xf32>) {
866 ///     %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
867 ///     %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
868 ///     %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
869 ///     scf.yield %4 : tensor<32x1024xf32>
870 ///   }
871 ///   use_of(%0)
872 /// ```
873 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
874   using OpRewritePattern<ForOp>::OpRewritePattern;
875 
876   LogicalResult matchAndRewrite(ForOp op,
877                                 PatternRewriter &rewriter) const override {
878     for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
879       OpOperand &iterOpOperand = std::get<0>(it);
880       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
881       if (!incomingCast)
882         continue;
883       if (!std::get<1>(it).hasOneUse())
884         continue;
885       auto outgoingCastOp =
886           dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
887       if (!outgoingCastOp)
888         continue;
889 
890       // Must be a tensor.cast op pair with matching types.
891       if (outgoingCastOp.getResult().getType() !=
892           incomingCast.getSource().getType())
893         continue;
894 
895       // Create a new ForOp with that iter operand replaced.
896       auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
897                                                     incomingCast.getSource());
898 
899       // Insert outgoing cast and use it to replace the corresponding result.
900       rewriter.setInsertionPointAfter(newForOp);
901       SmallVector<Value> replacements = newForOp.getResults();
902       unsigned returnIdx =
903           iterOpOperand.getOperandNumber() - op.getNumControlOperands();
904       replacements[returnIdx] = rewriter.create<tensor::CastOp>(
905           op.getLoc(), incomingCast.getDest().getType(),
906           replacements[returnIdx]);
907       rewriter.replaceOp(op, replacements);
908       return success();
909     }
910     return failure();
911   }
912 };
913 
914 /// Canonicalize the iter_args of an scf::ForOp that involve a
915 /// `bufferization.to_tensor` and for which only the last loop iteration is
916 /// actually visible outside of the loop. The canonicalization looks for a
917 /// pattern such as:
918 /// ```
919 ///    %t0 = ... : tensor_type
920 ///    %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
921 ///      ...
922 ///      // %m is either buffer_cast(%bb00) or defined above the loop
923 ///      %m... : memref_type
924 ///      ... // uses of %m with potential inplace updates
925 ///      %new_tensor = bufferization.to_tensor %m : memref_type
926 ///      ...
927 ///      scf.yield %new_tensor : tensor_type
928 ///    }
929 /// ```
930 ///
931 /// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
932 /// `%m = buffer_cast %bb0` op that feeds into the yielded
933 /// `bufferization.to_tensor` op.
934 ///
935 /// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
936 /// occurs between `bufferization.to_tensor and yield then the value %0
937 /// visible outside of the loop is the last `bufferization.to_tensor`
938 /// produced in the loop.
939 ///
940 /// For now, we approximate the absence of aliasing by only supporting the case
941 /// when the bufferization.to_tensor is the operation immediately preceding
942 /// the yield.
943 //
944 /// The canonicalization rewrites the pattern as:
945 /// ```
946 ///    // %m is either a buffer_cast or defined above
947 ///    %m... : memref_type
948 ///    scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
949 ///      ... // uses of %m with potential inplace updates
950 ///      scf.yield %bb0: tensor_type
951 ///    }
952 ///    %0 = bufferization.to_tensor %m : memref_type
953 /// ```
954 ///
955 /// A later bbArg canonicalization will further rewrite as:
956 /// ```
957 ///    // %m is either a buffer_cast or defined above
958 ///    %m... : memref_type
959 ///    scf.for ... { // no iter_args
960 ///      ... // uses of %m with potential inplace updates
961 ///    }
962 ///    %0 = bufferization.to_tensor %m : memref_type
963 /// ```
964 struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
965   using OpRewritePattern<ForOp>::OpRewritePattern;
966 
967   LogicalResult matchAndRewrite(ForOp forOp,
968                                 PatternRewriter &rewriter) const override {
969     assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
970            "unexpected multiple blocks");
971 
972     Location loc = forOp.getLoc();
973     DenseMap<Value, Value> replacements;
974     for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
975       unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
976       auto yieldOp =
977           cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
978       Value yieldVal = yieldOp->getOperand(idx);
979       auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
980       bool isTensor = bbArg.getType().isa<TensorType>();
981 
982       bufferization::ToMemrefOp tensorToMemref;
983       // Either bbArg has no use or it has a single buffer_cast use.
984       if (bbArg.hasOneUse())
985         tensorToMemref =
986             dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
987       if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
988         continue;
989       // If tensorToMemref is present, it must feed into the `ToTensorOp`.
990       if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
991         continue;
992       // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
993       // must be before `ToTensorOp` in the block so that the lastWrite
994       // property is not subject to additional side-effects.
995       // For now, we only support the case when ToTensorOp appears
996       // immediately before the terminator.
997       if (tensorLoadOp->getNextNode() != yieldOp)
998         continue;
999 
1000       // Clone the optional tensorToMemref before forOp.
1001       if (tensorToMemref) {
1002         rewriter.setInsertionPoint(forOp);
1003         rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
1004             tensorToMemref, tensorToMemref.getMemref().getType(),
1005             tensorToMemref.getTensor());
1006       }
1007 
1008       // Clone the tensorLoad after forOp.
1009       rewriter.setInsertionPointAfter(forOp);
1010       Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
1011           loc, tensorLoadOp.getMemref());
1012       Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
1013       replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1014 
1015       // Make the terminator just yield the bbArg, the old tensorLoadOp + the
1016       // old bbArg (that is now directly yielded) will canonicalize away.
1017       rewriter.startRootUpdate(yieldOp);
1018       yieldOp.setOperand(idx, bbArg);
1019       rewriter.finalizeRootUpdate(yieldOp);
1020     }
1021     if (replacements.empty())
1022       return failure();
1023 
1024     // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
1025     // replaces the whole op and erase it unconditionally. This is wrong for
1026     // `forOp` as it generally contains ops with side effects.
1027     // Instead, use `rewriter.replaceOpWithIf`.
1028     SmallVector<Value> newResults;
1029     newResults.reserve(forOp.getNumResults());
1030     for (Value v : forOp.getResults()) {
1031       auto it = replacements.find(v);
1032       newResults.push_back((it != replacements.end()) ? it->second : v);
1033     }
1034     unsigned idx = 0;
1035     rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
1036       return op.get() != newResults[idx++];
1037     });
1038     return success();
1039   }
1040 };
1041 } // namespace
1042 
1043 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1044                                         MLIRContext *context) {
1045   results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1046               LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // ForeachThreadOp
1051 //===----------------------------------------------------------------------===//
1052 
1053 LogicalResult ForeachThreadOp::verify() {
1054   // Call terminator's verify to produce most informative error messages.
1055   if (failed(getTerminator().verify()))
1056     return failure();
1057 
1058   // Check that the body defines as single block argument for the thread index.
1059   auto *body = getBody();
1060   if (body->getNumArguments() != getRank())
1061     return emitOpError("region expects ") << getRank() << " arguments";
1062 
1063   // Verify consistency between the result types and the terminator.
1064   auto terminatorTypes = getTerminator().yieldedTypes();
1065   auto opResults = getResults();
1066   if (opResults.size() != terminatorTypes.size())
1067     return emitOpError("produces ")
1068            << opResults.size() << " results, but its terminator yields "
1069            << terminatorTypes.size() << " value(s)";
1070   unsigned i = 0;
1071   for (auto e : llvm::zip(terminatorTypes, opResults)) {
1072     if (std::get<0>(e) != std::get<1>(e).getType())
1073       return emitOpError() << "type mismatch between result " << i << " ("
1074                            << std::get<1>(e).getType() << ") and terminator ("
1075                            << std::get<0>(e) << ")";
1076     i++;
1077   }
1078   return success();
1079 }
1080 
1081 void ForeachThreadOp::print(OpAsmPrinter &p) {
1082   p << " (";
1083   llvm::interleaveComma(getThreadIndices(), p);
1084   p << ") in (";
1085   llvm::interleaveComma(getNumThreads(), p);
1086   p << ") -> (" << getResultTypes() << ") ";
1087   p.printRegion(getRegion(),
1088                 /*printEntryBlockArgs=*/false,
1089                 /*printBlockTerminators=*/getNumResults() > 0);
1090   p.printOptionalAttrDict(getOperation()->getAttrs());
1091 }
1092 
1093 ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
1094                                    OperationState &result) {
1095   auto &builder = parser.getBuilder();
1096   // Parse an opening `(` followed by thread index variables followed by `)`
1097   // TODO: when we can refer to such "induction variable"-like handles from the
1098   // declarative assembly format, we can implement the parser as a custom hook.
1099   SmallVector<OpAsmParser::Argument, 4> threadIndices;
1100   if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
1101     return failure();
1102 
1103   // Parse `in` threadNums.
1104   SmallVector<OpAsmParser::UnresolvedOperand, 4> threadNums;
1105   if (parser.parseKeyword("in") ||
1106       parser.parseOperandList(threadNums, threadIndices.size(),
1107                               OpAsmParser::Delimiter::Paren) ||
1108       parser.resolveOperands(threadNums, builder.getIndexType(),
1109                              result.operands))
1110     return failure();
1111 
1112   // Parse optional results.
1113   if (parser.parseOptionalArrowTypeList(result.types))
1114     return failure();
1115 
1116   // Parse region.
1117   std::unique_ptr<Region> region = std::make_unique<Region>();
1118   for (auto &idx : threadIndices)
1119     idx.type = builder.getIndexType();
1120   if (parser.parseRegion(*region, threadIndices))
1121     return failure();
1122 
1123   // Ensure terminator and move region.
1124   OpBuilder b(builder.getContext());
1125   ForeachThreadOp::ensureTerminator(*region, b, result.location);
1126   result.addRegion(std::move(region));
1127 
1128   // Parse the optional attribute list.
1129   if (parser.parseOptionalAttrDict(result.attributes))
1130     return failure();
1131 
1132   return success();
1133 }
1134 
1135 // Bodyless builder, result types must be specified.
1136 void ForeachThreadOp::build(mlir::OpBuilder &builder,
1137                             mlir::OperationState &result, TypeRange resultTypes,
1138                             ValueRange numThreads,
1139                             ArrayRef<int64_t> threadDimMapping) {
1140   result.addOperands(numThreads);
1141   result.addAttribute(
1142       // TODO: getThreadDimMappingAttrName() but it is not a static member.
1143       "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
1144 
1145   Region *bodyRegion = result.addRegion();
1146   OpBuilder::InsertionGuard g(builder);
1147   // createBlock sets the IP inside the block.
1148   // Generally we would guard against that but the default ensureTerminator impl
1149   // expects it ..
1150   builder.createBlock(bodyRegion);
1151   Block &bodyBlock = bodyRegion->front();
1152   bodyBlock.addArguments(
1153       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1154       SmallVector<Location>(numThreads.size(), result.location));
1155   ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
1156   result.addTypes(resultTypes);
1157 }
1158 
1159 // Builder that takes a bodyBuilder lambda, result types are inferred from
1160 // the terminator.
1161 void ForeachThreadOp::build(
1162     mlir::OpBuilder &builder, mlir::OperationState &result,
1163     ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
1164     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1165   result.addOperands(numThreads);
1166   result.addAttribute(
1167       // TODO: getThreadDimMappingAttrName() but it is not a static member.
1168       "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
1169 
1170   OpBuilder::InsertionGuard g(builder);
1171   Region *bodyRegion = result.addRegion();
1172   builder.createBlock(bodyRegion);
1173   Block &bodyBlock = bodyRegion->front();
1174   bodyBlock.addArguments(
1175       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1176       SmallVector<Location>(numThreads.size(), result.location));
1177 
1178   OpBuilder::InsertionGuard guard(builder);
1179   builder.setInsertionPointToStart(&bodyBlock);
1180   bodyBuilder(builder, result.location, bodyBlock.getArguments());
1181   auto terminator =
1182       llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
1183   assert(terminator &&
1184          "expected bodyBuilder to create PerformConcurrentlyOp terminator");
1185   result.addTypes(terminator.yieldedTypes());
1186 }
1187 
1188 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1189 // unaware of the fact that our terminator also needs a region to be
1190 // well-formed. We override it here to ensure that we do the right thing.
1191 void ForeachThreadOp::ensureTerminator(Region &region, OpBuilder &builder,
1192                                        Location loc) {
1193   OpTrait::SingleBlockImplicitTerminator<PerformConcurrentlyOp>::Impl<
1194       ForeachThreadOp>::ensureTerminator(region, builder, loc);
1195   auto terminator =
1196       llvm::dyn_cast<PerformConcurrentlyOp>(region.front().getTerminator());
1197   if (terminator.getRegion().empty())
1198     builder.createBlock(&terminator.getRegion());
1199 }
1200 
1201 PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
1202   return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
1203 }
1204 
1205 ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
1206   auto tidxArg = val.dyn_cast<BlockArgument>();
1207   if (!tidxArg)
1208     return ForeachThreadOp();
1209   assert(tidxArg.getOwner() && "unlinked block argument");
1210   auto *containingOp = tidxArg.getOwner()->getParentOp();
1211   return dyn_cast<ForeachThreadOp>(containingOp);
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // ParallelInsertSliceOp
1216 //===----------------------------------------------------------------------===//
1217 
1218 OpResult ParallelInsertSliceOp::getTiedOpResult() {
1219   auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>();
1220   assert(foreachThreadOp && "unlinked ParallelInsertSliceOp");
1221   PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator();
1222   for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) {
1223     Operation &nextOp = it.value();
1224     if (&nextOp == getOperation())
1225       return foreachThreadOp->getResult(it.index());
1226   }
1227   llvm_unreachable("ParallelInsertSliceOp not found");
1228 }
1229 
1230 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
1231 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
1232                                   Value source, Value dest,
1233                                   ArrayRef<OpFoldResult> offsets,
1234                                   ArrayRef<OpFoldResult> sizes,
1235                                   ArrayRef<OpFoldResult> strides,
1236                                   ArrayRef<NamedAttribute> attrs) {
1237   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1238   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1239   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1240                              ShapedType::kDynamicStrideOrOffset);
1241   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1242                              ShapedType::kDynamicSize);
1243   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1244                              ShapedType::kDynamicStrideOrOffset);
1245   build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
1246         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1247         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1248   result.addAttributes(attrs);
1249 }
1250 
1251 // Build a ParallelInsertSliceOp with dynamic entries.
1252 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
1253                                   Value source, Value dest, ValueRange offsets,
1254                                   ValueRange sizes, ValueRange strides,
1255                                   ArrayRef<NamedAttribute> attrs) {
1256   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1257       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1258   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1259       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1260   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1261       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1262   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1263 }
1264 
1265 namespace {
1266 /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
1267 class ParallelInsertSliceOpConstantArgumentFolder final
1268     : public OpRewritePattern<ParallelInsertSliceOp> {
1269 public:
1270   using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
1271 
1272   LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
1273                                 PatternRewriter &rewriter) const override {
1274     // No constant operand, just return.
1275     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1276           return matchPattern(operand, matchConstantIndex());
1277         }))
1278       return failure();
1279 
1280     // At least one of offsets/sizes/strides is a new constant.
1281     // Form the new list of operands and constant attributes from the
1282     // existing.
1283     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1284     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1285     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1286     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1287     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1288     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1289 
1290     // Create the new op in canonical form.
1291     rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
1292         insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
1293         mixedOffsets, mixedSizes, mixedStrides);
1294     return success();
1295   }
1296 };
1297 } // namespace
1298 
1299 /// Fold a parallel_insert_slice source coming from a tensor.cast op.
1300 ///
1301 /// Example:
1302 /// ```
1303 /// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
1304 ///   %1 = compute_some_tensor() : tensor<64xf32>
1305 ///   %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
1306 ///   scf.foreach_thread.perform_concurrently {
1307 ///     scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
1308 ///        tensor<?xf32> into tensor<128xf32>
1309 ///   }
1310 /// }
1311 /// ```
1312 ///
1313 /// is folded into:
1314 /// ```
1315 /// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
1316 ///   %1 = compute_some_tensor() : tensor<64xf32>
1317 ///   scf.foreach_thread.perform_concurrently {
1318 ///     scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
1319 ///        tensor<64xf32> into tensor<128xf32>
1320 ///   }
1321 /// }
1322 /// ```
1323 LogicalResult
1324 ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
1325                             SmallVectorImpl<OpFoldResult> &results) {
1326   auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
1327   if (!sourceCast)
1328     return failure();
1329   getSourceMutable().assign(sourceCast.getSource());
1330   return success();
1331 }
1332 
1333 void ParallelInsertSliceOp::getCanonicalizationPatterns(
1334     RewritePatternSet &results, MLIRContext *context) {
1335   results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
1336 }
1337 
1338 //===----------------------------------------------------------------------===//
1339 // PerformConcurrentlyOp
1340 //===----------------------------------------------------------------------===//
1341 
1342 // Build a PerformConcurrentlyOp with mixed static and dynamic entries.
1343 void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
1344   OpBuilder::InsertionGuard g(b);
1345   Region *bodyRegion = result.addRegion();
1346   b.createBlock(bodyRegion);
1347 }
1348 
1349 LogicalResult PerformConcurrentlyOp::verify() {
1350   // TODO: PerformConcurrentlyOpInterface.
1351   for (const Operation &op : getRegion().front().getOperations())
1352     if (!isa<ParallelInsertSliceOp>(op))
1353       return emitOpError(
1354           "expected only scf.foreach_thread.parallel_insert_slice ops");
1355   return success();
1356 }
1357 
1358 void PerformConcurrentlyOp::print(OpAsmPrinter &p) {
1359   p << " ";
1360   p.printRegion(getRegion(),
1361                 /*printEntryBlockArgs=*/false,
1362                 /*printBlockTerminators=*/false);
1363   p.printOptionalAttrDict(getOperation()->getAttrs());
1364 }
1365 
1366 ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
1367                                          OperationState &result) {
1368   auto &builder = parser.getBuilder();
1369 
1370   SmallVector<OpAsmParser::Argument, 8> regionOperands;
1371   std::unique_ptr<Region> region = std::make_unique<Region>();
1372   if (parser.parseRegion(*region, regionOperands))
1373     return failure();
1374 
1375   if (region->empty())
1376     OpBuilder(builder.getContext()).createBlock(region.get());
1377   result.addRegion(std::move(region));
1378 
1379   // Parse the optional attribute list.
1380   if (parser.parseOptionalAttrDict(result.attributes))
1381     return failure();
1382   return success();
1383 }
1384 
1385 SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
1386   return llvm::to_vector<4>(
1387       llvm::map_range(this->yieldingOps(), [](Operation &op) {
1388         auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
1389         return insertSliceOp ? insertSliceOp.yieldedType() : Type();
1390       }));
1391 }
1392 
1393 llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::yieldingOps() {
1394   return getRegion().front().getOperations();
1395 }
1396 
1397 //===----------------------------------------------------------------------===//
1398 // IfOp
1399 //===----------------------------------------------------------------------===//
1400 
1401 bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
1402   assert(a && "expected non-empty operation");
1403   assert(b && "expected non-empty operation");
1404 
1405   IfOp ifOp = a->getParentOfType<IfOp>();
1406   while (ifOp) {
1407     // Check if b is inside ifOp. (We already know that a is.)
1408     if (ifOp->isProperAncestor(b))
1409       // b is contained in ifOp. a and b are in mutually exclusive branches if
1410       // they are in different blocks of ifOp.
1411       return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1412              static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1413     // Check next enclosing IfOp.
1414     ifOp = ifOp->getParentOfType<IfOp>();
1415   }
1416 
1417   // Could not find a common IfOp among a's and b's ancestors.
1418   return false;
1419 }
1420 
1421 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1422                  bool withElseRegion) {
1423   build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
1424 }
1425 
1426 void IfOp::build(OpBuilder &builder, OperationState &result,
1427                  TypeRange resultTypes, Value cond, bool withElseRegion) {
1428   auto addTerminator = [&](OpBuilder &nested, Location loc) {
1429     if (resultTypes.empty())
1430       IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
1431                              loc);
1432   };
1433 
1434   build(builder, result, resultTypes, cond, addTerminator,
1435         withElseRegion ? addTerminator
1436                        : function_ref<void(OpBuilder &, Location)>());
1437 }
1438 
1439 void IfOp::build(OpBuilder &builder, OperationState &result,
1440                  TypeRange resultTypes, Value cond,
1441                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1442                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1443   assert(thenBuilder && "the builder callback for 'then' must be present");
1444 
1445   result.addOperands(cond);
1446   result.addTypes(resultTypes);
1447 
1448   OpBuilder::InsertionGuard guard(builder);
1449   Region *thenRegion = result.addRegion();
1450   builder.createBlock(thenRegion);
1451   thenBuilder(builder, result.location);
1452 
1453   Region *elseRegion = result.addRegion();
1454   if (!elseBuilder)
1455     return;
1456 
1457   builder.createBlock(elseRegion);
1458   elseBuilder(builder, result.location);
1459 }
1460 
1461 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1462                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1463                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1464   build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
1465 }
1466 
1467 LogicalResult IfOp::verify() {
1468   if (getNumResults() != 0 && getElseRegion().empty())
1469     return emitOpError("must have an else block if defining values");
1470   return success();
1471 }
1472 
1473 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1474   // Create the regions for 'then'.
1475   result.regions.reserve(2);
1476   Region *thenRegion = result.addRegion();
1477   Region *elseRegion = result.addRegion();
1478 
1479   auto &builder = parser.getBuilder();
1480   OpAsmParser::UnresolvedOperand cond;
1481   Type i1Type = builder.getIntegerType(1);
1482   if (parser.parseOperand(cond) ||
1483       parser.resolveOperand(cond, i1Type, result.operands))
1484     return failure();
1485   // Parse optional results type list.
1486   if (parser.parseOptionalArrowTypeList(result.types))
1487     return failure();
1488   // Parse the 'then' region.
1489   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1490     return failure();
1491   IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1492 
1493   // If we find an 'else' keyword then parse the 'else' region.
1494   if (!parser.parseOptionalKeyword("else")) {
1495     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1496       return failure();
1497     IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1498   }
1499 
1500   // Parse the optional attribute list.
1501   if (parser.parseOptionalAttrDict(result.attributes))
1502     return failure();
1503   return success();
1504 }
1505 
1506 void IfOp::print(OpAsmPrinter &p) {
1507   bool printBlockTerminators = false;
1508 
1509   p << " " << getCondition();
1510   if (!getResults().empty()) {
1511     p << " -> (" << getResultTypes() << ")";
1512     // Print yield explicitly if the op defines values.
1513     printBlockTerminators = true;
1514   }
1515   p << ' ';
1516   p.printRegion(getThenRegion(),
1517                 /*printEntryBlockArgs=*/false,
1518                 /*printBlockTerminators=*/printBlockTerminators);
1519 
1520   // Print the 'else' regions if it exists and has a block.
1521   auto &elseRegion = getElseRegion();
1522   if (!elseRegion.empty()) {
1523     p << " else ";
1524     p.printRegion(elseRegion,
1525                   /*printEntryBlockArgs=*/false,
1526                   /*printBlockTerminators=*/printBlockTerminators);
1527   }
1528 
1529   p.printOptionalAttrDict((*this)->getAttrs());
1530 }
1531 
1532 /// Given the region at `index`, or the parent operation if `index` is None,
1533 /// return the successor regions. These are the regions that may be selected
1534 /// during the flow of control. `operands` is a set of optional attributes that
1535 /// correspond to a constant value for each operand, or null if that operand is
1536 /// not a constant.
1537 void IfOp::getSuccessorRegions(Optional<unsigned> index,
1538                                ArrayRef<Attribute> operands,
1539                                SmallVectorImpl<RegionSuccessor> &regions) {
1540   // The `then` and the `else` region branch back to the parent operation.
1541   if (index) {
1542     regions.push_back(RegionSuccessor(getResults()));
1543     return;
1544   }
1545 
1546   // Don't consider the else region if it is empty.
1547   Region *elseRegion = &this->getElseRegion();
1548   if (elseRegion->empty())
1549     elseRegion = nullptr;
1550 
1551   // Otherwise, the successor is dependent on the condition.
1552   bool condition;
1553   if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1554     condition = condAttr.getValue().isOneValue();
1555   } else {
1556     // If the condition isn't constant, both regions may be executed.
1557     regions.push_back(RegionSuccessor(&getThenRegion()));
1558     // If the else region does not exist, it is not a viable successor.
1559     if (elseRegion)
1560       regions.push_back(RegionSuccessor(elseRegion));
1561     return;
1562   }
1563 
1564   // Add the successor regions using the condition.
1565   regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
1566 }
1567 
1568 LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
1569                          SmallVectorImpl<OpFoldResult> &results) {
1570   // if (!c) then A() else B() -> if c then B() else A()
1571   if (getElseRegion().empty())
1572     return failure();
1573 
1574   arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
1575   if (!xorStmt)
1576     return failure();
1577 
1578   if (!matchPattern(xorStmt.getRhs(), m_One()))
1579     return failure();
1580 
1581   getConditionMutable().assign(xorStmt.getLhs());
1582   Block *thenBlock = &getThenRegion().front();
1583   // It would be nicer to use iplist::swap, but that has no implemented
1584   // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
1585   getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
1586                                      getElseRegion().getBlocks());
1587   getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
1588                                      getThenRegion().getBlocks(), thenBlock);
1589   return success();
1590 }
1591 
1592 void IfOp::getRegionInvocationBounds(
1593     ArrayRef<Attribute> operands,
1594     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1595   if (auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
1596     // If the condition is known, then one region is known to be executed once
1597     // and the other zero times.
1598     invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
1599     invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
1600   } else {
1601     // Non-constant condition. Each region may be executed 0 or 1 times.
1602     invocationBounds.assign(2, {0, 1});
1603   }
1604 }
1605 
1606 namespace {
1607 // Pattern to remove unused IfOp results.
1608 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
1609   using OpRewritePattern<IfOp>::OpRewritePattern;
1610 
1611   void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
1612                     PatternRewriter &rewriter) const {
1613     // Move all operations to the destination block.
1614     rewriter.mergeBlocks(source, dest);
1615     // Replace the yield op by one that returns only the used values.
1616     auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
1617     SmallVector<Value, 4> usedOperands;
1618     llvm::transform(usedResults, std::back_inserter(usedOperands),
1619                     [&](OpResult result) {
1620                       return yieldOp.getOperand(result.getResultNumber());
1621                     });
1622     rewriter.updateRootInPlace(yieldOp,
1623                                [&]() { yieldOp->setOperands(usedOperands); });
1624   }
1625 
1626   LogicalResult matchAndRewrite(IfOp op,
1627                                 PatternRewriter &rewriter) const override {
1628     // Compute the list of used results.
1629     SmallVector<OpResult, 4> usedResults;
1630     llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
1631                   [](OpResult result) { return !result.use_empty(); });
1632 
1633     // Replace the operation if only a subset of its results have uses.
1634     if (usedResults.size() == op.getNumResults())
1635       return failure();
1636 
1637     // Compute the result types of the replacement operation.
1638     SmallVector<Type, 4> newTypes;
1639     llvm::transform(usedResults, std::back_inserter(newTypes),
1640                     [](OpResult result) { return result.getType(); });
1641 
1642     // Create a replacement operation with empty then and else regions.
1643     auto emptyBuilder = [](OpBuilder &, Location) {};
1644     auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
1645                                        emptyBuilder, emptyBuilder);
1646 
1647     // Move the bodies and replace the terminators (note there is a then and
1648     // an else region since the operation returns results).
1649     transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
1650     transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
1651 
1652     // Replace the operation by the new one.
1653     SmallVector<Value, 4> repResults(op.getNumResults());
1654     for (const auto &en : llvm::enumerate(usedResults))
1655       repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
1656     rewriter.replaceOp(op, repResults);
1657     return success();
1658   }
1659 };
1660 
1661 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
1662   using OpRewritePattern<IfOp>::OpRewritePattern;
1663 
1664   LogicalResult matchAndRewrite(IfOp op,
1665                                 PatternRewriter &rewriter) const override {
1666     auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
1667     if (!constant)
1668       return failure();
1669 
1670     if (constant.getValue().cast<BoolAttr>().getValue())
1671       replaceOpWithRegion(rewriter, op, op.getThenRegion());
1672     else if (!op.getElseRegion().empty())
1673       replaceOpWithRegion(rewriter, op, op.getElseRegion());
1674     else
1675       rewriter.eraseOp(op);
1676 
1677     return success();
1678   }
1679 };
1680 
1681 /// Hoist any yielded results whose operands are defined outside
1682 /// the if, to a select instruction.
1683 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
1684   using OpRewritePattern<IfOp>::OpRewritePattern;
1685 
1686   LogicalResult matchAndRewrite(IfOp op,
1687                                 PatternRewriter &rewriter) const override {
1688     if (op->getNumResults() == 0)
1689       return failure();
1690 
1691     auto cond = op.getCondition();
1692     auto thenYieldArgs = op.thenYield().getOperands();
1693     auto elseYieldArgs = op.elseYield().getOperands();
1694 
1695     SmallVector<Type> nonHoistable;
1696     for (const auto &it :
1697          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1698       Value trueVal = std::get<0>(it.value());
1699       Value falseVal = std::get<1>(it.value());
1700       if (&op.getThenRegion() == trueVal.getParentRegion() ||
1701           &op.getElseRegion() == falseVal.getParentRegion())
1702         nonHoistable.push_back(trueVal.getType());
1703     }
1704     // Early exit if there aren't any yielded values we can
1705     // hoist outside the if.
1706     if (nonHoistable.size() == op->getNumResults())
1707       return failure();
1708 
1709     IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
1710     if (replacement.thenBlock())
1711       rewriter.eraseBlock(replacement.thenBlock());
1712     replacement.getThenRegion().takeBody(op.getThenRegion());
1713     replacement.getElseRegion().takeBody(op.getElseRegion());
1714 
1715     SmallVector<Value> results(op->getNumResults());
1716     assert(thenYieldArgs.size() == results.size());
1717     assert(elseYieldArgs.size() == results.size());
1718 
1719     SmallVector<Value> trueYields;
1720     SmallVector<Value> falseYields;
1721     rewriter.setInsertionPoint(replacement);
1722     for (const auto &it :
1723          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1724       Value trueVal = std::get<0>(it.value());
1725       Value falseVal = std::get<1>(it.value());
1726       if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
1727           &replacement.getElseRegion() == falseVal.getParentRegion()) {
1728         results[it.index()] = replacement.getResult(trueYields.size());
1729         trueYields.push_back(trueVal);
1730         falseYields.push_back(falseVal);
1731       } else if (trueVal == falseVal)
1732         results[it.index()] = trueVal;
1733       else
1734         results[it.index()] = rewriter.create<arith::SelectOp>(
1735             op.getLoc(), cond, trueVal, falseVal);
1736     }
1737 
1738     rewriter.setInsertionPointToEnd(replacement.thenBlock());
1739     rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
1740 
1741     rewriter.setInsertionPointToEnd(replacement.elseBlock());
1742     rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
1743 
1744     rewriter.replaceOp(op, results);
1745     return success();
1746   }
1747 };
1748 
1749 /// Allow the true region of an if to assume the condition is true
1750 /// and vice versa. For example:
1751 ///
1752 ///   scf.if %cmp {
1753 ///      print(%cmp)
1754 ///   }
1755 ///
1756 ///  becomes
1757 ///
1758 ///   scf.if %cmp {
1759 ///      print(true)
1760 ///   }
1761 ///
1762 struct ConditionPropagation : public OpRewritePattern<IfOp> {
1763   using OpRewritePattern<IfOp>::OpRewritePattern;
1764 
1765   LogicalResult matchAndRewrite(IfOp op,
1766                                 PatternRewriter &rewriter) const override {
1767     // Early exit if the condition is constant since replacing a constant
1768     // in the body with another constant isn't a simplification.
1769     if (op.getCondition().getDefiningOp<arith::ConstantOp>())
1770       return failure();
1771 
1772     bool changed = false;
1773     mlir::Type i1Ty = rewriter.getI1Type();
1774 
1775     // These variables serve to prevent creating duplicate constants
1776     // and hold constant true or false values.
1777     Value constantTrue = nullptr;
1778     Value constantFalse = nullptr;
1779 
1780     for (OpOperand &use :
1781          llvm::make_early_inc_range(op.getCondition().getUses())) {
1782       if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1783         changed = true;
1784 
1785         if (!constantTrue)
1786           constantTrue = rewriter.create<arith::ConstantOp>(
1787               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
1788 
1789         rewriter.updateRootInPlace(use.getOwner(),
1790                                    [&]() { use.set(constantTrue); });
1791       } else if (op.getElseRegion().isAncestor(
1792                      use.getOwner()->getParentRegion())) {
1793         changed = true;
1794 
1795         if (!constantFalse)
1796           constantFalse = rewriter.create<arith::ConstantOp>(
1797               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
1798 
1799         rewriter.updateRootInPlace(use.getOwner(),
1800                                    [&]() { use.set(constantFalse); });
1801       }
1802     }
1803 
1804     return success(changed);
1805   }
1806 };
1807 
1808 /// Remove any statements from an if that are equivalent to the condition
1809 /// or its negation. For example:
1810 ///
1811 ///    %res:2 = scf.if %cmp {
1812 ///       yield something(), true
1813 ///    } else {
1814 ///       yield something2(), false
1815 ///    }
1816 ///    print(%res#1)
1817 ///
1818 ///  becomes
1819 ///    %res = scf.if %cmp {
1820 ///       yield something()
1821 ///    } else {
1822 ///       yield something2()
1823 ///    }
1824 ///    print(%cmp)
1825 ///
1826 /// Additionally if both branches yield the same value, replace all uses
1827 /// of the result with the yielded value.
1828 ///
1829 ///    %res:2 = scf.if %cmp {
1830 ///       yield something(), %arg1
1831 ///    } else {
1832 ///       yield something2(), %arg1
1833 ///    }
1834 ///    print(%res#1)
1835 ///
1836 ///  becomes
1837 ///    %res = scf.if %cmp {
1838 ///       yield something()
1839 ///    } else {
1840 ///       yield something2()
1841 ///    }
1842 ///    print(%arg1)
1843 ///
1844 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
1845   using OpRewritePattern<IfOp>::OpRewritePattern;
1846 
1847   LogicalResult matchAndRewrite(IfOp op,
1848                                 PatternRewriter &rewriter) const override {
1849     // Early exit if there are no results that could be replaced.
1850     if (op.getNumResults() == 0)
1851       return failure();
1852 
1853     auto trueYield =
1854         cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
1855     auto falseYield =
1856         cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
1857 
1858     rewriter.setInsertionPoint(op->getBlock(),
1859                                op.getOperation()->getIterator());
1860     bool changed = false;
1861     Type i1Ty = rewriter.getI1Type();
1862     for (auto tup : llvm::zip(trueYield.getResults(), falseYield.getResults(),
1863                               op.getResults())) {
1864       Value trueResult, falseResult, opResult;
1865       std::tie(trueResult, falseResult, opResult) = tup;
1866 
1867       if (trueResult == falseResult) {
1868         if (!opResult.use_empty()) {
1869           opResult.replaceAllUsesWith(trueResult);
1870           changed = true;
1871         }
1872         continue;
1873       }
1874 
1875       auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
1876       if (!trueYield)
1877         continue;
1878 
1879       if (!trueYield.getType().isInteger(1))
1880         continue;
1881 
1882       auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
1883       if (!falseYield)
1884         continue;
1885 
1886       bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
1887       bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
1888       if (!trueVal && falseVal) {
1889         if (!opResult.use_empty()) {
1890           Value notCond = rewriter.create<arith::XOrIOp>(
1891               op.getLoc(), op.getCondition(),
1892               rewriter.create<arith::ConstantOp>(
1893                   op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
1894           opResult.replaceAllUsesWith(notCond);
1895           changed = true;
1896         }
1897       }
1898       if (trueVal && !falseVal) {
1899         if (!opResult.use_empty()) {
1900           opResult.replaceAllUsesWith(op.getCondition());
1901           changed = true;
1902         }
1903       }
1904     }
1905     return success(changed);
1906   }
1907 };
1908 
1909 /// Merge any consecutive scf.if's with the same condition.
1910 ///
1911 ///    scf.if %cond {
1912 ///       firstCodeTrue();...
1913 ///    } else {
1914 ///       firstCodeFalse();...
1915 ///    }
1916 ///    %res = scf.if %cond {
1917 ///       secondCodeTrue();...
1918 ///    } else {
1919 ///       secondCodeFalse();...
1920 ///    }
1921 ///
1922 ///  becomes
1923 ///    %res = scf.if %cmp {
1924 ///       firstCodeTrue();...
1925 ///       secondCodeTrue();...
1926 ///    } else {
1927 ///       firstCodeFalse();...
1928 ///       secondCodeFalse();...
1929 ///    }
1930 struct CombineIfs : public OpRewritePattern<IfOp> {
1931   using OpRewritePattern<IfOp>::OpRewritePattern;
1932 
1933   LogicalResult matchAndRewrite(IfOp nextIf,
1934                                 PatternRewriter &rewriter) const override {
1935     Block *parent = nextIf->getBlock();
1936     if (nextIf == &parent->front())
1937       return failure();
1938 
1939     auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
1940     if (!prevIf)
1941       return failure();
1942 
1943     // Determine the logical then/else blocks when prevIf's
1944     // condition is used. Null means the block does not exist
1945     // in that case (e.g. empty else). If neither of these
1946     // are set, the two conditions cannot be compared.
1947     Block *nextThen = nullptr;
1948     Block *nextElse = nullptr;
1949     if (nextIf.getCondition() == prevIf.getCondition()) {
1950       nextThen = nextIf.thenBlock();
1951       if (!nextIf.getElseRegion().empty())
1952         nextElse = nextIf.elseBlock();
1953     }
1954     if (arith::XOrIOp notv =
1955             nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1956       if (notv.getLhs() == prevIf.getCondition() &&
1957           matchPattern(notv.getRhs(), m_One())) {
1958         nextElse = nextIf.thenBlock();
1959         if (!nextIf.getElseRegion().empty())
1960           nextThen = nextIf.elseBlock();
1961       }
1962     }
1963     if (arith::XOrIOp notv =
1964             prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1965       if (notv.getLhs() == nextIf.getCondition() &&
1966           matchPattern(notv.getRhs(), m_One())) {
1967         nextElse = nextIf.thenBlock();
1968         if (!nextIf.getElseRegion().empty())
1969           nextThen = nextIf.elseBlock();
1970       }
1971     }
1972 
1973     if (!nextThen && !nextElse)
1974       return failure();
1975 
1976     SmallVector<Value> prevElseYielded;
1977     if (!prevIf.getElseRegion().empty())
1978       prevElseYielded = prevIf.elseYield().getOperands();
1979     // Replace all uses of return values of op within nextIf with the
1980     // corresponding yields
1981     for (auto it : llvm::zip(prevIf.getResults(),
1982                              prevIf.thenYield().getOperands(), prevElseYielded))
1983       for (OpOperand &use :
1984            llvm::make_early_inc_range(std::get<0>(it).getUses())) {
1985         if (nextThen && nextThen->getParent()->isAncestor(
1986                             use.getOwner()->getParentRegion())) {
1987           rewriter.startRootUpdate(use.getOwner());
1988           use.set(std::get<1>(it));
1989           rewriter.finalizeRootUpdate(use.getOwner());
1990         } else if (nextElse && nextElse->getParent()->isAncestor(
1991                                    use.getOwner()->getParentRegion())) {
1992           rewriter.startRootUpdate(use.getOwner());
1993           use.set(std::get<2>(it));
1994           rewriter.finalizeRootUpdate(use.getOwner());
1995         }
1996       }
1997 
1998     SmallVector<Type> mergedTypes(prevIf.getResultTypes());
1999     llvm::append_range(mergedTypes, nextIf.getResultTypes());
2000 
2001     IfOp combinedIf = rewriter.create<IfOp>(
2002         nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2003     rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2004 
2005     rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2006                                 combinedIf.getThenRegion(),
2007                                 combinedIf.getThenRegion().begin());
2008 
2009     if (nextThen) {
2010       YieldOp thenYield = combinedIf.thenYield();
2011       YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2012       rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2013       rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2014 
2015       SmallVector<Value> mergedYields(thenYield.getOperands());
2016       llvm::append_range(mergedYields, thenYield2.getOperands());
2017       rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2018       rewriter.eraseOp(thenYield);
2019       rewriter.eraseOp(thenYield2);
2020     }
2021 
2022     rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2023                                 combinedIf.getElseRegion(),
2024                                 combinedIf.getElseRegion().begin());
2025 
2026     if (nextElse) {
2027       if (combinedIf.getElseRegion().empty()) {
2028         rewriter.inlineRegionBefore(*nextElse->getParent(),
2029                                     combinedIf.getElseRegion(),
2030                                     combinedIf.getElseRegion().begin());
2031       } else {
2032         YieldOp elseYield = combinedIf.elseYield();
2033         YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2034         rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2035 
2036         rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2037 
2038         SmallVector<Value> mergedElseYields(elseYield.getOperands());
2039         llvm::append_range(mergedElseYields, elseYield2.getOperands());
2040 
2041         rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2042         rewriter.eraseOp(elseYield);
2043         rewriter.eraseOp(elseYield2);
2044       }
2045     }
2046 
2047     SmallVector<Value> prevValues;
2048     SmallVector<Value> nextValues;
2049     for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2050       if (pair.index() < prevIf.getNumResults())
2051         prevValues.push_back(pair.value());
2052       else
2053         nextValues.push_back(pair.value());
2054     }
2055     rewriter.replaceOp(prevIf, prevValues);
2056     rewriter.replaceOp(nextIf, nextValues);
2057     return success();
2058   }
2059 };
2060 
2061 /// Pattern to remove an empty else branch.
2062 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2063   using OpRewritePattern<IfOp>::OpRewritePattern;
2064 
2065   LogicalResult matchAndRewrite(IfOp ifOp,
2066                                 PatternRewriter &rewriter) const override {
2067     // Cannot remove else region when there are operation results.
2068     if (ifOp.getNumResults())
2069       return failure();
2070     Block *elseBlock = ifOp.elseBlock();
2071     if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2072       return failure();
2073     auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2074     rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2075                                 newIfOp.getThenRegion().begin());
2076     rewriter.eraseOp(ifOp);
2077     return success();
2078   }
2079 };
2080 
2081 /// Convert nested `if`s into `arith.andi` + single `if`.
2082 ///
2083 ///    scf.if %arg0 {
2084 ///      scf.if %arg1 {
2085 ///        ...
2086 ///        scf.yield
2087 ///      }
2088 ///      scf.yield
2089 ///    }
2090 ///  becomes
2091 ///
2092 ///    %0 = arith.andi %arg0, %arg1
2093 ///    scf.if %0 {
2094 ///      ...
2095 ///      scf.yield
2096 ///    }
2097 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2098   using OpRewritePattern<IfOp>::OpRewritePattern;
2099 
2100   LogicalResult matchAndRewrite(IfOp op,
2101                                 PatternRewriter &rewriter) const override {
2102     auto nestedOps = op.thenBlock()->without_terminator();
2103     // Nested `if` must be the only op in block.
2104     if (!llvm::hasSingleElement(nestedOps))
2105       return failure();
2106 
2107     // If there is an else block, it can only yield
2108     if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2109       return failure();
2110 
2111     auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2112     if (!nestedIf)
2113       return failure();
2114 
2115     if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2116       return failure();
2117 
2118     SmallVector<Value> thenYield(op.thenYield().getOperands());
2119     SmallVector<Value> elseYield;
2120     if (op.elseBlock())
2121       llvm::append_range(elseYield, op.elseYield().getOperands());
2122 
2123     // A list of indices for which we should upgrade the value yielded
2124     // in the else to a select.
2125     SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2126 
2127     // If the outer scf.if yields a value produced by the inner scf.if,
2128     // only permit combining if the value yielded when the condition
2129     // is false in the outer scf.if is the same value yielded when the
2130     // inner scf.if condition is false.
2131     // Note that the array access to elseYield will not go out of bounds
2132     // since it must have the same length as thenYield, since they both
2133     // come from the same scf.if.
2134     for (const auto &tup : llvm::enumerate(thenYield)) {
2135       if (tup.value().getDefiningOp() == nestedIf) {
2136         auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
2137         if (nestedIf.elseYield().getOperand(nestedIdx) !=
2138             elseYield[tup.index()]) {
2139           return failure();
2140         }
2141         // If the correctness test passes, we will yield
2142         // corresponding value from the inner scf.if
2143         thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2144         continue;
2145       }
2146 
2147       // Otherwise, we need to ensure the else block of the combined
2148       // condition still returns the same value when the outer condition is
2149       // true and the inner condition is false. This can be accomplished if
2150       // the then value is defined outside the outer scf.if and we replace the
2151       // value with a select that considers just the outer condition. Since
2152       // the else region contains just the yield, its yielded value is
2153       // defined outside the scf.if, by definition.
2154 
2155       // If the then value is defined within the scf.if, bail.
2156       if (tup.value().getParentRegion() == &op.getThenRegion()) {
2157         return failure();
2158       }
2159       elseYieldsToUpgradeToSelect.push_back(tup.index());
2160     }
2161 
2162     Location loc = op.getLoc();
2163     Value newCondition = rewriter.create<arith::AndIOp>(
2164         loc, op.getCondition(), nestedIf.getCondition());
2165     auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2166 
2167     SmallVector<Value> results;
2168     llvm::append_range(results, newIf.getResults());
2169     rewriter.setInsertionPoint(newIf);
2170 
2171     for (auto idx : elseYieldsToUpgradeToSelect)
2172       results[idx] = rewriter.create<arith::SelectOp>(
2173           op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2174 
2175     Block *newIfBlock = newIf.thenBlock();
2176     if (newIfBlock)
2177       rewriter.eraseOp(newIfBlock->getTerminator());
2178     else
2179       newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2180     rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2181     rewriter.setInsertionPointToEnd(newIf.thenBlock());
2182     rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2183     if (!elseYield.empty()) {
2184       rewriter.createBlock(&newIf.getElseRegion());
2185       rewriter.setInsertionPointToEnd(newIf.elseBlock());
2186       rewriter.create<YieldOp>(loc, elseYield);
2187     }
2188     rewriter.replaceOp(op, results);
2189     return success();
2190   }
2191 };
2192 
2193 } // namespace
2194 
2195 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2196                                        MLIRContext *context) {
2197   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2198               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2199               RemoveStaticCondition, RemoveUnusedResults,
2200               ReplaceIfYieldWithConditionOrValue>(context);
2201 }
2202 
2203 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2204 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2205 Block *IfOp::elseBlock() {
2206   Region &r = getElseRegion();
2207   if (r.empty())
2208     return nullptr;
2209   return &r.back();
2210 }
2211 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2212 
2213 //===----------------------------------------------------------------------===//
2214 // ParallelOp
2215 //===----------------------------------------------------------------------===//
2216 
2217 void ParallelOp::build(
2218     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2219     ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2220     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2221         bodyBuilderFn) {
2222   result.addOperands(lowerBounds);
2223   result.addOperands(upperBounds);
2224   result.addOperands(steps);
2225   result.addOperands(initVals);
2226   result.addAttribute(
2227       ParallelOp::getOperandSegmentSizeAttr(),
2228       builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
2229                                 static_cast<int32_t>(upperBounds.size()),
2230                                 static_cast<int32_t>(steps.size()),
2231                                 static_cast<int32_t>(initVals.size())}));
2232   result.addTypes(initVals.getTypes());
2233 
2234   OpBuilder::InsertionGuard guard(builder);
2235   unsigned numIVs = steps.size();
2236   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2237   SmallVector<Location, 8> argLocs(numIVs, result.location);
2238   Region *bodyRegion = result.addRegion();
2239   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2240 
2241   if (bodyBuilderFn) {
2242     builder.setInsertionPointToStart(bodyBlock);
2243     bodyBuilderFn(builder, result.location,
2244                   bodyBlock->getArguments().take_front(numIVs),
2245                   bodyBlock->getArguments().drop_front(numIVs));
2246   }
2247   ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2248 }
2249 
2250 void ParallelOp::build(
2251     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2252     ValueRange upperBounds, ValueRange steps,
2253     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2254   // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2255   // we don't capture a reference to a temporary by constructing the lambda at
2256   // function level.
2257   auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2258                                            Location nestedLoc, ValueRange ivs,
2259                                            ValueRange) {
2260     bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2261   };
2262   function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2263   if (bodyBuilderFn)
2264     wrapper = wrappedBuilderFn;
2265 
2266   build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2267         wrapper);
2268 }
2269 
2270 LogicalResult ParallelOp::verify() {
2271   // Check that there is at least one value in lowerBound, upperBound and step.
2272   // It is sufficient to test only step, because it is ensured already that the
2273   // number of elements in lowerBound, upperBound and step are the same.
2274   Operation::operand_range stepValues = getStep();
2275   if (stepValues.empty())
2276     return emitOpError(
2277         "needs at least one tuple element for lowerBound, upperBound and step");
2278 
2279   // Check whether all constant step values are positive.
2280   for (Value stepValue : stepValues)
2281     if (auto cst = stepValue.getDefiningOp<arith::ConstantIndexOp>())
2282       if (cst.value() <= 0)
2283         return emitOpError("constant step operand must be positive");
2284 
2285   // Check that the body defines the same number of block arguments as the
2286   // number of tuple elements in step.
2287   Block *body = getBody();
2288   if (body->getNumArguments() != stepValues.size())
2289     return emitOpError() << "expects the same number of induction variables: "
2290                          << body->getNumArguments()
2291                          << " as bound and step values: " << stepValues.size();
2292   for (auto arg : body->getArguments())
2293     if (!arg.getType().isIndex())
2294       return emitOpError(
2295           "expects arguments for the induction variable to be of index type");
2296 
2297   // Check that the yield has no results
2298   Operation *yield = body->getTerminator();
2299   if (yield->getNumOperands() != 0)
2300     return yield->emitOpError() << "not allowed to have operands inside '"
2301                                 << ParallelOp::getOperationName() << "'";
2302 
2303   // Check that the number of results is the same as the number of ReduceOps.
2304   SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
2305   auto resultsSize = getResults().size();
2306   auto reductionsSize = reductions.size();
2307   auto initValsSize = getInitVals().size();
2308   if (resultsSize != reductionsSize)
2309     return emitOpError() << "expects number of results: " << resultsSize
2310                          << " to be the same as number of reductions: "
2311                          << reductionsSize;
2312   if (resultsSize != initValsSize)
2313     return emitOpError() << "expects number of results: " << resultsSize
2314                          << " to be the same as number of initial values: "
2315                          << initValsSize;
2316 
2317   // Check that the types of the results and reductions are the same.
2318   for (auto resultAndReduce : llvm::zip(getResults(), reductions)) {
2319     auto resultType = std::get<0>(resultAndReduce).getType();
2320     auto reduceOp = std::get<1>(resultAndReduce);
2321     auto reduceType = reduceOp.getOperand().getType();
2322     if (resultType != reduceType)
2323       return reduceOp.emitOpError()
2324              << "expects type of reduce: " << reduceType
2325              << " to be the same as result type: " << resultType;
2326   }
2327   return success();
2328 }
2329 
2330 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2331   auto &builder = parser.getBuilder();
2332   // Parse an opening `(` followed by induction variables followed by `)`
2333   SmallVector<OpAsmParser::Argument, 4> ivs;
2334   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
2335     return failure();
2336 
2337   // Parse loop bounds.
2338   SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2339   if (parser.parseEqual() ||
2340       parser.parseOperandList(lower, ivs.size(),
2341                               OpAsmParser::Delimiter::Paren) ||
2342       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2343     return failure();
2344 
2345   SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2346   if (parser.parseKeyword("to") ||
2347       parser.parseOperandList(upper, ivs.size(),
2348                               OpAsmParser::Delimiter::Paren) ||
2349       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2350     return failure();
2351 
2352   // Parse step values.
2353   SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2354   if (parser.parseKeyword("step") ||
2355       parser.parseOperandList(steps, ivs.size(),
2356                               OpAsmParser::Delimiter::Paren) ||
2357       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2358     return failure();
2359 
2360   // Parse init values.
2361   SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2362   if (succeeded(parser.parseOptionalKeyword("init"))) {
2363     if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2364       return failure();
2365   }
2366 
2367   // Parse optional results in case there is a reduce.
2368   if (parser.parseOptionalArrowTypeList(result.types))
2369     return failure();
2370 
2371   // Now parse the body.
2372   Region *body = result.addRegion();
2373   for (auto &iv : ivs)
2374     iv.type = builder.getIndexType();
2375   if (parser.parseRegion(*body, ivs))
2376     return failure();
2377 
2378   // Set `operand_segment_sizes` attribute.
2379   result.addAttribute(
2380       ParallelOp::getOperandSegmentSizeAttr(),
2381       builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
2382                                 static_cast<int32_t>(upper.size()),
2383                                 static_cast<int32_t>(steps.size()),
2384                                 static_cast<int32_t>(initVals.size())}));
2385 
2386   // Parse attributes.
2387   if (parser.parseOptionalAttrDict(result.attributes) ||
2388       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2389                              result.operands))
2390     return failure();
2391 
2392   // Add a terminator if none was parsed.
2393   ForOp::ensureTerminator(*body, builder, result.location);
2394   return success();
2395 }
2396 
2397 void ParallelOp::print(OpAsmPrinter &p) {
2398   p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2399     << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2400   if (!getInitVals().empty())
2401     p << " init (" << getInitVals() << ")";
2402   p.printOptionalArrowTypeList(getResultTypes());
2403   p << ' ';
2404   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2405   p.printOptionalAttrDict(
2406       (*this)->getAttrs(),
2407       /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2408 }
2409 
2410 Region &ParallelOp::getLoopBody() { return getRegion(); }
2411 
2412 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
2413   auto ivArg = val.dyn_cast<BlockArgument>();
2414   if (!ivArg)
2415     return ParallelOp();
2416   assert(ivArg.getOwner() && "unlinked block argument");
2417   auto *containingOp = ivArg.getOwner()->getParentOp();
2418   return dyn_cast<ParallelOp>(containingOp);
2419 }
2420 
2421 namespace {
2422 // Collapse loop dimensions that perform a single iteration.
2423 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
2424   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2425 
2426   LogicalResult matchAndRewrite(ParallelOp op,
2427                                 PatternRewriter &rewriter) const override {
2428     BlockAndValueMapping mapping;
2429     // Compute new loop bounds that omit all single-iteration loop dimensions.
2430     SmallVector<Value, 2> newLowerBounds;
2431     SmallVector<Value, 2> newUpperBounds;
2432     SmallVector<Value, 2> newSteps;
2433     newLowerBounds.reserve(op.getLowerBound().size());
2434     newUpperBounds.reserve(op.getUpperBound().size());
2435     newSteps.reserve(op.getStep().size());
2436     for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound(),
2437                               op.getStep(), op.getInductionVars())) {
2438       Value lowerBound, upperBound, step, iv;
2439       std::tie(lowerBound, upperBound, step, iv) = dim;
2440       // Collect the statically known loop bounds.
2441       auto lowerBoundConstant =
2442           dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
2443       auto upperBoundConstant =
2444           dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
2445       auto stepConstant =
2446           dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
2447       // Replace the loop induction variable by the lower bound if the loop
2448       // performs a single iteration. Otherwise, copy the loop bounds.
2449       if (lowerBoundConstant && upperBoundConstant && stepConstant &&
2450           (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
2451           (upperBoundConstant.value() - lowerBoundConstant.value()) <=
2452               stepConstant.value()) {
2453         mapping.map(iv, lowerBound);
2454       } else {
2455         newLowerBounds.push_back(lowerBound);
2456         newUpperBounds.push_back(upperBound);
2457         newSteps.push_back(step);
2458       }
2459     }
2460     // Exit if none of the loop dimensions perform a single iteration.
2461     if (newLowerBounds.size() == op.getLowerBound().size())
2462       return failure();
2463 
2464     if (newLowerBounds.empty()) {
2465       // All of the loop dimensions perform a single iteration. Inline
2466       // loop body and nested ReduceOp's
2467       SmallVector<Value> results;
2468       results.reserve(op.getInitVals().size());
2469       for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
2470         auto reduce = dyn_cast<ReduceOp>(bodyOp);
2471         if (!reduce) {
2472           rewriter.clone(bodyOp, mapping);
2473           continue;
2474         }
2475         Block &reduceBlock = reduce.getReductionOperator().front();
2476         auto initValIndex = results.size();
2477         mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
2478         mapping.map(reduceBlock.getArgument(1),
2479                     mapping.lookupOrDefault(reduce.getOperand()));
2480         for (auto &reduceBodyOp : reduceBlock.without_terminator())
2481           rewriter.clone(reduceBodyOp, mapping);
2482 
2483         auto result = mapping.lookupOrDefault(
2484             cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
2485         results.push_back(result);
2486       }
2487       rewriter.replaceOp(op, results);
2488       return success();
2489     }
2490     // Replace the parallel loop by lower-dimensional parallel loop.
2491     auto newOp =
2492         rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
2493                                     newSteps, op.getInitVals(), nullptr);
2494     // Clone the loop body and remap the block arguments of the collapsed loops
2495     // (inlining does not support a cancellable block argument mapping).
2496     rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
2497                                newOp.getRegion().begin(), mapping);
2498     rewriter.replaceOp(op, newOp.getResults());
2499     return success();
2500   }
2501 };
2502 
2503 /// Removes parallel loops in which at least one lower/upper bound pair consists
2504 /// of the same values - such loops have an empty iteration domain.
2505 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
2506   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2507 
2508   LogicalResult matchAndRewrite(ParallelOp op,
2509                                 PatternRewriter &rewriter) const override {
2510     for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
2511       if (std::get<0>(dim) == std::get<1>(dim)) {
2512         rewriter.replaceOp(op, op.getInitVals());
2513         return success();
2514       }
2515     }
2516     return failure();
2517   }
2518 };
2519 
2520 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
2521   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2522 
2523   LogicalResult matchAndRewrite(ParallelOp op,
2524                                 PatternRewriter &rewriter) const override {
2525     Block &outerBody = op.getLoopBody().front();
2526     if (!llvm::hasSingleElement(outerBody.without_terminator()))
2527       return failure();
2528 
2529     auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
2530     if (!innerOp)
2531       return failure();
2532 
2533     for (auto val : outerBody.getArguments())
2534       if (llvm::is_contained(innerOp.getLowerBound(), val) ||
2535           llvm::is_contained(innerOp.getUpperBound(), val) ||
2536           llvm::is_contained(innerOp.getStep(), val))
2537         return failure();
2538 
2539     // Reductions are not supported yet.
2540     if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2541       return failure();
2542 
2543     auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
2544                            ValueRange iterVals, ValueRange) {
2545       Block &innerBody = innerOp.getLoopBody().front();
2546       assert(iterVals.size() ==
2547              (outerBody.getNumArguments() + innerBody.getNumArguments()));
2548       BlockAndValueMapping mapping;
2549       mapping.map(outerBody.getArguments(),
2550                   iterVals.take_front(outerBody.getNumArguments()));
2551       mapping.map(innerBody.getArguments(),
2552                   iterVals.take_back(innerBody.getNumArguments()));
2553       for (Operation &op : innerBody.without_terminator())
2554         builder.clone(op, mapping);
2555     };
2556 
2557     auto concatValues = [](const auto &first, const auto &second) {
2558       SmallVector<Value> ret;
2559       ret.reserve(first.size() + second.size());
2560       ret.assign(first.begin(), first.end());
2561       ret.append(second.begin(), second.end());
2562       return ret;
2563     };
2564 
2565     auto newLowerBounds =
2566         concatValues(op.getLowerBound(), innerOp.getLowerBound());
2567     auto newUpperBounds =
2568         concatValues(op.getUpperBound(), innerOp.getUpperBound());
2569     auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2570 
2571     rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
2572                                             newSteps, llvm::None, bodyBuilder);
2573     return success();
2574   }
2575 };
2576 
2577 } // namespace
2578 
2579 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
2580                                              MLIRContext *context) {
2581   results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
2582               MergeNestedParallelLoops>(context);
2583 }
2584 
2585 //===----------------------------------------------------------------------===//
2586 // ReduceOp
2587 //===----------------------------------------------------------------------===//
2588 
2589 void ReduceOp::build(
2590     OpBuilder &builder, OperationState &result, Value operand,
2591     function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
2592   auto type = operand.getType();
2593   result.addOperands(operand);
2594 
2595   OpBuilder::InsertionGuard guard(builder);
2596   Region *bodyRegion = result.addRegion();
2597   Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type},
2598                                     {result.location, result.location});
2599   if (bodyBuilderFn)
2600     bodyBuilderFn(builder, result.location, body->getArgument(0),
2601                   body->getArgument(1));
2602 }
2603 
2604 LogicalResult ReduceOp::verifyRegions() {
2605   // The region of a ReduceOp has two arguments of the same type as its operand.
2606   auto type = getOperand().getType();
2607   Block &block = getReductionOperator().front();
2608   if (block.empty())
2609     return emitOpError("the block inside reduce should not be empty");
2610   if (block.getNumArguments() != 2 ||
2611       llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
2612         return arg.getType() != type;
2613       }))
2614     return emitOpError() << "expects two arguments to reduce block of type "
2615                          << type;
2616 
2617   // Check that the block is terminated by a ReduceReturnOp.
2618   if (!isa<ReduceReturnOp>(block.getTerminator()))
2619     return emitOpError("the block inside reduce should be terminated with a "
2620                        "'scf.reduce.return' op");
2621 
2622   return success();
2623 }
2624 
2625 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2626   // Parse an opening `(` followed by the reduced value followed by `)`
2627   OpAsmParser::UnresolvedOperand operand;
2628   if (parser.parseLParen() || parser.parseOperand(operand) ||
2629       parser.parseRParen())
2630     return failure();
2631 
2632   Type resultType;
2633   // Parse the type of the operand (and also what reduce computes on).
2634   if (parser.parseColonType(resultType) ||
2635       parser.resolveOperand(operand, resultType, result.operands))
2636     return failure();
2637 
2638   // Now parse the body.
2639   Region *body = result.addRegion();
2640   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2641     return failure();
2642 
2643   return success();
2644 }
2645 
2646 void ReduceOp::print(OpAsmPrinter &p) {
2647   p << "(" << getOperand() << ") ";
2648   p << " : " << getOperand().getType() << ' ';
2649   p.printRegion(getReductionOperator());
2650 }
2651 
2652 //===----------------------------------------------------------------------===//
2653 // ReduceReturnOp
2654 //===----------------------------------------------------------------------===//
2655 
2656 LogicalResult ReduceReturnOp::verify() {
2657   // The type of the return value should be the same type as the type of the
2658   // operand of the enclosing ReduceOp.
2659   auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
2660   Type reduceType = reduceOp.getOperand().getType();
2661   if (reduceType != getResult().getType())
2662     return emitOpError() << "needs to have type " << reduceType
2663                          << " (the type of the enclosing ReduceOp)";
2664   return success();
2665 }
2666 
2667 //===----------------------------------------------------------------------===//
2668 // WhileOp
2669 //===----------------------------------------------------------------------===//
2670 
2671 OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
2672   assert(index && *index == 0 &&
2673          "WhileOp is expected to branch only to the first region");
2674 
2675   return getInits();
2676 }
2677 
2678 ConditionOp WhileOp::getConditionOp() {
2679   return cast<ConditionOp>(getBefore().front().getTerminator());
2680 }
2681 
2682 YieldOp WhileOp::getYieldOp() {
2683   return cast<YieldOp>(getAfter().front().getTerminator());
2684 }
2685 
2686 Block::BlockArgListType WhileOp::getBeforeArguments() {
2687   return getBefore().front().getArguments();
2688 }
2689 
2690 Block::BlockArgListType WhileOp::getAfterArguments() {
2691   return getAfter().front().getArguments();
2692 }
2693 
2694 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
2695                                   ArrayRef<Attribute> operands,
2696                                   SmallVectorImpl<RegionSuccessor> &regions) {
2697   // The parent op always branches to the condition region.
2698   if (!index) {
2699     regions.emplace_back(&getBefore(), getBefore().getArguments());
2700     return;
2701   }
2702 
2703   assert(*index < 2 && "there are only two regions in a WhileOp");
2704   // The body region always branches back to the condition region.
2705   if (*index == 1) {
2706     regions.emplace_back(&getBefore(), getBefore().getArguments());
2707     return;
2708   }
2709 
2710   // Try to narrow the successor to the condition region.
2711   assert(!operands.empty() && "expected at least one operand");
2712   auto cond = operands[0].dyn_cast_or_null<BoolAttr>();
2713   if (!cond || !cond.getValue())
2714     regions.emplace_back(getResults());
2715   if (!cond || cond.getValue())
2716     regions.emplace_back(&getAfter(), getAfter().getArguments());
2717 }
2718 
2719 /// Parses a `while` op.
2720 ///
2721 /// op ::= `scf.while` assignments `:` function-type region `do` region
2722 ///         `attributes` attribute-dict
2723 /// initializer ::= /* empty */ | `(` assignment-list `)`
2724 /// assignment-list ::= assignment | assignment `,` assignment-list
2725 /// assignment ::= ssa-value `=` ssa-value
2726 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2727   SmallVector<OpAsmParser::Argument, 4> regionArgs;
2728   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2729   Region *before = result.addRegion();
2730   Region *after = result.addRegion();
2731 
2732   OptionalParseResult listResult =
2733       parser.parseOptionalAssignmentList(regionArgs, operands);
2734   if (listResult.hasValue() && failed(listResult.getValue()))
2735     return failure();
2736 
2737   FunctionType functionType;
2738   SMLoc typeLoc = parser.getCurrentLocation();
2739   if (failed(parser.parseColonType(functionType)))
2740     return failure();
2741 
2742   result.addTypes(functionType.getResults());
2743 
2744   if (functionType.getNumInputs() != operands.size()) {
2745     return parser.emitError(typeLoc)
2746            << "expected as many input types as operands "
2747            << "(expected " << operands.size() << " got "
2748            << functionType.getNumInputs() << ")";
2749   }
2750 
2751   // Resolve input operands.
2752   if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2753                                     parser.getCurrentLocation(),
2754                                     result.operands)))
2755     return failure();
2756 
2757   // Propagate the types into the region arguments.
2758   for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2759     regionArgs[i].type = functionType.getInput(i);
2760 
2761   return failure(parser.parseRegion(*before, regionArgs) ||
2762                  parser.parseKeyword("do") || parser.parseRegion(*after) ||
2763                  parser.parseOptionalAttrDictWithKeyword(result.attributes));
2764 }
2765 
2766 /// Prints a `while` op.
2767 void scf::WhileOp::print(OpAsmPrinter &p) {
2768   printInitializationList(p, getBefore().front().getArguments(), getInits(),
2769                           " ");
2770   p << " : ";
2771   p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
2772   p << ' ';
2773   p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
2774   p << " do ";
2775   p.printRegion(getAfter());
2776   p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2777 }
2778 
2779 /// Verifies that two ranges of types match, i.e. have the same number of
2780 /// entries and that types are pairwise equals. Reports errors on the given
2781 /// operation in case of mismatch.
2782 template <typename OpTy>
2783 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
2784                                            TypeRange right, StringRef message) {
2785   if (left.size() != right.size())
2786     return op.emitOpError("expects the same number of ") << message;
2787 
2788   for (unsigned i = 0, e = left.size(); i < e; ++i) {
2789     if (left[i] != right[i]) {
2790       InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
2791                                 << message;
2792       diag.attachNote() << "for argument " << i << ", found " << left[i]
2793                         << " and " << right[i];
2794       return diag;
2795     }
2796   }
2797 
2798   return success();
2799 }
2800 
2801 /// Verifies that the first block of the given `region` is terminated by a
2802 /// YieldOp. Reports errors on the given operation if it is not the case.
2803 template <typename TerminatorTy>
2804 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
2805                                            StringRef errorMessage) {
2806   Operation *terminatorOperation = region.front().getTerminator();
2807   if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
2808     return yield;
2809 
2810   auto diag = op.emitOpError(errorMessage);
2811   if (terminatorOperation)
2812     diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
2813   return nullptr;
2814 }
2815 
2816 LogicalResult scf::WhileOp::verify() {
2817   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
2818       *this, getBefore(),
2819       "expects the 'before' region to terminate with 'scf.condition'");
2820   if (!beforeTerminator)
2821     return failure();
2822 
2823   auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
2824       *this, getAfter(),
2825       "expects the 'after' region to terminate with 'scf.yield'");
2826   return success(afterTerminator != nullptr);
2827 }
2828 
2829 namespace {
2830 /// Replace uses of the condition within the do block with true, since otherwise
2831 /// the block would not be evaluated.
2832 ///
2833 /// scf.while (..) : (i1, ...) -> ... {
2834 ///  %condition = call @evaluate_condition() : () -> i1
2835 ///  scf.condition(%condition) %condition : i1, ...
2836 /// } do {
2837 /// ^bb0(%arg0: i1, ...):
2838 ///    use(%arg0)
2839 ///    ...
2840 ///
2841 /// becomes
2842 /// scf.while (..) : (i1, ...) -> ... {
2843 ///  %condition = call @evaluate_condition() : () -> i1
2844 ///  scf.condition(%condition) %condition : i1, ...
2845 /// } do {
2846 /// ^bb0(%arg0: i1, ...):
2847 ///    use(%true)
2848 ///    ...
2849 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
2850   using OpRewritePattern<WhileOp>::OpRewritePattern;
2851 
2852   LogicalResult matchAndRewrite(WhileOp op,
2853                                 PatternRewriter &rewriter) const override {
2854     auto term = op.getConditionOp();
2855 
2856     // These variables serve to prevent creating duplicate constants
2857     // and hold constant true or false values.
2858     Value constantTrue = nullptr;
2859 
2860     bool replaced = false;
2861     for (auto yieldedAndBlockArgs :
2862          llvm::zip(term.getArgs(), op.getAfterArguments())) {
2863       if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
2864         if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
2865           if (!constantTrue)
2866             constantTrue = rewriter.create<arith::ConstantOp>(
2867                 op.getLoc(), term.getCondition().getType(),
2868                 rewriter.getBoolAttr(true));
2869 
2870           std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
2871           replaced = true;
2872         }
2873       }
2874     }
2875     return success(replaced);
2876   }
2877 };
2878 
2879 /// Remove loop invariant arguments from `before` block of scf.while.
2880 /// A before block argument is considered loop invariant if :-
2881 ///   1. i-th yield operand is equal to the i-th while operand.
2882 ///   2. i-th yield operand is k-th after block argument which is (k+1)-th
2883 ///      condition operand AND this (k+1)-th condition operand is equal to i-th
2884 ///      iter argument/while operand.
2885 /// For the arguments which are removed, their uses inside scf.while
2886 /// are replaced with their corresponding initial value.
2887 ///
2888 /// Eg:
2889 ///    INPUT :-
2890 ///    %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
2891 ///                                     ..., %argN_before = %N)
2892 ///           {
2893 ///                ...
2894 ///                scf.condition(%cond) %arg1_before, %arg0_before,
2895 ///                                     %arg2_before, %arg0_before, ...
2896 ///           } do {
2897 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2898 ///                  ..., %argK_after):
2899 ///                ...
2900 ///                scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
2901 ///           }
2902 ///
2903 ///    OUTPUT :-
2904 ///    %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
2905 ///                                     %N)
2906 ///           {
2907 ///                ...
2908 ///                scf.condition(%cond) %b, %a, %arg2_before, %a, ...
2909 ///           } do {
2910 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2911 ///                  ..., %argK_after):
2912 ///                ...
2913 ///                scf.yield %arg1_after, ..., %argN
2914 ///           }
2915 ///
2916 ///    EXPLANATION:
2917 ///      We iterate over each yield operand.
2918 ///        1. 0-th yield operand %arg0_after_2 is 4-th condition operand
2919 ///           %arg0_before, which in turn is the 0-th iter argument. So we
2920 ///           remove 0-th before block argument and yield operand, and replace
2921 ///           all uses of the 0-th before block argument with its initial value
2922 ///           %a.
2923 ///        2. 1-th yield operand %b is equal to the 1-th iter arg's initial
2924 ///           value. So we remove this operand and the corresponding before
2925 ///           block argument and replace all uses of 1-th before block argument
2926 ///           with %b.
2927 struct RemoveLoopInvariantArgsFromBeforeBlock
2928     : public OpRewritePattern<WhileOp> {
2929   using OpRewritePattern<WhileOp>::OpRewritePattern;
2930 
2931   LogicalResult matchAndRewrite(WhileOp op,
2932                                 PatternRewriter &rewriter) const override {
2933     Block &afterBlock = op.getAfter().front();
2934     Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
2935     ConditionOp condOp = op.getConditionOp();
2936     OperandRange condOpArgs = condOp.getArgs();
2937     Operation *yieldOp = afterBlock.getTerminator();
2938     ValueRange yieldOpArgs = yieldOp->getOperands();
2939 
2940     bool canSimplify = false;
2941     for (const auto &it :
2942          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2943       auto index = static_cast<unsigned>(it.index());
2944       Value initVal, yieldOpArg;
2945       std::tie(initVal, yieldOpArg) = it.value();
2946       // If i-th yield operand is equal to the i-th operand of the scf.while,
2947       // the i-th before block argument is a loop invariant.
2948       if (yieldOpArg == initVal) {
2949         canSimplify = true;
2950         break;
2951       }
2952       // If the i-th yield operand is k-th after block argument, then we check
2953       // if the (k+1)-th condition op operand is equal to either the i-th before
2954       // block argument or the initial value of i-th before block argument. If
2955       // the comparison results `true`, i-th before block argument is a loop
2956       // invariant.
2957       auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2958       if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2959         Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2960         if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2961           canSimplify = true;
2962           break;
2963         }
2964       }
2965     }
2966 
2967     if (!canSimplify)
2968       return failure();
2969 
2970     SmallVector<Value> newInitArgs, newYieldOpArgs;
2971     DenseMap<unsigned, Value> beforeBlockInitValMap;
2972     SmallVector<Location> newBeforeBlockArgLocs;
2973     for (const auto &it :
2974          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2975       auto index = static_cast<unsigned>(it.index());
2976       Value initVal, yieldOpArg;
2977       std::tie(initVal, yieldOpArg) = it.value();
2978 
2979       // If i-th yield operand is equal to the i-th operand of the scf.while,
2980       // the i-th before block argument is a loop invariant.
2981       if (yieldOpArg == initVal) {
2982         beforeBlockInitValMap.insert({index, initVal});
2983         continue;
2984       } else {
2985         // If the i-th yield operand is k-th after block argument, then we check
2986         // if the (k+1)-th condition op operand is equal to either the i-th
2987         // before block argument or the initial value of i-th before block
2988         // argument. If the comparison results `true`, i-th before block
2989         // argument is a loop invariant.
2990         auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2991         if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2992           Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2993           if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2994             beforeBlockInitValMap.insert({index, initVal});
2995             continue;
2996           }
2997         }
2998       }
2999       newInitArgs.emplace_back(initVal);
3000       newYieldOpArgs.emplace_back(yieldOpArg);
3001       newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3002     }
3003 
3004     {
3005       OpBuilder::InsertionGuard g(rewriter);
3006       rewriter.setInsertionPoint(yieldOp);
3007       rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3008     }
3009 
3010     auto newWhile =
3011         rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3012 
3013     Block &newBeforeBlock = *rewriter.createBlock(
3014         &newWhile.getBefore(), /*insertPt*/ {},
3015         ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3016 
3017     Block &beforeBlock = op.getBefore().front();
3018     SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3019     // For each i-th before block argument we find it's replacement value as :-
3020     //   1. If i-th before block argument is a loop invariant, we fetch it's
3021     //      initial value from `beforeBlockInitValMap` by querying for key `i`.
3022     //   2. Else we fetch j-th new before block argument as the replacement
3023     //      value of i-th before block argument.
3024     for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3025       // If the index 'i' argument was a loop invariant we fetch it's initial
3026       // value from `beforeBlockInitValMap`.
3027       if (beforeBlockInitValMap.count(i) != 0)
3028         newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3029       else
3030         newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3031     }
3032 
3033     rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3034     rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3035                                 newWhile.getAfter().begin());
3036 
3037     rewriter.replaceOp(op, newWhile.getResults());
3038     return success();
3039   }
3040 };
3041 
3042 /// Remove loop invariant value from result (condition op) of scf.while.
3043 /// A value is considered loop invariant if the final value yielded by
3044 /// scf.condition is defined outside of the `before` block. We remove the
3045 /// corresponding argument in `after` block and replace the use with the value.
3046 /// We also replace the use of the corresponding result of scf.while with the
3047 /// value.
3048 ///
3049 /// Eg:
3050 ///    INPUT :-
3051 ///    %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3052 ///                                             %argN_before = %N) {
3053 ///                ...
3054 ///                scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3055 ///           } do {
3056 ///             ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3057 ///                ...
3058 ///                some_func(%arg1_after)
3059 ///                ...
3060 ///                scf.yield %arg0_after, %arg2_after, ..., %argN_after
3061 ///           }
3062 ///
3063 ///    OUTPUT :-
3064 ///    %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3065 ///                ...
3066 ///                scf.condition(%cond) %arg0, %arg1, ..., %argM
3067 ///           } do {
3068 ///             ^bb0(%arg0, %arg3, ..., %argM):
3069 ///                ...
3070 ///                some_func(%a)
3071 ///                ...
3072 ///                scf.yield %arg0, %b, ..., %argN
3073 ///           }
3074 ///
3075 ///     EXPLANATION:
3076 ///       1. The 1-th and 2-th operand of scf.condition are defined outside the
3077 ///          before block of scf.while, so they get removed.
3078 ///       2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3079 ///          replaced by %b.
3080 ///       3. The corresponding after block argument %arg1_after's uses are
3081 ///          replaced by %a and %arg2_after's uses are replaced by %b.
3082 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3083   using OpRewritePattern<WhileOp>::OpRewritePattern;
3084 
3085   LogicalResult matchAndRewrite(WhileOp op,
3086                                 PatternRewriter &rewriter) const override {
3087     Block &beforeBlock = op.getBefore().front();
3088     ConditionOp condOp = op.getConditionOp();
3089     OperandRange condOpArgs = condOp.getArgs();
3090 
3091     bool canSimplify = false;
3092     for (Value condOpArg : condOpArgs) {
3093       // Those values not defined within `before` block will be considered as
3094       // loop invariant values. We map the corresponding `index` with their
3095       // value.
3096       if (condOpArg.getParentBlock() != &beforeBlock) {
3097         canSimplify = true;
3098         break;
3099       }
3100     }
3101 
3102     if (!canSimplify)
3103       return failure();
3104 
3105     Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3106 
3107     SmallVector<Value> newCondOpArgs;
3108     SmallVector<Type> newAfterBlockType;
3109     DenseMap<unsigned, Value> condOpInitValMap;
3110     SmallVector<Location> newAfterBlockArgLocs;
3111     for (const auto &it : llvm::enumerate(condOpArgs)) {
3112       auto index = static_cast<unsigned>(it.index());
3113       Value condOpArg = it.value();
3114       // Those values not defined within `before` block will be considered as
3115       // loop invariant values. We map the corresponding `index` with their
3116       // value.
3117       if (condOpArg.getParentBlock() != &beforeBlock) {
3118         condOpInitValMap.insert({index, condOpArg});
3119       } else {
3120         newCondOpArgs.emplace_back(condOpArg);
3121         newAfterBlockType.emplace_back(condOpArg.getType());
3122         newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3123       }
3124     }
3125 
3126     {
3127       OpBuilder::InsertionGuard g(rewriter);
3128       rewriter.setInsertionPoint(condOp);
3129       rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3130                                                newCondOpArgs);
3131     }
3132 
3133     auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3134                                              op.getOperands());
3135 
3136     Block &newAfterBlock =
3137         *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3138                               newAfterBlockType, newAfterBlockArgLocs);
3139 
3140     Block &afterBlock = op.getAfter().front();
3141     // Since a new scf.condition op was created, we need to fetch the new
3142     // `after` block arguments which will be used while replacing operations of
3143     // previous scf.while's `after` blocks. We'd also be fetching new result
3144     // values too.
3145     SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3146     SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3147     for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3148       Value afterBlockArg, result;
3149       // If index 'i' argument was loop invariant we fetch it's value from the
3150       // `condOpInitMap` map.
3151       if (condOpInitValMap.count(i) != 0) {
3152         afterBlockArg = condOpInitValMap[i];
3153         result = afterBlockArg;
3154       } else {
3155         afterBlockArg = newAfterBlock.getArgument(j);
3156         result = newWhile.getResult(j);
3157         j++;
3158       }
3159       newAfterBlockArgs[i] = afterBlockArg;
3160       newWhileResults[i] = result;
3161     }
3162 
3163     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3164     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3165                                 newWhile.getBefore().begin());
3166 
3167     rewriter.replaceOp(op, newWhileResults);
3168     return success();
3169   }
3170 };
3171 
3172 /// Remove WhileOp results that are also unused in 'after' block.
3173 ///
3174 ///  %0:2 = scf.while () : () -> (i32, i64) {
3175 ///    %condition = "test.condition"() : () -> i1
3176 ///    %v1 = "test.get_some_value"() : () -> i32
3177 ///    %v2 = "test.get_some_value"() : () -> i64
3178 ///    scf.condition(%condition) %v1, %v2 : i32, i64
3179 ///  } do {
3180 ///  ^bb0(%arg0: i32, %arg1: i64):
3181 ///    "test.use"(%arg0) : (i32) -> ()
3182 ///    scf.yield
3183 ///  }
3184 ///  return %0#0 : i32
3185 ///
3186 /// becomes
3187 ///  %0 = scf.while () : () -> (i32) {
3188 ///    %condition = "test.condition"() : () -> i1
3189 ///    %v1 = "test.get_some_value"() : () -> i32
3190 ///    %v2 = "test.get_some_value"() : () -> i64
3191 ///    scf.condition(%condition) %v1 : i32
3192 ///  } do {
3193 ///  ^bb0(%arg0: i32):
3194 ///    "test.use"(%arg0) : (i32) -> ()
3195 ///    scf.yield
3196 ///  }
3197 ///  return %0 : i32
3198 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3199   using OpRewritePattern<WhileOp>::OpRewritePattern;
3200 
3201   LogicalResult matchAndRewrite(WhileOp op,
3202                                 PatternRewriter &rewriter) const override {
3203     auto term = op.getConditionOp();
3204     auto afterArgs = op.getAfterArguments();
3205     auto termArgs = term.getArgs();
3206 
3207     // Collect results mapping, new terminator args and new result types.
3208     SmallVector<unsigned> newResultsIndices;
3209     SmallVector<Type> newResultTypes;
3210     SmallVector<Value> newTermArgs;
3211     SmallVector<Location> newArgLocs;
3212     bool needUpdate = false;
3213     for (const auto &it :
3214          llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3215       auto i = static_cast<unsigned>(it.index());
3216       Value result = std::get<0>(it.value());
3217       Value afterArg = std::get<1>(it.value());
3218       Value termArg = std::get<2>(it.value());
3219       if (result.use_empty() && afterArg.use_empty()) {
3220         needUpdate = true;
3221       } else {
3222         newResultsIndices.emplace_back(i);
3223         newTermArgs.emplace_back(termArg);
3224         newResultTypes.emplace_back(result.getType());
3225         newArgLocs.emplace_back(result.getLoc());
3226       }
3227     }
3228 
3229     if (!needUpdate)
3230       return failure();
3231 
3232     {
3233       OpBuilder::InsertionGuard g(rewriter);
3234       rewriter.setInsertionPoint(term);
3235       rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3236                                                newTermArgs);
3237     }
3238 
3239     auto newWhile =
3240         rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3241 
3242     Block &newAfterBlock = *rewriter.createBlock(
3243         &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3244 
3245     // Build new results list and new after block args (unused entries will be
3246     // null).
3247     SmallVector<Value> newResults(op.getNumResults());
3248     SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3249     for (const auto &it : llvm::enumerate(newResultsIndices)) {
3250       newResults[it.value()] = newWhile.getResult(it.index());
3251       newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3252     }
3253 
3254     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3255                                 newWhile.getBefore().begin());
3256 
3257     Block &afterBlock = op.getAfter().front();
3258     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3259 
3260     rewriter.replaceOp(op, newResults);
3261     return success();
3262   }
3263 };
3264 
3265 /// Replace operations equivalent to the condition in the do block with true,
3266 /// since otherwise the block would not be evaluated.
3267 ///
3268 /// scf.while (..) : (i32, ...) -> ... {
3269 ///  %z = ... : i32
3270 ///  %condition = cmpi pred %z, %a
3271 ///  scf.condition(%condition) %z : i32, ...
3272 /// } do {
3273 /// ^bb0(%arg0: i32, ...):
3274 ///    %condition2 = cmpi pred %arg0, %a
3275 ///    use(%condition2)
3276 ///    ...
3277 ///
3278 /// becomes
3279 /// scf.while (..) : (i32, ...) -> ... {
3280 ///  %z = ... : i32
3281 ///  %condition = cmpi pred %z, %a
3282 ///  scf.condition(%condition) %z : i32, ...
3283 /// } do {
3284 /// ^bb0(%arg0: i32, ...):
3285 ///    use(%true)
3286 ///    ...
3287 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3288   using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3289 
3290   LogicalResult matchAndRewrite(scf::WhileOp op,
3291                                 PatternRewriter &rewriter) const override {
3292     using namespace scf;
3293     auto cond = op.getConditionOp();
3294     auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3295     if (!cmp)
3296       return failure();
3297     bool changed = false;
3298     for (auto tup :
3299          llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
3300       for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3301         if (std::get<0>(tup) != cmp.getOperand(opIdx))
3302           continue;
3303         for (OpOperand &u :
3304              llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3305           auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3306           if (!cmp2)
3307             continue;
3308           // For a binary operator 1-opIdx gets the other side.
3309           if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3310             continue;
3311           bool samePredicate;
3312           if (cmp2.getPredicate() == cmp.getPredicate())
3313             samePredicate = true;
3314           else if (cmp2.getPredicate() ==
3315                    arith::invertPredicate(cmp.getPredicate()))
3316             samePredicate = false;
3317           else
3318             continue;
3319 
3320           rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3321                                                             1);
3322           changed = true;
3323         }
3324       }
3325     }
3326     return success(changed);
3327   }
3328 };
3329 
3330 struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
3331   using OpRewritePattern<WhileOp>::OpRewritePattern;
3332 
3333   LogicalResult matchAndRewrite(WhileOp op,
3334                                 PatternRewriter &rewriter) const override {
3335 
3336     if (!llvm::any_of(op.getBeforeArguments(),
3337                       [](Value arg) { return arg.use_empty(); }))
3338       return failure();
3339 
3340     YieldOp yield = op.getYieldOp();
3341 
3342     // Collect results mapping, new terminator args and new result types.
3343     SmallVector<Value> newYields;
3344     SmallVector<Value> newInits;
3345     SmallVector<unsigned> argsToErase;
3346     for (const auto &it : llvm::enumerate(llvm::zip(
3347              op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
3348       Value beforeArg = std::get<0>(it.value());
3349       Value yieldValue = std::get<1>(it.value());
3350       Value initValue = std::get<2>(it.value());
3351       if (beforeArg.use_empty()) {
3352         argsToErase.push_back(it.index());
3353       } else {
3354         newYields.emplace_back(yieldValue);
3355         newInits.emplace_back(initValue);
3356       }
3357     }
3358 
3359     if (argsToErase.empty())
3360       return failure();
3361 
3362     rewriter.startRootUpdate(op);
3363     op.getBefore().front().eraseArguments(argsToErase);
3364     rewriter.finalizeRootUpdate(op);
3365 
3366     WhileOp replacement =
3367         rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
3368     replacement.getBefore().takeBody(op.getBefore());
3369     replacement.getAfter().takeBody(op.getAfter());
3370     rewriter.replaceOp(op, replacement.getResults());
3371 
3372     rewriter.setInsertionPoint(yield);
3373     rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3374     return success();
3375   }
3376 };
3377 } // namespace
3378 
3379 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3380                                           MLIRContext *context) {
3381   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
3382               RemoveLoopInvariantValueYielded, WhileConditionTruth,
3383               WhileCmpCond, WhileUnusedResult>(context);
3384 }
3385 
3386 //===----------------------------------------------------------------------===//
3387 // TableGen'd op method definitions
3388 //===----------------------------------------------------------------------===//
3389 
3390 #define GET_OP_CLASSES
3391 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
3392