1 #include "MethodMetadata.h"
2 #include "JSIInteropModuleRegistry.h"
3 #include "JavaScriptValue.h"
4 #include "JavaScriptObject.h"
5 #include "JavaScriptTypedArray.h"
6 #include "JavaReferencesCache.h"
7 #include "Exceptions.h"
8 #include "JavaCallback.h"
9 
10 #include <utility>
11 
12 #include <react/jni/ReadableNativeMap.h>
13 #include <react/jni/ReadableNativeArray.h>
14 #include <react/jni/WritableNativeArray.h>
15 #include <react/jni/WritableNativeMap.h>
16 #include "JSReferencesCache.h"
17 
18 namespace jni = facebook::jni;
19 namespace jsi = facebook::jsi;
20 namespace react = facebook::react;
21 
22 namespace expo {
23 
24 // Modified version of the RN implementation
25 // https://github.com/facebook/react-native/blob/7dceb9b63c0bfd5b13bf6d26f9530729506e9097/ReactCommon/react/nativemodule/core/platform/android/ReactCommon/JavaTurboModule.cpp#L57
26 jni::local_ref<JavaCallback::JavaPart> createJavaCallbackFromJSIFunction(
27   jsi::Function &&function,
28   jsi::Runtime &rt,
29   JSIInteropModuleRegistry *moduleRegistry,
30   bool isRejectCallback = false
31 ) {
32   std::shared_ptr<react::CallInvoker> jsInvoker = moduleRegistry->runtimeHolder->jsInvoker;
33   auto weakWrapper = react::CallbackWrapper::createWeak(std::move(function), rt,
34                                                         std::move(jsInvoker));
35 
36   // This needs to be a shared_ptr because:
37   // 1. It cannot be unique_ptr. std::function is copyable but unique_ptr is
38   // not.
39   // 2. It cannot be weak_ptr since we need this object to live on.
40   // 3. It cannot be a value, because that would be deleted as soon as this
41   // function returns.
42   auto callbackWrapperOwner =
43     std::make_shared<react::RAIICallbackWrapperDestroyer>(weakWrapper);
44 
45   std::function<void(folly::dynamic)> fn =
46     [
47       weakWrapper,
48       callbackWrapperOwner = std::move(callbackWrapperOwner),
49       wrapperWasCalled = false,
50       isRejectCallback
51     ](
52       folly::dynamic responses) mutable {
53       if (wrapperWasCalled) {
54         throw std::runtime_error(
55           "callback 2 arg cannot be called more than once");
56       }
57 
58       auto strongWrapper = weakWrapper.lock();
59       if (!strongWrapper) {
60         return;
61       }
62 
63       strongWrapper->jsInvoker().invokeAsync(
64         [
65           weakWrapper,
66           callbackWrapperOwner = std::move(callbackWrapperOwner),
67           responses = std::move(responses),
68           isRejectCallback
69         ]() mutable {
70           auto strongWrapper2 = weakWrapper.lock();
71           if (!strongWrapper2) {
72             return;
73           }
74 
75           jsi::Value arg = jsi::valueFromDynamic(strongWrapper2->runtime(), responses);
76           if (!isRejectCallback) {
77             strongWrapper2->callback().call(
78               strongWrapper2->runtime(),
79               (const jsi::Value *) &arg,
80               (size_t) 1
81             );
82           } else {
83             auto &rt = strongWrapper2->runtime();
84             auto jsErrorObject = arg.getObject(rt);
85             auto errorCode = jsErrorObject.getProperty(rt, "code").asString(rt);
86             auto message = jsErrorObject.getProperty(rt, "message").asString(rt);
87 
88             auto codedError = makeCodedError(
89               rt,
90               std::move(errorCode),
91               std::move(message)
92             );
93 
94             strongWrapper2->callback().call(
95               strongWrapper2->runtime(),
96               (const jsi::Value *) &codedError,
97               (size_t) 1
98             );
99           }
100 
101           callbackWrapperOwner.reset();
102         });
103 
104       wrapperWasCalled = true;
105     };
106 
107   return JavaCallback::newObjectCxxArgs(std::move(fn));
108 }
109 
110 jobjectArray MethodMetadata::convertJSIArgsToJNI(
111   JSIInteropModuleRegistry *moduleRegistry,
112   JNIEnv *env,
113   jsi::Runtime &rt,
114   const jsi::Value *args,
115   size_t count
116 ) {
117   auto argumentArray = env->NewObjectArray(
118     count,
119     JavaReferencesCache::instance()->getJClass("java/lang/Object").clazz,
120     nullptr
121   );
122 
123   std::vector<jobject> result(count);
124 
125   for (unsigned int argIndex = 0; argIndex < count; argIndex++) {
126     const jsi::Value &arg = args[argIndex];
127     auto &type = argTypes[argIndex];
128     if (arg.isNull() || arg.isUndefined()) {
129       // If value is null or undefined, we just passes a null
130       // Kotlin code will check if expected type is nullable.
131       result[argIndex] = nullptr;
132     } else {
133       if (type->converter->canConvert(rt, arg)) {
134         auto converterValue = type->converter->convert(rt, env, moduleRegistry, arg);
135         env->SetObjectArrayElement(argumentArray, argIndex, converterValue);
136         env->DeleteLocalRef(converterValue);
137       } else {
138         auto stringRepresentation = arg.toString(rt).utf8(rt);
139         throwNewJavaException(
140           UnexpectedException::create(
141             "Cannot convert '" + stringRepresentation + "' to a Kotlin type.").get()
142         );
143       }
144     }
145   }
146 
147   return argumentArray;
148 }
149 
150 MethodMetadata::MethodMetadata(
151   std::string name,
152   int args,
153   bool isAsync,
154   jni::local_ref<jni::JArrayClass<ExpectedType>> expectedArgTypes,
155   jni::global_ref<jobject> &&jBodyReference
156 ) : name(std::move(name)),
157     args(args),
158     isAsync(isAsync),
159     jBodyReference(std::move(jBodyReference)) {
160   argTypes.reserve(args);
161   for (size_t i = 0; i < args; i++) {
162     auto expectedType = expectedArgTypes->getElement(i);
163     argTypes.push_back(
164       std::make_unique<AnyType>(std::move(expectedType))
165     );
166   }
167 }
168 
169 MethodMetadata::MethodMetadata(
170   std::string name,
171   int args,
172   bool isAsync,
173   std::vector<std::unique_ptr<AnyType>> &&expectedArgTypes,
174   jni::global_ref<jobject> &&jBodyReference
175 ) : name(std::move(name)),
176     args(args),
177     isAsync(isAsync),
178     argTypes(std::move(expectedArgTypes)),
179     jBodyReference(std::move(jBodyReference)
180     ) {}
181 
182 std::shared_ptr<jsi::Function> MethodMetadata::toJSFunction(
183   jsi::Runtime &runtime,
184   JSIInteropModuleRegistry *moduleRegistry
185 ) {
186   if (body == nullptr) {
187     if (isAsync) {
188       body = std::make_shared<jsi::Function>(toAsyncFunction(runtime, moduleRegistry));
189     } else {
190       body = std::make_shared<jsi::Function>(toSyncFunction(runtime, moduleRegistry));
191     }
192   }
193 
194   return body;
195 }
196 
197 jsi::Function MethodMetadata::toSyncFunction(
198   jsi::Runtime &runtime,
199   JSIInteropModuleRegistry *moduleRegistry
200 ) {
201   return jsi::Function::createFromHostFunction(
202     runtime,
203     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
204     args,
205     [this, moduleRegistry](
206       jsi::Runtime &rt,
207       const jsi::Value &thisValue,
208       const jsi::Value *args,
209       size_t count
210     ) -> jsi::Value {
211       try {
212         return this->callSync(
213           rt,
214           moduleRegistry,
215           args,
216           count
217         );
218       } catch (jni::JniException &jniException) {
219         rethrowAsCodedError(rt, jniException);
220       }
221     });
222 }
223 
224 jsi::Value MethodMetadata::callSync(
225   jsi::Runtime &rt,
226   JSIInteropModuleRegistry *moduleRegistry,
227   const jsi::Value *args,
228   size_t count
229 ) {
230   if (this->jBodyReference == nullptr) {
231     return jsi::Value::undefined();
232   }
233 
234   JNIEnv *env = jni::Environment::current();
235 
236   /**
237    * This will push a new JNI stack frame for the LocalReferences in this
238    * function call. When the stack frame for this lambda is popped,
239    * all LocalReferences are deleted.
240    */
241   jni::JniLocalScope scope(env, (int) count);
242 
243   auto convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args, count);
244 
245   // Cast in this place is safe, cause we know that this function is promise-less.
246   auto syncFunction = jni::static_ref_cast<JNIFunctionBody>(this->jBodyReference);
247   auto result = syncFunction->invoke(
248     convertedArgs
249   );
250 
251   env->DeleteLocalRef(convertedArgs);
252   if (result == nullptr) {
253     return jsi::Value::undefined();
254   }
255   auto unpackedResult = result.get();
256   auto cache = JavaReferencesCache::instance();
257   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Double").clazz)) {
258     return {jni::static_ref_cast<jni::JDouble>(result)->value()};
259   }
260   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Integer").clazz)) {
261     return {jni::static_ref_cast<jni::JInteger>(result)->value()};
262   }
263   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/String").clazz)) {
264     return jsi::String::createFromUtf8(
265       rt,
266       jni::static_ref_cast<jni::JString>(result)->toStdString()
267     );
268   }
269   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Boolean").clazz)) {
270     return {(bool) jni::static_ref_cast<jni::JBoolean>(result)->value()};
271   }
272   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Float").clazz)) {
273     return {(double) jni::static_ref_cast<jni::JFloat>(result)->value()};
274   }
275   if (env->IsInstanceOf(
276     unpackedResult,
277     cache->getJClass("com/facebook/react/bridge/WritableNativeArray").clazz
278   )) {
279     auto dynamic = jni::static_ref_cast<react::WritableNativeArray::javaobject>(result)
280       ->cthis()
281       ->consume();
282     return jsi::valueFromDynamic(rt, dynamic);
283   }
284   if (env->IsInstanceOf(
285     unpackedResult,
286     cache->getJClass("com/facebook/react/bridge/WritableNativeMap").clazz
287   )) {
288     auto dynamic = jni::static_ref_cast<react::WritableNativeMap::javaobject>(result)
289       ->cthis()
290       ->consume();
291     return jsi::valueFromDynamic(rt, dynamic);
292   }
293 
294   return jsi::Value::undefined();
295 }
296 
297 jsi::Function MethodMetadata::toAsyncFunction(
298   jsi::Runtime &runtime,
299   JSIInteropModuleRegistry *moduleRegistry
300 ) {
301   return jsi::Function::createFromHostFunction(
302     runtime,
303     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
304     args,
305     [this, moduleRegistry](
306       jsi::Runtime &rt,
307       const jsi::Value &thisValue,
308       const jsi::Value *args,
309       size_t count
310     ) -> jsi::Value {
311       JNIEnv *env = jni::Environment::current();
312 
313       /**
314        * This will push a new JNI stack frame for the LocalReferences in this
315        * function call. When the stack frame for this lambda is popped,
316        * all LocalReferences are deleted.
317        */
318       jni::JniLocalScope scope(env, (int) count);
319 
320       auto &Promise = moduleRegistry->jsRegistry->getObject<jsi::Function>(
321         JSReferencesCache::JSKeys::PROMISE
322       );
323 
324       try {
325         auto convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args, count);
326         auto globalConvertedArgs = (jobjectArray) env->NewGlobalRef(convertedArgs);
327         env->DeleteLocalRef(convertedArgs);
328 
329         // Creates a JSI promise
330         jsi::Value promise = Promise.callAsConstructor(
331           rt,
332           createPromiseBody(rt, moduleRegistry, globalConvertedArgs)
333         );
334         return promise;
335       } catch (jni::JniException &jniException) {
336         jni::local_ref<jni::JThrowable> unboxedThrowable = jniException.getThrowable();
337         if (!unboxedThrowable->isInstanceOf(CodedException::javaClassLocal())) {
338           unboxedThrowable = UnexpectedException::create(jniException.what());
339         }
340 
341         auto codedException = jni::static_ref_cast<CodedException>(unboxedThrowable);
342         auto code = codedException->getCode();
343         auto message = codedException->getLocalizedMessage().value_or("");
344 
345         jsi::Value promise = Promise.callAsConstructor(
346           rt,
347           jsi::Function::createFromHostFunction(
348             rt,
349             moduleRegistry->jsRegistry->getPropNameID(rt, "promiseFn"),
350             2,
351             [code, message](
352               jsi::Runtime &rt,
353               const jsi::Value &thisVal,
354               const jsi::Value *promiseConstructorArgs,
355               size_t promiseConstructorArgCount
356             ) {
357               if (promiseConstructorArgCount != 2) {
358                 throw std::invalid_argument("Promise fn arg count must be 2");
359               }
360 
361               jsi::Function rejectJSIFn = promiseConstructorArgs[1].getObject(rt).getFunction(rt);
362               rejectJSIFn.call(
363                 rt,
364                 makeCodedError(
365                   rt,
366                   jsi::String::createFromUtf8(rt, code),
367                   jsi::String::createFromUtf8(rt, message)
368                 )
369               );
370               return jsi::Value::undefined();
371             }
372           )
373         );
374 
375         return promise;
376       }
377     }
378   );
379 }
380 
381 jsi::Function MethodMetadata::createPromiseBody(
382   jsi::Runtime &runtime,
383   JSIInteropModuleRegistry *moduleRegistry,
384   jobjectArray globalArgs
385 ) {
386   return jsi::Function::createFromHostFunction(
387     runtime,
388     moduleRegistry->jsRegistry->getPropNameID(runtime, "promiseFn"),
389     2,
390     [this, globalArgs, moduleRegistry](
391       jsi::Runtime &rt,
392       const jsi::Value &thisVal,
393       const jsi::Value *promiseConstructorArgs,
394       size_t promiseConstructorArgCount
395     ) {
396       if (promiseConstructorArgCount != 2) {
397         throw std::invalid_argument("Promise fn arg count must be 2");
398       }
399 
400       jsi::Function resolveJSIFn = promiseConstructorArgs[0].getObject(rt).getFunction(rt);
401       jsi::Function rejectJSIFn = promiseConstructorArgs[1].getObject(rt).getFunction(rt);
402 
403       jobject resolve = createJavaCallbackFromJSIFunction(
404         std::move(resolveJSIFn),
405         rt,
406         moduleRegistry
407       ).release();
408 
409       jobject reject = createJavaCallbackFromJSIFunction(
410         std::move(rejectJSIFn),
411         rt,
412         moduleRegistry,
413         true
414       ).release();
415 
416       JNIEnv *env = jni::Environment::current();
417 
418       auto &jPromise = JavaReferencesCache::instance()->getJClass(
419         "expo/modules/kotlin/jni/PromiseImpl");
420       jmethodID jPromiseConstructor = jPromise.getMethod(
421         "<init>",
422         "(Lexpo/modules/kotlin/jni/JavaCallback;Lexpo/modules/kotlin/jni/JavaCallback;)V"
423       );
424 
425       // Creates a promise object
426       jobject promise = env->NewObject(
427         jPromise.clazz,
428         jPromiseConstructor,
429         resolve,
430         reject
431       );
432 
433       // Cast in this place is safe, cause we know that this function expects promise.
434       auto asyncFunction = jni::static_ref_cast<JNIAsyncFunctionBody>(this->jBodyReference);
435       asyncFunction->invoke(
436         globalArgs,
437         promise
438       );
439 
440       // We have to remove the local reference to the promise object.
441       // It doesn't mean that the promise will be deallocated, but rather that we move
442       // the ownership to the `JNIAsyncFunctionBody`.
443       env->DeleteLocalRef(promise);
444       env->DeleteGlobalRef(globalArgs);
445 
446       return jsi::Value::undefined();
447     }
448   );
449 }
450 } // namespace expo
451