1fec6c5acSUday Bondhugula //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2fec6c5acSUday Bondhugula //
3fec6c5acSUday Bondhugula // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fec6c5acSUday Bondhugula // See https://llvm.org/LICENSE.txt for license information.
5fec6c5acSUday Bondhugula // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fec6c5acSUday Bondhugula //
7fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
8fec6c5acSUday Bondhugula 
9fec6c5acSUday Bondhugula #include "TestDialect.h"
102e2cdd0aSRiver Riddle #include "TestTypes.h"
11fec6c5acSUday Bondhugula #include "mlir/Dialect/StandardOps/IR/Ops.h"
1273ca690dSRiver Riddle #include "mlir/IR/BuiltinDialect.h"
132e2cdd0aSRiver Riddle #include "mlir/IR/DialectImplementation.h"
14fec6c5acSUday Bondhugula #include "mlir/IR/PatternMatch.h"
15fec6c5acSUday Bondhugula #include "mlir/IR/TypeUtilities.h"
16fec6c5acSUday Bondhugula #include "mlir/Transforms/FoldUtils.h"
17fec6c5acSUday Bondhugula #include "mlir/Transforms/InliningUtils.h"
18a5182991SAlex Zinenko #include "llvm/ADT/SetVector.h"
19fec6c5acSUday Bondhugula #include "llvm/ADT/StringSwitch.h"
20fec6c5acSUday Bondhugula 
21fec6c5acSUday Bondhugula using namespace mlir;
2272c65b69SAlexander Belyaev using namespace mlir::test;
23fec6c5acSUday Bondhugula 
2472c65b69SAlexander Belyaev void mlir::test::registerTestDialect(DialectRegistry &registry) {
25f9dc2b70SMehdi Amini   registry.insert<TestDialect>();
26f9dc2b70SMehdi Amini }
27f9dc2b70SMehdi Amini 
28fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
29fec6c5acSUday Bondhugula // TestDialect Interfaces
30fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
31fec6c5acSUday Bondhugula 
32fec6c5acSUday Bondhugula namespace {
33fec6c5acSUday Bondhugula 
34fec6c5acSUday Bondhugula // Test support for interacting with the AsmPrinter.
35fec6c5acSUday Bondhugula struct TestOpAsmInterface : public OpAsmDialectInterface {
36fec6c5acSUday Bondhugula   using OpAsmDialectInterface::OpAsmDialectInterface;
37fec6c5acSUday Bondhugula 
38a463ea50SRiver Riddle   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
39a463ea50SRiver Riddle     StringAttr strAttr = attr.dyn_cast<StringAttr>();
40a463ea50SRiver Riddle     if (!strAttr)
41a463ea50SRiver Riddle       return failure();
42a463ea50SRiver Riddle 
43a463ea50SRiver Riddle     // Check the contents of the string attribute to see what the test alias
44a463ea50SRiver Riddle     // should be named.
45a463ea50SRiver Riddle     Optional<StringRef> aliasName =
46a463ea50SRiver Riddle         StringSwitch<Optional<StringRef>>(strAttr.getValue())
47a463ea50SRiver Riddle             .Case("alias_test:dot_in_name", StringRef("test.alias"))
48a463ea50SRiver Riddle             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
49a463ea50SRiver Riddle             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
50a463ea50SRiver Riddle             .Case("alias_test:sanitize_conflict_a",
51a463ea50SRiver Riddle                   StringRef("test_alias_conflict0"))
52a463ea50SRiver Riddle             .Case("alias_test:sanitize_conflict_b",
53a463ea50SRiver Riddle                   StringRef("test_alias_conflict0_"))
54a463ea50SRiver Riddle             .Default(llvm::None);
55a463ea50SRiver Riddle     if (!aliasName)
56a463ea50SRiver Riddle       return failure();
57a463ea50SRiver Riddle 
58a463ea50SRiver Riddle     os << *aliasName;
59a463ea50SRiver Riddle     return success();
60a463ea50SRiver Riddle   }
61a463ea50SRiver Riddle 
62fec6c5acSUday Bondhugula   void getAsmResultNames(Operation *op,
63fec6c5acSUday Bondhugula                          OpAsmSetValueNameFn setNameFn) const final {
64fec6c5acSUday Bondhugula     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
65fec6c5acSUday Bondhugula       setNameFn(asmOp, "result");
66fec6c5acSUday Bondhugula   }
67fec6c5acSUday Bondhugula 
68fec6c5acSUday Bondhugula   void getAsmBlockArgumentNames(Block *block,
69fec6c5acSUday Bondhugula                                 OpAsmSetValueNameFn setNameFn) const final {
70fec6c5acSUday Bondhugula     auto op = block->getParentOp();
71fec6c5acSUday Bondhugula     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
72fec6c5acSUday Bondhugula     if (!arrayAttr)
73fec6c5acSUday Bondhugula       return;
74fec6c5acSUday Bondhugula     auto args = block->getArguments();
75fec6c5acSUday Bondhugula     auto e = std::min(arrayAttr.size(), args.size());
76fec6c5acSUday Bondhugula     for (unsigned i = 0; i < e; ++i) {
77fec6c5acSUday Bondhugula       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
78fec6c5acSUday Bondhugula         setNameFn(args[i], strAttr.getValue());
79fec6c5acSUday Bondhugula     }
80fec6c5acSUday Bondhugula   }
81fec6c5acSUday Bondhugula };
82fec6c5acSUday Bondhugula 
83b28e3db8SMehdi Amini struct TestDialectFoldInterface : public DialectFoldInterface {
84b28e3db8SMehdi Amini   using DialectFoldInterface::DialectFoldInterface;
85fec6c5acSUday Bondhugula 
86fec6c5acSUday Bondhugula   /// Registered hook to check if the given region, which is attached to an
87fec6c5acSUday Bondhugula   /// operation that is *not* isolated from above, should be used when
88fec6c5acSUday Bondhugula   /// materializing constants.
89fec6c5acSUday Bondhugula   bool shouldMaterializeInto(Region *region) const final {
90fec6c5acSUday Bondhugula     // If this is a one region operation, then insert into it.
91fec6c5acSUday Bondhugula     return isa<OneRegionOp>(region->getParentOp());
92fec6c5acSUday Bondhugula   }
93fec6c5acSUday Bondhugula };
94fec6c5acSUday Bondhugula 
95fec6c5acSUday Bondhugula /// This class defines the interface for handling inlining with standard
96fec6c5acSUday Bondhugula /// operations.
97fec6c5acSUday Bondhugula struct TestInlinerInterface : public DialectInlinerInterface {
98fec6c5acSUday Bondhugula   using DialectInlinerInterface::DialectInlinerInterface;
99fec6c5acSUday Bondhugula 
100fec6c5acSUday Bondhugula   //===--------------------------------------------------------------------===//
101fec6c5acSUday Bondhugula   // Analysis Hooks
102fec6c5acSUday Bondhugula   //===--------------------------------------------------------------------===//
103fec6c5acSUday Bondhugula 
104fa417479SRiver Riddle   bool isLegalToInline(Operation *call, Operation *callable,
105fa417479SRiver Riddle                        bool wouldBeCloned) const final {
106501fda01SRiver Riddle     // Don't allow inlining calls that are marked `noinline`.
107501fda01SRiver Riddle     return !call->hasAttr("noinline");
108501fda01SRiver Riddle   }
109fa417479SRiver Riddle   bool isLegalToInline(Region *, Region *, bool,
110fa417479SRiver Riddle                        BlockAndValueMapping &) const final {
111fec6c5acSUday Bondhugula     // Inlining into test dialect regions is legal.
112fec6c5acSUday Bondhugula     return true;
113fec6c5acSUday Bondhugula   }
114fa417479SRiver Riddle   bool isLegalToInline(Operation *, Region *, bool,
115fec6c5acSUday Bondhugula                        BlockAndValueMapping &) const final {
116fec6c5acSUday Bondhugula     return true;
117fec6c5acSUday Bondhugula   }
118fec6c5acSUday Bondhugula 
119fec6c5acSUday Bondhugula   bool shouldAnalyzeRecursively(Operation *op) const final {
120fec6c5acSUday Bondhugula     // Analyze recursively if this is not a functional region operation, it
121fec6c5acSUday Bondhugula     // froms a separate functional scope.
122fec6c5acSUday Bondhugula     return !isa<FunctionalRegionOp>(op);
123fec6c5acSUday Bondhugula   }
124fec6c5acSUday Bondhugula 
125fec6c5acSUday Bondhugula   //===--------------------------------------------------------------------===//
126fec6c5acSUday Bondhugula   // Transformation Hooks
127fec6c5acSUday Bondhugula   //===--------------------------------------------------------------------===//
128fec6c5acSUday Bondhugula 
129fec6c5acSUday Bondhugula   /// Handle the given inlined terminator by replacing it with a new operation
130fec6c5acSUday Bondhugula   /// as necessary.
131fec6c5acSUday Bondhugula   void handleTerminator(Operation *op,
132fec6c5acSUday Bondhugula                         ArrayRef<Value> valuesToRepl) const final {
133fec6c5acSUday Bondhugula     // Only handle "test.return" here.
134fec6c5acSUday Bondhugula     auto returnOp = dyn_cast<TestReturnOp>(op);
135fec6c5acSUday Bondhugula     if (!returnOp)
136fec6c5acSUday Bondhugula       return;
137fec6c5acSUday Bondhugula 
138fec6c5acSUday Bondhugula     // Replace the values directly with the return operands.
139fec6c5acSUday Bondhugula     assert(returnOp.getNumOperands() == valuesToRepl.size());
140fec6c5acSUday Bondhugula     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
141fec6c5acSUday Bondhugula       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
142fec6c5acSUday Bondhugula   }
143fec6c5acSUday Bondhugula 
144fec6c5acSUday Bondhugula   /// Attempt to materialize a conversion for a type mismatch between a call
145fec6c5acSUday Bondhugula   /// from this dialect, and a callable region. This method should generate an
146fec6c5acSUday Bondhugula   /// operation that takes 'input' as the only operand, and produces a single
147fec6c5acSUday Bondhugula   /// result of 'resultType'. If a conversion can not be generated, nullptr
148fec6c5acSUday Bondhugula   /// should be returned.
149fec6c5acSUday Bondhugula   Operation *materializeCallConversion(OpBuilder &builder, Value input,
150fec6c5acSUday Bondhugula                                        Type resultType,
151fec6c5acSUday Bondhugula                                        Location conversionLoc) const final {
152fec6c5acSUday Bondhugula     // Only allow conversion for i16/i32 types.
153fec6c5acSUday Bondhugula     if (!(resultType.isSignlessInteger(16) ||
154fec6c5acSUday Bondhugula           resultType.isSignlessInteger(32)) ||
155fec6c5acSUday Bondhugula         !(input.getType().isSignlessInteger(16) ||
156fec6c5acSUday Bondhugula           input.getType().isSignlessInteger(32)))
157fec6c5acSUday Bondhugula       return nullptr;
158fec6c5acSUday Bondhugula     return builder.create<TestCastOp>(conversionLoc, resultType, input);
159fec6c5acSUday Bondhugula   }
160fec6c5acSUday Bondhugula };
161fec6c5acSUday Bondhugula } // end anonymous namespace
162fec6c5acSUday Bondhugula 
163fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
164fec6c5acSUday Bondhugula // TestDialect
165fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
166fec6c5acSUday Bondhugula 
167575b22b5SMehdi Amini void TestDialect::initialize() {
168fec6c5acSUday Bondhugula   addOperations<
169fec6c5acSUday Bondhugula #define GET_OP_LIST
170fec6c5acSUday Bondhugula #include "TestOps.cpp.inc"
171fec6c5acSUday Bondhugula       >();
172b28e3db8SMehdi Amini   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
173fec6c5acSUday Bondhugula                 TestInlinerInterface>();
1745fe53c41SJohn Demme   addTypes<TestType, TestRecursiveType,
1755fe53c41SJohn Demme #define GET_TYPEDEF_LIST
1765fe53c41SJohn Demme #include "TestTypeDefs.cpp.inc"
1775fe53c41SJohn Demme            >();
178fec6c5acSUday Bondhugula   allowUnknownOperations();
179fec6c5acSUday Bondhugula }
180fec6c5acSUday Bondhugula 
1815fe53c41SJohn Demme static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
182a5182991SAlex Zinenko                           llvm::SetVector<Type> &stack) {
183a5182991SAlex Zinenko   StringRef typeTag;
184a5182991SAlex Zinenko   if (failed(parser.parseKeyword(&typeTag)))
1852e2cdd0aSRiver Riddle     return Type();
186a5182991SAlex Zinenko 
1875fe53c41SJohn Demme   auto genType = generatedTypeParser(ctxt, parser, typeTag);
1885fe53c41SJohn Demme   if (genType != Type())
1895fe53c41SJohn Demme     return genType;
1905fe53c41SJohn Demme 
191a5182991SAlex Zinenko   if (typeTag == "test_type")
192a5182991SAlex Zinenko     return TestType::get(parser.getBuilder().getContext());
193a5182991SAlex Zinenko 
194a5182991SAlex Zinenko   if (typeTag != "test_rec")
195a5182991SAlex Zinenko     return Type();
196a5182991SAlex Zinenko 
197a5182991SAlex Zinenko   StringRef name;
198a5182991SAlex Zinenko   if (parser.parseLess() || parser.parseKeyword(&name))
199a5182991SAlex Zinenko     return Type();
200250f43d3SRiver Riddle   auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
201a5182991SAlex Zinenko 
202a5182991SAlex Zinenko   // If this type already has been parsed above in the stack, expect just the
203a5182991SAlex Zinenko   // name.
204a5182991SAlex Zinenko   if (stack.contains(rec)) {
205a5182991SAlex Zinenko     if (failed(parser.parseGreater()))
206a5182991SAlex Zinenko       return Type();
207a5182991SAlex Zinenko     return rec;
208a5182991SAlex Zinenko   }
209a5182991SAlex Zinenko 
210a5182991SAlex Zinenko   // Otherwise, parse the body and update the type.
211a5182991SAlex Zinenko   if (failed(parser.parseComma()))
212a5182991SAlex Zinenko     return Type();
213a5182991SAlex Zinenko   stack.insert(rec);
2145fe53c41SJohn Demme   Type subtype = parseTestType(ctxt, parser, stack);
215a5182991SAlex Zinenko   stack.pop_back();
216a5182991SAlex Zinenko   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
217a5182991SAlex Zinenko     return Type();
218a5182991SAlex Zinenko 
219a5182991SAlex Zinenko   return rec;
220a5182991SAlex Zinenko }
221a5182991SAlex Zinenko 
222a5182991SAlex Zinenko Type TestDialect::parseType(DialectAsmParser &parser) const {
223a5182991SAlex Zinenko   llvm::SetVector<Type> stack;
2245fe53c41SJohn Demme   return parseTestType(getContext(), parser, stack);
225a5182991SAlex Zinenko }
226a5182991SAlex Zinenko 
227a5182991SAlex Zinenko static void printTestType(Type type, DialectAsmPrinter &printer,
228a5182991SAlex Zinenko                           llvm::SetVector<Type> &stack) {
2295fe53c41SJohn Demme   if (succeeded(generatedTypePrinter(type, printer)))
2305fe53c41SJohn Demme     return;
231a5182991SAlex Zinenko   if (type.isa<TestType>()) {
232a5182991SAlex Zinenko     printer << "test_type";
233a5182991SAlex Zinenko     return;
234a5182991SAlex Zinenko   }
235a5182991SAlex Zinenko 
236a5182991SAlex Zinenko   auto rec = type.cast<TestRecursiveType>();
237a5182991SAlex Zinenko   printer << "test_rec<" << rec.getName();
238a5182991SAlex Zinenko   if (!stack.contains(rec)) {
239a5182991SAlex Zinenko     printer << ", ";
240a5182991SAlex Zinenko     stack.insert(rec);
241a5182991SAlex Zinenko     printTestType(rec.getBody(), printer, stack);
242a5182991SAlex Zinenko     stack.pop_back();
243a5182991SAlex Zinenko   }
244a5182991SAlex Zinenko   printer << ">";
2452e2cdd0aSRiver Riddle }
2462e2cdd0aSRiver Riddle 
2472e2cdd0aSRiver Riddle void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
248a5182991SAlex Zinenko   llvm::SetVector<Type> stack;
249a5182991SAlex Zinenko   printTestType(type, printer, stack);
2502e2cdd0aSRiver Riddle }
2512e2cdd0aSRiver Riddle 
252fec6c5acSUday Bondhugula LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
253fec6c5acSUday Bondhugula                                                     NamedAttribute namedAttr) {
254fec6c5acSUday Bondhugula   if (namedAttr.first == "test.invalid_attr")
255fec6c5acSUday Bondhugula     return op->emitError() << "invalid to use 'test.invalid_attr'";
256fec6c5acSUday Bondhugula   return success();
257fec6c5acSUday Bondhugula }
258fec6c5acSUday Bondhugula 
259fec6c5acSUday Bondhugula LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
260fec6c5acSUday Bondhugula                                                     unsigned regionIndex,
261fec6c5acSUday Bondhugula                                                     unsigned argIndex,
262fec6c5acSUday Bondhugula                                                     NamedAttribute namedAttr) {
263fec6c5acSUday Bondhugula   if (namedAttr.first == "test.invalid_attr")
264fec6c5acSUday Bondhugula     return op->emitError() << "invalid to use 'test.invalid_attr'";
265fec6c5acSUday Bondhugula   return success();
266fec6c5acSUday Bondhugula }
267fec6c5acSUday Bondhugula 
268fec6c5acSUday Bondhugula LogicalResult
269fec6c5acSUday Bondhugula TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
270fec6c5acSUday Bondhugula                                          unsigned resultIndex,
271fec6c5acSUday Bondhugula                                          NamedAttribute namedAttr) {
272fec6c5acSUday Bondhugula   if (namedAttr.first == "test.invalid_attr")
273fec6c5acSUday Bondhugula     return op->emitError() << "invalid to use 'test.invalid_attr'";
274fec6c5acSUday Bondhugula   return success();
275fec6c5acSUday Bondhugula }
276fec6c5acSUday Bondhugula 
277fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
278fec6c5acSUday Bondhugula // TestBranchOp
279fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
280fec6c5acSUday Bondhugula 
2810752d98cSRiver Riddle Optional<MutableOperandRange>
2820752d98cSRiver Riddle TestBranchOp::getMutableSuccessorOperands(unsigned index) {
283fec6c5acSUday Bondhugula   assert(index == 0 && "invalid successor index");
2840752d98cSRiver Riddle   return targetOperandsMutable();
285fec6c5acSUday Bondhugula }
286fec6c5acSUday Bondhugula 
287fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
288f4ef77cbSRiver Riddle // TestFoldToCallOp
289f4ef77cbSRiver Riddle //===----------------------------------------------------------------------===//
290f4ef77cbSRiver Riddle 
291f4ef77cbSRiver Riddle namespace {
292f4ef77cbSRiver Riddle struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
293f4ef77cbSRiver Riddle   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
294f4ef77cbSRiver Riddle 
295f4ef77cbSRiver Riddle   LogicalResult matchAndRewrite(FoldToCallOp op,
296f4ef77cbSRiver Riddle                                 PatternRewriter &rewriter) const override {
29708e4f078SRahul Joshi     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
298f4ef77cbSRiver Riddle                                         ValueRange());
299f4ef77cbSRiver Riddle     return success();
300f4ef77cbSRiver Riddle   }
301f4ef77cbSRiver Riddle };
302f4ef77cbSRiver Riddle } // end anonymous namespace
303f4ef77cbSRiver Riddle 
304f4ef77cbSRiver Riddle void FoldToCallOp::getCanonicalizationPatterns(
305f4ef77cbSRiver Riddle     OwningRewritePatternList &results, MLIRContext *context) {
306f4ef77cbSRiver Riddle   results.insert<FoldToCallOpPattern>(context);
307f4ef77cbSRiver Riddle }
308f4ef77cbSRiver Riddle 
309f4ef77cbSRiver Riddle //===----------------------------------------------------------------------===//
31088c6e25eSRiver Riddle // Test Format* operations
31188c6e25eSRiver Riddle //===----------------------------------------------------------------------===//
31288c6e25eSRiver Riddle 
31388c6e25eSRiver Riddle //===----------------------------------------------------------------------===//
31488c6e25eSRiver Riddle // Parsing
31588c6e25eSRiver Riddle 
31688c6e25eSRiver Riddle static ParseResult parseCustomDirectiveOperands(
31788c6e25eSRiver Riddle     OpAsmParser &parser, OpAsmParser::OperandType &operand,
31888c6e25eSRiver Riddle     Optional<OpAsmParser::OperandType> &optOperand,
31988c6e25eSRiver Riddle     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
32088c6e25eSRiver Riddle   if (parser.parseOperand(operand))
32188c6e25eSRiver Riddle     return failure();
32288c6e25eSRiver Riddle   if (succeeded(parser.parseOptionalComma())) {
32388c6e25eSRiver Riddle     optOperand.emplace();
32488c6e25eSRiver Riddle     if (parser.parseOperand(*optOperand))
32588c6e25eSRiver Riddle       return failure();
32688c6e25eSRiver Riddle   }
32788c6e25eSRiver Riddle   if (parser.parseArrow() || parser.parseLParen() ||
32888c6e25eSRiver Riddle       parser.parseOperandList(varOperands) || parser.parseRParen())
32988c6e25eSRiver Riddle     return failure();
33088c6e25eSRiver Riddle   return success();
33188c6e25eSRiver Riddle }
33288c6e25eSRiver Riddle static ParseResult
33388c6e25eSRiver Riddle parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
33488c6e25eSRiver Riddle                             Type &optOperandType,
33588c6e25eSRiver Riddle                             SmallVectorImpl<Type> &varOperandTypes) {
33688c6e25eSRiver Riddle   if (parser.parseColon())
33788c6e25eSRiver Riddle     return failure();
33888c6e25eSRiver Riddle 
33988c6e25eSRiver Riddle   if (parser.parseType(operandType))
34088c6e25eSRiver Riddle     return failure();
34188c6e25eSRiver Riddle   if (succeeded(parser.parseOptionalComma())) {
34288c6e25eSRiver Riddle     if (parser.parseType(optOperandType))
34388c6e25eSRiver Riddle       return failure();
34488c6e25eSRiver Riddle   }
34588c6e25eSRiver Riddle   if (parser.parseArrow() || parser.parseLParen() ||
34688c6e25eSRiver Riddle       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
34788c6e25eSRiver Riddle     return failure();
34888c6e25eSRiver Riddle   return success();
34988c6e25eSRiver Riddle }
35093fd30baSNicolas Vasilache static ParseResult
35193fd30baSNicolas Vasilache parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
35293fd30baSNicolas Vasilache                                  Type optOperandType,
35393fd30baSNicolas Vasilache                                  const SmallVectorImpl<Type> &varOperandTypes) {
35493fd30baSNicolas Vasilache   if (parser.parseKeyword("type_refs_capture"))
35593fd30baSNicolas Vasilache     return failure();
35693fd30baSNicolas Vasilache 
35793fd30baSNicolas Vasilache   Type operandType2, optOperandType2;
35893fd30baSNicolas Vasilache   SmallVector<Type, 1> varOperandTypes2;
35993fd30baSNicolas Vasilache   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
36093fd30baSNicolas Vasilache                                   varOperandTypes2))
36193fd30baSNicolas Vasilache     return failure();
36293fd30baSNicolas Vasilache 
36393fd30baSNicolas Vasilache   if (operandType != operandType2 || optOperandType != optOperandType2 ||
36493fd30baSNicolas Vasilache       varOperandTypes != varOperandTypes2)
36593fd30baSNicolas Vasilache     return failure();
36693fd30baSNicolas Vasilache 
36793fd30baSNicolas Vasilache   return success();
36893fd30baSNicolas Vasilache }
36988c6e25eSRiver Riddle static ParseResult parseCustomDirectiveOperandsAndTypes(
37088c6e25eSRiver Riddle     OpAsmParser &parser, OpAsmParser::OperandType &operand,
37188c6e25eSRiver Riddle     Optional<OpAsmParser::OperandType> &optOperand,
37288c6e25eSRiver Riddle     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
37388c6e25eSRiver Riddle     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
37488c6e25eSRiver Riddle   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
37588c6e25eSRiver Riddle       parseCustomDirectiveResults(parser, operandType, optOperandType,
37688c6e25eSRiver Riddle                                   varOperandTypes))
37788c6e25eSRiver Riddle     return failure();
37888c6e25eSRiver Riddle   return success();
37988c6e25eSRiver Riddle }
380eaeadce9SRiver Riddle static ParseResult parseCustomDirectiveRegions(
381eaeadce9SRiver Riddle     OpAsmParser &parser, Region &region,
382eaeadce9SRiver Riddle     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
383eaeadce9SRiver Riddle   if (parser.parseRegion(region))
384eaeadce9SRiver Riddle     return failure();
385eaeadce9SRiver Riddle   if (failed(parser.parseOptionalComma()))
386eaeadce9SRiver Riddle     return success();
387eaeadce9SRiver Riddle   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
388eaeadce9SRiver Riddle   if (parser.parseRegion(*varRegion))
389eaeadce9SRiver Riddle     return failure();
390eaeadce9SRiver Riddle   varRegions.emplace_back(std::move(varRegion));
391eaeadce9SRiver Riddle   return success();
392eaeadce9SRiver Riddle }
39388c6e25eSRiver Riddle static ParseResult
39488c6e25eSRiver Riddle parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
39588c6e25eSRiver Riddle                                SmallVectorImpl<Block *> &varSuccessors) {
39688c6e25eSRiver Riddle   if (parser.parseSuccessor(successor))
39788c6e25eSRiver Riddle     return failure();
39888c6e25eSRiver Riddle   if (failed(parser.parseOptionalComma()))
39988c6e25eSRiver Riddle     return success();
40088c6e25eSRiver Riddle   Block *varSuccessor;
40188c6e25eSRiver Riddle   if (parser.parseSuccessor(varSuccessor))
40288c6e25eSRiver Riddle     return failure();
40388c6e25eSRiver Riddle   varSuccessors.append(2, varSuccessor);
40488c6e25eSRiver Riddle   return success();
40588c6e25eSRiver Riddle }
406d14cfe10SMike Urbach static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
407d14cfe10SMike Urbach                                                   IntegerAttr &attr,
408d14cfe10SMike Urbach                                                   IntegerAttr &optAttr) {
409d14cfe10SMike Urbach   if (parser.parseAttribute(attr))
410d14cfe10SMike Urbach     return failure();
411d14cfe10SMike Urbach   if (succeeded(parser.parseOptionalComma())) {
412d14cfe10SMike Urbach     if (parser.parseAttribute(optAttr))
413d14cfe10SMike Urbach       return failure();
414d14cfe10SMike Urbach   }
415d14cfe10SMike Urbach   return success();
416d14cfe10SMike Urbach }
41788c6e25eSRiver Riddle 
418035e12e6SJohn Demme static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
419035e12e6SJohn Demme                                                 NamedAttrList &attrs) {
420035e12e6SJohn Demme   return parser.parseOptionalAttrDict(attrs);
421035e12e6SJohn Demme }
422035e12e6SJohn Demme 
42388c6e25eSRiver Riddle //===----------------------------------------------------------------------===//
42488c6e25eSRiver Riddle // Printing
42588c6e25eSRiver Riddle 
426035e12e6SJohn Demme static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
427035e12e6SJohn Demme                                          Value operand, Value optOperand,
42888c6e25eSRiver Riddle                                          OperandRange varOperands) {
42988c6e25eSRiver Riddle   printer << operand;
43088c6e25eSRiver Riddle   if (optOperand)
43188c6e25eSRiver Riddle     printer << ", " << optOperand;
43288c6e25eSRiver Riddle   printer << " -> (" << varOperands << ")";
43388c6e25eSRiver Riddle }
434035e12e6SJohn Demme static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
435035e12e6SJohn Demme                                         Type operandType, Type optOperandType,
43688c6e25eSRiver Riddle                                         TypeRange varOperandTypes) {
43788c6e25eSRiver Riddle   printer << " : " << operandType;
43888c6e25eSRiver Riddle   if (optOperandType)
43988c6e25eSRiver Riddle     printer << ", " << optOperandType;
44088c6e25eSRiver Riddle   printer << " -> (" << varOperandTypes << ")";
44188c6e25eSRiver Riddle }
44293fd30baSNicolas Vasilache static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
443035e12e6SJohn Demme                                              Operation *op, Type operandType,
44493fd30baSNicolas Vasilache                                              Type optOperandType,
44593fd30baSNicolas Vasilache                                              TypeRange varOperandTypes) {
44693fd30baSNicolas Vasilache   printer << " type_refs_capture ";
447035e12e6SJohn Demme   printCustomDirectiveResults(printer, op, operandType, optOperandType,
44893fd30baSNicolas Vasilache                               varOperandTypes);
44993fd30baSNicolas Vasilache }
450035e12e6SJohn Demme static void printCustomDirectiveOperandsAndTypes(
451035e12e6SJohn Demme     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
452035e12e6SJohn Demme     OperandRange varOperands, Type operandType, Type optOperandType,
45388c6e25eSRiver Riddle     TypeRange varOperandTypes) {
454035e12e6SJohn Demme   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
455035e12e6SJohn Demme   printCustomDirectiveResults(printer, op, operandType, optOperandType,
45688c6e25eSRiver Riddle                               varOperandTypes);
45788c6e25eSRiver Riddle }
458035e12e6SJohn Demme static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
459035e12e6SJohn Demme                                         Region &region,
460eaeadce9SRiver Riddle                                         MutableArrayRef<Region> varRegions) {
461eaeadce9SRiver Riddle   printer.printRegion(region);
462eaeadce9SRiver Riddle   if (!varRegions.empty()) {
463eaeadce9SRiver Riddle     printer << ", ";
464eaeadce9SRiver Riddle     for (Region &region : varRegions)
465eaeadce9SRiver Riddle       printer.printRegion(region);
466eaeadce9SRiver Riddle   }
467eaeadce9SRiver Riddle }
468035e12e6SJohn Demme static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
46988c6e25eSRiver Riddle                                            Block *successor,
47088c6e25eSRiver Riddle                                            SuccessorRange varSuccessors) {
47188c6e25eSRiver Riddle   printer << successor;
47288c6e25eSRiver Riddle   if (!varSuccessors.empty())
47388c6e25eSRiver Riddle     printer << ", " << varSuccessors.front();
47488c6e25eSRiver Riddle }
475035e12e6SJohn Demme static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
476d14cfe10SMike Urbach                                            Attribute attribute,
477d14cfe10SMike Urbach                                            Attribute optAttribute) {
478d14cfe10SMike Urbach   printer << attribute;
479d14cfe10SMike Urbach   if (optAttribute)
480d14cfe10SMike Urbach     printer << ", " << optAttribute;
481d14cfe10SMike Urbach }
48288c6e25eSRiver Riddle 
483035e12e6SJohn Demme static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
484035e12e6SJohn Demme                                          MutableDictionaryAttr attrs) {
485035e12e6SJohn Demme   printer.printOptionalAttrDict(attrs.getAttrs());
486035e12e6SJohn Demme }
48788c6e25eSRiver Riddle //===----------------------------------------------------------------------===//
488fec6c5acSUday Bondhugula // Test IsolatedRegionOp - parse passthrough region arguments.
489fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
490fec6c5acSUday Bondhugula 
491fec6c5acSUday Bondhugula static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
492fec6c5acSUday Bondhugula                                          OperationState &result) {
493fec6c5acSUday Bondhugula   OpAsmParser::OperandType argInfo;
494fec6c5acSUday Bondhugula   Type argType = parser.getBuilder().getIndexType();
495fec6c5acSUday Bondhugula 
496fec6c5acSUday Bondhugula   // Parse the input operand.
497fec6c5acSUday Bondhugula   if (parser.parseOperand(argInfo) ||
498fec6c5acSUday Bondhugula       parser.resolveOperand(argInfo, argType, result.operands))
499fec6c5acSUday Bondhugula     return failure();
500fec6c5acSUday Bondhugula 
501fec6c5acSUday Bondhugula   // Parse the body region, and reuse the operand info as the argument info.
502fec6c5acSUday Bondhugula   Region *body = result.addRegion();
503fec6c5acSUday Bondhugula   return parser.parseRegion(*body, argInfo, argType,
504fec6c5acSUday Bondhugula                             /*enableNameShadowing=*/true);
505fec6c5acSUday Bondhugula }
506fec6c5acSUday Bondhugula 
507fec6c5acSUday Bondhugula static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
508fec6c5acSUday Bondhugula   p << "test.isolated_region ";
509fec6c5acSUday Bondhugula   p.printOperand(op.getOperand());
510fec6c5acSUday Bondhugula   p.shadowRegionArgs(op.region(), op.getOperand());
511fec6c5acSUday Bondhugula   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
512fec6c5acSUday Bondhugula }
513fec6c5acSUday Bondhugula 
514fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
51562828865SStephen Neuendorffer // Test SSACFGRegionOp
51662828865SStephen Neuendorffer //===----------------------------------------------------------------------===//
51762828865SStephen Neuendorffer 
51862828865SStephen Neuendorffer RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
51962828865SStephen Neuendorffer   return RegionKind::SSACFG;
52062828865SStephen Neuendorffer }
52162828865SStephen Neuendorffer 
52262828865SStephen Neuendorffer //===----------------------------------------------------------------------===//
52362828865SStephen Neuendorffer // Test GraphRegionOp
52462828865SStephen Neuendorffer //===----------------------------------------------------------------------===//
52562828865SStephen Neuendorffer 
52662828865SStephen Neuendorffer static ParseResult parseGraphRegionOp(OpAsmParser &parser,
52762828865SStephen Neuendorffer                                       OperationState &result) {
52862828865SStephen Neuendorffer   // Parse the body region, and reuse the operand info as the argument info.
52962828865SStephen Neuendorffer   Region *body = result.addRegion();
53062828865SStephen Neuendorffer   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
53162828865SStephen Neuendorffer }
53262828865SStephen Neuendorffer 
53362828865SStephen Neuendorffer static void print(OpAsmPrinter &p, GraphRegionOp op) {
53462828865SStephen Neuendorffer   p << "test.graph_region ";
53562828865SStephen Neuendorffer   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
53662828865SStephen Neuendorffer }
53762828865SStephen Neuendorffer 
53862828865SStephen Neuendorffer RegionKind GraphRegionOp::getRegionKind(unsigned index) {
53962828865SStephen Neuendorffer   return RegionKind::Graph;
54062828865SStephen Neuendorffer }
54162828865SStephen Neuendorffer 
54262828865SStephen Neuendorffer //===----------------------------------------------------------------------===//
54357d361bdSUday Bondhugula // Test AffineScopeOp
54448034538SUday Bondhugula //===----------------------------------------------------------------------===//
54548034538SUday Bondhugula 
54657d361bdSUday Bondhugula static ParseResult parseAffineScopeOp(OpAsmParser &parser,
54748034538SUday Bondhugula                                       OperationState &result) {
54848034538SUday Bondhugula   // Parse the body region, and reuse the operand info as the argument info.
54948034538SUday Bondhugula   Region *body = result.addRegion();
55048034538SUday Bondhugula   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
55148034538SUday Bondhugula }
55248034538SUday Bondhugula 
55357d361bdSUday Bondhugula static void print(OpAsmPrinter &p, AffineScopeOp op) {
55457d361bdSUday Bondhugula   p << "test.affine_scope ";
55548034538SUday Bondhugula   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
55648034538SUday Bondhugula }
55748034538SUday Bondhugula 
55848034538SUday Bondhugula //===----------------------------------------------------------------------===//
559fec6c5acSUday Bondhugula // Test parser.
560fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
561fec6c5acSUday Bondhugula 
562fec6c5acSUday Bondhugula static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
563fec6c5acSUday Bondhugula                                          OperationState &result) {
564fec6c5acSUday Bondhugula   StringRef keyword;
565fec6c5acSUday Bondhugula   if (parser.parseKeyword(&keyword))
566fec6c5acSUday Bondhugula     return failure();
567fec6c5acSUday Bondhugula   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
568fec6c5acSUday Bondhugula   return success();
569fec6c5acSUday Bondhugula }
570fec6c5acSUday Bondhugula 
571fec6c5acSUday Bondhugula static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
572fec6c5acSUday Bondhugula   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
573fec6c5acSUday Bondhugula }
574fec6c5acSUday Bondhugula 
575fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
576fec6c5acSUday Bondhugula // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
577fec6c5acSUday Bondhugula 
578fec6c5acSUday Bondhugula static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
579fec6c5acSUday Bondhugula                                          OperationState &result) {
580fec6c5acSUday Bondhugula   if (parser.parseKeyword("wraps"))
581fec6c5acSUday Bondhugula     return failure();
582fec6c5acSUday Bondhugula 
583fec6c5acSUday Bondhugula   // Parse the wrapped op in a region
584fec6c5acSUday Bondhugula   Region &body = *result.addRegion();
585fec6c5acSUday Bondhugula   body.push_back(new Block);
586fec6c5acSUday Bondhugula   Block &block = body.back();
587fec6c5acSUday Bondhugula   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
588fec6c5acSUday Bondhugula   if (!wrapped_op)
589fec6c5acSUday Bondhugula     return failure();
590fec6c5acSUday Bondhugula 
591fec6c5acSUday Bondhugula   // Create a return terminator in the inner region, pass as operand to the
592fec6c5acSUday Bondhugula   // terminator the returned values from the wrapped operation.
593fec6c5acSUday Bondhugula   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
594fec6c5acSUday Bondhugula   OpBuilder builder(parser.getBuilder().getContext());
595fec6c5acSUday Bondhugula   builder.setInsertionPointToEnd(&block);
596fec6c5acSUday Bondhugula   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
597fec6c5acSUday Bondhugula 
598fec6c5acSUday Bondhugula   // Get the results type for the wrapping op from the terminator operands.
599fec6c5acSUday Bondhugula   Operation &return_op = body.back().back();
600fec6c5acSUday Bondhugula   result.types.append(return_op.operand_type_begin(),
601fec6c5acSUday Bondhugula                       return_op.operand_type_end());
602fec6c5acSUday Bondhugula 
603fec6c5acSUday Bondhugula   // Use the location of the wrapped op for the "test.wrapping_region" op.
604fec6c5acSUday Bondhugula   result.location = wrapped_op->getLoc();
605fec6c5acSUday Bondhugula 
606fec6c5acSUday Bondhugula   return success();
607fec6c5acSUday Bondhugula }
608fec6c5acSUday Bondhugula 
609fec6c5acSUday Bondhugula static void print(OpAsmPrinter &p, WrappingRegionOp op) {
610fec6c5acSUday Bondhugula   p << op.getOperationName() << " wraps ";
611fec6c5acSUday Bondhugula   p.printGenericOp(&op.region().front().front());
612fec6c5acSUday Bondhugula }
613fec6c5acSUday Bondhugula 
614fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
615fec6c5acSUday Bondhugula // Test PolyForOp - parse list of region arguments.
616fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
617fec6c5acSUday Bondhugula 
618fec6c5acSUday Bondhugula static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
619fec6c5acSUday Bondhugula   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
620fec6c5acSUday Bondhugula   // Parse list of region arguments without a delimiter.
621fec6c5acSUday Bondhugula   if (parser.parseRegionArgumentList(ivsInfo))
622fec6c5acSUday Bondhugula     return failure();
623fec6c5acSUday Bondhugula 
624fec6c5acSUday Bondhugula   // Parse the body region.
625fec6c5acSUday Bondhugula   Region *body = result.addRegion();
626fec6c5acSUday Bondhugula   auto &builder = parser.getBuilder();
627fec6c5acSUday Bondhugula   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
628fec6c5acSUday Bondhugula   return parser.parseRegion(*body, ivsInfo, argTypes);
629fec6c5acSUday Bondhugula }
630fec6c5acSUday Bondhugula 
631fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
632fec6c5acSUday Bondhugula // Test removing op with inner ops.
633fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
634fec6c5acSUday Bondhugula 
635fec6c5acSUday Bondhugula namespace {
636fec6c5acSUday Bondhugula struct TestRemoveOpWithInnerOps
637fec6c5acSUday Bondhugula     : public OpRewritePattern<TestOpWithRegionPattern> {
638fec6c5acSUday Bondhugula   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
639fec6c5acSUday Bondhugula 
640fec6c5acSUday Bondhugula   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
641fec6c5acSUday Bondhugula                                 PatternRewriter &rewriter) const override {
642fec6c5acSUday Bondhugula     rewriter.eraseOp(op);
643fec6c5acSUday Bondhugula     return success();
644fec6c5acSUday Bondhugula   }
645fec6c5acSUday Bondhugula };
646fec6c5acSUday Bondhugula } // end anonymous namespace
647fec6c5acSUday Bondhugula 
648fec6c5acSUday Bondhugula void TestOpWithRegionPattern::getCanonicalizationPatterns(
649fec6c5acSUday Bondhugula     OwningRewritePatternList &results, MLIRContext *context) {
650fec6c5acSUday Bondhugula   results.insert<TestRemoveOpWithInnerOps>(context);
651fec6c5acSUday Bondhugula }
652fec6c5acSUday Bondhugula 
653fec6c5acSUday Bondhugula OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
654fec6c5acSUday Bondhugula   return operand();
655fec6c5acSUday Bondhugula }
656fec6c5acSUday Bondhugula 
6572bf423b0SRob Suderman OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
6582bf423b0SRob Suderman   return getValue();
6592bf423b0SRob Suderman }
6602bf423b0SRob Suderman 
661fec6c5acSUday Bondhugula LogicalResult TestOpWithVariadicResultsAndFolder::fold(
662fec6c5acSUday Bondhugula     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
663fec6c5acSUday Bondhugula   for (Value input : this->operands()) {
664fec6c5acSUday Bondhugula     results.push_back(input);
665fec6c5acSUday Bondhugula   }
666fec6c5acSUday Bondhugula   return success();
667fec6c5acSUday Bondhugula }
668fec6c5acSUday Bondhugula 
66926f93d9fSAlex Zinenko OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
67026f93d9fSAlex Zinenko   assert(operands.size() == 1);
67126f93d9fSAlex Zinenko   if (operands.front()) {
67226f93d9fSAlex Zinenko     setAttr("attr", operands.front());
67326f93d9fSAlex Zinenko     return getResult();
67426f93d9fSAlex Zinenko   }
67526f93d9fSAlex Zinenko   return {};
67626f93d9fSAlex Zinenko }
67726f93d9fSAlex Zinenko 
67862828865SStephen Neuendorffer LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
679fec6c5acSUday Bondhugula     MLIRContext *, Optional<Location> location, ValueRange operands,
6805eae715aSJacques Pienaar     DictionaryAttr attributes, RegionRange regions,
681fec6c5acSUday Bondhugula     SmallVectorImpl<Type> &inferredReturnTypes) {
682fec6c5acSUday Bondhugula   if (operands[0].getType() != operands[1].getType()) {
683fec6c5acSUday Bondhugula     return emitOptionalError(location, "operand type mismatch ",
684fec6c5acSUday Bondhugula                              operands[0].getType(), " vs ",
685fec6c5acSUday Bondhugula                              operands[1].getType());
686fec6c5acSUday Bondhugula   }
687fec6c5acSUday Bondhugula   inferredReturnTypes.assign({operands[0].getType()});
688fec6c5acSUday Bondhugula   return success();
689fec6c5acSUday Bondhugula }
690fec6c5acSUday Bondhugula 
691fec6c5acSUday Bondhugula LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
692fec6c5acSUday Bondhugula     MLIRContext *context, Optional<Location> location, ValueRange operands,
6935eae715aSJacques Pienaar     DictionaryAttr attributes, RegionRange regions,
694fec6c5acSUday Bondhugula     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
695fec6c5acSUday Bondhugula   // Create return type consisting of the last element of the first operand.
696fec6c5acSUday Bondhugula   auto operandType = *operands.getTypes().begin();
697fec6c5acSUday Bondhugula   auto sval = operandType.dyn_cast<ShapedType>();
698fec6c5acSUday Bondhugula   if (!sval) {
699fec6c5acSUday Bondhugula     return emitOptionalError(location, "only shaped type operands allowed");
700fec6c5acSUday Bondhugula   }
701fec6c5acSUday Bondhugula   int64_t dim =
702fec6c5acSUday Bondhugula       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
703fec6c5acSUday Bondhugula   auto type = IntegerType::get(17, context);
704fec6c5acSUday Bondhugula   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
705fec6c5acSUday Bondhugula   return success();
706fec6c5acSUday Bondhugula }
707fec6c5acSUday Bondhugula 
708fec6c5acSUday Bondhugula LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
709fec6c5acSUday Bondhugula     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
710fec6c5acSUday Bondhugula   shapes = SmallVector<Value, 1>{
71162828865SStephen Neuendorffer       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
712fec6c5acSUday Bondhugula   return success();
713fec6c5acSUday Bondhugula }
714fec6c5acSUday Bondhugula 
715fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
716fec6c5acSUday Bondhugula // Test SideEffect interfaces
717fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
718fec6c5acSUday Bondhugula 
719fec6c5acSUday Bondhugula namespace {
720fec6c5acSUday Bondhugula /// A test resource for side effects.
721fec6c5acSUday Bondhugula struct TestResource : public SideEffects::Resource::Base<TestResource> {
722fec6c5acSUday Bondhugula   StringRef getName() final { return "<Test>"; }
723fec6c5acSUday Bondhugula };
724fec6c5acSUday Bondhugula } // end anonymous namespace
725fec6c5acSUday Bondhugula 
726fec6c5acSUday Bondhugula void SideEffectOp::getEffects(
727fec6c5acSUday Bondhugula     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
728fec6c5acSUday Bondhugula   // Check for an effects attribute on the op instance.
729fec6c5acSUday Bondhugula   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
730fec6c5acSUday Bondhugula   if (!effectsAttr)
731fec6c5acSUday Bondhugula     return;
732fec6c5acSUday Bondhugula 
733fec6c5acSUday Bondhugula   // If there is one, it is an array of dictionary attributes that hold
734fec6c5acSUday Bondhugula   // information on the effects of this operation.
735fec6c5acSUday Bondhugula   for (Attribute element : effectsAttr) {
736fec6c5acSUday Bondhugula     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
737fec6c5acSUday Bondhugula 
738fec6c5acSUday Bondhugula     // Get the specific memory effect.
739fec6c5acSUday Bondhugula     MemoryEffects::Effect *effect =
740cc83dc19SChristian Sigg         StringSwitch<MemoryEffects::Effect *>(
741fec6c5acSUday Bondhugula             effectElement.get("effect").cast<StringAttr>().getValue())
742fec6c5acSUday Bondhugula             .Case("allocate", MemoryEffects::Allocate::get())
743fec6c5acSUday Bondhugula             .Case("free", MemoryEffects::Free::get())
744fec6c5acSUday Bondhugula             .Case("read", MemoryEffects::Read::get())
745fec6c5acSUday Bondhugula             .Case("write", MemoryEffects::Write::get());
746fec6c5acSUday Bondhugula 
747fec6c5acSUday Bondhugula     // Check for a result to affect.
748fec6c5acSUday Bondhugula     Value value;
749fec6c5acSUday Bondhugula     if (effectElement.get("on_result"))
750fec6c5acSUday Bondhugula       value = getResult();
751fec6c5acSUday Bondhugula 
752fec6c5acSUday Bondhugula     // Check for a non-default resource to use.
753fec6c5acSUday Bondhugula     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
754fec6c5acSUday Bondhugula     if (effectElement.get("test_resource"))
755fec6c5acSUday Bondhugula       resource = TestResource::get();
756fec6c5acSUday Bondhugula 
757fec6c5acSUday Bondhugula     effects.emplace_back(effect, value, resource);
758fec6c5acSUday Bondhugula   }
759fec6c5acSUday Bondhugula }
760fec6c5acSUday Bondhugula 
761*052d24afSAlex Zinenko void SideEffectOp::getEffects(
762*052d24afSAlex Zinenko     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
763*052d24afSAlex Zinenko   auto effectsAttr = getAttrOfType<AffineMapAttr>("effect_parameter");
764*052d24afSAlex Zinenko   if (!effectsAttr)
765*052d24afSAlex Zinenko     return;
766*052d24afSAlex Zinenko 
767*052d24afSAlex Zinenko   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
768*052d24afSAlex Zinenko }
769*052d24afSAlex Zinenko 
770fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
771fec6c5acSUday Bondhugula // StringAttrPrettyNameOp
772fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
773fec6c5acSUday Bondhugula 
774fec6c5acSUday Bondhugula // This op has fancy handling of its SSA result name.
775fec6c5acSUday Bondhugula static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
776fec6c5acSUday Bondhugula                                                OperationState &result) {
777fec6c5acSUday Bondhugula   // Add the result types.
778fec6c5acSUday Bondhugula   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
779fec6c5acSUday Bondhugula     result.addTypes(parser.getBuilder().getIntegerType(32));
780fec6c5acSUday Bondhugula 
781fec6c5acSUday Bondhugula   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
782fec6c5acSUday Bondhugula     return failure();
783fec6c5acSUday Bondhugula 
784fec6c5acSUday Bondhugula   // If the attribute dictionary contains no 'names' attribute, infer it from
785fec6c5acSUday Bondhugula   // the SSA name (if specified).
786fec6c5acSUday Bondhugula   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
78774e6a5b2SChris Lattner     return attr.first == "names";
788fec6c5acSUday Bondhugula   });
789fec6c5acSUday Bondhugula 
790fec6c5acSUday Bondhugula   // If there was no name specified, check to see if there was a useful name
791fec6c5acSUday Bondhugula   // specified in the asm file.
792fec6c5acSUday Bondhugula   if (hadNames || parser.getNumResults() == 0)
793fec6c5acSUday Bondhugula     return success();
794fec6c5acSUday Bondhugula 
795fec6c5acSUday Bondhugula   SmallVector<StringRef, 4> names;
796fec6c5acSUday Bondhugula   auto *context = result.getContext();
797fec6c5acSUday Bondhugula 
798fec6c5acSUday Bondhugula   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
799fec6c5acSUday Bondhugula     auto resultName = parser.getResultName(i);
800fec6c5acSUday Bondhugula     StringRef nameStr;
801fec6c5acSUday Bondhugula     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
802fec6c5acSUday Bondhugula       nameStr = resultName.first;
803fec6c5acSUday Bondhugula 
804fec6c5acSUday Bondhugula     names.push_back(nameStr);
805fec6c5acSUday Bondhugula   }
806fec6c5acSUday Bondhugula 
807fec6c5acSUday Bondhugula   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
808fec6c5acSUday Bondhugula   result.attributes.push_back({Identifier::get("names", context), namesAttr});
809fec6c5acSUday Bondhugula   return success();
810fec6c5acSUday Bondhugula }
811fec6c5acSUday Bondhugula 
812fec6c5acSUday Bondhugula static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
813fec6c5acSUday Bondhugula   p << "test.string_attr_pretty_name";
814fec6c5acSUday Bondhugula 
815fec6c5acSUday Bondhugula   // Note that we only need to print the "name" attribute if the asmprinter
816fec6c5acSUday Bondhugula   // result name disagrees with it.  This can happen in strange cases, e.g.
817fec6c5acSUday Bondhugula   // when there are conflicts.
818fec6c5acSUday Bondhugula   bool namesDisagree = op.names().size() != op.getNumResults();
819fec6c5acSUday Bondhugula 
820fec6c5acSUday Bondhugula   SmallString<32> resultNameStr;
821fec6c5acSUday Bondhugula   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
822fec6c5acSUday Bondhugula     resultNameStr.clear();
823fec6c5acSUday Bondhugula     llvm::raw_svector_ostream tmpStream(resultNameStr);
824fec6c5acSUday Bondhugula     p.printOperand(op.getResult(i), tmpStream);
825fec6c5acSUday Bondhugula 
826fec6c5acSUday Bondhugula     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
827fec6c5acSUday Bondhugula     if (!expectedName ||
828fec6c5acSUday Bondhugula         tmpStream.str().drop_front() != expectedName.getValue()) {
829fec6c5acSUday Bondhugula       namesDisagree = true;
830fec6c5acSUday Bondhugula     }
831fec6c5acSUday Bondhugula   }
832fec6c5acSUday Bondhugula 
833fec6c5acSUday Bondhugula   if (namesDisagree)
834fec6c5acSUday Bondhugula     p.printOptionalAttrDictWithKeyword(op.getAttrs());
835fec6c5acSUday Bondhugula   else
836fec6c5acSUday Bondhugula     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
837fec6c5acSUday Bondhugula }
838fec6c5acSUday Bondhugula 
839fec6c5acSUday Bondhugula // We set the SSA name in the asm syntax to the contents of the name
840fec6c5acSUday Bondhugula // attribute.
841fec6c5acSUday Bondhugula void StringAttrPrettyNameOp::getAsmResultNames(
842fec6c5acSUday Bondhugula     function_ref<void(Value, StringRef)> setNameFn) {
843fec6c5acSUday Bondhugula 
844fec6c5acSUday Bondhugula   auto value = names();
845fec6c5acSUday Bondhugula   for (size_t i = 0, e = value.size(); i != e; ++i)
846fec6c5acSUday Bondhugula     if (auto str = value[i].dyn_cast<StringAttr>())
847fec6c5acSUday Bondhugula       if (!str.getValue().empty())
848fec6c5acSUday Bondhugula         setNameFn(getResult(i), str.getValue());
849fec6c5acSUday Bondhugula }
850fec6c5acSUday Bondhugula 
851fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===//
8526f5da84fSMarcel Koester // RegionIfOp
8536f5da84fSMarcel Koester //===----------------------------------------------------------------------===//
8546f5da84fSMarcel Koester 
8556f5da84fSMarcel Koester static void print(OpAsmPrinter &p, RegionIfOp op) {
8566f5da84fSMarcel Koester   p << RegionIfOp::getOperationName() << " ";
8576f5da84fSMarcel Koester   p.printOperands(op.getOperands());
8586f5da84fSMarcel Koester   p << ": " << op.getOperandTypes();
8596f5da84fSMarcel Koester   p.printArrowTypeList(op.getResultTypes());
8606f5da84fSMarcel Koester   p << " then";
8616f5da84fSMarcel Koester   p.printRegion(op.thenRegion(),
8626f5da84fSMarcel Koester                 /*printEntryBlockArgs=*/true,
8636f5da84fSMarcel Koester                 /*printBlockTerminators=*/true);
8646f5da84fSMarcel Koester   p << " else";
8656f5da84fSMarcel Koester   p.printRegion(op.elseRegion(),
8666f5da84fSMarcel Koester                 /*printEntryBlockArgs=*/true,
8676f5da84fSMarcel Koester                 /*printBlockTerminators=*/true);
8686f5da84fSMarcel Koester   p << " join";
8696f5da84fSMarcel Koester   p.printRegion(op.joinRegion(),
8706f5da84fSMarcel Koester                 /*printEntryBlockArgs=*/true,
8716f5da84fSMarcel Koester                 /*printBlockTerminators=*/true);
8726f5da84fSMarcel Koester }
8736f5da84fSMarcel Koester 
8746f5da84fSMarcel Koester static ParseResult parseRegionIfOp(OpAsmParser &parser,
8756f5da84fSMarcel Koester                                    OperationState &result) {
8766f5da84fSMarcel Koester   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
8776f5da84fSMarcel Koester   SmallVector<Type, 2> operandTypes;
8786f5da84fSMarcel Koester 
8796f5da84fSMarcel Koester   result.regions.reserve(3);
8806f5da84fSMarcel Koester   Region *thenRegion = result.addRegion();
8816f5da84fSMarcel Koester   Region *elseRegion = result.addRegion();
8826f5da84fSMarcel Koester   Region *joinRegion = result.addRegion();
8836f5da84fSMarcel Koester 
8846f5da84fSMarcel Koester   // Parse operand, type and arrow type lists.
8856f5da84fSMarcel Koester   if (parser.parseOperandList(operandInfos) ||
8866f5da84fSMarcel Koester       parser.parseColonTypeList(operandTypes) ||
8876f5da84fSMarcel Koester       parser.parseArrowTypeList(result.types))
8886f5da84fSMarcel Koester     return failure();
8896f5da84fSMarcel Koester 
8906f5da84fSMarcel Koester   // Parse all attached regions.
8916f5da84fSMarcel Koester   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
8926f5da84fSMarcel Koester       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
8936f5da84fSMarcel Koester       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
8946f5da84fSMarcel Koester     return failure();
8956f5da84fSMarcel Koester 
8966f5da84fSMarcel Koester   return parser.resolveOperands(operandInfos, operandTypes,
8976f5da84fSMarcel Koester                                 parser.getCurrentLocation(), result.operands);
8986f5da84fSMarcel Koester }
8996f5da84fSMarcel Koester 
9006f5da84fSMarcel Koester OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
9016f5da84fSMarcel Koester   assert(index < 2 && "invalid region index");
9026f5da84fSMarcel Koester   return getOperands();
9036f5da84fSMarcel Koester }
9046f5da84fSMarcel Koester 
9056f5da84fSMarcel Koester void RegionIfOp::getSuccessorRegions(
9066f5da84fSMarcel Koester     Optional<unsigned> index, ArrayRef<Attribute> operands,
9076f5da84fSMarcel Koester     SmallVectorImpl<RegionSuccessor> &regions) {
9086f5da84fSMarcel Koester   // We always branch to the join region.
9096f5da84fSMarcel Koester   if (index.hasValue()) {
9106f5da84fSMarcel Koester     if (index.getValue() < 2)
9116f5da84fSMarcel Koester       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
9126f5da84fSMarcel Koester     else
9136f5da84fSMarcel Koester       regions.push_back(RegionSuccessor(getResults()));
9146f5da84fSMarcel Koester     return;
9156f5da84fSMarcel Koester   }
9166f5da84fSMarcel Koester 
9176f5da84fSMarcel Koester   // The then and else regions are the entry regions of this op.
9186f5da84fSMarcel Koester   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
9196f5da84fSMarcel Koester   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
9206f5da84fSMarcel Koester }
9216f5da84fSMarcel Koester 
922fec6c5acSUday Bondhugula #include "TestOpEnums.cpp.inc"
923*052d24afSAlex Zinenko #include "TestOpInterfaces.cpp.inc"
9249c9f479aSSean Silva #include "TestOpStructs.cpp.inc"
9252e2cdd0aSRiver Riddle #include "TestTypeInterfaces.cpp.inc"
926fec6c5acSUday Bondhugula 
927fec6c5acSUday Bondhugula #define GET_OP_CLASSES
928fec6c5acSUday Bondhugula #include "TestOps.cpp.inc"
929