1 import ExpoModulesCore
2 import sqlite3
3 
4 public final class SQLiteModule: Module {
5   private var cachedDatabases = [String: OpaquePointer]()
6   private var hasListeners = false
7   private lazy var selfPointer = Unmanaged.passRetained(self).toOpaque()
8 
9   public func definition() -> ModuleDefinition {
10     Name("ExpoSQLite")
11 
12     Events("onDatabaseChange")
13 
14     OnCreate {
15       crsqlite_init_from_swift()
16     }
17 
18     AsyncFunction("exec") { (dbName: String, queries: [[Any]], readOnly: Bool) -> [Any?] in
19       guard let db = openDatabase(dbName: dbName) else {
20         throw DatabaseException()
21       }
22 
23       let results = try queries.map { query in
24         guard let sql = query[0] as? String else {
25           throw InvalidSqlException()
26         }
27 
28         guard let args = query[1] as? [Any] else {
29           throw InvalidArgumentsException()
30         }
31 
32         return executeSql(sql: sql, with: args, for: db, readOnly: readOnly)
33       }
34 
35       return results
36     }
37 
38     AsyncFunction("close") { (dbName: String) in
39       cachedDatabases.removeValue(forKey: dbName)
40     }
41 
42     Function("closeSync") { (dbName: String) in
43       cachedDatabases.removeValue(forKey: dbName)
44     }
45 
46     AsyncFunction("deleteAsync") { (dbName: String) in
47       if cachedDatabases[dbName] != nil {
48         throw DeleteDatabaseException(dbName)
49       }
50 
51       guard let path = self.pathForDatabaseName(name: dbName) else {
52         throw Exceptions.FileSystemModuleNotFound()
53       }
54 
55       if !FileManager.default.fileExists(atPath: path.absoluteString) {
56         throw DatabaseNotFoundException(dbName)
57       }
58 
59       do {
60         try FileManager.default.removeItem(atPath: path.absoluteString)
61       } catch {
62         throw DeleteDatabaseFileException(dbName)
63       }
64     }
65 
66     OnStartObserving {
67       hasListeners = true
68     }
69 
70     OnStopObserving {
71       hasListeners = false
72     }
73 
74     OnDestroy {
75       cachedDatabases.values.forEach {
76         executeSql(sql: "SELECT crsql_finalize()", with: [], for: $0, readOnly: false)
77         sqlite3_close($0)
78       }
79     }
80   }
81 
82   private func pathForDatabaseName(name: String) -> URL? {
83     guard let fileSystem = appContext?.fileSystem else {
84       return nil
85     }
86 
87     let directory = URL(string: fileSystem.documentDirectory)?.appendingPathComponent("SQLite")
88     fileSystem.ensureDirExists(withPath: directory?.absoluteString)
89 
90     return directory?.appendingPathComponent(name)
91   }
92 
93   private func openDatabase(dbName: String) -> OpaquePointer? {
94     var db: OpaquePointer?
95     guard let path = pathForDatabaseName(name: dbName) else {
96       return nil
97     }
98 
99     let fileExists = FileManager.default.fileExists(atPath: path.absoluteString)
100 
101     if fileExists {
102       db = cachedDatabases[dbName]
103     }
104 
105     if let db {
106       return db
107     }
108 
109     cachedDatabases.removeValue(forKey: dbName)
110 
111     if sqlite3_open(path.absoluteString, &db) != SQLITE_OK {
112       return nil
113     }
114 
115     sqlite3_update_hook(
116       db, { (obj, action, _, tableName, rowId) in
117         if let obj, let tableName {
118           let selfObj = Unmanaged<SQLiteModule>.fromOpaque(obj).takeUnretainedValue()
119           if selfObj.hasListeners {
120             selfObj.sendEvent("onDatabaseChange", [
121               "tableName": String(cString: UnsafePointer(tableName)),
122               "rowId": rowId,
123               "typeId": SqlAction.fromCode(value: action)
124             ])
125           }
126         }
127       },
128       selfPointer
129     )
130 
131     cachedDatabases[dbName] = db
132     return db
133   }
134 
135   private func executeSql(sql: String, with args: [Any], for db: OpaquePointer, readOnly: Bool) -> [Any?] {
136     var resultRows = [Any]()
137     var statement: OpaquePointer?
138     var rowsAffected: Int32 = 0
139     var insertId: Int64 = 0
140     var error: String?
141 
142     if sqlite3_prepare_v2(db, sql, -1, &statement, nil) != SQLITE_OK {
143       return [convertSqlLiteErrorToString(db: db)]
144     }
145 
146     let queryIsReadOnly = sqlite3_stmt_readonly(statement) > 0
147 
148     if readOnly && !queryIsReadOnly {
149       return ["could not prepare \(sql)"]
150     }
151 
152     for (index, arg) in args.enumerated() {
153       guard let obj = arg as? NSObject else { continue }
154       bindStatement(statement: statement, with: obj, at: Int32(index + 1))
155     }
156 
157     var columnCount: Int32 = 0
158     var columnNames = [String]()
159     var columnType: Int32
160     var fetchedColumns = false
161     var value: Any?
162     var hasMore = true
163 
164     while hasMore {
165       let result = sqlite3_step(statement)
166 
167       switch result {
168       case SQLITE_ROW:
169         if !fetchedColumns {
170           columnCount = sqlite3_column_count(statement)
171 
172           for i in 0..<Int(columnCount) {
173             let columnName = NSString(format: "%s", sqlite3_column_name(statement, Int32(i))) as String
174             columnNames.append(columnName)
175           }
176           fetchedColumns = true
177         }
178 
179         var entry = [Any]()
180 
181         for i in 0..<Int(columnCount) {
182           columnType = sqlite3_column_type(statement, Int32(i))
183           value = getSqlValue(for: columnType, with: statement, index: Int32(i))
184           entry.append(value)
185         }
186 
187         resultRows.append(entry)
188       case SQLITE_DONE:
189         hasMore = false
190       default:
191         error = convertSqlLiteErrorToString(db: db)
192         hasMore = false
193       }
194     }
195 
196     if !queryIsReadOnly {
197       rowsAffected = sqlite3_changes(db)
198       if rowsAffected > 0 {
199         insertId = sqlite3_last_insert_rowid(db)
200       }
201     }
202 
203     sqlite3_finalize(statement)
204 
205     if error != nil {
206       return [error]
207     }
208 
209     return [nil, insertId, rowsAffected, columnNames, resultRows]
210   }
211 
212   private func bindStatement(statement: OpaquePointer?, with arg: NSObject, at index: Int32) {
213     if arg == NSNull() {
214       sqlite3_bind_null(statement, index)
215     } else if arg is Double {
216       sqlite3_bind_double(statement, index, arg as? Double ?? 0.0)
217     } else {
218       var stringArg: NSString
219 
220       if arg is NSString {
221         stringArg = NSString(format: "%@", arg)
222       } else {
223         stringArg = arg.description as NSString
224       }
225 
226       let SQLITE_TRANSIENT = unsafeBitCast(OpaquePointer(bitPattern: -1), to: sqlite3_destructor_type.self)
227 
228       let data = stringArg.data(using: NSUTF8StringEncoding)
229       sqlite3_bind_text(statement, index, stringArg.utf8String, Int32(data?.count ?? 0), SQLITE_TRANSIENT)
230     }
231   }
232 
233   private func getSqlValue(for columnType: Int32, with statement: OpaquePointer?, index: Int32) -> Any? {
234     switch columnType {
235     case SQLITE_INTEGER:
236       return sqlite3_column_int64(statement, index)
237     case SQLITE_FLOAT:
238       return sqlite3_column_double(statement, index)
239     case SQLITE_BLOB, SQLITE_TEXT:
240       return NSString(bytes: sqlite3_column_text(statement, index), length: Int(sqlite3_column_bytes(statement, index)), encoding: NSUTF8StringEncoding)
241     default:
242       return nil
243     }
244   }
245 
246   private func convertSqlLiteErrorToString(db: OpaquePointer?) -> String {
247     let code = sqlite3_errcode(db)
248     let message = NSString(utf8String: sqlite3_errmsg(db)) ?? ""
249     return NSString(format: "Error code %i: %@", code, message) as String
250   }
251 }
252 
253 enum SqlAction: String, Enumerable {
254   case insert
255   case delete
256   case update
257   case unknown
258 
259   static func fromCode(value: Int32) -> SqlAction {
260     switch value {
261     case 9:
262       return .delete
263     case 18:
264       return .insert
265     case 23:
266       return .update
267     default:
268       return .unknown
269     }
270   }
271 }
272