1 #include "MethodMetadata.h"
2 #include "JSIInteropModuleRegistry.h"
3 #include "JavaScriptValue.h"
4 #include "JavaScriptObject.h"
5 #include "JavaScriptTypedArray.h"
6 #include "JavaReferencesCache.h"
7 #include "Exceptions.h"
8 #include "JavaCallback.h"
9 
10 #include <utility>
11 #include <functional>
12 
13 #include <react/jni/ReadableNativeMap.h>
14 #include <react/jni/ReadableNativeArray.h>
15 #include <react/jni/WritableNativeArray.h>
16 #include <react/jni/WritableNativeMap.h>
17 #include "JSReferencesCache.h"
18 
19 namespace jni = facebook::jni;
20 namespace jsi = facebook::jsi;
21 namespace react = facebook::react;
22 
23 namespace expo {
24 
25 class JSI_EXPORT ObjectDeallocator : public jsi::HostObject {
26 public:
27   typedef std::function<void()> ObjectDeallocatorType;
28 
29   ObjectDeallocator(ObjectDeallocatorType deallocator) : deallocator(deallocator) {};
30 
31   virtual ~ObjectDeallocator() {
32     deallocator();
33   }
34 
35   const ObjectDeallocatorType deallocator;
36 
37 }; // class ObjectDeallocator
38 
39 // Modified version of the RN implementation
40 // https://github.com/facebook/react-native/blob/7dceb9b63c0bfd5b13bf6d26f9530729506e9097/ReactCommon/react/nativemodule/core/platform/android/ReactCommon/JavaTurboModule.cpp#L57
41 jni::local_ref<JavaCallback::JavaPart> createJavaCallbackFromJSIFunction(
42   jsi::Function &&function,
43   std::weak_ptr<react::LongLivedObjectCollection> longLivedObjectCollection,
44   jsi::Runtime &rt,
45   JSIInteropModuleRegistry *moduleRegistry,
46   bool isRejectCallback = false
47 ) {
48   std::shared_ptr<react::CallInvoker> jsInvoker = moduleRegistry->runtimeHolder->jsInvoker;
49   auto strongLongLiveObjectCollection = longLivedObjectCollection.lock();
50   if (!strongLongLiveObjectCollection) {
51     throw std::runtime_error("The LongLivedObjectCollection for MethodMetadata is not alive.");
52   }
53   auto weakWrapper = react::CallbackWrapper::createWeak(strongLongLiveObjectCollection,
54                                                         std::move(function), rt,
55                                                         std::move(jsInvoker));
56 
57   // This needs to be a shared_ptr because:
58   // 1. It cannot be unique_ptr. std::function is copyable but unique_ptr is
59   // not.
60   // 2. It cannot be weak_ptr since we need this object to live on.
61   // 3. It cannot be a value, because that would be deleted as soon as this
62   // function returns.
63   auto callbackWrapperOwner =
64     std::make_shared<react::RAIICallbackWrapperDestroyer>(weakWrapper);
65 
66   std::function<void(folly::dynamic)> fn =
67     [
68       weakWrapper,
69       callbackWrapperOwner = std::move(callbackWrapperOwner),
70       wrapperWasCalled = false,
71       isRejectCallback
72     ](
73       folly::dynamic responses) mutable {
74       if (wrapperWasCalled) {
75         throw std::runtime_error(
76           "callback 2 arg cannot be called more than once");
77       }
78 
79       auto strongWrapper = weakWrapper.lock();
80       if (!strongWrapper) {
81         return;
82       }
83 
84       strongWrapper->jsInvoker().invokeAsync(
85         [
86           weakWrapper,
87           callbackWrapperOwner = std::move(callbackWrapperOwner),
88           responses = std::move(responses),
89           isRejectCallback
90         ]() mutable {
91           auto strongWrapper2 = weakWrapper.lock();
92           if (!strongWrapper2) {
93             return;
94           }
95 
96           jsi::Value arg = jsi::valueFromDynamic(strongWrapper2->runtime(), responses);
97           if (!isRejectCallback) {
98             strongWrapper2->callback().call(
99               strongWrapper2->runtime(),
100               (const jsi::Value *) &arg,
101               (size_t) 1
102             );
103           } else {
104             auto &rt = strongWrapper2->runtime();
105             auto jsErrorObject = arg.getObject(rt);
106             auto errorCode = jsErrorObject.getProperty(rt, "code").asString(rt);
107             auto message = jsErrorObject.getProperty(rt, "message").asString(rt);
108 
109             auto codedError = makeCodedError(
110               rt,
111               std::move(errorCode),
112               std::move(message)
113             );
114 
115             strongWrapper2->callback().call(
116               strongWrapper2->runtime(),
117               (const jsi::Value *) &codedError,
118               (size_t) 1
119             );
120           }
121 
122           callbackWrapperOwner.reset();
123         });
124 
125       wrapperWasCalled = true;
126     };
127 
128   return JavaCallback::newObjectCxxArgs(std::move(fn));
129 }
130 
131 jobjectArray MethodMetadata::convertJSIArgsToJNI(
132   JSIInteropModuleRegistry *moduleRegistry,
133   JNIEnv *env,
134   jsi::Runtime &rt,
135   const jsi::Value *args,
136   size_t count
137 ) {
138   auto argumentArray = env->NewObjectArray(
139     count,
140     JavaReferencesCache::instance()->getJClass("java/lang/Object").clazz,
141     nullptr
142   );
143 
144   std::vector<jobject> result(count);
145 
146   for (unsigned int argIndex = 0; argIndex < count; argIndex++) {
147     const jsi::Value &arg = args[argIndex];
148     auto &type = argTypes[argIndex];
149     if (arg.isNull() || arg.isUndefined()) {
150       // If value is null or undefined, we just passes a null
151       // Kotlin code will check if expected type is nullable.
152       result[argIndex] = nullptr;
153     } else {
154       if (type->converter->canConvert(rt, arg)) {
155         auto converterValue = type->converter->convert(rt, env, moduleRegistry, arg);
156         env->SetObjectArrayElement(argumentArray, argIndex, converterValue);
157         env->DeleteLocalRef(converterValue);
158       } else {
159         auto stringRepresentation = arg.toString(rt).utf8(rt);
160         throwNewJavaException(
161           UnexpectedException::create(
162             "Cannot convert '" + stringRepresentation + "' to a Kotlin type.").get()
163         );
164       }
165     }
166   }
167 
168   return argumentArray;
169 }
170 
171 MethodMetadata::MethodMetadata(
172   std::weak_ptr<react::LongLivedObjectCollection> longLivedObjectCollection,
173   std::string name,
174   int args,
175   bool isAsync,
176   jni::local_ref<jni::JArrayClass<ExpectedType>> expectedArgTypes,
177   jni::global_ref<jobject> &&jBodyReference
178 ) : name(std::move(name)),
179     args(args),
180     isAsync(isAsync),
181     jBodyReference(std::move(jBodyReference)),
182     longLivedObjectCollection_(std::move(longLivedObjectCollection)) {
183   argTypes.reserve(args);
184   for (size_t i = 0; i < args; i++) {
185     auto expectedType = expectedArgTypes->getElement(i);
186     argTypes.push_back(
187       std::make_unique<AnyType>(std::move(expectedType))
188     );
189   }
190 }
191 
192 MethodMetadata::MethodMetadata(
193   std::weak_ptr<react::LongLivedObjectCollection> longLivedObjectCollection,
194   std::string name,
195   int args,
196   bool isAsync,
197   std::vector<std::unique_ptr<AnyType>> &&expectedArgTypes,
198   jni::global_ref<jobject> &&jBodyReference
199 ) : name(std::move(name)),
200     args(args),
201     isAsync(isAsync),
202     argTypes(std::move(expectedArgTypes)),
203     jBodyReference(std::move(jBodyReference)),
204     longLivedObjectCollection_(std::move(longLivedObjectCollection)) {
205 }
206 
207 std::shared_ptr<jsi::Function> MethodMetadata::toJSFunction(
208   jsi::Runtime &runtime,
209   JSIInteropModuleRegistry *moduleRegistry
210 ) {
211   if (body == nullptr) {
212     if (isAsync) {
213       body = std::make_shared<jsi::Function>(toAsyncFunction(runtime, moduleRegistry));
214     } else {
215       body = std::make_shared<jsi::Function>(toSyncFunction(runtime, moduleRegistry));
216     }
217   }
218 
219   return body;
220 }
221 
222 jsi::Function MethodMetadata::toSyncFunction(
223   jsi::Runtime &runtime,
224   JSIInteropModuleRegistry *moduleRegistry
225 ) {
226   return jsi::Function::createFromHostFunction(
227     runtime,
228     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
229     args,
230     [this, moduleRegistry](
231       jsi::Runtime &rt,
232       const jsi::Value &thisValue,
233       const jsi::Value *args,
234       size_t count
235     ) -> jsi::Value {
236       try {
237         return this->callSync(
238           rt,
239           moduleRegistry,
240           args,
241           count
242         );
243       } catch (jni::JniException &jniException) {
244         rethrowAsCodedError(rt, jniException);
245       }
246     });
247 }
248 
249 jsi::Value MethodMetadata::callSync(
250   jsi::Runtime &rt,
251   JSIInteropModuleRegistry *moduleRegistry,
252   const jsi::Value *args,
253   size_t count
254 ) {
255   if (this->jBodyReference == nullptr) {
256     return jsi::Value::undefined();
257   }
258 
259   JNIEnv *env = jni::Environment::current();
260 
261   /**
262    * This will push a new JNI stack frame for the LocalReferences in this
263    * function call. When the stack frame for this lambda is popped,
264    * all LocalReferences are deleted.
265    */
266   jni::JniLocalScope scope(env, (int) count);
267 
268   auto convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args, count);
269 
270   // Cast in this place is safe, cause we know that this function is promise-less.
271   auto syncFunction = jni::static_ref_cast<JNIFunctionBody>(this->jBodyReference);
272   auto result = syncFunction->invoke(
273     convertedArgs
274   );
275 
276   env->DeleteLocalRef(convertedArgs);
277   if (result == nullptr) {
278     return jsi::Value::undefined();
279   }
280   auto unpackedResult = result.get();
281   auto cache = JavaReferencesCache::instance();
282   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Double").clazz)) {
283     return {jni::static_ref_cast<jni::JDouble>(result)->value()};
284   }
285   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Integer").clazz)) {
286     return {jni::static_ref_cast<jni::JInteger>(result)->value()};
287   }
288   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Long").clazz)) {
289     return {(double) jni::static_ref_cast<jni::JLong>(result)->value()};
290   }
291   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/String").clazz)) {
292     return jsi::String::createFromUtf8(
293       rt,
294       jni::static_ref_cast<jni::JString>(result)->toStdString()
295     );
296   }
297   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Boolean").clazz)) {
298     return {(bool) jni::static_ref_cast<jni::JBoolean>(result)->value()};
299   }
300   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Float").clazz)) {
301     return {(double) jni::static_ref_cast<jni::JFloat>(result)->value()};
302   }
303   if (env->IsInstanceOf(
304     unpackedResult,
305     cache->getJClass("com/facebook/react/bridge/WritableNativeArray").clazz
306   )) {
307     auto dynamic = jni::static_ref_cast<react::WritableNativeArray::javaobject>(result)
308       ->cthis()
309       ->consume();
310     return jsi::valueFromDynamic(rt, dynamic);
311   }
312   if (env->IsInstanceOf(
313     unpackedResult,
314     cache->getJClass("com/facebook/react/bridge/WritableNativeMap").clazz
315   )) {
316     auto dynamic = jni::static_ref_cast<react::WritableNativeMap::javaobject>(result)
317       ->cthis()
318       ->consume();
319     return jsi::valueFromDynamic(rt, dynamic);
320   }
321   if (env->IsInstanceOf(unpackedResult, JavaScriptModuleObject::javaClassStatic().get())) {
322     auto anonymousObject = jni::static_ref_cast<JavaScriptModuleObject::javaobject>(result)
323       ->cthis();
324     anonymousObject->jsiInteropModuleRegistry = moduleRegistry;
325     auto jsiObject = anonymousObject->getJSIObject(rt);
326 
327     jni::global_ref<jobject> globalRef = jni::make_global(result);
328     std::shared_ptr<expo::ObjectDeallocator> deallocator = std::make_shared<ObjectDeallocator>(
329       [globalRef = globalRef]() mutable {
330         globalRef.reset();
331       });
332 
333     auto descriptor = JavaScriptObject::preparePropertyDescriptor(rt, 0);
334     descriptor.setProperty(rt, "value", jsi::Object::createFromHostObject(rt, deallocator));
335     JavaScriptObject::defineProperty(rt, jsiObject, "__expo_object_deallocator__", std::move(descriptor));
336 
337     return jsi::Value(rt, *jsiObject);
338   }
339 
340   return jsi::Value::undefined();
341 }
342 
343 jsi::Function MethodMetadata::toAsyncFunction(
344   jsi::Runtime &runtime,
345   JSIInteropModuleRegistry *moduleRegistry
346 ) {
347   return jsi::Function::createFromHostFunction(
348     runtime,
349     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
350     args,
351     [this, moduleRegistry](
352       jsi::Runtime &rt,
353       const jsi::Value &thisValue,
354       const jsi::Value *args,
355       size_t count
356     ) -> jsi::Value {
357       JNIEnv *env = jni::Environment::current();
358 
359       /**
360        * This will push a new JNI stack frame for the LocalReferences in this
361        * function call. When the stack frame for this lambda is popped,
362        * all LocalReferences are deleted.
363        */
364       jni::JniLocalScope scope(env, (int) count);
365 
366       auto &Promise = moduleRegistry->jsRegistry->getObject<jsi::Function>(
367         JSReferencesCache::JSKeys::PROMISE
368       );
369 
370       try {
371         auto convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args, count);
372         auto globalConvertedArgs = (jobjectArray) env->NewGlobalRef(convertedArgs);
373         env->DeleteLocalRef(convertedArgs);
374 
375         // Creates a JSI promise
376         jsi::Value promise = Promise.callAsConstructor(
377           rt,
378           createPromiseBody(rt, moduleRegistry, globalConvertedArgs)
379         );
380         return promise;
381       } catch (jni::JniException &jniException) {
382         jni::local_ref<jni::JThrowable> unboxedThrowable = jniException.getThrowable();
383         if (!unboxedThrowable->isInstanceOf(CodedException::javaClassLocal())) {
384           unboxedThrowable = UnexpectedException::create(jniException.what());
385         }
386 
387         auto codedException = jni::static_ref_cast<CodedException>(unboxedThrowable);
388         auto code = codedException->getCode();
389         auto message = codedException->getLocalizedMessage().value_or("");
390 
391         jsi::Value promise = Promise.callAsConstructor(
392           rt,
393           jsi::Function::createFromHostFunction(
394             rt,
395             moduleRegistry->jsRegistry->getPropNameID(rt, "promiseFn"),
396             2,
397             [code, message](
398               jsi::Runtime &rt,
399               const jsi::Value &thisVal,
400               const jsi::Value *promiseConstructorArgs,
401               size_t promiseConstructorArgCount
402             ) {
403               if (promiseConstructorArgCount != 2) {
404                 throw std::invalid_argument("Promise fn arg count must be 2");
405               }
406 
407               jsi::Function rejectJSIFn = promiseConstructorArgs[1].getObject(rt).getFunction(rt);
408               rejectJSIFn.call(
409                 rt,
410                 makeCodedError(
411                   rt,
412                   jsi::String::createFromUtf8(rt, code),
413                   jsi::String::createFromUtf8(rt, message)
414                 )
415               );
416               return jsi::Value::undefined();
417             }
418           )
419         );
420 
421         return promise;
422       }
423     }
424   );
425 }
426 
427 jsi::Function MethodMetadata::createPromiseBody(
428   jsi::Runtime &runtime,
429   JSIInteropModuleRegistry *moduleRegistry,
430   jobjectArray globalArgs
431 ) {
432   return jsi::Function::createFromHostFunction(
433     runtime,
434     moduleRegistry->jsRegistry->getPropNameID(runtime, "promiseFn"),
435     2,
436     [this, globalArgs, moduleRegistry](
437       jsi::Runtime &rt,
438       const jsi::Value &thisVal,
439       const jsi::Value *promiseConstructorArgs,
440       size_t promiseConstructorArgCount
441     ) {
442       if (promiseConstructorArgCount != 2) {
443         throw std::invalid_argument("Promise fn arg count must be 2");
444       }
445 
446       jsi::Function resolveJSIFn = promiseConstructorArgs[0].getObject(rt).getFunction(rt);
447       jsi::Function rejectJSIFn = promiseConstructorArgs[1].getObject(rt).getFunction(rt);
448 
449       jobject resolve = createJavaCallbackFromJSIFunction(
450         std::move(resolveJSIFn),
451         longLivedObjectCollection_,
452         rt,
453         moduleRegistry
454       ).release();
455 
456       jobject reject = createJavaCallbackFromJSIFunction(
457         std::move(rejectJSIFn),
458         longLivedObjectCollection_,
459         rt,
460         moduleRegistry,
461         true
462       ).release();
463 
464       JNIEnv *env = jni::Environment::current();
465 
466       auto &jPromise = JavaReferencesCache::instance()->getJClass(
467         "expo/modules/kotlin/jni/PromiseImpl");
468       jmethodID jPromiseConstructor = jPromise.getMethod(
469         "<init>",
470         "(Lexpo/modules/kotlin/jni/JavaCallback;Lexpo/modules/kotlin/jni/JavaCallback;)V"
471       );
472 
473       // Creates a promise object
474       jobject promise = env->NewObject(
475         jPromise.clazz,
476         jPromiseConstructor,
477         resolve,
478         reject
479       );
480 
481       // Cast in this place is safe, cause we know that this function expects promise.
482       auto asyncFunction = jni::static_ref_cast<JNIAsyncFunctionBody>(this->jBodyReference);
483       asyncFunction->invoke(
484         globalArgs,
485         promise
486       );
487 
488       // We have to remove the local reference to the promise object.
489       // It doesn't mean that the promise will be deallocated, but rather that we move
490       // the ownership to the `JNIAsyncFunctionBody`.
491       env->DeleteLocalRef(promise);
492       env->DeleteGlobalRef(globalArgs);
493 
494       return jsi::Value::undefined();
495     }
496   );
497 }
498 } // namespace expo
499