From d79cbf74142d4ab9fb00ff583c147a95a134ed92 Mon Sep 17 00:00:00 2001
From: Simon <965089+sarthurdev@users.noreply.github.com>
Date: Sun, 30 May 2021 12:07:18 +0200
Subject: ipsec: T2816: Refactor to remove global variable and tidy up

---
 src/conf_mode/vpn_ipsec.py    | 173 +++++++++++++++++++++---------------------
 src/conf_mode/vpn_rsa-keys.py |   5 +-
 2 files changed, 88 insertions(+), 90 deletions(-)

(limited to 'src/conf_mode')

diff --git a/src/conf_mode/vpn_ipsec.py b/src/conf_mode/vpn_ipsec.py
index e59f20a5d..4632310ca 100755
--- a/src/conf_mode/vpn_ipsec.py
+++ b/src/conf_mode/vpn_ipsec.py
@@ -22,9 +22,10 @@ from sys import exit
 from time import sleep
 
 from vyos.config import Config
-from vyos.configdiff import ConfigDiff
+from vyos.configdiff import get_config_diff
 from vyos.template import render
 from vyos.util import call
+from vyos.util import dict_search
 from vyos.util import get_interface_address
 from vyos.util import process_named_running
 from vyos.util import run
@@ -82,24 +83,7 @@ DHCP_BASE = "/var/lib/dhcp/dhclient"
 LOCAL_KEY_PATHS = ['/config/auth/', '/config/ipsec.d/rsa-keys/']
 X509_PATH = '/config/auth/'
 
-conf = None
-
-def resync_l2tp(conf):
-    if not conf.exists('vpn l2tp remote-access ipsec-settings '):
-        return
-
-    tmp = run('/usr/libexec/vyos/conf_mode/ipsec-settings.py')
-    if tmp > 0:
-        print('ERROR: failed to reapply L2TP IPSec settings!')
-
-def resync_nhrp(conf):
-    if not conf.exists('protocols nhrp tunnel'):
-        return
-
-    run('/opt/vyatta/sbin/vyos-update-nhrp.pl --set_ipsec')
-
 def get_config(config=None):
-    global conf
     if config:
         conf = config
     else:
@@ -109,7 +93,13 @@ def get_config(config=None):
         return None
 
     # retrieve common dictionary keys
-    ipsec = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True)
+    ipsec = conf.get_config_dict(base, key_mangling=('-', '_'),
+                                 get_first_key=True, no_tag_node_value_mangle=True)
+
+    ipsec['l2tp_exists'] = conf.exists('vpn l2tp remote-access ipsec-settings ')
+    ipsec['nhrp_exists'] = conf.exists('protocols nhrp tunnel')
+    ipsec['rsa_keys'] = conf.get_config_dict(['vpn', 'rsa-keys'], key_mangling=('-', '_'),
+                                             get_first_key=True, no_tag_node_value_mangle=True)
 
     default_ike_pfs = None
 
@@ -150,8 +140,33 @@ def get_config(config=None):
                         ciphers.append(f"{enc}-{hash}-{pfs_translate[pfs]}" if pfs else f"{enc}-{hash}")
                 esp_ciphers[group] = ','.join(ciphers) + '!'
 
+    diff = get_config_diff(conf, key_mangling=('-', '_'))
+    diff.set_level(base)
+
+    new_if, old_if = diff.get_value_diff(['ipsec-interfaces', 'interface'])
+    ipsec['interface_change'] = (old_if != new_if)
+
     return ipsec
 
+def get_rsa_local_key(ipsec):
+    return dict_search('local_key.file', ipsec['rsa_keys'])
+
+def verify_rsa_local_key(ipsec):
+    file = get_rsa_local_key(ipsec)
+
+    if not file:
+        return False
+
+    for path in LOCAL_KEY_PATHS:
+        full_path = os.path.join(path, file)
+        if os.path.exists(full_path):
+            return full_path
+
+    return False
+
+def verify_rsa_key(ipsec, key_name):
+    return dict_search(f'rsa_key_name.{key_name}.rsa_key', ipsec['rsa_keys'])
+
 def verify(ipsec):
     if not ipsec:
         return None
@@ -194,30 +209,33 @@ def verify(ipsec):
                 if 'x509' not in peer_conf['authentication']:
                     raise ConfigError(f"Missing x509 settings on site-to-site peer {peer}")
 
-                if 'key' not in peer_conf['authentication']['x509'] or 'ca_cert_file' not in peer_conf['authentication']['x509'] or 'cert_file' not in peer_conf['authentication']['x509']:
+                if 'key' not in peer_conf['authentication']['x509']:
+                    raise ConfigError(f"Missing x509 key on site-to-site peer {peer}")
+
+                if 'ca_cert_file' not in peer_conf['authentication']['x509'] or 'cert_file' not in peer_conf['authentication']['x509']:
                     raise ConfigError(f"Missing x509 settings on site-to-site peer {peer}")
 
                 if 'file' not in peer_conf['authentication']['x509']['key']:
-                    raise ConfigError(f"Missing x509 settings on site-to-site peer {peer}")
+                    raise ConfigError(f"Missing x509 key file on site-to-site peer {peer}")
 
                 for key in ['ca_cert_file', 'cert_file', 'crl_file']:
                     if key in peer_conf['authentication']['x509']:
-                        path = peer_conf['authentication']['x509'][key]
-                        if not os.path.exists(path if path.startswith(X509_PATH) else (X509_PATH + path)):
+                        path = os.path.join(X509_PATH, peer_conf['authentication']['x509'][key])
+                        if not os.path.exists(path):
                             raise ConfigError(f"File not found for {key} on site-to-site peer {peer}")
 
-                key_path = peer_conf['authentication']['x509']['key']['file']
-                if not os.path.exists(key_path if key_path.startswith(X509_PATH) else (X509_PATH + key_path)):
+                key_path = os.path.join(X509_PATH, peer_conf['authentication']['x509']['key']['file'])
+                if not os.path.exists(key_path):
                     raise ConfigError(f"Private key not found on site-to-site peer {peer}")
 
             if peer_conf['authentication']['mode'] == 'rsa':
-                if not verify_rsa_local_key():
+                if not verify_rsa_local_key(ipsec):
                     raise ConfigError(f"Invalid key on rsa-keys local-key")
 
                 if 'rsa_key_name' not in peer_conf['authentication']:
                     raise ConfigError(f"Missing rsa-key-name on site-to-site peer {peer}")
 
-                if not verify_rsa_key(peer_conf['authentication']['rsa_key_name']):
+                if not verify_rsa_key(ipsec, peer_conf['authentication']['rsa_key_name']):
                     raise ConfigError(f"Invalid rsa-key-name on site-to-site peer {peer}")
 
             if 'local_address' not in peer_conf and 'dhcp_interface' not in peer_conf:
@@ -228,6 +246,9 @@ def verify(ipsec):
                 if not os.path.exists(f'{DHCP_BASE}_{dhcp_interface}.conf'):
                     raise ConfigError(f"Invalid dhcp-interface on site-to-site peer {peer}")
 
+                if not get_dhcp_address(dhcp_interface):
+                    raise ConfigError(f"Failed to get address from dhcp-interface on site-to-site peer {peer}")
+
             if 'vti' in peer_conf:
                 if 'local_address' in peer_conf and 'dhcp_interface' in peer_conf:
                     raise ConfigError(f"A single local-address or dhcp-interface is required when using VTI on site-to-site peer {peer}")
@@ -238,7 +259,7 @@ def verify(ipsec):
                         raise ConfigError(f'VTI interface {vti_interface} for site-to-site peer {peer} does not exist!')
 
             if 'vti' not in peer_conf and 'tunnel' not in peer_conf:
-                raise ConfigError(f"No vti or tunnels specified on site-to-site peer {peer}")
+                raise ConfigError(f"No VTI or tunnel specified on site-to-site peer {peer}")
 
             if 'tunnel' in peer_conf:
                 for tunnel, tunnel_conf in peer_conf['tunnel'].items():
@@ -259,33 +280,6 @@ def verify(ipsec):
                         if ('local' in tunnel_conf and 'prefix' in tunnel_conf['local']) or ('remote' in tunnel_conf and 'prefix' in tunnel_conf['remote']):
                             raise ConfigError(f"Local/remote prefix cannot be used with ESP transport mode on tunnel {tunnel} for site-to-site peer {peer}")
 
-def get_rsa_local_key():
-    global conf
-    base = ['vpn', 'rsa-keys']
-    if not conf.exists(base + ['local-key', 'file']):
-        return False
-
-    return conf.return_value(base + ['local-key', 'file'])
-
-def verify_rsa_local_key():
-    file = get_rsa_local_key()
-
-    if not file:
-        return False
-
-    for path in LOCAL_KEY_PATHS:
-        if os.path.exists(path + file):
-            return path + file
-
-    return False
-
-def verify_rsa_key(key_name):
-    global conf
-    base = ['vpn', 'rsa-keys']
-    if not conf.exists(base):
-        return False
-    return conf.exists(base + ['rsa-key-name', key_name, 'rsa-key'])
-
 def generate(ipsec):
     data = {}
 
@@ -294,7 +288,7 @@ def generate(ipsec):
         data['authby'] = authby_translate
         data['ciphers'] = {'ike': ike_ciphers, 'esp': esp_ciphers}
         data['marks'] = {}
-        data['rsa_local_key'] = verify_rsa_local_key()
+        data['rsa_local_key'] = verify_rsa_local_key(ipsec)
         data['x509_path'] = X509_PATH
 
         if 'site_to_site' in data and 'peer' in data['site_to_site']:
@@ -320,10 +314,12 @@ def generate(ipsec):
                     data['marks'][vti_interface] = get_mark(vti_interface)
                 else:
                     for tunnel, tunnel_conf in peer_conf['tunnel'].items():
