1 #include "MethodMetadata.h"
2 #include "JSIInteropModuleRegistry.h"
3 #include "JavaScriptValue.h"
4 #include "JavaScriptObject.h"
5 #include "JavaScriptTypedArray.h"
6 #include "JavaReferencesCache.h"
7 #include "Exceptions.h"
8 #include "JavaCallback.h"
9 
10 #include <utility>
11 
12 #include <react/jni/ReadableNativeMap.h>
13 #include <react/jni/ReadableNativeArray.h>
14 #include <react/jni/WritableNativeArray.h>
15 #include <react/jni/WritableNativeMap.h>
16 #include "JSReferencesCache.h"
17 
18 namespace jni = facebook::jni;
19 namespace jsi = facebook::jsi;
20 namespace react = facebook::react;
21 
22 namespace expo {
23 
24 // Modified version of the RN implementation
25 // https://github.com/facebook/react-native/blob/7dceb9b63c0bfd5b13bf6d26f9530729506e9097/ReactCommon/react/nativemodule/core/platform/android/ReactCommon/JavaTurboModule.cpp#L57
26 jni::local_ref<JavaCallback::JavaPart> createJavaCallbackFromJSIFunction(
27   jsi::Function &&function,
28   std::weak_ptr<react::LongLivedObjectCollection> longLivedObjectCollection,
29   jsi::Runtime &rt,
30   JSIInteropModuleRegistry *moduleRegistry,
31   bool isRejectCallback = false
32 ) {
33   std::shared_ptr<react::CallInvoker> jsInvoker = moduleRegistry->runtimeHolder->jsInvoker;
34   auto strongLongLiveObjectCollection = longLivedObjectCollection.lock();
35   if (!strongLongLiveObjectCollection) {
36     throw std::runtime_error("The LongLivedObjectCollection for MethodMetadata is not alive.");
37   }
38   auto weakWrapper = react::CallbackWrapper::createWeak(strongLongLiveObjectCollection,
39                                                         std::move(function), rt,
40                                                         std::move(jsInvoker));
41 
42   // This needs to be a shared_ptr because:
43   // 1. It cannot be unique_ptr. std::function is copyable but unique_ptr is
44   // not.
45   // 2. It cannot be weak_ptr since we need this object to live on.
46   // 3. It cannot be a value, because that would be deleted as soon as this
47   // function returns.
48   auto callbackWrapperOwner =
49     std::make_shared<react::RAIICallbackWrapperDestroyer>(weakWrapper);
50 
51   std::function<void(folly::dynamic)> fn =
52     [
53       weakWrapper,
54       callbackWrapperOwner = std::move(callbackWrapperOwner),
55       wrapperWasCalled = false,
56       isRejectCallback
57     ](
58       folly::dynamic responses) mutable {
59       if (wrapperWasCalled) {
60         throw std::runtime_error(
61           "callback 2 arg cannot be called more than once");
62       }
63 
64       auto strongWrapper = weakWrapper.lock();
65       if (!strongWrapper) {
66         return;
67       }
68 
69       strongWrapper->jsInvoker().invokeAsync(
70         [
71           weakWrapper,
72           callbackWrapperOwner = std::move(callbackWrapperOwner),
73           responses = std::move(responses),
74           isRejectCallback
75         ]() mutable {
76           auto strongWrapper2 = weakWrapper.lock();
77           if (!strongWrapper2) {
78             return;
79           }
80 
81           jsi::Value arg = jsi::valueFromDynamic(strongWrapper2->runtime(), responses);
82           if (!isRejectCallback) {
83             strongWrapper2->callback().call(
84               strongWrapper2->runtime(),
85               (const jsi::Value *) &arg,
86               (size_t) 1
87             );
88           } else {
89             auto &rt = strongWrapper2->runtime();
90             auto jsErrorObject = arg.getObject(rt);
91             auto errorCode = jsErrorObject.getProperty(rt, "code").asString(rt);
92             auto message = jsErrorObject.getProperty(rt, "message").asString(rt);
93 
94             auto codedError = makeCodedError(
95               rt,
96               std::move(errorCode),
97               std::move(message)
98             );
99 
100             strongWrapper2->callback().call(
101               strongWrapper2->runtime(),
102               (const jsi::Value *) &codedError,
103               (size_t) 1
104             );
105           }
106 
107           callbackWrapperOwner.reset();
108         });
109 
110       wrapperWasCalled = true;
111     };
112 
113   return JavaCallback::newObjectCxxArgs(std::move(fn));
114 }
115 
116 jobjectArray MethodMetadata::convertJSIArgsToJNI(
117   JSIInteropModuleRegistry *moduleRegistry,
118   JNIEnv *env,
119   jsi::Runtime &rt,
120   const jsi::Value *args,
121   size_t count
122 ) {
123   auto argumentArray = env->NewObjectArray(
124     count,
125     JavaReferencesCache::instance()->getJClass("java/lang/Object").clazz,
126     nullptr
127   );
128 
129   std::vector<jobject> result(count);
130 
131   for (unsigned int argIndex = 0; argIndex < count; argIndex++) {
132     const jsi::Value &arg = args[argIndex];
133     auto &type = argTypes[argIndex];
134     if (arg.isNull() || arg.isUndefined()) {
135       // If value is null or undefined, we just passes a null
136       // Kotlin code will check if expected type is nullable.
137       result[argIndex] = nullptr;
138     } else {
139       if (type->converter->canConvert(rt, arg)) {
140         auto converterValue = type->converter->convert(rt, env, moduleRegistry, arg);
141         env->SetObjectArrayElement(argumentArray, argIndex, converterValue);
142         env->DeleteLocalRef(converterValue);
143       } else {
144         auto stringRepresentation = arg.toString(rt).utf8(rt);
145         throwNewJavaException(
146           UnexpectedException::create(
147             "Cannot convert '" + stringRepresentation + "' to a Kotlin type.").get()
148         );
149       }
150     }
151   }
152 
153   return argumentArray;
154 }
155 
156 MethodMetadata::MethodMetadata(
157   std::weak_ptr<react::LongLivedObjectCollection> longLivedObjectCollection,
158   std::string name,
159   int args,
160   bool isAsync,
161   jni::local_ref<jni::JArrayClass<ExpectedType>> expectedArgTypes,
162   jni::global_ref<jobject> &&jBodyReference
163 ) : name(std::move(name)),
164     args(args),
165     isAsync(isAsync),
166     jBodyReference(std::move(jBodyReference)),
167     longLivedObjectCollection_(std::move(longLivedObjectCollection)) {
168   argTypes.reserve(args);
169   for (size_t i = 0; i < args; i++) {
170     auto expectedType = expectedArgTypes->getElement(i);
171     argTypes.push_back(
172       std::make_unique<AnyType>(std::move(expectedType))
173     );
174   }
175 }
176 
177 MethodMetadata::MethodMetadata(
178   std::weak_ptr<react::LongLivedObjectCollection> longLivedObjectCollection,
179   std::string name,
180   int args,
181   bool isAsync,
182   std::vector<std::unique_ptr<AnyType>> &&expectedArgTypes,
183   jni::global_ref<jobject> &&jBodyReference
184 ) : name(std::move(name)),
185     args(args),
186     isAsync(isAsync),
187     argTypes(std::move(expectedArgTypes)),
188     jBodyReference(std::move(jBodyReference)),
189     longLivedObjectCollection_(std::move(longLivedObjectCollection)) {
190 }
191 
192 std::shared_ptr<jsi::Function> MethodMetadata::toJSFunction(
193   jsi::Runtime &runtime,
194   JSIInteropModuleRegistry *moduleRegistry
195 ) {
196   if (body == nullptr) {
197     if (isAsync) {
198       body = std::make_shared<jsi::Function>(toAsyncFunction(runtime, moduleRegistry));
199     } else {
200       body = std::make_shared<jsi::Function>(toSyncFunction(runtime, moduleRegistry));
201     }
202   }
203 
204   return body;
205 }
206 
207 jsi::Function MethodMetadata::toSyncFunction(
208   jsi::Runtime &runtime,
209   JSIInteropModuleRegistry *moduleRegistry
210 ) {
211   return jsi::Function::createFromHostFunction(
212     runtime,
213     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
214     args,
215     [this, moduleRegistry](
216       jsi::Runtime &rt,
217       const jsi::Value &thisValue,
218       const jsi::Value *args,
219       size_t count
220     ) -> jsi::Value {
221       try {
222         return this->callSync(
223           rt,
224           moduleRegistry,
225           args,
226           count
227         );
228       } catch (jni::JniException &jniException) {
229         rethrowAsCodedError(rt, jniException);
230       }
231     });
232 }
233 
234 jsi::Value MethodMetadata::callSync(
235   jsi::Runtime &rt,
236   JSIInteropModuleRegistry *moduleRegistry,
237   const jsi::Value *args,
238   size_t count
239 ) {
240   if (this->jBodyReference == nullptr) {
241     return jsi::Value::undefined();
242   }
243 
244   JNIEnv *env = jni::Environment::current();
245 
246   /**
247    * This will push a new JNI stack frame for the LocalReferences in this
248    * function call. When the stack frame for this lambda is popped,
249    * all LocalReferences are deleted.
250    */
251   jni::JniLocalScope scope(env, (int) count);
252 
253   auto convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args, count);
254 
255   // Cast in this place is safe, cause we know that this function is promise-less.
256   auto syncFunction = jni::static_ref_cast<JNIFunctionBody>(this->jBodyReference);
257   auto result = syncFunction->invoke(
258     convertedArgs
259   );
260 
261   env->DeleteLocalRef(convertedArgs);
262   if (result == nullptr) {
263     return jsi::Value::undefined();
264   }
265   auto unpackedResult = result.get();
266   auto cache = JavaReferencesCache::instance();
267   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Double").clazz)) {
268     return {jni::static_ref_cast<jni::JDouble>(result)->value()};
269   }
270   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Integer").clazz)) {
271     return {jni::static_ref_cast<jni::JInteger>(result)->value()};
272   }
273   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Long").clazz)) {
274     return {(double) jni::static_ref_cast<jni::JLong>(result)->value()};
275   }
276   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/String").clazz)) {
277     return jsi::String::createFromUtf8(
278       rt,
279       jni::static_ref_cast<jni::JString>(result)->toStdString()
280     );
281   }
282   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Boolean").clazz)) {
283     return {(bool) jni::static_ref_cast<jni::JBoolean>(result)->value()};
284   }
285   if (env->IsInstanceOf(unpackedResult, cache->getJClass("java/lang/Float").clazz)) {
286     return {(double) jni::static_ref_cast<jni::JFloat>(result)->value()};
287   }
288   if (env->IsInstanceOf(
289     unpackedResult,
290     cache->getJClass("com/facebook/react/bridge/WritableNativeArray").clazz
291   )) {
292     auto dynamic = jni::static_ref_cast<react::WritableNativeArray::javaobject>(result)
293       ->cthis()
294       ->consume();
295     return jsi::valueFromDynamic(rt, dynamic);
296   }
297   if (env->IsInstanceOf(
298     unpackedResult,
299     cache->getJClass("com/facebook/react/bridge/WritableNativeMap").clazz
300   )) {
301     auto dynamic = jni::static_ref_cast<react::WritableNativeMap::javaobject>(result)
302       ->cthis()
303       ->consume();
304     return jsi::valueFromDynamic(rt, dynamic);
305   }
306   if (env->IsInstanceOf(unpackedResult, JavaScriptModuleObject::javaClassStatic().get())) {
307     auto anonymousObject = jni::static_ref_cast<JavaScriptModuleObject::javaobject>(result)
308       ->cthis();
309     anonymousObject->jsiInteropModuleRegistry = moduleRegistry;
310     auto hostObject = std::make_shared<JavaScriptModuleObject::HostObject>(anonymousObject);
311     hostObject->jObjectRef = jni::make_global(result);
312     return jsi::Object::createFromHostObject(rt, hostObject);
313   }
314 
315   return jsi::Value::undefined();
316 }
317 
318 jsi::Function MethodMetadata::toAsyncFunction(
319   jsi::Runtime &runtime,
320   JSIInteropModuleRegistry *moduleRegistry
321 ) {
322   return jsi::Function::createFromHostFunction(
323     runtime,
324     moduleRegistry->jsRegistry->getPropNameID(runtime, name),
325     args,
326     [this, moduleRegistry](
327       jsi::Runtime &rt,
328       const jsi::Value &thisValue,
329       const jsi::Value *args,
330       size_t count
331     ) -> jsi::Value {
332       JNIEnv *env = jni::Environment::current();
333 
334       /**
335        * This will push a new JNI stack frame for the LocalReferences in this
336        * function call. When the stack frame for this lambda is popped,
337        * all LocalReferences are deleted.
338        */
339       jni::JniLocalScope scope(env, (int) count);
340 
341       auto &Promise = moduleRegistry->jsRegistry->getObject<jsi::Function>(
342         JSReferencesCache::JSKeys::PROMISE
343       );
344 
345       try {
346         auto convertedArgs = convertJSIArgsToJNI(moduleRegistry, env, rt, args, count);
347         auto globalConvertedArgs = (jobjectArray) env->NewGlobalRef(convertedArgs);
348         env->DeleteLocalRef(convertedArgs);
349 
350         // Creates a JSI promise
351         jsi::Value promise = Promise.callAsConstructor(
352           rt,
353           createPromiseBody(rt, moduleRegistry, globalConvertedArgs)
354         );
355         return promise;
356       } catch (jni::JniException &jniException) {
357         jni::local_ref<jni::JThrowable> unboxedThrowable = jniException.getThrowable();
358         if (!unboxedThrowable->isInstanceOf(CodedException::javaClassLocal())) {
359           unboxedThrowable = UnexpectedException::create(jniException.what());
360         }
361 
362         auto codedException = jni::static_ref_cast<CodedException>(unboxedThrowable);
363         auto code = codedException->getCode();
364         auto message = codedException->getLocalizedMessage().value_or("");
365 
366         jsi::Value promise = Promise.callAsConstructor(
367           rt,
368           jsi::Function::createFromHostFunction(
369             rt,
370             moduleRegistry->jsRegistry->getPropNameID(rt, "promiseFn"),
371             2,
372             [code, message](
373               jsi::Runtime &rt,
374               const jsi::Value &thisVal,
375               const jsi::Value *promiseConstructorArgs,
376               size_t promiseConstructorArgCount
377             ) {
378               if (promiseConstructorArgCount != 2) {
379                 throw std::invalid_argument("Promise fn arg count must be 2");
380               }
381 
382               jsi::Function rejectJSIFn = promiseConstructorArgs[1].getObject(rt).getFunction(rt);
383               rejectJSIFn.call(
384                 rt,
385                 makeCodedError(
386                   rt,
387                   jsi::String::createFromUtf8(rt, code),
388                   jsi::String::createFromUtf8(rt, message)
389                 )
390               );
391               return jsi::Value::undefined();
392             }
393           )
394         );
395 
396         return promise;
397       }
398     }
399   );
400 }
401 
402 jsi::Function MethodMetadata::createPromiseBody(
403   jsi::Runtime &runtime,
404   JSIInteropModuleRegistry *moduleRegistry,
405   jobjectArray globalArgs
406 ) {
407   return jsi::Function::createFromHostFunction(
408     runtime,
409     moduleRegistry->jsRegistry->getPropNameID(runtime, "promiseFn"),
410     2,
411     [this, globalArgs, moduleRegistry](
412       jsi::Runtime &rt,
413       const jsi::Value &thisVal,
414       const jsi::Value *promiseConstructorArgs,
415       size_t promiseConstructorArgCount
416     ) {
417       if (promiseConstructorArgCount != 2) {
418         throw std::invalid_argument("Promise fn arg count must be 2");
419       }
420 
421       jsi::Function resolveJSIFn = promiseConstructorArgs[0].getObject(rt).getFunction(rt);
422       jsi::Function rejectJSIFn = promiseConstructorArgs[1].getObject(rt).getFunction(rt);
423 
424       jobject resolve = createJavaCallbackFromJSIFunction(
425         std::move(resolveJSIFn),
426         longLivedObjectCollection_,
427         rt,
428         moduleRegistry
429       ).release();
430 
431       jobject reject = createJavaCallbackFromJSIFunction(
432         std::move(rejectJSIFn),
433         longLivedObjectCollection_,
434         rt,
435         moduleRegistry,
436         true
437       ).release();
438 
439       JNIEnv *env = jni::Environment::current();
440 
441       auto &jPromise = JavaReferencesCache::instance()->getJClass(
442         "expo/modules/kotlin/jni/PromiseImpl");
443       jmethodID jPromiseConstructor = jPromise.getMethod(
444         "<init>",
445         "(Lexpo/modules/kotlin/jni/JavaCallback;Lexpo/modules/kotlin/jni/JavaCallback;)V"
446       );
447 
448       // Creates a promise object
449       jobject promise = env->NewObject(
450         jPromise.clazz,
451         jPromiseConstructor,
452         resolve,
453         reject
454       );
455 
456       // Cast in this place is safe, cause we know that this function expects promise.
457       auto asyncFunction = jni::static_ref_cast<JNIAsyncFunctionBody>(this->jBodyReference);
458       asyncFunction->invoke(
459         globalArgs,
460         promise
461       );
462 
463       // We have to remove the local reference to the promise object.
464       // It doesn't mean that the promise will be deallocated, but rather that we move
465       // the ownership to the `JNIAsyncFunctionBody`.
466       env->DeleteLocalRef(promise);
467       env->DeleteGlobalRef(globalArgs);
468 
469       return jsi::Value::undefined();
470     }
471   );
472 }
473 } // namespace expo
474