@ -4,7 +4,6 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from pathlib import Path
from typing import Any , Optional
import dataclasses
import getpass
import json
import os
import subprocess
@ -15,7 +14,6 @@ try:
YAML_SUPPORTED = True
except ModuleNotFoundError :
YAML_SUPPORTED = False
yaml = NotImplemented
WIREGUARD_DIR = Path ( ' /etc/wireguard ' )
NETNS_DIR = Path ( ' /etc/netns ' )
@ -23,18 +21,7 @@ VERBOSE = 0
SHELL = Path ( ' /bin/sh ' )
def main ( ) :
try :
cli ( sys . argv [ 1 : ] )
sys . exit ( 0 )
except Exception as e :
print ( f ' error: { e } ( { e . __class__ . __name__ } ) ' , file = sys . stderr )
if VERBOSE :
raise
sys . exit ( 1 )
def cli ( args ) :
def main ( args ) :
global WIREGUARD_DIR
global NETNS_DIR
global VERBOSE
@ -60,15 +47,6 @@ def cli(args):
parser . add_argument ( ' -f ' , ' --force ' , action = ' store_true ' , help = ' ignore errors ' )
parser . add_argument ( ' profile ' , type = lambda x : Path ( x ) . expanduser ( ) , metavar = ' PROFILE ' , help = ' name or path of profile ' )
parser = subparsers . add_parser ( ' list ' , help = ' list network namespaces ' )
parser = subparsers . add_parser ( ' switch ' , help = ' open shell in namespace ' )
parser . add_argument ( ' netns ' , metavar = ' NETNS ' , help = ' network namespace name ' )
parser = subparsers . add_parser ( ' exec ' , help = ' run command in namespace ' )
parser . add_argument ( ' netns ' , metavar = ' NETNS ' , help = ' network namespace name ' )
parser . add_argument ( ' command ' , nargs = ' + ' , help = ' command ' )
opts = entrypoint . parse_args ( args )
try :
@ -79,9 +57,8 @@ def cli(args):
except Exception as e :
raise RuntimeError ( f ' failed to load environment variable: { e } (e.__class__.__name__) ' ) from e
namespace = Namespace . from_profile ( opts . profile )
if opts . action == ' up ' :
_conditional_elevate ( )
namespace = Namespace . from_profile ( opts . profile )
try :
namespace . setup ( )
except KeyboardInterrupt :
@ -90,33 +67,16 @@ def cli(args):
namespace . teardown ( check = False )
raise
elif opts . action == ' down ' :
_conditional_elevate ( )
namespace = Namespace . from_profile ( opts . profile )
namespace . teardown ( check = not opts . force )
elif opts . action == ' list ' :
output = ip ( ' -json ' , ' netns ' , capture = True )
if not output :
return
data = json . loads ( output )
print ( ' \n ' . join ( item [ ' name ' ] for item in data ) )
elif opts . action == ' switch ' :
os . execvp ( ' sudo ' , [ ' ip ' , ' ip ' , ' netns ' , ' exec ' , opts . netns , ' sudo ' , ' -u ' , getpass . getuser ( ) , ' -D ' , Path . cwd ( ) . as_posix ( ) , os . environ [ ' SHELL ' ] , ' -i ' ] )
elif opts . action == ' exec ' :
os . execvp ( ' sudo ' , [ ' ip ' , ' ip ' , ' netns ' , ' exec ' , opts . netns , ' sudo ' , ' -u ' , getpass . getuser ( ) , ' -D ' , Path . cwd ( ) . as_posix ( ) , * opts . command ] )
else :
raise RuntimeError ( ' congratulations, you reached unreachable code ' )
def _conditional_elevate ( ) - > None :
if os . getuid ( ) != 0 and os . isatty ( sys . stdin . fileno ( ) ) :
os . execvp ( ' sudo ' , [ sys . argv [ 0 ] , * sys . argv ] )
@dataclasses.dataclass
class Peer :
name : str
public_key : str
preshared_key : Optional [ str ] = None
name : Optional [ str ] = None
endpoint : Optional [ str ] = None
persistent_keepalive : int = 0
allowed_ips : list [ str ] = dataclasses . field ( default_factory = list )
@ -127,7 +87,7 @@ class Peer:
data = { key . replace ( ' - ' , ' _ ' ) : value for key , value in data . items ( ) }
return cls ( * * data )
def setup ( self , interface : Interface , namespace : str | Non e) - > Peer :
def setup ( self , interface : Interface , namespace : Namespac e) - > Peer :
options = [
' peer ' , self . public_key ,
' preshared-key ' , ' /dev/stdin ' if self . preshared_key else ' /dev/null ' ,
@ -137,16 +97,15 @@ class Peer:
options . extend ( ( ' endpoint ' , self . endpoint ) )
if self . allowed_ips :
options . extend ( ( ' allowed-ips ' , ' , ' . join ( self . allowed_ips ) ) )
wg ( ' set ' , interface . name , * options , stdin = self . preshared_key , netns = namespace )
wg ( ' set ' , interface . name , * options , stdin = self . preshared_key , netns = namespace . name )
return self
@dataclasses.dataclass
class Interface :
name : str
base_netns : str | None = None
private_key : Optional [ str ] = None
public_key : Optional [ str ] = None
public_key : str
private_key : str
address : list [ str ] = dataclasses . field ( default_factory = list )
listen_port : int = 0
fwmark : int = 0
@ -154,118 +113,64 @@ class Interface:
peers : list [ Peer ] = dataclasses . field ( default_factory = list )
@classmethod
def from_dict ( cls , data : dict [ str , Any ] , base_netns : str | None = None ) - > Interface :
def from_dict ( cls , data : dict [ str , Any ] ) - > Interface :
peers = data . pop ( ' peers ' , list ( ) )
peers = [ Peer . from_dict ( { key . replace ( ' - ' , ' _ ' ) : value for key , value in peer . items ( ) } ) for peer in peers ]
return cls ( * * data , peers = peers , base_netns = base_netns )
return cls ( * * data , peers = peers )
def setup ( self , namespace : Namespace ) - > Interface :
self . _create ( )
self . _configure_wireguard ( )
self . _create ( namespace )
self . _configure_wireguard ( namespace )
for peer in self . peers :
peer . setup ( self , self . base_netns )
self . _assign_namespace ( namespace . name )
self . _assign_addresses ( namespace . name )
self . _bring_up ( namespace . name )
self . _create_routes ( namespace . name )
peer . setup ( self , namespace )
self . _assign_addresses ( namespace )
self . _bring_up ( namespace )
self . _create_routes ( namespace )
return self
def _create ( self ) - > None :
ip ( ' link ' , ' add ' , self . name , ' type ' , ' wireguard ' , netns = self . base_netns )
def _create ( self , namespace : Namespace ) - > None :
ip ( ' link ' , ' add ' , self . name , ' type ' , ' wireguard ' )
ip ( ' link ' , ' set ' , self . name , ' netns ' , namespace . name )
def _configure_wireguard ( self ) - > None :
wg ( ' set ' , self . name , ' listen-port ' , self . listen_port , netns = self . base_netns )
wg ( ' set ' , self . name , ' fwmark ' , self . fwmark , netns = self . base_netns )
if self . private_key :
wg ( ' set ' , self . name , ' private-key ' , ' /dev/stdin ' , stdin = self . private_key , netns = self . base_netns )
def _configure_wireguard ( self , namespace : Namespace ) - > None :
wg ( ' set ' , self . name , ' listen-port ' , self . listen_port , netns = namespace . name )
wg ( ' set ' , self . name , ' fwmark ' , self . fwmark , netns = namespace . name )
wg ( ' set ' , self . name , ' private-key ' , ' /dev/stdin ' , stdin = self . private_key , netns = namespace . name )
def _assign_namespace ( self , namespace : str | None ) - > None :
ip ( ' link ' , ' set ' , self . name , ' netns ' , namespace if namespace else ' 1 ' , netns = self . base_netns )
def _assign_addresses ( self , namespace : str | None ) - > None :
def _assign_addresses ( self , namespace : Namespace ) - > None :
for address in self . address :
ip ( ' - 6' if ' : ' in address else ' -4 ' , ' address ' , ' add ' , address , ' dev ' , self . nam e, netns = namespac e)
ip ( ' -n ' , namespace . name , ' -6 ' if ' : ' in address else ' -4 ' , ' address ' , ' add ' , address , ' dev ' , self . name )
def _bring_up ( self , namespace : str | Non e) - > None :
ip ( ' link' , ' set ' , ' dev ' , self . name , ' mtu ' , self . mtu , ' up ' , netns = namespace )
def _bring_up ( self , namespace : Namespac e) - > None :
ip ( ' -n' , namespace . name , ' link' , ' set ' , ' dev ' , self . name , ' mtu ' , self . mtu , ' up ' )
def _create_routes ( self , namespace : str | Non e) :
def _create_routes ( self , namespace : Namespac e) :
for peer in self . peers :
networks = peer . routes if peer . routes is not None else peer . allowed_ips
for network in networks :
ip ( ' - 6' if ' : ' in network else ' -4 ' , ' route ' , ' add ' , network , ' dev ' , self . nam e, netns = namespac e)
ip ( ' - n' , namespace . name , ' - 6' if ' : ' in network else ' -4 ' , ' route ' , ' add ' , network , ' dev ' , self . nam e)
def teardown ( self , namespace : Namespace , check = True ) - > Interface :
if self . exists ( namespace ) :
ip ( ' link' , ' set ' , self . name , ' down ' , check = check , netns = namespace . name )
ip ( ' link' , ' delete ' , self . name , check = check , netns = namespace . name )
ip ( ' -n' , namespace . name , ' link' , ' set ' , self . name , ' down ' , check = check )
ip ( ' -n' , namespace . name , ' link' , ' delete ' , self . name , check = check )
return self
def exists ( self , namespace : Namespace ) - > bool :
try :
ip ( ' link' , ' show ' , self . name , capture = Tru e, netns = namespace . nam e)
ip ( ' -n' , namespace . name , ' link' , ' show ' , self . name , capture = Tru e)
return True
except Exception :
return False
@dataclasses.dataclass
class ScriptletItem :
command : str
host_namespace : bool = False
@classmethod
def from_str ( cls , data : str ) - > ScriptletItem :
return cls ( command = data )
@classmethod
def from_dict ( cls , data : dict [ str , Any ] ) - > ScriptletItem :
data = { key . replace ( ' - ' , ' _ ' ) : value for key , value in data . items ( ) }
host_namespace = bool ( data . pop ( ' host_namespace ' , None ) )
return cls ( * * data , host_namespace = host_namespace )
def run ( self , netns : str | None ) :
if self . host_namespace or not netns :
host_eval ( self . command )
else :
ip_netns_eval ( self . command , netns = netns )
@dataclasses.dataclass
class Scriptlet :
items : list [ ScriptletItem ] = dataclasses . field ( default_factory = list )
@classmethod
def from_value ( cls , data ) - > Scriptlet :
if isinstance ( data , list ) :
return cls . from_list ( data )
elif isinstance ( data , str ) :
return cls . from_singleton ( data )
else :
raise RuntimeError ( f ' unsupported scriptlet type: { data . __class__ . __name__ } ' )
@classmethod
def from_list ( cls , data : list [ Any ] ) - > Scriptlet :
items = [ ScriptletItem . from_dict ( item ) for item in data ]
return cls ( items = items )
@classmethod
def from_singleton ( cls , data ) - > Scriptlet :
item = ScriptletItem . from_str ( data )
return cls ( items = [ item ] )
def run ( self , netns : str | None ) :
for item in self . items :
item . run ( netns = netns )
@dataclasses.dataclass
class Namespace :
name : str | None
pre_up : Optional [ Scriptlet ] = None
post_up : Optional [ Scriptlet ] = None
pre_down : Optional [ Scriptlet ] = None
post_down : Optional [ Scriptlet ] = None
name : str
pre_up : Optional [ str ] = None
post_up : Optional [ str ] = None
pre_down : Optional [ str ] = None
post_down : Optional [ str ] = None
managed : bool = True
dns_server : list [ str ] = dataclasses . field ( default_factory = list )
interfaces : list [ Interface ] = dataclasses . field ( default_factory = list )
@ -275,7 +180,7 @@ class Namespace:
try :
return cls . from_dict ( cls . _read_profile ( cls . _find_profile ( path ) ) )
except Exception as e :
raise RuntimeError ( f ' failed to load profile : { e } ' ) from e
raise RuntimeError ( ' failed to load profile ' ) from e
@staticmethod
def _find_profile ( profile : Path ) - > Path :
@ -301,35 +206,32 @@ class Namespace:
@classmethod
def from_dict ( cls , data : dict [ str , Any ] ) - > Namespace :
data = { key . replace ( ' - ' , ' _ ' ) : value for key , value in data . items ( ) }
scriptlets = { key : data . pop ( key , None ) for key in [ ' pre_up ' , ' post_up ' , ' pre_down ' , ' post_down ' ] }
scriptlets = { key : Scriptlet . from_value ( value ) for key , value in scriptlets . items ( ) if value is not None }
interfaces = data . pop ( ' interfaces ' , list ( ) )
base_netns = data . pop ( ' base_netns ' , None )
interfaces = [ Interface . from_dict ( { key . replace ( ' - ' , ' _ ' ) : value for key , value in interface . items ( ) } , base_netns = base_netns ) for interface in interfaces ]
return cls ( * * data , * * scriptlets , interfaces = interfaces ) # type: ignore
interfaces = [ Interface . from_dict ( { key . replace ( ' - ' , ' _ ' ) : value for key , value in interface . items ( ) } ) for interface in interfaces ]
return cls ( * * data , interfaces = interfaces )
def setup ( self ) - > Namespace :
if self . managed and self . name :
if self . pre_up :
ip_netns_eval ( self . pre_up , netns = self . name )
if self . managed :
self . _create ( )
self . _write_resolvconf ( )
if self . pre_up :
self . pre_up . run ( netns = self . name )
for interface in self . interfaces :
interface . setup ( self )
if self . post_up :
self . post_up . run ( netns = self . name )
ip_netns_eval ( self . post_up , netns = self . name )
return self
def teardown ( self , check = True ) - > Namespace :
if self . pre_down :
self . pre_down . run ( netns = self . name )
ip_netns_eval ( self . pre_down , netns = self . name )
for interface in self . interfaces :
interface . teardown ( self , check = check )
if self . post_down :
self . post_down . run ( netns = self . name )
if self . managed and self . exists ( ) :
self . _delete ( check )
self . _delete_resolvconf ( )
if self . post_down :
ip_netns_eval ( self . post_down , netns = self . name )
return self
def exists ( self ) - > bool :
@ -338,14 +240,13 @@ class Namespace:
def _create ( self ) - > None :
ip ( ' netns ' , ' add ' , self . name )
ip ( ' link' , ' set ' , ' dev ' , ' lo ' , ' up ' , netns = self . name )
ip ( ' -n' , self . name , ' link' , ' set ' , ' dev ' , ' lo ' , ' up ' )
def _delete ( self , check = True ) - > None :
ip ( ' netns ' , ' delete ' , self . name , check = check )
@property
def _resolvconf_path ( self ) - > Path :
assert self . name
return NETNS_DIR / self . name / ' resolv.conf '
def _write_resolvconf ( self ) - > None :
@ -363,30 +264,23 @@ class Namespace:
pass
def wg ( * args , netns : str | None = None , stdin : str | None = None , check = True , capture = False ) - > str :
if netns :
return ip_netns_exec ( ' wg ' , * args , netns = netns , stdin = stdin , check = check , capture = capture )
else :
return run ( ' wg ' , * args , stdin = stdin , check = check , capture = capture )
def wg ( * args , netns : str = None , stdin : str = None , check = True , capture = False ) - > str :
return ip_netns_exec ( ' wg ' , * args , netns = netns , stdin = stdin , check = check , capture = capture )
def ip_netns_eval ( * args , netns : str , stdin : str | None = None , check = True , capture = False ) - > str :
def ip_netns_eval ( * args , netns : str = None , stdin : str = None , check = True , capture = False ) - > str :
return ip_netns_exec ( SHELL , ' -c ' , * args , netns = netns , stdin = stdin , check = check , capture = capture )
def ip_netns_exec ( * args , netns : str , stdin : str | None = None , check = True , capture = False ) - > str :
def ip_netns_exec ( * args , netns : str = None , stdin : str = None , check = True , capture = False ) - > str :
return ip ( ' netns ' , ' exec ' , netns , * args , stdin = stdin , check = check , capture = capture )
def ip ( * args , stdin : str | None = None , netns : str | None = None , check = True , capture = False ) - > str :
return run ( ' ip ' , * ( [ ' -n ' , netns ] if netns else [ ] ) , * args , stdin = stdin , check = check , capture = capture )
def ip ( * args , stdin : str = None , check = True , capture = False ) - > str :
return run ( ' ip ' , * args , stdin = stdin , check = check , capture = capture )
def host_eval ( * args , stdin : str | None = None , check = True , capture = False ) - > str :
return run ( SHELL , ' -c ' , * args , stdin = stdin , check = check , capture = capture )
def run ( * args , stdin : str | None = None , check = True , capture = False ) - > str :
def run ( * args , stdin : str = None , check = True , capture = False ) - > str :
args = [ str ( item ) if item is not None else ' ' for item in args ]
if VERBOSE :
print ( ' > ' , ' ' . join ( args ) , file = sys . stderr )
@ -398,4 +292,11 @@ def run(*args, stdin: str|None = None, check=True, capture=False) -> str:
if __name__ == ' __main__ ' :
main ( )
try :
main ( sys . argv [ 1 : ] )
sys . exit ( 0 )
except Exception as e :
print ( f ' error: { e } ( { e . __class__ . __name__ } ) ' , file = sys . stderr )
if VERBOSE :
raise
sys . exit ( 1 )