diff --git a/README.md b/README.md index 7312728..d7bf280 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,8 @@ Full YAML example: ~~~ yaml # name of the network namespace name: ns-example +# namespace where the interface is initialized, defaults to the main/default namespace +base_netns: null # if false, the netns itself won't be created or deleted, just the interfaces inside it managed: true # list of dns servers, if empty dns servers from default netns will be used diff --git a/wgnetns/main.py b/wgnetns/main.py index 0d0fb39..fde7a43 100755 --- a/wgnetns/main.py +++ b/wgnetns/main.py @@ -4,7 +4,6 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter from pathlib import Path from typing import Any, Optional import dataclasses -import itertools import json import os import subprocess @@ -135,6 +134,7 @@ class Peer: @dataclasses.dataclass class Interface: name: str + base_netns: str private_key: str public_key: Optional[str] = None address: list[str] = dataclasses.field(default_factory=list) @@ -144,10 +144,10 @@ class Interface: peers: list[Peer] = dataclasses.field(default_factory=list) @classmethod - def from_dict(cls, data: dict[str, Any]) -> Interface: + def from_dict(cls, data: dict[str, Any], base_netns=None) -> Interface: peers = data.pop('peers', list()) peers = [Peer.from_dict({key.replace('-', '_'): value for key, value in peer.items()}) for peer in peers] - return cls(**data, peers=peers) + return cls(**data, peers=peers, base_netns=base_netns) def setup(self, namespace: Namespace) -> Interface: self._create(namespace) @@ -160,8 +160,8 @@ class Interface: return self def _create(self, namespace: Namespace) -> None: - ip('link', 'add', self.name, 'type', 'wireguard') - ip('link', 'set', self.name, 'netns', namespace.name) + ip('link', 'add', self.name, 'type', 'wireguard', netns=self.base_netns) + ip('link', 'set', self.name, 'netns', namespace.name, netns=self.base_netns) def _configure_wireguard(self, namespace: Namespace) -> None: wg('set', self.name, 'listen-port', self.listen_port, netns=namespace.name) @@ -170,26 +170,26 @@ class Interface: def _assign_addresses(self, namespace: Namespace) -> None: for address in self.address: - ip('-n', namespace.name, '-6' if ':' in address else '-4', 'address', 'add', address, 'dev', self.name) + ip('-6' if ':' in address else '-4', 'address', 'add', address, 'dev', self.name, netns=namespace.name) def _bring_up(self, namespace: Namespace) -> None: - ip('-n', namespace.name, 'link', 'set', 'dev', self.name, 'mtu', self.mtu, 'up') + ip('link', 'set', 'dev', self.name, 'mtu', self.mtu, 'up', netns=namespace.name) def _create_routes(self, namespace: Namespace): for peer in self.peers: networks = peer.routes if peer.routes is not None else peer.allowed_ips for network in networks: - ip('-n', namespace.name, '-6' if ':' in network else '-4', 'route', 'add', network, 'dev', self.name) + ip('-6' if ':' in network else '-4', 'route', 'add', network, 'dev', self.name, netns=namespace.name) def teardown(self, namespace: Namespace, check=True) -> Interface: if self.exists(namespace): - ip('-n', namespace.name, 'link', 'set', self.name, 'down', check=check) - ip('-n', namespace.name, 'link', 'delete', self.name, check=check) + ip('link', 'set', self.name, 'down', check=check, netns=namespace.name) + ip('link', 'delete', self.name, check=check, netns=namespace.name) return self def exists(self, namespace: Namespace) -> bool: try: - ip('-n', namespace.name, 'link', 'show', self.name, capture=True) + ip('link', 'show', self.name, capture=True, netns=namespace.name) return True except Exception: return False @@ -290,7 +290,8 @@ class Namespace: scriptlets = {key: data.pop(key, None) for key in ['pre_up', 'post_up', 'pre_down', 'post_down']} scriptlets = {key: Scriptlet.from_value(value) for key, value in scriptlets.items() if value is not None} interfaces = data.pop('interfaces', list()) - interfaces = [Interface.from_dict({key.replace('-', '_'): value for key, value in interface.items()}) for interface in interfaces] + base_netns = data.pop('base_netns', None) + interfaces = [Interface.from_dict({key.replace('-', '_'): value for key, value in interface.items()}, base_netns=base_netns) for interface in interfaces] return cls(**data, **scriptlets, interfaces=interfaces) def setup(self) -> Namespace: @@ -323,7 +324,7 @@ class Namespace: def _create(self) -> None: ip('netns', 'add', self.name) - ip('-n', self.name, 'link', 'set', 'dev', 'lo', 'up') + ip('link', 'set', 'dev', 'lo', 'up', netns=self.name) def _delete(self, check=True) -> None: ip('netns', 'delete', self.name, check=check) @@ -359,8 +360,8 @@ def ip_netns_exec(*args, netns: str = None, stdin: str = None, check=True, captu return ip('netns', 'exec', netns, *args, stdin=stdin, check=check, capture=capture) -def ip(*args, stdin: str = None, check=True, capture=False) -> str: - return run('ip', *args, stdin=stdin, check=check, capture=capture) +def ip(*args, stdin: str = None, netns=None, check=True, capture=False) -> str: + return run('ip', *([] if netns is None else ['-n', netns]), *args, stdin=stdin, check=check, capture=capture) def host_eval(*args, stdin: str = None, check=True, capture=False) -> str: