Refactor DNUpdater to use async/await

This commit is contained in:
Caleb Jasik 2025-02-19 17:24:41 -06:00
parent 3580427aa3
commit 1af5c48b62
No known key found for this signature in database
4 changed files with 100 additions and 68 deletions

View file

@ -1,7 +1,10 @@
import NetworkExtension import NetworkExtension
class SiteList { typealias SiteDictionary = [String: Site]
private var sites = [String: Site]()
actor SiteList {
// This keeps a reference around to the sites that are loaded. It's not referenced elsewhere.
private var sites = SiteDictionary()
/// Gets the root directory that can be used to share files between the UI and VPN process. Does ensure the directory exists /// Gets the root directory that can be used to share files between the UI and VPN process. Does ensure the directory exists
static func getRootDir() throws -> URL { static func getRootDir() throws -> URL {
@ -50,25 +53,33 @@ class SiteList {
) )
} }
init(completion: @escaping ([String: Site]?, (any Error)?) -> Void) { init?() async {
#if targetEnvironment(simulator) _ = await loadSites()
SiteList.loadAllFromFS { sites, err in
if sites != nil {
self.sites = sites!
} }
completion(sites, err)
func loadSites() async -> Result<SiteDictionary, any Error> {
#if targetEnvironment(simulator)
let sitesResult = await SiteList.loadAllFromFS()
switch sitesResult {
case .success(let sites):
self.sites = sites
return .success(sites)
case .failure(let error):
return .failure(error)
} }
#else #else
SiteList.loadAllFromNETPM { sites, err in let sitesResult = await SiteList.loadAllFromNETPM()
if sites != nil { switch sitesResult {
self.sites = sites! case .success(let sites):
} self.sites = sites
completion(sites, err) return .success(sites)
case .failure(let error):
return .failure(error)
} }
#endif #endif
} }
private static func loadAllFromFS(completion: @escaping ([String: Site]?, (any Error)?) -> Void) { private static func loadAllFromFS() async -> Result<SiteDictionary, any Error> {
let fileManager = FileManager.default let fileManager = FileManager.default
var siteDirs: [URL] var siteDirs: [URL]
var sites = [String: Site]() var sites = [String: Site]()
@ -79,8 +90,7 @@ class SiteList {
) )
} catch { } catch {
completion(nil, error) return Result.failure(error)
return
} }
for path in siteDirs { for path in siteDirs {
@ -96,55 +106,50 @@ class SiteList {
} }
} }
completion(sites, nil) return Result.success(sites)
} }
private static func loadAllFromNETPM( private static func loadAllFromNETPM() async -> Result<SiteDictionary, any Error> {
completion: @escaping ([String: Site]?, (any Error)?) -> Void
) {
var sites = [String: Site]() var sites = [String: Site]()
// dispatchGroup is used to ensure we have migrated all sites before returning them do {
// If there are no sites to migrate, there are never any entrants let newManagers = try await NETunnelProviderManager.loadAllFromPreferences()
let dispatchGroup = DispatchGroup() for manager in newManagers {
NETunnelProviderManager.loadAllFromPreferences { newManagers, err in
if err != nil {
return completion(nil, err)
}
newManagers?.forEach { manager in
do { do {
let site = try Site(manager: manager) let site = try Site(manager: manager)
if site.needsToMigrateToFS { if site.needsToMigrateToFS {
dispatchGroup.enter() let error = await withCheckedContinuation({ continuation in
site.incomingSite?.save(manager: manager) { error in site.incomingSite?.save(manager: manager) { error in
continuation.resume(returning: error)
}
})
if error != nil { if error != nil {
print("Error while migrating site to fs: \(error!.localizedDescription)") print("Error while migrating site to fs: \(error!.localizedDescription)")
} }
print("Migrated site to fs: \(site.name)") print("Migrated site to fs: \(site.name)")
site.needsToMigrateToFS = false site.needsToMigrateToFS = false
dispatchGroup.leave()
}
} }
sites[site.id] = site sites[site.id] = site
} catch { } catch {
// TODO: notify the user about this // TODO: notify the user about this
print("Deleted non conforming site \(manager) \(error)") print("Deleted non conforming site \(manager) \(error)")
manager.removeFromPreferences() try await manager.removeFromPreferences()
// TODO: delete from disk, we need to try and discover the site id though // TODO: delete from disk, we need to try and discover the site id though
} }
} }
dispatchGroup.notify(queue: .main) { return Result.success(sites)
completion(sites, nil)
} } catch {
return Result.failure(error)
} }
} }
func getSites() -> [String: Site] { func getSites() -> SiteDictionary {
return sites return sites
} }
} }

View file

