<lambda>null1 package host.exp.exponent.utils
2 
3 import host.exp.exponent.di.NativeModuleDepsProvider
4 import org.mockito.Matchers
5 import org.mockito.Mockito
6 import java.lang.RuntimeException
7 import java.lang.reflect.Field
8 import javax.inject.Inject
9 
10 /*
11  * Modified NativeModuleDepsProvider to inject mocks
12  */
13 object MockExpoDI {
14   private const val ENHANCER = "$\$EnhancerByMockitoWithCGLIB$$"
15 
16   // Use this instead of .getClass because mockito wraps our classes
17   private fun typeOf(instance: Any): Class<out Any?> {
18     var type: Class<out Any?> = instance.javaClass
19     while (type.simpleName.contains(ENHANCER)) {
20       type = type.superclass
21     }
22     return type
23   }
24 
25   private var classesToInstances = mutableMapOf<Class<*>, Any>()
26 
27   fun clearMocks() {
28     classesToInstances = mutableMapOf()
29   }
30 
31   @JvmStatic fun addMock(vararg instances: Any) {
32     for (instance in instances) {
33       classesToInstances[typeOf(instance)] = instance
34     }
35   }
36 
37   @JvmStatic fun initialize() {
38     val mockInstance = Mockito.mock(NativeModuleDepsProvider::class.java)
39     Mockito.doAnswer { invocation ->
40       val args = invocation.arguments
41       inject(args[0] as Class<*>, args[1])
42       null
43     }.`when`(mockInstance).inject(Matchers.any(Class::class.java), Matchers.any())
44     NativeModuleDepsProvider.setTestInstance(mockInstance)
45   }
46 
47   private fun inject(clazz: Class<*>, target: Any) {
48     for (field in clazz.declaredFields) {
49       injectFieldInTarget(target, field)
50     }
51   }
52 
53   private fun injectFieldInTarget(target: Any, field: Field) {
54     if (field.isAnnotationPresent(Inject::class.java)) {
55       val fieldClazz = field.type
56       if (!classesToInstances.containsKey(fieldClazz)) {
57         throw RuntimeException("Mocked NativeModuleDepsProvider could not find object for class $fieldClazz")
58       }
59       val fieldObject = classesToInstances[fieldClazz]
60       try {
61         field.isAccessible = true
62         field[target] = fieldObject
63       } catch (e: IllegalAccessException) {
64         throw RuntimeException(e)
65       }
66     }
67   }
68 }
69