Make DNUpdater an actor to enforce thread safety

This commit is contained in:
Caleb Jasik 2025-02-11 15:42:24 -06:00
parent 9c19e39891
commit 06ed3dfaaa
No known key found for this signature in database
3 changed files with 185 additions and 150 deletions

View file

@ -116,7 +116,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
}
self.nebula!.start()
self.dnUpdater.updateSingleLoop(site: self.site!, onUpdate: self.handleDNUpdate)
await self.dnUpdater.updateSingleLoop(site: self.site!, onUpdate: self.handleDNUpdate)
}
private func handleDNUpdate(newSite: Site) {

View file

@ -1,8 +1,8 @@
import UIKit
@preconcurrency import Flutter
import MobileNebula
import NetworkExtension
import SwiftyJSON
import UIKit
enum ChannelName {
static let vpn = "net.defined.mobileNebula/NebulaVpnService"
@ -19,25 +19,29 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
private var sites: Sites?
private var ui: FlutterMethodChannel?
override func application(
_ application: UIApplication,
didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?
) -> Bool {
GeneratedPluginRegistrant.register(with: self)
dnUpdater.updateAllLoop { site in
Task.detached {
await self.dnUpdater.updateAllLoop { [weak self] site in
// Signal the site has changed in case the current site details screen is active
let container = self.sites?.getContainer(id: site.id)
if (container != nil) {
let container = self?.sites?.getContainer(id: site.id)
if container != nil {
// Update references to the site with the new site config
container!.site = site
container!.updater.update(connected: site.connected ?? false, replaceSite: site)
}
// Send the refresh sites command on the main thread
DispatchQueue.main.async {
// Signal to the main screen to reload
self.ui?.invokeMethod("refreshSites", arguments: nil)
self?.ui?.invokeMethod("refreshSites", arguments: nil)
}
}
}
guard let controller = window?.rootViewController as? FlutterViewController else {
@ -62,11 +66,16 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
case "startSite": return self.startSite(call: call, result: result)
case "stopSite": return self.stopSite(call: call, result: result)
case "active.listHostmap": self.vpnRequest(command: "listHostmap", arguments: call.arguments, result: result)
case "active.listPendingHostmap": self.vpnRequest(command: "listPendingHostmap", arguments: call.arguments, result: result)
case "active.getHostInfo": self.vpnRequest(command: "getHostInfo", arguments: call.arguments, result: result)
case "active.setRemoteForTunnel": self.vpnRequest(command: "setRemoteForTunnel", arguments: call.arguments, result: result)
case "active.closeTunnel": self.vpnRequest(command: "closeTunnel", arguments: call.arguments, result: result)
case "active.listHostmap":
self.vpnRequest(command: "listHostmap", arguments: call.arguments, result: result)
case "active.listPendingHostmap":
self.vpnRequest(command: "listPendingHostmap", arguments: call.arguments, result: result)
case "active.getHostInfo":
self.vpnRequest(command: "getHostInfo", arguments: call.arguments, result: result)
case "active.setRemoteForTunnel":
self.vpnRequest(command: "setRemoteForTunnel", arguments: call.arguments, result: result)
case "active.closeTunnel":
self.vpnRequest(command: "closeTunnel", arguments: call.arguments, result: result)
default:
result(FlutterMethodNotImplemented)
@ -77,28 +86,35 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
}
func nebulaParseCerts(call: FlutterMethodCall, result: FlutterResult) {
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) }
guard let certs = args["certs"] else { return result(MissingArgumentError(message: "certs is a required argument")) }
guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
guard let certs = args["certs"] else {
return result(MissingArgumentError(message: "certs is a required argument"))
}
var err: NSError?
let json = MobileNebulaParseCerts(certs, &err)
if (err != nil) {
return result(CallFailedError(message: "Error while parsing certificate(s)", details: err!.localizedDescription))
if err != nil {
return result(
CallFailedError(message: "Error while parsing certificate(s)", details: err!.localizedDescription))
}
return result(json)
}
func nebulaVerifyCertAndKey(call: FlutterMethodCall, result: FlutterResult) {
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) }
guard let cert = args["cert"] else { return result(MissingArgumentError(message: "cert is a required argument")) }
guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
guard let cert = args["cert"] else {
return result(MissingArgumentError(message: "cert is a required argument"))
}
guard let key = args["key"] else { return result(MissingArgumentError(message: "key is a required argument")) }
var err: NSError?
var validd: ObjCBool = false
let valid = MobileNebulaVerifyCertAndKey(cert, key, &validd, &err)
if (err != nil) {
return result(CallFailedError(message: "Error while verifying certificate and private key", details: err!.localizedDescription))
if err != nil {
return result(
CallFailedError(
message: "Error while verifying certificate and private key", details: err!.localizedDescription))
}
return result(valid)
@ -107,8 +123,9 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
func nebulaGenerateKeyPair(result: FlutterResult) {
var err: NSError?
let kp = MobileNebulaGenerateKeyPair(&err)
if (err != nil) {
return result(CallFailedError(message: "Error while generating key pairs", details: err!.localizedDescription))
if err != nil {
return result(
CallFailedError(message: "Error while generating key pairs", details: err!.localizedDescription))
}
return result(kp)
@ -119,7 +136,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
var err: NSError?
let yaml = MobileNebulaRenderConfig(config, "<hidden>", &err)
if (err != nil) {
if err != nil {
return result(CallFailedError(message: "Error while rendering config", details: err!.localizedDescription))
}
@ -134,7 +151,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
let oldSite = self.sites?.getSite(id: site.id)
site.save(manager: oldSite?.manager) { error in
if (error != nil) {
if error != nil {
return result(CallFailedError(message: "Failed to enroll", details: error!.localizedDescription))
}
@ -146,8 +163,8 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
}
func listSites(result: @escaping FlutterResult) {
self.sites?.loadSites { (sites, err) -> () in
if (err != nil) {
self.sites?.loadSites { (sites, err) -> Void in
if err != nil {
return result(CallFailedError(message: "Failed to load site list", details: err!.localizedDescription))
}
@ -162,7 +179,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
guard let id = call.arguments as? String else { return result(NoArgumentsError()) }
//TODO: stop the site if its running currently
self.sites?.deleteSite(id: id) { error in
if (error != nil) {
if error != nil {
result(CallFailedError(message: "Failed to delete site", details: error!.localizedDescription))
}
@ -180,7 +197,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
let oldSite = self.sites?.getSite(id: site.id)
site.save(manager: oldSite?.manager) { error in
if (error != nil) {
if error != nil {
return result(CallFailedError(message: "Failed to save site", details: error!.localizedDescription))
}
@ -191,7 +208,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
}
func startSite(call: FlutterMethodCall, result: @escaping FlutterResult) {
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) }
guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
guard let id = args["id"] else { return result(MissingArgumentError(message: "id is a required argument")) }
#if targetEnvironment(simulator)
@ -215,7 +232,8 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
}
try manager?.connection.startVPNTunnel(options: ["expectStart": NSNumber(1)])
} catch {
return result(CallFailedError(message: "Could not start site", details: error.localizedDescription))
return result(
CallFailedError(message: "Could not start site", details: error.localizedDescription))
}
}
}
@ -224,7 +242,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
}
func stopSite(call: FlutterMethodCall, result: @escaping FlutterResult) {
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) }
guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
guard let id = args["id"] else { return result(MissingArgumentError(message: "id is a required argument")) }
#if targetEnvironment(simulator)
let updater = self.sites?.getUpdater(id: id)
@ -242,8 +260,10 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
}
func vpnRequest(command: String, arguments: Any?, result: @escaping FlutterResult) {
guard let args = arguments as? Dictionary<String, Any> else { return result(NoArgumentsError()) }
guard let id = args["id"] as? String else { return result(MissingArgumentError(message: "id is a required argument")) }
guard let args = arguments as? [String: Any] else { return result(NoArgumentsError()) }
guard let id = args["id"] as? String else {
return result(MissingArgumentError(message: "id is a required argument"))
}
let container = sites?.getContainer(id: id)
if container == nil {
@ -258,7 +278,9 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
if let session = container!.site.manager?.connection as? NETunnelProviderSession {
do {
try session.sendProviderMessage(try JSONEncoder().encode(IPCRequest(command: command, arguments: JSON(args)))) { data in
try session.sendProviderMessage(
try JSONEncoder().encode(IPCRequest(command: command, arguments: JSON(args)))
) { data in
if data == nil {
return result(nil)
}
@ -288,7 +310,9 @@ func MissingArgumentError(message: String, details: (any Error)? = nil) -> Flutt
return FlutterError(code: "missingArgument", message: message, details: details)
}
func NoArgumentsError(message: String? = "no arguments were provided or could not be deserialized", details: (any Error)? = nil) -> FlutterError {
func NoArgumentsError(
message: String? = "no arguments were provided or could not be deserialized", details: (any Error)? = nil
) -> FlutterError {
return FlutterError(code: "noArguments", message: message, details: details)
}

View file

@ -1,45 +1,52 @@
import Foundation
import os.log
class DNUpdater {
actor DNUpdater {
private let apiClient = APIClient()
private let timer = RepeatingTimer(timeInterval: 15 * 60) // 15 * 60 is 15 minutes
private let log = Logger(subsystem: "net.defined.mobileNebula", category: "DNUpdater")
func updateAll(onUpdate: @escaping (Site) -> ()) {
_ = SiteList{ (sites, _) -> () in
func updateAll(onUpdate: @escaping (Site) -> Void) {
_ = SiteList { (sites, _) -> Void in
switch sites
{
case .some(let sites):
// NEVPN seems to force us onto the main thread and we are about to make network calls that
// could block for a while. Push ourselves onto another thread to avoid blocking the UI.
Task.detached(priority: .userInitiated) {
sites?.values.forEach { site in
if (site.connected == true) {
for site in sites.values {
if site.connected == true {
// The vpn service is in charge of updating the currently connected site
return
}
self.updateSite(site: site, onUpdate: onUpdate)
await self.updateSite(site: site, onUpdate: onUpdate)
}
}
default: break
}
}
}
func updateAllLoop(onUpdate: @escaping (Site) -> ()) {
func updateAllLoop(onUpdate: @escaping (Site) -> Void) {
timer.eventHandler = {
self.updateAll(onUpdate: onUpdate)
}
timer.resume()
}
func updateSingleLoop(site: Site, onUpdate: @escaping (Site) -> ()) {
func updateSingleLoop(site: Site, onUpdate: @escaping (Site) -> Void) {
timer.eventHandler = {
self.updateSite(site: site, onUpdate: onUpdate)
}
timer.resume()
}
func updateSite(site: Site, onUpdate: @escaping (Site) -> ()) {
func updateSite(site: Site, onUpdate: @escaping (Site) -> Void) {
do {
if (!site.managed) {
if !site.managed {
return
}
@ -55,7 +62,7 @@ class DNUpdater {
trustedKeys: credentials.trustedKeys
)
} catch (APIClientError.invalidCredentials) {
if (!credentials.invalid) {
if !credentials.invalid {
try site.invalidateDNCredentials()
log.notice("Invalidated credentials in site: \(site.name, privacy: .public)")
}
@ -64,10 +71,13 @@ class DNUpdater {
}
let siteManager = site.manager
let shouldSaveToManager = siteManager != nil || ProcessInfo().isOperatingSystemAtLeast(OperatingSystemVersion(majorVersion: 17, minorVersion: 0, patchVersion: 0))
let shouldSaveToManager =
siteManager != nil
|| ProcessInfo().isOperatingSystemAtLeast(
OperatingSystemVersion(majorVersion: 17, minorVersion: 0, patchVersion: 0))
newSite?.save(manager: site.manager, saveToManager: shouldSaveToManager) { error in
if (error != nil) {
if error != nil {
self.log.error("failed to save update: \(error!.localizedDescription, privacy: .public)")
}
@ -75,13 +85,14 @@ class DNUpdater {
onUpdate(Site(incoming: newSite!))
}
if (credentials.invalid) {
if credentials.invalid {
try site.validateDNCredentials()
log.notice("Revalidated credentials in site \(site.name, privacy: .public)")
}
} catch {
log.error("Error while updating \(site.name, privacy: .public): \(error.localizedDescription, privacy: .public)")
log.error(
"Error while updating \(site.name, privacy: .public): \(error.localizedDescription, privacy: .public)")
}
}
}