diff --git a/ios/NebulaNetworkExtension/SiteList.swift b/ios/NebulaNetworkExtension/SiteList.swift index aaddc61..b6f3c54 100644 --- a/ios/NebulaNetworkExtension/SiteList.swift +++ b/ios/NebulaNetworkExtension/SiteList.swift @@ -1,7 +1,10 @@ import NetworkExtension -class SiteList { - private var sites = [String: Site]() +typealias SiteDictionary = [String: Site] + +actor SiteList { + // This keeps a reference around to the sites that are loaded. It's not referenced elsewhere. + private var sites = SiteDictionary() /// Gets the root directory that can be used to share files between the UI and VPN process. Does ensure the directory exists static func getRootDir() throws -> URL { @@ -50,25 +53,33 @@ class SiteList { ) } - init(completion: @escaping ([String: Site]?, (any Error)?) -> Void) { + init?() async { + _ = await loadSites() + } + + func loadSites() async -> Result { #if targetEnvironment(simulator) - SiteList.loadAllFromFS { sites, err in - if sites != nil { - self.sites = sites! - } - completion(sites, err) + let sitesResult = await SiteList.loadAllFromFS() + switch sitesResult { + case .success(let sites): + self.sites = sites + return .success(sites) + case .failure(let error): + return .failure(error) } #else - SiteList.loadAllFromNETPM { sites, err in - if sites != nil { - self.sites = sites! - } - completion(sites, err) + let sitesResult = await SiteList.loadAllFromNETPM() + switch sitesResult { + case .success(let sites): + self.sites = sites + return .success(sites) + case .failure(let error): + return .failure(error) } #endif } - private static func loadAllFromFS(completion: @escaping ([String: Site]?, (any Error)?) -> Void) { + private static func loadAllFromFS() async -> Result { let fileManager = FileManager.default var siteDirs: [URL] var sites = [String: Site]() @@ -79,8 +90,7 @@ class SiteList { ) } catch { - completion(nil, error) - return + return Result.failure(error) } for path in siteDirs { @@ -96,55 +106,50 @@ class SiteList { } } - completion(sites, nil) + return Result.success(sites) } - private static func loadAllFromNETPM( - completion: @escaping ([String: Site]?, (any Error)?) -> Void - ) { + private static func loadAllFromNETPM() async -> Result { var sites = [String: Site]() - // dispatchGroup is used to ensure we have migrated all sites before returning them - // If there are no sites to migrate, there are never any entrants - let dispatchGroup = DispatchGroup() - - NETunnelProviderManager.loadAllFromPreferences { newManagers, err in - if err != nil { - return completion(nil, err) - } - - newManagers?.forEach { manager in + do { + let newManagers = try await NETunnelProviderManager.loadAllFromPreferences() + for manager in newManagers { do { let site = try Site(manager: manager) if site.needsToMigrateToFS { - dispatchGroup.enter() - site.incomingSite?.save(manager: manager) { error in - if error != nil { - print("Error while migrating site to fs: \(error!.localizedDescription)") + let error = await withCheckedContinuation({ continuation in + site.incomingSite?.save(manager: manager) { error in + continuation.resume(returning: error) } + }) - print("Migrated site to fs: \(site.name)") - site.needsToMigrateToFS = false - dispatchGroup.leave() + if error != nil { + print("Error while migrating site to fs: \(error!.localizedDescription)") } + + print("Migrated site to fs: \(site.name)") + site.needsToMigrateToFS = false + } sites[site.id] = site } catch { // TODO: notify the user about this print("Deleted non conforming site \(manager) \(error)") - manager.removeFromPreferences() + try await manager.removeFromPreferences() // TODO: delete from disk, we need to try and discover the site id though } } - dispatchGroup.notify(queue: .main) { - completion(sites, nil) - } + return Result.success(sites) + + } catch { + return Result.failure(error) } } - func getSites() -> [String: Site] { + func getSites() -> SiteDictionary { return sites } } diff --git a/ios/Runner/AppDelegate.swift b/ios/Runner/AppDelegate.swift index 278138d..fe308cb 100644 --- a/ios/Runner/AppDelegate.swift +++ b/ios/Runner/AppDelegate.swift @@ -26,14 +26,8 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError { GeneratedPluginRegistrant.register(with: self) Task { - await dnUpdater.updateAllLoop { @MainActor site in - // Signal the site has changed in case the current site details screen is active - let container = self.sites?.getContainer(id: site.id) - if container != nil { - // Update references to the site with the new site config - container!.site = site - container!.updater.update(connected: site.connected ?? false, replaceSite: site) - } + for await site in await dnUpdater.siteUpdates { + self.sites?.updateSite(site: site) // Signal to the main screen to reload self.ui?.invokeMethod("refreshSites", arguments: nil) @@ -167,17 +161,24 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError { } func listSites(result: @escaping FlutterResult) { - sites?.loadSites { sites, err in - if err != nil { + Task { + let sitesResult = await sites?.loadSites() + switch sitesResult { + case let .success(sites): + let encoder = JSONEncoder() + let data = try! encoder.encode(sites) + let ret = String(data: data, encoding: .utf8) + result(ret) + case let .failure(error): return result( - CallFailedError(message: "Failed to load site list", details: err!.localizedDescription)) + CallFailedError(message: "Failed to load site list", details: error.localizedDescription)) + case nil: + return result( + CallFailedError(message: "Failed to load site list")) } - let encoder = JSONEncoder() - let data = try! encoder.encode(sites) - let ret = String(data: data, encoding: .utf8) - result(ret) } + } func deleteSite(call: FlutterMethodCall, result: @escaping FlutterResult) { @@ -208,9 +209,11 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError { CallFailedError(message: "Failed to save site", details: error!.localizedDescription)) } - self.sites?.loadSites { _, _ in + Task { + _ = await self.sites?.loadSites() result(nil) } + } } diff --git a/ios/Runner/DNUpdate.swift b/ios/Runner/DNUpdate.swift index 1739708..1a59319 100644 --- a/ios/Runner/DNUpdate.swift +++ b/ios/Runner/DNUpdate.swift @@ -7,8 +7,9 @@ actor DNUpdater { private let log = Logger(subsystem: "net.defined.mobileNebula", category: "DNUpdater") func updateAll(onUpdate: @Sendable @escaping (Site) -> Void) { - _ = SiteList { sites, _ in - guard let unwrappedSites = sites else { + Task { + let sitesResult = await SiteList()?.loadSites() + guard case let .success(unwrappedSites) = sitesResult else { // There was an error, let's bail. return } @@ -24,6 +25,22 @@ actor DNUpdater { await self.updateSite(site: site, onUpdate: onUpdate) } } + + } + + } + + // Site updates provides an async/await alternative to `.updateAllLoop` that doesn't require a sendable closure. + // https://developer.apple.com/documentation/swift/asyncstream + var siteUpdates: AsyncStream { + AsyncStream { continuation in + timer.eventHandler = { + self.updateAll(onUpdate: { site in + continuation.yield(site) + }) + } + timer.resume() + } } diff --git a/ios/Runner/Sites.swift b/ios/Runner/Sites.swift index 69f9acf..101fa2e 100644 --- a/ios/Runner/Sites.swift +++ b/ios/Runner/Sites.swift @@ -1,6 +1,10 @@ import MobileNebula import NetworkExtension +enum SitesListError: Error { + case missingSitesList +} + class SiteContainer { var site: Site var updater: SiteUpdater @@ -19,13 +23,13 @@ class Sites { self.messenger = messenger } - func loadSites(completion: @escaping ([String: Site]?, (any Error)?) -> Void) { - _ = SiteList { sites, err in - if err != nil { - return completion(nil, err) - } - - sites?.values.forEach { site in + func loadSites() async -> Result<[String: Site], any Error> { + let sitesResult = await SiteList()?.loadSites() + switch sitesResult { + case .failure(let error): + return Result.failure(error) + case .success(let sites): + sites.values.forEach { site in var updater = self.containers[site.id]?.updater if updater != nil { updater!.setSite(site: site) @@ -38,8 +42,11 @@ class Sites { let justSites = self.containers.mapValues { $0.site } - completion(justSites, nil) + return Result.success(justSites) + case nil: + return Result.failure(SitesListError.missingSitesList) } + } func deleteSite(id: String, callback: @escaping ((any Error)?) -> Void) {