@ -26,14 +26,8 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
GeneratedPluginRegistrant.register(with: self) GeneratedPluginRegistrant.register(with: self)
Task { Task {
await dnUpdater.updateAllLoop { @MainActor site in for await site in await dnUpdater.siteUpdates {
// Signal the site has changed in case the current site details screen is active self.sites?.updateSite(site: site)
let container = self.sites?.getContainer(id: site.id)
if container != nil {
// Update references to the site with the new site config
container!.site = site
container!.updater.update(connected: site.connected ?? false, replaceSite: site)
}
// Signal to the main screen to reload // Signal to the main screen to reload
self.ui?.invokeMethod("refreshSites", arguments: nil) self.ui?.invokeMethod("refreshSites", arguments: nil)
@ -167,17 +161,24 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
} }
func listSites(result: @escaping FlutterResult) { func listSites(result: @escaping FlutterResult) {
sites?.loadSites { sites, err in Task {
if err != nil { let sitesResult = await sites?.loadSites()
return result( switch sitesResult {
CallFailedError(message: "Failed to load site list", details: err!.localizedDescription)) case let .success(sites):
}
let encoder = JSONEncoder() let encoder = JSONEncoder()
let data = try! encoder.encode(sites) let data = try! encoder.encode(sites)
let ret = String(data: data, encoding: .utf8) let ret = String(data: data, encoding: .utf8)
result(ret) result(ret)
case let .failure(error):
return result(
CallFailedError(message: "Failed to load site list", details: error.localizedDescription))
case nil:
return result(
CallFailedError(message: "Failed to load site list"))
} }
}
} }
func deleteSite(call: FlutterMethodCall, result: @escaping FlutterResult) { func deleteSite(call: FlutterMethodCall, result: @escaping FlutterResult) {
@ -208,9 +209,11 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
CallFailedError(message: "Failed to save site", details: error!.localizedDescription)) CallFailedError(message: "Failed to save site", details: error!.localizedDescription))
} }
self.sites?.loadSites { _, _ in Task {
_ = await self.sites?.loadSites()
result(nil) result(nil)
} }
} }
} }

View file

@ -7,8 +7,9 @@ actor DNUpdater {
private let log = Logger(subsystem: "net.defined.mobileNebula", category: "DNUpdater") private let log = Logger(subsystem: "net.defined.mobileNebula", category: "DNUpdater")
func updateAll(onUpdate: @Sendable @escaping (Site) -> Void) { func updateAll(onUpdate: @Sendable @escaping (Site) -> Void) {
_ = SiteList { sites, _ in Task {
guard let unwrappedSites = sites else { let sitesResult = await SiteList()?.loadSites()
guard case let .success(unwrappedSites) = sitesResult else {
// There was an error, let's bail. // There was an error, let's bail.
return return
} }
@ -24,6 +25,22 @@ actor DNUpdater {
await self.updateSite(site: site, onUpdate: onUpdate) await self.updateSite(site: site, onUpdate: onUpdate)
} }
} }
}
}
// Site updates provides an async/await alternative to `.updateAllLoop` that doesn't require a sendable closure.
// https://developer.apple.com/documentation/swift/asyncstream
var siteUpdates: AsyncStream<Site> {
AsyncStream { continuation in
timer.eventHandler = {
self.updateAll(onUpdate: { site in
continuation.yield(site)
})
}
timer.resume()
} }
} }

View file

@ -1,6 +1,10 @@
import MobileNebula import MobileNebula
import NetworkExtension import NetworkExtension
enum SitesListError: Error {
case missingSitesList
}
class SiteContainer { class SiteContainer {
var site: Site var site: Site
var updater: SiteUpdater var updater: SiteUpdater
@ -19,13 +23,13 @@ class Sites {
self.messenger = messenger self.messenger = messenger
} }
func loadSites(completion: @escaping ([String: Site]?, (any Error)?) -> Void) { func loadSites() async -> Result<[String: Site], any Error> {
_ = SiteList { sites, err in let sitesResult = await SiteList()?.loadSites()
if err != nil { switch sitesResult {
return completion(nil, err) case .failure(let error):
} return Result.failure(error)
case .success(let sites):
sites?.values.forEach { site in sites.values.forEach { site in
var updater = self.containers[site.id]?.updater var updater = self.containers[site.id]?.updater
if updater != nil { if updater != nil {
updater!.setSite(site: site) updater!.setSite(site: site)
@ -38,8 +42,11 @@ class Sites {
let justSites = self.containers.mapValues { let justSites = self.containers.mapValues {
$0.site $0.site
} }
completion(justSites, nil) return Result.success(justSites)
case nil:
return Result.failure(SitesListError.missingSitesList)
} }
} }
func deleteSite(id: String, callback: @escaping ((any Error)?) -> Void) { func deleteSite(id: String, callback: @escaping ((any Error)?) -> Void) {