1 import ExpoModulesCore
2 import sqlite3
3 
4 public final class SQLiteModule: Module {
5   private var cachedDatabases = [String: OpaquePointer]()
6   private var hasListeners = false
7   private lazy var selfPointer = Unmanaged.passRetained(self).toOpaque()
8 
9   public func definition() -> ModuleDefinition {
10     Name("ExpoSQLite")
11 
12     Events("onDatabaseChange")
13 
14     OnCreate {
15       crsqlite_init_from_swift()
16     }
17 
18     AsyncFunction("exec") { (dbName: String, queries: [[Any]], readOnly: Bool) -> [Any?] in
19       guard let db = openDatabase(dbName: dbName) else {
20         throw DatabaseException()
21       }
22 
23       let results = try queries.map { query in
24         guard let sql = query[0] as? String else {
25           throw InvalidSqlException()
26         }
27 
28         guard let args = query[1] as? [Any] else {
29           throw InvalidArgumentsException()
30         }
31 
32         return executeSql(sql: sql, with: args, for: db, readOnly: readOnly)
33       }
34 
35       return results
36     }
37 
38     AsyncFunction("close") { (dbName: String) in
39       cachedDatabases.removeValue(forKey: dbName)
40     }
41 
42     Function("closeSync") { (dbName: String) in
43       cachedDatabases.removeValue(forKey: dbName)
44     }
45 
46     AsyncFunction("deleteAsync") { (dbName: String) in
47       if cachedDatabases[dbName] != nil {
48         throw DeleteDatabaseException(dbName)
49       }
50 
51       guard let path = self.pathForDatabaseName(name: dbName) else {
52         throw Exceptions.FileSystemModuleNotFound()
53       }
54 
55       if !FileManager.default.fileExists(atPath: path.absoluteString) {
56         throw DatabaseNotFoundException(dbName)
57       }
58 
59       do {
60         try FileManager.default.removeItem(atPath: path.absoluteString)
61       } catch {
62         throw DeleteDatabaseFileException(dbName)
63       }
64     }
65 
66     OnStartObserving {
67       hasListeners = true
68     }
69 
70     OnStopObserving {
71       hasListeners = false
72     }
73 
74     OnDestroy {
75       cachedDatabases.values.forEach {
76         executeSql(sql: "SELECT crsql_finalize()", with: [], for: $0, readOnly: false)
77         sqlite3_close($0)
78       }
79     }
80   }
81 
82   private func pathForDatabaseName(name: String) -> URL? {
83     guard let fileSystem = appContext?.fileSystem else {
84       return nil
85     }
86 
87     let directory = URL(string: fileSystem.documentDirectory)?.appendingPathComponent("SQLite")
88     fileSystem.ensureDirExists(withPath: directory?.absoluteString)
89 
90     return directory?.appendingPathComponent(name)
91   }
92 
93   private func openDatabase(dbName: String) -> OpaquePointer? {
94     var db: OpaquePointer?
95     guard let path = pathForDatabaseName(name: dbName) else {
96       return nil
97     }
98 
99     let fileExists = FileManager.default.fileExists(atPath: path.absoluteString)
100 
101     if fileExists {
102       db = cachedDatabases[dbName]
103     }
104 
105     if db == nil {
106       cachedDatabases.removeValue(forKey: dbName)
107 
108       if sqlite3_open(path.absoluteString, &db) != SQLITE_OK {
109         return nil
110       }
111 
112       sqlite3_update_hook(
113         db,
114         { (obj, action, _, tableName, rowId) in
115           if let obj, let tableName {
116             let selfObj = Unmanaged<SQLiteModule>.fromOpaque(obj).takeUnretainedValue()
117             selfObj.sendEvent("onDatabaseChange", [
118               "tableName": String(cString: UnsafePointer(tableName)),
119               "rowId": rowId,
120               "typeId": SqlAction.fromCode(value: action)
121             ])
122           }
123         },
124         selfPointer
125       )
126 
127       cachedDatabases[dbName] = db
128     }
129     return db
130   }
131 
132   private func executeSql(sql: String, with args: [Any], for db: OpaquePointer, readOnly: Bool) -> [Any?] {
133     var resultRows = [Any]()
134     var statement: OpaquePointer?
135     var rowsAffected: Int32 = 0
136     var insertId: Int64 = 0
137     var error: String?
138 
139     if sqlite3_prepare_v2(db, sql, -1, &statement, nil) != SQLITE_OK {
140       return [convertSqlLiteErrorToString(db: db)]
141     }
142 
143     let queryIsReadOnly = sqlite3_stmt_readonly(statement) > 0
144 
145     if readOnly && !queryIsReadOnly {
146       return ["could not prepare \(sql)"]
147     }
148 
149     for (index, arg) in args.enumerated() {
150       guard let obj = arg as? NSObject else { continue }
151       bindStatement(statement: statement, with: obj, at: Int32(index + 1))
152     }
153 
154     var columnCount: Int32 = 0
155     var columnNames = [String]()
156     var columnType: Int32
157     var fetchedColumns = false
158     var value: Any?
159     var hasMore = true
160 
161     while hasMore {
162       let result = sqlite3_step(statement)
163 
164       switch result {
165       case SQLITE_ROW:
166         if !fetchedColumns {
167           columnCount = sqlite3_column_count(statement)
168 
169           for i in 0..<Int(columnCount) {
170             let columnName = NSString(format: "%s", sqlite3_column_name(statement, Int32(i))) as String
171             columnNames.append(columnName)
172           }
173           fetchedColumns = true
174         }
175 
176         var entry = [Any]()
177 
178         for i in 0..<Int(columnCount) {
179           columnType = sqlite3_column_type(statement, Int32(i))
180           value = getSqlValue(for: columnType, with: statement, index: Int32(i))
181           entry.append(value)
182         }
183 
184         resultRows.append(entry)
185       case SQLITE_DONE:
186         hasMore = false
187       default:
188         error = convertSqlLiteErrorToString(db: db)
189         hasMore = false
190       }
191     }
192 
193     if !queryIsReadOnly {
194       rowsAffected = sqlite3_changes(db)
195       if rowsAffected > 0 {
196         insertId = sqlite3_last_insert_rowid(db)
197       }
198     }
199 
200     sqlite3_finalize(statement)
201 
202     if error != nil {
203       return [error]
204     }
205 
206     return [nil, insertId, rowsAffected, columnNames, resultRows]
207   }
208 
209   private func bindStatement(statement: OpaquePointer?, with arg: NSObject, at index: Int32) {
210     if arg == NSNull() {
211       sqlite3_bind_null(statement, index)
212     } else if arg is Double {
213       sqlite3_bind_double(statement, index, arg as? Double ?? 0.0)
214     } else {
215       var stringArg: NSString
216 
217       if arg is NSString {
218         stringArg = NSString(format: "%@", arg)
219       } else {
220         stringArg = arg.description as NSString
221       }
222 
223       let SQLITE_TRANSIENT = unsafeBitCast(OpaquePointer(bitPattern: -1), to: sqlite3_destructor_type.self)
224 
225       let data = stringArg.data(using: NSUTF8StringEncoding)
226       sqlite3_bind_text(statement, index, stringArg.utf8String, Int32(data?.count ?? 0), SQLITE_TRANSIENT)
227     }
228   }
229 
230   private func getSqlValue(for columnType: Int32, with statement: OpaquePointer?, index: Int32) -> Any? {
231     switch columnType {
232     case SQLITE_INTEGER:
233       return sqlite3_column_int64(statement, index)
234     case SQLITE_FLOAT:
235       return sqlite3_column_double(statement, index)
236     case SQLITE_BLOB, SQLITE_TEXT:
237       return NSString(bytes: sqlite3_column_text(statement, index), length: Int(sqlite3_column_bytes(statement, index)), encoding: NSUTF8StringEncoding)
238     default:
239       return nil
240     }
241   }
242 
243   private func convertSqlLiteErrorToString(db: OpaquePointer?) -> String {
244     let code = sqlite3_errcode(db)
245     let message = NSString(utf8String: sqlite3_errmsg(db)) ?? ""
246     return NSString(format: "Error code %i: %@", code, message) as String
247   }
248 }
249 
250 enum SqlAction: Int, Enumerable {
251   case insert
252   case delete
253   case update
254   case unknown
255 
256   static func fromCode(value: Int32) -> SqlAction {
257     switch value {
258     case 9:
259       return .delete
260     case 18:
261       return .insert
262     case 23:
263       return .update
264     default:
265       return .unknown
266     }
267   }
268 }
269