1fec6c5acSUday Bondhugula //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// 2fec6c5acSUday Bondhugula // 3fec6c5acSUday Bondhugula // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fec6c5acSUday Bondhugula // See https://llvm.org/LICENSE.txt for license information. 5fec6c5acSUday Bondhugula // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fec6c5acSUday Bondhugula // 7fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 8fec6c5acSUday Bondhugula 9fec6c5acSUday Bondhugula #include "TestDialect.h" 102e2cdd0aSRiver Riddle #include "TestTypes.h" 11fec6c5acSUday Bondhugula #include "mlir/Dialect/StandardOps/IR/Ops.h" 1273ca690dSRiver Riddle #include "mlir/IR/BuiltinDialect.h" 132e2cdd0aSRiver Riddle #include "mlir/IR/DialectImplementation.h" 14fec6c5acSUday Bondhugula #include "mlir/IR/PatternMatch.h" 15fec6c5acSUday Bondhugula #include "mlir/IR/TypeUtilities.h" 16fec6c5acSUday Bondhugula #include "mlir/Transforms/FoldUtils.h" 17fec6c5acSUday Bondhugula #include "mlir/Transforms/InliningUtils.h" 18a5182991SAlex Zinenko #include "llvm/ADT/SetVector.h" 19fec6c5acSUday Bondhugula #include "llvm/ADT/StringSwitch.h" 20fec6c5acSUday Bondhugula 21fec6c5acSUday Bondhugula using namespace mlir; 2272c65b69SAlexander Belyaev using namespace mlir::test; 23fec6c5acSUday Bondhugula 2472c65b69SAlexander Belyaev void mlir::test::registerTestDialect(DialectRegistry ®istry) { 25f9dc2b70SMehdi Amini registry.insert<TestDialect>(); 26f9dc2b70SMehdi Amini } 27f9dc2b70SMehdi Amini 28fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 29fec6c5acSUday Bondhugula // TestDialect Interfaces 30fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 31fec6c5acSUday Bondhugula 32fec6c5acSUday Bondhugula namespace { 33fec6c5acSUday Bondhugula 34fec6c5acSUday Bondhugula // Test support for interacting with the AsmPrinter. 35fec6c5acSUday Bondhugula struct TestOpAsmInterface : public OpAsmDialectInterface { 36fec6c5acSUday Bondhugula using OpAsmDialectInterface::OpAsmDialectInterface; 37fec6c5acSUday Bondhugula 38a463ea50SRiver Riddle LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { 39a463ea50SRiver Riddle StringAttr strAttr = attr.dyn_cast<StringAttr>(); 40a463ea50SRiver Riddle if (!strAttr) 41a463ea50SRiver Riddle return failure(); 42a463ea50SRiver Riddle 43a463ea50SRiver Riddle // Check the contents of the string attribute to see what the test alias 44a463ea50SRiver Riddle // should be named. 45a463ea50SRiver Riddle Optional<StringRef> aliasName = 46a463ea50SRiver Riddle StringSwitch<Optional<StringRef>>(strAttr.getValue()) 47a463ea50SRiver Riddle .Case("alias_test:dot_in_name", StringRef("test.alias")) 48a463ea50SRiver Riddle .Case("alias_test:trailing_digit", StringRef("test_alias0")) 49a463ea50SRiver Riddle .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) 50a463ea50SRiver Riddle .Case("alias_test:sanitize_conflict_a", 51a463ea50SRiver Riddle StringRef("test_alias_conflict0")) 52a463ea50SRiver Riddle .Case("alias_test:sanitize_conflict_b", 53a463ea50SRiver Riddle StringRef("test_alias_conflict0_")) 54a463ea50SRiver Riddle .Default(llvm::None); 55a463ea50SRiver Riddle if (!aliasName) 56a463ea50SRiver Riddle return failure(); 57a463ea50SRiver Riddle 58a463ea50SRiver Riddle os << *aliasName; 59a463ea50SRiver Riddle return success(); 60a463ea50SRiver Riddle } 61a463ea50SRiver Riddle 62fec6c5acSUday Bondhugula void getAsmResultNames(Operation *op, 63fec6c5acSUday Bondhugula OpAsmSetValueNameFn setNameFn) const final { 64fec6c5acSUday Bondhugula if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op)) 65fec6c5acSUday Bondhugula setNameFn(asmOp, "result"); 66fec6c5acSUday Bondhugula } 67fec6c5acSUday Bondhugula 68fec6c5acSUday Bondhugula void getAsmBlockArgumentNames(Block *block, 69fec6c5acSUday Bondhugula OpAsmSetValueNameFn setNameFn) const final { 70fec6c5acSUday Bondhugula auto op = block->getParentOp(); 71fec6c5acSUday Bondhugula auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names"); 72fec6c5acSUday Bondhugula if (!arrayAttr) 73fec6c5acSUday Bondhugula return; 74fec6c5acSUday Bondhugula auto args = block->getArguments(); 75fec6c5acSUday Bondhugula auto e = std::min(arrayAttr.size(), args.size()); 76fec6c5acSUday Bondhugula for (unsigned i = 0; i < e; ++i) { 77fec6c5acSUday Bondhugula if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>()) 78fec6c5acSUday Bondhugula setNameFn(args[i], strAttr.getValue()); 79fec6c5acSUday Bondhugula } 80fec6c5acSUday Bondhugula } 81fec6c5acSUday Bondhugula }; 82fec6c5acSUday Bondhugula 83b28e3db8SMehdi Amini struct TestDialectFoldInterface : public DialectFoldInterface { 84b28e3db8SMehdi Amini using DialectFoldInterface::DialectFoldInterface; 85fec6c5acSUday Bondhugula 86fec6c5acSUday Bondhugula /// Registered hook to check if the given region, which is attached to an 87fec6c5acSUday Bondhugula /// operation that is *not* isolated from above, should be used when 88fec6c5acSUday Bondhugula /// materializing constants. 89fec6c5acSUday Bondhugula bool shouldMaterializeInto(Region *region) const final { 90fec6c5acSUday Bondhugula // If this is a one region operation, then insert into it. 91fec6c5acSUday Bondhugula return isa<OneRegionOp>(region->getParentOp()); 92fec6c5acSUday Bondhugula } 93fec6c5acSUday Bondhugula }; 94fec6c5acSUday Bondhugula 95fec6c5acSUday Bondhugula /// This class defines the interface for handling inlining with standard 96fec6c5acSUday Bondhugula /// operations. 97fec6c5acSUday Bondhugula struct TestInlinerInterface : public DialectInlinerInterface { 98fec6c5acSUday Bondhugula using DialectInlinerInterface::DialectInlinerInterface; 99fec6c5acSUday Bondhugula 100fec6c5acSUday Bondhugula //===--------------------------------------------------------------------===// 101fec6c5acSUday Bondhugula // Analysis Hooks 102fec6c5acSUday Bondhugula //===--------------------------------------------------------------------===// 103fec6c5acSUday Bondhugula 104fa417479SRiver Riddle bool isLegalToInline(Operation *call, Operation *callable, 105fa417479SRiver Riddle bool wouldBeCloned) const final { 106501fda01SRiver Riddle // Don't allow inlining calls that are marked `noinline`. 107501fda01SRiver Riddle return !call->hasAttr("noinline"); 108501fda01SRiver Riddle } 109fa417479SRiver Riddle bool isLegalToInline(Region *, Region *, bool, 110fa417479SRiver Riddle BlockAndValueMapping &) const final { 111fec6c5acSUday Bondhugula // Inlining into test dialect regions is legal. 112fec6c5acSUday Bondhugula return true; 113fec6c5acSUday Bondhugula } 114fa417479SRiver Riddle bool isLegalToInline(Operation *, Region *, bool, 115fec6c5acSUday Bondhugula BlockAndValueMapping &) const final { 116fec6c5acSUday Bondhugula return true; 117fec6c5acSUday Bondhugula } 118fec6c5acSUday Bondhugula 119fec6c5acSUday Bondhugula bool shouldAnalyzeRecursively(Operation *op) const final { 120fec6c5acSUday Bondhugula // Analyze recursively if this is not a functional region operation, it 121fec6c5acSUday Bondhugula // froms a separate functional scope. 122fec6c5acSUday Bondhugula return !isa<FunctionalRegionOp>(op); 123fec6c5acSUday Bondhugula } 124fec6c5acSUday Bondhugula 125fec6c5acSUday Bondhugula //===--------------------------------------------------------------------===// 126fec6c5acSUday Bondhugula // Transformation Hooks 127fec6c5acSUday Bondhugula //===--------------------------------------------------------------------===// 128fec6c5acSUday Bondhugula 129fec6c5acSUday Bondhugula /// Handle the given inlined terminator by replacing it with a new operation 130fec6c5acSUday Bondhugula /// as necessary. 131fec6c5acSUday Bondhugula void handleTerminator(Operation *op, 132fec6c5acSUday Bondhugula ArrayRef<Value> valuesToRepl) const final { 133fec6c5acSUday Bondhugula // Only handle "test.return" here. 134fec6c5acSUday Bondhugula auto returnOp = dyn_cast<TestReturnOp>(op); 135fec6c5acSUday Bondhugula if (!returnOp) 136fec6c5acSUday Bondhugula return; 137fec6c5acSUday Bondhugula 138fec6c5acSUday Bondhugula // Replace the values directly with the return operands. 139fec6c5acSUday Bondhugula assert(returnOp.getNumOperands() == valuesToRepl.size()); 140fec6c5acSUday Bondhugula for (const auto &it : llvm::enumerate(returnOp.getOperands())) 141fec6c5acSUday Bondhugula valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 142fec6c5acSUday Bondhugula } 143fec6c5acSUday Bondhugula 144fec6c5acSUday Bondhugula /// Attempt to materialize a conversion for a type mismatch between a call 145fec6c5acSUday Bondhugula /// from this dialect, and a callable region. This method should generate an 146fec6c5acSUday Bondhugula /// operation that takes 'input' as the only operand, and produces a single 147fec6c5acSUday Bondhugula /// result of 'resultType'. If a conversion can not be generated, nullptr 148fec6c5acSUday Bondhugula /// should be returned. 149fec6c5acSUday Bondhugula Operation *materializeCallConversion(OpBuilder &builder, Value input, 150fec6c5acSUday Bondhugula Type resultType, 151fec6c5acSUday Bondhugula Location conversionLoc) const final { 152fec6c5acSUday Bondhugula // Only allow conversion for i16/i32 types. 153fec6c5acSUday Bondhugula if (!(resultType.isSignlessInteger(16) || 154fec6c5acSUday Bondhugula resultType.isSignlessInteger(32)) || 155fec6c5acSUday Bondhugula !(input.getType().isSignlessInteger(16) || 156fec6c5acSUday Bondhugula input.getType().isSignlessInteger(32))) 157fec6c5acSUday Bondhugula return nullptr; 158fec6c5acSUday Bondhugula return builder.create<TestCastOp>(conversionLoc, resultType, input); 159fec6c5acSUday Bondhugula } 160fec6c5acSUday Bondhugula }; 161fec6c5acSUday Bondhugula } // end anonymous namespace 162fec6c5acSUday Bondhugula 163fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 164fec6c5acSUday Bondhugula // TestDialect 165fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 166fec6c5acSUday Bondhugula 167575b22b5SMehdi Amini void TestDialect::initialize() { 168fec6c5acSUday Bondhugula addOperations< 169fec6c5acSUday Bondhugula #define GET_OP_LIST 170fec6c5acSUday Bondhugula #include "TestOps.cpp.inc" 171fec6c5acSUday Bondhugula >(); 172b28e3db8SMehdi Amini addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, 173fec6c5acSUday Bondhugula TestInlinerInterface>(); 1745fe53c41SJohn Demme addTypes<TestType, TestRecursiveType, 1755fe53c41SJohn Demme #define GET_TYPEDEF_LIST 1765fe53c41SJohn Demme #include "TestTypeDefs.cpp.inc" 1775fe53c41SJohn Demme >(); 178fec6c5acSUday Bondhugula allowUnknownOperations(); 179fec6c5acSUday Bondhugula } 180fec6c5acSUday Bondhugula 1815fe53c41SJohn Demme static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, 182a5182991SAlex Zinenko llvm::SetVector<Type> &stack) { 183a5182991SAlex Zinenko StringRef typeTag; 184a5182991SAlex Zinenko if (failed(parser.parseKeyword(&typeTag))) 1852e2cdd0aSRiver Riddle return Type(); 186a5182991SAlex Zinenko 1875fe53c41SJohn Demme auto genType = generatedTypeParser(ctxt, parser, typeTag); 1885fe53c41SJohn Demme if (genType != Type()) 1895fe53c41SJohn Demme return genType; 1905fe53c41SJohn Demme 191a5182991SAlex Zinenko if (typeTag == "test_type") 192a5182991SAlex Zinenko return TestType::get(parser.getBuilder().getContext()); 193a5182991SAlex Zinenko 194a5182991SAlex Zinenko if (typeTag != "test_rec") 195a5182991SAlex Zinenko return Type(); 196a5182991SAlex Zinenko 197a5182991SAlex Zinenko StringRef name; 198a5182991SAlex Zinenko if (parser.parseLess() || parser.parseKeyword(&name)) 199a5182991SAlex Zinenko return Type(); 200250f43d3SRiver Riddle auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); 201a5182991SAlex Zinenko 202a5182991SAlex Zinenko // If this type already has been parsed above in the stack, expect just the 203a5182991SAlex Zinenko // name. 204a5182991SAlex Zinenko if (stack.contains(rec)) { 205a5182991SAlex Zinenko if (failed(parser.parseGreater())) 206a5182991SAlex Zinenko return Type(); 207a5182991SAlex Zinenko return rec; 208a5182991SAlex Zinenko } 209a5182991SAlex Zinenko 210a5182991SAlex Zinenko // Otherwise, parse the body and update the type. 211a5182991SAlex Zinenko if (failed(parser.parseComma())) 212a5182991SAlex Zinenko return Type(); 213a5182991SAlex Zinenko stack.insert(rec); 2145fe53c41SJohn Demme Type subtype = parseTestType(ctxt, parser, stack); 215a5182991SAlex Zinenko stack.pop_back(); 216a5182991SAlex Zinenko if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) 217a5182991SAlex Zinenko return Type(); 218a5182991SAlex Zinenko 219a5182991SAlex Zinenko return rec; 220a5182991SAlex Zinenko } 221a5182991SAlex Zinenko 222a5182991SAlex Zinenko Type TestDialect::parseType(DialectAsmParser &parser) const { 223a5182991SAlex Zinenko llvm::SetVector<Type> stack; 2245fe53c41SJohn Demme return parseTestType(getContext(), parser, stack); 225a5182991SAlex Zinenko } 226a5182991SAlex Zinenko 227a5182991SAlex Zinenko static void printTestType(Type type, DialectAsmPrinter &printer, 228a5182991SAlex Zinenko llvm::SetVector<Type> &stack) { 2295fe53c41SJohn Demme if (succeeded(generatedTypePrinter(type, printer))) 2305fe53c41SJohn Demme return; 231a5182991SAlex Zinenko if (type.isa<TestType>()) { 232a5182991SAlex Zinenko printer << "test_type"; 233a5182991SAlex Zinenko return; 234a5182991SAlex Zinenko } 235a5182991SAlex Zinenko 236a5182991SAlex Zinenko auto rec = type.cast<TestRecursiveType>(); 237a5182991SAlex Zinenko printer << "test_rec<" << rec.getName(); 238a5182991SAlex Zinenko if (!stack.contains(rec)) { 239a5182991SAlex Zinenko printer << ", "; 240a5182991SAlex Zinenko stack.insert(rec); 241a5182991SAlex Zinenko printTestType(rec.getBody(), printer, stack); 242a5182991SAlex Zinenko stack.pop_back(); 243a5182991SAlex Zinenko } 244a5182991SAlex Zinenko printer << ">"; 2452e2cdd0aSRiver Riddle } 2462e2cdd0aSRiver Riddle 2472e2cdd0aSRiver Riddle void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { 248a5182991SAlex Zinenko llvm::SetVector<Type> stack; 249a5182991SAlex Zinenko printTestType(type, printer, stack); 2502e2cdd0aSRiver Riddle } 2512e2cdd0aSRiver Riddle 252fec6c5acSUday Bondhugula LogicalResult TestDialect::verifyOperationAttribute(Operation *op, 253fec6c5acSUday Bondhugula NamedAttribute namedAttr) { 254fec6c5acSUday Bondhugula if (namedAttr.first == "test.invalid_attr") 255fec6c5acSUday Bondhugula return op->emitError() << "invalid to use 'test.invalid_attr'"; 256fec6c5acSUday Bondhugula return success(); 257fec6c5acSUday Bondhugula } 258fec6c5acSUday Bondhugula 259fec6c5acSUday Bondhugula LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, 260fec6c5acSUday Bondhugula unsigned regionIndex, 261fec6c5acSUday Bondhugula unsigned argIndex, 262fec6c5acSUday Bondhugula NamedAttribute namedAttr) { 263fec6c5acSUday Bondhugula if (namedAttr.first == "test.invalid_attr") 264fec6c5acSUday Bondhugula return op->emitError() << "invalid to use 'test.invalid_attr'"; 265fec6c5acSUday Bondhugula return success(); 266fec6c5acSUday Bondhugula } 267fec6c5acSUday Bondhugula 268fec6c5acSUday Bondhugula LogicalResult 269fec6c5acSUday Bondhugula TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, 270fec6c5acSUday Bondhugula unsigned resultIndex, 271fec6c5acSUday Bondhugula NamedAttribute namedAttr) { 272fec6c5acSUday Bondhugula if (namedAttr.first == "test.invalid_attr") 273fec6c5acSUday Bondhugula return op->emitError() << "invalid to use 'test.invalid_attr'"; 274fec6c5acSUday Bondhugula return success(); 275fec6c5acSUday Bondhugula } 276fec6c5acSUday Bondhugula 277fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 278fec6c5acSUday Bondhugula // TestBranchOp 279fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 280fec6c5acSUday Bondhugula 2810752d98cSRiver Riddle Optional<MutableOperandRange> 2820752d98cSRiver Riddle TestBranchOp::getMutableSuccessorOperands(unsigned index) { 283fec6c5acSUday Bondhugula assert(index == 0 && "invalid successor index"); 2840752d98cSRiver Riddle return targetOperandsMutable(); 285fec6c5acSUday Bondhugula } 286fec6c5acSUday Bondhugula 287fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 288f4ef77cbSRiver Riddle // TestFoldToCallOp 289f4ef77cbSRiver Riddle //===----------------------------------------------------------------------===// 290f4ef77cbSRiver Riddle 291f4ef77cbSRiver Riddle namespace { 292f4ef77cbSRiver Riddle struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 293f4ef77cbSRiver Riddle using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 294f4ef77cbSRiver Riddle 295f4ef77cbSRiver Riddle LogicalResult matchAndRewrite(FoldToCallOp op, 296f4ef77cbSRiver Riddle PatternRewriter &rewriter) const override { 29708e4f078SRahul Joshi rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(), 298f4ef77cbSRiver Riddle ValueRange()); 299f4ef77cbSRiver Riddle return success(); 300f4ef77cbSRiver Riddle } 301f4ef77cbSRiver Riddle }; 302f4ef77cbSRiver Riddle } // end anonymous namespace 303f4ef77cbSRiver Riddle 304f4ef77cbSRiver Riddle void FoldToCallOp::getCanonicalizationPatterns( 305f4ef77cbSRiver Riddle OwningRewritePatternList &results, MLIRContext *context) { 306f4ef77cbSRiver Riddle results.insert<FoldToCallOpPattern>(context); 307f4ef77cbSRiver Riddle } 308f4ef77cbSRiver Riddle 309f4ef77cbSRiver Riddle //===----------------------------------------------------------------------===// 31088c6e25eSRiver Riddle // Test Format* operations 31188c6e25eSRiver Riddle //===----------------------------------------------------------------------===// 31288c6e25eSRiver Riddle 31388c6e25eSRiver Riddle //===----------------------------------------------------------------------===// 31488c6e25eSRiver Riddle // Parsing 31588c6e25eSRiver Riddle 31688c6e25eSRiver Riddle static ParseResult parseCustomDirectiveOperands( 31788c6e25eSRiver Riddle OpAsmParser &parser, OpAsmParser::OperandType &operand, 31888c6e25eSRiver Riddle Optional<OpAsmParser::OperandType> &optOperand, 31988c6e25eSRiver Riddle SmallVectorImpl<OpAsmParser::OperandType> &varOperands) { 32088c6e25eSRiver Riddle if (parser.parseOperand(operand)) 32188c6e25eSRiver Riddle return failure(); 32288c6e25eSRiver Riddle if (succeeded(parser.parseOptionalComma())) { 32388c6e25eSRiver Riddle optOperand.emplace(); 32488c6e25eSRiver Riddle if (parser.parseOperand(*optOperand)) 32588c6e25eSRiver Riddle return failure(); 32688c6e25eSRiver Riddle } 32788c6e25eSRiver Riddle if (parser.parseArrow() || parser.parseLParen() || 32888c6e25eSRiver Riddle parser.parseOperandList(varOperands) || parser.parseRParen()) 32988c6e25eSRiver Riddle return failure(); 33088c6e25eSRiver Riddle return success(); 33188c6e25eSRiver Riddle } 33288c6e25eSRiver Riddle static ParseResult 33388c6e25eSRiver Riddle parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, 33488c6e25eSRiver Riddle Type &optOperandType, 33588c6e25eSRiver Riddle SmallVectorImpl<Type> &varOperandTypes) { 33688c6e25eSRiver Riddle if (parser.parseColon()) 33788c6e25eSRiver Riddle return failure(); 33888c6e25eSRiver Riddle 33988c6e25eSRiver Riddle if (parser.parseType(operandType)) 34088c6e25eSRiver Riddle return failure(); 34188c6e25eSRiver Riddle if (succeeded(parser.parseOptionalComma())) { 34288c6e25eSRiver Riddle if (parser.parseType(optOperandType)) 34388c6e25eSRiver Riddle return failure(); 34488c6e25eSRiver Riddle } 34588c6e25eSRiver Riddle if (parser.parseArrow() || parser.parseLParen() || 34688c6e25eSRiver Riddle parser.parseTypeList(varOperandTypes) || parser.parseRParen()) 34788c6e25eSRiver Riddle return failure(); 34888c6e25eSRiver Riddle return success(); 34988c6e25eSRiver Riddle } 35093fd30baSNicolas Vasilache static ParseResult 35193fd30baSNicolas Vasilache parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, 35293fd30baSNicolas Vasilache Type optOperandType, 35393fd30baSNicolas Vasilache const SmallVectorImpl<Type> &varOperandTypes) { 35493fd30baSNicolas Vasilache if (parser.parseKeyword("type_refs_capture")) 35593fd30baSNicolas Vasilache return failure(); 35693fd30baSNicolas Vasilache 35793fd30baSNicolas Vasilache Type operandType2, optOperandType2; 35893fd30baSNicolas Vasilache SmallVector<Type, 1> varOperandTypes2; 35993fd30baSNicolas Vasilache if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, 36093fd30baSNicolas Vasilache varOperandTypes2)) 36193fd30baSNicolas Vasilache return failure(); 36293fd30baSNicolas Vasilache 36393fd30baSNicolas Vasilache if (operandType != operandType2 || optOperandType != optOperandType2 || 36493fd30baSNicolas Vasilache varOperandTypes != varOperandTypes2) 36593fd30baSNicolas Vasilache return failure(); 36693fd30baSNicolas Vasilache 36793fd30baSNicolas Vasilache return success(); 36893fd30baSNicolas Vasilache } 36988c6e25eSRiver Riddle static ParseResult parseCustomDirectiveOperandsAndTypes( 37088c6e25eSRiver Riddle OpAsmParser &parser, OpAsmParser::OperandType &operand, 37188c6e25eSRiver Riddle Optional<OpAsmParser::OperandType> &optOperand, 37288c6e25eSRiver Riddle SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType, 37388c6e25eSRiver Riddle Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) { 37488c6e25eSRiver Riddle if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || 37588c6e25eSRiver Riddle parseCustomDirectiveResults(parser, operandType, optOperandType, 37688c6e25eSRiver Riddle varOperandTypes)) 37788c6e25eSRiver Riddle return failure(); 37888c6e25eSRiver Riddle return success(); 37988c6e25eSRiver Riddle } 380eaeadce9SRiver Riddle static ParseResult parseCustomDirectiveRegions( 381eaeadce9SRiver Riddle OpAsmParser &parser, Region ®ion, 382eaeadce9SRiver Riddle SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { 383eaeadce9SRiver Riddle if (parser.parseRegion(region)) 384eaeadce9SRiver Riddle return failure(); 385eaeadce9SRiver Riddle if (failed(parser.parseOptionalComma())) 386eaeadce9SRiver Riddle return success(); 387eaeadce9SRiver Riddle std::unique_ptr<Region> varRegion = std::make_unique<Region>(); 388eaeadce9SRiver Riddle if (parser.parseRegion(*varRegion)) 389eaeadce9SRiver Riddle return failure(); 390eaeadce9SRiver Riddle varRegions.emplace_back(std::move(varRegion)); 391eaeadce9SRiver Riddle return success(); 392eaeadce9SRiver Riddle } 39388c6e25eSRiver Riddle static ParseResult 39488c6e25eSRiver Riddle parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, 39588c6e25eSRiver Riddle SmallVectorImpl<Block *> &varSuccessors) { 39688c6e25eSRiver Riddle if (parser.parseSuccessor(successor)) 39788c6e25eSRiver Riddle return failure(); 39888c6e25eSRiver Riddle if (failed(parser.parseOptionalComma())) 39988c6e25eSRiver Riddle return success(); 40088c6e25eSRiver Riddle Block *varSuccessor; 40188c6e25eSRiver Riddle if (parser.parseSuccessor(varSuccessor)) 40288c6e25eSRiver Riddle return failure(); 40388c6e25eSRiver Riddle varSuccessors.append(2, varSuccessor); 40488c6e25eSRiver Riddle return success(); 40588c6e25eSRiver Riddle } 406d14cfe10SMike Urbach static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, 407d14cfe10SMike Urbach IntegerAttr &attr, 408d14cfe10SMike Urbach IntegerAttr &optAttr) { 409d14cfe10SMike Urbach if (parser.parseAttribute(attr)) 410d14cfe10SMike Urbach return failure(); 411d14cfe10SMike Urbach if (succeeded(parser.parseOptionalComma())) { 412d14cfe10SMike Urbach if (parser.parseAttribute(optAttr)) 413d14cfe10SMike Urbach return failure(); 414d14cfe10SMike Urbach } 415d14cfe10SMike Urbach return success(); 416d14cfe10SMike Urbach } 41788c6e25eSRiver Riddle 418035e12e6SJohn Demme static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, 419035e12e6SJohn Demme NamedAttrList &attrs) { 420035e12e6SJohn Demme return parser.parseOptionalAttrDict(attrs); 421035e12e6SJohn Demme } 422035e12e6SJohn Demme 42388c6e25eSRiver Riddle //===----------------------------------------------------------------------===// 42488c6e25eSRiver Riddle // Printing 42588c6e25eSRiver Riddle 426035e12e6SJohn Demme static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, 427035e12e6SJohn Demme Value operand, Value optOperand, 42888c6e25eSRiver Riddle OperandRange varOperands) { 42988c6e25eSRiver Riddle printer << operand; 43088c6e25eSRiver Riddle if (optOperand) 43188c6e25eSRiver Riddle printer << ", " << optOperand; 43288c6e25eSRiver Riddle printer << " -> (" << varOperands << ")"; 43388c6e25eSRiver Riddle } 434035e12e6SJohn Demme static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, 435035e12e6SJohn Demme Type operandType, Type optOperandType, 43688c6e25eSRiver Riddle TypeRange varOperandTypes) { 43788c6e25eSRiver Riddle printer << " : " << operandType; 43888c6e25eSRiver Riddle if (optOperandType) 43988c6e25eSRiver Riddle printer << ", " << optOperandType; 44088c6e25eSRiver Riddle printer << " -> (" << varOperandTypes << ")"; 44188c6e25eSRiver Riddle } 44293fd30baSNicolas Vasilache static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, 443035e12e6SJohn Demme Operation *op, Type operandType, 44493fd30baSNicolas Vasilache Type optOperandType, 44593fd30baSNicolas Vasilache TypeRange varOperandTypes) { 44693fd30baSNicolas Vasilache printer << " type_refs_capture "; 447035e12e6SJohn Demme printCustomDirectiveResults(printer, op, operandType, optOperandType, 44893fd30baSNicolas Vasilache varOperandTypes); 44993fd30baSNicolas Vasilache } 450035e12e6SJohn Demme static void printCustomDirectiveOperandsAndTypes( 451035e12e6SJohn Demme OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, 452035e12e6SJohn Demme OperandRange varOperands, Type operandType, Type optOperandType, 45388c6e25eSRiver Riddle TypeRange varOperandTypes) { 454035e12e6SJohn Demme printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); 455035e12e6SJohn Demme printCustomDirectiveResults(printer, op, operandType, optOperandType, 45688c6e25eSRiver Riddle varOperandTypes); 45788c6e25eSRiver Riddle } 458035e12e6SJohn Demme static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, 459035e12e6SJohn Demme Region ®ion, 460eaeadce9SRiver Riddle MutableArrayRef<Region> varRegions) { 461eaeadce9SRiver Riddle printer.printRegion(region); 462eaeadce9SRiver Riddle if (!varRegions.empty()) { 463eaeadce9SRiver Riddle printer << ", "; 464eaeadce9SRiver Riddle for (Region ®ion : varRegions) 465eaeadce9SRiver Riddle printer.printRegion(region); 466eaeadce9SRiver Riddle } 467eaeadce9SRiver Riddle } 468035e12e6SJohn Demme static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, 46988c6e25eSRiver Riddle Block *successor, 47088c6e25eSRiver Riddle SuccessorRange varSuccessors) { 47188c6e25eSRiver Riddle printer << successor; 47288c6e25eSRiver Riddle if (!varSuccessors.empty()) 47388c6e25eSRiver Riddle printer << ", " << varSuccessors.front(); 47488c6e25eSRiver Riddle } 475035e12e6SJohn Demme static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, 476d14cfe10SMike Urbach Attribute attribute, 477d14cfe10SMike Urbach Attribute optAttribute) { 478d14cfe10SMike Urbach printer << attribute; 479d14cfe10SMike Urbach if (optAttribute) 480d14cfe10SMike Urbach printer << ", " << optAttribute; 481d14cfe10SMike Urbach } 48288c6e25eSRiver Riddle 483035e12e6SJohn Demme static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, 484035e12e6SJohn Demme MutableDictionaryAttr attrs) { 485035e12e6SJohn Demme printer.printOptionalAttrDict(attrs.getAttrs()); 486035e12e6SJohn Demme } 48788c6e25eSRiver Riddle //===----------------------------------------------------------------------===// 488fec6c5acSUday Bondhugula // Test IsolatedRegionOp - parse passthrough region arguments. 489fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 490fec6c5acSUday Bondhugula 491fec6c5acSUday Bondhugula static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, 492fec6c5acSUday Bondhugula OperationState &result) { 493fec6c5acSUday Bondhugula OpAsmParser::OperandType argInfo; 494fec6c5acSUday Bondhugula Type argType = parser.getBuilder().getIndexType(); 495fec6c5acSUday Bondhugula 496fec6c5acSUday Bondhugula // Parse the input operand. 497fec6c5acSUday Bondhugula if (parser.parseOperand(argInfo) || 498fec6c5acSUday Bondhugula parser.resolveOperand(argInfo, argType, result.operands)) 499fec6c5acSUday Bondhugula return failure(); 500fec6c5acSUday Bondhugula 501fec6c5acSUday Bondhugula // Parse the body region, and reuse the operand info as the argument info. 502fec6c5acSUday Bondhugula Region *body = result.addRegion(); 503fec6c5acSUday Bondhugula return parser.parseRegion(*body, argInfo, argType, 504fec6c5acSUday Bondhugula /*enableNameShadowing=*/true); 505fec6c5acSUday Bondhugula } 506fec6c5acSUday Bondhugula 507fec6c5acSUday Bondhugula static void print(OpAsmPrinter &p, IsolatedRegionOp op) { 508fec6c5acSUday Bondhugula p << "test.isolated_region "; 509fec6c5acSUday Bondhugula p.printOperand(op.getOperand()); 510fec6c5acSUday Bondhugula p.shadowRegionArgs(op.region(), op.getOperand()); 511fec6c5acSUday Bondhugula p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 512fec6c5acSUday Bondhugula } 513fec6c5acSUday Bondhugula 514fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 51562828865SStephen Neuendorffer // Test SSACFGRegionOp 51662828865SStephen Neuendorffer //===----------------------------------------------------------------------===// 51762828865SStephen Neuendorffer 51862828865SStephen Neuendorffer RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 51962828865SStephen Neuendorffer return RegionKind::SSACFG; 52062828865SStephen Neuendorffer } 52162828865SStephen Neuendorffer 52262828865SStephen Neuendorffer //===----------------------------------------------------------------------===// 52362828865SStephen Neuendorffer // Test GraphRegionOp 52462828865SStephen Neuendorffer //===----------------------------------------------------------------------===// 52562828865SStephen Neuendorffer 52662828865SStephen Neuendorffer static ParseResult parseGraphRegionOp(OpAsmParser &parser, 52762828865SStephen Neuendorffer OperationState &result) { 52862828865SStephen Neuendorffer // Parse the body region, and reuse the operand info as the argument info. 52962828865SStephen Neuendorffer Region *body = result.addRegion(); 53062828865SStephen Neuendorffer return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 53162828865SStephen Neuendorffer } 53262828865SStephen Neuendorffer 53362828865SStephen Neuendorffer static void print(OpAsmPrinter &p, GraphRegionOp op) { 53462828865SStephen Neuendorffer p << "test.graph_region "; 53562828865SStephen Neuendorffer p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 53662828865SStephen Neuendorffer } 53762828865SStephen Neuendorffer 53862828865SStephen Neuendorffer RegionKind GraphRegionOp::getRegionKind(unsigned index) { 53962828865SStephen Neuendorffer return RegionKind::Graph; 54062828865SStephen Neuendorffer } 54162828865SStephen Neuendorffer 54262828865SStephen Neuendorffer //===----------------------------------------------------------------------===// 54357d361bdSUday Bondhugula // Test AffineScopeOp 54448034538SUday Bondhugula //===----------------------------------------------------------------------===// 54548034538SUday Bondhugula 54657d361bdSUday Bondhugula static ParseResult parseAffineScopeOp(OpAsmParser &parser, 54748034538SUday Bondhugula OperationState &result) { 54848034538SUday Bondhugula // Parse the body region, and reuse the operand info as the argument info. 54948034538SUday Bondhugula Region *body = result.addRegion(); 55048034538SUday Bondhugula return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 55148034538SUday Bondhugula } 55248034538SUday Bondhugula 55357d361bdSUday Bondhugula static void print(OpAsmPrinter &p, AffineScopeOp op) { 55457d361bdSUday Bondhugula p << "test.affine_scope "; 55548034538SUday Bondhugula p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 55648034538SUday Bondhugula } 55748034538SUday Bondhugula 55848034538SUday Bondhugula //===----------------------------------------------------------------------===// 559fec6c5acSUday Bondhugula // Test parser. 560fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 561fec6c5acSUday Bondhugula 562fec6c5acSUday Bondhugula static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, 563fec6c5acSUday Bondhugula OperationState &result) { 564fec6c5acSUday Bondhugula StringRef keyword; 565fec6c5acSUday Bondhugula if (parser.parseKeyword(&keyword)) 566fec6c5acSUday Bondhugula return failure(); 567fec6c5acSUday Bondhugula result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); 568fec6c5acSUday Bondhugula return success(); 569fec6c5acSUday Bondhugula } 570fec6c5acSUday Bondhugula 571fec6c5acSUday Bondhugula static void print(OpAsmPrinter &p, WrappedKeywordOp op) { 572fec6c5acSUday Bondhugula p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); 573fec6c5acSUday Bondhugula } 574fec6c5acSUday Bondhugula 575fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 576fec6c5acSUday Bondhugula // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. 577fec6c5acSUday Bondhugula 578fec6c5acSUday Bondhugula static ParseResult parseWrappingRegionOp(OpAsmParser &parser, 579fec6c5acSUday Bondhugula OperationState &result) { 580fec6c5acSUday Bondhugula if (parser.parseKeyword("wraps")) 581fec6c5acSUday Bondhugula return failure(); 582fec6c5acSUday Bondhugula 583fec6c5acSUday Bondhugula // Parse the wrapped op in a region 584fec6c5acSUday Bondhugula Region &body = *result.addRegion(); 585fec6c5acSUday Bondhugula body.push_back(new Block); 586fec6c5acSUday Bondhugula Block &block = body.back(); 587fec6c5acSUday Bondhugula Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); 588fec6c5acSUday Bondhugula if (!wrapped_op) 589fec6c5acSUday Bondhugula return failure(); 590fec6c5acSUday Bondhugula 591fec6c5acSUday Bondhugula // Create a return terminator in the inner region, pass as operand to the 592fec6c5acSUday Bondhugula // terminator the returned values from the wrapped operation. 593fec6c5acSUday Bondhugula SmallVector<Value, 8> return_operands(wrapped_op->getResults()); 594fec6c5acSUday Bondhugula OpBuilder builder(parser.getBuilder().getContext()); 595fec6c5acSUday Bondhugula builder.setInsertionPointToEnd(&block); 596fec6c5acSUday Bondhugula builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); 597fec6c5acSUday Bondhugula 598fec6c5acSUday Bondhugula // Get the results type for the wrapping op from the terminator operands. 599fec6c5acSUday Bondhugula Operation &return_op = body.back().back(); 600fec6c5acSUday Bondhugula result.types.append(return_op.operand_type_begin(), 601fec6c5acSUday Bondhugula return_op.operand_type_end()); 602fec6c5acSUday Bondhugula 603fec6c5acSUday Bondhugula // Use the location of the wrapped op for the "test.wrapping_region" op. 604fec6c5acSUday Bondhugula result.location = wrapped_op->getLoc(); 605fec6c5acSUday Bondhugula 606fec6c5acSUday Bondhugula return success(); 607fec6c5acSUday Bondhugula } 608fec6c5acSUday Bondhugula 609fec6c5acSUday Bondhugula static void print(OpAsmPrinter &p, WrappingRegionOp op) { 610fec6c5acSUday Bondhugula p << op.getOperationName() << " wraps "; 611fec6c5acSUday Bondhugula p.printGenericOp(&op.region().front().front()); 612fec6c5acSUday Bondhugula } 613fec6c5acSUday Bondhugula 614fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 615fec6c5acSUday Bondhugula // Test PolyForOp - parse list of region arguments. 616fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 617fec6c5acSUday Bondhugula 618fec6c5acSUday Bondhugula static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { 619fec6c5acSUday Bondhugula SmallVector<OpAsmParser::OperandType, 4> ivsInfo; 620fec6c5acSUday Bondhugula // Parse list of region arguments without a delimiter. 621fec6c5acSUday Bondhugula if (parser.parseRegionArgumentList(ivsInfo)) 622fec6c5acSUday Bondhugula return failure(); 623fec6c5acSUday Bondhugula 624fec6c5acSUday Bondhugula // Parse the body region. 625fec6c5acSUday Bondhugula Region *body = result.addRegion(); 626fec6c5acSUday Bondhugula auto &builder = parser.getBuilder(); 627fec6c5acSUday Bondhugula SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType()); 628fec6c5acSUday Bondhugula return parser.parseRegion(*body, ivsInfo, argTypes); 629fec6c5acSUday Bondhugula } 630fec6c5acSUday Bondhugula 631fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 632fec6c5acSUday Bondhugula // Test removing op with inner ops. 633fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 634fec6c5acSUday Bondhugula 635fec6c5acSUday Bondhugula namespace { 636fec6c5acSUday Bondhugula struct TestRemoveOpWithInnerOps 637fec6c5acSUday Bondhugula : public OpRewritePattern<TestOpWithRegionPattern> { 638fec6c5acSUday Bondhugula using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 639fec6c5acSUday Bondhugula 640fec6c5acSUday Bondhugula LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 641fec6c5acSUday Bondhugula PatternRewriter &rewriter) const override { 642fec6c5acSUday Bondhugula rewriter.eraseOp(op); 643fec6c5acSUday Bondhugula return success(); 644fec6c5acSUday Bondhugula } 645fec6c5acSUday Bondhugula }; 646fec6c5acSUday Bondhugula } // end anonymous namespace 647fec6c5acSUday Bondhugula 648fec6c5acSUday Bondhugula void TestOpWithRegionPattern::getCanonicalizationPatterns( 649fec6c5acSUday Bondhugula OwningRewritePatternList &results, MLIRContext *context) { 650fec6c5acSUday Bondhugula results.insert<TestRemoveOpWithInnerOps>(context); 651fec6c5acSUday Bondhugula } 652fec6c5acSUday Bondhugula 653fec6c5acSUday Bondhugula OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) { 654fec6c5acSUday Bondhugula return operand(); 655fec6c5acSUday Bondhugula } 656fec6c5acSUday Bondhugula 6572bf423b0SRob Suderman OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) { 6582bf423b0SRob Suderman return getValue(); 6592bf423b0SRob Suderman } 6602bf423b0SRob Suderman 661fec6c5acSUday Bondhugula LogicalResult TestOpWithVariadicResultsAndFolder::fold( 662fec6c5acSUday Bondhugula ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) { 663fec6c5acSUday Bondhugula for (Value input : this->operands()) { 664fec6c5acSUday Bondhugula results.push_back(input); 665fec6c5acSUday Bondhugula } 666fec6c5acSUday Bondhugula return success(); 667fec6c5acSUday Bondhugula } 668fec6c5acSUday Bondhugula 66926f93d9fSAlex Zinenko OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) { 67026f93d9fSAlex Zinenko assert(operands.size() == 1); 67126f93d9fSAlex Zinenko if (operands.front()) { 67226f93d9fSAlex Zinenko setAttr("attr", operands.front()); 67326f93d9fSAlex Zinenko return getResult(); 67426f93d9fSAlex Zinenko } 67526f93d9fSAlex Zinenko return {}; 67626f93d9fSAlex Zinenko } 67726f93d9fSAlex Zinenko 67862828865SStephen Neuendorffer LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 679fec6c5acSUday Bondhugula MLIRContext *, Optional<Location> location, ValueRange operands, 6805eae715aSJacques Pienaar DictionaryAttr attributes, RegionRange regions, 681fec6c5acSUday Bondhugula SmallVectorImpl<Type> &inferredReturnTypes) { 682fec6c5acSUday Bondhugula if (operands[0].getType() != operands[1].getType()) { 683fec6c5acSUday Bondhugula return emitOptionalError(location, "operand type mismatch ", 684fec6c5acSUday Bondhugula operands[0].getType(), " vs ", 685fec6c5acSUday Bondhugula operands[1].getType()); 686fec6c5acSUday Bondhugula } 687fec6c5acSUday Bondhugula inferredReturnTypes.assign({operands[0].getType()}); 688fec6c5acSUday Bondhugula return success(); 689fec6c5acSUday Bondhugula } 690fec6c5acSUday Bondhugula 691fec6c5acSUday Bondhugula LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 692fec6c5acSUday Bondhugula MLIRContext *context, Optional<Location> location, ValueRange operands, 6935eae715aSJacques Pienaar DictionaryAttr attributes, RegionRange regions, 694fec6c5acSUday Bondhugula SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 695fec6c5acSUday Bondhugula // Create return type consisting of the last element of the first operand. 696fec6c5acSUday Bondhugula auto operandType = *operands.getTypes().begin(); 697fec6c5acSUday Bondhugula auto sval = operandType.dyn_cast<ShapedType>(); 698fec6c5acSUday Bondhugula if (!sval) { 699fec6c5acSUday Bondhugula return emitOptionalError(location, "only shaped type operands allowed"); 700fec6c5acSUday Bondhugula } 701fec6c5acSUday Bondhugula int64_t dim = 702fec6c5acSUday Bondhugula sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; 703fec6c5acSUday Bondhugula auto type = IntegerType::get(17, context); 704fec6c5acSUday Bondhugula inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); 705fec6c5acSUday Bondhugula return success(); 706fec6c5acSUday Bondhugula } 707fec6c5acSUday Bondhugula 708fec6c5acSUday Bondhugula LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 709fec6c5acSUday Bondhugula OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) { 710fec6c5acSUday Bondhugula shapes = SmallVector<Value, 1>{ 71162828865SStephen Neuendorffer builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)}; 712fec6c5acSUday Bondhugula return success(); 713fec6c5acSUday Bondhugula } 714fec6c5acSUday Bondhugula 715fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 716fec6c5acSUday Bondhugula // Test SideEffect interfaces 717fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 718fec6c5acSUday Bondhugula 719fec6c5acSUday Bondhugula namespace { 720fec6c5acSUday Bondhugula /// A test resource for side effects. 721fec6c5acSUday Bondhugula struct TestResource : public SideEffects::Resource::Base<TestResource> { 722fec6c5acSUday Bondhugula StringRef getName() final { return "<Test>"; } 723fec6c5acSUday Bondhugula }; 724fec6c5acSUday Bondhugula } // end anonymous namespace 725fec6c5acSUday Bondhugula 726fec6c5acSUday Bondhugula void SideEffectOp::getEffects( 727fec6c5acSUday Bondhugula SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 728fec6c5acSUday Bondhugula // Check for an effects attribute on the op instance. 729fec6c5acSUday Bondhugula ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects"); 730fec6c5acSUday Bondhugula if (!effectsAttr) 731fec6c5acSUday Bondhugula return; 732fec6c5acSUday Bondhugula 733fec6c5acSUday Bondhugula // If there is one, it is an array of dictionary attributes that hold 734fec6c5acSUday Bondhugula // information on the effects of this operation. 735fec6c5acSUday Bondhugula for (Attribute element : effectsAttr) { 736fec6c5acSUday Bondhugula DictionaryAttr effectElement = element.cast<DictionaryAttr>(); 737fec6c5acSUday Bondhugula 738fec6c5acSUday Bondhugula // Get the specific memory effect. 739fec6c5acSUday Bondhugula MemoryEffects::Effect *effect = 740cc83dc19SChristian Sigg StringSwitch<MemoryEffects::Effect *>( 741fec6c5acSUday Bondhugula effectElement.get("effect").cast<StringAttr>().getValue()) 742fec6c5acSUday Bondhugula .Case("allocate", MemoryEffects::Allocate::get()) 743fec6c5acSUday Bondhugula .Case("free", MemoryEffects::Free::get()) 744fec6c5acSUday Bondhugula .Case("read", MemoryEffects::Read::get()) 745fec6c5acSUday Bondhugula .Case("write", MemoryEffects::Write::get()); 746fec6c5acSUday Bondhugula 747fec6c5acSUday Bondhugula // Check for a result to affect. 748fec6c5acSUday Bondhugula Value value; 749fec6c5acSUday Bondhugula if (effectElement.get("on_result")) 750fec6c5acSUday Bondhugula value = getResult(); 751fec6c5acSUday Bondhugula 752fec6c5acSUday Bondhugula // Check for a non-default resource to use. 753fec6c5acSUday Bondhugula SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 754fec6c5acSUday Bondhugula if (effectElement.get("test_resource")) 755fec6c5acSUday Bondhugula resource = TestResource::get(); 756fec6c5acSUday Bondhugula 757fec6c5acSUday Bondhugula effects.emplace_back(effect, value, resource); 758fec6c5acSUday Bondhugula } 759fec6c5acSUday Bondhugula } 760fec6c5acSUday Bondhugula 761*052d24afSAlex Zinenko void SideEffectOp::getEffects( 762*052d24afSAlex Zinenko SmallVectorImpl<TestEffects::EffectInstance> &effects) { 763*052d24afSAlex Zinenko auto effectsAttr = getAttrOfType<AffineMapAttr>("effect_parameter"); 764*052d24afSAlex Zinenko if (!effectsAttr) 765*052d24afSAlex Zinenko return; 766*052d24afSAlex Zinenko 767*052d24afSAlex Zinenko effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); 768*052d24afSAlex Zinenko } 769*052d24afSAlex Zinenko 770fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 771fec6c5acSUday Bondhugula // StringAttrPrettyNameOp 772fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 773fec6c5acSUday Bondhugula 774fec6c5acSUday Bondhugula // This op has fancy handling of its SSA result name. 775fec6c5acSUday Bondhugula static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, 776fec6c5acSUday Bondhugula OperationState &result) { 777fec6c5acSUday Bondhugula // Add the result types. 778fec6c5acSUday Bondhugula for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 779fec6c5acSUday Bondhugula result.addTypes(parser.getBuilder().getIntegerType(32)); 780fec6c5acSUday Bondhugula 781fec6c5acSUday Bondhugula if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 782fec6c5acSUday Bondhugula return failure(); 783fec6c5acSUday Bondhugula 784fec6c5acSUday Bondhugula // If the attribute dictionary contains no 'names' attribute, infer it from 785fec6c5acSUday Bondhugula // the SSA name (if specified). 786fec6c5acSUday Bondhugula bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 78774e6a5b2SChris Lattner return attr.first == "names"; 788fec6c5acSUday Bondhugula }); 789fec6c5acSUday Bondhugula 790fec6c5acSUday Bondhugula // If there was no name specified, check to see if there was a useful name 791fec6c5acSUday Bondhugula // specified in the asm file. 792fec6c5acSUday Bondhugula if (hadNames || parser.getNumResults() == 0) 793fec6c5acSUday Bondhugula return success(); 794fec6c5acSUday Bondhugula 795fec6c5acSUday Bondhugula SmallVector<StringRef, 4> names; 796fec6c5acSUday Bondhugula auto *context = result.getContext(); 797fec6c5acSUday Bondhugula 798fec6c5acSUday Bondhugula for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 799fec6c5acSUday Bondhugula auto resultName = parser.getResultName(i); 800fec6c5acSUday Bondhugula StringRef nameStr; 801fec6c5acSUday Bondhugula if (!resultName.first.empty() && !isdigit(resultName.first[0])) 802fec6c5acSUday Bondhugula nameStr = resultName.first; 803fec6c5acSUday Bondhugula 804fec6c5acSUday Bondhugula names.push_back(nameStr); 805fec6c5acSUday Bondhugula } 806fec6c5acSUday Bondhugula 807fec6c5acSUday Bondhugula auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 808fec6c5acSUday Bondhugula result.attributes.push_back({Identifier::get("names", context), namesAttr}); 809fec6c5acSUday Bondhugula return success(); 810fec6c5acSUday Bondhugula } 811fec6c5acSUday Bondhugula 812fec6c5acSUday Bondhugula static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { 813fec6c5acSUday Bondhugula p << "test.string_attr_pretty_name"; 814fec6c5acSUday Bondhugula 815fec6c5acSUday Bondhugula // Note that we only need to print the "name" attribute if the asmprinter 816fec6c5acSUday Bondhugula // result name disagrees with it. This can happen in strange cases, e.g. 817fec6c5acSUday Bondhugula // when there are conflicts. 818fec6c5acSUday Bondhugula bool namesDisagree = op.names().size() != op.getNumResults(); 819fec6c5acSUday Bondhugula 820fec6c5acSUday Bondhugula SmallString<32> resultNameStr; 821fec6c5acSUday Bondhugula for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { 822fec6c5acSUday Bondhugula resultNameStr.clear(); 823fec6c5acSUday Bondhugula llvm::raw_svector_ostream tmpStream(resultNameStr); 824fec6c5acSUday Bondhugula p.printOperand(op.getResult(i), tmpStream); 825fec6c5acSUday Bondhugula 826fec6c5acSUday Bondhugula auto expectedName = op.names()[i].dyn_cast<StringAttr>(); 827fec6c5acSUday Bondhugula if (!expectedName || 828fec6c5acSUday Bondhugula tmpStream.str().drop_front() != expectedName.getValue()) { 829fec6c5acSUday Bondhugula namesDisagree = true; 830fec6c5acSUday Bondhugula } 831fec6c5acSUday Bondhugula } 832fec6c5acSUday Bondhugula 833fec6c5acSUday Bondhugula if (namesDisagree) 834fec6c5acSUday Bondhugula p.printOptionalAttrDictWithKeyword(op.getAttrs()); 835fec6c5acSUday Bondhugula else 836fec6c5acSUday Bondhugula p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); 837fec6c5acSUday Bondhugula } 838fec6c5acSUday Bondhugula 839fec6c5acSUday Bondhugula // We set the SSA name in the asm syntax to the contents of the name 840fec6c5acSUday Bondhugula // attribute. 841fec6c5acSUday Bondhugula void StringAttrPrettyNameOp::getAsmResultNames( 842fec6c5acSUday Bondhugula function_ref<void(Value, StringRef)> setNameFn) { 843fec6c5acSUday Bondhugula 844fec6c5acSUday Bondhugula auto value = names(); 845fec6c5acSUday Bondhugula for (size_t i = 0, e = value.size(); i != e; ++i) 846fec6c5acSUday Bondhugula if (auto str = value[i].dyn_cast<StringAttr>()) 847fec6c5acSUday Bondhugula if (!str.getValue().empty()) 848fec6c5acSUday Bondhugula setNameFn(getResult(i), str.getValue()); 849fec6c5acSUday Bondhugula } 850fec6c5acSUday Bondhugula 851fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 8526f5da84fSMarcel Koester // RegionIfOp 8536f5da84fSMarcel Koester //===----------------------------------------------------------------------===// 8546f5da84fSMarcel Koester 8556f5da84fSMarcel Koester static void print(OpAsmPrinter &p, RegionIfOp op) { 8566f5da84fSMarcel Koester p << RegionIfOp::getOperationName() << " "; 8576f5da84fSMarcel Koester p.printOperands(op.getOperands()); 8586f5da84fSMarcel Koester p << ": " << op.getOperandTypes(); 8596f5da84fSMarcel Koester p.printArrowTypeList(op.getResultTypes()); 8606f5da84fSMarcel Koester p << " then"; 8616f5da84fSMarcel Koester p.printRegion(op.thenRegion(), 8626f5da84fSMarcel Koester /*printEntryBlockArgs=*/true, 8636f5da84fSMarcel Koester /*printBlockTerminators=*/true); 8646f5da84fSMarcel Koester p << " else"; 8656f5da84fSMarcel Koester p.printRegion(op.elseRegion(), 8666f5da84fSMarcel Koester /*printEntryBlockArgs=*/true, 8676f5da84fSMarcel Koester /*printBlockTerminators=*/true); 8686f5da84fSMarcel Koester p << " join"; 8696f5da84fSMarcel Koester p.printRegion(op.joinRegion(), 8706f5da84fSMarcel Koester /*printEntryBlockArgs=*/true, 8716f5da84fSMarcel Koester /*printBlockTerminators=*/true); 8726f5da84fSMarcel Koester } 8736f5da84fSMarcel Koester 8746f5da84fSMarcel Koester static ParseResult parseRegionIfOp(OpAsmParser &parser, 8756f5da84fSMarcel Koester OperationState &result) { 8766f5da84fSMarcel Koester SmallVector<OpAsmParser::OperandType, 2> operandInfos; 8776f5da84fSMarcel Koester SmallVector<Type, 2> operandTypes; 8786f5da84fSMarcel Koester 8796f5da84fSMarcel Koester result.regions.reserve(3); 8806f5da84fSMarcel Koester Region *thenRegion = result.addRegion(); 8816f5da84fSMarcel Koester Region *elseRegion = result.addRegion(); 8826f5da84fSMarcel Koester Region *joinRegion = result.addRegion(); 8836f5da84fSMarcel Koester 8846f5da84fSMarcel Koester // Parse operand, type and arrow type lists. 8856f5da84fSMarcel Koester if (parser.parseOperandList(operandInfos) || 8866f5da84fSMarcel Koester parser.parseColonTypeList(operandTypes) || 8876f5da84fSMarcel Koester parser.parseArrowTypeList(result.types)) 8886f5da84fSMarcel Koester return failure(); 8896f5da84fSMarcel Koester 8906f5da84fSMarcel Koester // Parse all attached regions. 8916f5da84fSMarcel Koester if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 8926f5da84fSMarcel Koester parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 8936f5da84fSMarcel Koester parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 8946f5da84fSMarcel Koester return failure(); 8956f5da84fSMarcel Koester 8966f5da84fSMarcel Koester return parser.resolveOperands(operandInfos, operandTypes, 8976f5da84fSMarcel Koester parser.getCurrentLocation(), result.operands); 8986f5da84fSMarcel Koester } 8996f5da84fSMarcel Koester 9006f5da84fSMarcel Koester OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { 9016f5da84fSMarcel Koester assert(index < 2 && "invalid region index"); 9026f5da84fSMarcel Koester return getOperands(); 9036f5da84fSMarcel Koester } 9046f5da84fSMarcel Koester 9056f5da84fSMarcel Koester void RegionIfOp::getSuccessorRegions( 9066f5da84fSMarcel Koester Optional<unsigned> index, ArrayRef<Attribute> operands, 9076f5da84fSMarcel Koester SmallVectorImpl<RegionSuccessor> ®ions) { 9086f5da84fSMarcel Koester // We always branch to the join region. 9096f5da84fSMarcel Koester if (index.hasValue()) { 9106f5da84fSMarcel Koester if (index.getValue() < 2) 9116f5da84fSMarcel Koester regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); 9126f5da84fSMarcel Koester else 9136f5da84fSMarcel Koester regions.push_back(RegionSuccessor(getResults())); 9146f5da84fSMarcel Koester return; 9156f5da84fSMarcel Koester } 9166f5da84fSMarcel Koester 9176f5da84fSMarcel Koester // The then and else regions are the entry regions of this op. 9186f5da84fSMarcel Koester regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); 9196f5da84fSMarcel Koester regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); 9206f5da84fSMarcel Koester } 9216f5da84fSMarcel Koester 922fec6c5acSUday Bondhugula #include "TestOpEnums.cpp.inc" 923*052d24afSAlex Zinenko #include "TestOpInterfaces.cpp.inc" 9249c9f479aSSean Silva #include "TestOpStructs.cpp.inc" 9252e2cdd0aSRiver Riddle #include "TestTypeInterfaces.cpp.inc" 926fec6c5acSUday Bondhugula 927fec6c5acSUday Bondhugula #define GET_OP_CLASSES 928fec6c5acSUday Bondhugula #include "TestOps.cpp.inc" 929