-                        if ('local' not in tunnel_conf or 'prefix' not in tunnel_conf['local']) or ('remote' not in tunnel_conf or 'prefix' not in tunnel_conf['remote']):
+                        local_prefix = dict_search('local.prefix', tunnel_conf['local']['prefix'])
+                        remote_prefix = dict_search('remote.prefix', tunnel_conf['remote']['prefix'])
+
+                        if not local_prefix or not remote_prefix:
                             continue
-                        local_prefix = tunnel_conf['local']['prefix']
-                        remote_prefix = tunnel_conf['remote']['prefix']
+
                         passthrough = cidr_fit(local_prefix, remote_prefix)
                         data['site_to_site']['peer'][peer]['tunnel'][tunnel]['passthrough'] = passthrough
 
@@ -340,49 +336,48 @@ def generate(ipsec):
     render("/etc/ipsec.secrets", "ipsec/ipsec.secrets.tmpl", data)
     render("/etc/swanctl/swanctl.conf", "ipsec/swanctl.conf.tmpl", data)
 
+def resync_l2tp(ipsec):
+    if not ipsec['l2tp_exists']:
+        return
+
+    tmp = run('/usr/libexec/vyos/conf_mode/ipsec-settings.py')
+    if tmp > 0:
+        print('ERROR: failed to reapply L2TP IPSec settings!')
+
+def resync_nhrp(ipsec):
+    if not ipsec['nhrp_exists']:
+        return
+
+    run('/opt/vyatta/sbin/vyos-update-nhrp.pl --set_ipsec')
+
 def apply(ipsec):
     if not ipsec:
-        if conf.exists('vpn l2tp '):
+        if ipsec['l2tp_exists']:
             call('sudo /usr/sbin/ipsec rereadall')
             call('sudo /usr/sbin/ipsec reload')
             call('sudo /usr/sbin/swanctl -q')
         else:
             call('sudo /usr/sbin/ipsec stop')
+    else:
+        should_start = ('profile' in ipsec or dict_search('site_to_site.peer', ipsec))
 
-        resync_l2tp(conf)
-        resync_nhrp(conf)
-        return
-
-    diff = ConfigDiff(conf, key_mangling=('-', '_'))
-    diff.set_level(['vpn', 'ipsec'])
-
-    old_if, new_if = diff.get_value_diff(['ipsec-interfaces', 'interface'])
-    interface_change = (old_if != new_if)
-
-    should_start = ('profile' in ipsec or ('site_to_site' in ipsec and 'peer' in ipsec['site_to_site']))
-
-    if not process_named_running('charon'):
-        args = ''
-        if 'auto_update' in ipsec:
-            args = f'--auto-update {ipsec["auto_update"]}'
-
-        if should_start:
+        if not process_named_running('charon') and should_start:
+            args = f'--auto-update {ipsec["auto_update"]}' if 'auto_update' in ipsec else ''
             call(f'sudo /usr/sbin/ipsec start {args}')
-    else:
-        if not should_start:
-            call('sudo /usr/sbin/ipsec stop')
-        elif interface_change:
+        elif not should_start:
+            ipsec_stop()
+        elif ipsec['interface_change']:
             call('sudo /usr/sbin/ipsec restart')
         else:
             call('sudo /usr/sbin/ipsec rereadall')
             call('sudo /usr/sbin/ipsec reload')
 
-    if should_start:
-        sleep(2) # Give charon enough time to start
-        call('sudo /usr/sbin/swanctl -q')
+        if should_start:
+            sleep(2) # Give charon enough time to start
+            call('sudo /usr/sbin/swanctl -q')
 
-    resync_l2tp(conf)
-    resync_nhrp(conf)
+    resync_l2tp(ipsec)
+    resync_nhrp(ipsec)
 
 def get_mark(vti_interface):
     vti_num = int(vti_interface.lstrip('vti'))
@@ -390,10 +385,12 @@ def get_mark(vti_interface):
 
 def get_dhcp_address(interface):
     addr = get_interface_address(interface)
-    if not addr:
+    if not addr or 'addr_info' not in addr:
         return None
     if len(addr['addr_info']) == 0:
         return None
+    if 'local' not in addr['addr_info'][0]:
+        return None
     return addr['addr_info'][0]['local']
 
 if __name__ == '__main__':
diff --git a/src/conf_mode/vpn_rsa-keys.py b/src/conf_mode/vpn_rsa-keys.py
index a0e2e2690..6cf7eba6e 100755
--- a/src/conf_mode/vpn_rsa-keys.py
+++ b/src/conf_mode/vpn_rsa-keys.py
@@ -56,8 +56,9 @@ def verify(conf):
 
 def get_local_key(local_key):
     for path in LOCAL_KEY_PATHS:
-        if os.path.exists(path + local_key):
-            return path + local_key
+        full_path = os.path.join(path, local_key)
+        if os.path.exists(full_path):
+            return full_path
     return False
 
 def generate(conf):
-- 
cgit v1.2.3