Make DNUpdater an actor to enforce thread safety

This commit is contained in:
Caleb Jasik 2025-02-11 15:42:24 -06:00
parent 9c19e39891
commit 06ed3dfaaa
No known key found for this signature in database
3 changed files with 185 additions and 150 deletions

View file

@ -116,7 +116,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
} }
self.nebula!.start() self.nebula!.start()
self.dnUpdater.updateSingleLoop(site: self.site!, onUpdate: self.handleDNUpdate) await self.dnUpdater.updateSingleLoop(site: self.site!, onUpdate: self.handleDNUpdate)
} }
private func handleDNUpdate(newSite: Site) { private func handleDNUpdate(newSite: Site) {

View file

@ -1,8 +1,8 @@
import UIKit
@preconcurrency import Flutter @preconcurrency import Flutter
import MobileNebula import MobileNebula
import NetworkExtension import NetworkExtension
import SwiftyJSON import SwiftyJSON
import UIKit
enum ChannelName { enum ChannelName {
static let vpn = "net.defined.mobileNebula/NebulaVpnService" static let vpn = "net.defined.mobileNebula/NebulaVpnService"
@ -19,25 +19,29 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
private var sites: Sites? private var sites: Sites?
private var ui: FlutterMethodChannel? private var ui: FlutterMethodChannel?
override func application( override func application(
_ application: UIApplication, _ application: UIApplication,
didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]? didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?
) -> Bool { ) -> Bool {
GeneratedPluginRegistrant.register(with: self) GeneratedPluginRegistrant.register(with: self)
Task.detached {
dnUpdater.updateAllLoop { site in await self.dnUpdater.updateAllLoop { [weak self] site in
// Signal the site has changed in case the current site details screen is active // Signal the site has changed in case the current site details screen is active
let container = self.sites?.getContainer(id: site.id) let container = self?.sites?.getContainer(id: site.id)
if (container != nil) { if container != nil {
// Update references to the site with the new site config // Update references to the site with the new site config
container!.site = site container!.site = site
container!.updater.update(connected: site.connected ?? false, replaceSite: site) container!.updater.update(connected: site.connected ?? false, replaceSite: site)
} }
// Send the refresh sites command on the main thread
DispatchQueue.main.async {
// Signal to the main screen to reload // Signal to the main screen to reload
self.ui?.invokeMethod("refreshSites", arguments: nil) self?.ui?.invokeMethod("refreshSites", arguments: nil)
}
}
} }
guard let controller = window?.rootViewController as? FlutterViewController else { guard let controller = window?.rootViewController as? FlutterViewController else {
@ -47,7 +51,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
sites = Sites(messenger: controller.binaryMessenger) sites = Sites(messenger: controller.binaryMessenger)
ui = FlutterMethodChannel(name: ChannelName.vpn, binaryMessenger: controller.binaryMessenger) ui = FlutterMethodChannel(name: ChannelName.vpn, binaryMessenger: controller.binaryMessenger)
ui!.setMethodCallHandler({(call: FlutterMethodCall, result: @escaping FlutterResult) -> Void in ui!.setMethodCallHandler({ (call: FlutterMethodCall, result: @escaping FlutterResult) -> Void in
switch call.method { switch call.method {
case "nebula.parseCerts": return self.nebulaParseCerts(call: call, result: result) case "nebula.parseCerts": return self.nebulaParseCerts(call: call, result: result)
case "nebula.generateKeyPair": return self.nebulaGenerateKeyPair(result: result) case "nebula.generateKeyPair": return self.nebulaGenerateKeyPair(result: result)
@ -62,11 +66,16 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
case "startSite": return self.startSite(call: call, result: result) case "startSite": return self.startSite(call: call, result: result)
case "stopSite": return self.stopSite(call: call, result: result) case "stopSite": return self.stopSite(call: call, result: result)
case "active.listHostmap": self.vpnRequest(command: "listHostmap", arguments: call.arguments, result: result) case "active.listHostmap":
case "active.listPendingHostmap": self.vpnRequest(command: "listPendingHostmap", arguments: call.arguments, result: result) self.vpnRequest(command: "listHostmap", arguments: call.arguments, result: result)
case "active.getHostInfo": self.vpnRequest(command: "getHostInfo", arguments: call.arguments, result: result) case "active.listPendingHostmap":
case "active.setRemoteForTunnel": self.vpnRequest(command: "setRemoteForTunnel", arguments: call.arguments, result: result) self.vpnRequest(command: "listPendingHostmap", arguments: call.arguments, result: result)
case "active.closeTunnel": self.vpnRequest(command: "closeTunnel", arguments: call.arguments, result: result) case "active.getHostInfo":
self.vpnRequest(command: "getHostInfo", arguments: call.arguments, result: result)
case "active.setRemoteForTunnel":
self.vpnRequest(command: "setRemoteForTunnel", arguments: call.arguments, result: result)
case "active.closeTunnel":
self.vpnRequest(command: "closeTunnel", arguments: call.arguments, result: result)
default: default:
result(FlutterMethodNotImplemented) result(FlutterMethodNotImplemented)
@ -77,28 +86,35 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
} }
func nebulaParseCerts(call: FlutterMethodCall, result: FlutterResult) { func nebulaParseCerts(call: FlutterMethodCall, result: FlutterResult) {
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) } guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
guard let certs = args["certs"] else { return result(MissingArgumentError(message: "certs is a required argument")) } guard let certs = args["certs"] else {
return result(MissingArgumentError(message: "certs is a required argument"))
}
var err: NSError? var err: NSError?
let json = MobileNebulaParseCerts(certs, &err) let json = MobileNebulaParseCerts(certs, &err)
if (err != nil) { if err != nil {
return result(CallFailedError(message: "Error while parsing certificate(s)", details: err!.localizedDescription)) return result(
CallFailedError(message: "Error while parsing certificate(s)", details: err!.localizedDescription))
} }
return result(json) return result(json)
} }
func nebulaVerifyCertAndKey(call: FlutterMethodCall, result: FlutterResult) { func nebulaVerifyCertAndKey(call: FlutterMethodCall, result: FlutterResult) {
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) } guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
guard let cert = args["cert"] else { return result(MissingArgumentError(message: "cert is a required argument")) } guard let cert = args["cert"] else {
return result(MissingArgumentError(message: "cert is a required argument"))
}
guard let key = args["key"] else { return result(MissingArgumentError(message: "key is a required argument")) } guard let key = args["key"] else { return result(MissingArgumentError(message: "key is a required argument")) }
var err: NSError? var err: NSError?
var validd: ObjCBool = false var validd: ObjCBool = false
let valid = MobileNebulaVerifyCertAndKey(cert, key, &validd, &err) let valid = MobileNebulaVerifyCertAndKey(cert, key, &validd, &err)
if (err != nil) { if err != nil {
return result(CallFailedError(message: "Error while verifying certificate and private key", details: err!.localizedDescription)) return result(
CallFailedError(
message: "Error while verifying certificate and private key", details: err!.localizedDescription))
} }
return result(valid) return result(valid)
@ -107,8 +123,9 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
func nebulaGenerateKeyPair(result: FlutterResult) { func nebulaGenerateKeyPair(result: FlutterResult) {
var err: NSError? var err: NSError?
let kp = MobileNebulaGenerateKeyPair(&err) let kp = MobileNebulaGenerateKeyPair(&err)
if (err != nil) { if err != nil {
return result(CallFailedError(message: "Error while generating key pairs", details: err!.localizedDescription)) return result(
CallFailedError(message: "Error while generating key pairs", details: err!.localizedDescription))
} }
return result(kp) return result(kp)
@ -119,7 +136,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
var err: NSError? var err: NSError?
let yaml = MobileNebulaRenderConfig(config, "<hidden>", &err) let yaml = MobileNebulaRenderConfig(config, "<hidden>", &err)
if (err != nil) { if err != nil {
return result(CallFailedError(message: "Error while rendering config", details: err!.localizedDescription)) return result(CallFailedError(message: "Error while rendering config", details: err!.localizedDescription))
} }
@ -134,7 +151,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
let oldSite = self.sites?.getSite(id: site.id) let oldSite = self.sites?.getSite(id: site.id)
site.save(manager: oldSite?.manager) { error in site.save(manager: oldSite?.manager) { error in
if (error != nil) { if error != nil {
return result(CallFailedError(message: "Failed to enroll", details: error!.localizedDescription)) return result(CallFailedError(message: "Failed to enroll", details: error!.localizedDescription))
} }
@ -146,8 +163,8 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
} }
func listSites(result: @escaping FlutterResult) { func listSites(result: @escaping FlutterResult) {
self.sites?.loadSites { (sites, err) -> () in self.sites?.loadSites { (sites, err) -> Void in
if (err != nil) { if err != nil {
return result(CallFailedError(message: "Failed to load site list", details: err!.localizedDescription)) return result(CallFailedError(message: "Failed to load site list", details: err!.localizedDescription))
} }
@ -162,7 +179,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
guard let id = call.arguments as? String else { return result(NoArgumentsError()) } guard let id = call.arguments as? String else { return result(NoArgumentsError()) }
//TODO: stop the site if its running currently //TODO: stop the site if its running currently
self.sites?.deleteSite(id: id) { error in self.sites?.deleteSite(id: id) { error in
if (error != nil) { if error != nil {
result(CallFailedError(message: "Failed to delete site", details: error!.localizedDescription)) result(CallFailedError(message: "Failed to delete site", details: error!.localizedDescription))
} }
@ -180,7 +197,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
let oldSite = self.sites?.getSite(id: site.id) let oldSite = self.sites?.getSite(id: site.id)
site.save(manager: oldSite?.manager) { error in site.save(manager: oldSite?.manager) { error in
if (error != nil) { if error != nil {
return result(CallFailedError(message: "Failed to save site", details: error!.localizedDescription)) return result(CallFailedError(message: "Failed to save site", details: error!.localizedDescription))
} }
@ -191,59 +208,62 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
} }
func startSite(call: FlutterMethodCall, result: @escaping FlutterResult) { func startSite(call: FlutterMethodCall, result: @escaping FlutterResult) {
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) } guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
guard let id = args["id"] else { return result(MissingArgumentError(message: "id is a required argument")) } guard let id = args["id"] else { return result(MissingArgumentError(message: "id is a required argument")) }
#if targetEnvironment(simulator) #if targetEnvironment(simulator)
let updater = self.sites?.getUpdater(id: id) let updater = self.sites?.getUpdater(id: id)
updater?.update(connected: true) updater?.update(connected: true)
#else #else
let container = self.sites?.getContainer(id: id) let container = self.sites?.getContainer(id: id)
let manager = container?.site.manager let manager = container?.site.manager
manager?.loadFromPreferences{ error in manager?.loadFromPreferences { error in
//TODO: Handle load error //TODO: Handle load error
// This is silly but we need to enable the site each time to avoid situations where folks have multiple sites // This is silly but we need to enable the site each time to avoid situations where folks have multiple sites
manager?.isEnabled = true manager?.isEnabled = true
manager?.saveToPreferences{ error in manager?.saveToPreferences { error in
//TODO: Handle load error //TODO: Handle load error
manager?.loadFromPreferences{ error in manager?.loadFromPreferences { error in
//TODO: Handle load error //TODO: Handle load error
do { do {
container?.updater.startFunc = {() -> Void in container?.updater.startFunc = { () -> Void in
return self.vpnRequest(command: "start", arguments: args, result: result) return self.vpnRequest(command: "start", arguments: args, result: result)
} }
try manager?.connection.startVPNTunnel(options: ["expectStart": NSNumber(1)]) try manager?.connection.startVPNTunnel(options: ["expectStart": NSNumber(1)])
} catch { } catch {
return result(CallFailedError(message: "Could not start site", details: error.localizedDescription)) return result(
CallFailedError(message: "Could not start site", details: error.localizedDescription))
} }
} }
} }
} }
#endif #endif
} }
func stopSite(call: FlutterMethodCall, result: @escaping FlutterResult) { func stopSite(call: FlutterMethodCall, result: @escaping FlutterResult) {
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) } guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
guard let id = args["id"] else { return result(MissingArgumentError(message: "id is a required argument")) } guard let id = args["id"] else { return result(MissingArgumentError(message: "id is a required argument")) }
#if targetEnvironment(simulator) #if targetEnvironment(simulator)
let updater = self.sites?.getUpdater(id: id) let updater = self.sites?.getUpdater(id: id)
updater?.update(connected: false) updater?.update(connected: false)
#else #else
let manager = self.sites?.getSite(id: id)?.manager let manager = self.sites?.getSite(id: id)?.manager
manager?.loadFromPreferences{ error in manager?.loadFromPreferences { error in
//TODO: Handle load error //TODO: Handle load error
manager?.connection.stopVPNTunnel() manager?.connection.stopVPNTunnel()
return result(nil) return result(nil)
} }
#endif #endif
} }
func vpnRequest(command: String, arguments: Any?, result: @escaping FlutterResult) { func vpnRequest(command: String, arguments: Any?, result: @escaping FlutterResult) {
guard let args = arguments as? Dictionary<String, Any> else { return result(NoArgumentsError()) } guard let args = arguments as? [String: Any] else { return result(NoArgumentsError()) }
guard let id = args["id"] as? String else { return result(MissingArgumentError(message: "id is a required argument")) } guard let id = args["id"] as? String else {
return result(MissingArgumentError(message: "id is a required argument"))
}
let container = sites?.getContainer(id: id) let container = sites?.getContainer(id: id)
if container == nil { if container == nil {
@ -258,7 +278,9 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
if let session = container!.site.manager?.connection as? NETunnelProviderSession { if let session = container!.site.manager?.connection as? NETunnelProviderSession {
do { do {
try session.sendProviderMessage(try JSONEncoder().encode(IPCRequest(command: command, arguments: JSON(args)))) { data in try session.sendProviderMessage(
try JSONEncoder().encode(IPCRequest(command: command, arguments: JSON(args)))
) { data in
if data == nil { if data == nil {
return result(nil) return result(nil)
} }
@ -288,7 +310,9 @@ func MissingArgumentError(message: String, details: (any Error)? = nil) -> Flutt
return FlutterError(code: "missingArgument", message: message, details: details) return FlutterError(code: "missingArgument", message: message, details: details)
} }
func NoArgumentsError(message: String? = "no arguments were provided or could not be deserialized", details: (any Error)? = nil) -> FlutterError { func NoArgumentsError(
message: String? = "no arguments were provided or could not be deserialized", details: (any Error)? = nil
) -> FlutterError {
return FlutterError(code: "noArguments", message: message, details: details) return FlutterError(code: "noArguments", message: message, details: details)
} }

View file

@ -1,45 +1,52 @@
import Foundation import Foundation
import os.log import os.log
class DNUpdater { actor DNUpdater {
private let apiClient = APIClient() private let apiClient = APIClient()
private let timer = RepeatingTimer(timeInterval: 15 * 60) // 15 * 60 is 15 minutes private let timer = RepeatingTimer(timeInterval: 15 * 60) // 15 * 60 is 15 minutes
private let log = Logger(subsystem: "net.defined.mobileNebula", category: "DNUpdater") private let log = Logger(subsystem: "net.defined.mobileNebula", category: "DNUpdater")
func updateAll(onUpdate: @escaping (Site) -> ()) { func updateAll(onUpdate: @escaping (Site) -> Void) {
_ = SiteList{ (sites, _) -> () in _ = SiteList { (sites, _) -> Void in
switch sites
{
case .some(let sites):
// NEVPN seems to force us onto the main thread and we are about to make network calls that // NEVPN seems to force us onto the main thread and we are about to make network calls that
// could block for a while. Push ourselves onto another thread to avoid blocking the UI. // could block for a while. Push ourselves onto another thread to avoid blocking the UI.
Task.detached(priority: .userInitiated) { Task.detached(priority: .userInitiated) {
sites?.values.forEach { site in for site in sites.values {
if (site.connected == true) { if site.connected == true {
// The vpn service is in charge of updating the currently connected site // The vpn service is in charge of updating the currently connected site
return return
} }
self.updateSite(site: site, onUpdate: onUpdate) await self.updateSite(site: site, onUpdate: onUpdate)
} }
} }
default: break
}
} }
} }
func updateAllLoop(onUpdate: @escaping (Site) -> ()) { func updateAllLoop(onUpdate: @escaping (Site) -> Void) {
timer.eventHandler = { timer.eventHandler = {
self.updateAll(onUpdate: onUpdate) self.updateAll(onUpdate: onUpdate)
} }
timer.resume() timer.resume()
} }
func updateSingleLoop(site: Site, onUpdate: @escaping (Site) -> ()) { func updateSingleLoop(site: Site, onUpdate: @escaping (Site) -> Void) {
timer.eventHandler = { timer.eventHandler = {
self.updateSite(site: site, onUpdate: onUpdate) self.updateSite(site: site, onUpdate: onUpdate)
} }
timer.resume() timer.resume()
} }
func updateSite(site: Site, onUpdate: @escaping (Site) -> ()) { func updateSite(site: Site, onUpdate: @escaping (Site) -> Void) {
do { do {
if (!site.managed) { if !site.managed {
return return
} }
@ -55,7 +62,7 @@ class DNUpdater {
trustedKeys: credentials.trustedKeys trustedKeys: credentials.trustedKeys
) )
} catch (APIClientError.invalidCredentials) { } catch (APIClientError.invalidCredentials) {
if (!credentials.invalid) { if !credentials.invalid {
try site.invalidateDNCredentials() try site.invalidateDNCredentials()
log.notice("Invalidated credentials in site: \(site.name, privacy: .public)") log.notice("Invalidated credentials in site: \(site.name, privacy: .public)")
} }
@ -64,10 +71,13 @@ class DNUpdater {
} }
let siteManager = site.manager let siteManager = site.manager
let shouldSaveToManager = siteManager != nil || ProcessInfo().isOperatingSystemAtLeast(OperatingSystemVersion(majorVersion: 17, minorVersion: 0, patchVersion: 0)) let shouldSaveToManager =
siteManager != nil
|| ProcessInfo().isOperatingSystemAtLeast(
OperatingSystemVersion(majorVersion: 17, minorVersion: 0, patchVersion: 0))
newSite?.save(manager: site.manager, saveToManager: shouldSaveToManager) { error in newSite?.save(manager: site.manager, saveToManager: shouldSaveToManager) { error in
if (error != nil) { if error != nil {
self.log.error("failed to save update: \(error!.localizedDescription, privacy: .public)") self.log.error("failed to save update: \(error!.localizedDescription, privacy: .public)")
} }
@ -75,13 +85,14 @@ class DNUpdater {
onUpdate(Site(incoming: newSite!)) onUpdate(Site(incoming: newSite!))
} }
if (credentials.invalid) { if credentials.invalid {
try site.validateDNCredentials() try site.validateDNCredentials()
log.notice("Revalidated credentials in site \(site.name, privacy: .public)") log.notice("Revalidated credentials in site \(site.name, privacy: .public)")
} }
} catch { } catch {
log.error("Error while updating \(site.name, privacy: .public): \(error.localizedDescription, privacy: .public)") log.error(
"Error while updating \(site.name, privacy: .public): \(error.localizedDescription, privacy: .public)")
} }
} }
} }