1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Traits.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/IR/StandardTypes.h" 17 #include "mlir/Transforms/InliningUtils.h" 18 #include "llvm/ADT/SmallString.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 #include "llvm/Support/raw_ostream.h" 21 22 namespace mlir { 23 namespace async { 24 25 void AsyncDialect::initialize() { 26 addOperations< 27 #define GET_OP_LIST 28 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 29 >(); 30 addTypes<TokenType>(); 31 addTypes<ValueType>(); 32 } 33 34 /// Parse a type registered to this dialect. 35 Type AsyncDialect::parseType(DialectAsmParser &parser) const { 36 StringRef keyword; 37 if (parser.parseKeyword(&keyword)) 38 return Type(); 39 40 if (keyword == "token") 41 return TokenType::get(getContext()); 42 43 if (keyword == "value") { 44 Type ty; 45 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 46 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 47 return Type(); 48 } 49 return ValueType::get(ty); 50 } 51 52 parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword; 53 return Type(); 54 } 55 56 /// Print a type registered to this dialect. 57 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { 58 TypeSwitch<Type>(type) 59 .Case<TokenType>([&](TokenType) { os << "token"; }) 60 .Case<ValueType>([&](ValueType valueTy) { 61 os << "value<"; 62 os.printType(valueTy.getValueType()); 63 os << '>'; 64 }) 65 .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); }); 66 } 67 68 //===----------------------------------------------------------------------===// 69 /// ValueType 70 //===----------------------------------------------------------------------===// 71 72 namespace detail { 73 74 // Storage for `async.value<T>` type, the only member is the wrapped type. 75 struct ValueTypeStorage : public TypeStorage { 76 ValueTypeStorage(Type valueType) : valueType(valueType) {} 77 78 /// The hash key used for uniquing. 79 using KeyTy = Type; 80 bool operator==(const KeyTy &key) const { return key == valueType; } 81 82 /// Construction. 83 static ValueTypeStorage *construct(TypeStorageAllocator &allocator, 84 Type valueType) { 85 return new (allocator.allocate<ValueTypeStorage>()) 86 ValueTypeStorage(valueType); 87 } 88 89 Type valueType; 90 }; 91 92 } // namespace detail 93 94 ValueType ValueType::get(Type valueType) { 95 return Base::get(valueType.getContext(), valueType); 96 } 97 98 Type ValueType::getValueType() { return getImpl()->valueType; } 99 100 //===----------------------------------------------------------------------===// 101 // YieldOp 102 //===----------------------------------------------------------------------===// 103 104 static LogicalResult verify(YieldOp op) { 105 // Get the underlying value types from async values returned from the 106 // parent `async.execute` operation. 107 auto executeOp = op.getParentOfType<ExecuteOp>(); 108 auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) { 109 return result.getType().cast<ValueType>().getValueType(); 110 }); 111 112 if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin())) 113 return op.emitOpError("Operand types do not match the types returned from " 114 "the parent ExecuteOp"); 115 116 return success(); 117 } 118 119 //===----------------------------------------------------------------------===// 120 /// ExecuteOp 121 //===----------------------------------------------------------------------===// 122 123 static void print(OpAsmPrinter &p, ExecuteOp op) { 124 p << "async.execute "; 125 p.printRegion(op.body()); 126 p.printOptionalAttrDict(op.getAttrs()); 127 p << " : "; 128 p.printType(op.done().getType()); 129 if (!op.values().empty()) 130 p << ", "; 131 llvm::interleaveComma(op.values(), p, [&](const OpResult &result) { 132 p.printType(result.getType()); 133 }); 134 } 135 136 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { 137 MLIRContext *ctx = result.getContext(); 138 139 // Parse asynchronous region. 140 Region *body = result.addRegion(); 141 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}, 142 /*enableNameShadowing=*/false)) 143 return failure(); 144 145 // Parse operation attributes. 146 NamedAttrList attrs; 147 if (parser.parseOptionalAttrDict(attrs)) 148 return failure(); 149 result.addAttributes(attrs); 150 151 // Parse result types. 152 SmallVector<Type, 4> resultTypes; 153 if (parser.parseColonTypeList(resultTypes)) 154 return failure(); 155 156 // First result type must be an async token type. 157 if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx)) 158 return failure(); 159 parser.addTypesToList(resultTypes, result.types); 160 161 return success(); 162 } 163 164 } // namespace async 165 } // namespace mlir 166 167 #define GET_OP_CLASSES 168 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 169