1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/IR/SCF.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/FunctionInterfaces.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Support/MathExtras.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 
23 using namespace mlir;
24 using namespace mlir::scf;
25 
26 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // SCFDialect Dialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 struct SCFInlinerInterface : public DialectInlinerInterface {
34   using DialectInlinerInterface::DialectInlinerInterface;
35   // We don't have any special restrictions on what can be inlined into
36   // destination regions (e.g. while/conditional bodies). Always allow it.
37   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
38                        BlockAndValueMapping &valueMapping) const final {
39     return true;
40   }
41   // Operations in scf dialect are always legal to inline since they are
42   // pure.
43   bool isLegalToInline(Operation *, Region *, bool,
44                        BlockAndValueMapping &) const final {
45     return true;
46   }
47   // Handle the given inlined terminator by replacing it with a new operation
48   // as necessary. Required when the region has only one block.
49   void handleTerminator(Operation *op,
50                         ArrayRef<Value> valuesToRepl) const final {
51     auto retValOp = dyn_cast<scf::YieldOp>(op);
52     if (!retValOp)
53       return;
54 
55     for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
56       std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
57     }
58   }
59 };
60 } // namespace
61 
62 //===----------------------------------------------------------------------===//
63 // SCFDialect
64 //===----------------------------------------------------------------------===//
65 
66 void SCFDialect::initialize() {
67   addOperations<
68 #define GET_OP_LIST
69 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
70       >();
71   addInterfaces<SCFInlinerInterface>();
72 }
73 
74 /// Default callback for IfOp builders. Inserts a yield without arguments.
75 void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
76   builder.create<scf::YieldOp>(loc);
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // ExecuteRegionOp
81 //===----------------------------------------------------------------------===//
82 
83 /// Replaces the given op with the contents of the given single-block region,
84 /// using the operands of the block terminator to replace operation results.
85 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
86                                 Region &region, ValueRange blockArgs = {}) {
87   assert(llvm::hasSingleElement(region) && "expected single-region block");
88   Block *block = &region.front();
89   Operation *terminator = block->getTerminator();
90   ValueRange results = terminator->getOperands();
91   rewriter.mergeBlockBefore(block, op, blockArgs);
92   rewriter.replaceOp(op, results);
93   rewriter.eraseOp(terminator);
94 }
95 
96 ///
97 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
98 ///    block+
99 /// `}`
100 ///
101 /// Example:
102 ///   scf.execute_region -> i32 {
103 ///     %idx = load %rI[%i] : memref<128xi32>
104 ///     return %idx : i32
105 ///   }
106 ///
107 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
108                                    OperationState &result) {
109   if (parser.parseOptionalArrowTypeList(result.types))
110     return failure();
111 
112   // Introduce the body region and parse it.
113   Region *body = result.addRegion();
114   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
115       parser.parseOptionalAttrDict(result.attributes))
116     return failure();
117 
118   return success();
119 }
120 
121 void ExecuteRegionOp::print(OpAsmPrinter &p) {
122   p.printOptionalArrowTypeList(getResultTypes());
123 
124   p << ' ';
125   p.printRegion(getRegion(),
126                 /*printEntryBlockArgs=*/false,
127                 /*printBlockTerminators=*/true);
128 
129   p.printOptionalAttrDict((*this)->getAttrs());
130 }
131 
132 LogicalResult ExecuteRegionOp::verify() {
133   if (getRegion().empty())
134     return emitOpError("region needs to have at least one block");
135   if (getRegion().front().getNumArguments() > 0)
136     return emitOpError("region cannot have any arguments");
137   return success();
138 }
139 
140 // Inline an ExecuteRegionOp if it only contains one block.
141 //     "test.foo"() : () -> ()
142 //      %v = scf.execute_region -> i64 {
143 //        %x = "test.val"() : () -> i64
144 //        scf.yield %x : i64
145 //      }
146 //      "test.bar"(%v) : (i64) -> ()
147 //
148 //  becomes
149 //
150 //     "test.foo"() : () -> ()
151 //     %x = "test.val"() : () -> i64
152 //     "test.bar"(%x) : (i64) -> ()
153 //
154 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
155   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
156 
157   LogicalResult matchAndRewrite(ExecuteRegionOp op,
158                                 PatternRewriter &rewriter) const override {
159     if (!llvm::hasSingleElement(op.getRegion()))
160       return failure();
161     replaceOpWithRegion(rewriter, op, op.getRegion());
162     return success();
163   }
164 };
165 
166 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
167 // TODO generalize the conditions for operations which can be inlined into.
168 // func @func_execute_region_elim() {
169 //     "test.foo"() : () -> ()
170 //     %v = scf.execute_region -> i64 {
171 //       %c = "test.cmp"() : () -> i1
172 //       cf.cond_br %c, ^bb2, ^bb3
173 //     ^bb2:
174 //       %x = "test.val1"() : () -> i64
175 //       cf.br ^bb4(%x : i64)
176 //     ^bb3:
177 //       %y = "test.val2"() : () -> i64
178 //       cf.br ^bb4(%y : i64)
179 //     ^bb4(%z : i64):
180 //       scf.yield %z : i64
181 //     }
182 //     "test.bar"(%v) : (i64) -> ()
183 //   return
184 // }
185 //
186 //  becomes
187 //
188 // func @func_execute_region_elim() {
189 //    "test.foo"() : () -> ()
190 //    %c = "test.cmp"() : () -> i1
191 //    cf.cond_br %c, ^bb1, ^bb2
192 //  ^bb1:  // pred: ^bb0
193 //    %x = "test.val1"() : () -> i64
194 //    cf.br ^bb3(%x : i64)
195 //  ^bb2:  // pred: ^bb0
196 //    %y = "test.val2"() : () -> i64
197 //    cf.br ^bb3(%y : i64)
198 //  ^bb3(%z: i64):  // 2 preds: ^bb1, ^bb2
199 //    "test.bar"(%z) : (i64) -> ()
200 //    return
201 //  }
202 //
203 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
204   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
205 
206   LogicalResult matchAndRewrite(ExecuteRegionOp op,
207                                 PatternRewriter &rewriter) const override {
208     if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
209       return failure();
210 
211     Block *prevBlock = op->getBlock();
212     Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
213     rewriter.setInsertionPointToEnd(prevBlock);
214 
215     rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
216 
217     for (Block &blk : op.getRegion()) {
218       if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
219         rewriter.setInsertionPoint(yieldOp);
220         rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
221                                       yieldOp.getResults());
222         rewriter.eraseOp(yieldOp);
223       }
224     }
225 
226     rewriter.inlineRegionBefore(op.getRegion(), postBlock);
227     SmallVector<Value> blockArgs;
228 
229     for (auto res : op.getResults())
230       blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
231 
232     rewriter.replaceOp(op, blockArgs);
233     return success();
234   }
235 };
236 
237 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
238                                                   MLIRContext *context) {
239   results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
240 }
241 
242 /// Given the region at `index`, or the parent operation if `index` is None,
243 /// return the successor regions. These are the regions that may be selected
244 /// during the flow of control. `operands` is a set of optional attributes that
245 /// correspond to a constant value for each operand, or null if that operand is
246 /// not a constant.
247 void ExecuteRegionOp::getSuccessorRegions(
248     Optional<unsigned> index, ArrayRef<Attribute> operands,
249     SmallVectorImpl<RegionSuccessor> &regions) {
250   // If the predecessor is the ExecuteRegionOp, branch into the body.
251   if (!index) {
252     regions.push_back(RegionSuccessor(&getRegion()));
253     return;
254   }
255 
256   // Otherwise, the region branches back to the parent operation.
257   regions.push_back(RegionSuccessor(getResults()));
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // ConditionOp
262 //===----------------------------------------------------------------------===//
263 
264 MutableOperandRange
265 ConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
266   // Pass all operands except the condition to the successor region.
267   return getArgsMutable();
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // ForOp
272 //===----------------------------------------------------------------------===//
273 
274 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
275                   Value ub, Value step, ValueRange iterArgs,
276                   BodyBuilderFn bodyBuilder) {
277   result.addOperands({lb, ub, step});
278   result.addOperands(iterArgs);
279   for (Value v : iterArgs)
280     result.addTypes(v.getType());
281   Region *bodyRegion = result.addRegion();
282   bodyRegion->push_back(new Block);
283   Block &bodyBlock = bodyRegion->front();
284   bodyBlock.addArgument(builder.getIndexType(), result.location);
285   for (Value v : iterArgs)
286     bodyBlock.addArgument(v.getType(), v.getLoc());
287 
288   // Create the default terminator if the builder is not provided and if the
289   // iteration arguments are not provided. Otherwise, leave this to the caller
290   // because we don't know which values to return from the loop.
291   if (iterArgs.empty() && !bodyBuilder) {
292     ForOp::ensureTerminator(*bodyRegion, builder, result.location);
293   } else if (bodyBuilder) {
294     OpBuilder::InsertionGuard guard(builder);
295     builder.setInsertionPointToStart(&bodyBlock);
296     bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
297                 bodyBlock.getArguments().drop_front());
298   }
299 }
300 
301 LogicalResult ForOp::verify() {
302   if (auto cst = getStep().getDefiningOp<arith::ConstantIndexOp>())
303     if (cst.value() <= 0)
304       return emitOpError("constant step operand must be positive");
305 
306   auto opNumResults = getNumResults();
307   if (opNumResults == 0)
308     return success();
309   // If ForOp defines values, check that the number and types of
310   // the defined values match ForOp initial iter operands and backedge
311   // basic block arguments.
312   if (getNumIterOperands() != opNumResults)
313     return emitOpError(
314         "mismatch in number of loop-carried values and defined values");
315   return success();
316 }
317 
318 LogicalResult ForOp::verifyRegions() {
319   // Check that the body defines as single block argument for the induction
320   // variable.
321   auto *body = getBody();
322   if (!body->getArgument(0).getType().isIndex())
323     return emitOpError(
324         "expected body first argument to be an index argument for "
325         "the induction variable");
326 
327   auto opNumResults = getNumResults();
328   if (opNumResults == 0)
329     return success();
330 
331   if (getNumRegionIterArgs() != opNumResults)
332     return emitOpError(
333         "mismatch in number of basic block args and defined values");
334 
335   auto iterOperands = getIterOperands();
336   auto iterArgs = getRegionIterArgs();
337   auto opResults = getResults();
338   unsigned i = 0;
339   for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
340     if (std::get<0>(e).getType() != std::get<2>(e).getType())
341       return emitOpError() << "types mismatch between " << i
342                            << "th iter operand and defined value";
343     if (std::get<1>(e).getType() != std::get<2>(e).getType())
344       return emitOpError() << "types mismatch between " << i
345                            << "th iter region arg and defined value";
346 
347     i++;
348   }
349   return success();
350 }
351 
352 Optional<Value> ForOp::getSingleInductionVar() { return getInductionVar(); }
353 
354 Optional<OpFoldResult> ForOp::getSingleLowerBound() {
355   return OpFoldResult(getLowerBound());
356 }
357 
358 Optional<OpFoldResult> ForOp::getSingleStep() {
359   return OpFoldResult(getStep());
360 }
361 
362 Optional<OpFoldResult> ForOp::getSingleUpperBound() {
363   return OpFoldResult(getUpperBound());
364 }
365 
366 /// Prints the initialization list in the form of
367 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
368 /// where 'inner' values are assumed to be region arguments and 'outer' values
369 /// are regular SSA values.
370 static void printInitializationList(OpAsmPrinter &p,
371                                     Block::BlockArgListType blocksArgs,
372                                     ValueRange initializers,
373                                     StringRef prefix = "") {
374   assert(blocksArgs.size() == initializers.size() &&
375          "expected same length of arguments and initializers");
376   if (initializers.empty())
377     return;
378 
379   p << prefix << '(';
380   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
381     p << std::get<0>(it) << " = " << std::get<1>(it);
382   });
383   p << ")";
384 }
385 
386 void ForOp::print(OpAsmPrinter &p) {
387   p << " " << getInductionVar() << " = " << getLowerBound() << " to "
388     << getUpperBound() << " step " << getStep();
389 
390   printInitializationList(p, getRegionIterArgs(), getIterOperands(),
391                           " iter_args");
392   if (!getIterOperands().empty())
393     p << " -> (" << getIterOperands().getTypes() << ')';
394   p << ' ';
395   p.printRegion(getRegion(),
396                 /*printEntryBlockArgs=*/false,
397                 /*printBlockTerminators=*/hasIterOperands());
398   p.printOptionalAttrDict((*this)->getAttrs());
399 }
400 
401 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
402   auto &builder = parser.getBuilder();
403   Type indexType = builder.getIndexType();
404 
405   OpAsmParser::Argument inductionVariable;
406   inductionVariable.type = indexType;
407   OpAsmParser::UnresolvedOperand lb, ub, step;
408 
409   // Parse the induction variable followed by '='.
410   if (parser.parseArgument(inductionVariable) || parser.parseEqual() ||
411       // Parse loop bounds.
412       parser.parseOperand(lb) ||
413       parser.resolveOperand(lb, indexType, result.operands) ||
414       parser.parseKeyword("to") || parser.parseOperand(ub) ||
415       parser.resolveOperand(ub, indexType, result.operands) ||
416       parser.parseKeyword("step") || parser.parseOperand(step) ||
417       parser.resolveOperand(step, indexType, result.operands))
418     return failure();
419 
420   // Parse the optional initial iteration arguments.
421   SmallVector<OpAsmParser::Argument, 4> regionArgs;
422   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
423   regionArgs.push_back(inductionVariable);
424 
425   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
426     // Parse assignment list and results type list.
427     if (parser.parseAssignmentList(regionArgs, operands) ||
428         parser.parseArrowTypeList(result.types))
429       return failure();
430 
431     // Resolve input operands.
432     for (auto argOperandType :
433          llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
434       Type type = std::get<2>(argOperandType);
435       std::get<0>(argOperandType).type = type;
436       if (parser.resolveOperand(std::get<1>(argOperandType), type,
437                                 result.operands))
438         return failure();
439     }
440   }
441 
442   if (regionArgs.size() != result.types.size() + 1)
443     return parser.emitError(
444         parser.getNameLoc(),
445         "mismatch in number of loop-carried values and defined values");
446 
447   // Parse the body region.
448   Region *body = result.addRegion();
449   if (parser.parseRegion(*body, regionArgs))
450     return failure();
451 
452   ForOp::ensureTerminator(*body, builder, result.location);
453 
454   // Parse the optional attribute list.
455   if (parser.parseOptionalAttrDict(result.attributes))
456     return failure();
457 
458   return success();
459 }
460 
461 Region &ForOp::getLoopBody() { return getRegion(); }
462 
463 ForOp mlir::scf::getForInductionVarOwner(Value val) {
464   auto ivArg = val.dyn_cast<BlockArgument>();
465   if (!ivArg)
466     return ForOp();
467   assert(ivArg.getOwner() && "unlinked block argument");
468   auto *containingOp = ivArg.getOwner()->getParentOp();
469   return dyn_cast_or_null<ForOp>(containingOp);
470 }
471 
472 /// Return operands used when entering the region at 'index'. These operands
473 /// correspond to the loop iterator operands, i.e., those excluding the
474 /// induction variable. LoopOp only has one region, so 0 is the only valid value
475 /// for `index`.
476 OperandRange ForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
477   assert(index && *index == 0 && "invalid region index");
478 
479   // The initial operands map to the loop arguments after the induction
480   // variable.
481   return getInitArgs();
482 }
483 
484 /// Given the region at `index`, or the parent operation if `index` is None,
485 /// return the successor regions. These are the regions that may be selected
486 /// during the flow of control. `operands` is a set of optional attributes that
487 /// correspond to a constant value for each operand, or null if that operand is
488 /// not a constant.
489 void ForOp::getSuccessorRegions(Optional<unsigned> index,
490                                 ArrayRef<Attribute> operands,
491                                 SmallVectorImpl<RegionSuccessor> &regions) {
492   // If the predecessor is the ForOp, branch into the body using the iterator
493   // arguments.
494   if (!index) {
495     regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
496     return;
497   }
498 
499   // Otherwise, the loop may branch back to itself or the parent operation.
500   assert(*index == 0 && "expected loop region");
501   regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
502   regions.push_back(RegionSuccessor(getResults()));
503 }
504 
505 LoopNest mlir::scf::buildLoopNest(
506     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
507     ValueRange steps, ValueRange iterArgs,
508     function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
509         bodyBuilder) {
510   assert(lbs.size() == ubs.size() &&
511          "expected the same number of lower and upper bounds");
512   assert(lbs.size() == steps.size() &&
513          "expected the same number of lower bounds and steps");
514 
515   // If there are no bounds, call the body-building function and return early.
516   if (lbs.empty()) {
517     ValueVector results =
518         bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
519                     : ValueVector();
520     assert(results.size() == iterArgs.size() &&
521            "loop nest body must return as many values as loop has iteration "
522            "arguments");
523     return LoopNest();
524   }
525 
526   // First, create the loop structure iteratively using the body-builder
527   // callback of `ForOp::build`. Do not create `YieldOp`s yet.
528   OpBuilder::InsertionGuard guard(builder);
529   SmallVector<scf::ForOp, 4> loops;
530   SmallVector<Value, 4> ivs;
531   loops.reserve(lbs.size());
532   ivs.reserve(lbs.size());
533   ValueRange currentIterArgs = iterArgs;
534   Location currentLoc = loc;
535   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
536     auto loop = builder.create<scf::ForOp>(
537         currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
538         [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
539             ValueRange args) {
540           ivs.push_back(iv);
541           // It is safe to store ValueRange args because it points to block
542           // arguments of a loop operation that we also own.
543           currentIterArgs = args;
544           currentLoc = nestedLoc;
545         });
546     // Set the builder to point to the body of the newly created loop. We don't
547     // do this in the callback because the builder is reset when the callback
548     // returns.
549     builder.setInsertionPointToStart(loop.getBody());
550     loops.push_back(loop);
551   }
552 
553   // For all loops but the innermost, yield the results of the nested loop.
554   for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
555     builder.setInsertionPointToEnd(loops[i].getBody());
556     builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
557   }
558 
559   // In the body of the innermost loop, call the body building function if any
560   // and yield its results.
561   builder.setInsertionPointToStart(loops.back().getBody());
562   ValueVector results = bodyBuilder
563                             ? bodyBuilder(builder, currentLoc, ivs,
564                                           loops.back().getRegionIterArgs())
565                             : ValueVector();
566   assert(results.size() == iterArgs.size() &&
567          "loop nest body must return as many values as loop has iteration "
568          "arguments");
569   builder.setInsertionPointToEnd(loops.back().getBody());
570   builder.create<scf::YieldOp>(loc, results);
571 
572   // Return the loops.
573   LoopNest res;
574   res.loops.assign(loops.begin(), loops.end());
575   return res;
576 }
577 
578 LoopNest mlir::scf::buildLoopNest(
579     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
580     ValueRange steps,
581     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
582   // Delegate to the main function by wrapping the body builder.
583   return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
584                        [&bodyBuilder](OpBuilder &nestedBuilder,
585                                       Location nestedLoc, ValueRange ivs,
586                                       ValueRange) -> ValueVector {
587                          if (bodyBuilder)
588                            bodyBuilder(nestedBuilder, nestedLoc, ivs);
589                          return {};
590                        });
591 }
592 
593 namespace {
594 // Fold away ForOp iter arguments when:
595 // 1) The op yields the iter arguments.
596 // 2) The iter arguments have no use and the corresponding outer region
597 // iterators (inputs) are yielded.
598 // 3) The iter arguments have no use and the corresponding (operation) results
599 // have no use.
600 //
601 // These arguments must be defined outside of
602 // the ForOp region and can just be forwarded after simplifying the op inits,
603 // yields and returns.
604 //
605 // The implementation uses `mergeBlockBefore` to steal the content of the
606 // original ForOp and avoid cloning.
607 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
608   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
609 
610   LogicalResult matchAndRewrite(scf::ForOp forOp,
611                                 PatternRewriter &rewriter) const final {
612     bool canonicalize = false;
613     Block &block = forOp.getRegion().front();
614     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
615 
616     // An internal flat vector of block transfer
617     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
618     // transformed block argument mappings. This plays the role of a
619     // BlockAndValueMapping for the particular use case of calling into
620     // `mergeBlockBefore`.
621     SmallVector<bool, 4> keepMask;
622     keepMask.reserve(yieldOp.getNumOperands());
623     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
624         newResultValues;
625     newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
626     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
627     newIterArgs.reserve(forOp.getNumIterOperands());
628     newYieldValues.reserve(yieldOp.getNumOperands());
629     newResultValues.reserve(forOp.getNumResults());
630     for (auto it : llvm::zip(forOp.getIterOperands(),   // iter from outside
631                              forOp.getRegionIterArgs(), // iter inside region
632                              forOp.getResults(),        // op results
633                              yieldOp.getOperands()      // iter yield
634                              )) {
635       // Forwarded is `true` when:
636       // 1) The region `iter` argument is yielded.
637       // 2) The region `iter` argument has no use, and the corresponding iter
638       // operand (input) is yielded.
639       // 3) The region `iter` argument has no use, and the corresponding op
640       // result has no use.
641       bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
642                         (std::get<1>(it).use_empty() &&
643                          (std::get<0>(it) == std::get<3>(it) ||
644                           std::get<2>(it).use_empty())));
645       keepMask.push_back(!forwarded);
646       canonicalize |= forwarded;
647       if (forwarded) {
648         newBlockTransferArgs.push_back(std::get<0>(it));
649         newResultValues.push_back(std::get<0>(it));
650         continue;
651       }
652       newIterArgs.push_back(std::get<0>(it));
653       newYieldValues.push_back(std::get<3>(it));
654       newBlockTransferArgs.push_back(Value()); // placeholder with null value
655       newResultValues.push_back(Value());      // placeholder with null value
656     }
657 
658     if (!canonicalize)
659       return failure();
660 
661     scf::ForOp newForOp = rewriter.create<scf::ForOp>(
662         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
663         forOp.getStep(), newIterArgs);
664     newForOp->setAttrs(forOp->getAttrs());
665     Block &newBlock = newForOp.getRegion().front();
666 
667     // Replace the null placeholders with newly constructed values.
668     newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
669     for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
670          idx != e; ++idx) {
671       Value &blockTransferArg = newBlockTransferArgs[1 + idx];
672       Value &newResultVal = newResultValues[idx];
673       assert((blockTransferArg && newResultVal) ||
674              (!blockTransferArg && !newResultVal));
675       if (!blockTransferArg) {
676         blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
677         newResultVal = newForOp.getResult(collapsedIdx++);
678       }
679     }
680 
681     Block &oldBlock = forOp.getRegion().front();
682     assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
683            "unexpected argument size mismatch");
684 
685     // No results case: the scf::ForOp builder already created a zero
686     // result terminator. Merge before this terminator and just get rid of the
687     // original terminator that has been merged in.
688     if (newIterArgs.empty()) {
689       auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
690       rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
691       rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
692       rewriter.replaceOp(forOp, newResultValues);
693       return success();
694     }
695 
696     // No terminator case: merge and rewrite the merged terminator.
697     auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
698       OpBuilder::InsertionGuard g(rewriter);
699       rewriter.setInsertionPoint(mergedTerminator);
700       SmallVector<Value, 4> filteredOperands;
701       filteredOperands.reserve(newResultValues.size());
702       for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
703         if (keepMask[idx])
704           filteredOperands.push_back(mergedTerminator.getOperand(idx));
705       rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
706                                     filteredOperands);
707     };
708 
709     rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
710     auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
711     cloneFilteredTerminator(mergedYieldOp);
712     rewriter.eraseOp(mergedYieldOp);
713     rewriter.replaceOp(forOp, newResultValues);
714     return success();
715   }
716 };
717 
718 /// Rewriting pattern that erases loops that are known not to iterate, replaces
719 /// single-iteration loops with their bodies, and removes empty loops that
720 /// iterate at least once and only return values defined outside of the loop.
721 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
722   using OpRewritePattern<ForOp>::OpRewritePattern;
723 
724   LogicalResult matchAndRewrite(ForOp op,
725                                 PatternRewriter &rewriter) const override {
726     // If the upper bound is the same as the lower bound, the loop does not
727     // iterate, just remove it.
728     if (op.getLowerBound() == op.getUpperBound()) {
729       rewriter.replaceOp(op, op.getIterOperands());
730       return success();
731     }
732 
733     auto lb = op.getLowerBound().getDefiningOp<arith::ConstantOp>();
734     auto ub = op.getUpperBound().getDefiningOp<arith::ConstantOp>();
735     if (!lb || !ub)
736       return failure();
737 
738     // If the loop is known to have 0 iterations, remove it.
739     llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
740     llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
741     if (lbValue.sge(ubValue)) {
742       rewriter.replaceOp(op, op.getIterOperands());
743       return success();
744     }
745 
746     auto step = op.getStep().getDefiningOp<arith::ConstantOp>();
747     if (!step)
748       return failure();
749 
750     // If the loop is known to have 1 iteration, inline its body and remove the
751     // loop.
752     llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
753     if ((lbValue + stepValue).sge(ubValue)) {
754       SmallVector<Value, 4> blockArgs;
755       blockArgs.reserve(op.getNumIterOperands() + 1);
756       blockArgs.push_back(op.getLowerBound());
757       llvm::append_range(blockArgs, op.getIterOperands());
758       replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
759       return success();
760     }
761 
762     // Now we are left with loops that have more than 1 iterations.
763     Block &block = op.getRegion().front();
764     if (!llvm::hasSingleElement(block))
765       return failure();
766     // If the loop is empty, iterates at least once, and only returns values
767     // defined outside of the loop, remove it and replace it with yield values.
768     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
769     auto yieldOperands = yieldOp.getOperands();
770     if (llvm::any_of(yieldOperands,
771                      [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
772       return failure();
773     rewriter.replaceOp(op, yieldOperands);
774     return success();
775   }
776 };
777 
778 /// Perform a replacement of one iter OpOperand of an scf.for to the
779 /// `replacement` value which is expected to be the source of a tensor.cast.
780 /// tensor.cast ops are inserted inside the block to account for the type cast.
781 static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
782                                            OpOperand &operand,
783                                            Value replacement) {
784   Type oldType = operand.get().getType(), newType = replacement.getType();
785   assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
786          "expected ranked tensor types");
787 
788   // 1. Create new iter operands, exactly 1 is replaced.
789   ForOp forOp = cast<ForOp>(operand.getOwner());
790   assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
791          "expected an iter OpOperand");
792   if (operand.get().getType() == replacement.getType())
793     return forOp;
794   SmallVector<Value> newIterOperands;
795   for (OpOperand &opOperand : forOp.getIterOpOperands()) {
796     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
797       newIterOperands.push_back(replacement);
798       continue;
799     }
800     newIterOperands.push_back(opOperand.get());
801   }
802 
803   // 2. Create the new forOp shell.
804   scf::ForOp newForOp = rewriter.create<scf::ForOp>(
805       forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806       forOp.getStep(), newIterOperands);
807   newForOp->setAttrs(forOp->getAttrs());
808   Block &newBlock = newForOp.getRegion().front();
809   SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
810                                              newBlock.getArguments().end());
811 
812   // 3. Inject an incoming cast op at the beginning of the block for the bbArg
813   // corresponding to the `replacement` value.
814   OpBuilder::InsertionGuard g(rewriter);
815   rewriter.setInsertionPoint(&newBlock, newBlock.begin());
816   BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
817       newForOp->getOpOperand(operand.getOperandNumber()));
818   Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
819                                                  newRegionIterArg);
820   newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
821 
822   // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
823   Block &oldBlock = forOp.getRegion().front();
824   rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
825 
826   // 5. Inject an outgoing cast op at the end of the block and yield it instead.
827   auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
828   rewriter.setInsertionPoint(clonedYieldOp);
829   unsigned yieldIdx =
830       newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
831   Value castOut = rewriter.create<tensor::CastOp>(
832       newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
833   SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
834   newYieldOperands[yieldIdx] = castOut;
835   rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
836   rewriter.eraseOp(clonedYieldOp);
837 
838   // 6. Inject an outgoing cast op after the forOp.
839   rewriter.setInsertionPointAfter(newForOp);
840   SmallVector<Value> newResults = newForOp.getResults();
841   newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
842       newForOp.getLoc(), oldType, newResults[yieldIdx]);
843 
844   return newForOp;
845 }
846 
847 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
848 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
849 ///
850 /// ```
851 ///   %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
852 ///   %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
853 ///      -> (tensor<?x?xf32>) {
854 ///     %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
855 ///     scf.yield %2 : tensor<?x?xf32>
856 ///   }
857 ///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
858 ///   use_of(%2)
859 /// ```
860 ///
861 /// folds into:
862 ///
863 /// ```
864 ///   %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
865 ///       -> (tensor<32x1024xf32>) {
866 ///     %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
867 ///     %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
868 ///     %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
869 ///     scf.yield %4 : tensor<32x1024xf32>
870 ///   }
871 ///   use_of(%0)
872 /// ```
873 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
874   using OpRewritePattern<ForOp>::OpRewritePattern;
875 
876   LogicalResult matchAndRewrite(ForOp op,
877                                 PatternRewriter &rewriter) const override {
878     for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
879       OpOperand &iterOpOperand = std::get<0>(it);
880       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
881       if (!incomingCast)
882         continue;
883       if (!std::get<1>(it).hasOneUse())
884         continue;
885       auto outgoingCastOp =
886           dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
887       if (!outgoingCastOp)
888         continue;
889 
890       // Must be a tensor.cast op pair with matching types.
891       if (outgoingCastOp.getResult().getType() !=
892           incomingCast.getSource().getType())
893         continue;
894 
895       // Create a new ForOp with that iter operand replaced.
896       auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
897                                                     incomingCast.getSource());
898 
899       // Insert outgoing cast and use it to replace the corresponding result.
900       rewriter.setInsertionPointAfter(newForOp);
901       SmallVector<Value> replacements = newForOp.getResults();
902       unsigned returnIdx =
903           iterOpOperand.getOperandNumber() - op.getNumControlOperands();
904       replacements[returnIdx] = rewriter.create<tensor::CastOp>(
905           op.getLoc(), incomingCast.getDest().getType(),
906           replacements[returnIdx]);
907       rewriter.replaceOp(op, replacements);
908       return success();
909     }
910     return failure();
911   }
912 };
913 
914 /// Canonicalize the iter_args of an scf::ForOp that involve a
915 /// `bufferization.to_tensor` and for which only the last loop iteration is
916 /// actually visible outside of the loop. The canonicalization looks for a
917 /// pattern such as:
918 /// ```
919 ///    %t0 = ... : tensor_type
920 ///    %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
921 ///      ...
922 ///      // %m is either buffer_cast(%bb00) or defined above the loop
923 ///      %m... : memref_type
924 ///      ... // uses of %m with potential inplace updates
925 ///      %new_tensor = bufferization.to_tensor %m : memref_type
926 ///      ...
927 ///      scf.yield %new_tensor : tensor_type
928 ///    }
929 /// ```
930 ///
931 /// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
932 /// `%m = buffer_cast %bb0` op that feeds into the yielded
933 /// `bufferization.to_tensor` op.
934 ///
935 /// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
936 /// occurs between `bufferization.to_tensor and yield then the value %0
937 /// visible outside of the loop is the last `bufferization.to_tensor`
938 /// produced in the loop.
939 ///
940 /// For now, we approximate the absence of aliasing by only supporting the case
941 /// when the bufferization.to_tensor is the operation immediately preceding
942 /// the yield.
943 //
944 /// The canonicalization rewrites the pattern as:
945 /// ```
946 ///    // %m is either a buffer_cast or defined above
947 ///    %m... : memref_type
948 ///    scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
949 ///      ... // uses of %m with potential inplace updates
950 ///      scf.yield %bb0: tensor_type
951 ///    }
952 ///    %0 = bufferization.to_tensor %m : memref_type
953 /// ```
954 ///
955 /// A later bbArg canonicalization will further rewrite as:
956 /// ```
957 ///    // %m is either a buffer_cast or defined above
958 ///    %m... : memref_type
959 ///    scf.for ... { // no iter_args
960 ///      ... // uses of %m with potential inplace updates
961 ///    }
962 ///    %0 = bufferization.to_tensor %m : memref_type
963 /// ```
964 struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
965   using OpRewritePattern<ForOp>::OpRewritePattern;
966 
967   LogicalResult matchAndRewrite(ForOp forOp,
968                                 PatternRewriter &rewriter) const override {
969     assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
970            "unexpected multiple blocks");
971 
972     Location loc = forOp.getLoc();
973     DenseMap<Value, Value> replacements;
974     for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
975       unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
976       auto yieldOp =
977           cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
978       Value yieldVal = yieldOp->getOperand(idx);
979       auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
980       bool isTensor = bbArg.getType().isa<TensorType>();
981 
982       bufferization::ToMemrefOp tensorToMemref;
983       // Either bbArg has no use or it has a single buffer_cast use.
984       if (bbArg.hasOneUse())
985         tensorToMemref =
986             dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
987       if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
988         continue;
989       // If tensorToMemref is present, it must feed into the `ToTensorOp`.
990       if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
991         continue;
992       // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
993       // must be before `ToTensorOp` in the block so that the lastWrite
994       // property is not subject to additional side-effects.
995       // For now, we only support the case when ToTensorOp appears
996       // immediately before the terminator.
997       if (tensorLoadOp->getNextNode() != yieldOp)
998         continue;
999 
1000       // Clone the optional tensorToMemref before forOp.
1001       if (tensorToMemref) {
1002         rewriter.setInsertionPoint(forOp);
1003         rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
1004             tensorToMemref, tensorToMemref.getMemref().getType(),
1005             tensorToMemref.getTensor());
1006       }
1007 
1008       // Clone the tensorLoad after forOp.
1009       rewriter.setInsertionPointAfter(forOp);
1010       Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
1011           loc, tensorLoadOp.getMemref());
1012       Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
1013       replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1014 
1015       // Make the terminator just yield the bbArg, the old tensorLoadOp + the
1016       // old bbArg (that is now directly yielded) will canonicalize away.
1017       rewriter.startRootUpdate(yieldOp);
1018       yieldOp.setOperand(idx, bbArg);
1019       rewriter.finalizeRootUpdate(yieldOp);
1020     }
1021     if (replacements.empty())
1022       return failure();
1023 
1024     // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
1025     // replaces the whole op and erase it unconditionally. This is wrong for
1026     // `forOp` as it generally contains ops with side effects.
1027     // Instead, use `rewriter.replaceOpWithIf`.
1028     SmallVector<Value> newResults;
1029     newResults.reserve(forOp.getNumResults());
1030     for (Value v : forOp.getResults()) {
1031       auto it = replacements.find(v);
1032       newResults.push_back((it != replacements.end()) ? it->second : v);
1033     }
1034     unsigned idx = 0;
1035     rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
1036       return op.get() != newResults[idx++];
1037     });
1038     return success();
1039   }
1040 };
1041 } // namespace
1042 
1043 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1044                                         MLIRContext *context) {
1045   results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1046               LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // ForeachThreadOp
1051 //===----------------------------------------------------------------------===//
1052 
1053 LogicalResult ForeachThreadOp::verify() {
1054   // Call terminator's verify to produce most informative error messages.
1055   if (failed(getTerminator().verify()))
1056     return failure();
1057 
1058   // Check that the body defines as single block argument for the thread index.
1059   auto *body = getBody();
1060   if (body->getNumArguments() != getRank())
1061     return emitOpError("region expects ") << getRank() << " arguments";
1062 
1063   // Verify consistency between the result types and the terminator.
1064   auto terminatorTypes = getTerminator().yieldedTypes();
1065   auto opResults = getResults();
1066   if (opResults.size() != terminatorTypes.size())
1067     return emitOpError("produces ")
1068            << opResults.size() << " results, but its terminator yields "
1069            << terminatorTypes.size() << " value(s)";
1070   unsigned i = 0;
1071   for (auto e : llvm::zip(terminatorTypes, opResults)) {
1072     if (std::get<0>(e) != std::get<1>(e).getType())
1073       return emitOpError() << "type mismatch between result " << i << " ("
1074                            << std::get<1>(e).getType() << ") and terminator ("
1075                            << std::get<0>(e) << ")";
1076     i++;
1077   }
1078   return success();
1079 }
1080 
1081 void ForeachThreadOp::print(OpAsmPrinter &p) {
1082   p << " (";
1083   llvm::interleaveComma(getThreadIndices(), p);
1084   p << ") in (";
1085   llvm::interleaveComma(getNumThreads(), p);
1086   p << ") -> (" << getResultTypes() << ") ";
1087   p.printRegion(getRegion(),
1088                 /*printEntryBlockArgs=*/false,
1089                 /*printBlockTerminators=*/getNumResults() > 0);
1090   p.printOptionalAttrDict(getOperation()->getAttrs());
1091 }
1092 
1093 ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
1094                                    OperationState &result) {
1095   auto &builder = parser.getBuilder();
1096   // Parse an opening `(` followed by thread index variables followed by `)`
1097   // TODO: when we can refer to such "induction variable"-like handles from the
1098   // declarative assembly format, we can implement the parser as a custom hook.
1099   SmallVector<OpAsmParser::Argument, 4> threadIndices;
1100   if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
1101     return failure();
1102 
1103   // Parse `in` threadNums.
1104   SmallVector<OpAsmParser::UnresolvedOperand, 4> threadNums;
1105   if (parser.parseKeyword("in") ||
1106       parser.parseOperandList(threadNums, threadIndices.size(),
1107                               OpAsmParser::Delimiter::Paren) ||
1108       parser.resolveOperands(threadNums, builder.getIndexType(),
1109                              result.operands))
1110     return failure();
1111 
1112   // Parse optional results.
1113   if (parser.parseOptionalArrowTypeList(result.types))
1114     return failure();
1115 
1116   // Parse region.
1117   std::unique_ptr<Region> region = std::make_unique<Region>();
1118   for (auto &idx : threadIndices)
1119     idx.type = builder.getIndexType();
1120   if (parser.parseRegion(*region, threadIndices))
1121     return failure();
1122 
1123   // Ensure terminator and move region.
1124   OpBuilder b(builder.getContext());
1125   ForeachThreadOp::ensureTerminator(*region, b, result.location);
1126   result.addRegion(std::move(region));
1127 
1128   // Parse the optional attribute list.
1129   if (parser.parseOptionalAttrDict(result.attributes))
1130     return failure();
1131 
1132   return success();
1133 }
1134 
1135 // Bodyless builder, result types must be specified.
1136 void ForeachThreadOp::build(mlir::OpBuilder &builder,
1137                             mlir::OperationState &result, TypeRange resultTypes,
1138                             ValueRange numThreads,
1139                             ArrayRef<int64_t> threadDimMapping) {
1140   result.addOperands(numThreads);
1141   result.addAttribute(
1142       // TODO: getThreadDimMappingAttrName() but it is not a static member.
1143       "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
1144 
1145   Region *bodyRegion = result.addRegion();
1146   OpBuilder::InsertionGuard g(builder);
1147   // createBlock sets the IP inside the block.
1148   // Generally we would guard against that but the default ensureTerminator impl
1149   // expects it ..
1150   builder.createBlock(bodyRegion);
1151   Block &bodyBlock = bodyRegion->front();
1152   bodyBlock.addArguments(
1153       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1154       SmallVector<Location>(numThreads.size(), result.location));
1155   ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
1156   result.addTypes(resultTypes);
1157 }
1158 
1159 // Builder that takes a bodyBuilder lambda, result types are inferred from
1160 // the terminator.
1161 void ForeachThreadOp::build(
1162     mlir::OpBuilder &builder, mlir::OperationState &result,
1163     ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
1164     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1165   result.addOperands(numThreads);
1166   result.addAttribute(
1167       // TODO: getThreadDimMappingAttrName() but it is not a static member.
1168       "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
1169 
1170   OpBuilder::InsertionGuard g(builder);
1171   Region *bodyRegion = result.addRegion();
1172   builder.createBlock(bodyRegion);
1173   Block &bodyBlock = bodyRegion->front();
1174   bodyBlock.addArguments(
1175       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1176       SmallVector<Location>(numThreads.size(), result.location));
1177 
1178   OpBuilder::InsertionGuard guard(builder);
1179   builder.setInsertionPointToStart(&bodyBlock);
1180   bodyBuilder(builder, result.location, bodyBlock.getArguments());
1181   auto terminator =
1182       llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
1183   assert(terminator &&
1184          "expected bodyBuilder to create PerformConcurrentlyOp terminator");
1185   result.addTypes(terminator.yieldedTypes());
1186 }
1187 
1188 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1189 // unaware of the fact that our terminator also needs a region to be
1190 // well-formed. We override it here to ensure that we do the right thing.
1191 void ForeachThreadOp::ensureTerminator(Region &region, OpBuilder &builder,
1192                                        Location loc) {
1193   OpTrait::SingleBlockImplicitTerminator<PerformConcurrentlyOp>::Impl<
1194       ForeachThreadOp>::ensureTerminator(region, builder, loc);
1195   auto terminator =
1196       llvm::dyn_cast<PerformConcurrentlyOp>(region.front().getTerminator());
1197   if (terminator.getRegion().empty())
1198     builder.createBlock(&terminator.getRegion());
1199 }
1200 
1201 PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
1202   return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
1203 }
1204 
1205 ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
1206   auto tidxArg = val.dyn_cast<BlockArgument>();
1207   if (!tidxArg)
1208     return ForeachThreadOp();
1209   assert(tidxArg.getOwner() && "unlinked block argument");
1210   auto *containingOp = tidxArg.getOwner()->getParentOp();
1211   return dyn_cast<ForeachThreadOp>(containingOp);
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // ParallelInsertSliceOp
1216 //===----------------------------------------------------------------------===//
1217 
1218 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
1219 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
1220                                   Value source, Value dest,
1221                                   ArrayRef<OpFoldResult> offsets,
1222                                   ArrayRef<OpFoldResult> sizes,
1223                                   ArrayRef<OpFoldResult> strides,
1224                                   ArrayRef<NamedAttribute> attrs) {
1225   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1226   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1227   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1228                              ShapedType::kDynamicStrideOrOffset);
1229   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1230                              ShapedType::kDynamicSize);
1231   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1232                              ShapedType::kDynamicStrideOrOffset);
1233   build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
1234         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1235         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1236   result.addAttributes(attrs);
1237 }
1238 
1239 // Build a ParallelInsertSliceOp with dynamic entries.
1240 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
1241                                   Value source, Value dest, ValueRange offsets,
1242                                   ValueRange sizes, ValueRange strides,
1243                                   ArrayRef<NamedAttribute> attrs) {
1244   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1245       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1246   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1247       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1248   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1249       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1250   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1251 }
1252 
1253 namespace {
1254 /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
1255 class ParallelInsertSliceOpConstantArgumentFolder final
1256     : public OpRewritePattern<ParallelInsertSliceOp> {
1257 public:
1258   using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
1259 
1260   LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
1261                                 PatternRewriter &rewriter) const override {
1262     // No constant operand, just return.
1263     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1264           return matchPattern(operand, matchConstantIndex());
1265         }))
1266       return failure();
1267 
1268     // At least one of offsets/sizes/strides is a new constant.
1269     // Form the new list of operands and constant attributes from the
1270     // existing.
1271     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1272     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1273     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1274     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1275     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1276     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1277 
1278     // Create the new op in canonical form.
1279     rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
1280         insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
1281         mixedOffsets, mixedSizes, mixedStrides);
1282     return success();
1283   }
1284 };
1285 } // namespace
1286 
1287 /// Fold a parallel_insert_slice source coming from a tensor.cast op.
1288 ///
1289 /// Example:
1290 /// ```
1291 /// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
1292 ///   %1 = compute_some_tensor() : tensor<64xf32>
1293 ///   %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
1294 ///   scf.foreach_thread.perform_concurrently {
1295 ///     scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
1296 ///        tensor<?xf32> into tensor<128xf32>
1297 ///   }
1298 /// }
1299 /// ```
1300 ///
1301 /// is folded into:
1302 /// ```
1303 /// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
1304 ///   %1 = compute_some_tensor() : tensor<64xf32>
1305 ///   scf.foreach_thread.perform_concurrently {
1306 ///     scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
1307 ///        tensor<64xf32> into tensor<128xf32>
1308 ///   }
1309 /// }
1310 /// ```
1311 LogicalResult
1312 ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
1313                             SmallVectorImpl<OpFoldResult> &results) {
1314   auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
1315   if (!sourceCast)
1316     return failure();
1317   getSourceMutable().assign(sourceCast.getSource());
1318   return success();
1319 }
1320 
1321 void ParallelInsertSliceOp::getCanonicalizationPatterns(
1322     RewritePatternSet &results, MLIRContext *context) {
1323   results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // PerformConcurrentlyOp
1328 //===----------------------------------------------------------------------===//
1329 
1330 // Build a PerformConcurrentlyOp with mixed static and dynamic entries.
1331 void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
1332   OpBuilder::InsertionGuard g(b);
1333   Region *bodyRegion = result.addRegion();
1334   b.createBlock(bodyRegion);
1335 }
1336 
1337 LogicalResult PerformConcurrentlyOp::verify() {
1338   // TODO: PerformConcurrentlyOpInterface.
1339   for (const Operation &op : getRegion().front().getOperations())
1340     if (!isa<ParallelInsertSliceOp>(op))
1341       return emitOpError(
1342           "expected only scf.foreach_thread.parallel_insert_slice ops");
1343   return success();
1344 }
1345 
1346 void PerformConcurrentlyOp::print(OpAsmPrinter &p) {
1347   p << " ";
1348   p.printRegion(getRegion(),
1349                 /*printEntryBlockArgs=*/false,
1350                 /*printBlockTerminators=*/false);
1351   p.printOptionalAttrDict(getOperation()->getAttrs());
1352 }
1353 
1354 ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
1355                                          OperationState &result) {
1356   auto &builder = parser.getBuilder();
1357 
1358   SmallVector<OpAsmParser::Argument, 8> regionOperands;
1359   std::unique_ptr<Region> region = std::make_unique<Region>();
1360   if (parser.parseRegion(*region, regionOperands))
1361     return failure();
1362 
1363   if (region->empty())
1364     OpBuilder(builder.getContext()).createBlock(region.get());
1365   result.addRegion(std::move(region));
1366 
1367   // Parse the optional attribute list.
1368   if (parser.parseOptionalAttrDict(result.attributes))
1369     return failure();
1370   return success();
1371 }
1372 
1373 SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
1374   return llvm::to_vector<4>(
1375       llvm::map_range(this->yieldingOps(), [](Operation &op) {
1376         auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
1377         return insertSliceOp ? insertSliceOp.yieldedType() : Type();
1378       }));
1379 }
1380 
1381 llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::yieldingOps() {
1382   return getRegion().front().getOperations();
1383 }
1384 
1385 //===----------------------------------------------------------------------===//
1386 // IfOp
1387 //===----------------------------------------------------------------------===//
1388 
1389 bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
1390   assert(a && "expected non-empty operation");
1391   assert(b && "expected non-empty operation");
1392 
1393   IfOp ifOp = a->getParentOfType<IfOp>();
1394   while (ifOp) {
1395     // Check if b is inside ifOp. (We already know that a is.)
1396     if (ifOp->isProperAncestor(b))
1397       // b is contained in ifOp. a and b are in mutually exclusive branches if
1398       // they are in different blocks of ifOp.
1399       return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1400              static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1401     // Check next enclosing IfOp.
1402     ifOp = ifOp->getParentOfType<IfOp>();
1403   }
1404 
1405   // Could not find a common IfOp among a's and b's ancestors.
1406   return false;
1407 }
1408 
1409 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1410                  bool withElseRegion) {
1411   build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
1412 }
1413 
1414 void IfOp::build(OpBuilder &builder, OperationState &result,
1415                  TypeRange resultTypes, Value cond, bool withElseRegion) {
1416   auto addTerminator = [&](OpBuilder &nested, Location loc) {
1417     if (resultTypes.empty())
1418       IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
1419                              loc);
1420   };
1421 
1422   build(builder, result, resultTypes, cond, addTerminator,
1423         withElseRegion ? addTerminator
1424                        : function_ref<void(OpBuilder &, Location)>());
1425 }
1426 
1427 void IfOp::build(OpBuilder &builder, OperationState &result,
1428                  TypeRange resultTypes, Value cond,
1429                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1430                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1431   assert(thenBuilder && "the builder callback for 'then' must be present");
1432 
1433   result.addOperands(cond);
1434   result.addTypes(resultTypes);
1435 
1436   OpBuilder::InsertionGuard guard(builder);
1437   Region *thenRegion = result.addRegion();
1438   builder.createBlock(thenRegion);
1439   thenBuilder(builder, result.location);
1440 
1441   Region *elseRegion = result.addRegion();
1442   if (!elseBuilder)
1443     return;
1444 
1445   builder.createBlock(elseRegion);
1446   elseBuilder(builder, result.location);
1447 }
1448 
1449 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1450                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1451                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1452   build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
1453 }
1454 
1455 LogicalResult IfOp::verify() {
1456   if (getNumResults() != 0 && getElseRegion().empty())
1457     return emitOpError("must have an else block if defining values");
1458   return success();
1459 }
1460 
1461 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1462   // Create the regions for 'then'.
1463   result.regions.reserve(2);
1464   Region *thenRegion = result.addRegion();
1465   Region *elseRegion = result.addRegion();
1466 
1467   auto &builder = parser.getBuilder();
1468   OpAsmParser::UnresolvedOperand cond;
1469   Type i1Type = builder.getIntegerType(1);
1470   if (parser.parseOperand(cond) ||
1471       parser.resolveOperand(cond, i1Type, result.operands))
1472     return failure();
1473   // Parse optional results type list.
1474   if (parser.parseOptionalArrowTypeList(result.types))
1475     return failure();
1476   // Parse the 'then' region.
1477   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1478     return failure();
1479   IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1480 
1481   // If we find an 'else' keyword then parse the 'else' region.
1482   if (!parser.parseOptionalKeyword("else")) {
1483     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1484       return failure();
1485     IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1486   }
1487 
1488   // Parse the optional attribute list.
1489   if (parser.parseOptionalAttrDict(result.attributes))
1490     return failure();
1491   return success();
1492 }
1493 
1494 void IfOp::print(OpAsmPrinter &p) {
1495   bool printBlockTerminators = false;
1496 
1497   p << " " << getCondition();
1498   if (!getResults().empty()) {
1499     p << " -> (" << getResultTypes() << ")";
1500     // Print yield explicitly if the op defines values.
1501     printBlockTerminators = true;
1502   }
1503   p << ' ';
1504   p.printRegion(getThenRegion(),
1505                 /*printEntryBlockArgs=*/false,
1506                 /*printBlockTerminators=*/printBlockTerminators);
1507 
1508   // Print the 'else' regions if it exists and has a block.
1509   auto &elseRegion = getElseRegion();
1510   if (!elseRegion.empty()) {
1511     p << " else ";
1512     p.printRegion(elseRegion,
1513                   /*printEntryBlockArgs=*/false,
1514                   /*printBlockTerminators=*/printBlockTerminators);
1515   }
1516 
1517   p.printOptionalAttrDict((*this)->getAttrs());
1518 }
1519 
1520 /// Given the region at `index`, or the parent operation if `index` is None,
1521 /// return the successor regions. These are the regions that may be selected
1522 /// during the flow of control. `operands` is a set of optional attributes that
1523 /// correspond to a constant value for each operand, or null if that operand is
1524 /// not a constant.
1525 void IfOp::getSuccessorRegions(Optional<unsigned> index,
1526                                ArrayRef<Attribute> operands,
1527                                SmallVectorImpl<RegionSuccessor> &regions) {
1528   // The `then` and the `else` region branch back to the parent operation.
1529   if (index) {
1530     regions.push_back(RegionSuccessor(getResults()));
1531     return;
1532   }
1533 
1534   // Don't consider the else region if it is empty.
1535   Region *elseRegion = &this->getElseRegion();
1536   if (elseRegion->empty())
1537     elseRegion = nullptr;
1538 
1539   // Otherwise, the successor is dependent on the condition.
1540   bool condition;
1541   if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1542     condition = condAttr.getValue().isOneValue();
1543   } else {
1544     // If the condition isn't constant, both regions may be executed.
1545     regions.push_back(RegionSuccessor(&getThenRegion()));
1546     // If the else region does not exist, it is not a viable successor.
1547     if (elseRegion)
1548       regions.push_back(RegionSuccessor(elseRegion));
1549     return;
1550   }
1551 
1552   // Add the successor regions using the condition.
1553   regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
1554 }
1555 
1556 LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
1557                          SmallVectorImpl<OpFoldResult> &results) {
1558   // if (!c) then A() else B() -> if c then B() else A()
1559   if (getElseRegion().empty())
1560     return failure();
1561 
1562   arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
1563   if (!xorStmt)
1564     return failure();
1565 
1566   if (!matchPattern(xorStmt.getRhs(), m_One()))
1567     return failure();
1568 
1569   getConditionMutable().assign(xorStmt.getLhs());
1570   Block *thenBlock = &getThenRegion().front();
1571   // It would be nicer to use iplist::swap, but that has no implemented
1572   // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
1573   getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
1574                                      getElseRegion().getBlocks());
1575   getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
1576                                      getThenRegion().getBlocks(), thenBlock);
1577   return success();
1578 }
1579 
1580 void IfOp::getRegionInvocationBounds(
1581     ArrayRef<Attribute> operands,
1582     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1583   if (auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
1584     // If the condition is known, then one region is known to be executed once
1585     // and the other zero times.
1586     invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
1587     invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
1588   } else {
1589     // Non-constant condition. Each region may be executed 0 or 1 times.
1590     invocationBounds.assign(2, {0, 1});
1591   }
1592 }
1593 
1594 namespace {
1595 // Pattern to remove unused IfOp results.
1596 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
1597   using OpRewritePattern<IfOp>::OpRewritePattern;
1598 
1599   void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
1600                     PatternRewriter &rewriter) const {
1601     // Move all operations to the destination block.
1602     rewriter.mergeBlocks(source, dest);
1603     // Replace the yield op by one that returns only the used values.
1604     auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
1605     SmallVector<Value, 4> usedOperands;
1606     llvm::transform(usedResults, std::back_inserter(usedOperands),
1607                     [&](OpResult result) {
1608                       return yieldOp.getOperand(result.getResultNumber());
1609                     });
1610     rewriter.updateRootInPlace(yieldOp,
1611                                [&]() { yieldOp->setOperands(usedOperands); });
1612   }
1613 
1614   LogicalResult matchAndRewrite(IfOp op,
1615                                 PatternRewriter &rewriter) const override {
1616     // Compute the list of used results.
1617     SmallVector<OpResult, 4> usedResults;
1618     llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
1619                   [](OpResult result) { return !result.use_empty(); });
1620 
1621     // Replace the operation if only a subset of its results have uses.
1622     if (usedResults.size() == op.getNumResults())
1623       return failure();
1624 
1625     // Compute the result types of the replacement operation.
1626     SmallVector<Type, 4> newTypes;
1627     llvm::transform(usedResults, std::back_inserter(newTypes),
1628                     [](OpResult result) { return result.getType(); });
1629 
1630     // Create a replacement operation with empty then and else regions.
1631     auto emptyBuilder = [](OpBuilder &, Location) {};
1632     auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
1633                                        emptyBuilder, emptyBuilder);
1634 
1635     // Move the bodies and replace the terminators (note there is a then and
1636     // an else region since the operation returns results).
1637     transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
1638     transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
1639 
1640     // Replace the operation by the new one.
1641     SmallVector<Value, 4> repResults(op.getNumResults());
1642     for (const auto &en : llvm::enumerate(usedResults))
1643       repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
1644     rewriter.replaceOp(op, repResults);
1645     return success();
1646   }
1647 };
1648 
1649 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
1650   using OpRewritePattern<IfOp>::OpRewritePattern;
1651 
1652   LogicalResult matchAndRewrite(IfOp op,
1653                                 PatternRewriter &rewriter) const override {
1654     auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
1655     if (!constant)
1656       return failure();
1657 
1658     if (constant.getValue().cast<BoolAttr>().getValue())
1659       replaceOpWithRegion(rewriter, op, op.getThenRegion());
1660     else if (!op.getElseRegion().empty())
1661       replaceOpWithRegion(rewriter, op, op.getElseRegion());
1662     else
1663       rewriter.eraseOp(op);
1664 
1665     return success();
1666   }
1667 };
1668 
1669 /// Hoist any yielded results whose operands are defined outside
1670 /// the if, to a select instruction.
1671 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
1672   using OpRewritePattern<IfOp>::OpRewritePattern;
1673 
1674   LogicalResult matchAndRewrite(IfOp op,
1675                                 PatternRewriter &rewriter) const override {
1676     if (op->getNumResults() == 0)
1677       return failure();
1678 
1679     auto cond = op.getCondition();
1680     auto thenYieldArgs = op.thenYield().getOperands();
1681     auto elseYieldArgs = op.elseYield().getOperands();
1682 
1683     SmallVector<Type> nonHoistable;
1684     for (const auto &it :
1685          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1686       Value trueVal = std::get<0>(it.value());
1687       Value falseVal = std::get<1>(it.value());
1688       if (&op.getThenRegion() == trueVal.getParentRegion() ||
1689           &op.getElseRegion() == falseVal.getParentRegion())
1690         nonHoistable.push_back(trueVal.getType());
1691     }
1692     // Early exit if there aren't any yielded values we can
1693     // hoist outside the if.
1694     if (nonHoistable.size() == op->getNumResults())
1695       return failure();
1696 
1697     IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
1698     if (replacement.thenBlock())
1699       rewriter.eraseBlock(replacement.thenBlock());
1700     replacement.getThenRegion().takeBody(op.getThenRegion());
1701     replacement.getElseRegion().takeBody(op.getElseRegion());
1702 
1703     SmallVector<Value> results(op->getNumResults());
1704     assert(thenYieldArgs.size() == results.size());
1705     assert(elseYieldArgs.size() == results.size());
1706 
1707     SmallVector<Value> trueYields;
1708     SmallVector<Value> falseYields;
1709     rewriter.setInsertionPoint(replacement);
1710     for (const auto &it :
1711          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1712       Value trueVal = std::get<0>(it.value());
1713       Value falseVal = std::get<1>(it.value());
1714       if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
1715           &replacement.getElseRegion() == falseVal.getParentRegion()) {
1716         results[it.index()] = replacement.getResult(trueYields.size());
1717         trueYields.push_back(trueVal);
1718         falseYields.push_back(falseVal);
1719       } else if (trueVal == falseVal)
1720         results[it.index()] = trueVal;
1721       else
1722         results[it.index()] = rewriter.create<arith::SelectOp>(
1723             op.getLoc(), cond, trueVal, falseVal);
1724     }
1725 
1726     rewriter.setInsertionPointToEnd(replacement.thenBlock());
1727     rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
1728 
1729     rewriter.setInsertionPointToEnd(replacement.elseBlock());
1730     rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
1731 
1732     rewriter.replaceOp(op, results);
1733     return success();
1734   }
1735 };
1736 
1737 /// Allow the true region of an if to assume the condition is true
1738 /// and vice versa. For example:
1739 ///
1740 ///   scf.if %cmp {
1741 ///      print(%cmp)
1742 ///   }
1743 ///
1744 ///  becomes
1745 ///
1746 ///   scf.if %cmp {
1747 ///      print(true)
1748 ///   }
1749 ///
1750 struct ConditionPropagation : public OpRewritePattern<IfOp> {
1751   using OpRewritePattern<IfOp>::OpRewritePattern;
1752 
1753   LogicalResult matchAndRewrite(IfOp op,
1754                                 PatternRewriter &rewriter) const override {
1755     // Early exit if the condition is constant since replacing a constant
1756     // in the body with another constant isn't a simplification.
1757     if (op.getCondition().getDefiningOp<arith::ConstantOp>())
1758       return failure();
1759 
1760     bool changed = false;
1761     mlir::Type i1Ty = rewriter.getI1Type();
1762 
1763     // These variables serve to prevent creating duplicate constants
1764     // and hold constant true or false values.
1765     Value constantTrue = nullptr;
1766     Value constantFalse = nullptr;
1767 
1768     for (OpOperand &use :
1769          llvm::make_early_inc_range(op.getCondition().getUses())) {
1770       if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1771         changed = true;
1772 
1773         if (!constantTrue)
1774           constantTrue = rewriter.create<arith::ConstantOp>(
1775               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
1776 
1777         rewriter.updateRootInPlace(use.getOwner(),
1778                                    [&]() { use.set(constantTrue); });
1779       } else if (op.getElseRegion().isAncestor(
1780                      use.getOwner()->getParentRegion())) {
1781         changed = true;
1782 
1783         if (!constantFalse)
1784           constantFalse = rewriter.create<arith::ConstantOp>(
1785               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
1786 
1787         rewriter.updateRootInPlace(use.getOwner(),
1788                                    [&]() { use.set(constantFalse); });
1789       }
1790     }
1791 
1792     return success(changed);
1793   }
1794 };
1795 
1796 /// Remove any statements from an if that are equivalent to the condition
1797 /// or its negation. For example:
1798 ///
1799 ///    %res:2 = scf.if %cmp {
1800 ///       yield something(), true
1801 ///    } else {
1802 ///       yield something2(), false
1803 ///    }
1804 ///    print(%res#1)
1805 ///
1806 ///  becomes
1807 ///    %res = scf.if %cmp {
1808 ///       yield something()
1809 ///    } else {
1810 ///       yield something2()
1811 ///    }
1812 ///    print(%cmp)
1813 ///
1814 /// Additionally if both branches yield the same value, replace all uses
1815 /// of the result with the yielded value.
1816 ///
1817 ///    %res:2 = scf.if %cmp {
1818 ///       yield something(), %arg1
1819 ///    } else {
1820 ///       yield something2(), %arg1
1821 ///    }
1822 ///    print(%res#1)
1823 ///
1824 ///  becomes
1825 ///    %res = scf.if %cmp {
1826 ///       yield something()
1827 ///    } else {
1828 ///       yield something2()
1829 ///    }
1830 ///    print(%arg1)
1831 ///
1832 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
1833   using OpRewritePattern<IfOp>::OpRewritePattern;
1834 
1835   LogicalResult matchAndRewrite(IfOp op,
1836                                 PatternRewriter &rewriter) const override {
1837     // Early exit if there are no results that could be replaced.
1838     if (op.getNumResults() == 0)
1839       return failure();
1840 
1841     auto trueYield =
1842         cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
1843     auto falseYield =
1844         cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
1845 
1846     rewriter.setInsertionPoint(op->getBlock(),
1847                                op.getOperation()->getIterator());
1848     bool changed = false;
1849     Type i1Ty = rewriter.getI1Type();
1850     for (auto tup : llvm::zip(trueYield.getResults(), falseYield.getResults(),
1851                               op.getResults())) {
1852       Value trueResult, falseResult, opResult;
1853       std::tie(trueResult, falseResult, opResult) = tup;
1854 
1855       if (trueResult == falseResult) {
1856         if (!opResult.use_empty()) {
1857           opResult.replaceAllUsesWith(trueResult);
1858           changed = true;
1859         }
1860         continue;
1861       }
1862 
1863       auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
1864       if (!trueYield)
1865         continue;
1866 
1867       if (!trueYield.getType().isInteger(1))
1868         continue;
1869 
1870       auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
1871       if (!falseYield)
1872         continue;
1873 
1874       bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
1875       bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
1876       if (!trueVal && falseVal) {
1877         if (!opResult.use_empty()) {
1878           Value notCond = rewriter.create<arith::XOrIOp>(
1879               op.getLoc(), op.getCondition(),
1880               rewriter.create<arith::ConstantOp>(
1881                   op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
1882           opResult.replaceAllUsesWith(notCond);
1883           changed = true;
1884         }
1885       }
1886       if (trueVal && !falseVal) {
1887         if (!opResult.use_empty()) {
1888           opResult.replaceAllUsesWith(op.getCondition());
1889           changed = true;
1890         }
1891       }
1892     }
1893     return success(changed);
1894   }
1895 };
1896 
1897 /// Merge any consecutive scf.if's with the same condition.
1898 ///
1899 ///    scf.if %cond {
1900 ///       firstCodeTrue();...
1901 ///    } else {
1902 ///       firstCodeFalse();...
1903 ///    }
1904 ///    %res = scf.if %cond {
1905 ///       secondCodeTrue();...
1906 ///    } else {
1907 ///       secondCodeFalse();...
1908 ///    }
1909 ///
1910 ///  becomes
1911 ///    %res = scf.if %cmp {
1912 ///       firstCodeTrue();...
1913 ///       secondCodeTrue();...
1914 ///    } else {
1915 ///       firstCodeFalse();...
1916 ///       secondCodeFalse();...
1917 ///    }
1918 struct CombineIfs : public OpRewritePattern<IfOp> {
1919   using OpRewritePattern<IfOp>::OpRewritePattern;
1920 
1921   LogicalResult matchAndRewrite(IfOp nextIf,
1922                                 PatternRewriter &rewriter) const override {
1923     Block *parent = nextIf->getBlock();
1924     if (nextIf == &parent->front())
1925       return failure();
1926 
1927     auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
1928     if (!prevIf)
1929       return failure();
1930 
1931     // Determine the logical then/else blocks when prevIf's
1932     // condition is used. Null means the block does not exist
1933     // in that case (e.g. empty else). If neither of these
1934     // are set, the two conditions cannot be compared.
1935     Block *nextThen = nullptr;
1936     Block *nextElse = nullptr;
1937     if (nextIf.getCondition() == prevIf.getCondition()) {
1938       nextThen = nextIf.thenBlock();
1939       if (!nextIf.getElseRegion().empty())
1940         nextElse = nextIf.elseBlock();
1941     }
1942     if (arith::XOrIOp notv =
1943             nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1944       if (notv.getLhs() == prevIf.getCondition() &&
1945           matchPattern(notv.getRhs(), m_One())) {
1946         nextElse = nextIf.thenBlock();
1947         if (!nextIf.getElseRegion().empty())
1948           nextThen = nextIf.elseBlock();
1949       }
1950     }
1951     if (arith::XOrIOp notv =
1952             prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1953       if (notv.getLhs() == nextIf.getCondition() &&
1954           matchPattern(notv.getRhs(), m_One())) {
1955         nextElse = nextIf.thenBlock();
1956         if (!nextIf.getElseRegion().empty())
1957           nextThen = nextIf.elseBlock();
1958       }
1959     }
1960 
1961     if (!nextThen && !nextElse)
1962       return failure();
1963 
1964     SmallVector<Value> prevElseYielded;
1965     if (!prevIf.getElseRegion().empty())
1966       prevElseYielded = prevIf.elseYield().getOperands();
1967     // Replace all uses of return values of op within nextIf with the
1968     // corresponding yields
1969     for (auto it : llvm::zip(prevIf.getResults(),
1970                              prevIf.thenYield().getOperands(), prevElseYielded))
1971       for (OpOperand &use :
1972            llvm::make_early_inc_range(std::get<0>(it).getUses())) {
1973         if (nextThen && nextThen->getParent()->isAncestor(
1974                             use.getOwner()->getParentRegion())) {
1975           rewriter.startRootUpdate(use.getOwner());
1976           use.set(std::get<1>(it));
1977           rewriter.finalizeRootUpdate(use.getOwner());
1978         } else if (nextElse && nextElse->getParent()->isAncestor(
1979                                    use.getOwner()->getParentRegion())) {
1980           rewriter.startRootUpdate(use.getOwner());
1981           use.set(std::get<2>(it));
1982           rewriter.finalizeRootUpdate(use.getOwner());
1983         }
1984       }
1985 
1986     SmallVector<Type> mergedTypes(prevIf.getResultTypes());
1987     llvm::append_range(mergedTypes, nextIf.getResultTypes());
1988 
1989     IfOp combinedIf = rewriter.create<IfOp>(
1990         nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
1991     rewriter.eraseBlock(&combinedIf.getThenRegion().back());
1992 
1993     rewriter.inlineRegionBefore(prevIf.getThenRegion(),
1994                                 combinedIf.getThenRegion(),
1995                                 combinedIf.getThenRegion().begin());
1996 
1997     if (nextThen) {
1998       YieldOp thenYield = combinedIf.thenYield();
1999       YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2000       rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2001       rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2002 
2003       SmallVector<Value> mergedYields(thenYield.getOperands());
2004       llvm::append_range(mergedYields, thenYield2.getOperands());
2005       rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2006       rewriter.eraseOp(thenYield);
2007       rewriter.eraseOp(thenYield2);
2008     }
2009 
2010     rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2011                                 combinedIf.getElseRegion(),
2012                                 combinedIf.getElseRegion().begin());
2013 
2014     if (nextElse) {
2015       if (combinedIf.getElseRegion().empty()) {
2016         rewriter.inlineRegionBefore(*nextElse->getParent(),
2017                                     combinedIf.getElseRegion(),
2018                                     combinedIf.getElseRegion().begin());
2019       } else {
2020         YieldOp elseYield = combinedIf.elseYield();
2021         YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2022         rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2023 
2024         rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2025 
2026         SmallVector<Value> mergedElseYields(elseYield.getOperands());
2027         llvm::append_range(mergedElseYields, elseYield2.getOperands());
2028 
2029         rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2030         rewriter.eraseOp(elseYield);
2031         rewriter.eraseOp(elseYield2);
2032       }
2033     }
2034 
2035     SmallVector<Value> prevValues;
2036     SmallVector<Value> nextValues;
2037     for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2038       if (pair.index() < prevIf.getNumResults())
2039         prevValues.push_back(pair.value());
2040       else
2041         nextValues.push_back(pair.value());
2042     }
2043     rewriter.replaceOp(prevIf, prevValues);
2044     rewriter.replaceOp(nextIf, nextValues);
2045     return success();
2046   }
2047 };
2048 
2049 /// Pattern to remove an empty else branch.
2050 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2051   using OpRewritePattern<IfOp>::OpRewritePattern;
2052 
2053   LogicalResult matchAndRewrite(IfOp ifOp,
2054                                 PatternRewriter &rewriter) const override {
2055     // Cannot remove else region when there are operation results.
2056     if (ifOp.getNumResults())
2057       return failure();
2058     Block *elseBlock = ifOp.elseBlock();
2059     if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2060       return failure();
2061     auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2062     rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2063                                 newIfOp.getThenRegion().begin());
2064     rewriter.eraseOp(ifOp);
2065     return success();
2066   }
2067 };
2068 
2069 /// Convert nested `if`s into `arith.andi` + single `if`.
2070 ///
2071 ///    scf.if %arg0 {
2072 ///      scf.if %arg1 {
2073 ///        ...
2074 ///        scf.yield
2075 ///      }
2076 ///      scf.yield
2077 ///    }
2078 ///  becomes
2079 ///
2080 ///    %0 = arith.andi %arg0, %arg1
2081 ///    scf.if %0 {
2082 ///      ...
2083 ///      scf.yield
2084 ///    }
2085 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2086   using OpRewritePattern<IfOp>::OpRewritePattern;
2087 
2088   LogicalResult matchAndRewrite(IfOp op,
2089                                 PatternRewriter &rewriter) const override {
2090     auto nestedOps = op.thenBlock()->without_terminator();
2091     // Nested `if` must be the only op in block.
2092     if (!llvm::hasSingleElement(nestedOps))
2093       return failure();
2094 
2095     // If there is an else block, it can only yield
2096     if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2097       return failure();
2098 
2099     auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2100     if (!nestedIf)
2101       return failure();
2102 
2103     if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2104       return failure();
2105 
2106     SmallVector<Value> thenYield(op.thenYield().getOperands());
2107     SmallVector<Value> elseYield;
2108     if (op.elseBlock())
2109       llvm::append_range(elseYield, op.elseYield().getOperands());
2110 
2111     // A list of indices for which we should upgrade the value yielded
2112     // in the else to a select.
2113     SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2114 
2115     // If the outer scf.if yields a value produced by the inner scf.if,
2116     // only permit combining if the value yielded when the condition
2117     // is false in the outer scf.if is the same value yielded when the
2118     // inner scf.if condition is false.
2119     // Note that the array access to elseYield will not go out of bounds
2120     // since it must have the same length as thenYield, since they both
2121     // come from the same scf.if.
2122     for (const auto &tup : llvm::enumerate(thenYield)) {
2123       if (tup.value().getDefiningOp() == nestedIf) {
2124         auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
2125         if (nestedIf.elseYield().getOperand(nestedIdx) !=
2126             elseYield[tup.index()]) {
2127           return failure();
2128         }
2129         // If the correctness test passes, we will yield
2130         // corresponding value from the inner scf.if
2131         thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2132         continue;
2133       }
2134 
2135       // Otherwise, we need to ensure the else block of the combined
2136       // condition still returns the same value when the outer condition is
2137       // true and the inner condition is false. This can be accomplished if
2138       // the then value is defined outside the outer scf.if and we replace the
2139       // value with a select that considers just the outer condition. Since
2140       // the else region contains just the yield, its yielded value is
2141       // defined outside the scf.if, by definition.
2142 
2143       // If the then value is defined within the scf.if, bail.
2144       if (tup.value().getParentRegion() == &op.getThenRegion()) {
2145         return failure();
2146       }
2147       elseYieldsToUpgradeToSelect.push_back(tup.index());
2148     }
2149 
2150     Location loc = op.getLoc();
2151     Value newCondition = rewriter.create<arith::AndIOp>(
2152         loc, op.getCondition(), nestedIf.getCondition());
2153     auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2154 
2155     SmallVector<Value> results;
2156     llvm::append_range(results, newIf.getResults());
2157     rewriter.setInsertionPoint(newIf);
2158 
2159     for (auto idx : elseYieldsToUpgradeToSelect)
2160       results[idx] = rewriter.create<arith::SelectOp>(
2161           op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2162 
2163     Block *newIfBlock = newIf.thenBlock();
2164     if (newIfBlock)
2165       rewriter.eraseOp(newIfBlock->getTerminator());
2166     else
2167       newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2168     rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2169     rewriter.setInsertionPointToEnd(newIf.thenBlock());
2170     rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2171     if (!elseYield.empty()) {
2172       rewriter.createBlock(&newIf.getElseRegion());
2173       rewriter.setInsertionPointToEnd(newIf.elseBlock());
2174       rewriter.create<YieldOp>(loc, elseYield);
2175     }
2176     rewriter.replaceOp(op, results);
2177     return success();
2178   }
2179 };
2180 
2181 } // namespace
2182 
2183 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2184                                        MLIRContext *context) {
2185   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2186               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2187               RemoveStaticCondition, RemoveUnusedResults,
2188               ReplaceIfYieldWithConditionOrValue>(context);
2189 }
2190 
2191 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2192 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2193 Block *IfOp::elseBlock() {
2194   Region &r = getElseRegion();
2195   if (r.empty())
2196     return nullptr;
2197   return &r.back();
2198 }
2199 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2200 
2201 //===----------------------------------------------------------------------===//
2202 // ParallelOp
2203 //===----------------------------------------------------------------------===//
2204 
2205 void ParallelOp::build(
2206     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2207     ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2208     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2209         bodyBuilderFn) {
2210   result.addOperands(lowerBounds);
2211   result.addOperands(upperBounds);
2212   result.addOperands(steps);
2213   result.addOperands(initVals);
2214   result.addAttribute(
2215       ParallelOp::getOperandSegmentSizeAttr(),
2216       builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
2217                                 static_cast<int32_t>(upperBounds.size()),
2218                                 static_cast<int32_t>(steps.size()),
2219                                 static_cast<int32_t>(initVals.size())}));
2220   result.addTypes(initVals.getTypes());
2221 
2222   OpBuilder::InsertionGuard guard(builder);
2223   unsigned numIVs = steps.size();
2224   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2225   SmallVector<Location, 8> argLocs(numIVs, result.location);
2226   Region *bodyRegion = result.addRegion();
2227   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2228 
2229   if (bodyBuilderFn) {
2230     builder.setInsertionPointToStart(bodyBlock);
2231     bodyBuilderFn(builder, result.location,
2232                   bodyBlock->getArguments().take_front(numIVs),
2233                   bodyBlock->getArguments().drop_front(numIVs));
2234   }
2235   ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2236 }
2237 
2238 void ParallelOp::build(
2239     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2240     ValueRange upperBounds, ValueRange steps,
2241     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2242   // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2243   // we don't capture a reference to a temporary by constructing the lambda at
2244   // function level.
2245   auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2246                                            Location nestedLoc, ValueRange ivs,
2247                                            ValueRange) {
2248     bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2249   };
2250   function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2251   if (bodyBuilderFn)
2252     wrapper = wrappedBuilderFn;
2253 
2254   build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2255         wrapper);
2256 }
2257 
2258 LogicalResult ParallelOp::verify() {
2259   // Check that there is at least one value in lowerBound, upperBound and step.
2260   // It is sufficient to test only step, because it is ensured already that the
2261   // number of elements in lowerBound, upperBound and step are the same.
2262   Operation::operand_range stepValues = getStep();
2263   if (stepValues.empty())
2264     return emitOpError(
2265         "needs at least one tuple element for lowerBound, upperBound and step");
2266 
2267   // Check whether all constant step values are positive.
2268   for (Value stepValue : stepValues)
2269     if (auto cst = stepValue.getDefiningOp<arith::ConstantIndexOp>())
2270       if (cst.value() <= 0)
2271         return emitOpError("constant step operand must be positive");
2272 
2273   // Check that the body defines the same number of block arguments as the
2274   // number of tuple elements in step.
2275   Block *body = getBody();
2276   if (body->getNumArguments() != stepValues.size())
2277     return emitOpError() << "expects the same number of induction variables: "
2278                          << body->getNumArguments()
2279                          << " as bound and step values: " << stepValues.size();
2280   for (auto arg : body->getArguments())
2281     if (!arg.getType().isIndex())
2282       return emitOpError(
2283           "expects arguments for the induction variable to be of index type");
2284 
2285   // Check that the yield has no results
2286   Operation *yield = body->getTerminator();
2287   if (yield->getNumOperands() != 0)
2288     return yield->emitOpError() << "not allowed to have operands inside '"
2289                                 << ParallelOp::getOperationName() << "'";
2290 
2291   // Check that the number of results is the same as the number of ReduceOps.
2292   SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
2293   auto resultsSize = getResults().size();
2294   auto reductionsSize = reductions.size();
2295   auto initValsSize = getInitVals().size();
2296   if (resultsSize != reductionsSize)
2297     return emitOpError() << "expects number of results: " << resultsSize
2298                          << " to be the same as number of reductions: "
2299                          << reductionsSize;
2300   if (resultsSize != initValsSize)
2301     return emitOpError() << "expects number of results: " << resultsSize
2302                          << " to be the same as number of initial values: "
2303                          << initValsSize;
2304 
2305   // Check that the types of the results and reductions are the same.
2306   for (auto resultAndReduce : llvm::zip(getResults(), reductions)) {
2307     auto resultType = std::get<0>(resultAndReduce).getType();
2308     auto reduceOp = std::get<1>(resultAndReduce);
2309     auto reduceType = reduceOp.getOperand().getType();
2310     if (resultType != reduceType)
2311       return reduceOp.emitOpError()
2312              << "expects type of reduce: " << reduceType
2313              << " to be the same as result type: " << resultType;
2314   }
2315   return success();
2316 }
2317 
2318 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2319   auto &builder = parser.getBuilder();
2320   // Parse an opening `(` followed by induction variables followed by `)`
2321   SmallVector<OpAsmParser::Argument, 4> ivs;
2322   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
2323     return failure();
2324 
2325   // Parse loop bounds.
2326   SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2327   if (parser.parseEqual() ||
2328       parser.parseOperandList(lower, ivs.size(),
2329                               OpAsmParser::Delimiter::Paren) ||
2330       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2331     return failure();
2332 
2333   SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2334   if (parser.parseKeyword("to") ||
2335       parser.parseOperandList(upper, ivs.size(),
2336                               OpAsmParser::Delimiter::Paren) ||
2337       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2338     return failure();
2339 
2340   // Parse step values.
2341   SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2342   if (parser.parseKeyword("step") ||
2343       parser.parseOperandList(steps, ivs.size(),
2344                               OpAsmParser::Delimiter::Paren) ||
2345       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2346     return failure();
2347 
2348   // Parse init values.
2349   SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2350   if (succeeded(parser.parseOptionalKeyword("init"))) {
2351     if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2352       return failure();
2353   }
2354 
2355   // Parse optional results in case there is a reduce.
2356   if (parser.parseOptionalArrowTypeList(result.types))
2357     return failure();
2358 
2359   // Now parse the body.
2360   Region *body = result.addRegion();
2361   for (auto &iv : ivs)
2362     iv.type = builder.getIndexType();
2363   if (parser.parseRegion(*body, ivs))
2364     return failure();
2365 
2366   // Set `operand_segment_sizes` attribute.
2367   result.addAttribute(
2368       ParallelOp::getOperandSegmentSizeAttr(),
2369       builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
2370                                 static_cast<int32_t>(upper.size()),
2371                                 static_cast<int32_t>(steps.size()),
2372                                 static_cast<int32_t>(initVals.size())}));
2373 
2374   // Parse attributes.
2375   if (parser.parseOptionalAttrDict(result.attributes) ||
2376       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2377                              result.operands))
2378     return failure();
2379 
2380   // Add a terminator if none was parsed.
2381   ForOp::ensureTerminator(*body, builder, result.location);
2382   return success();
2383 }
2384 
2385 void ParallelOp::print(OpAsmPrinter &p) {
2386   p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2387     << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2388   if (!getInitVals().empty())
2389     p << " init (" << getInitVals() << ")";
2390   p.printOptionalArrowTypeList(getResultTypes());
2391   p << ' ';
2392   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2393   p.printOptionalAttrDict(
2394       (*this)->getAttrs(),
2395       /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2396 }
2397 
2398 Region &ParallelOp::getLoopBody() { return getRegion(); }
2399 
2400 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
2401   auto ivArg = val.dyn_cast<BlockArgument>();
2402   if (!ivArg)
2403     return ParallelOp();
2404   assert(ivArg.getOwner() && "unlinked block argument");
2405   auto *containingOp = ivArg.getOwner()->getParentOp();
2406   return dyn_cast<ParallelOp>(containingOp);
2407 }
2408 
2409 namespace {
2410 // Collapse loop dimensions that perform a single iteration.
2411 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
2412   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2413 
2414   LogicalResult matchAndRewrite(ParallelOp op,
2415                                 PatternRewriter &rewriter) const override {
2416     BlockAndValueMapping mapping;
2417     // Compute new loop bounds that omit all single-iteration loop dimensions.
2418     SmallVector<Value, 2> newLowerBounds;
2419     SmallVector<Value, 2> newUpperBounds;
2420     SmallVector<Value, 2> newSteps;
2421     newLowerBounds.reserve(op.getLowerBound().size());
2422     newUpperBounds.reserve(op.getUpperBound().size());
2423     newSteps.reserve(op.getStep().size());
2424     for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound(),
2425                               op.getStep(), op.getInductionVars())) {
2426       Value lowerBound, upperBound, step, iv;
2427       std::tie(lowerBound, upperBound, step, iv) = dim;
2428       // Collect the statically known loop bounds.
2429       auto lowerBoundConstant =
2430           dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
2431       auto upperBoundConstant =
2432           dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
2433       auto stepConstant =
2434           dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
2435       // Replace the loop induction variable by the lower bound if the loop
2436       // performs a single iteration. Otherwise, copy the loop bounds.
2437       if (lowerBoundConstant && upperBoundConstant && stepConstant &&
2438           (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
2439           (upperBoundConstant.value() - lowerBoundConstant.value()) <=
2440               stepConstant.value()) {
2441         mapping.map(iv, lowerBound);
2442       } else {
2443         newLowerBounds.push_back(lowerBound);
2444         newUpperBounds.push_back(upperBound);
2445         newSteps.push_back(step);
2446       }
2447     }
2448     // Exit if none of the loop dimensions perform a single iteration.
2449     if (newLowerBounds.size() == op.getLowerBound().size())
2450       return failure();
2451 
2452     if (newLowerBounds.empty()) {
2453       // All of the loop dimensions perform a single iteration. Inline
2454       // loop body and nested ReduceOp's
2455       SmallVector<Value> results;
2456       results.reserve(op.getInitVals().size());
2457       for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
2458         auto reduce = dyn_cast<ReduceOp>(bodyOp);
2459         if (!reduce) {
2460           rewriter.clone(bodyOp, mapping);
2461           continue;
2462         }
2463         Block &reduceBlock = reduce.getReductionOperator().front();
2464         auto initValIndex = results.size();
2465         mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
2466         mapping.map(reduceBlock.getArgument(1),
2467                     mapping.lookupOrDefault(reduce.getOperand()));
2468         for (auto &reduceBodyOp : reduceBlock.without_terminator())
2469           rewriter.clone(reduceBodyOp, mapping);
2470 
2471         auto result = mapping.lookupOrDefault(
2472             cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
2473         results.push_back(result);
2474       }
2475       rewriter.replaceOp(op, results);
2476       return success();
2477     }
2478     // Replace the parallel loop by lower-dimensional parallel loop.
2479     auto newOp =
2480         rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
2481                                     newSteps, op.getInitVals(), nullptr);
2482     // Clone the loop body and remap the block arguments of the collapsed loops
2483     // (inlining does not support a cancellable block argument mapping).
2484     rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
2485                                newOp.getRegion().begin(), mapping);
2486     rewriter.replaceOp(op, newOp.getResults());
2487     return success();
2488   }
2489 };
2490 
2491 /// Removes parallel loops in which at least one lower/upper bound pair consists
2492 /// of the same values - such loops have an empty iteration domain.
2493 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
2494   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2495 
2496   LogicalResult matchAndRewrite(ParallelOp op,
2497                                 PatternRewriter &rewriter) const override {
2498     for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
2499       if (std::get<0>(dim) == std::get<1>(dim)) {
2500         rewriter.replaceOp(op, op.getInitVals());
2501         return success();
2502       }
2503     }
2504     return failure();
2505   }
2506 };
2507 
2508 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
2509   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2510 
2511   LogicalResult matchAndRewrite(ParallelOp op,
2512                                 PatternRewriter &rewriter) const override {
2513     Block &outerBody = op.getLoopBody().front();
2514     if (!llvm::hasSingleElement(outerBody.without_terminator()))
2515       return failure();
2516 
2517     auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
2518     if (!innerOp)
2519       return failure();
2520 
2521     for (auto val : outerBody.getArguments())
2522       if (llvm::is_contained(innerOp.getLowerBound(), val) ||
2523           llvm::is_contained(innerOp.getUpperBound(), val) ||
2524           llvm::is_contained(innerOp.getStep(), val))
2525         return failure();
2526 
2527     // Reductions are not supported yet.
2528     if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2529       return failure();
2530 
2531     auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
2532                            ValueRange iterVals, ValueRange) {
2533       Block &innerBody = innerOp.getLoopBody().front();
2534       assert(iterVals.size() ==
2535              (outerBody.getNumArguments() + innerBody.getNumArguments()));
2536       BlockAndValueMapping mapping;
2537       mapping.map(outerBody.getArguments(),
2538                   iterVals.take_front(outerBody.getNumArguments()));
2539       mapping.map(innerBody.getArguments(),
2540                   iterVals.take_back(innerBody.getNumArguments()));
2541       for (Operation &op : innerBody.without_terminator())
2542         builder.clone(op, mapping);
2543     };
2544 
2545     auto concatValues = [](const auto &first, const auto &second) {
2546       SmallVector<Value> ret;
2547       ret.reserve(first.size() + second.size());
2548       ret.assign(first.begin(), first.end());
2549       ret.append(second.begin(), second.end());
2550       return ret;
2551     };
2552 
2553     auto newLowerBounds =
2554         concatValues(op.getLowerBound(), innerOp.getLowerBound());
2555     auto newUpperBounds =
2556         concatValues(op.getUpperBound(), innerOp.getUpperBound());
2557     auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2558 
2559     rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
2560                                             newSteps, llvm::None, bodyBuilder);
2561     return success();
2562   }
2563 };
2564 
2565 } // namespace
2566 
2567 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
2568                                              MLIRContext *context) {
2569   results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
2570               MergeNestedParallelLoops>(context);
2571 }
2572 
2573 //===----------------------------------------------------------------------===//
2574 // ReduceOp
2575 //===----------------------------------------------------------------------===//
2576 
2577 void ReduceOp::build(
2578     OpBuilder &builder, OperationState &result, Value operand,
2579     function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
2580   auto type = operand.getType();
2581   result.addOperands(operand);
2582 
2583   OpBuilder::InsertionGuard guard(builder);
2584   Region *bodyRegion = result.addRegion();
2585   Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type},
2586                                     {result.location, result.location});
2587   if (bodyBuilderFn)
2588     bodyBuilderFn(builder, result.location, body->getArgument(0),
2589                   body->getArgument(1));
2590 }
2591 
2592 LogicalResult ReduceOp::verifyRegions() {
2593   // The region of a ReduceOp has two arguments of the same type as its operand.
2594   auto type = getOperand().getType();
2595   Block &block = getReductionOperator().front();
2596   if (block.empty())
2597     return emitOpError("the block inside reduce should not be empty");
2598   if (block.getNumArguments() != 2 ||
2599       llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
2600         return arg.getType() != type;
2601       }))
2602     return emitOpError() << "expects two arguments to reduce block of type "
2603                          << type;
2604 
2605   // Check that the block is terminated by a ReduceReturnOp.
2606   if (!isa<ReduceReturnOp>(block.getTerminator()))
2607     return emitOpError("the block inside reduce should be terminated with a "
2608                        "'scf.reduce.return' op");
2609 
2610   return success();
2611 }
2612 
2613 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2614   // Parse an opening `(` followed by the reduced value followed by `)`
2615   OpAsmParser::UnresolvedOperand operand;
2616   if (parser.parseLParen() || parser.parseOperand(operand) ||
2617       parser.parseRParen())
2618     return failure();
2619 
2620   Type resultType;
2621   // Parse the type of the operand (and also what reduce computes on).
2622   if (parser.parseColonType(resultType) ||
2623       parser.resolveOperand(operand, resultType, result.operands))
2624     return failure();
2625 
2626   // Now parse the body.
2627   Region *body = result.addRegion();
2628   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2629     return failure();
2630 
2631   return success();
2632 }
2633 
2634 void ReduceOp::print(OpAsmPrinter &p) {
2635   p << "(" << getOperand() << ") ";
2636   p << " : " << getOperand().getType() << ' ';
2637   p.printRegion(getReductionOperator());
2638 }
2639 
2640 //===----------------------------------------------------------------------===//
2641 // ReduceReturnOp
2642 //===----------------------------------------------------------------------===//
2643 
2644 LogicalResult ReduceReturnOp::verify() {
2645   // The type of the return value should be the same type as the type of the
2646   // operand of the enclosing ReduceOp.
2647   auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
2648   Type reduceType = reduceOp.getOperand().getType();
2649   if (reduceType != getResult().getType())
2650     return emitOpError() << "needs to have type " << reduceType
2651                          << " (the type of the enclosing ReduceOp)";
2652   return success();
2653 }
2654 
2655 //===----------------------------------------------------------------------===//
2656 // WhileOp
2657 //===----------------------------------------------------------------------===//
2658 
2659 OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
2660   assert(index && *index == 0 &&
2661          "WhileOp is expected to branch only to the first region");
2662 
2663   return getInits();
2664 }
2665 
2666 ConditionOp WhileOp::getConditionOp() {
2667   return cast<ConditionOp>(getBefore().front().getTerminator());
2668 }
2669 
2670 YieldOp WhileOp::getYieldOp() {
2671   return cast<YieldOp>(getAfter().front().getTerminator());
2672 }
2673 
2674 Block::BlockArgListType WhileOp::getBeforeArguments() {
2675   return getBefore().front().getArguments();
2676 }
2677 
2678 Block::BlockArgListType WhileOp::getAfterArguments() {
2679   return getAfter().front().getArguments();
2680 }
2681 
2682 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
2683                                   ArrayRef<Attribute> operands,
2684                                   SmallVectorImpl<RegionSuccessor> &regions) {
2685   // The parent op always branches to the condition region.
2686   if (!index) {
2687     regions.emplace_back(&getBefore(), getBefore().getArguments());
2688     return;
2689   }
2690 
2691   assert(*index < 2 && "there are only two regions in a WhileOp");
2692   // The body region always branches back to the condition region.
2693   if (*index == 1) {
2694     regions.emplace_back(&getBefore(), getBefore().getArguments());
2695     return;
2696   }
2697 
2698   // Try to narrow the successor to the condition region.
2699   assert(!operands.empty() && "expected at least one operand");
2700   auto cond = operands[0].dyn_cast_or_null<BoolAttr>();
2701   if (!cond || !cond.getValue())
2702     regions.emplace_back(getResults());
2703   if (!cond || cond.getValue())
2704     regions.emplace_back(&getAfter(), getAfter().getArguments());
2705 }
2706 
2707 /// Parses a `while` op.
2708 ///
2709 /// op ::= `scf.while` assignments `:` function-type region `do` region
2710 ///         `attributes` attribute-dict
2711 /// initializer ::= /* empty */ | `(` assignment-list `)`
2712 /// assignment-list ::= assignment | assignment `,` assignment-list
2713 /// assignment ::= ssa-value `=` ssa-value
2714 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2715   SmallVector<OpAsmParser::Argument, 4> regionArgs;
2716   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2717   Region *before = result.addRegion();
2718   Region *after = result.addRegion();
2719 
2720   OptionalParseResult listResult =
2721       parser.parseOptionalAssignmentList(regionArgs, operands);
2722   if (listResult.hasValue() && failed(listResult.getValue()))
2723     return failure();
2724 
2725   FunctionType functionType;
2726   SMLoc typeLoc = parser.getCurrentLocation();
2727   if (failed(parser.parseColonType(functionType)))
2728     return failure();
2729 
2730   result.addTypes(functionType.getResults());
2731 
2732   if (functionType.getNumInputs() != operands.size()) {
2733     return parser.emitError(typeLoc)
2734            << "expected as many input types as operands "
2735            << "(expected " << operands.size() << " got "
2736            << functionType.getNumInputs() << ")";
2737   }
2738 
2739   // Resolve input operands.
2740   if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2741                                     parser.getCurrentLocation(),
2742                                     result.operands)))
2743     return failure();
2744 
2745   // Propagate the types into the region arguments.
2746   for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2747     regionArgs[i].type = functionType.getInput(i);
2748 
2749   return failure(parser.parseRegion(*before, regionArgs) ||
2750                  parser.parseKeyword("do") || parser.parseRegion(*after) ||
2751                  parser.parseOptionalAttrDictWithKeyword(result.attributes));
2752 }
2753 
2754 /// Prints a `while` op.
2755 void scf::WhileOp::print(OpAsmPrinter &p) {
2756   printInitializationList(p, getBefore().front().getArguments(), getInits(),
2757                           " ");
2758   p << " : ";
2759   p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
2760   p << ' ';
2761   p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
2762   p << " do ";
2763   p.printRegion(getAfter());
2764   p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2765 }
2766 
2767 /// Verifies that two ranges of types match, i.e. have the same number of
2768 /// entries and that types are pairwise equals. Reports errors on the given
2769 /// operation in case of mismatch.
2770 template <typename OpTy>
2771 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
2772                                            TypeRange right, StringRef message) {
2773   if (left.size() != right.size())
2774     return op.emitOpError("expects the same number of ") << message;
2775 
2776   for (unsigned i = 0, e = left.size(); i < e; ++i) {
2777     if (left[i] != right[i]) {
2778       InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
2779                                 << message;
2780       diag.attachNote() << "for argument " << i << ", found " << left[i]
2781                         << " and " << right[i];
2782       return diag;
2783     }
2784   }
2785 
2786   return success();
2787 }
2788 
2789 /// Verifies that the first block of the given `region` is terminated by a
2790 /// YieldOp. Reports errors on the given operation if it is not the case.
2791 template <typename TerminatorTy>
2792 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
2793                                            StringRef errorMessage) {
2794   Operation *terminatorOperation = region.front().getTerminator();
2795   if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
2796     return yield;
2797 
2798   auto diag = op.emitOpError(errorMessage);
2799   if (terminatorOperation)
2800     diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
2801   return nullptr;
2802 }
2803 
2804 LogicalResult scf::WhileOp::verify() {
2805   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
2806       *this, getBefore(),
2807       "expects the 'before' region to terminate with 'scf.condition'");
2808   if (!beforeTerminator)
2809     return failure();
2810 
2811   auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
2812       *this, getAfter(),
2813       "expects the 'after' region to terminate with 'scf.yield'");
2814   return success(afterTerminator != nullptr);
2815 }
2816 
2817 namespace {
2818 /// Replace uses of the condition within the do block with true, since otherwise
2819 /// the block would not be evaluated.
2820 ///
2821 /// scf.while (..) : (i1, ...) -> ... {
2822 ///  %condition = call @evaluate_condition() : () -> i1
2823 ///  scf.condition(%condition) %condition : i1, ...
2824 /// } do {
2825 /// ^bb0(%arg0: i1, ...):
2826 ///    use(%arg0)
2827 ///    ...
2828 ///
2829 /// becomes
2830 /// scf.while (..) : (i1, ...) -> ... {
2831 ///  %condition = call @evaluate_condition() : () -> i1
2832 ///  scf.condition(%condition) %condition : i1, ...
2833 /// } do {
2834 /// ^bb0(%arg0: i1, ...):
2835 ///    use(%true)
2836 ///    ...
2837 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
2838   using OpRewritePattern<WhileOp>::OpRewritePattern;
2839 
2840   LogicalResult matchAndRewrite(WhileOp op,
2841                                 PatternRewriter &rewriter) const override {
2842     auto term = op.getConditionOp();
2843 
2844     // These variables serve to prevent creating duplicate constants
2845     // and hold constant true or false values.
2846     Value constantTrue = nullptr;
2847 
2848     bool replaced = false;
2849     for (auto yieldedAndBlockArgs :
2850          llvm::zip(term.getArgs(), op.getAfterArguments())) {
2851       if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
2852         if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
2853           if (!constantTrue)
2854             constantTrue = rewriter.create<arith::ConstantOp>(
2855                 op.getLoc(), term.getCondition().getType(),
2856                 rewriter.getBoolAttr(true));
2857 
2858           std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
2859           replaced = true;
2860         }
2861       }
2862     }
2863     return success(replaced);
2864   }
2865 };
2866 
2867 /// Remove loop invariant arguments from `before` block of scf.while.
2868 /// A before block argument is considered loop invariant if :-
2869 ///   1. i-th yield operand is equal to the i-th while operand.
2870 ///   2. i-th yield operand is k-th after block argument which is (k+1)-th
2871 ///      condition operand AND this (k+1)-th condition operand is equal to i-th
2872 ///      iter argument/while operand.
2873 /// For the arguments which are removed, their uses inside scf.while
2874 /// are replaced with their corresponding initial value.
2875 ///
2876 /// Eg:
2877 ///    INPUT :-
2878 ///    %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
2879 ///                                     ..., %argN_before = %N)
2880 ///           {
2881 ///                ...
2882 ///                scf.condition(%cond) %arg1_before, %arg0_before,
2883 ///                                     %arg2_before, %arg0_before, ...
2884 ///           } do {
2885 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2886 ///                  ..., %argK_after):
2887 ///                ...
2888 ///                scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
2889 ///           }
2890 ///
2891 ///    OUTPUT :-
2892 ///    %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
2893 ///                                     %N)
2894 ///           {
2895 ///                ...
2896 ///                scf.condition(%cond) %b, %a, %arg2_before, %a, ...
2897 ///           } do {
2898 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2899 ///                  ..., %argK_after):
2900 ///                ...
2901 ///                scf.yield %arg1_after, ..., %argN
2902 ///           }
2903 ///
2904 ///    EXPLANATION:
2905 ///      We iterate over each yield operand.
2906 ///        1. 0-th yield operand %arg0_after_2 is 4-th condition operand
2907 ///           %arg0_before, which in turn is the 0-th iter argument. So we
2908 ///           remove 0-th before block argument and yield operand, and replace
2909 ///           all uses of the 0-th before block argument with its initial value
2910 ///           %a.
2911 ///        2. 1-th yield operand %b is equal to the 1-th iter arg's initial
2912 ///           value. So we remove this operand and the corresponding before
2913 ///           block argument and replace all uses of 1-th before block argument
2914 ///           with %b.
2915 struct RemoveLoopInvariantArgsFromBeforeBlock
2916     : public OpRewritePattern<WhileOp> {
2917   using OpRewritePattern<WhileOp>::OpRewritePattern;
2918 
2919   LogicalResult matchAndRewrite(WhileOp op,
2920                                 PatternRewriter &rewriter) const override {
2921     Block &afterBlock = op.getAfter().front();
2922     Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
2923     ConditionOp condOp = op.getConditionOp();
2924     OperandRange condOpArgs = condOp.getArgs();
2925     Operation *yieldOp = afterBlock.getTerminator();
2926     ValueRange yieldOpArgs = yieldOp->getOperands();
2927 
2928     bool canSimplify = false;
2929     for (const auto &it :
2930          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2931       auto index = static_cast<unsigned>(it.index());
2932       Value initVal, yieldOpArg;
2933       std::tie(initVal, yieldOpArg) = it.value();
2934       // If i-th yield operand is equal to the i-th operand of the scf.while,
2935       // the i-th before block argument is a loop invariant.
2936       if (yieldOpArg == initVal) {
2937         canSimplify = true;
2938         break;
2939       }
2940       // If the i-th yield operand is k-th after block argument, then we check
2941       // if the (k+1)-th condition op operand is equal to either the i-th before
2942       // block argument or the initial value of i-th before block argument. If
2943       // the comparison results `true`, i-th before block argument is a loop
2944       // invariant.
2945       auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2946       if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2947         Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2948         if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2949           canSimplify = true;
2950           break;
2951         }
2952       }
2953     }
2954 
2955     if (!canSimplify)
2956       return failure();
2957 
2958     SmallVector<Value> newInitArgs, newYieldOpArgs;
2959     DenseMap<unsigned, Value> beforeBlockInitValMap;
2960     SmallVector<Location> newBeforeBlockArgLocs;
2961     for (const auto &it :
2962          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2963       auto index = static_cast<unsigned>(it.index());
2964       Value initVal, yieldOpArg;
2965       std::tie(initVal, yieldOpArg) = it.value();
2966 
2967       // If i-th yield operand is equal to the i-th operand of the scf.while,
2968       // the i-th before block argument is a loop invariant.
2969       if (yieldOpArg == initVal) {
2970         beforeBlockInitValMap.insert({index, initVal});
2971         continue;
2972       } else {
2973         // If the i-th yield operand is k-th after block argument, then we check
2974         // if the (k+1)-th condition op operand is equal to either the i-th
2975         // before block argument or the initial value of i-th before block
2976         // argument. If the comparison results `true`, i-th before block
2977         // argument is a loop invariant.
2978         auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2979         if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2980           Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2981           if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2982             beforeBlockInitValMap.insert({index, initVal});
2983             continue;
2984           }
2985         }
2986       }
2987       newInitArgs.emplace_back(initVal);
2988       newYieldOpArgs.emplace_back(yieldOpArg);
2989       newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
2990     }
2991 
2992     {
2993       OpBuilder::InsertionGuard g(rewriter);
2994       rewriter.setInsertionPoint(yieldOp);
2995       rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
2996     }
2997 
2998     auto newWhile =
2999         rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3000 
3001     Block &newBeforeBlock = *rewriter.createBlock(
3002         &newWhile.getBefore(), /*insertPt*/ {},
3003         ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3004 
3005     Block &beforeBlock = op.getBefore().front();
3006     SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3007     // For each i-th before block argument we find it's replacement value as :-
3008     //   1. If i-th before block argument is a loop invariant, we fetch it's
3009     //      initial value from `beforeBlockInitValMap` by querying for key `i`.
3010     //   2. Else we fetch j-th new before block argument as the replacement
3011     //      value of i-th before block argument.
3012     for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3013       // If the index 'i' argument was a loop invariant we fetch it's initial
3014       // value from `beforeBlockInitValMap`.
3015       if (beforeBlockInitValMap.count(i) != 0)
3016         newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3017       else
3018         newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3019     }
3020 
3021     rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3022     rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3023                                 newWhile.getAfter().begin());
3024 
3025     rewriter.replaceOp(op, newWhile.getResults());
3026     return success();
3027   }
3028 };
3029 
3030 /// Remove loop invariant value from result (condition op) of scf.while.
3031 /// A value is considered loop invariant if the final value yielded by
3032 /// scf.condition is defined outside of the `before` block. We remove the
3033 /// corresponding argument in `after` block and replace the use with the value.
3034 /// We also replace the use of the corresponding result of scf.while with the
3035 /// value.
3036 ///
3037 /// Eg:
3038 ///    INPUT :-
3039 ///    %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3040 ///                                             %argN_before = %N) {
3041 ///                ...
3042 ///                scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3043 ///           } do {
3044 ///             ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3045 ///                ...
3046 ///                some_func(%arg1_after)
3047 ///                ...
3048 ///                scf.yield %arg0_after, %arg2_after, ..., %argN_after
3049 ///           }
3050 ///
3051 ///    OUTPUT :-
3052 ///    %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3053 ///                ...
3054 ///                scf.condition(%cond) %arg0, %arg1, ..., %argM
3055 ///           } do {
3056 ///             ^bb0(%arg0, %arg3, ..., %argM):
3057 ///                ...
3058 ///                some_func(%a)
3059 ///                ...
3060 ///                scf.yield %arg0, %b, ..., %argN
3061 ///           }
3062 ///
3063 ///     EXPLANATION:
3064 ///       1. The 1-th and 2-th operand of scf.condition are defined outside the
3065 ///          before block of scf.while, so they get removed.
3066 ///       2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3067 ///          replaced by %b.
3068 ///       3. The corresponding after block argument %arg1_after's uses are
3069 ///          replaced by %a and %arg2_after's uses are replaced by %b.
3070 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3071   using OpRewritePattern<WhileOp>::OpRewritePattern;
3072 
3073   LogicalResult matchAndRewrite(WhileOp op,
3074                                 PatternRewriter &rewriter) const override {
3075     Block &beforeBlock = op.getBefore().front();
3076     ConditionOp condOp = op.getConditionOp();
3077     OperandRange condOpArgs = condOp.getArgs();
3078 
3079     bool canSimplify = false;
3080     for (Value condOpArg : condOpArgs) {
3081       // Those values not defined within `before` block will be considered as
3082       // loop invariant values. We map the corresponding `index` with their
3083       // value.
3084       if (condOpArg.getParentBlock() != &beforeBlock) {
3085         canSimplify = true;
3086         break;
3087       }
3088     }
3089 
3090     if (!canSimplify)
3091       return failure();
3092 
3093     Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3094 
3095     SmallVector<Value> newCondOpArgs;
3096     SmallVector<Type> newAfterBlockType;
3097     DenseMap<unsigned, Value> condOpInitValMap;
3098     SmallVector<Location> newAfterBlockArgLocs;
3099     for (const auto &it : llvm::enumerate(condOpArgs)) {
3100       auto index = static_cast<unsigned>(it.index());
3101       Value condOpArg = it.value();
3102       // Those values not defined within `before` block will be considered as
3103       // loop invariant values. We map the corresponding `index` with their
3104       // value.
3105       if (condOpArg.getParentBlock() != &beforeBlock) {
3106         condOpInitValMap.insert({index, condOpArg});
3107       } else {
3108         newCondOpArgs.emplace_back(condOpArg);
3109         newAfterBlockType.emplace_back(condOpArg.getType());
3110         newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3111       }
3112     }
3113 
3114     {
3115       OpBuilder::InsertionGuard g(rewriter);
3116       rewriter.setInsertionPoint(condOp);
3117       rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3118                                                newCondOpArgs);
3119     }
3120 
3121     auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3122                                              op.getOperands());
3123 
3124     Block &newAfterBlock =
3125         *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3126                               newAfterBlockType, newAfterBlockArgLocs);
3127 
3128     Block &afterBlock = op.getAfter().front();
3129     // Since a new scf.condition op was created, we need to fetch the new
3130     // `after` block arguments which will be used while replacing operations of
3131     // previous scf.while's `after` blocks. We'd also be fetching new result
3132     // values too.
3133     SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3134     SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3135     for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3136       Value afterBlockArg, result;
3137       // If index 'i' argument was loop invariant we fetch it's value from the
3138       // `condOpInitMap` map.
3139       if (condOpInitValMap.count(i) != 0) {
3140         afterBlockArg = condOpInitValMap[i];
3141         result = afterBlockArg;
3142       } else {
3143         afterBlockArg = newAfterBlock.getArgument(j);
3144         result = newWhile.getResult(j);
3145         j++;
3146       }
3147       newAfterBlockArgs[i] = afterBlockArg;
3148       newWhileResults[i] = result;
3149     }
3150 
3151     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3152     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3153                                 newWhile.getBefore().begin());
3154 
3155     rewriter.replaceOp(op, newWhileResults);
3156     return success();
3157   }
3158 };
3159 
3160 /// Remove WhileOp results that are also unused in 'after' block.
3161 ///
3162 ///  %0:2 = scf.while () : () -> (i32, i64) {
3163 ///    %condition = "test.condition"() : () -> i1
3164 ///    %v1 = "test.get_some_value"() : () -> i32
3165 ///    %v2 = "test.get_some_value"() : () -> i64
3166 ///    scf.condition(%condition) %v1, %v2 : i32, i64
3167 ///  } do {
3168 ///  ^bb0(%arg0: i32, %arg1: i64):
3169 ///    "test.use"(%arg0) : (i32) -> ()
3170 ///    scf.yield
3171 ///  }
3172 ///  return %0#0 : i32
3173 ///
3174 /// becomes
3175 ///  %0 = scf.while () : () -> (i32) {
3176 ///    %condition = "test.condition"() : () -> i1
3177 ///    %v1 = "test.get_some_value"() : () -> i32
3178 ///    %v2 = "test.get_some_value"() : () -> i64
3179 ///    scf.condition(%condition) %v1 : i32
3180 ///  } do {
3181 ///  ^bb0(%arg0: i32):
3182 ///    "test.use"(%arg0) : (i32) -> ()
3183 ///    scf.yield
3184 ///  }
3185 ///  return %0 : i32
3186 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3187   using OpRewritePattern<WhileOp>::OpRewritePattern;
3188 
3189   LogicalResult matchAndRewrite(WhileOp op,
3190                                 PatternRewriter &rewriter) const override {
3191     auto term = op.getConditionOp();
3192     auto afterArgs = op.getAfterArguments();
3193     auto termArgs = term.getArgs();
3194 
3195     // Collect results mapping, new terminator args and new result types.
3196     SmallVector<unsigned> newResultsIndices;
3197     SmallVector<Type> newResultTypes;
3198     SmallVector<Value> newTermArgs;
3199     SmallVector<Location> newArgLocs;
3200     bool needUpdate = false;
3201     for (const auto &it :
3202          llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3203       auto i = static_cast<unsigned>(it.index());
3204       Value result = std::get<0>(it.value());
3205       Value afterArg = std::get<1>(it.value());
3206       Value termArg = std::get<2>(it.value());
3207       if (result.use_empty() && afterArg.use_empty()) {
3208         needUpdate = true;
3209       } else {
3210         newResultsIndices.emplace_back(i);
3211         newTermArgs.emplace_back(termArg);
3212         newResultTypes.emplace_back(result.getType());
3213         newArgLocs.emplace_back(result.getLoc());
3214       }
3215     }
3216 
3217     if (!needUpdate)
3218       return failure();
3219 
3220     {
3221       OpBuilder::InsertionGuard g(rewriter);
3222       rewriter.setInsertionPoint(term);
3223       rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3224                                                newTermArgs);
3225     }
3226 
3227     auto newWhile =
3228         rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3229 
3230     Block &newAfterBlock = *rewriter.createBlock(
3231         &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3232 
3233     // Build new results list and new after block args (unused entries will be
3234     // null).
3235     SmallVector<Value> newResults(op.getNumResults());
3236     SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3237     for (const auto &it : llvm::enumerate(newResultsIndices)) {
3238       newResults[it.value()] = newWhile.getResult(it.index());
3239       newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3240     }
3241 
3242     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3243                                 newWhile.getBefore().begin());
3244 
3245     Block &afterBlock = op.getAfter().front();
3246     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3247 
3248     rewriter.replaceOp(op, newResults);
3249     return success();
3250   }
3251 };
3252 
3253 /// Replace operations equivalent to the condition in the do block with true,
3254 /// since otherwise the block would not be evaluated.
3255 ///
3256 /// scf.while (..) : (i32, ...) -> ... {
3257 ///  %z = ... : i32
3258 ///  %condition = cmpi pred %z, %a
3259 ///  scf.condition(%condition) %z : i32, ...
3260 /// } do {
3261 /// ^bb0(%arg0: i32, ...):
3262 ///    %condition2 = cmpi pred %arg0, %a
3263 ///    use(%condition2)
3264 ///    ...
3265 ///
3266 /// becomes
3267 /// scf.while (..) : (i32, ...) -> ... {
3268 ///  %z = ... : i32
3269 ///  %condition = cmpi pred %z, %a
3270 ///  scf.condition(%condition) %z : i32, ...
3271 /// } do {
3272 /// ^bb0(%arg0: i32, ...):
3273 ///    use(%true)
3274 ///    ...
3275 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3276   using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3277 
3278   LogicalResult matchAndRewrite(scf::WhileOp op,
3279                                 PatternRewriter &rewriter) const override {
3280     using namespace scf;
3281     auto cond = op.getConditionOp();
3282     auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3283     if (!cmp)
3284       return failure();
3285     bool changed = false;
3286     for (auto tup :
3287          llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
3288       for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3289         if (std::get<0>(tup) != cmp.getOperand(opIdx))
3290           continue;
3291         for (OpOperand &u :
3292              llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3293           auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3294           if (!cmp2)
3295             continue;
3296           // For a binary operator 1-opIdx gets the other side.
3297           if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3298             continue;
3299           bool samePredicate;
3300           if (cmp2.getPredicate() == cmp.getPredicate())
3301             samePredicate = true;
3302           else if (cmp2.getPredicate() ==
3303                    arith::invertPredicate(cmp.getPredicate()))
3304             samePredicate = false;
3305           else
3306             continue;
3307 
3308           rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3309                                                             1);
3310           changed = true;
3311         }
3312       }
3313     }
3314     return success(changed);
3315   }
3316 };
3317 
3318 struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
3319   using OpRewritePattern<WhileOp>::OpRewritePattern;
3320 
3321   LogicalResult matchAndRewrite(WhileOp op,
3322                                 PatternRewriter &rewriter) const override {
3323 
3324     if (!llvm::any_of(op.getBeforeArguments(),
3325                       [](Value arg) { return arg.use_empty(); }))
3326       return failure();
3327 
3328     YieldOp yield = op.getYieldOp();
3329 
3330     // Collect results mapping, new terminator args and new result types.
3331     SmallVector<Value> newYields;
3332     SmallVector<Value> newInits;
3333     SmallVector<unsigned> argsToErase;
3334     for (const auto &it : llvm::enumerate(llvm::zip(
3335              op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
3336       Value beforeArg = std::get<0>(it.value());
3337       Value yieldValue = std::get<1>(it.value());
3338       Value initValue = std::get<2>(it.value());
3339       if (beforeArg.use_empty()) {
3340         argsToErase.push_back(it.index());
3341       } else {
3342         newYields.emplace_back(yieldValue);
3343         newInits.emplace_back(initValue);
3344       }
3345     }
3346 
3347     if (argsToErase.empty())
3348       return failure();
3349 
3350     rewriter.startRootUpdate(op);
3351     op.getBefore().front().eraseArguments(argsToErase);
3352     rewriter.finalizeRootUpdate(op);
3353 
3354     WhileOp replacement =
3355         rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
3356     replacement.getBefore().takeBody(op.getBefore());
3357     replacement.getAfter().takeBody(op.getAfter());
3358     rewriter.replaceOp(op, replacement.getResults());
3359 
3360     rewriter.setInsertionPoint(yield);
3361     rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3362     return success();
3363   }
3364 };
3365 } // namespace
3366 
3367 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3368                                           MLIRContext *context) {
3369   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
3370               RemoveLoopInvariantValueYielded, WhileConditionTruth,
3371               WhileCmpCond, WhileUnusedResult>(context);
3372 }
3373 
3374 //===----------------------------------------------------------------------===//
3375 // TableGen'd op method definitions
3376 //===----------------------------------------------------------------------===//
3377 
3378 #define GET_OP_CLASSES
3379 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
3380