1 #include "MethodMetadata.h"
2 #include "JSIInteropModuleRegistry.h"
3 #include "JavaScriptValue.h"
4 #include "JavaScriptObject.h"
5 #include "JavaScriptTypedArray.h"
6 #include "JavaReferencesCache.h"
7 #include "Exceptions.h"
8 
9 #include <utility>
10 
11 #include "react/jni/ReadableNativeMap.h"
12 #include "react/jni/ReadableNativeArray.h"
13 #include "JSReferencesCache.h"
14 
15 namespace jni = facebook::jni;
16 namespace jsi = facebook::jsi;
17 namespace react = facebook::react;
18 
19 namespace expo {
20 
21 // Modified version of the RN implementation
22 // https://github.com/facebook/react-native/blob/7dceb9b63c0bfd5b13bf6d26f9530729506e9097/ReactCommon/react/nativemodule/core/platform/android/ReactCommon/JavaTurboModule.cpp#L57
23 jni::local_ref<react::JCxxCallbackImpl::JavaPart> createJavaCallbackFromJSIFunction(
24   jsi::Function &&function,
25   jsi::Runtime &rt,
26   std::shared_ptr<react::CallInvoker> jsInvoker
27 ) {
28   auto weakWrapper = react::CallbackWrapper::createWeak(std::move(function), rt,
29                                                         std::move(jsInvoker));
30 
31   // This needs to be a shared_ptr because:
32   // 1. It cannot be unique_ptr. std::function is copyable but unique_ptr is
33   // not.
34   // 2. It cannot be weak_ptr since we need this object to live on.
35   // 3. It cannot be a value, because that would be deleted as soon as this
36   // function returns.
37   auto callbackWrapperOwner =
38     std::make_shared<react::RAIICallbackWrapperDestroyer>(weakWrapper);
39 
40   std::function<void(folly::dynamic)> fn =
41     [weakWrapper, callbackWrapperOwner, wrapperWasCalled = false](
42       folly::dynamic responses) mutable {
43       if (wrapperWasCalled) {
44         throw std::runtime_error(
45           "callback 2 arg cannot be called more than once");
46       }
47 
48       auto strongWrapper = weakWrapper.lock();
49       if (!strongWrapper) {
50         return;
51       }
52 
53       strongWrapper->jsInvoker().invokeAsync(
54         [weakWrapper, callbackWrapperOwner, responses]() mutable {
55           auto strongWrapper2 = weakWrapper.lock();
56           if (!strongWrapper2) {
57             return;
58           }
59 
60           jsi::Value args =
61             jsi::valueFromDynamic(strongWrapper2->runtime(), responses);
62           auto argsArray = args.getObject(strongWrapper2->runtime())
63             .asArray(strongWrapper2->runtime());
64           jsi::Value arg = argsArray.getValueAtIndex(strongWrapper2->runtime(), 0);
65 
66           strongWrapper2->callback().call(
67             strongWrapper2->runtime(),
68             (const jsi::Value *) &arg,
69             (size_t) 1
70           );
71 
72           callbackWrapperOwner.reset();
73         });
74 
75       wrapperWasCalled = true;
76     };
77 
78   return react::JCxxCallbackImpl::newObjectCxxArgs(fn);
79 }
80 
81 std::vector<jobject> MethodMetadata::convertJSIArgsToJNI(
82   JSIInteropModuleRegistry *moduleRegistry,
83   JNIEnv *env,
84   jsi::Runtime &rt,
85   const jsi::Value *args,
86   size_t count,
87   bool returnGlobalReferences
88 ) {
89   std::vector<jobject> result(count);
90 
91   auto makeGlobalIfNecessary = [env, returnGlobalReferences](jobject obj) -> jobject {
92     if (returnGlobalReferences) {
93       return env->NewGlobalRef(obj);
94     }
95     return obj;
96   };
97 
98   for (unsigned int argIndex = 0; argIndex < count; argIndex++) {
99     const jsi::Value *arg = &args[argIndex];
100     jobject *jarg = &result[argIndex];
101     int desiredType = desiredTypes[argIndex];
102 
103     if (desiredType & CppType::JS_VALUE) {
104       *jarg = makeGlobalIfNecessary(
105         JavaScriptValue::newObjectCxxArgs(
106           moduleRegistry->runtimeHolder->weak_from_this(),
107           // TODO(@lukmccall): make sure that copy here is necessary
108           std::make_shared<jsi::Value>(jsi::Value(rt, *arg))
109         ).release()
110       );
111     } else if (desiredType & CppType::JS_OBJECT) {
112       *jarg = makeGlobalIfNecessary(
113         JavaScriptObject::newObjectCxxArgs(
114           moduleRegistry->runtimeHolder->weak_from_this(),
115           std::make_shared<jsi::Object>(arg->getObject(rt))
116         ).release()
117       );
118     } else if (desiredType & CppType::TYPED_ARRAY) {
119       *jarg = makeGlobalIfNecessary(
120         JavaScriptTypedArray::newObjectCxxArgs(
121           moduleRegistry->runtimeHolder->weak_from_this(),
122           std::make_shared<jsi::Object>(arg->getObject(rt))
123         ).release()
124       );
125     } else if (arg->isNull() || arg->isUndefined()) {
126       *jarg = nullptr;
127     } else if (arg->isNumber()) {
128       if (desiredType & CppType::INT) {
129         auto &integerClass = JavaReferencesCache::instance()
130           ->getJClass("java/lang/Integer");
131         jmethodID integerConstructor = integerClass.getMethod("<init>", "(I)V");
132         *jarg = makeGlobalIfNecessary(
133           env->NewObject(integerClass.clazz, integerConstructor,
134                          static_cast<int>(arg->getNumber())));
135       } else if (desiredType & CppType::FLOAT) {
136         auto &floatClass = JavaReferencesCache::instance()
137           ->getJClass("java/lang/Float");
138         jmethodID floatConstructor = floatClass.getMethod("<init>", "(F)V");
139         *jarg = makeGlobalIfNecessary(
140           env->NewObject(floatClass.clazz, floatConstructor, static_cast<float>(arg->getNumber())));
141       } else {
142         auto &doubleClass = JavaReferencesCache::instance()
143           ->getJClass("java/lang/Double");
144         jmethodID doubleConstructor = doubleClass.getMethod("<init>", "(D)V");
145         *jarg = makeGlobalIfNecessary(
146           env->NewObject(doubleClass.clazz, doubleConstructor, arg->getNumber()));
147       }
148     } else if (arg->isBool()) {
149       auto &booleanClass = JavaReferencesCache::instance()
150         ->getJClass("java/lang/Boolean");
151       jmethodID booleanConstructor = booleanClass.getMethod("<init>", "(Z)V");
152       *jarg = makeGlobalIfNecessary(
153         env->NewObject(booleanClass.clazz, booleanConstructor, arg->getBool()));
154     } else if (arg->isString()) {
155       *jarg = makeGlobalIfNecessary(env->NewStringUTF(arg->getString(rt).utf8(rt).c_str()));
156     } else if (arg->isObject()) {
157       const jsi::Object object = arg->getObject(rt);
158 
159       // TODO(@lukmccall): stop using dynamic
160       auto dynamic = jsi::dynamicFromValue(rt, *arg);
161       if (arg->getObject(rt).isArray(rt)) {
162         *jarg = makeGlobalIfNecessary(
163           react::ReadableNativeArray::newObjectCxxArgs(std::move(dynamic)).release());
164       } else {
165         *jarg = makeGlobalIfNecessary(
166           react::ReadableNativeMap::createWithContents(std::move(dynamic)).release());
167       }
168     } else {
169       auto stringRepresentation = arg->toString(rt).utf8(rt);
170       jni::throwNewJavaException(
171         UnexpectedException::create(
172           "Cannot convert '" + stringRepresentation + "' to a Kotlin type.").get()
173       );
174     }
175   }
176 
177   return result;
178 }
179 
180 MethodMetadata::MethodMetadata(
181   std::string name,
182   int args,
183   bool isAsync,
184   std::unique_ptr<int[]> desiredTypes,
185   jni::global_ref<jobject> &&jBodyReference
186 ) : name(std::move(name)),
187     args(args),
188     isAsync(isAsync),
189     desiredTypes(std::move(desiredTypes)),
190     jBodyReference(std::move(jBodyReference)) {}
191 
192 std::shared_ptr<jsi::Function> MethodMetadata::toJSFunction(
193   jsi::Runtime &runtime,
194   JSIInteropModuleRegistry *moduleRegistry
195 ) {
196   if (body == nullptr) {
197     if (isAsync) {
198       body = std::make_shared<jsi::Function>(toAsyncFunction(runtime, moduleRegistry));
199     } else {
200       body = std::make_shared<jsi::Function>(toSyncFunction(runtime, moduleRegistry));
201     }
202   }
203 
204   return body;
205 }
206 
207 jsi::Function MethodMetadata::toSyncFunction(
208   jsi::Runtime &runtime,
209   JSIInteropModuleRegistry *moduleRegistry
210 ) {
211   return jsi::Function::createFromHostFunction(
212     runtime,
213     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
214     args,
215     [this, moduleRegistry](
216       jsi::Runtime &rt,
217       const jsi::Value &thisValue,
218       const jsi::Value *args,
219       size_t count
220     ) -> jsi::Value {
221       try {
222         return this->callSync(
223           rt,
224           moduleRegistry,
225           args,
226           count
227         );
228       } catch (jni::JniException &jniException) {
229         rethrowAsCodedError(rt, moduleRegistry, jniException);
230       }
231     });
232 }
233 
234 jsi::Value MethodMetadata::callSync(
235   jsi::Runtime &rt,
236   JSIInteropModuleRegistry *moduleRegistry,
237   const jsi::Value *args,
238   size_t count
239 ) {
240   if (this->jBodyReference == nullptr) {
241     return jsi::Value::undefined();
242   }
243 
244   JNIEnv *env = jni::Environment::current();
245 
246   /**
247    * This will push a new JNI stack frame for the LocalReferences in this
248    * function call. When the stack frame for this lambda is popped,
249    * all LocalReferences are deleted.
250    */
251   jni::JniLocalScope scope(env, (int) count);
252 
253   std::vector<jobject> convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args, count,
254                                                           false);
255 
256   // TODO(@lukmccall): Remove this temp array
257   auto tempArray = env->NewObjectArray(
258     convertedArgs.size(),
259     JavaReferencesCache::instance()->getJClass("java/lang/Object").clazz,
260     nullptr
261   );
262   for (size_t i = 0; i < convertedArgs.size(); i++) {
263     env->SetObjectArrayElement(tempArray, i, convertedArgs[i]);
264   }
265 
266   // Cast in this place is safe, cause we know that this function is promise-less.
267   auto syncFunction = jni::static_ref_cast<JNIFunctionBody>(this->jBodyReference);
268   auto result = syncFunction->invoke(
269     tempArray
270   );
271 
272   if (result == nullptr) {
273     return jsi::Value::undefined();
274   }
275 
276   return jsi::valueFromDynamic(rt, result->cthis()->consume())
277     .asObject(rt)
278     .asArray(rt)
279     .getValueAtIndex(rt, 0);
280 }
281 
282 jsi::Function MethodMetadata::toAsyncFunction(
283   jsi::Runtime &runtime,
284   JSIInteropModuleRegistry *moduleRegistry
285 ) {
286   return jsi::Function::createFromHostFunction(
287     runtime,
288     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
289     args,
290     [this, moduleRegistry](
291       jsi::Runtime &rt,
292       const jsi::Value &thisValue,
293       const jsi::Value *args,
294       size_t count
295     ) -> jsi::Value {
296       JNIEnv *env = jni::Environment::current();
297 
298       /**
299        * This will push a new JNI stack frame for the LocalReferences in this
300        * function call. When the stack frame for this lambda is popped,
301        * all LocalReferences are deleted.
302        */
303       jni::JniLocalScope scope(env, (int) count);
304 
305       try {
306         std::vector<jobject> convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args,
307                                                                 count,
308                                                                 true);
309         auto &Promise = moduleRegistry->jsRegistry->getObject<jsi::Function>(
310           JSReferencesCache::JSKeys::PROMISE
311         );
312         // Creates a JSI promise
313         jsi::Value promise = Promise.callAsConstructor(
314           rt,
315           createPromiseBody(rt, moduleRegistry, std::move(convertedArgs))
316         );
317         return promise;
318       } catch (jni::JniException &jniException) {
319         rethrowAsCodedError(rt, moduleRegistry, jniException);
320       }
321     }
322   );
323 }
324 
325 jsi::Function MethodMetadata::createPromiseBody(
326   jsi::Runtime &runtime,
327   JSIInteropModuleRegistry *moduleRegistry,
328   std::vector<jobject> &&args
329 ) {
330   return jsi::Function::createFromHostFunction(
331     runtime,
332     moduleRegistry->jsRegistry->getPropNameID(runtime, "promiseFn"),
333     2,
334     [this, args = std::move(args), moduleRegistry](
335       jsi::Runtime &rt,
336       const jsi::Value &thisVal,
337       const jsi::Value *promiseConstructorArgs,
338       size_t promiseConstructorArgCount
339     ) {
340       if (promiseConstructorArgCount != 2) {
341         throw std::invalid_argument("Promise fn arg count must be 2");
342       }
343 
344       jsi::Function resolveJSIFn = promiseConstructorArgs[0].getObject(rt).getFunction(rt);
345       jsi::Function rejectJSIFn = promiseConstructorArgs[1].getObject(rt).getFunction(rt);
346 
347       auto &runtimeHolder = moduleRegistry->runtimeHolder;
348       jobject resolve = createJavaCallbackFromJSIFunction(
349         std::move(resolveJSIFn),
350         rt,
351         runtimeHolder->jsInvoker
352       ).release();
353 
354       jobject reject = createJavaCallbackFromJSIFunction(
355         std::move(rejectJSIFn),
356         rt,
357         runtimeHolder->jsInvoker
358       ).release();
359 
360       JNIEnv *env = jni::Environment::current();
361 
362       auto &jPromise = JavaReferencesCache::instance()->getJClass(
363         "com/facebook/react/bridge/PromiseImpl");
364       jmethodID jPromiseConstructor = jPromise.getMethod(
365         "<init>",
366         "(Lcom/facebook/react/bridge/Callback;Lcom/facebook/react/bridge/Callback;)V"
367       );
368 
369       // Creates a promise object
370       jobject promise = env->NewObject(
371         jPromise.clazz,
372         jPromiseConstructor,
373         resolve,
374         reject
375       );
376 
377       auto argsSize = args.size();
378       // TODO(@lukmccall): Remove this temp array
379       auto tempArray = env->NewObjectArray(
380         argsSize,
381         JavaReferencesCache::instance()->getJClass("java/lang/Object").clazz,
382         nullptr
383       );
384       for (size_t i = 0; i < argsSize; i++) {
385         env->SetObjectArrayElement(tempArray, i, args[i]);
386       }
387 
388       // Cast in this place is safe, cause we know that this function expects promise.
389       auto asyncFunction = jni::static_ref_cast<JNIAsyncFunctionBody>(this->jBodyReference);
390       asyncFunction->invoke(
391         tempArray,
392         promise
393       );
394 
395       // We have to remove the local reference to the promise object.
396       // It doesn't mean that the promise will be deallocated, but rather that we move
397       // the ownership to the `JNIAsyncFunctionBody`.
398       env->DeleteLocalRef(promise);
399 
400       for (const auto &arg: args) {
401         env->DeleteGlobalRef(arg);
402       }
403       env->DeleteLocalRef(tempArray);
404 
405       return jsi::Value::undefined();
406     }
407   );
408 }
409 
410 } // namespace expo
411