1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/IR/SCF.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/FunctionInterfaces.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Support/MathExtras.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 
23 using namespace mlir;
24 using namespace mlir::scf;
25 
26 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // SCFDialect Dialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 struct SCFInlinerInterface : public DialectInlinerInterface {
34   using DialectInlinerInterface::DialectInlinerInterface;
35   // We don't have any special restrictions on what can be inlined into
36   // destination regions (e.g. while/conditional bodies). Always allow it.
37   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
38                        BlockAndValueMapping &valueMapping) const final {
39     return true;
40   }
41   // Operations in scf dialect are always legal to inline since they are
42   // pure.
43   bool isLegalToInline(Operation *, Region *, bool,
44                        BlockAndValueMapping &) const final {
45     return true;
46   }
47   // Handle the given inlined terminator by replacing it with a new operation
48   // as necessary. Required when the region has only one block.
49   void handleTerminator(Operation *op,
50                         ArrayRef<Value> valuesToRepl) const final {
51     auto retValOp = dyn_cast<scf::YieldOp>(op);
52     if (!retValOp)
53       return;
54 
55     for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
56       std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
57     }
58   }
59 };
60 } // namespace
61 
62 //===----------------------------------------------------------------------===//
63 // SCFDialect
64 //===----------------------------------------------------------------------===//
65 
66 void SCFDialect::initialize() {
67   addOperations<
68 #define GET_OP_LIST
69 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
70       >();
71   addInterfaces<SCFInlinerInterface>();
72 }
73 
74 /// Default callback for IfOp builders. Inserts a yield without arguments.
75 void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
76   builder.create<scf::YieldOp>(loc);
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // ExecuteRegionOp
81 //===----------------------------------------------------------------------===//
82 
83 /// Replaces the given op with the contents of the given single-block region,
84 /// using the operands of the block terminator to replace operation results.
85 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
86                                 Region &region, ValueRange blockArgs = {}) {
87   assert(llvm::hasSingleElement(region) && "expected single-region block");
88   Block *block = &region.front();
89   Operation *terminator = block->getTerminator();
90   ValueRange results = terminator->getOperands();
91   rewriter.mergeBlockBefore(block, op, blockArgs);
92   rewriter.replaceOp(op, results);
93   rewriter.eraseOp(terminator);
94 }
95 
96 ///
97 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
98 ///    block+
99 /// `}`
100 ///
101 /// Example:
102 ///   scf.execute_region -> i32 {
103 ///     %idx = load %rI[%i] : memref<128xi32>
104 ///     return %idx : i32
105 ///   }
106 ///
107 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
108                                    OperationState &result) {
109   if (parser.parseOptionalArrowTypeList(result.types))
110     return failure();
111 
112   // Introduce the body region and parse it.
113   Region *body = result.addRegion();
114   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
115       parser.parseOptionalAttrDict(result.attributes))
116     return failure();
117 
118   return success();
119 }
120 
121 void ExecuteRegionOp::print(OpAsmPrinter &p) {
122   p.printOptionalArrowTypeList(getResultTypes());
123 
124   p << ' ';
125   p.printRegion(getRegion(),
126                 /*printEntryBlockArgs=*/false,
127                 /*printBlockTerminators=*/true);
128 
129   p.printOptionalAttrDict((*this)->getAttrs());
130 }
131 
132 LogicalResult ExecuteRegionOp::verify() {
133   if (getRegion().empty())
134     return emitOpError("region needs to have at least one block");
135   if (getRegion().front().getNumArguments() > 0)
136     return emitOpError("region cannot have any arguments");
137   return success();
138 }
139 
140 // Inline an ExecuteRegionOp if it only contains one block.
141 //     "test.foo"() : () -> ()
142 //      %v = scf.execute_region -> i64 {
143 //        %x = "test.val"() : () -> i64
144 //        scf.yield %x : i64
145 //      }
146 //      "test.bar"(%v) : (i64) -> ()
147 //
148 //  becomes
149 //
150 //     "test.foo"() : () -> ()
151 //     %x = "test.val"() : () -> i64
152 //     "test.bar"(%x) : (i64) -> ()
153 //
154 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
155   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
156 
157   LogicalResult matchAndRewrite(ExecuteRegionOp op,
158                                 PatternRewriter &rewriter) const override {
159     if (!llvm::hasSingleElement(op.getRegion()))
160       return failure();
161     replaceOpWithRegion(rewriter, op, op.getRegion());
162     return success();
163   }
164 };
165 
166 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
167 // TODO generalize the conditions for operations which can be inlined into.
168 // func @func_execute_region_elim() {
169 //     "test.foo"() : () -> ()
170 //     %v = scf.execute_region -> i64 {
171 //       %c = "test.cmp"() : () -> i1
172 //       cf.cond_br %c, ^bb2, ^bb3
173 //     ^bb2:
174 //       %x = "test.val1"() : () -> i64
175 //       cf.br ^bb4(%x : i64)
176 //     ^bb3:
177 //       %y = "test.val2"() : () -> i64
178 //       cf.br ^bb4(%y : i64)
179 //     ^bb4(%z : i64):
180 //       scf.yield %z : i64
181 //     }
182 //     "test.bar"(%v) : (i64) -> ()
183 //   return
184 // }
185 //
186 //  becomes
187 //
188 // func @func_execute_region_elim() {
189 //    "test.foo"() : () -> ()
190 //    %c = "test.cmp"() : () -> i1
191 //    cf.cond_br %c, ^bb1, ^bb2
192 //  ^bb1:  // pred: ^bb0
193 //    %x = "test.val1"() : () -> i64
194 //    cf.br ^bb3(%x : i64)
195 //  ^bb2:  // pred: ^bb0
196 //    %y = "test.val2"() : () -> i64
197 //    cf.br ^bb3(%y : i64)
198 //  ^bb3(%z: i64):  // 2 preds: ^bb1, ^bb2
199 //    "test.bar"(%z) : (i64) -> ()
200 //    return
201 //  }
202 //
203 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
204   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
205 
206   LogicalResult matchAndRewrite(ExecuteRegionOp op,
207                                 PatternRewriter &rewriter) const override {
208     if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
209       return failure();
210 
211     Block *prevBlock = op->getBlock();
212     Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
213     rewriter.setInsertionPointToEnd(prevBlock);
214 
215     rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
216 
217     for (Block &blk : op.getRegion()) {
218       if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
219         rewriter.setInsertionPoint(yieldOp);
220         rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
221                                       yieldOp.getResults());
222         rewriter.eraseOp(yieldOp);
223       }
224     }
225 
226     rewriter.inlineRegionBefore(op.getRegion(), postBlock);
227     SmallVector<Value> blockArgs;
228 
229     for (auto res : op.getResults())
230       blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
231 
232     rewriter.replaceOp(op, blockArgs);
233     return success();
234   }
235 };
236 
237 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
238                                                   MLIRContext *context) {
239   results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
240 }
241 
242 /// Given the region at `index`, or the parent operation if `index` is None,
243 /// return the successor regions. These are the regions that may be selected
244 /// during the flow of control. `operands` is a set of optional attributes that
245 /// correspond to a constant value for each operand, or null if that operand is
246 /// not a constant.
247 void ExecuteRegionOp::getSuccessorRegions(
248     Optional<unsigned> index, ArrayRef<Attribute> operands,
249     SmallVectorImpl<RegionSuccessor> &regions) {
250   // If the predecessor is the ExecuteRegionOp, branch into the body.
251   if (!index) {
252     regions.push_back(RegionSuccessor(&getRegion()));
253     return;
254   }
255 
256   // Otherwise, the region branches back to the parent operation.
257   regions.push_back(RegionSuccessor(getResults()));
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // ConditionOp
262 //===----------------------------------------------------------------------===//
263 
264 MutableOperandRange
265 ConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
266   // Pass all operands except the condition to the successor region.
267   return getArgsMutable();
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // ForOp
272 //===----------------------------------------------------------------------===//
273 
274 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
275                   Value ub, Value step, ValueRange iterArgs,
276                   BodyBuilderFn bodyBuilder) {
277   result.addOperands({lb, ub, step});
278   result.addOperands(iterArgs);
279   for (Value v : iterArgs)
280     result.addTypes(v.getType());
281   Region *bodyRegion = result.addRegion();
282   bodyRegion->push_back(new Block);
283   Block &bodyBlock = bodyRegion->front();
284   bodyBlock.addArgument(builder.getIndexType(), result.location);
285   for (Value v : iterArgs)
286     bodyBlock.addArgument(v.getType(), v.getLoc());
287 
288   // Create the default terminator if the builder is not provided and if the
289   // iteration arguments are not provided. Otherwise, leave this to the caller
290   // because we don't know which values to return from the loop.
291   if (iterArgs.empty() && !bodyBuilder) {
292     ForOp::ensureTerminator(*bodyRegion, builder, result.location);
293   } else if (bodyBuilder) {
294     OpBuilder::InsertionGuard guard(builder);
295     builder.setInsertionPointToStart(&bodyBlock);
296     bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
297                 bodyBlock.getArguments().drop_front());
298   }
299 }
300 
301 LogicalResult ForOp::verify() {
302   if (auto cst = getStep().getDefiningOp<arith::ConstantIndexOp>())
303     if (cst.value() <= 0)
304       return emitOpError("constant step operand must be positive");
305 
306   auto opNumResults = getNumResults();
307   if (opNumResults == 0)
308     return success();
309   // If ForOp defines values, check that the number and types of
310   // the defined values match ForOp initial iter operands and backedge
311   // basic block arguments.
312   if (getNumIterOperands() != opNumResults)
313     return emitOpError(
314         "mismatch in number of loop-carried values and defined values");
315   return success();
316 }
317 
318 LogicalResult ForOp::verifyRegions() {
319   // Check that the body defines as single block argument for the induction
320   // variable.
321   auto *body = getBody();
322   if (!body->getArgument(0).getType().isIndex())
323     return emitOpError(
324         "expected body first argument to be an index argument for "
325         "the induction variable");
326 
327   auto opNumResults = getNumResults();
328   if (opNumResults == 0)
329     return success();
330 
331   if (getNumRegionIterArgs() != opNumResults)
332     return emitOpError(
333         "mismatch in number of basic block args and defined values");
334 
335   auto iterOperands = getIterOperands();
336   auto iterArgs = getRegionIterArgs();
337   auto opResults = getResults();
338   unsigned i = 0;
339   for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
340     if (std::get<0>(e).getType() != std::get<2>(e).getType())
341       return emitOpError() << "types mismatch between " << i
342                            << "th iter operand and defined value";
343     if (std::get<1>(e).getType() != std::get<2>(e).getType())
344       return emitOpError() << "types mismatch between " << i
345                            << "th iter region arg and defined value";
346 
347     i++;
348   }
349   return success();
350 }
351 
352 Optional<Value> ForOp::getSingleInductionVar() { return getInductionVar(); }
353 
354 Optional<OpFoldResult> ForOp::getSingleLowerBound() {
355   return OpFoldResult(getLowerBound());
356 }
357 
358 Optional<OpFoldResult> ForOp::getSingleStep() {
359   return OpFoldResult(getStep());
360 }
361 
362 Optional<OpFoldResult> ForOp::getSingleUpperBound() {
363   return OpFoldResult(getUpperBound());
364 }
365 
366 /// Prints the initialization list in the form of
367 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
368 /// where 'inner' values are assumed to be region arguments and 'outer' values
369 /// are regular SSA values.
370 static void printInitializationList(OpAsmPrinter &p,
371                                     Block::BlockArgListType blocksArgs,
372                                     ValueRange initializers,
373                                     StringRef prefix = "") {
374   assert(blocksArgs.size() == initializers.size() &&
375          "expected same length of arguments and initializers");
376   if (initializers.empty())
377     return;
378 
379   p << prefix << '(';
380   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
381     p << std::get<0>(it) << " = " << std::get<1>(it);
382   });
383   p << ")";
384 }
385 
386 void ForOp::print(OpAsmPrinter &p) {
387   p << " " << getInductionVar() << " = " << getLowerBound() << " to "
388     << getUpperBound() << " step " << getStep();
389 
390   printInitializationList(p, getRegionIterArgs(), getIterOperands(),
391                           " iter_args");
392   if (!getIterOperands().empty())
393     p << " -> (" << getIterOperands().getTypes() << ')';
394   p << ' ';
395   p.printRegion(getRegion(),
396                 /*printEntryBlockArgs=*/false,
397                 /*printBlockTerminators=*/hasIterOperands());
398   p.printOptionalAttrDict((*this)->getAttrs());
399 }
400 
401 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
402   auto &builder = parser.getBuilder();
403   Type indexType = builder.getIndexType();
404 
405   OpAsmParser::Argument inductionVariable;
406   inductionVariable.type = indexType;
407   OpAsmParser::UnresolvedOperand lb, ub, step;
408 
409   // Parse the induction variable followed by '='.
410   if (parser.parseArgument(inductionVariable) || parser.parseEqual() ||
411       // Parse loop bounds.
412       parser.parseOperand(lb) ||
413       parser.resolveOperand(lb, indexType, result.operands) ||
414       parser.parseKeyword("to") || parser.parseOperand(ub) ||
415       parser.resolveOperand(ub, indexType, result.operands) ||
416       parser.parseKeyword("step") || parser.parseOperand(step) ||
417       parser.resolveOperand(step, indexType, result.operands))
418     return failure();
419 
420   // Parse the optional initial iteration arguments.
421   SmallVector<OpAsmParser::Argument, 4> regionArgs;
422   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
423   regionArgs.push_back(inductionVariable);
424 
425   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
426     // Parse assignment list and results type list.
427     if (parser.parseAssignmentList(regionArgs, operands) ||
428         parser.parseArrowTypeList(result.types))
429       return failure();
430 
431     // Resolve input operands.
432     for (auto argOperandType :
433          llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
434       Type type = std::get<2>(argOperandType);
435       std::get<0>(argOperandType).type = type;
436       if (parser.resolveOperand(std::get<1>(argOperandType), type,
437                                 result.operands))
438         return failure();
439     }
440   }
441 
442   if (regionArgs.size() != result.types.size() + 1)
443     return parser.emitError(
444         parser.getNameLoc(),
445         "mismatch in number of loop-carried values and defined values");
446 
447   // Parse the body region.
448   Region *body = result.addRegion();
449   if (parser.parseRegion(*body, regionArgs))
450     return failure();
451 
452   ForOp::ensureTerminator(*body, builder, result.location);
453 
454   // Parse the optional attribute list.
455   if (parser.parseOptionalAttrDict(result.attributes))
456     return failure();
457 
458   return success();
459 }
460 
461 Region &ForOp::getLoopBody() { return getRegion(); }
462 
463 ForOp mlir::scf::getForInductionVarOwner(Value val) {
464   auto ivArg = val.dyn_cast<BlockArgument>();
465   if (!ivArg)
466     return ForOp();
467   assert(ivArg.getOwner() && "unlinked block argument");
468   auto *containingOp = ivArg.getOwner()->getParentOp();
469   return dyn_cast_or_null<ForOp>(containingOp);
470 }
471 
472 /// Return operands used when entering the region at 'index'. These operands
473 /// correspond to the loop iterator operands, i.e., those excluding the
474 /// induction variable. LoopOp only has one region, so 0 is the only valid value
475 /// for `index`.
476 OperandRange ForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
477   assert(index && *index == 0 && "invalid region index");
478 
479   // The initial operands map to the loop arguments after the induction
480   // variable.
481   return getInitArgs();
482 }
483 
484 /// Given the region at `index`, or the parent operation if `index` is None,
485 /// return the successor regions. These are the regions that may be selected
486 /// during the flow of control. `operands` is a set of optional attributes that
487 /// correspond to a constant value for each operand, or null if that operand is
488 /// not a constant.
489 void ForOp::getSuccessorRegions(Optional<unsigned> index,
490                                 ArrayRef<Attribute> operands,
491                                 SmallVectorImpl<RegionSuccessor> &regions) {
492   // If the predecessor is the ForOp, branch into the body using the iterator
493   // arguments.
494   if (!index) {
495     regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
496     return;
497   }
498 
499   // Otherwise, the loop may branch back to itself or the parent operation.
500   assert(*index == 0 && "expected loop region");
501   regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
502   regions.push_back(RegionSuccessor(getResults()));
503 }
504 
505 LoopNest mlir::scf::buildLoopNest(
506     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
507     ValueRange steps, ValueRange iterArgs,
508     function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
509         bodyBuilder) {
510   assert(lbs.size() == ubs.size() &&
511          "expected the same number of lower and upper bounds");
512   assert(lbs.size() == steps.size() &&
513          "expected the same number of lower bounds and steps");
514 
515   // If there are no bounds, call the body-building function and return early.
516   if (lbs.empty()) {
517     ValueVector results =
518         bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
519                     : ValueVector();
520     assert(results.size() == iterArgs.size() &&
521            "loop nest body must return as many values as loop has iteration "
522            "arguments");
523     return LoopNest();
524   }
525 
526   // First, create the loop structure iteratively using the body-builder
527   // callback of `ForOp::build`. Do not create `YieldOp`s yet.
528   OpBuilder::InsertionGuard guard(builder);
529   SmallVector<scf::ForOp, 4> loops;
530   SmallVector<Value, 4> ivs;
531   loops.reserve(lbs.size());
532   ivs.reserve(lbs.size());
533   ValueRange currentIterArgs = iterArgs;
534   Location currentLoc = loc;
535   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
536     auto loop = builder.create<scf::ForOp>(
537         currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
538         [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
539             ValueRange args) {
540           ivs.push_back(iv);
541           // It is safe to store ValueRange args because it points to block
542           // arguments of a loop operation that we also own.
543           currentIterArgs = args;
544           currentLoc = nestedLoc;
545         });
546     // Set the builder to point to the body of the newly created loop. We don't
547     // do this in the callback because the builder is reset when the callback
548     // returns.
549     builder.setInsertionPointToStart(loop.getBody());
550     loops.push_back(loop);
551   }
552 
553   // For all loops but the innermost, yield the results of the nested loop.
554   for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
555     builder.setInsertionPointToEnd(loops[i].getBody());
556     builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
557   }
558 
559   // In the body of the innermost loop, call the body building function if any
560   // and yield its results.
561   builder.setInsertionPointToStart(loops.back().getBody());
562   ValueVector results = bodyBuilder
563                             ? bodyBuilder(builder, currentLoc, ivs,
564                                           loops.back().getRegionIterArgs())
565                             : ValueVector();
566   assert(results.size() == iterArgs.size() &&
567          "loop nest body must return as many values as loop has iteration "
568          "arguments");
569   builder.setInsertionPointToEnd(loops.back().getBody());
570   builder.create<scf::YieldOp>(loc, results);
571 
572   // Return the loops.
573   LoopNest res;
574   res.loops.assign(loops.begin(), loops.end());
575   return res;
576 }
577 
578 LoopNest mlir::scf::buildLoopNest(
579     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
580     ValueRange steps,
581     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
582   // Delegate to the main function by wrapping the body builder.
583   return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
584                        [&bodyBuilder](OpBuilder &nestedBuilder,
585                                       Location nestedLoc, ValueRange ivs,
586                                       ValueRange) -> ValueVector {
587                          if (bodyBuilder)
588                            bodyBuilder(nestedBuilder, nestedLoc, ivs);
589                          return {};
590                        });
591 }
592 
593 namespace {
594 // Fold away ForOp iter arguments when:
595 // 1) The op yields the iter arguments.
596 // 2) The iter arguments have no use and the corresponding outer region
597 // iterators (inputs) are yielded.
598 // 3) The iter arguments have no use and the corresponding (operation) results
599 // have no use.
600 //
601 // These arguments must be defined outside of
602 // the ForOp region and can just be forwarded after simplifying the op inits,
603 // yields and returns.
604 //
605 // The implementation uses `mergeBlockBefore` to steal the content of the
606 // original ForOp and avoid cloning.
607 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
608   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
609 
610   LogicalResult matchAndRewrite(scf::ForOp forOp,
611                                 PatternRewriter &rewriter) const final {
612     bool canonicalize = false;
613     Block &block = forOp.getRegion().front();
614     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
615 
616     // An internal flat vector of block transfer
617     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
618     // transformed block argument mappings. This plays the role of a
619     // BlockAndValueMapping for the particular use case of calling into
620     // `mergeBlockBefore`.
621     SmallVector<bool, 4> keepMask;
622     keepMask.reserve(yieldOp.getNumOperands());
623     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
624         newResultValues;
625     newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
626     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
627     newIterArgs.reserve(forOp.getNumIterOperands());
628     newYieldValues.reserve(yieldOp.getNumOperands());
629     newResultValues.reserve(forOp.getNumResults());
630     for (auto it : llvm::zip(forOp.getIterOperands(),   // iter from outside
631                              forOp.getRegionIterArgs(), // iter inside region
632                              forOp.getResults(),        // op results
633                              yieldOp.getOperands()      // iter yield
634                              )) {
635       // Forwarded is `true` when:
636       // 1) The region `iter` argument is yielded.
637       // 2) The region `iter` argument has no use, and the corresponding iter
638       // operand (input) is yielded.
639       // 3) The region `iter` argument has no use, and the corresponding op
640       // result has no use.
641       bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
642                         (std::get<1>(it).use_empty() &&
643                          (std::get<0>(it) == std::get<3>(it) ||
644                           std::get<2>(it).use_empty())));
645       keepMask.push_back(!forwarded);
646       canonicalize |= forwarded;
647       if (forwarded) {
648         newBlockTransferArgs.push_back(std::get<0>(it));
649         newResultValues.push_back(std::get<0>(it));
650         continue;
651       }
652       newIterArgs.push_back(std::get<0>(it));
653       newYieldValues.push_back(std::get<3>(it));
654       newBlockTransferArgs.push_back(Value()); // placeholder with null value
655       newResultValues.push_back(Value());      // placeholder with null value
656     }
657 
658     if (!canonicalize)
659       return failure();
660 
661     scf::ForOp newForOp = rewriter.create<scf::ForOp>(
662         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
663         forOp.getStep(), newIterArgs);
664     newForOp->setAttrs(forOp->getAttrs());
665     Block &newBlock = newForOp.getRegion().front();
666 
667     // Replace the null placeholders with newly constructed values.
668     newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
669     for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
670          idx != e; ++idx) {
671       Value &blockTransferArg = newBlockTransferArgs[1 + idx];
672       Value &newResultVal = newResultValues[idx];
673       assert((blockTransferArg && newResultVal) ||
674              (!blockTransferArg && !newResultVal));
675       if (!blockTransferArg) {
676         blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
677         newResultVal = newForOp.getResult(collapsedIdx++);
678       }
679     }
680 
681     Block &oldBlock = forOp.getRegion().front();
682     assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
683            "unexpected argument size mismatch");
684 
685     // No results case: the scf::ForOp builder already created a zero
686     // result terminator. Merge before this terminator and just get rid of the
687     // original terminator that has been merged in.
688     if (newIterArgs.empty()) {
689       auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
690       rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
691       rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
692       rewriter.replaceOp(forOp, newResultValues);
693       return success();
694     }
695 
696     // No terminator case: merge and rewrite the merged terminator.
697     auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
698       OpBuilder::InsertionGuard g(rewriter);
699       rewriter.setInsertionPoint(mergedTerminator);
700       SmallVector<Value, 4> filteredOperands;
701       filteredOperands.reserve(newResultValues.size());
702       for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
703         if (keepMask[idx])
704           filteredOperands.push_back(mergedTerminator.getOperand(idx));
705       rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
706                                     filteredOperands);
707     };
708 
709     rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
710     auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
711     cloneFilteredTerminator(mergedYieldOp);
712     rewriter.eraseOp(mergedYieldOp);
713     rewriter.replaceOp(forOp, newResultValues);
714     return success();
715   }
716 };
717 
718 /// Rewriting pattern that erases loops that are known not to iterate, replaces
719 /// single-iteration loops with their bodies, and removes empty loops that
720 /// iterate at least once and only return values defined outside of the loop.
721 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
722   using OpRewritePattern<ForOp>::OpRewritePattern;
723 
724   LogicalResult matchAndRewrite(ForOp op,
725                                 PatternRewriter &rewriter) const override {
726     // If the upper bound is the same as the lower bound, the loop does not
727     // iterate, just remove it.
728     if (op.getLowerBound() == op.getUpperBound()) {
729       rewriter.replaceOp(op, op.getIterOperands());
730       return success();
731     }
732 
733     auto lb = op.getLowerBound().getDefiningOp<arith::ConstantOp>();
734     auto ub = op.getUpperBound().getDefiningOp<arith::ConstantOp>();
735     if (!lb || !ub)
736       return failure();
737 
738     // If the loop is known to have 0 iterations, remove it.
739     llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
740     llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
741     if (lbValue.sge(ubValue)) {
742       rewriter.replaceOp(op, op.getIterOperands());
743       return success();
744     }
745 
746     auto step = op.getStep().getDefiningOp<arith::ConstantOp>();
747     if (!step)
748       return failure();
749 
750     // If the loop is known to have 1 iteration, inline its body and remove the
751     // loop.
752     llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
753     if ((lbValue + stepValue).sge(ubValue)) {
754       SmallVector<Value, 4> blockArgs;
755       blockArgs.reserve(op.getNumIterOperands() + 1);
756       blockArgs.push_back(op.getLowerBound());
757       llvm::append_range(blockArgs, op.getIterOperands());
758       replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
759       return success();
760     }
761 
762     // Now we are left with loops that have more than 1 iterations.
763     Block &block = op.getRegion().front();
764     if (!llvm::hasSingleElement(block))
765       return failure();
766     // If the loop is empty, iterates at least once, and only returns values
767     // defined outside of the loop, remove it and replace it with yield values.
768     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
769     auto yieldOperands = yieldOp.getOperands();
770     if (llvm::any_of(yieldOperands,
771                      [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
772       return failure();
773     rewriter.replaceOp(op, yieldOperands);
774     return success();
775   }
776 };
777 
778 /// Perform a replacement of one iter OpOperand of an scf.for to the
779 /// `replacement` value which is expected to be the source of a tensor.cast.
780 /// tensor.cast ops are inserted inside the block to account for the type cast.
781 static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
782                                            OpOperand &operand,
783                                            Value replacement) {
784   Type oldType = operand.get().getType(), newType = replacement.getType();
785   assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
786          "expected ranked tensor types");
787 
788   // 1. Create new iter operands, exactly 1 is replaced.
789   ForOp forOp = cast<ForOp>(operand.getOwner());
790   assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
791          "expected an iter OpOperand");
792   if (operand.get().getType() == replacement.getType())
793     return forOp;
794   SmallVector<Value> newIterOperands;
795   for (OpOperand &opOperand : forOp.getIterOpOperands()) {
796     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
797       newIterOperands.push_back(replacement);
798       continue;
799     }
800     newIterOperands.push_back(opOperand.get());
801   }
802 
803   // 2. Create the new forOp shell.
804   scf::ForOp newForOp = rewriter.create<scf::ForOp>(
805       forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806       forOp.getStep(), newIterOperands);
807   newForOp->setAttrs(forOp->getAttrs());
808   Block &newBlock = newForOp.getRegion().front();
809   SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
810                                              newBlock.getArguments().end());
811 
812   // 3. Inject an incoming cast op at the beginning of the block for the bbArg
813   // corresponding to the `replacement` value.
814   OpBuilder::InsertionGuard g(rewriter);
815   rewriter.setInsertionPoint(&newBlock, newBlock.begin());
816   BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
817       newForOp->getOpOperand(operand.getOperandNumber()));
818   Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
819                                                  newRegionIterArg);
820   newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
821 
822   // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
823   Block &oldBlock = forOp.getRegion().front();
824   rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
825 
826   // 5. Inject an outgoing cast op at the end of the block and yield it instead.
827   auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
828   rewriter.setInsertionPoint(clonedYieldOp);
829   unsigned yieldIdx =
830       newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
831   Value castOut = rewriter.create<tensor::CastOp>(
832       newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
833   SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
834   newYieldOperands[yieldIdx] = castOut;
835   rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
836   rewriter.eraseOp(clonedYieldOp);
837 
838   // 6. Inject an outgoing cast op after the forOp.
839   rewriter.setInsertionPointAfter(newForOp);
840   SmallVector<Value> newResults = newForOp.getResults();
841   newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
842       newForOp.getLoc(), oldType, newResults[yieldIdx]);
843 
844   return newForOp;
845 }
846 
847 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
848 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
849 ///
850 /// ```
851 ///   %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
852 ///   %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
853 ///      -> (tensor<?x?xf32>) {
854 ///     %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
855 ///     scf.yield %2 : tensor<?x?xf32>
856 ///   }
857 ///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
858 ///   use_of(%2)
859 /// ```
860 ///
861 /// folds into:
862 ///
863 /// ```
864 ///   %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
865 ///       -> (tensor<32x1024xf32>) {
866 ///     %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
867 ///     %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
868 ///     %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
869 ///     scf.yield %4 : tensor<32x1024xf32>
870 ///   }
871 ///   use_of(%0)
872 /// ```
873 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
874   using OpRewritePattern<ForOp>::OpRewritePattern;
875 
876   LogicalResult matchAndRewrite(ForOp op,
877                                 PatternRewriter &rewriter) const override {
878     for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
879       OpOperand &iterOpOperand = std::get<0>(it);
880       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
881       if (!incomingCast)
882         continue;
883       if (!std::get<1>(it).hasOneUse())
884         continue;
885       auto outgoingCastOp =
886           dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
887       if (!outgoingCastOp)
888         continue;
889 
890       // Must be a tensor.cast op pair with matching types.
891       if (outgoingCastOp.getResult().getType() !=
892           incomingCast.source().getType())
893         continue;
894 
895       // Create a new ForOp with that iter operand replaced.
896       auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
897                                                     incomingCast.source());
898 
899       // Insert outgoing cast and use it to replace the corresponding result.
900       rewriter.setInsertionPointAfter(newForOp);
901       SmallVector<Value> replacements = newForOp.getResults();
902       unsigned returnIdx =
903           iterOpOperand.getOperandNumber() - op.getNumControlOperands();
904       replacements[returnIdx] = rewriter.create<tensor::CastOp>(
905           op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]);
906       rewriter.replaceOp(op, replacements);
907       return success();
908     }
909     return failure();
910   }
911 };
912 
913 /// Canonicalize the iter_args of an scf::ForOp that involve a
914 /// `bufferization.to_tensor` and for which only the last loop iteration is
915 /// actually visible outside of the loop. The canonicalization looks for a
916 /// pattern such as:
917 /// ```
918 ///    %t0 = ... : tensor_type
919 ///    %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
920 ///      ...
921 ///      // %m is either buffer_cast(%bb00) or defined above the loop
922 ///      %m... : memref_type
923 ///      ... // uses of %m with potential inplace updates
924 ///      %new_tensor = bufferization.to_tensor %m : memref_type
925 ///      ...
926 ///      scf.yield %new_tensor : tensor_type
927 ///    }
928 /// ```
929 ///
930 /// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
931 /// `%m = buffer_cast %bb0` op that feeds into the yielded
932 /// `bufferization.to_tensor` op.
933 ///
934 /// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
935 /// occurs between `bufferization.to_tensor and yield then the value %0
936 /// visible outside of the loop is the last `bufferization.to_tensor`
937 /// produced in the loop.
938 ///
939 /// For now, we approximate the absence of aliasing by only supporting the case
940 /// when the bufferization.to_tensor is the operation immediately preceding
941 /// the yield.
942 //
943 /// The canonicalization rewrites the pattern as:
944 /// ```
945 ///    // %m is either a buffer_cast or defined above
946 ///    %m... : memref_type
947 ///    scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
948 ///      ... // uses of %m with potential inplace updates
949 ///      scf.yield %bb0: tensor_type
950 ///    }
951 ///    %0 = bufferization.to_tensor %m : memref_type
952 /// ```
953 ///
954 /// A later bbArg canonicalization will further rewrite as:
955 /// ```
956 ///    // %m is either a buffer_cast or defined above
957 ///    %m... : memref_type
958 ///    scf.for ... { // no iter_args
959 ///      ... // uses of %m with potential inplace updates
960 ///    }
961 ///    %0 = bufferization.to_tensor %m : memref_type
962 /// ```
963 struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
964   using OpRewritePattern<ForOp>::OpRewritePattern;
965 
966   LogicalResult matchAndRewrite(ForOp forOp,
967                                 PatternRewriter &rewriter) const override {
968     assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
969            "unexpected multiple blocks");
970 
971     Location loc = forOp.getLoc();
972     DenseMap<Value, Value> replacements;
973     for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
974       unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
975       auto yieldOp =
976           cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
977       Value yieldVal = yieldOp->getOperand(idx);
978       auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
979       bool isTensor = bbArg.getType().isa<TensorType>();
980 
981       bufferization::ToMemrefOp tensorToMemref;
982       // Either bbArg has no use or it has a single buffer_cast use.
983       if (bbArg.hasOneUse())
984         tensorToMemref =
985             dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
986       if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
987         continue;
988       // If tensorToMemref is present, it must feed into the `ToTensorOp`.
989       if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
990         continue;
991       // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
992       // must be before `ToTensorOp` in the block so that the lastWrite
993       // property is not subject to additional side-effects.
994       // For now, we only support the case when ToTensorOp appears
995       // immediately before the terminator.
996       if (tensorLoadOp->getNextNode() != yieldOp)
997         continue;
998 
999       // Clone the optional tensorToMemref before forOp.
1000       if (tensorToMemref) {
1001         rewriter.setInsertionPoint(forOp);
1002         rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
1003             tensorToMemref, tensorToMemref.getMemref().getType(),
1004             tensorToMemref.getTensor());
1005       }
1006 
1007       // Clone the tensorLoad after forOp.
1008       rewriter.setInsertionPointAfter(forOp);
1009       Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
1010           loc, tensorLoadOp.getMemref());
1011       Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
1012       replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1013 
1014       // Make the terminator just yield the bbArg, the old tensorLoadOp + the
1015       // old bbArg (that is now directly yielded) will canonicalize away.
1016       rewriter.startRootUpdate(yieldOp);
1017       yieldOp.setOperand(idx, bbArg);
1018       rewriter.finalizeRootUpdate(yieldOp);
1019     }
1020     if (replacements.empty())
1021       return failure();
1022 
1023     // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
1024     // replaces the whole op and erase it unconditionally. This is wrong for
1025     // `forOp` as it generally contains ops with side effects.
1026     // Instead, use `rewriter.replaceOpWithIf`.
1027     SmallVector<Value> newResults;
1028     newResults.reserve(forOp.getNumResults());
1029     for (Value v : forOp.getResults()) {
1030       auto it = replacements.find(v);
1031       newResults.push_back((it != replacements.end()) ? it->second : v);
1032     }
1033     unsigned idx = 0;
1034     rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
1035       return op.get() != newResults[idx++];
1036     });
1037     return success();
1038   }
1039 };
1040 } // namespace
1041 
1042 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1043                                         MLIRContext *context) {
1044   results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1045               LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // ForeachThreadOp
1050 //===----------------------------------------------------------------------===//
1051 
1052 LogicalResult ForeachThreadOp::verify() {
1053   // Call terminator's verify to produce most informative error messages.
1054   if (failed(getTerminator().verify()))
1055     return failure();
1056 
1057   // Check that the body defines as single block argument for the thread index.
1058   auto *body = getBody();
1059   if (body->getNumArguments() != getRank())
1060     return emitOpError("region expects ") << getRank() << " arguments";
1061 
1062   // Verify consistency between the result types and the terminator.
1063   auto terminatorTypes = getTerminator().yieldedTypes();
1064   auto opResults = getResults();
1065   if (opResults.size() != terminatorTypes.size())
1066     return emitOpError("produces ")
1067            << opResults.size() << " results, but its terminator yields "
1068            << terminatorTypes.size() << " value(s)";
1069   unsigned i = 0;
1070   for (auto e : llvm::zip(terminatorTypes, opResults)) {
1071     if (std::get<0>(e) != std::get<1>(e).getType())
1072       return emitOpError() << "type mismatch between result " << i << " ("
1073                            << std::get<1>(e).getType() << ") and terminator ("
1074                            << std::get<0>(e) << ")";
1075     i++;
1076   }
1077   return success();
1078 }
1079 
1080 void ForeachThreadOp::print(OpAsmPrinter &p) {
1081   p << " (";
1082   llvm::interleaveComma(getThreadIndices(), p);
1083   p << ") in (";
1084   llvm::interleaveComma(getNumThreads(), p);
1085   p << ") -> (" << getResultTypes() << ") ";
1086   p.printRegion(getRegion(),
1087                 /*printEntryBlockArgs=*/false,
1088                 /*printBlockTerminators=*/getNumResults() > 0);
1089   p.printOptionalAttrDict(getOperation()->getAttrs());
1090 }
1091 
1092 ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
1093                                    OperationState &result) {
1094   auto &builder = parser.getBuilder();
1095   // Parse an opening `(` followed by thread index variables followed by `)`
1096   // TODO: when we can refer to such "induction variable"-like handles from the
1097   // declarative assembly format, we can implement the parser as a custom hook.
1098   SmallVector<OpAsmParser::Argument, 4> threadIndices;
1099   if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
1100     return failure();
1101 
1102   // Parse `in` threadNums.
1103   SmallVector<OpAsmParser::UnresolvedOperand, 4> threadNums;
1104   if (parser.parseKeyword("in") ||
1105       parser.parseOperandList(threadNums, threadIndices.size(),
1106                               OpAsmParser::Delimiter::Paren) ||
1107       parser.resolveOperands(threadNums, builder.getIndexType(),
1108                              result.operands))
1109     return failure();
1110 
1111   // Parse optional results.
1112   if (parser.parseOptionalArrowTypeList(result.types))
1113     return failure();
1114 
1115   // Parse region.
1116   std::unique_ptr<Region> region = std::make_unique<Region>();
1117   for (auto &idx : threadIndices)
1118     idx.type = builder.getIndexType();
1119   if (parser.parseRegion(*region, threadIndices))
1120     return failure();
1121 
1122   // Ensure terminator and move region.
1123   OpBuilder b(builder.getContext());
1124   ForeachThreadOp::ensureTerminator(*region, b, result.location);
1125   result.addRegion(std::move(region));
1126 
1127   // Parse the optional attribute list.
1128   if (parser.parseOptionalAttrDict(result.attributes))
1129     return failure();
1130 
1131   return success();
1132 }
1133 
1134 // Bodyless builder, result types must be specified.
1135 void ForeachThreadOp::build(mlir::OpBuilder &builder,
1136                             mlir::OperationState &result, TypeRange resultTypes,
1137                             ValueRange numThreads) {
1138   result.addOperands(numThreads);
1139 
1140   Region *bodyRegion = result.addRegion();
1141   OpBuilder::InsertionGuard g(builder);
1142   // createBlock sets the IP inside the block.
1143   // Generally we would guard against that but the default ensureTerminator impl
1144   // expects it ..
1145   builder.createBlock(bodyRegion);
1146   Block &bodyBlock = bodyRegion->front();
1147   bodyBlock.addArguments(
1148       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1149       SmallVector<Location>(numThreads.size(), result.location));
1150   ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
1151   result.addTypes(resultTypes);
1152 }
1153 
1154 // Builder that takes a bodyBuilder lambda, result types are inferred from
1155 // the terminator.
1156 void ForeachThreadOp::build(
1157     mlir::OpBuilder &builder, mlir::OperationState &result,
1158     ValueRange numThreads,
1159     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1160   result.addOperands(numThreads);
1161 
1162   OpBuilder::InsertionGuard g(builder);
1163   Region *bodyRegion = result.addRegion();
1164   builder.createBlock(bodyRegion);
1165   Block &bodyBlock = bodyRegion->front();
1166   bodyBlock.addArguments(
1167       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1168       SmallVector<Location>(numThreads.size(), result.location));
1169 
1170   OpBuilder::InsertionGuard guard(builder);
1171   builder.setInsertionPointToStart(&bodyBlock);
1172   bodyBuilder(builder, result.location, bodyBlock.getArguments());
1173   auto terminator =
1174       llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
1175   assert(terminator &&
1176          "expected bodyBuilder to create PerformConcurrentlyOp terminator");
1177   result.addTypes(terminator.yieldedTypes());
1178 }
1179 
1180 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1181 // unaware of the fact that our terminator also needs a region to be
1182 // well-formed. We override it here to ensure that we do the right thing.
1183 void ForeachThreadOp::ensureTerminator(Region &region, OpBuilder &builder,
1184                                        Location loc) {
1185   OpTrait::SingleBlockImplicitTerminator<PerformConcurrentlyOp>::Impl<
1186       ForeachThreadOp>::ensureTerminator(region, builder, loc);
1187   auto terminator =
1188       llvm::dyn_cast<PerformConcurrentlyOp>(region.front().getTerminator());
1189   if (terminator.getRegion().empty())
1190     builder.createBlock(&terminator.getRegion());
1191 }
1192 
1193 PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
1194   return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // ParallelInsertSliceOp
1199 //===----------------------------------------------------------------------===//
1200 
1201 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
1202 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
1203                                   Value source, Value dest,
1204                                   ArrayRef<OpFoldResult> offsets,
1205                                   ArrayRef<OpFoldResult> sizes,
1206                                   ArrayRef<OpFoldResult> strides,
1207                                   ArrayRef<NamedAttribute> attrs) {
1208   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1209   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1210   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1211                              ShapedType::kDynamicStrideOrOffset);
1212   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1213                              ShapedType::kDynamicSize);
1214   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1215                              ShapedType::kDynamicStrideOrOffset);
1216   build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
1217         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1218         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1219   result.addAttributes(attrs);
1220 }
1221 
1222 // Build a ParallelInsertSliceOp with dynamic entries.
1223 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
1224                                   Value source, Value dest, ValueRange offsets,
1225                                   ValueRange sizes, ValueRange strides,
1226                                   ArrayRef<NamedAttribute> attrs) {
1227   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1228       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1229   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1230       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1231   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1232       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1233   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1234 }
1235 
1236 namespace {
1237 /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
1238 class ParallelInsertSliceOpConstantArgumentFolder final
1239     : public OpRewritePattern<ParallelInsertSliceOp> {
1240 public:
1241   using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
1242 
1243   LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
1244                                 PatternRewriter &rewriter) const override {
1245     // No constant operand, just return.
1246     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1247           return matchPattern(operand, matchConstantIndex());
1248         }))
1249       return failure();
1250 
1251     // At least one of offsets/sizes/strides is a new constant.
1252     // Form the new list of operands and constant attributes from the
1253     // existing.
1254     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1255     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1256     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1257     canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1258     canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1259     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1260 
1261     // Create the new op in canonical form.
1262     rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
1263         insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
1264         mixedOffsets, mixedSizes, mixedStrides);
1265     return success();
1266   }
1267 };
1268 } // namespace
1269 
1270 void ParallelInsertSliceOp::getCanonicalizationPatterns(
1271     RewritePatternSet &results, MLIRContext *context) {
1272   results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
1273 }
1274 
1275 //===----------------------------------------------------------------------===//
1276 // PerformConcurrentlyOp
1277 //===----------------------------------------------------------------------===//
1278 
1279 // Build a PerformConcurrentlyOp with mixed static and dynamic entries.
1280 void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
1281   OpBuilder::InsertionGuard g(b);
1282   Region *bodyRegion = result.addRegion();
1283   b.createBlock(bodyRegion);
1284 }
1285 
1286 LogicalResult PerformConcurrentlyOp::verify() {
1287   // TODO: PerformConcurrentlyOpInterface.
1288   for (const Operation &op : getRegion().front().getOperations())
1289     if (!isa<ParallelInsertSliceOp>(op))
1290       return emitOpError(
1291           "expected only scf.foreach_thread.parallel_insert_slice ops");
1292   return success();
1293 }
1294 
1295 void PerformConcurrentlyOp::print(OpAsmPrinter &p) {
1296   p << " ";
1297   p.printRegion(getRegion(),
1298                 /*printEntryBlockArgs=*/false,
1299                 /*printBlockTerminators=*/false);
1300   p.printOptionalAttrDict(getOperation()->getAttrs());
1301 }
1302 
1303 ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
1304                                          OperationState &result) {
1305   auto &builder = parser.getBuilder();
1306 
1307   SmallVector<OpAsmParser::Argument, 8> regionOperands;
1308   std::unique_ptr<Region> region = std::make_unique<Region>();
1309   if (parser.parseRegion(*region, regionOperands))
1310     return failure();
1311 
1312   if (region->empty())
1313     OpBuilder(builder.getContext()).createBlock(region.get());
1314   result.addRegion(std::move(region));
1315 
1316   // Parse the optional attribute list.
1317   if (parser.parseOptionalAttrDict(result.attributes))
1318     return failure();
1319   return success();
1320 }
1321 
1322 SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
1323   return llvm::to_vector<4>(
1324       llvm::map_range(this->yieldingOps(), [](Operation &op) {
1325         auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(&op);
1326         return insertSliceOp ? insertSliceOp.yieldedType() : Type();
1327       }));
1328 }
1329 
1330 llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::yieldingOps() {
1331   return getRegion().front().getOperations();
1332 }
1333 
1334 //===----------------------------------------------------------------------===//
1335 // IfOp
1336 //===----------------------------------------------------------------------===//
1337 
1338 bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
1339   assert(a && "expected non-empty operation");
1340   assert(b && "expected non-empty operation");
1341 
1342   IfOp ifOp = a->getParentOfType<IfOp>();
1343   while (ifOp) {
1344     // Check if b is inside ifOp. (We already know that a is.)
1345     if (ifOp->isProperAncestor(b))
1346       // b is contained in ifOp. a and b are in mutually exclusive branches if
1347       // they are in different blocks of ifOp.
1348       return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1349              static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1350     // Check next enclosing IfOp.
1351     ifOp = ifOp->getParentOfType<IfOp>();
1352   }
1353 
1354   // Could not find a common IfOp among a's and b's ancestors.
1355   return false;
1356 }
1357 
1358 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1359                  bool withElseRegion) {
1360   build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
1361 }
1362 
1363 void IfOp::build(OpBuilder &builder, OperationState &result,
1364                  TypeRange resultTypes, Value cond, bool withElseRegion) {
1365   auto addTerminator = [&](OpBuilder &nested, Location loc) {
1366     if (resultTypes.empty())
1367       IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
1368                              loc);
1369   };
1370 
1371   build(builder, result, resultTypes, cond, addTerminator,
1372         withElseRegion ? addTerminator
1373                        : function_ref<void(OpBuilder &, Location)>());
1374 }
1375 
1376 void IfOp::build(OpBuilder &builder, OperationState &result,
1377                  TypeRange resultTypes, Value cond,
1378                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1379                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1380   assert(thenBuilder && "the builder callback for 'then' must be present");
1381 
1382   result.addOperands(cond);
1383   result.addTypes(resultTypes);
1384 
1385   OpBuilder::InsertionGuard guard(builder);
1386   Region *thenRegion = result.addRegion();
1387   builder.createBlock(thenRegion);
1388   thenBuilder(builder, result.location);
1389 
1390   Region *elseRegion = result.addRegion();
1391   if (!elseBuilder)
1392     return;
1393 
1394   builder.createBlock(elseRegion);
1395   elseBuilder(builder, result.location);
1396 }
1397 
1398 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1399                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1400                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1401   build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
1402 }
1403 
1404 LogicalResult IfOp::verify() {
1405   if (getNumResults() != 0 && getElseRegion().empty())
1406     return emitOpError("must have an else block if defining values");
1407   return success();
1408 }
1409 
1410 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1411   // Create the regions for 'then'.
1412   result.regions.reserve(2);
1413   Region *thenRegion = result.addRegion();
1414   Region *elseRegion = result.addRegion();
1415 
1416   auto &builder = parser.getBuilder();
1417   OpAsmParser::UnresolvedOperand cond;
1418   Type i1Type = builder.getIntegerType(1);
1419   if (parser.parseOperand(cond) ||
1420       parser.resolveOperand(cond, i1Type, result.operands))
1421     return failure();
1422   // Parse optional results type list.
1423   if (parser.parseOptionalArrowTypeList(result.types))
1424     return failure();
1425   // Parse the 'then' region.
1426   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1427     return failure();
1428   IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1429 
1430   // If we find an 'else' keyword then parse the 'else' region.
1431   if (!parser.parseOptionalKeyword("else")) {
1432     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1433       return failure();
1434     IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1435   }
1436 
1437   // Parse the optional attribute list.
1438   if (parser.parseOptionalAttrDict(result.attributes))
1439     return failure();
1440   return success();
1441 }
1442 
1443 void IfOp::print(OpAsmPrinter &p) {
1444   bool printBlockTerminators = false;
1445 
1446   p << " " << getCondition();
1447   if (!getResults().empty()) {
1448     p << " -> (" << getResultTypes() << ")";
1449     // Print yield explicitly if the op defines values.
1450     printBlockTerminators = true;
1451   }
1452   p << ' ';
1453   p.printRegion(getThenRegion(),
1454                 /*printEntryBlockArgs=*/false,
1455                 /*printBlockTerminators=*/printBlockTerminators);
1456 
1457   // Print the 'else' regions if it exists and has a block.
1458   auto &elseRegion = getElseRegion();
1459   if (!elseRegion.empty()) {
1460     p << " else ";
1461     p.printRegion(elseRegion,
1462                   /*printEntryBlockArgs=*/false,
1463                   /*printBlockTerminators=*/printBlockTerminators);
1464   }
1465 
1466   p.printOptionalAttrDict((*this)->getAttrs());
1467 }
1468 
1469 /// Given the region at `index`, or the parent operation if `index` is None,
1470 /// return the successor regions. These are the regions that may be selected
1471 /// during the flow of control. `operands` is a set of optional attributes that
1472 /// correspond to a constant value for each operand, or null if that operand is
1473 /// not a constant.
1474 void IfOp::getSuccessorRegions(Optional<unsigned> index,
1475                                ArrayRef<Attribute> operands,
1476                                SmallVectorImpl<RegionSuccessor> &regions) {
1477   // The `then` and the `else` region branch back to the parent operation.
1478   if (index) {
1479     regions.push_back(RegionSuccessor(getResults()));
1480     return;
1481   }
1482 
1483   // Don't consider the else region if it is empty.
1484   Region *elseRegion = &this->getElseRegion();
1485   if (elseRegion->empty())
1486     elseRegion = nullptr;
1487 
1488   // Otherwise, the successor is dependent on the condition.
1489   bool condition;
1490   if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1491     condition = condAttr.getValue().isOneValue();
1492   } else {
1493     // If the condition isn't constant, both regions may be executed.
1494     regions.push_back(RegionSuccessor(&getThenRegion()));
1495     // If the else region does not exist, it is not a viable successor.
1496     if (elseRegion)
1497       regions.push_back(RegionSuccessor(elseRegion));
1498     return;
1499   }
1500 
1501   // Add the successor regions using the condition.
1502   regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
1503 }
1504 
1505 LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
1506                          SmallVectorImpl<OpFoldResult> &results) {
1507   // if (!c) then A() else B() -> if c then B() else A()
1508   if (getElseRegion().empty())
1509     return failure();
1510 
1511   arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
1512   if (!xorStmt)
1513     return failure();
1514 
1515   if (!matchPattern(xorStmt.getRhs(), m_One()))
1516     return failure();
1517 
1518   getConditionMutable().assign(xorStmt.getLhs());
1519   Block *thenBlock = &getThenRegion().front();
1520   // It would be nicer to use iplist::swap, but that has no implemented
1521   // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
1522   getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
1523                                      getElseRegion().getBlocks());
1524   getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
1525                                      getThenRegion().getBlocks(), thenBlock);
1526   return success();
1527 }
1528 
1529 void IfOp::getRegionInvocationBounds(
1530     ArrayRef<Attribute> operands,
1531     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1532   if (auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
1533     // If the condition is known, then one region is known to be executed once
1534     // and the other zero times.
1535     invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
1536     invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
1537   } else {
1538     // Non-constant condition. Each region may be executed 0 or 1 times.
1539     invocationBounds.assign(2, {0, 1});
1540   }
1541 }
1542 
1543 namespace {
1544 // Pattern to remove unused IfOp results.
1545 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
1546   using OpRewritePattern<IfOp>::OpRewritePattern;
1547 
1548   void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
1549                     PatternRewriter &rewriter) const {
1550     // Move all operations to the destination block.
1551     rewriter.mergeBlocks(source, dest);
1552     // Replace the yield op by one that returns only the used values.
1553     auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
1554     SmallVector<Value, 4> usedOperands;
1555     llvm::transform(usedResults, std::back_inserter(usedOperands),
1556                     [&](OpResult result) {
1557                       return yieldOp.getOperand(result.getResultNumber());
1558                     });
1559     rewriter.updateRootInPlace(yieldOp,
1560                                [&]() { yieldOp->setOperands(usedOperands); });
1561   }
1562 
1563   LogicalResult matchAndRewrite(IfOp op,
1564                                 PatternRewriter &rewriter) const override {
1565     // Compute the list of used results.
1566     SmallVector<OpResult, 4> usedResults;
1567     llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
1568                   [](OpResult result) { return !result.use_empty(); });
1569 
1570     // Replace the operation if only a subset of its results have uses.
1571     if (usedResults.size() == op.getNumResults())
1572       return failure();
1573 
1574     // Compute the result types of the replacement operation.
1575     SmallVector<Type, 4> newTypes;
1576     llvm::transform(usedResults, std::back_inserter(newTypes),
1577                     [](OpResult result) { return result.getType(); });
1578 
1579     // Create a replacement operation with empty then and else regions.
1580     auto emptyBuilder = [](OpBuilder &, Location) {};
1581     auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
1582                                        emptyBuilder, emptyBuilder);
1583 
1584     // Move the bodies and replace the terminators (note there is a then and
1585     // an else region since the operation returns results).
1586     transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
1587     transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
1588 
1589     // Replace the operation by the new one.
1590     SmallVector<Value, 4> repResults(op.getNumResults());
1591     for (const auto &en : llvm::enumerate(usedResults))
1592       repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
1593     rewriter.replaceOp(op, repResults);
1594     return success();
1595   }
1596 };
1597 
1598 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
1599   using OpRewritePattern<IfOp>::OpRewritePattern;
1600 
1601   LogicalResult matchAndRewrite(IfOp op,
1602                                 PatternRewriter &rewriter) const override {
1603     auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
1604     if (!constant)
1605       return failure();
1606 
1607     if (constant.getValue().cast<BoolAttr>().getValue())
1608       replaceOpWithRegion(rewriter, op, op.getThenRegion());
1609     else if (!op.getElseRegion().empty())
1610       replaceOpWithRegion(rewriter, op, op.getElseRegion());
1611     else
1612       rewriter.eraseOp(op);
1613 
1614     return success();
1615   }
1616 };
1617 
1618 /// Hoist any yielded results whose operands are defined outside
1619 /// the if, to a select instruction.
1620 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
1621   using OpRewritePattern<IfOp>::OpRewritePattern;
1622 
1623   LogicalResult matchAndRewrite(IfOp op,
1624                                 PatternRewriter &rewriter) const override {
1625     if (op->getNumResults() == 0)
1626       return failure();
1627 
1628     auto cond = op.getCondition();
1629     auto thenYieldArgs = op.thenYield().getOperands();
1630     auto elseYieldArgs = op.elseYield().getOperands();
1631 
1632     SmallVector<Type> nonHoistable;
1633     for (const auto &it :
1634          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1635       Value trueVal = std::get<0>(it.value());
1636       Value falseVal = std::get<1>(it.value());
1637       if (&op.getThenRegion() == trueVal.getParentRegion() ||
1638           &op.getElseRegion() == falseVal.getParentRegion())
1639         nonHoistable.push_back(trueVal.getType());
1640     }
1641     // Early exit if there aren't any yielded values we can
1642     // hoist outside the if.
1643     if (nonHoistable.size() == op->getNumResults())
1644       return failure();
1645 
1646     IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
1647     if (replacement.thenBlock())
1648       rewriter.eraseBlock(replacement.thenBlock());
1649     replacement.getThenRegion().takeBody(op.getThenRegion());
1650     replacement.getElseRegion().takeBody(op.getElseRegion());
1651 
1652     SmallVector<Value> results(op->getNumResults());
1653     assert(thenYieldArgs.size() == results.size());
1654     assert(elseYieldArgs.size() == results.size());
1655 
1656     SmallVector<Value> trueYields;
1657     SmallVector<Value> falseYields;
1658     rewriter.setInsertionPoint(replacement);
1659     for (const auto &it :
1660          llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1661       Value trueVal = std::get<0>(it.value());
1662       Value falseVal = std::get<1>(it.value());
1663       if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
1664           &replacement.getElseRegion() == falseVal.getParentRegion()) {
1665         results[it.index()] = replacement.getResult(trueYields.size());
1666         trueYields.push_back(trueVal);
1667         falseYields.push_back(falseVal);
1668       } else if (trueVal == falseVal)
1669         results[it.index()] = trueVal;
1670       else
1671         results[it.index()] = rewriter.create<arith::SelectOp>(
1672             op.getLoc(), cond, trueVal, falseVal);
1673     }
1674 
1675     rewriter.setInsertionPointToEnd(replacement.thenBlock());
1676     rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
1677 
1678     rewriter.setInsertionPointToEnd(replacement.elseBlock());
1679     rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
1680 
1681     rewriter.replaceOp(op, results);
1682     return success();
1683   }
1684 };
1685 
1686 /// Allow the true region of an if to assume the condition is true
1687 /// and vice versa. For example:
1688 ///
1689 ///   scf.if %cmp {
1690 ///      print(%cmp)
1691 ///   }
1692 ///
1693 ///  becomes
1694 ///
1695 ///   scf.if %cmp {
1696 ///      print(true)
1697 ///   }
1698 ///
1699 struct ConditionPropagation : public OpRewritePattern<IfOp> {
1700   using OpRewritePattern<IfOp>::OpRewritePattern;
1701 
1702   LogicalResult matchAndRewrite(IfOp op,
1703                                 PatternRewriter &rewriter) const override {
1704     // Early exit if the condition is constant since replacing a constant
1705     // in the body with another constant isn't a simplification.
1706     if (op.getCondition().getDefiningOp<arith::ConstantOp>())
1707       return failure();
1708 
1709     bool changed = false;
1710     mlir::Type i1Ty = rewriter.getI1Type();
1711 
1712     // These variables serve to prevent creating duplicate constants
1713     // and hold constant true or false values.
1714     Value constantTrue = nullptr;
1715     Value constantFalse = nullptr;
1716 
1717     for (OpOperand &use :
1718          llvm::make_early_inc_range(op.getCondition().getUses())) {
1719       if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1720         changed = true;
1721 
1722         if (!constantTrue)
1723           constantTrue = rewriter.create<arith::ConstantOp>(
1724               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
1725 
1726         rewriter.updateRootInPlace(use.getOwner(),
1727                                    [&]() { use.set(constantTrue); });
1728       } else if (op.getElseRegion().isAncestor(
1729                      use.getOwner()->getParentRegion())) {
1730         changed = true;
1731 
1732         if (!constantFalse)
1733           constantFalse = rewriter.create<arith::ConstantOp>(
1734               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
1735 
1736         rewriter.updateRootInPlace(use.getOwner(),
1737                                    [&]() { use.set(constantFalse); });
1738       }
1739     }
1740 
1741     return success(changed);
1742   }
1743 };
1744 
1745 /// Remove any statements from an if that are equivalent to the condition
1746 /// or its negation. For example:
1747 ///
1748 ///    %res:2 = scf.if %cmp {
1749 ///       yield something(), true
1750 ///    } else {
1751 ///       yield something2(), false
1752 ///    }
1753 ///    print(%res#1)
1754 ///
1755 ///  becomes
1756 ///    %res = scf.if %cmp {
1757 ///       yield something()
1758 ///    } else {
1759 ///       yield something2()
1760 ///    }
1761 ///    print(%cmp)
1762 ///
1763 /// Additionally if both branches yield the same value, replace all uses
1764 /// of the result with the yielded value.
1765 ///
1766 ///    %res:2 = scf.if %cmp {
1767 ///       yield something(), %arg1
1768 ///    } else {
1769 ///       yield something2(), %arg1
1770 ///    }
1771 ///    print(%res#1)
1772 ///
1773 ///  becomes
1774 ///    %res = scf.if %cmp {
1775 ///       yield something()
1776 ///    } else {
1777 ///       yield something2()
1778 ///    }
1779 ///    print(%arg1)
1780 ///
1781 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
1782   using OpRewritePattern<IfOp>::OpRewritePattern;
1783 
1784   LogicalResult matchAndRewrite(IfOp op,
1785                                 PatternRewriter &rewriter) const override {
1786     // Early exit if there are no results that could be replaced.
1787     if (op.getNumResults() == 0)
1788       return failure();
1789 
1790     auto trueYield =
1791         cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
1792     auto falseYield =
1793         cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
1794 
1795     rewriter.setInsertionPoint(op->getBlock(),
1796                                op.getOperation()->getIterator());
1797     bool changed = false;
1798     Type i1Ty = rewriter.getI1Type();
1799     for (auto tup : llvm::zip(trueYield.getResults(), falseYield.getResults(),
1800                               op.getResults())) {
1801       Value trueResult, falseResult, opResult;
1802       std::tie(trueResult, falseResult, opResult) = tup;
1803 
1804       if (trueResult == falseResult) {
1805         if (!opResult.use_empty()) {
1806           opResult.replaceAllUsesWith(trueResult);
1807           changed = true;
1808         }
1809         continue;
1810       }
1811 
1812       auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
1813       if (!trueYield)
1814         continue;
1815 
1816       if (!trueYield.getType().isInteger(1))
1817         continue;
1818 
1819       auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
1820       if (!falseYield)
1821         continue;
1822 
1823       bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
1824       bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
1825       if (!trueVal && falseVal) {
1826         if (!opResult.use_empty()) {
1827           Value notCond = rewriter.create<arith::XOrIOp>(
1828               op.getLoc(), op.getCondition(),
1829               rewriter.create<arith::ConstantOp>(
1830                   op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
1831           opResult.replaceAllUsesWith(notCond);
1832           changed = true;
1833         }
1834       }
1835       if (trueVal && !falseVal) {
1836         if (!opResult.use_empty()) {
1837           opResult.replaceAllUsesWith(op.getCondition());
1838           changed = true;
1839         }
1840       }
1841     }
1842     return success(changed);
1843   }
1844 };
1845 
1846 /// Merge any consecutive scf.if's with the same condition.
1847 ///
1848 ///    scf.if %cond {
1849 ///       firstCodeTrue();...
1850 ///    } else {
1851 ///       firstCodeFalse();...
1852 ///    }
1853 ///    %res = scf.if %cond {
1854 ///       secondCodeTrue();...
1855 ///    } else {
1856 ///       secondCodeFalse();...
1857 ///    }
1858 ///
1859 ///  becomes
1860 ///    %res = scf.if %cmp {
1861 ///       firstCodeTrue();...
1862 ///       secondCodeTrue();...
1863 ///    } else {
1864 ///       firstCodeFalse();...
1865 ///       secondCodeFalse();...
1866 ///    }
1867 struct CombineIfs : public OpRewritePattern<IfOp> {
1868   using OpRewritePattern<IfOp>::OpRewritePattern;
1869 
1870   LogicalResult matchAndRewrite(IfOp nextIf,
1871                                 PatternRewriter &rewriter) const override {
1872     Block *parent = nextIf->getBlock();
1873     if (nextIf == &parent->front())
1874       return failure();
1875 
1876     auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
1877     if (!prevIf)
1878       return failure();
1879 
1880     // Determine the logical then/else blocks when prevIf's
1881     // condition is used. Null means the block does not exist
1882     // in that case (e.g. empty else). If neither of these
1883     // are set, the two conditions cannot be compared.
1884     Block *nextThen = nullptr;
1885     Block *nextElse = nullptr;
1886     if (nextIf.getCondition() == prevIf.getCondition()) {
1887       nextThen = nextIf.thenBlock();
1888       if (!nextIf.getElseRegion().empty())
1889         nextElse = nextIf.elseBlock();
1890     }
1891     if (arith::XOrIOp notv =
1892             nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1893       if (notv.getLhs() == prevIf.getCondition() &&
1894           matchPattern(notv.getRhs(), m_One())) {
1895         nextElse = nextIf.thenBlock();
1896         if (!nextIf.getElseRegion().empty())
1897           nextThen = nextIf.elseBlock();
1898       }
1899     }
1900     if (arith::XOrIOp notv =
1901             prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1902       if (notv.getLhs() == nextIf.getCondition() &&
1903           matchPattern(notv.getRhs(), m_One())) {
1904         nextElse = nextIf.thenBlock();
1905         if (!nextIf.getElseRegion().empty())
1906           nextThen = nextIf.elseBlock();
1907       }
1908     }
1909 
1910     if (!nextThen && !nextElse)
1911       return failure();
1912 
1913     SmallVector<Value> prevElseYielded;
1914     if (!prevIf.getElseRegion().empty())
1915       prevElseYielded = prevIf.elseYield().getOperands();
1916     // Replace all uses of return values of op within nextIf with the
1917     // corresponding yields
1918     for (auto it : llvm::zip(prevIf.getResults(),
1919                              prevIf.thenYield().getOperands(), prevElseYielded))
1920       for (OpOperand &use :
1921            llvm::make_early_inc_range(std::get<0>(it).getUses())) {
1922         if (nextThen && nextThen->getParent()->isAncestor(
1923                             use.getOwner()->getParentRegion())) {
1924           rewriter.startRootUpdate(use.getOwner());
1925           use.set(std::get<1>(it));
1926           rewriter.finalizeRootUpdate(use.getOwner());
1927         } else if (nextElse && nextElse->getParent()->isAncestor(
1928                                    use.getOwner()->getParentRegion())) {
1929           rewriter.startRootUpdate(use.getOwner());
1930           use.set(std::get<2>(it));
1931           rewriter.finalizeRootUpdate(use.getOwner());
1932         }
1933       }
1934 
1935     SmallVector<Type> mergedTypes(prevIf.getResultTypes());
1936     llvm::append_range(mergedTypes, nextIf.getResultTypes());
1937 
1938     IfOp combinedIf = rewriter.create<IfOp>(
1939         nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
1940     rewriter.eraseBlock(&combinedIf.getThenRegion().back());
1941 
1942     rewriter.inlineRegionBefore(prevIf.getThenRegion(),
1943                                 combinedIf.getThenRegion(),
1944                                 combinedIf.getThenRegion().begin());
1945 
1946     if (nextThen) {
1947       YieldOp thenYield = combinedIf.thenYield();
1948       YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
1949       rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
1950       rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
1951 
1952       SmallVector<Value> mergedYields(thenYield.getOperands());
1953       llvm::append_range(mergedYields, thenYield2.getOperands());
1954       rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
1955       rewriter.eraseOp(thenYield);
1956       rewriter.eraseOp(thenYield2);
1957     }
1958 
1959     rewriter.inlineRegionBefore(prevIf.getElseRegion(),
1960                                 combinedIf.getElseRegion(),
1961                                 combinedIf.getElseRegion().begin());
1962 
1963     if (nextElse) {
1964       if (combinedIf.getElseRegion().empty()) {
1965         rewriter.inlineRegionBefore(*nextElse->getParent(),
1966                                     combinedIf.getElseRegion(),
1967                                     combinedIf.getElseRegion().begin());
1968       } else {
1969         YieldOp elseYield = combinedIf.elseYield();
1970         YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
1971         rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
1972 
1973         rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
1974 
1975         SmallVector<Value> mergedElseYields(elseYield.getOperands());
1976         llvm::append_range(mergedElseYields, elseYield2.getOperands());
1977 
1978         rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
1979         rewriter.eraseOp(elseYield);
1980         rewriter.eraseOp(elseYield2);
1981       }
1982     }
1983 
1984     SmallVector<Value> prevValues;
1985     SmallVector<Value> nextValues;
1986     for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
1987       if (pair.index() < prevIf.getNumResults())
1988         prevValues.push_back(pair.value());
1989       else
1990         nextValues.push_back(pair.value());
1991     }
1992     rewriter.replaceOp(prevIf, prevValues);
1993     rewriter.replaceOp(nextIf, nextValues);
1994     return success();
1995   }
1996 };
1997 
1998 /// Pattern to remove an empty else branch.
1999 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2000   using OpRewritePattern<IfOp>::OpRewritePattern;
2001 
2002   LogicalResult matchAndRewrite(IfOp ifOp,
2003                                 PatternRewriter &rewriter) const override {
2004     // Cannot remove else region when there are operation results.
2005     if (ifOp.getNumResults())
2006       return failure();
2007     Block *elseBlock = ifOp.elseBlock();
2008     if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2009       return failure();
2010     auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2011     rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2012                                 newIfOp.getThenRegion().begin());
2013     rewriter.eraseOp(ifOp);
2014     return success();
2015   }
2016 };
2017 
2018 /// Convert nested `if`s into `arith.andi` + single `if`.
2019 ///
2020 ///    scf.if %arg0 {
2021 ///      scf.if %arg1 {
2022 ///        ...
2023 ///        scf.yield
2024 ///      }
2025 ///      scf.yield
2026 ///    }
2027 ///  becomes
2028 ///
2029 ///    %0 = arith.andi %arg0, %arg1
2030 ///    scf.if %0 {
2031 ///      ...
2032 ///      scf.yield
2033 ///    }
2034 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2035   using OpRewritePattern<IfOp>::OpRewritePattern;
2036 
2037   LogicalResult matchAndRewrite(IfOp op,
2038                                 PatternRewriter &rewriter) const override {
2039     auto nestedOps = op.thenBlock()->without_terminator();
2040     // Nested `if` must be the only op in block.
2041     if (!llvm::hasSingleElement(nestedOps))
2042       return failure();
2043 
2044     // If there is an else block, it can only yield
2045     if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2046       return failure();
2047 
2048     auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2049     if (!nestedIf)
2050       return failure();
2051 
2052     if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2053       return failure();
2054 
2055     SmallVector<Value> thenYield(op.thenYield().getOperands());
2056     SmallVector<Value> elseYield;
2057     if (op.elseBlock())
2058       llvm::append_range(elseYield, op.elseYield().getOperands());
2059 
2060     // A list of indices for which we should upgrade the value yielded
2061     // in the else to a select.
2062     SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2063 
2064     // If the outer scf.if yields a value produced by the inner scf.if,
2065     // only permit combining if the value yielded when the condition
2066     // is false in the outer scf.if is the same value yielded when the
2067     // inner scf.if condition is false.
2068     // Note that the array access to elseYield will not go out of bounds
2069     // since it must have the same length as thenYield, since they both
2070     // come from the same scf.if.
2071     for (const auto &tup : llvm::enumerate(thenYield)) {
2072       if (tup.value().getDefiningOp() == nestedIf) {
2073         auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
2074         if (nestedIf.elseYield().getOperand(nestedIdx) !=
2075             elseYield[tup.index()]) {
2076           return failure();
2077         }
2078         // If the correctness test passes, we will yield
2079         // corresponding value from the inner scf.if
2080         thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2081         continue;
2082       }
2083 
2084       // Otherwise, we need to ensure the else block of the combined
2085       // condition still returns the same value when the outer condition is
2086       // true and the inner condition is false. This can be accomplished if
2087       // the then value is defined outside the outer scf.if and we replace the
2088       // value with a select that considers just the outer condition. Since
2089       // the else region contains just the yield, its yielded value is
2090       // defined outside the scf.if, by definition.
2091 
2092       // If the then value is defined within the scf.if, bail.
2093       if (tup.value().getParentRegion() == &op.getThenRegion()) {
2094         return failure();
2095       }
2096       elseYieldsToUpgradeToSelect.push_back(tup.index());
2097     }
2098 
2099     Location loc = op.getLoc();
2100     Value newCondition = rewriter.create<arith::AndIOp>(
2101         loc, op.getCondition(), nestedIf.getCondition());
2102     auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2103 
2104     SmallVector<Value> results;
2105     llvm::append_range(results, newIf.getResults());
2106     rewriter.setInsertionPoint(newIf);
2107 
2108     for (auto idx : elseYieldsToUpgradeToSelect)
2109       results[idx] = rewriter.create<arith::SelectOp>(
2110           op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2111 
2112     Block *newIfBlock = newIf.thenBlock();
2113     if (newIfBlock)
2114       rewriter.eraseOp(newIfBlock->getTerminator());
2115     else
2116       newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2117     rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2118     rewriter.setInsertionPointToEnd(newIf.thenBlock());
2119     rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2120     if (!elseYield.empty()) {
2121       rewriter.createBlock(&newIf.getElseRegion());
2122       rewriter.setInsertionPointToEnd(newIf.elseBlock());
2123       rewriter.create<YieldOp>(loc, elseYield);
2124     }
2125     rewriter.replaceOp(op, results);
2126     return success();
2127   }
2128 };
2129 
2130 } // namespace
2131 
2132 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2133                                        MLIRContext *context) {
2134   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2135               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2136               RemoveStaticCondition, RemoveUnusedResults,
2137               ReplaceIfYieldWithConditionOrValue>(context);
2138 }
2139 
2140 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2141 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2142 Block *IfOp::elseBlock() {
2143   Region &r = getElseRegion();
2144   if (r.empty())
2145     return nullptr;
2146   return &r.back();
2147 }
2148 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2149 
2150 //===----------------------------------------------------------------------===//
2151 // ParallelOp
2152 //===----------------------------------------------------------------------===//
2153 
2154 void ParallelOp::build(
2155     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2156     ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2157     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
2158         bodyBuilderFn) {
2159   result.addOperands(lowerBounds);
2160   result.addOperands(upperBounds);
2161   result.addOperands(steps);
2162   result.addOperands(initVals);
2163   result.addAttribute(
2164       ParallelOp::getOperandSegmentSizeAttr(),
2165       builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
2166                                 static_cast<int32_t>(upperBounds.size()),
2167                                 static_cast<int32_t>(steps.size()),
2168                                 static_cast<int32_t>(initVals.size())}));
2169   result.addTypes(initVals.getTypes());
2170 
2171   OpBuilder::InsertionGuard guard(builder);
2172   unsigned numIVs = steps.size();
2173   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2174   SmallVector<Location, 8> argLocs(numIVs, result.location);
2175   Region *bodyRegion = result.addRegion();
2176   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2177 
2178   if (bodyBuilderFn) {
2179     builder.setInsertionPointToStart(bodyBlock);
2180     bodyBuilderFn(builder, result.location,
2181                   bodyBlock->getArguments().take_front(numIVs),
2182                   bodyBlock->getArguments().drop_front(numIVs));
2183   }
2184   ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2185 }
2186 
2187 void ParallelOp::build(
2188     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2189     ValueRange upperBounds, ValueRange steps,
2190     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2191   // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2192   // we don't capture a reference to a temporary by constructing the lambda at
2193   // function level.
2194   auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2195                                            Location nestedLoc, ValueRange ivs,
2196                                            ValueRange) {
2197     bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2198   };
2199   function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2200   if (bodyBuilderFn)
2201     wrapper = wrappedBuilderFn;
2202 
2203   build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2204         wrapper);
2205 }
2206 
2207 LogicalResult ParallelOp::verify() {
2208   // Check that there is at least one value in lowerBound, upperBound and step.
2209   // It is sufficient to test only step, because it is ensured already that the
2210   // number of elements in lowerBound, upperBound and step are the same.
2211   Operation::operand_range stepValues = getStep();
2212   if (stepValues.empty())
2213     return emitOpError(
2214         "needs at least one tuple element for lowerBound, upperBound and step");
2215 
2216   // Check whether all constant step values are positive.
2217   for (Value stepValue : stepValues)
2218     if (auto cst = stepValue.getDefiningOp<arith::ConstantIndexOp>())
2219       if (cst.value() <= 0)
2220         return emitOpError("constant step operand must be positive");
2221 
2222   // Check that the body defines the same number of block arguments as the
2223   // number of tuple elements in step.
2224   Block *body = getBody();
2225   if (body->getNumArguments() != stepValues.size())
2226     return emitOpError() << "expects the same number of induction variables: "
2227                          << body->getNumArguments()
2228                          << " as bound and step values: " << stepValues.size();
2229   for (auto arg : body->getArguments())
2230     if (!arg.getType().isIndex())
2231       return emitOpError(
2232           "expects arguments for the induction variable to be of index type");
2233 
2234   // Check that the yield has no results
2235   Operation *yield = body->getTerminator();
2236   if (yield->getNumOperands() != 0)
2237     return yield->emitOpError() << "not allowed to have operands inside '"
2238                                 << ParallelOp::getOperationName() << "'";
2239 
2240   // Check that the number of results is the same as the number of ReduceOps.
2241   SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
2242   auto resultsSize = getResults().size();
2243   auto reductionsSize = reductions.size();
2244   auto initValsSize = getInitVals().size();
2245   if (resultsSize != reductionsSize)
2246     return emitOpError() << "expects number of results: " << resultsSize
2247                          << " to be the same as number of reductions: "
2248                          << reductionsSize;
2249   if (resultsSize != initValsSize)
2250     return emitOpError() << "expects number of results: " << resultsSize
2251                          << " to be the same as number of initial values: "
2252                          << initValsSize;
2253 
2254   // Check that the types of the results and reductions are the same.
2255   for (auto resultAndReduce : llvm::zip(getResults(), reductions)) {
2256     auto resultType = std::get<0>(resultAndReduce).getType();
2257     auto reduceOp = std::get<1>(resultAndReduce);
2258     auto reduceType = reduceOp.getOperand().getType();
2259     if (resultType != reduceType)
2260       return reduceOp.emitOpError()
2261              << "expects type of reduce: " << reduceType
2262              << " to be the same as result type: " << resultType;
2263   }
2264   return success();
2265 }
2266 
2267 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2268   auto &builder = parser.getBuilder();
2269   // Parse an opening `(` followed by induction variables followed by `)`
2270   SmallVector<OpAsmParser::Argument, 4> ivs;
2271   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
2272     return failure();
2273 
2274   // Parse loop bounds.
2275   SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
2276   if (parser.parseEqual() ||
2277       parser.parseOperandList(lower, ivs.size(),
2278                               OpAsmParser::Delimiter::Paren) ||
2279       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2280     return failure();
2281 
2282   SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
2283   if (parser.parseKeyword("to") ||
2284       parser.parseOperandList(upper, ivs.size(),
2285                               OpAsmParser::Delimiter::Paren) ||
2286       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2287     return failure();
2288 
2289   // Parse step values.
2290   SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
2291   if (parser.parseKeyword("step") ||
2292       parser.parseOperandList(steps, ivs.size(),
2293                               OpAsmParser::Delimiter::Paren) ||
2294       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2295     return failure();
2296 
2297   // Parse init values.
2298   SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals;
2299   if (succeeded(parser.parseOptionalKeyword("init"))) {
2300     if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2301       return failure();
2302   }
2303 
2304   // Parse optional results in case there is a reduce.
2305   if (parser.parseOptionalArrowTypeList(result.types))
2306     return failure();
2307 
2308   // Now parse the body.
2309   Region *body = result.addRegion();
2310   for (auto &iv : ivs)
2311     iv.type = builder.getIndexType();
2312   if (parser.parseRegion(*body, ivs))
2313     return failure();
2314 
2315   // Set `operand_segment_sizes` attribute.
2316   result.addAttribute(
2317       ParallelOp::getOperandSegmentSizeAttr(),
2318       builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
2319                                 static_cast<int32_t>(upper.size()),
2320                                 static_cast<int32_t>(steps.size()),
2321                                 static_cast<int32_t>(initVals.size())}));
2322 
2323   // Parse attributes.
2324   if (parser.parseOptionalAttrDict(result.attributes) ||
2325       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2326                              result.operands))
2327     return failure();
2328 
2329   // Add a terminator if none was parsed.
2330   ForOp::ensureTerminator(*body, builder, result.location);
2331   return success();
2332 }
2333 
2334 void ParallelOp::print(OpAsmPrinter &p) {
2335   p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2336     << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2337   if (!getInitVals().empty())
2338     p << " init (" << getInitVals() << ")";
2339   p.printOptionalArrowTypeList(getResultTypes());
2340   p << ' ';
2341   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2342   p.printOptionalAttrDict(
2343       (*this)->getAttrs(),
2344       /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2345 }
2346 
2347 Region &ParallelOp::getLoopBody() { return getRegion(); }
2348 
2349 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
2350   auto ivArg = val.dyn_cast<BlockArgument>();
2351   if (!ivArg)
2352     return ParallelOp();
2353   assert(ivArg.getOwner() && "unlinked block argument");
2354   auto *containingOp = ivArg.getOwner()->getParentOp();
2355   return dyn_cast<ParallelOp>(containingOp);
2356 }
2357 
2358 namespace {
2359 // Collapse loop dimensions that perform a single iteration.
2360 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
2361   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2362 
2363   LogicalResult matchAndRewrite(ParallelOp op,
2364                                 PatternRewriter &rewriter) const override {
2365     BlockAndValueMapping mapping;
2366     // Compute new loop bounds that omit all single-iteration loop dimensions.
2367     SmallVector<Value, 2> newLowerBounds;
2368     SmallVector<Value, 2> newUpperBounds;
2369     SmallVector<Value, 2> newSteps;
2370     newLowerBounds.reserve(op.getLowerBound().size());
2371     newUpperBounds.reserve(op.getUpperBound().size());
2372     newSteps.reserve(op.getStep().size());
2373     for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound(),
2374                               op.getStep(), op.getInductionVars())) {
2375       Value lowerBound, upperBound, step, iv;
2376       std::tie(lowerBound, upperBound, step, iv) = dim;
2377       // Collect the statically known loop bounds.
2378       auto lowerBoundConstant =
2379           dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
2380       auto upperBoundConstant =
2381           dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
2382       auto stepConstant =
2383           dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
2384       // Replace the loop induction variable by the lower bound if the loop
2385       // performs a single iteration. Otherwise, copy the loop bounds.
2386       if (lowerBoundConstant && upperBoundConstant && stepConstant &&
2387           (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
2388           (upperBoundConstant.value() - lowerBoundConstant.value()) <=
2389               stepConstant.value()) {
2390         mapping.map(iv, lowerBound);
2391       } else {
2392         newLowerBounds.push_back(lowerBound);
2393         newUpperBounds.push_back(upperBound);
2394         newSteps.push_back(step);
2395       }
2396     }
2397     // Exit if none of the loop dimensions perform a single iteration.
2398     if (newLowerBounds.size() == op.getLowerBound().size())
2399       return failure();
2400 
2401     if (newLowerBounds.empty()) {
2402       // All of the loop dimensions perform a single iteration. Inline
2403       // loop body and nested ReduceOp's
2404       SmallVector<Value> results;
2405       results.reserve(op.getInitVals().size());
2406       for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
2407         auto reduce = dyn_cast<ReduceOp>(bodyOp);
2408         if (!reduce) {
2409           rewriter.clone(bodyOp, mapping);
2410           continue;
2411         }
2412         Block &reduceBlock = reduce.getReductionOperator().front();
2413         auto initValIndex = results.size();
2414         mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
2415         mapping.map(reduceBlock.getArgument(1),
2416                     mapping.lookupOrDefault(reduce.getOperand()));
2417         for (auto &reduceBodyOp : reduceBlock.without_terminator())
2418           rewriter.clone(reduceBodyOp, mapping);
2419 
2420         auto result = mapping.lookupOrDefault(
2421             cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
2422         results.push_back(result);
2423       }
2424       rewriter.replaceOp(op, results);
2425       return success();
2426     }
2427     // Replace the parallel loop by lower-dimensional parallel loop.
2428     auto newOp =
2429         rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
2430                                     newSteps, op.getInitVals(), nullptr);
2431     // Clone the loop body and remap the block arguments of the collapsed loops
2432     // (inlining does not support a cancellable block argument mapping).
2433     rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
2434                                newOp.getRegion().begin(), mapping);
2435     rewriter.replaceOp(op, newOp.getResults());
2436     return success();
2437   }
2438 };
2439 
2440 /// Removes parallel loops in which at least one lower/upper bound pair consists
2441 /// of the same values - such loops have an empty iteration domain.
2442 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
2443   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2444 
2445   LogicalResult matchAndRewrite(ParallelOp op,
2446                                 PatternRewriter &rewriter) const override {
2447     for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
2448       if (std::get<0>(dim) == std::get<1>(dim)) {
2449         rewriter.replaceOp(op, op.getInitVals());
2450         return success();
2451       }
2452     }
2453     return failure();
2454   }
2455 };
2456 
2457 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
2458   using OpRewritePattern<ParallelOp>::OpRewritePattern;
2459 
2460   LogicalResult matchAndRewrite(ParallelOp op,
2461                                 PatternRewriter &rewriter) const override {
2462     Block &outerBody = op.getLoopBody().front();
2463     if (!llvm::hasSingleElement(outerBody.without_terminator()))
2464       return failure();
2465 
2466     auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
2467     if (!innerOp)
2468       return failure();
2469 
2470     for (auto val : outerBody.getArguments())
2471       if (llvm::is_contained(innerOp.getLowerBound(), val) ||
2472           llvm::is_contained(innerOp.getUpperBound(), val) ||
2473           llvm::is_contained(innerOp.getStep(), val))
2474         return failure();
2475 
2476     // Reductions are not supported yet.
2477     if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2478       return failure();
2479 
2480     auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
2481                            ValueRange iterVals, ValueRange) {
2482       Block &innerBody = innerOp.getLoopBody().front();
2483       assert(iterVals.size() ==
2484              (outerBody.getNumArguments() + innerBody.getNumArguments()));
2485       BlockAndValueMapping mapping;
2486       mapping.map(outerBody.getArguments(),
2487                   iterVals.take_front(outerBody.getNumArguments()));
2488       mapping.map(innerBody.getArguments(),
2489                   iterVals.take_back(innerBody.getNumArguments()));
2490       for (Operation &op : innerBody.without_terminator())
2491         builder.clone(op, mapping);
2492     };
2493 
2494     auto concatValues = [](const auto &first, const auto &second) {
2495       SmallVector<Value> ret;
2496       ret.reserve(first.size() + second.size());
2497       ret.assign(first.begin(), first.end());
2498       ret.append(second.begin(), second.end());
2499       return ret;
2500     };
2501 
2502     auto newLowerBounds =
2503         concatValues(op.getLowerBound(), innerOp.getLowerBound());
2504     auto newUpperBounds =
2505         concatValues(op.getUpperBound(), innerOp.getUpperBound());
2506     auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2507 
2508     rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
2509                                             newSteps, llvm::None, bodyBuilder);
2510     return success();
2511   }
2512 };
2513 
2514 } // namespace
2515 
2516 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
2517                                              MLIRContext *context) {
2518   results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
2519               MergeNestedParallelLoops>(context);
2520 }
2521 
2522 //===----------------------------------------------------------------------===//
2523 // ReduceOp
2524 //===----------------------------------------------------------------------===//
2525 
2526 void ReduceOp::build(
2527     OpBuilder &builder, OperationState &result, Value operand,
2528     function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
2529   auto type = operand.getType();
2530   result.addOperands(operand);
2531 
2532   OpBuilder::InsertionGuard guard(builder);
2533   Region *bodyRegion = result.addRegion();
2534   Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type},
2535                                     {result.location, result.location});
2536   if (bodyBuilderFn)
2537     bodyBuilderFn(builder, result.location, body->getArgument(0),
2538                   body->getArgument(1));
2539 }
2540 
2541 LogicalResult ReduceOp::verifyRegions() {
2542   // The region of a ReduceOp has two arguments of the same type as its operand.
2543   auto type = getOperand().getType();
2544   Block &block = getReductionOperator().front();
2545   if (block.empty())
2546     return emitOpError("the block inside reduce should not be empty");
2547   if (block.getNumArguments() != 2 ||
2548       llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
2549         return arg.getType() != type;
2550       }))
2551     return emitOpError() << "expects two arguments to reduce block of type "
2552                          << type;
2553 
2554   // Check that the block is terminated by a ReduceReturnOp.
2555   if (!isa<ReduceReturnOp>(block.getTerminator()))
2556     return emitOpError("the block inside reduce should be terminated with a "
2557                        "'scf.reduce.return' op");
2558 
2559   return success();
2560 }
2561 
2562 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2563   // Parse an opening `(` followed by the reduced value followed by `)`
2564   OpAsmParser::UnresolvedOperand operand;
2565   if (parser.parseLParen() || parser.parseOperand(operand) ||
2566       parser.parseRParen())
2567     return failure();
2568 
2569   Type resultType;
2570   // Parse the type of the operand (and also what reduce computes on).
2571   if (parser.parseColonType(resultType) ||
2572       parser.resolveOperand(operand, resultType, result.operands))
2573     return failure();
2574 
2575   // Now parse the body.
2576   Region *body = result.addRegion();
2577   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2578     return failure();
2579 
2580   return success();
2581 }
2582 
2583 void ReduceOp::print(OpAsmPrinter &p) {
2584   p << "(" << getOperand() << ") ";
2585   p << " : " << getOperand().getType() << ' ';
2586   p.printRegion(getReductionOperator());
2587 }
2588 
2589 //===----------------------------------------------------------------------===//
2590 // ReduceReturnOp
2591 //===----------------------------------------------------------------------===//
2592 
2593 LogicalResult ReduceReturnOp::verify() {
2594   // The type of the return value should be the same type as the type of the
2595   // operand of the enclosing ReduceOp.
2596   auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
2597   Type reduceType = reduceOp.getOperand().getType();
2598   if (reduceType != getResult().getType())
2599     return emitOpError() << "needs to have type " << reduceType
2600                          << " (the type of the enclosing ReduceOp)";
2601   return success();
2602 }
2603 
2604 //===----------------------------------------------------------------------===//
2605 // WhileOp
2606 //===----------------------------------------------------------------------===//
2607 
2608 OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
2609   assert(index && *index == 0 &&
2610          "WhileOp is expected to branch only to the first region");
2611 
2612   return getInits();
2613 }
2614 
2615 ConditionOp WhileOp::getConditionOp() {
2616   return cast<ConditionOp>(getBefore().front().getTerminator());
2617 }
2618 
2619 YieldOp WhileOp::getYieldOp() {
2620   return cast<YieldOp>(getAfter().front().getTerminator());
2621 }
2622 
2623 Block::BlockArgListType WhileOp::getBeforeArguments() {
2624   return getBefore().front().getArguments();
2625 }
2626 
2627 Block::BlockArgListType WhileOp::getAfterArguments() {
2628   return getAfter().front().getArguments();
2629 }
2630 
2631 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
2632                                   ArrayRef<Attribute> operands,
2633                                   SmallVectorImpl<RegionSuccessor> &regions) {
2634   // The parent op always branches to the condition region.
2635   if (!index) {
2636     regions.emplace_back(&getBefore(), getBefore().getArguments());
2637     return;
2638   }
2639 
2640   assert(*index < 2 && "there are only two regions in a WhileOp");
2641   // The body region always branches back to the condition region.
2642   if (*index == 1) {
2643     regions.emplace_back(&getBefore(), getBefore().getArguments());
2644     return;
2645   }
2646 
2647   // Try to narrow the successor to the condition region.
2648   assert(!operands.empty() && "expected at least one operand");
2649   auto cond = operands[0].dyn_cast_or_null<BoolAttr>();
2650   if (!cond || !cond.getValue())
2651     regions.emplace_back(getResults());
2652   if (!cond || cond.getValue())
2653     regions.emplace_back(&getAfter(), getAfter().getArguments());
2654 }
2655 
2656 /// Parses a `while` op.
2657 ///
2658 /// op ::= `scf.while` assignments `:` function-type region `do` region
2659 ///         `attributes` attribute-dict
2660 /// initializer ::= /* empty */ | `(` assignment-list `)`
2661 /// assignment-list ::= assignment | assignment `,` assignment-list
2662 /// assignment ::= ssa-value `=` ssa-value
2663 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2664   SmallVector<OpAsmParser::Argument, 4> regionArgs;
2665   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2666   Region *before = result.addRegion();
2667   Region *after = result.addRegion();
2668 
2669   OptionalParseResult listResult =
2670       parser.parseOptionalAssignmentList(regionArgs, operands);
2671   if (listResult.hasValue() && failed(listResult.getValue()))
2672     return failure();
2673 
2674   FunctionType functionType;
2675   SMLoc typeLoc = parser.getCurrentLocation();
2676   if (failed(parser.parseColonType(functionType)))
2677     return failure();
2678 
2679   result.addTypes(functionType.getResults());
2680 
2681   if (functionType.getNumInputs() != operands.size()) {
2682     return parser.emitError(typeLoc)
2683            << "expected as many input types as operands "
2684            << "(expected " << operands.size() << " got "
2685            << functionType.getNumInputs() << ")";
2686   }
2687 
2688   // Resolve input operands.
2689   if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2690                                     parser.getCurrentLocation(),
2691                                     result.operands)))
2692     return failure();
2693 
2694   // Propagate the types into the region arguments.
2695   for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2696     regionArgs[i].type = functionType.getInput(i);
2697 
2698   return failure(parser.parseRegion(*before, regionArgs) ||
2699                  parser.parseKeyword("do") || parser.parseRegion(*after) ||
2700                  parser.parseOptionalAttrDictWithKeyword(result.attributes));
2701 }
2702 
2703 /// Prints a `while` op.
2704 void scf::WhileOp::print(OpAsmPrinter &p) {
2705   printInitializationList(p, getBefore().front().getArguments(), getInits(),
2706                           " ");
2707   p << " : ";
2708   p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
2709   p << ' ';
2710   p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
2711   p << " do ";
2712   p.printRegion(getAfter());
2713   p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2714 }
2715 
2716 /// Verifies that two ranges of types match, i.e. have the same number of
2717 /// entries and that types are pairwise equals. Reports errors on the given
2718 /// operation in case of mismatch.
2719 template <typename OpTy>
2720 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
2721                                            TypeRange right, StringRef message) {
2722   if (left.size() != right.size())
2723     return op.emitOpError("expects the same number of ") << message;
2724 
2725   for (unsigned i = 0, e = left.size(); i < e; ++i) {
2726     if (left[i] != right[i]) {
2727       InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
2728                                 << message;
2729       diag.attachNote() << "for argument " << i << ", found " << left[i]
2730                         << " and " << right[i];
2731       return diag;
2732     }
2733   }
2734 
2735   return success();
2736 }
2737 
2738 /// Verifies that the first block of the given `region` is terminated by a
2739 /// YieldOp. Reports errors on the given operation if it is not the case.
2740 template <typename TerminatorTy>
2741 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
2742                                            StringRef errorMessage) {
2743   Operation *terminatorOperation = region.front().getTerminator();
2744   if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
2745     return yield;
2746 
2747   auto diag = op.emitOpError(errorMessage);
2748   if (terminatorOperation)
2749     diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
2750   return nullptr;
2751 }
2752 
2753 LogicalResult scf::WhileOp::verify() {
2754   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
2755       *this, getBefore(),
2756       "expects the 'before' region to terminate with 'scf.condition'");
2757   if (!beforeTerminator)
2758     return failure();
2759 
2760   auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
2761       *this, getAfter(),
2762       "expects the 'after' region to terminate with 'scf.yield'");
2763   return success(afterTerminator != nullptr);
2764 }
2765 
2766 namespace {
2767 /// Replace uses of the condition within the do block with true, since otherwise
2768 /// the block would not be evaluated.
2769 ///
2770 /// scf.while (..) : (i1, ...) -> ... {
2771 ///  %condition = call @evaluate_condition() : () -> i1
2772 ///  scf.condition(%condition) %condition : i1, ...
2773 /// } do {
2774 /// ^bb0(%arg0: i1, ...):
2775 ///    use(%arg0)
2776 ///    ...
2777 ///
2778 /// becomes
2779 /// scf.while (..) : (i1, ...) -> ... {
2780 ///  %condition = call @evaluate_condition() : () -> i1
2781 ///  scf.condition(%condition) %condition : i1, ...
2782 /// } do {
2783 /// ^bb0(%arg0: i1, ...):
2784 ///    use(%true)
2785 ///    ...
2786 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
2787   using OpRewritePattern<WhileOp>::OpRewritePattern;
2788 
2789   LogicalResult matchAndRewrite(WhileOp op,
2790                                 PatternRewriter &rewriter) const override {
2791     auto term = op.getConditionOp();
2792 
2793     // These variables serve to prevent creating duplicate constants
2794     // and hold constant true or false values.
2795     Value constantTrue = nullptr;
2796 
2797     bool replaced = false;
2798     for (auto yieldedAndBlockArgs :
2799          llvm::zip(term.getArgs(), op.getAfterArguments())) {
2800       if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
2801         if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
2802           if (!constantTrue)
2803             constantTrue = rewriter.create<arith::ConstantOp>(
2804                 op.getLoc(), term.getCondition().getType(),
2805                 rewriter.getBoolAttr(true));
2806 
2807           std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
2808           replaced = true;
2809         }
2810       }
2811     }
2812     return success(replaced);
2813   }
2814 };
2815 
2816 /// Remove loop invariant arguments from `before` block of scf.while.
2817 /// A before block argument is considered loop invariant if :-
2818 ///   1. i-th yield operand is equal to the i-th while operand.
2819 ///   2. i-th yield operand is k-th after block argument which is (k+1)-th
2820 ///      condition operand AND this (k+1)-th condition operand is equal to i-th
2821 ///      iter argument/while operand.
2822 /// For the arguments which are removed, their uses inside scf.while
2823 /// are replaced with their corresponding initial value.
2824 ///
2825 /// Eg:
2826 ///    INPUT :-
2827 ///    %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
2828 ///                                     ..., %argN_before = %N)
2829 ///           {
2830 ///                ...
2831 ///                scf.condition(%cond) %arg1_before, %arg0_before,
2832 ///                                     %arg2_before, %arg0_before, ...
2833 ///           } do {
2834 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2835 ///                  ..., %argK_after):
2836 ///                ...
2837 ///                scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
2838 ///           }
2839 ///
2840 ///    OUTPUT :-
2841 ///    %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
2842 ///                                     %N)
2843 ///           {
2844 ///                ...
2845 ///                scf.condition(%cond) %b, %a, %arg2_before, %a, ...
2846 ///           } do {
2847 ///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2848 ///                  ..., %argK_after):
2849 ///                ...
2850 ///                scf.yield %arg1_after, ..., %argN
2851 ///           }
2852 ///
2853 ///    EXPLANATION:
2854 ///      We iterate over each yield operand.
2855 ///        1. 0-th yield operand %arg0_after_2 is 4-th condition operand
2856 ///           %arg0_before, which in turn is the 0-th iter argument. So we
2857 ///           remove 0-th before block argument and yield operand, and replace
2858 ///           all uses of the 0-th before block argument with its initial value
2859 ///           %a.
2860 ///        2. 1-th yield operand %b is equal to the 1-th iter arg's initial
2861 ///           value. So we remove this operand and the corresponding before
2862 ///           block argument and replace all uses of 1-th before block argument
2863 ///           with %b.
2864 struct RemoveLoopInvariantArgsFromBeforeBlock
2865     : public OpRewritePattern<WhileOp> {
2866   using OpRewritePattern<WhileOp>::OpRewritePattern;
2867 
2868   LogicalResult matchAndRewrite(WhileOp op,
2869                                 PatternRewriter &rewriter) const override {
2870     Block &afterBlock = op.getAfter().front();
2871     Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
2872     ConditionOp condOp = op.getConditionOp();
2873     OperandRange condOpArgs = condOp.getArgs();
2874     Operation *yieldOp = afterBlock.getTerminator();
2875     ValueRange yieldOpArgs = yieldOp->getOperands();
2876 
2877     bool canSimplify = false;
2878     for (const auto &it :
2879          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2880       auto index = static_cast<unsigned>(it.index());
2881       Value initVal, yieldOpArg;
2882       std::tie(initVal, yieldOpArg) = it.value();
2883       // If i-th yield operand is equal to the i-th operand of the scf.while,
2884       // the i-th before block argument is a loop invariant.
2885       if (yieldOpArg == initVal) {
2886         canSimplify = true;
2887         break;
2888       }
2889       // If the i-th yield operand is k-th after block argument, then we check
2890       // if the (k+1)-th condition op operand is equal to either the i-th before
2891       // block argument or the initial value of i-th before block argument. If
2892       // the comparison results `true`, i-th before block argument is a loop
2893       // invariant.
2894       auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2895       if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2896         Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2897         if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2898           canSimplify = true;
2899           break;
2900         }
2901       }
2902     }
2903 
2904     if (!canSimplify)
2905       return failure();
2906 
2907     SmallVector<Value> newInitArgs, newYieldOpArgs;
2908     DenseMap<unsigned, Value> beforeBlockInitValMap;
2909     SmallVector<Location> newBeforeBlockArgLocs;
2910     for (const auto &it :
2911          llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2912       auto index = static_cast<unsigned>(it.index());
2913       Value initVal, yieldOpArg;
2914       std::tie(initVal, yieldOpArg) = it.value();
2915 
2916       // If i-th yield operand is equal to the i-th operand of the scf.while,
2917       // the i-th before block argument is a loop invariant.
2918       if (yieldOpArg == initVal) {
2919         beforeBlockInitValMap.insert({index, initVal});
2920         continue;
2921       } else {
2922         // If the i-th yield operand is k-th after block argument, then we check
2923         // if the (k+1)-th condition op operand is equal to either the i-th
2924         // before block argument or the initial value of i-th before block
2925         // argument. If the comparison results `true`, i-th before block
2926         // argument is a loop invariant.
2927         auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2928         if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2929           Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2930           if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2931             beforeBlockInitValMap.insert({index, initVal});
2932             continue;
2933           }
2934         }
2935       }
2936       newInitArgs.emplace_back(initVal);
2937       newYieldOpArgs.emplace_back(yieldOpArg);
2938       newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
2939     }
2940 
2941     {
2942       OpBuilder::InsertionGuard g(rewriter);
2943       rewriter.setInsertionPoint(yieldOp);
2944       rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
2945     }
2946 
2947     auto newWhile =
2948         rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
2949 
2950     Block &newBeforeBlock = *rewriter.createBlock(
2951         &newWhile.getBefore(), /*insertPt*/ {},
2952         ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
2953 
2954     Block &beforeBlock = op.getBefore().front();
2955     SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
2956     // For each i-th before block argument we find it's replacement value as :-
2957     //   1. If i-th before block argument is a loop invariant, we fetch it's
2958     //      initial value from `beforeBlockInitValMap` by querying for key `i`.
2959     //   2. Else we fetch j-th new before block argument as the replacement
2960     //      value of i-th before block argument.
2961     for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
2962       // If the index 'i' argument was a loop invariant we fetch it's initial
2963       // value from `beforeBlockInitValMap`.
2964       if (beforeBlockInitValMap.count(i) != 0)
2965         newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
2966       else
2967         newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
2968     }
2969 
2970     rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
2971     rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
2972                                 newWhile.getAfter().begin());
2973 
2974     rewriter.replaceOp(op, newWhile.getResults());
2975     return success();
2976   }
2977 };
2978 
2979 /// Remove loop invariant value from result (condition op) of scf.while.
2980 /// A value is considered loop invariant if the final value yielded by
2981 /// scf.condition is defined outside of the `before` block. We remove the
2982 /// corresponding argument in `after` block and replace the use with the value.
2983 /// We also replace the use of the corresponding result of scf.while with the
2984 /// value.
2985 ///
2986 /// Eg:
2987 ///    INPUT :-
2988 ///    %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
2989 ///                                             %argN_before = %N) {
2990 ///                ...
2991 ///                scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
2992 ///           } do {
2993 ///             ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
2994 ///                ...
2995 ///                some_func(%arg1_after)
2996 ///                ...
2997 ///                scf.yield %arg0_after, %arg2_after, ..., %argN_after
2998 ///           }
2999 ///
3000 ///    OUTPUT :-
3001 ///    %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3002 ///                ...
3003 ///                scf.condition(%cond) %arg0, %arg1, ..., %argM
3004 ///           } do {
3005 ///             ^bb0(%arg0, %arg3, ..., %argM):
3006 ///                ...
3007 ///                some_func(%a)
3008 ///                ...
3009 ///                scf.yield %arg0, %b, ..., %argN
3010 ///           }
3011 ///
3012 ///     EXPLANATION:
3013 ///       1. The 1-th and 2-th operand of scf.condition are defined outside the
3014 ///          before block of scf.while, so they get removed.
3015 ///       2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3016 ///          replaced by %b.
3017 ///       3. The corresponding after block argument %arg1_after's uses are
3018 ///          replaced by %a and %arg2_after's uses are replaced by %b.
3019 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3020   using OpRewritePattern<WhileOp>::OpRewritePattern;
3021 
3022   LogicalResult matchAndRewrite(WhileOp op,
3023                                 PatternRewriter &rewriter) const override {
3024     Block &beforeBlock = op.getBefore().front();
3025     ConditionOp condOp = op.getConditionOp();
3026     OperandRange condOpArgs = condOp.getArgs();
3027 
3028     bool canSimplify = false;
3029     for (Value condOpArg : condOpArgs) {
3030       // Those values not defined within `before` block will be considered as
3031       // loop invariant values. We map the corresponding `index` with their
3032       // value.
3033       if (condOpArg.getParentBlock() != &beforeBlock) {
3034         canSimplify = true;
3035         break;
3036       }
3037     }
3038 
3039     if (!canSimplify)
3040       return failure();
3041 
3042     Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3043 
3044     SmallVector<Value> newCondOpArgs;
3045     SmallVector<Type> newAfterBlockType;
3046     DenseMap<unsigned, Value> condOpInitValMap;
3047     SmallVector<Location> newAfterBlockArgLocs;
3048     for (const auto &it : llvm::enumerate(condOpArgs)) {
3049       auto index = static_cast<unsigned>(it.index());
3050       Value condOpArg = it.value();
3051       // Those values not defined within `before` block will be considered as
3052       // loop invariant values. We map the corresponding `index` with their
3053       // value.
3054       if (condOpArg.getParentBlock() != &beforeBlock) {
3055         condOpInitValMap.insert({index, condOpArg});
3056       } else {
3057         newCondOpArgs.emplace_back(condOpArg);
3058         newAfterBlockType.emplace_back(condOpArg.getType());
3059         newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3060       }
3061     }
3062 
3063     {
3064       OpBuilder::InsertionGuard g(rewriter);
3065       rewriter.setInsertionPoint(condOp);
3066       rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3067                                                newCondOpArgs);
3068     }
3069 
3070     auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3071                                              op.getOperands());
3072 
3073     Block &newAfterBlock =
3074         *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3075                               newAfterBlockType, newAfterBlockArgLocs);
3076 
3077     Block &afterBlock = op.getAfter().front();
3078     // Since a new scf.condition op was created, we need to fetch the new
3079     // `after` block arguments which will be used while replacing operations of
3080     // previous scf.while's `after` blocks. We'd also be fetching new result
3081     // values too.
3082     SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3083     SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3084     for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3085       Value afterBlockArg, result;
3086       // If index 'i' argument was loop invariant we fetch it's value from the
3087       // `condOpInitMap` map.
3088       if (condOpInitValMap.count(i) != 0) {
3089         afterBlockArg = condOpInitValMap[i];
3090         result = afterBlockArg;
3091       } else {
3092         afterBlockArg = newAfterBlock.getArgument(j);
3093         result = newWhile.getResult(j);
3094         j++;
3095       }
3096       newAfterBlockArgs[i] = afterBlockArg;
3097       newWhileResults[i] = result;
3098     }
3099 
3100     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3101     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3102                                 newWhile.getBefore().begin());
3103 
3104     rewriter.replaceOp(op, newWhileResults);
3105     return success();
3106   }
3107 };
3108 
3109 /// Remove WhileOp results that are also unused in 'after' block.
3110 ///
3111 ///  %0:2 = scf.while () : () -> (i32, i64) {
3112 ///    %condition = "test.condition"() : () -> i1
3113 ///    %v1 = "test.get_some_value"() : () -> i32
3114 ///    %v2 = "test.get_some_value"() : () -> i64
3115 ///    scf.condition(%condition) %v1, %v2 : i32, i64
3116 ///  } do {
3117 ///  ^bb0(%arg0: i32, %arg1: i64):
3118 ///    "test.use"(%arg0) : (i32) -> ()
3119 ///    scf.yield
3120 ///  }
3121 ///  return %0#0 : i32
3122 ///
3123 /// becomes
3124 ///  %0 = scf.while () : () -> (i32) {
3125 ///    %condition = "test.condition"() : () -> i1
3126 ///    %v1 = "test.get_some_value"() : () -> i32
3127 ///    %v2 = "test.get_some_value"() : () -> i64
3128 ///    scf.condition(%condition) %v1 : i32
3129 ///  } do {
3130 ///  ^bb0(%arg0: i32):
3131 ///    "test.use"(%arg0) : (i32) -> ()
3132 ///    scf.yield
3133 ///  }
3134 ///  return %0 : i32
3135 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3136   using OpRewritePattern<WhileOp>::OpRewritePattern;
3137 
3138   LogicalResult matchAndRewrite(WhileOp op,
3139                                 PatternRewriter &rewriter) const override {
3140     auto term = op.getConditionOp();
3141     auto afterArgs = op.getAfterArguments();
3142     auto termArgs = term.getArgs();
3143 
3144     // Collect results mapping, new terminator args and new result types.
3145     SmallVector<unsigned> newResultsIndices;
3146     SmallVector<Type> newResultTypes;
3147     SmallVector<Value> newTermArgs;
3148     SmallVector<Location> newArgLocs;
3149     bool needUpdate = false;
3150     for (const auto &it :
3151          llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3152       auto i = static_cast<unsigned>(it.index());
3153       Value result = std::get<0>(it.value());
3154       Value afterArg = std::get<1>(it.value());
3155       Value termArg = std::get<2>(it.value());
3156       if (result.use_empty() && afterArg.use_empty()) {
3157         needUpdate = true;
3158       } else {
3159         newResultsIndices.emplace_back(i);
3160         newTermArgs.emplace_back(termArg);
3161         newResultTypes.emplace_back(result.getType());
3162         newArgLocs.emplace_back(result.getLoc());
3163       }
3164     }
3165 
3166     if (!needUpdate)
3167       return failure();
3168 
3169     {
3170       OpBuilder::InsertionGuard g(rewriter);
3171       rewriter.setInsertionPoint(term);
3172       rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3173                                                newTermArgs);
3174     }
3175 
3176     auto newWhile =
3177         rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3178 
3179     Block &newAfterBlock = *rewriter.createBlock(
3180         &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3181 
3182     // Build new results list and new after block args (unused entries will be
3183     // null).
3184     SmallVector<Value> newResults(op.getNumResults());
3185     SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3186     for (const auto &it : llvm::enumerate(newResultsIndices)) {
3187       newResults[it.value()] = newWhile.getResult(it.index());
3188       newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3189     }
3190 
3191     rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3192                                 newWhile.getBefore().begin());
3193 
3194     Block &afterBlock = op.getAfter().front();
3195     rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3196 
3197     rewriter.replaceOp(op, newResults);
3198     return success();
3199   }
3200 };
3201 
3202 /// Replace operations equivalent to the condition in the do block with true,
3203 /// since otherwise the block would not be evaluated.
3204 ///
3205 /// scf.while (..) : (i32, ...) -> ... {
3206 ///  %z = ... : i32
3207 ///  %condition = cmpi pred %z, %a
3208 ///  scf.condition(%condition) %z : i32, ...
3209 /// } do {
3210 /// ^bb0(%arg0: i32, ...):
3211 ///    %condition2 = cmpi pred %arg0, %a
3212 ///    use(%condition2)
3213 ///    ...
3214 ///
3215 /// becomes
3216 /// scf.while (..) : (i32, ...) -> ... {
3217 ///  %z = ... : i32
3218 ///  %condition = cmpi pred %z, %a
3219 ///  scf.condition(%condition) %z : i32, ...
3220 /// } do {
3221 /// ^bb0(%arg0: i32, ...):
3222 ///    use(%true)
3223 ///    ...
3224 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3225   using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
3226 
3227   LogicalResult matchAndRewrite(scf::WhileOp op,
3228                                 PatternRewriter &rewriter) const override {
3229     using namespace scf;
3230     auto cond = op.getConditionOp();
3231     auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3232     if (!cmp)
3233       return failure();
3234     bool changed = false;
3235     for (auto tup :
3236          llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
3237       for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3238         if (std::get<0>(tup) != cmp.getOperand(opIdx))
3239           continue;
3240         for (OpOperand &u :
3241              llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3242           auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3243           if (!cmp2)
3244             continue;
3245           // For a binary operator 1-opIdx gets the other side.
3246           if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3247             continue;
3248           bool samePredicate;
3249           if (cmp2.getPredicate() == cmp.getPredicate())
3250             samePredicate = true;
3251           else if (cmp2.getPredicate() ==
3252                    arith::invertPredicate(cmp.getPredicate()))
3253             samePredicate = false;
3254           else
3255             continue;
3256 
3257           rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3258                                                             1);
3259           changed = true;
3260         }
3261       }
3262     }
3263     return success(changed);
3264   }
3265 };
3266 
3267 struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
3268   using OpRewritePattern<WhileOp>::OpRewritePattern;
3269 
3270   LogicalResult matchAndRewrite(WhileOp op,
3271                                 PatternRewriter &rewriter) const override {
3272 
3273     if (!llvm::any_of(op.getBeforeArguments(),
3274                       [](Value arg) { return arg.use_empty(); }))
3275       return failure();
3276 
3277     YieldOp yield = op.getYieldOp();
3278 
3279     // Collect results mapping, new terminator args and new result types.
3280     SmallVector<Value> newYields;
3281     SmallVector<Value> newInits;
3282     SmallVector<unsigned> argsToErase;
3283     for (const auto &it : llvm::enumerate(llvm::zip(
3284              op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
3285       Value beforeArg = std::get<0>(it.value());
3286       Value yieldValue = std::get<1>(it.value());
3287       Value initValue = std::get<2>(it.value());
3288       if (beforeArg.use_empty()) {
3289         argsToErase.push_back(it.index());
3290       } else {
3291         newYields.emplace_back(yieldValue);
3292         newInits.emplace_back(initValue);
3293       }
3294     }
3295 
3296     if (argsToErase.empty())
3297       return failure();
3298 
3299     rewriter.startRootUpdate(op);
3300     op.getBefore().front().eraseArguments(argsToErase);
3301     rewriter.finalizeRootUpdate(op);
3302 
3303     WhileOp replacement =
3304         rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
3305     replacement.getBefore().takeBody(op.getBefore());
3306     replacement.getAfter().takeBody(op.getAfter());
3307     rewriter.replaceOp(op, replacement.getResults());
3308 
3309     rewriter.setInsertionPoint(yield);
3310     rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3311     return success();
3312   }
3313 };
3314 } // namespace
3315 
3316 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3317                                           MLIRContext *context) {
3318   results.add<RemoveLoopInvariantArgsFromBeforeBlock,
3319               RemoveLoopInvariantValueYielded, WhileConditionTruth,
3320               WhileCmpCond, WhileUnusedResult>(context);
3321 }
3322 
3323 //===----------------------------------------------------------------------===//
3324 // TableGen'd op method definitions
3325 //===----------------------------------------------------------------------===//
3326 
3327 #define GET_OP_CLASSES
3328 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
3329