1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Traits.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SmallString.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 namespace mlir {
23 namespace async {
24 
25 void AsyncDialect::initialize() {
26   addOperations<
27 #define GET_OP_LIST
28 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
29       >();
30   addTypes<TokenType>();
31   addTypes<ValueType>();
32 }
33 
34 /// Parse a type registered to this dialect.
35 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
36   StringRef keyword;
37   if (parser.parseKeyword(&keyword))
38     return Type();
39 
40   if (keyword == "token")
41     return TokenType::get(getContext());
42 
43   if (keyword == "value") {
44     Type ty;
45     if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
46       parser.emitError(parser.getNameLoc(), "failed to parse async value type");
47       return Type();
48     }
49     return ValueType::get(ty);
50   }
51 
52   parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
53   return Type();
54 }
55 
56 /// Print a type registered to this dialect.
57 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
58   TypeSwitch<Type>(type)
59       .Case<TokenType>([&](TokenType) { os << "token"; })
60       .Case<ValueType>([&](ValueType valueTy) {
61         os << "value<";
62         os.printType(valueTy.getValueType());
63         os << '>';
64       })
65       .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
66 }
67 
68 //===----------------------------------------------------------------------===//
69 /// ValueType
70 //===----------------------------------------------------------------------===//
71 
72 namespace detail {
73 
74 // Storage for `async.value<T>` type, the only member is the wrapped type.
75 struct ValueTypeStorage : public TypeStorage {
76   ValueTypeStorage(Type valueType) : valueType(valueType) {}
77 
78   /// The hash key used for uniquing.
79   using KeyTy = Type;
80   bool operator==(const KeyTy &key) const { return key == valueType; }
81 
82   /// Construction.
83   static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
84                                      Type valueType) {
85     return new (allocator.allocate<ValueTypeStorage>())
86         ValueTypeStorage(valueType);
87   }
88 
89   Type valueType;
90 };
91 
92 } // namespace detail
93 
94 ValueType ValueType::get(Type valueType) {
95   return Base::get(valueType.getContext(), valueType);
96 }
97 
98 Type ValueType::getValueType() { return getImpl()->valueType; }
99 
100 //===----------------------------------------------------------------------===//
101 // YieldOp
102 //===----------------------------------------------------------------------===//
103 
104 static LogicalResult verify(YieldOp op) {
105   // Get the underlying value types from async values returned from the
106   // parent `async.execute` operation.
107   auto executeOp = op.getParentOfType<ExecuteOp>();
108   auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) {
109     return result.getType().cast<ValueType>().getValueType();
110   });
111 
112   if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin()))
113     return op.emitOpError("Operand types do not match the types returned from "
114                           "the parent ExecuteOp");
115 
116   return success();
117 }
118 
119 //===----------------------------------------------------------------------===//
120 /// ExecuteOp
121 //===----------------------------------------------------------------------===//
122 
123 static void print(OpAsmPrinter &p, ExecuteOp op) {
124   p << "async.execute ";
125   p.printRegion(op.body());
126   p.printOptionalAttrDict(op.getAttrs());
127   p << " : ";
128   p.printType(op.done().getType());
129   if (!op.values().empty())
130     p << ", ";
131   llvm::interleaveComma(op.values(), p, [&](const OpResult &result) {
132     p.printType(result.getType());
133   });
134 }
135 
136 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
137   MLIRContext *ctx = result.getContext();
138 
139   // Parse asynchronous region.
140   Region *body = result.addRegion();
141   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{},
142                          /*enableNameShadowing=*/false))
143     return failure();
144 
145   // Parse operation attributes.
146   NamedAttrList attrs;
147   if (parser.parseOptionalAttrDict(attrs))
148     return failure();
149   result.addAttributes(attrs);
150 
151   // Parse result types.
152   SmallVector<Type, 4> resultTypes;
153   if (parser.parseColonTypeList(resultTypes))
154     return failure();
155 
156   // First result type must be an async token type.
157   if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx))
158     return failure();
159   parser.addTypesToList(resultTypes, result.types);
160 
161   return success();
162 }
163 
164 } // namespace async
165 } // namespace mlir
166 
167 #define GET_OP_CLASSES
168 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
169