diff --git a/wgnetns/main.py b/wgnetns/main.py index 16f678e..41ebd36 100755 --- a/wgnetns/main.py +++ b/wgnetns/main.py @@ -117,7 +117,7 @@ class Peer: data = {key.replace('-', '_'): value for key, value in data.items()} return cls(**data) - def setup(self, interface: Interface, namespace: Namespace) -> Peer: + def setup(self, interface: Interface, namespace: str|None) -> Peer: options = [ 'peer', self.public_key, 'preshared-key', '/dev/stdin' if self.preshared_key else '/dev/null', @@ -127,7 +127,7 @@ class Peer: options.extend(('endpoint', self.endpoint)) if self.allowed_ips: options.extend(('allowed-ips', ','.join(self.allowed_ips))) - wg('set', interface.name, *options, stdin=self.preshared_key, netns=namespace.name) + wg('set', interface.name, *options, stdin=self.preshared_key, netns=namespace) return self @@ -150,37 +150,40 @@ class Interface: return cls(**data, peers=peers, base_netns=base_netns) def setup(self, namespace: Namespace) -> Interface: - self._create(namespace) - self._configure_wireguard(namespace) + self._create() + self._configure_wireguard() for peer in self.peers: - peer.setup(self, namespace) - self._assign_addresses(namespace) - self._bring_up(namespace) - self._create_routes(namespace) + peer.setup(self, self.base_netns) + self._assign_namespace(namespace.name) + self._assign_addresses(namespace.name) + self._bring_up(namespace.name) + self._create_routes(namespace.name) return self - def _create(self, namespace: Namespace) -> None: + def _create(self) -> None: ip('link', 'add', self.name, 'type', 'wireguard', netns=self.base_netns) - ip('link', 'set', self.name, 'netns', namespace.name, netns=self.base_netns) - def _configure_wireguard(self, namespace: Namespace) -> None: - wg('set', self.name, 'listen-port', self.listen_port, netns=namespace.name) - wg('set', self.name, 'fwmark', self.fwmark, netns=namespace.name) + def _configure_wireguard(self) -> None: + wg('set', self.name, 'listen-port', self.listen_port, netns=self.base_netns) + wg('set', self.name, 'fwmark', self.fwmark, netns=self.base_netns) if self.private_key: - wg('set', self.name, 'private-key', '/dev/stdin', stdin=self.private_key, netns=namespace.name) + wg('set', self.name, 'private-key', '/dev/stdin', stdin=self.private_key, netns=self.base_netns) + + def _assign_namespace(self, namespace: str) -> None: + ip('link', 'set', self.name, 'netns', namespace, netns=self.base_netns) - def _assign_addresses(self, namespace: Namespace) -> None: + def _assign_addresses(self, namespace: str) -> None: for address in self.address: - ip('-6' if ':' in address else '-4', 'address', 'add', address, 'dev', self.name, netns=namespace.name) + ip('-6' if ':' in address else '-4', 'address', 'add', address, 'dev', self.name, netns=namespace) - def _bring_up(self, namespace: Namespace) -> None: - ip('link', 'set', 'dev', self.name, 'mtu', self.mtu, 'up', netns=namespace.name) + def _bring_up(self, namespace: str) -> None: + ip('link', 'set', 'dev', self.name, 'mtu', self.mtu, 'up', netns=namespace) - def _create_routes(self, namespace: Namespace): + def _create_routes(self, namespace: str): for peer in self.peers: networks = peer.routes if peer.routes is not None else peer.allowed_ips for network in networks: - ip('-6' if ':' in network else '-4', 'route', 'add', network, 'dev', self.name, netns=namespace.name) + ip('-6' if ':' in network else '-4', 'route', 'add', network, 'dev', self.name, netns=namespace) def teardown(self, namespace: Namespace, check=True) -> Interface: if self.exists(namespace): @@ -350,14 +353,17 @@ class Namespace: def wg(*args, netns: str|None = None, stdin: str|None = None, check=True, capture=False) -> str: - return ip_netns_exec('wg', *args, netns=netns, stdin=stdin, check=check, capture=capture) + if netns: + return ip_netns_exec('wg', *args, netns=netns, stdin=stdin, check=check, capture=capture) + else: + return run('wg', *args, stdin=stdin, check=check, capture=capture) -def ip_netns_eval(*args, netns: str|None = None, stdin: str|None = None, check=True, capture=False) -> str: +def ip_netns_eval(*args, netns: str, stdin: str|None = None, check=True, capture=False) -> str: return ip_netns_exec(SHELL, '-c', *args, netns=netns, stdin=stdin, check=check, capture=capture) -def ip_netns_exec(*args, netns: str|None = None, stdin: str|None = None, check=True, capture=False) -> str: +def ip_netns_exec(*args, netns: str, stdin: str|None = None, check=True, capture=False) -> str: return ip('netns', 'exec', netns, *args, stdin=stdin, check=check, capture=capture)