1 #include "MethodMetadata.h"
2 #include "JSIInteropModuleRegistry.h"
3 #include "JavaScriptValue.h"
4 #include "JavaScriptObject.h"
5 #include "JavaScriptTypedArray.h"
6 #include "JavaReferencesCache.h"
7 #include "Exceptions.h"
8 #include "JavaCallback.h"
9 
10 #include <utility>
11 
12 #include <react/jni/ReadableNativeMap.h>
13 #include <react/jni/ReadableNativeArray.h>
14 #include <react/jni/WritableNativeArray.h>
15 #include <react/jni/WritableNativeMap.h>
16 #include "JSReferencesCache.h"
17 
18 namespace jni = facebook::jni;
19 namespace jsi = facebook::jsi;
20 namespace react = facebook::react;
21 
22 namespace expo {
23 
24 // Modified version of the RN implementation
25 // https://github.com/facebook/react-native/blob/7dceb9b63c0bfd5b13bf6d26f9530729506e9097/ReactCommon/react/nativemodule/core/platform/android/ReactCommon/JavaTurboModule.cpp#L57
26 jni::local_ref<JavaCallback::JavaPart> createJavaCallbackFromJSIFunction(
27   jsi::Function &&function,
28   std::weak_ptr<react::LongLivedObjectCollection> longLivedObjectCollection,
29   jsi::Runtime &rt,
30   JSIInteropModuleRegistry *moduleRegistry,
31   bool isRejectCallback = false
32 ) {
33   std::shared_ptr<react::CallInvoker> jsInvoker = moduleRegistry->runtimeHolder->jsInvoker;
34   auto strongLongLiveObjectCollection = longLivedObjectCollection.lock();
35   if (!strongLongLiveObjectCollection) {
36     throw std::runtime_error("The LongLivedObjectCollection for MethodMetadata is not alive.");
37   }
38   auto weakWrapper = react::CallbackWrapper::createWeak(strongLongLiveObjectCollection,
39                                                         std::move(function), rt,
40                                                         std::move(jsInvoker));
41 
42   // This needs to be a shared_ptr because:
43   // 1. It cannot be unique_ptr. std::function is copyable but unique_ptr is
44   // not.
45   // 2. It cannot be weak_ptr since we need this object to live on.
46   // 3. It cannot be a value, because that would be deleted as soon as this
47   // function returns.
48   auto callbackWrapperOwner =
49     std::make_shared<react::RAIICallbackWrapperDestroyer>(weakWrapper);
50 
51   std::function<void(folly::dynamic)> fn =
52     [
53       weakWrapper,
54       callbackWrapperOwner = std::move(callbackWrapperOwner),
55       wrapperWasCalled = false,
56       isRejectCallback
57     ](
58       folly::dynamic responses) mutable {
59       if (wrapperWasCalled) {
60         throw std::runtime_error(
61           "callback 2 arg cannot be called more than once");
62       }
63 
64       auto strongWrapper = weakWrapper.lock();
65       if (!strongWrapper) {
66         return;
67       }
68 
69       strongWrapper->jsInvoker().invokeAsync(
70         [
71           weakWrapper,
72           callbackWrapperOwner = std::move(callbackWrapperOwner),
73           responses = std::move(responses),
74           isRejectCallback
75         ]() mutable {
76           auto strongWrapper2 = weakWrapper.lock();
77           if (!strongWrapper2) {
78             return;
79           }
80 
81           jsi::Value arg = jsi::valueFromDynamic(strongWrapper2->runtime(), responses);
82           if (!isRejectCallback) {
83             strongWrapper2->callback().call(
84               strongWrapper2->runtime(),
85               (const jsi::Value *) &arg,
86               (size_t) 1
87             );
88           } else {
89             auto &rt = strongWrapper2->runtime();
90             auto jsErrorObject = arg.getObject(rt);
91             auto errorCode = jsErrorObject.getProperty(rt, "code").asString(rt);
92             auto message = jsErrorObject.getProperty(rt, "message").asString(rt);
93 
94             auto codedError = makeCodedError(
95               rt,
96               std::move(errorCode),
97               std::move(message)
98             );
99 
100             strongWrapper2->callback().call(
101               strongWrapper2->runtime(),
102               (const jsi::Value *) &codedError,
103               (size_t) 1
104             );
105           }
106 
107           callbackWrapperOwner.reset();
108         });
109 
110       wrapperWasCalled = true;
111     };
112 
113   return JavaCallback::newObjectCxxArgs(std::move(fn));
114 }
115 
116 jobjectArray MethodMetadata::convertJSIArgsToJNI(
117   JSIInteropModuleRegistry *moduleRegistry,
118   JNIEnv *env,
119   jsi::Runtime &rt,
120   const jsi::Value *args,
121   size_t count
122 ) {
123   auto argumentArray = env->NewObjectArray(
124     count,
125     JavaReferencesCache::instance()->getJClass("java/lang/Object").clazz,
126     nullptr
127   );
128 
129   std::vector<jobject> result(count);
130 
131   for (unsigned int argIndex = 0; argIndex < count; argIndex++) {
132     const jsi::Value &arg = args[argIndex];
133     auto &type = argTypes[argIndex];
134     if (arg.isNull() || arg.isUndefined()) {
135       // If value is null or undefined, we just passes a null
136       // Kotlin code will check if expected type is nullable.
137       result[argIndex] = nullptr;
138     } else {
139       if (type->converter->canConvert(rt, arg)) {
140         auto converterValue = type->converter->convert(rt, env, moduleRegistry, arg);
141         env->SetObjectArrayElement(argumentArray, argIndex, converterValue);
142         env->DeleteLocalRef(converterValue);
143       } else {
144         auto stringRepresentation = arg.toString(rt).utf8(rt);
145         throwNewJavaException(
146           UnexpectedException::create(
147             "Cannot convert '" + stringRepresentation + "' to a Kotlin type.").get()
148         );
149       }
150     }
151   }
152 
153   return argumentArray;
154 }
155 
156 MethodMetadata::MethodMetadata(
157   std::weak_ptr<react::LongLivedObjectCollection> longLivedObjectCollection,
158   std::string name,
159   int args,
160   bool isAsync,
161   jni::local_ref<jni::JArrayClass<ExpectedType>> expectedArgTypes,
162   jni::global_ref<jobject> &&jBodyReference
163 ) : name(std::move(name)),
164     args(args),
165     isAsync(isAsync),
166     jBodyReference(std::move(jBodyReference)),
167     longLivedObjectCollection_(std::move(longLivedObjectCollection)) {
168   argTypes.reserve(args);
169   for (size_t i = 0; i < args; i++) {
170     auto expectedType = expectedArgTypes->getElement(i);
171     argTypes.push_back(
172       std::make_unique<AnyType>(std::move(expectedType))
173     );
174   }
175 }
176 
177 MethodMetadata::MethodMetadata(
178   std::weak_ptr<react::LongLivedObjectCollection> longLivedObjectCollection,
179   std::string name,
180   int args,
181   bool isAsync,
182   std::vector<std::unique_ptr<AnyType>> &&expectedArgTypes,
183   jni::global_ref<jobject> &&jBodyReference
184 ) : name(std::move(name)),
185     args(args),
186     isAsync(isAsync),
187     argTypes(std::move(expectedArgTypes)),
188     jBodyReference(std::move(jBodyReference)),
189     longLivedObjectCollection_(std::move(longLivedObjectCollection)) {
190 }
191 
192 std::shared_ptr<jsi::Function> MethodMetadata::toJSFunction(
193   jsi::Runtime &runtime,
194   JSIInteropModuleRegistry *moduleRegistry
195 ) {
196   if (body == nullptr) {
197     if (isAsync) {
198       body = std::make_shared<jsi::Function>(toAsyncFunction(runtime, moduleRegistry));
199     } else {
200       body = std::make_shared<jsi::Function>(toSyncFunction(runtime, moduleRegistry));
201     }
202   }
203 
204   return body;
205 }
206 
207 jsi::Function MethodMetadata::toSyncFunction(
208   jsi::Runtime &runtime,
209   JSIInteropModuleRegistry *moduleRegistry
210 ) {
211   return jsi::Function::createFromHostFunction(
212     runtime,
213     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
214     args,
215     [this, moduleRegistry](
216       jsi::Runtime &rt,
217       const jsi::Value &thisValue,
218       const jsi::Value *args,
219       size_t count
220     ) -> jsi::Value {
221       try {
222         return this->callSync(
223           rt,
224           moduleRegistry,
225           args,
226           count
227         );
228       } catch (jni::JniException &jniException) {
229         rethrowAsCodedError(rt, jniException);
230       }
231     });
232 }
233 
234 jsi::Value MethodMetadata::callSync(
235   jsi::Runtime &rt,
236   JSIInteropModuleRegistry *moduleRegistry,
237   const jsi::Value *args,
238   size_t count
239 ) {
240   if (this->jBodyReference == nullptr) {
241     return jsi::Value::undefined();
242   }
243 
244   JNIEnv *env = jni::Environment::current();
245 
246   /**
247    * This will push a new JNI stack frame for the LocalReferences in this
248    * function call. When the stack frame for this lambda is popped,
249    * all LocalReferences are deleted.
250    */
251   jni::JniLocalScope scope(env, (int) count);
252 
253   auto convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args, count);
254 
255   // Cast in this place is safe, cause we know that this function is promise-less.
256   auto syncFunction = jni::static_ref_cast<JNIFunctionBody>(this->jBodyReference);
257   auto result = syncFunction->invoke(
258     convertedArgs
259   );
260 
261   env->DeleteLocalRef(convertedArgs);
262   if (result == nullptr) {
263     return jsi::Value::undefined();
264   }
265   auto unpackedResult = result.get();
266   auto cache = JavaReferencesCache::instance();
267   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Double").clazz)) {
268     return {jni::static_ref_cast<jni::JDouble>(result)->value()};
269   }
270   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Integer").clazz)) {
271     return {jni::static_ref_cast<jni::JInteger>(result)->value()};
272   }
273   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Long").clazz)) {
274     return {(double) jni::static_ref_cast<jni::JLong>(result)->value()};
275   }
276   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/String").clazz)) {
277     return jsi::String::createFromUtf8(
278       rt,
279       jni::static_ref_cast<jni::JString>(result)->toStdString()
280     );
281   }
282   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Boolean").clazz)) {
283     return {(bool) jni::static_ref_cast<jni::JBoolean>(result)->value()};
284   }
285   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Float").clazz)) {
286     return {(double) jni::static_ref_cast<jni::JFloat>(result)->value()};
287   }
288   if (env->IsInstanceOf(
289     unpackedResult,
290     cache->getJClass("com/facebook/react/bridge/WritableNativeArray").clazz
291   )) {
292     auto dynamic = jni::static_ref_cast<react::WritableNativeArray::javaobject>(result)
293       ->cthis()
294       ->consume();
295     return jsi::valueFromDynamic(rt, dynamic);
296   }
297   if (env->IsInstanceOf(
298     unpackedResult,
299     cache->getJClass("com/facebook/react/bridge/WritableNativeMap").clazz
300   )) {
301     auto dynamic = jni::static_ref_cast<react::WritableNativeMap::javaobject>(result)
302       ->cthis()
303       ->consume();
304     return jsi::valueFromDynamic(rt, dynamic);
305   }
306 
307   return jsi::Value::undefined();
308 }
309 
310 jsi::Function MethodMetadata::toAsyncFunction(
311   jsi::Runtime &runtime,
312   JSIInteropModuleRegistry *moduleRegistry
313 ) {
314   return jsi::Function::createFromHostFunction(
315     runtime,
316     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
317     args,
318     [this, moduleRegistry](
319       jsi::Runtime &rt,
320       const jsi::Value &thisValue,
321       const jsi::Value *args,
322       size_t count
323     ) -> jsi::Value {
324       JNIEnv *env = jni::Environment::current();
325 
326       /**
327        * This will push a new JNI stack frame for the LocalReferences in this
328        * function call. When the stack frame for this lambda is popped,
329        * all LocalReferences are deleted.
330        */
331       jni::JniLocalScope scope(env, (int) count);
332 
333       auto &Promise = moduleRegistry->jsRegistry->getObject<jsi::Function>(
334         JSReferencesCache::JSKeys::PROMISE
335       );
336 
337       try {
338         auto convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args, count);
339         auto globalConvertedArgs = (jobjectArray) env->NewGlobalRef(convertedArgs);
340         env->DeleteLocalRef(convertedArgs);
341 
342         // Creates a JSI promise
343         jsi::Value promise = Promise.callAsConstructor(
344           rt,
345           createPromiseBody(rt, moduleRegistry, globalConvertedArgs)
346         );
347         return promise;
348       } catch (jni::JniException &jniException) {
349         jni::local_ref<jni::JThrowable> unboxedThrowable = jniException.getThrowable();
350         if (!unboxedThrowable->isInstanceOf(CodedException::javaClassLocal())) {
351           unboxedThrowable = UnexpectedException::create(jniException.what());
352         }
353 
354         auto codedException = jni::static_ref_cast<CodedException>(unboxedThrowable);
355         auto code = codedException->getCode();
356         auto message = codedException->getLocalizedMessage().value_or("");
357 
358         jsi::Value promise = Promise.callAsConstructor(
359           rt,
360           jsi::Function::createFromHostFunction(
361             rt,
362             moduleRegistry->jsRegistry->getPropNameID(rt, "promiseFn"),
363             2,
364             [code, message](
365               jsi::Runtime &rt,
366               const jsi::Value &thisVal,
367               const jsi::Value *promiseConstructorArgs,
368               size_t promiseConstructorArgCount
369             ) {
370               if (promiseConstructorArgCount != 2) {
371                 throw std::invalid_argument("Promise fn arg count must be 2");
372               }
373 
374               jsi::Function rejectJSIFn = promiseConstructorArgs[1].getObject(rt).getFunction(rt);
375               rejectJSIFn.call(
376                 rt,
377                 makeCodedError(
378                   rt,
379                   jsi::String::createFromUtf8(rt, code),
380                   jsi::String::createFromUtf8(rt, message)
381                 )
382               );
383               return jsi::Value::undefined();
384             }
385           )
386         );
387 
388         return promise;
389       }
390     }
391   );
392 }
393 
394 jsi::Function MethodMetadata::createPromiseBody(
395   jsi::Runtime &runtime,
396   JSIInteropModuleRegistry *moduleRegistry,
397   jobjectArray globalArgs
398 ) {
399   return jsi::Function::createFromHostFunction(
400     runtime,
401     moduleRegistry->jsRegistry->getPropNameID(runtime, "promiseFn"),
402     2,
403     [this, globalArgs, moduleRegistry](
404       jsi::Runtime &rt,
405       const jsi::Value &thisVal,
406       const jsi::Value *promiseConstructorArgs,
407       size_t promiseConstructorArgCount
408     ) {
409       if (promiseConstructorArgCount != 2) {
410         throw std::invalid_argument("Promise fn arg count must be 2");
411       }
412 
413       jsi::Function resolveJSIFn = promiseConstructorArgs[0].getObject(rt).getFunction(rt);
414       jsi::Function rejectJSIFn = promiseConstructorArgs[1].getObject(rt).getFunction(rt);
415 
416       jobject resolve = createJavaCallbackFromJSIFunction(
417         std::move(resolveJSIFn),
418         longLivedObjectCollection_,
419         rt,
420         moduleRegistry
421       ).release();
422 
423       jobject reject = createJavaCallbackFromJSIFunction(
424         std::move(rejectJSIFn),
425         longLivedObjectCollection_,
426         rt,
427         moduleRegistry,
428         true
429       ).release();
430 
431       JNIEnv *env = jni::Environment::current();
432 
433       auto &jPromise = JavaReferencesCache::instance()->getJClass(
434         "expo/modules/kotlin/jni/PromiseImpl");
435       jmethodID jPromiseConstructor = jPromise.getMethod(
436         "<init>",
437         "(Lexpo/modules/kotlin/jni/JavaCallback;Lexpo/modules/kotlin/jni/JavaCallback;)V"
438       );
439 
440       // Creates a promise object
441       jobject promise = env->NewObject(
442         jPromise.clazz,
443         jPromiseConstructor,
444         resolve,
445         reject
446       );
447 
448       // Cast in this place is safe, cause we know that this function expects promise.
449       auto asyncFunction = jni::static_ref_cast<JNIAsyncFunctionBody>(this->jBodyReference);
450       asyncFunction->invoke(
451         globalArgs,
452         promise
453       );
454 
455       // We have to remove the local reference to the promise object.
456       // It doesn't mean that the promise will be deallocated, but rather that we move
457       // the ownership to the `JNIAsyncFunctionBody`.
458       env->DeleteLocalRef(promise);
459       env->DeleteGlobalRef(globalArgs);
460 
461       return jsi::Value::undefined();
462     }
463   );
464 }
465 } // namespace expo
466