mirror of
https://github.com/DefinedNet/mobile_nebula.git
synced 2025-02-23 11:35:26 +00:00
Make DNUpdater
an actor to enforce thread safety
This commit is contained in:
parent
9c19e39891
commit
06ed3dfaaa
3 changed files with 185 additions and 150 deletions
|
@ -116,7 +116,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||
}
|
||||
|
||||
self.nebula!.start()
|
||||
self.dnUpdater.updateSingleLoop(site: self.site!, onUpdate: self.handleDNUpdate)
|
||||
await self.dnUpdater.updateSingleLoop(site: self.site!, onUpdate: self.handleDNUpdate)
|
||||
}
|
||||
|
||||
private func handleDNUpdate(newSite: Site) {
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import UIKit
|
||||
@preconcurrency import Flutter
|
||||
import MobileNebula
|
||||
import NetworkExtension
|
||||
import SwiftyJSON
|
||||
import UIKit
|
||||
|
||||
enum ChannelName {
|
||||
static let vpn = "net.defined.mobileNebula/NebulaVpnService"
|
||||
|
@ -19,25 +19,29 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
private var sites: Sites?
|
||||
private var ui: FlutterMethodChannel?
|
||||
|
||||
|
||||
override func application(
|
||||
_ application: UIApplication,
|
||||
didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?
|
||||
) -> Bool {
|
||||
GeneratedPluginRegistrant.register(with: self)
|
||||
|
||||
Task.detached {
|
||||
await self.dnUpdater.updateAllLoop { [weak self] site in
|
||||
// Signal the site has changed in case the current site details screen is active
|
||||
let container = self?.sites?.getContainer(id: site.id)
|
||||
if container != nil {
|
||||
// Update references to the site with the new site config
|
||||
container!.site = site
|
||||
container!.updater.update(connected: site.connected ?? false, replaceSite: site)
|
||||
}
|
||||
|
||||
// Send the refresh sites command on the main thread
|
||||
DispatchQueue.main.async {
|
||||
// Signal to the main screen to reload
|
||||
self?.ui?.invokeMethod("refreshSites", arguments: nil)
|
||||
}
|
||||
|
||||
dnUpdater.updateAllLoop { site in
|
||||
// Signal the site has changed in case the current site details screen is active
|
||||
let container = self.sites?.getContainer(id: site.id)
|
||||
if (container != nil) {
|
||||
// Update references to the site with the new site config
|
||||
container!.site = site
|
||||
container!.updater.update(connected: site.connected ?? false, replaceSite: site)
|
||||
}
|
||||
|
||||
// Signal to the main screen to reload
|
||||
self.ui?.invokeMethod("refreshSites", arguments: nil)
|
||||
}
|
||||
|
||||
guard let controller = window?.rootViewController as? FlutterViewController else {
|
||||
|
@ -47,7 +51,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
sites = Sites(messenger: controller.binaryMessenger)
|
||||
ui = FlutterMethodChannel(name: ChannelName.vpn, binaryMessenger: controller.binaryMessenger)
|
||||
|
||||
ui!.setMethodCallHandler({(call: FlutterMethodCall, result: @escaping FlutterResult) -> Void in
|
||||
ui!.setMethodCallHandler({ (call: FlutterMethodCall, result: @escaping FlutterResult) -> Void in
|
||||
switch call.method {
|
||||
case "nebula.parseCerts": return self.nebulaParseCerts(call: call, result: result)
|
||||
case "nebula.generateKeyPair": return self.nebulaGenerateKeyPair(result: result)
|
||||
|
@ -62,11 +66,16 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
case "startSite": return self.startSite(call: call, result: result)
|
||||
case "stopSite": return self.stopSite(call: call, result: result)
|
||||
|
||||
case "active.listHostmap": self.vpnRequest(command: "listHostmap", arguments: call.arguments, result: result)
|
||||
case "active.listPendingHostmap": self.vpnRequest(command: "listPendingHostmap", arguments: call.arguments, result: result)
|
||||
case "active.getHostInfo": self.vpnRequest(command: "getHostInfo", arguments: call.arguments, result: result)
|
||||
case "active.setRemoteForTunnel": self.vpnRequest(command: "setRemoteForTunnel", arguments: call.arguments, result: result)
|
||||
case "active.closeTunnel": self.vpnRequest(command: "closeTunnel", arguments: call.arguments, result: result)
|
||||
case "active.listHostmap":
|
||||
self.vpnRequest(command: "listHostmap", arguments: call.arguments, result: result)
|
||||
case "active.listPendingHostmap":
|
||||
self.vpnRequest(command: "listPendingHostmap", arguments: call.arguments, result: result)
|
||||
case "active.getHostInfo":
|
||||
self.vpnRequest(command: "getHostInfo", arguments: call.arguments, result: result)
|
||||
case "active.setRemoteForTunnel":
|
||||
self.vpnRequest(command: "setRemoteForTunnel", arguments: call.arguments, result: result)
|
||||
case "active.closeTunnel":
|
||||
self.vpnRequest(command: "closeTunnel", arguments: call.arguments, result: result)
|
||||
|
||||
default:
|
||||
result(FlutterMethodNotImplemented)
|
||||
|
@ -77,28 +86,35 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
}
|
||||
|
||||
func nebulaParseCerts(call: FlutterMethodCall, result: FlutterResult) {
|
||||
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) }
|
||||
guard let certs = args["certs"] else { return result(MissingArgumentError(message: "certs is a required argument")) }
|
||||
guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
|
||||
guard let certs = args["certs"] else {
|
||||
return result(MissingArgumentError(message: "certs is a required argument"))
|
||||
}
|
||||
|
||||
var err: NSError?
|
||||
let json = MobileNebulaParseCerts(certs, &err)
|
||||
if (err != nil) {
|
||||
return result(CallFailedError(message: "Error while parsing certificate(s)", details: err!.localizedDescription))
|
||||
if err != nil {
|
||||
return result(
|
||||
CallFailedError(message: "Error while parsing certificate(s)", details: err!.localizedDescription))
|
||||
}
|
||||
|
||||
return result(json)
|
||||
}
|
||||
|
||||
func nebulaVerifyCertAndKey(call: FlutterMethodCall, result: FlutterResult) {
|
||||
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) }
|
||||
guard let cert = args["cert"] else { return result(MissingArgumentError(message: "cert is a required argument")) }
|
||||
guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
|
||||
guard let cert = args["cert"] else {
|
||||
return result(MissingArgumentError(message: "cert is a required argument"))
|
||||
}
|
||||
guard let key = args["key"] else { return result(MissingArgumentError(message: "key is a required argument")) }
|
||||
|
||||
var err: NSError?
|
||||
var validd: ObjCBool = false
|
||||
let valid = MobileNebulaVerifyCertAndKey(cert, key, &validd, &err)
|
||||
if (err != nil) {
|
||||
return result(CallFailedError(message: "Error while verifying certificate and private key", details: err!.localizedDescription))
|
||||
if err != nil {
|
||||
return result(
|
||||
CallFailedError(
|
||||
message: "Error while verifying certificate and private key", details: err!.localizedDescription))
|
||||
}
|
||||
|
||||
return result(valid)
|
||||
|
@ -107,8 +123,9 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
func nebulaGenerateKeyPair(result: FlutterResult) {
|
||||
var err: NSError?
|
||||
let kp = MobileNebulaGenerateKeyPair(&err)
|
||||
if (err != nil) {
|
||||
return result(CallFailedError(message: "Error while generating key pairs", details: err!.localizedDescription))
|
||||
if err != nil {
|
||||
return result(
|
||||
CallFailedError(message: "Error while generating key pairs", details: err!.localizedDescription))
|
||||
}
|
||||
|
||||
return result(kp)
|
||||
|
@ -119,7 +136,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
|
||||
var err: NSError?
|
||||
let yaml = MobileNebulaRenderConfig(config, "<hidden>", &err)
|
||||
if (err != nil) {
|
||||
if err != nil {
|
||||
return result(CallFailedError(message: "Error while rendering config", details: err!.localizedDescription))
|
||||
}
|
||||
|
||||
|
@ -134,7 +151,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
|
||||
let oldSite = self.sites?.getSite(id: site.id)
|
||||
site.save(manager: oldSite?.manager) { error in
|
||||
if (error != nil) {
|
||||
if error != nil {
|
||||
return result(CallFailedError(message: "Failed to enroll", details: error!.localizedDescription))
|
||||
}
|
||||
|
||||
|
@ -146,8 +163,8 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
}
|
||||
|
||||
func listSites(result: @escaping FlutterResult) {
|
||||
self.sites?.loadSites { (sites, err) -> () in
|
||||
if (err != nil) {
|
||||
self.sites?.loadSites { (sites, err) -> Void in
|
||||
if err != nil {
|
||||
return result(CallFailedError(message: "Failed to load site list", details: err!.localizedDescription))
|
||||
}
|
||||
|
||||
|
@ -162,7 +179,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
guard let id = call.arguments as? String else { return result(NoArgumentsError()) }
|
||||
//TODO: stop the site if its running currently
|
||||
self.sites?.deleteSite(id: id) { error in
|
||||
if (error != nil) {
|
||||
if error != nil {
|
||||
result(CallFailedError(message: "Failed to delete site", details: error!.localizedDescription))
|
||||
}
|
||||
|
||||
|
@ -180,7 +197,7 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
|
||||
let oldSite = self.sites?.getSite(id: site.id)
|
||||
site.save(manager: oldSite?.manager) { error in
|
||||
if (error != nil) {
|
||||
if error != nil {
|
||||
return result(CallFailedError(message: "Failed to save site", details: error!.localizedDescription))
|
||||
}
|
||||
|
||||
|
@ -191,59 +208,62 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
}
|
||||
|
||||
func startSite(call: FlutterMethodCall, result: @escaping FlutterResult) {
|
||||
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) }
|
||||
guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
|
||||
guard let id = args["id"] else { return result(MissingArgumentError(message: "id is a required argument")) }
|
||||
|
||||
#if targetEnvironment(simulator)
|
||||
let updater = self.sites?.getUpdater(id: id)
|
||||
updater?.update(connected: true)
|
||||
#else
|
||||
let container = self.sites?.getContainer(id: id)
|
||||
let manager = container?.site.manager
|
||||
#if targetEnvironment(simulator)
|
||||
let updater = self.sites?.getUpdater(id: id)
|
||||
updater?.update(connected: true)
|
||||
#else
|
||||
let container = self.sites?.getContainer(id: id)
|
||||
let manager = container?.site.manager
|
||||
|
||||
manager?.loadFromPreferences{ error in
|
||||
//TODO: Handle load error
|
||||
// This is silly but we need to enable the site each time to avoid situations where folks have multiple sites
|
||||
manager?.isEnabled = true
|
||||
manager?.saveToPreferences{ error in
|
||||
manager?.loadFromPreferences { error in
|
||||
//TODO: Handle load error
|
||||
manager?.loadFromPreferences{ error in
|
||||
// This is silly but we need to enable the site each time to avoid situations where folks have multiple sites
|
||||
manager?.isEnabled = true
|
||||
manager?.saveToPreferences { error in
|
||||
//TODO: Handle load error
|
||||
do {
|
||||
container?.updater.startFunc = {() -> Void in
|
||||
return self.vpnRequest(command: "start", arguments: args, result: result)
|
||||
manager?.loadFromPreferences { error in
|
||||
//TODO: Handle load error
|
||||
do {
|
||||
container?.updater.startFunc = { () -> Void in
|
||||
return self.vpnRequest(command: "start", arguments: args, result: result)
|
||||
}
|
||||
try manager?.connection.startVPNTunnel(options: ["expectStart": NSNumber(1)])
|
||||
} catch {
|
||||
return result(
|
||||
CallFailedError(message: "Could not start site", details: error.localizedDescription))
|
||||
}
|
||||
try manager?.connection.startVPNTunnel(options: ["expectStart": NSNumber(1)])
|
||||
} catch {
|
||||
return result(CallFailedError(message: "Could not start site", details: error.localizedDescription))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
func stopSite(call: FlutterMethodCall, result: @escaping FlutterResult) {
|
||||
guard let args = call.arguments as? Dictionary<String, String> else { return result(NoArgumentsError()) }
|
||||
guard let args = call.arguments as? [String: String] else { return result(NoArgumentsError()) }
|
||||
guard let id = args["id"] else { return result(MissingArgumentError(message: "id is a required argument")) }
|
||||
#if targetEnvironment(simulator)
|
||||
let updater = self.sites?.getUpdater(id: id)
|
||||
updater?.update(connected: false)
|
||||
#if targetEnvironment(simulator)
|
||||
let updater = self.sites?.getUpdater(id: id)
|
||||
updater?.update(connected: false)
|
||||
|
||||
#else
|
||||
let manager = self.sites?.getSite(id: id)?.manager
|
||||
manager?.loadFromPreferences{ error in
|
||||
//TODO: Handle load error
|
||||
#else
|
||||
let manager = self.sites?.getSite(id: id)?.manager
|
||||
manager?.loadFromPreferences { error in
|
||||
//TODO: Handle load error
|
||||
|
||||
manager?.connection.stopVPNTunnel()
|
||||
return result(nil)
|
||||
}
|
||||
#endif
|
||||
manager?.connection.stopVPNTunnel()
|
||||
return result(nil)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
func vpnRequest(command: String, arguments: Any?, result: @escaping FlutterResult) {
|
||||
guard let args = arguments as? Dictionary<String, Any> else { return result(NoArgumentsError()) }
|
||||
guard let id = args["id"] as? String else { return result(MissingArgumentError(message: "id is a required argument")) }
|
||||
guard let args = arguments as? [String: Any] else { return result(NoArgumentsError()) }
|
||||
guard let id = args["id"] as? String else {
|
||||
return result(MissingArgumentError(message: "id is a required argument"))
|
||||
}
|
||||
let container = sites?.getContainer(id: id)
|
||||
|
||||
if container == nil {
|
||||
|
@ -258,7 +278,9 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
|
|||
|
||||
if let session = container!.site.manager?.connection as? NETunnelProviderSession {
|
||||
do {
|
||||
try session.sendProviderMessage(try JSONEncoder().encode(IPCRequest(command: command, arguments: JSON(args)))) { data in
|
||||
try session.sendProviderMessage(
|
||||
try JSONEncoder().encode(IPCRequest(command: command, arguments: JSON(args)))
|
||||
) { data in
|
||||
if data == nil {
|
||||
return result(nil)
|
||||
}
|
||||
|
@ -288,7 +310,9 @@ func MissingArgumentError(message: String, details: (any Error)? = nil) -> Flutt
|
|||
return FlutterError(code: "missingArgument", message: message, details: details)
|
||||
}
|
||||
|
||||
func NoArgumentsError(message: String? = "no arguments were provided or could not be deserialized", details: (any Error)? = nil) -> FlutterError {
|
||||
func NoArgumentsError(
|
||||
message: String? = "no arguments were provided or could not be deserialized", details: (any Error)? = nil
|
||||
) -> FlutterError {
|
||||
return FlutterError(code: "noArguments", message: message, details: details)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,45 +1,52 @@
|
|||
import Foundation
|
||||
import os.log
|
||||
|
||||
class DNUpdater {
|
||||
actor DNUpdater {
|
||||
private let apiClient = APIClient()
|
||||
private let timer = RepeatingTimer(timeInterval: 15 * 60) // 15 * 60 is 15 minutes
|
||||
private let timer = RepeatingTimer(timeInterval: 15 * 60) // 15 * 60 is 15 minutes
|
||||
private let log = Logger(subsystem: "net.defined.mobileNebula", category: "DNUpdater")
|
||||
|
||||
func updateAll(onUpdate: @escaping (Site) -> ()) {
|
||||
_ = SiteList{ (sites, _) -> () in
|
||||
// NEVPN seems to force us onto the main thread and we are about to make network calls that
|
||||
// could block for a while. Push ourselves onto another thread to avoid blocking the UI.
|
||||
Task.detached(priority: .userInitiated) {
|
||||
sites?.values.forEach { site in
|
||||
if (site.connected == true) {
|
||||
// The vpn service is in charge of updating the currently connected site
|
||||
return
|
||||
func updateAll(onUpdate: @escaping (Site) -> Void) {
|
||||
_ = SiteList { (sites, _) -> Void in
|
||||
switch sites
|
||||
{
|
||||
case .some(let sites):
|
||||
// NEVPN seems to force us onto the main thread and we are about to make network calls that
|
||||
// could block for a while. Push ourselves onto another thread to avoid blocking the UI.
|
||||
Task.detached(priority: .userInitiated) {
|
||||
for site in sites.values {
|
||||
if site.connected == true {
|
||||
// The vpn service is in charge of updating the currently connected site
|
||||
return
|
||||
}
|
||||
|
||||
await self.updateSite(site: site, onUpdate: onUpdate)
|
||||
}
|
||||
|
||||
self.updateSite(site: site, onUpdate: onUpdate)
|
||||
}
|
||||
default: break
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func updateAllLoop(onUpdate: @escaping (Site) -> ()) {
|
||||
func updateAllLoop(onUpdate: @escaping (Site) -> Void) {
|
||||
timer.eventHandler = {
|
||||
self.updateAll(onUpdate: onUpdate)
|
||||
}
|
||||
timer.resume()
|
||||
}
|
||||
|
||||
func updateSingleLoop(site: Site, onUpdate: @escaping (Site) -> ()) {
|
||||
func updateSingleLoop(site: Site, onUpdate: @escaping (Site) -> Void) {
|
||||
timer.eventHandler = {
|
||||
self.updateSite(site: site, onUpdate: onUpdate)
|
||||
}
|
||||
timer.resume()
|
||||
}
|
||||
|
||||
func updateSite(site: Site, onUpdate: @escaping (Site) -> ()) {
|
||||
func updateSite(site: Site, onUpdate: @escaping (Site) -> Void) {
|
||||
do {
|
||||
if (!site.managed) {
|
||||
if !site.managed {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -55,7 +62,7 @@ class DNUpdater {
|
|||
trustedKeys: credentials.trustedKeys
|
||||
)
|
||||
} catch (APIClientError.invalidCredentials) {
|
||||
if (!credentials.invalid) {
|
||||
if !credentials.invalid {
|
||||
try site.invalidateDNCredentials()
|
||||
log.notice("Invalidated credentials in site: \(site.name, privacy: .public)")
|
||||
}
|
||||
|
@ -64,10 +71,13 @@ class DNUpdater {
|
|||
}
|
||||
|
||||
let siteManager = site.manager
|
||||
let shouldSaveToManager = siteManager != nil || ProcessInfo().isOperatingSystemAtLeast(OperatingSystemVersion(majorVersion: 17, minorVersion: 0, patchVersion: 0))
|
||||
let shouldSaveToManager =
|
||||
siteManager != nil
|
||||
|| ProcessInfo().isOperatingSystemAtLeast(
|
||||
OperatingSystemVersion(majorVersion: 17, minorVersion: 0, patchVersion: 0))
|
||||
|
||||
newSite?.save(manager: site.manager, saveToManager: shouldSaveToManager) { error in
|
||||
if (error != nil) {
|
||||
if error != nil {
|
||||
self.log.error("failed to save update: \(error!.localizedDescription, privacy: .public)")
|
||||
}
|
||||
|
||||
|
@ -75,13 +85,14 @@ class DNUpdater {
|
|||
onUpdate(Site(incoming: newSite!))
|
||||
}
|
||||
|
||||
if (credentials.invalid) {
|
||||
if credentials.invalid {
|
||||
try site.validateDNCredentials()
|
||||
log.notice("Revalidated credentials in site \(site.name, privacy: .public)")
|
||||
}
|
||||
|
||||
} catch {
|
||||
log.error("Error while updating \(site.name, privacy: .public): \(error.localizedDescription, privacy: .public)")
|
||||
log.error(
|
||||
"Error while updating \(site.name, privacy: .public): \(error.localizedDescription, privacy: .public)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue