Work on making DNUpdate use an async timer implementation

This commit is contained in:
Caleb Jasik 2025-02-12 15:23:20 -06:00
parent e4940d3e3a
commit 8ee9979b16
No known key found for this signature in database
4 changed files with 93 additions and 72 deletions

View file

@ -1,7 +1,7 @@
import NetworkExtension
import MobileNebula
import os.log
import NetworkExtension
import SwiftyJSON
import os.log
enum VPNStartError: Error {
case noManagers
@ -23,40 +23,41 @@ extension AppMessageError: LocalizedError {
}
}
// FIXME: marked as unchecked Sendable to allow sending `self.pathUpdate`, but we should refactor and re-enable linting.
class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
private var networkMonitor: NWPathMonitor?
private var site: Site?
private let log = Logger(subsystem: "net.defined.mobileNebula", category: "PacketTunnelProvider")
private var nebula: MobileNebulaNebula?
private var dnUpdater = DNUpdater()
private var didSleep = false
private var cachedRouteDescription: String?
override func startTunnel(options: [String : NSObject]? = nil) async throws {
override func startTunnel(options: [String: NSObject]? = nil) async throws {
// There is currently no way to get initialization errors back to the UI via completionHandler here
// `expectStart` is sent only via the UI which means we should wait for the real start command which has another completion handler the UI can intercept
if options?["expectStart"] != nil {
// startTunnel must complete before IPC will work
return
}
// VPN is being booted out of band of the UI. Use the system completion handler as there will be nothing to route initialization errors to but we still need to report
// success/fail by the presence of an error or nil
try await start()
}
private func start() async throws {
var manager: NETunnelProviderManager?
var config: Data
var key: String
do {
// Cannot use NETunnelProviderManager.loadAllFromPreferences() in earlier versions of iOS
// TODO: Remove else once we drop support for iOS 16
if ProcessInfo().isOperatingSystemAtLeast(OperatingSystemVersion(majorVersion: 17, minorVersion: 0, patchVersion: 0)) {
if ProcessInfo().isOperatingSystemAtLeast(
OperatingSystemVersion(majorVersion: 17, minorVersion: 0, patchVersion: 0))
{
manager = try await self.findManager()
guard let foundManager = manager else {
throw VPNStartError.couldNotFindManager
@ -76,7 +77,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
let _site = self.site!
key = try _site.getKey()
guard let fileDescriptor = self.tunnelFileDescriptor else {
throw VPNStartError.noTunFileDescriptor
}
@ -88,7 +89,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
// Make sure our ip is routed to the tun device
var err: NSError?
let ipNet = MobileNebulaParseCIDR(_site.cert!.cert.details.ips[0], &err)
if (err != nil) {
if err != nil {
throw err!
}
tunnelNetworkSettings.ipv4Settings = NEIPv4Settings(addresses: [ipNet!.ip], subnetMasks: [ipNet!.maskCIDR])
@ -97,7 +98,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
// Add our unsafe routes
try _site.unsafeRoutes.forEach { unsafeRoute in
let ipNet = MobileNebulaParseCIDR(unsafeRoute.route, &err)
if (err != nil) {
if err != nil {
throw err!
}
routes.append(NEIPv4Route(destinationAddress: ipNet!.network, subnetMask: ipNet!.maskCIDR))
@ -108,41 +109,42 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
try await self.setTunnelNetworkSettings(tunnelNetworkSettings)
var nebulaErr: NSError?
self.nebula = MobileNebulaNewNebula(String(data: config, encoding: .utf8), key, self.site!.logFile, tunFD, &nebulaErr)
self.nebula = MobileNebulaNewNebula(
String(data: config, encoding: .utf8), key, self.site!.logFile, tunFD, &nebulaErr)
self.startNetworkMonitor()
if nebulaErr != nil {
self.log.error("We had an error starting up: \(nebulaErr, privacy: .public)")
throw nebulaErr!
}
self.nebula!.start()
await self.dnUpdater.updateSingleLoop(site: self.site!, onUpdate: self.handleDNUpdate)
}
private func handleDNUpdate(newSite: Site) {
do {
self.site = newSite
try self.nebula?.reload(String(data: newSite.getConfig(), encoding: .utf8), key: newSite.getKey())
} catch {
log.error("Got an error while updating nebula \(error.localizedDescription, privacy: .public)")
}
}
//TODO: Sleep/wake get called aggressively and do nothing to help us here, we should locate why that is and make these work appropriately
// override func sleep(completionHandler: @escaping () -> Void) {
// nebula!.sleep()
// completionHandler()
// }
//TODO: Sleep/wake get called aggressively and do nothing to help us here, we should locate why that is and make these work appropriately
// override func sleep(completionHandler: @escaping () -> Void) {
// nebula!.sleep()
// completionHandler()
// }
private func findManager() async throws -> NETunnelProviderManager {
let targetProtoConfig = self.protocolConfiguration as? NETunnelProviderProtocol
guard let targetProviderConfig = targetProtoConfig?.providerConfiguration else {
throw VPNStartError.noProviderConfig
}
let targetID = targetProviderConfig["id"] as? String
// Load vpn configs from system, and find the manager matching the one being started
let managers = try await NETunnelProviderManager.loadAllFromPreferences()
for manager in managers {
@ -151,32 +153,32 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
throw VPNStartError.noProviderConfig
}
let id = mgrProviderConfig["id"] as? String
if (id == targetID) {
if id == targetID {
return manager
}
}
// If we didn't find anything, throw an error
throw VPNStartError.noManagers
}
private func startNetworkMonitor() {
networkMonitor = NWPathMonitor()
networkMonitor!.pathUpdateHandler = self.pathUpdate
networkMonitor!.start(queue: DispatchQueue(label: "NetworkMonitor"))
}
private func stopNetworkMonitor() {
self.networkMonitor?.cancel()
networkMonitor = nil
}
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
nebula?.stop()
stopNetworkMonitor()
completionHandler()
}
private func pathUpdate(path: Network.NWPath) {
let routeDescription = PacketTunnelProvider.collectAddresses(endpoints: path.gateways)
if routeDescription != cachedRouteDescription {
@ -187,10 +189,10 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
cachedRouteDescription = routeDescription
}
}
static private func collectAddresses(endpoints: [Network.NWEndpoint]) -> String {
var str: [String] = []
endpoints.forEach{ endpoint in
endpoints.forEach { endpoint in
switch endpoint {
case let .hostPort(.ipv6(host), port):
str.append("[\(host)]:\(port)")
@ -200,19 +202,19 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
return
}
}
return str.sorted().joined(separator: ", ")
}
override func handleAppMessage(_ data: Data) async -> Data? {
guard let call = try? JSONDecoder().decode(IPCRequest.self, from: data) else {
log.error("Failed to decode IPCRequest from network extension")
return nil
}
var error: (any Error)?
var data: JSON?
// start command has special treatment due to needing to call two completers
if call.command == "start" {
do {
@ -223,16 +225,17 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
defer {
self.cancelTunnelWithError(error)
}
return try? JSONEncoder().encode(IPCResponse.init(type: .error, message: JSON(error.localizedDescription)))
return try? JSONEncoder().encode(
IPCResponse.init(type: .error, message: JSON(error.localizedDescription)))
}
}
if nebula == nil {
// Respond with an empty success message in the event a command comes in before we've truly started
log.warning("Received command but do not have a nebula instance")
return try? JSONEncoder().encode(IPCResponse.init(type: .success, message: nil))
}
//TODO: try catch over all this
switch call.command {
case "listHostmap": (data, error) = listHostmap(pending: false)
@ -240,41 +243,42 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
case "getHostInfo": (data, error) = getHostInfo(args: call.arguments!)
case "setRemoteForTunnel": (data, error) = setRemoteForTunnel(args: call.arguments!)
case "closeTunnel": (data, error) = closeTunnel(args: call.arguments!)
default:
error = AppMessageError.unknownIPCType(command: call.command)
}
if (error != nil) {
return try? JSONEncoder().encode(IPCResponse.init(type: .error, message: JSON(error?.localizedDescription ?? "Unknown error")))
if error != nil {
return try? JSONEncoder().encode(
IPCResponse.init(type: .error, message: JSON(error?.localizedDescription ?? "Unknown error")))
} else {
return try? JSONEncoder().encode(IPCResponse.init(type: .success, message: data))
}
}
private func listHostmap(pending: Bool) -> (JSON?, (any Error)?) {
var err: NSError?
let res = nebula!.listHostmap(pending, error: &err)
return (JSON(res), err)
}
private func getHostInfo(args: JSON) -> (JSON?, (any Error)?) {
var err: NSError?
let res = nebula!.getHostInfo(byVpnIp: args["vpnIp"].string, pending: args["pending"].boolValue, error: &err)
return (JSON(res), err)
}
private func setRemoteForTunnel(args: JSON) -> (JSON?, (any Error)?) {
var err: NSError?
let res = nebula!.setRemoteForTunnel(args["vpnIp"].string, addr: args["addr"].string, error: &err)
return (JSON(res), err)
}
private func closeTunnel(args: JSON) -> (JSON?, (any Error)?) {
let res = nebula!.closeTunnel(args["vpnIp"].string)
return (JSON(res), nil)
}
private var tunnelFileDescriptor: Int32? {
var ctlInfo = ctl_info()
withUnsafeMutablePointer(to: &ctlInfo.ctl_name) {
@ -307,4 +311,3 @@ class PacketTunnelProvider: NEPacketTunnelProvider, @unchecked Sendable {
return nil
}
}

View file

@ -56,6 +56,26 @@ class SiteList {
#endif
}
static func loadAllAsync() async -> Result<[String: Site], any Error> {
await withCheckedContinuation { continuation in
#if targetEnvironment(simulator)
SiteList.loadAllFromFS { sites, err in
if err != nil {
continuation.resume(returning: Result.failure(err!))
}
continuation.resume(returning: Result.success(sites!))
}
#else
SiteList.loadAllFromNETPM { sites, err in
if err != nil {
continuation.resume(returning: Result.failure(err!))
}
continuation.resume(returning: Result.success(sites!))
}
#endif
}
}
private static func loadAllFromFS(completion: @escaping ([String: Site]?, (any Error)?) -> Void) {
let fileManager = FileManager.default
var siteDirs: [URL]

View file

@ -26,13 +26,11 @@ func MissingArgumentError(message: String, details: Any?) -> FlutterError {
GeneratedPluginRegistrant.register(with: self)
Task {
for await site in dnUpdater.siteUpdates {
for await site in await dnUpdater.siteUpdates {
self.sites?.updateSite(site: site)
// Send the refresh sites command on the main thread
DispatchQueue.main.async {
// Signal to the main screen to reload
self.ui?.invokeMethod("refreshSites", arguments: nil)
}
// Signal to the main screen to reload
self.ui?.invokeMethod("refreshSites", arguments: nil)
}
}

View file

@ -1,7 +1,7 @@
import Foundation
import os.log
class DNUpdater {
actor DNUpdater {
private let apiClient = APIClient()
private let timer = RepeatingTimer(timeInterval: 15 * 60) // 15 * 60 is 15 minutes
private let log = Logger(subsystem: "net.defined.mobileNebula", category: "DNUpdater")
@ -18,7 +18,7 @@ class DNUpdater {
return
}
self.updateSite(site: site, onUpdate: onUpdate)
await self.updateSite(site: site, onUpdate: onUpdate)
}
}
@ -28,6 +28,20 @@ class DNUpdater {
})
}
// Site updates provides an async/await alternative to `.updateAllLoop` that doesn't require a sendable closure.
// https://developer.apple.com/documentation/swift/asyncstream
var siteUpdates: AsyncStream<Site> {
AsyncStream { continuation in
timer.eventHandler = {
self.updateAll(onUpdate: { site in
continuation.yield(site)
})
}
timer.resume()
}
}
func updateAllLoop(onUpdate: @Sendable @escaping (Site) -> Void) {
timer.eventHandler = {
self.updateAll(onUpdate: onUpdate)
@ -95,20 +109,6 @@ class DNUpdater {
}
}
extension DNUpdater {
// Site updates provides an async/await alternative to `.updateAllLoop` that doesn't require a sendable closure.
// https://developer.apple.com/documentation/swift/asyncstream
var siteUpdates: AsyncStream<Site> {
AsyncStream { continuation in
self.updateAllLoop(onUpdate: { site in
continuation.yield(site)
})
}
}
}
// From https://medium.com/over-engineering/a-background-repeating-timer-in-swift-412cecfd2ef9
class RepeatingTimer {