You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
wg-netns/wg-netns.py

214 lines
7.3 KiB
Python

#!/usr/bin/env python3
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from pathlib import Path
import json
import os
import subprocess
import sys
NETNS_CONFIG_DIR = '/etc/netns'
DEBUG_LEVEL = 0
SHELL = '/bin/sh'
def main(args):
global NETNS_CONFIG_DIR
global DEBUG_LEVEL
global SHELL
entrypoint = ArgumentParser(
formatter_class=RawDescriptionHelpFormatter,
epilog=(
'environment variables:\n'
f' NETNS_CONFIG_DIR network namespace config directory, default: {NETNS_CONFIG_DIR}\n'
f' DEBUG_LEVEL print stack traces, default: {DEBUG_LEVEL}\n'
f' SHELL program for execution of shell hooks, default: {SHELL}\n'
),
)
subparsers = entrypoint.add_subparsers(dest='action', required=True)
parser = subparsers.add_parser('up', help='setup namespace and associated interfaces')
parser.add_argument('profile', type=lambda x: Path(x).expanduser(), help='path to profile')
parser = subparsers.add_parser('down', help='teardown namespace and associated interfaces')
parser.add_argument('-f', '--force', action='store_true', help='ignore errors')
parser.add_argument('profile', type=lambda x: Path(x).expanduser(), help='path to profile')
opts = entrypoint.parse_args(args)
try:
NETNS_CONFIG_DIR = Path(os.environ.get('NETNS_CONFIG_DIR', NETNS_CONFIG_DIR))
DEBUG_LEVEL = int(os.environ.get('DEBUG_LEVEL', DEBUG_LEVEL))
SHELL = Path(os.environ.get('SHELL', SHELL))
except Exception as e:
raise RuntimeError(f'failed to load environment variable: {e} (e.__class__.__name__)') from e
if opts.action == 'up':
setup_action(opts.profile)
elif opts.action == 'down':
teardown_action(opts.profile, check=not opts.force)
else:
raise RuntimeError('congratulations, you reached unreachable code')
def setup_action(path):
namespace = profile_read(path)
try:
namespace_setup(namespace)
except KeyboardInterrupt:
namespace_teardown(namespace, check=False)
except Exception as e:
namespace_teardown(namespace, check=False)
raise
def teardown_action(path, check=True):
namespace = profile_read(path)
namespace_teardown(namespace, check=check)
def profile_read(path):
with open(path) as file:
return json.load(file)
def namespace_setup(namespace):
if namespace.get('pre-up'):
ip_netns_shell(namespace['pre-up'], netns=namespace)
namespace_create(namespace)
namespace_resolvconf_write(namespace)
for interface in namespace['interfaces']:
interface_setup(interface, namespace)
if namespace.get('post-up'):
ip_netns_shell(namespace['post-up'], netns=namespace)
def namespace_create(namespace):
ip('netns', 'add', namespace['name'])
ip('-n', namespace['name'], 'link', 'set', 'dev', 'lo', 'up')
def namespace_resolvconf_write(namespace):
content = '\n'.join(f'nameserver {server}' for server in namespace['dns-server'])
if content:
NETNS_CONFIG_DIR.joinpath(namespace['name']).mkdir(parents=True, exist_ok=True)
NETNS_CONFIG_DIR.joinpath(namespace['name']).joinpath('resolv.conf').write_text(content)
def namespace_teardown(namespace, check=True):
if namespace.get('pre-down'):
ip_netns_shell(namespace['pre-down'], netns=namespace)
for interface in namespace['interfaces']:
interface_teardown(interface, namespace)
namespace_delete(namespace)
namespace_resolvconf_delete(namespace)
if namespace.get('post-down'):
ip_netns_shell(namespace['post-down'], netns=namespace)
def namespace_delete(namespace, check=True):
ip('netns', 'delete', namespace['name'], check=check)
def namespace_resolvconf_delete(namespace):
path = NETNS_CONFIG_DIR/namespace['name']/'resolv.conf'
if path.exists():
path.unlink()
try:
NETNS_CONFIG_DIR.rmdir()
except OSError:
pass
def interface_setup(interface, namespace):
interface_create(interface, namespace)
interface_configure_wireguard(interface, namespace)
for peer in interface['peers']:
peer_setup(peer, interface, namespace)
interface_assign_addresses(interface, namespace)
interface_bring_up(interface, namespace)
interface_create_routes(interface, namespace)
def interface_create(interface, namespace):
ip('link', 'add', interface['name'], 'type', 'wireguard')
ip('link', 'set', interface['name'], 'netns', namespace['name'])
def interface_configure_wireguard(interface, namespace):
wg('set', interface['name'], 'listen-port', interface.get('listen-port', 0), netns=namespace)
wg('set', interface['name'], 'fwmark', interface.get('fwmark', 0), netns=namespace)
wg('set', interface['name'], 'private-key', '/dev/stdin', stdin=interface['private-key'], netns=namespace)
def interface_assign_addresses(interface, namespace):
for address in interface['address']:
ip('-n', namespace['name'], '-6' if ':' in address else '-4', 'address', 'add', address, 'dev', interface['name'])
def interface_bring_up(interface, namespace):
ip('-n', namespace['name'], 'link', 'set', 'dev', interface['name'], 'mtu', interface.get('mtu', 1420), 'up')
def interface_create_routes(interface, namespace):
for peer in interface['peers']:
for network in peer.get('allowed-ips', ()):
ip('-n', namespace['name'], '-6' if ':' in network else '-4', 'route', 'add', network, 'dev', interface['name'])
def interface_teardown(interface, namespace, check=True):
ip('-n', namespace['name'], 'link', 'set', interface['name'], 'down', check=check)
ip('-n', namespace['name'], 'link', 'delete', interface['name'], check=check)
def peer_setup(peer, interface, namespace):
options = [
'peer', peer['public-key'],
'preshared-key', '/dev/stdin' if peer.get('preshared-key') else '/dev/null',
'persistent-keepalive', peer.get('persistent-keepalive', 0),
]
if peer.get('endpoint'):
options.extend(('endpoint', peer.get('endpoint')))
if peer.get('allowed-ips'):
options.extend(('allowed-ips', ','.join(peer['allowed-ips'])))
wg('set', interface['name'], *options, stdin=peer.get('preshared-key'), netns=namespace)
def wg(*args, **kwargs):
return ip_netns_exec('wg', *args, **kwargs)
def ip_netns_shell(*args, **kwargs):
return ip_netns_exec(SHELL, '-c', *args, **kwargs)
def ip_netns_exec(*args, netns=None, **kwargs):
return ip('netns', 'exec', netns['name'], *args, **kwargs)
def ip(*args, **kwargs):
return run('ip', *args, **kwargs)
def run(*args, stdin=None, check=True, capture=False):
args = [str(item) if item is not None else '' for item in args]
if DEBUG_LEVEL:
print('>', ' '.join(args), file=sys.stderr)
process = subprocess.run(args, input=stdin, text=True, capture_output=capture)
if check and process.returncode != 0:
error = process.stderr.strip() if process.stderr else f'exit code {process.returncode}'
raise RuntimeError(f'subprocess failed: {" ".join(args)}: {error}')
return process.stdout
if __name__ == '__main__':
try:
main(sys.argv[1:])
sys.exit(0)
except Exception as e:
if DEBUG_LEVEL:
raise
print(f'error: {e} ({e.__class__.__name__})', file=sys.stderr)
sys.exit(2)