From 77a9473915b46879bae504dfa3c1c4d0d60fa2e9 Mon Sep 17 00:00:00 2001
From: sarthurdev <965089+sarthurdev@users.noreply.github.com>
Date: Fri, 23 Jul 2021 13:39:14 +0200
Subject: pki: T3642: Add ability to write generated certificates/keys to
 specified filenames

---
 src/op_mode/pki.py | 187 +++++++++++++++++++++++++++++++++++++----------------
 1 file changed, 131 insertions(+), 56 deletions(-)

(limited to 'src')

diff --git a/src/op_mode/pki.py b/src/op_mode/pki.py
index b4a68b31c..297270cf1 100755
--- a/src/op_mode/pki.py
+++ b/src/op_mode/pki.py
@@ -38,6 +38,8 @@ from vyos.util import cmd
 
 CERT_REQ_END = '-----END CERTIFICATE REQUEST-----'
 
+auth_dir = '/config/auth'
+
 # Helper Functions
 
 def get_default_values():
@@ -230,6 +232,22 @@ def ask_passphrase():
         passphrase = ask_input('Enter passphrase:')
     return passphrase
 
+def write_file(filename, contents):
+    full_path = os.path.join(auth_dir, filename)
+    directory = os.path.dirname(full_path)
+
+    if not os.path.exists(directory):
+        print('Failed to write file: directory does not exist')
+        return False
+
+    if os.path.exists(full_path) and not ask_yes_no('Do you want to overwrite the existing file?'):
+        return False
+
+    with open(full_path, 'w') as f:
+        f.write(contents)
+
+    print(f'File written to {full_path}')
+
 # Generation functions
 
 def generate_private_key():
@@ -266,7 +284,7 @@ def parse_san_string(san_string):
             output.append(value)
     return output
 
-def generate_certificate_request(private_key=None, key_type=None, return_request=False, name=None, install=False, ask_san=True):
+def generate_certificate_request(private_key=None, key_type=None, return_request=False, name=None, install=False, file=False, ask_san=True):
     if not private_key:
         private_key, key_type = generate_private_key()
 
@@ -291,14 +309,19 @@ def generate_certificate_request(private_key=None, key_type=None, return_request
 
     passphrase = ask_passphrase()
 
-    if not install:
+    if not install and not file:
         print(encode_certificate(cert_req))
         print(encode_private_key(private_key, passphrase=passphrase))
         return None
 
-    print("Certificate request:")
-    print(encode_certificate(cert_req) + "\n")
-    install_certificate(name, private_key=private_key, key_type=key_type, key_passphrase=passphrase, is_ca=False)
+    if install:
+        print("Certificate request:")
+        print(encode_certificate(cert_req) + "\n")
+        install_certificate(name, private_key=private_key, key_type=key_type, key_passphrase=passphrase, is_ca=False)
+
+    if file:
+        write_file(f'{name}.csr', encode_certificate(cert_req))
+        write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase))
 
 def generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=False, is_sub_ca=False):
     valid_days = ask_input('Enter how many days certificate will be valid:', default='365' if not is_ca else '1825', numeric_only=True)
@@ -307,20 +330,25 @@ def generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=False, is_sub_
         cert_type = ask_input('Enter certificate type: (client, server)', default='server', valid_responses=['client', 'server'])
     return create_certificate(cert_req, ca_cert, ca_private_key, valid_days, cert_type, is_ca, is_sub_ca)
 
-def generate_ca_certificate(name, install=False):
+def generate_ca_certificate(name, install=False, file=False):
     private_key, key_type = generate_private_key()
     cert_req = generate_certificate_request(private_key, key_type, return_request=True, ask_san=False)
     cert = generate_certificate(cert_req, cert_req, private_key, is_ca=True)
     passphrase = ask_passphrase()
 
-    if not install:
+    if not install and not file:
         print(encode_certificate(cert))
         print(encode_private_key(private_key, passphrase=passphrase))
         return None
 
-    install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True)
+    if install:
+        install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True)
+
+    if file:
+        write_file(f'{name}.pem', encode_certificate(cert))
+        write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase))
 
-def generate_ca_certificate_sign(name, ca_name, install=False):
+def generate_ca_certificate_sign(name, ca_name, install=False, file=False):
     ca_dict = get_config_ca_certificate(ca_name)
 
     if not ca_dict:
@@ -374,14 +402,19 @@ def generate_ca_certificate_sign(name, ca_name, install=False):
     cert = generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=True, is_sub_ca=True)
     passphrase = ask_passphrase()
 
-    if not install:
+    if not install and not file:
         print(encode_certificate(cert))
         print(encode_private_key(private_key, passphrase=passphrase))
         return None
 
-    install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True)
+    if install:
+        install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True)
+
+    if file:
+        write_file(f'{name}.pem', encode_certificate(cert))
+        write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase))
 
-def generate_certificate_sign(name, ca_name, install=False):
+def generate_certificate_sign(name, ca_name, install=False, file=False):
     ca_dict = get_config_ca_certificate(ca_name)
 
     if not ca_dict:
@@ -435,27 +468,37 @@ def generate_certificate_sign(name, ca_name, install=False):
     cert = generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=False)
     passphrase = ask_passphrase()
 
-    if not install:
+    if not install and not file:
         print(encode_certificate(cert))
         print(encode_private_key(private_key, passphrase=passphrase))
         return None
 
-    install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=False)
+    if install:
+        install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=False)
+
+    if file:
+        write_file(f'{name}.pem', encode_certificate(cert))
+        write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase))
 
-def generate_certificate_selfsign(name, install=False):
+def generate_certificate_selfsign(name, install=False, file=False):
     private_key, key_type = generate_private_key()
     cert_req = generate_certificate_request(private_key, key_type, return_request=True)
     cert = generate_certificate(cert_req, cert_req, private_key, is_ca=False)
     passphrase = ask_passphrase()
 
-    if not install:
+    if not install and not file:
         print(encode_certificate(cert))
         print(encode_private_key(private_key, passphrase=passphrase))
         return None
 
-    install_certificate(name, cert, private_key=private_key, key_type=key_type, key_passphrase=passphrase, is_ca=False)
+    if install:
+        install_certificate(name, cert, private_key=private_key, key_type=key_type, key_passphrase=passphrase, is_ca=False)
+
+    if file:
+        write_file(f'{name}.pem', encode_certificate(cert))
+        write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase))
 
-def generate_certificate_revocation_list(ca_name, install=False):
+def generate_certificate_revocation_list(ca_name, install=False, file=False):
     ca_dict = get_config_ca_certificate(ca_name)
 
     if not ca_dict:
@@ -505,26 +548,35 @@ def generate_certificate_revocation_list(ca_name, install=False):
         print("Failed to create CRL")
         return None
 
-    if not install:
+    if not install and not file:
         print(encode_certificate(crl))
         return None
 
-    install_crl(ca_name, crl)
+    if install:
+        install_crl(ca_name, crl)
+
+    if file:
+        write_file(f'{name}.crl', encode_certificate(crl))
 
-def generate_ssh_keypair(name, install=False):
+def generate_ssh_keypair(name, install=False, file=False):
     private_key, key_type = generate_private_key()
     public_key = private_key.public_key()
     passphrase = ask_passphrase()
 
-    if not install:
+    if not install and not file:
         print(encode_public_key(public_key, encoding='OpenSSH', key_format='OpenSSH'))
         print("")
         print(encode_private_key(private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase))
         return None
 
-    install_ssh_key(name, public_key, private_key, passphrase)
+    if install:
+        install_ssh_key(name, public_key, private_key, passphrase)
+
+    if file:
+        write_file(f'{name}.pem', encode_public_key(public_key, encoding='OpenSSH', key_format='OpenSSH'))
+        write_file(f'{name}.key', encode_private_key(private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase))
 
-def generate_dh_parameters(name, install=False):
+def generate_dh_parameters(name, install=False, file=False):
     bits = ask_input('Enter DH parameters key size:', default=2048, numeric_only=True)
 
     print("Generating parameters...")
@@ -534,49 +586,62 @@ def generate_dh_parameters(name, install=False):
         print("Failed to create DH parameters")
         return None
 
-    if not install:
+    if not install and not file:
         print("DH Parameters:")
         print(encode_dh_parameters(dh_params))
 
-    install_dh_parameters(name, dh_params)
+    if install:
+        install_dh_parameters(name, dh_params)
+
+    if file:
+        write_file(f'{name}.pem', encode_dh_parameters(dh_params))
 
-def generate_keypair(name, install=False):
+def generate_keypair(name, install=False, file=False):
     private_key, key_type = generate_private_key()
     public_key = private_key.public_key()
     passphrase = ask_passphrase()
 
-    if not install:
+    if not install and not file:
         print(encode_public_key(public_key))
         print("")
         print(encode_private_key(private_key, passphrase=passphrase))
         return None
 
-    install_keypair(name, key_type, private_key, public_key, passphrase)
+    if install:
+        install_keypair(name, key_type, private_key, public_key, passphrase)
+
+    if file:
+        write_file(f'{name}.pem', encode_public_key(public_key))
+        write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase))
 
-def generate_openvpn_key(name, install=False):
+def generate_openvpn_key(name, install=False, file=False):
     result = cmd('openvpn --genkey secret /dev/stdout | grep -o "^[^#]*"')
 
     if not result:
         print("Failed to generate OpenVPN key")
         return None
 
-    if not install:
+    if not install and not file:
         print(result)
         return None
 
-    key_lines = result.split("\n")
-    key_data = "".join(key_lines[1:-1]) # Remove wrapper tags and line endings
-    key_version = '1'
+    if install:
+        key_lines = result.split("\n")
+        key_data = "".join(key_lines[1:-1]) # Remove wrapper tags and line endings
+        key_version = '1'
+
+        version_search = re.search(r'BEGIN OpenVPN Static key V(\d+)', result) # Future-proofing (hopefully)
+        if version_search:
+            key_version = version_search[1]
 
-    version_search = re.search(r'BEGIN OpenVPN Static key V(\d+)', result) # Future-proofing (hopefully)
-    if version_search:
-        key_version = version_search[1]
+        print("Configure mode commands to install OpenVPN key:")
+        print("set pki openvpn shared-secret %s key '%s'" % (name, key_data))
+        print("set pki openvpn shared-secret %s version '%s'" % (name, key_version))
 
-    print("Configure mode commands to install OpenVPN key:")
-    print("set pki openvpn shared-secret %s key '%s'" % (name, key_data))
-    print("set pki openvpn shared-secret %s version '%s'" % (name, key_version))
+    if file:
+        write_file(f'{name}.key', result)
 
-def generate_wireguard_key(name, install=False):
+def generate_wireguard_key(name, install=False, file=False):
     private_key = cmd('wg genkey')
     public_key = cmd('wg pubkey', input=private_key)
 
@@ -585,17 +650,26 @@ def generate_wireguard_key(name, install=False):
         print("Public key: " + public_key)
         return None
 
-    install_wireguard_key(name, private_key, public_key)
+    if install:
+        install_wireguard_key(name, private_key, public_key)
 
-def generate_wireguard_psk(name, install=False):
+    if file:
+        write_file(f'{name}_public.key', public_key)
+        write_file(f'{name}_private.key', private_key)
+
+def generate_wireguard_psk(name, install=False, file=False):
     psk = cmd('wg genpsk')
 
-    if not install:
+    if not install and not file:
         print("Pre-shared key:")
         print(psk)
         return None
 
-    install_wireguard_psk(name, psk)
+    if install:
+        install_wireguard_psk(name, psk)
+
+    if file:
+        write_file(f'{name}.key', psk)
 
 # Show functions
 
@@ -721,6 +795,7 @@ if __name__ == '__main__':
     parser.add_argument('--psk', help='Wireguard pre shared key', required=False)
 
     # Global
+    parser.add_argument('--file', help='Write generated keys into specified filename', action='store_true')
     parser.add_argument('--install', help='Install generated keys into running-config', action='store_true')
 
     args = parser.parse_args()
@@ -729,31 +804,31 @@ if __name__ == '__main__':
         if args.action == 'generate':
             if args.ca:
                 if args.sign:
-                    generate_ca_certificate_sign(args.ca, args.sign, args.install)
+                    generate_ca_certificate_sign(args.ca, args.sign, install=args.install, file=args.file)
                 else:
-                    generate_ca_certificate(args.ca, args.install)
+                    generate_ca_certificate(args.ca, install=args.install, file=args.file)
             elif args.certificate:
                 if args.sign:
-                    generate_certificate_sign(args.certificate, args.sign, args.install)
+                    generate_certificate_sign(args.certificate, args.sign, install=args.install, file=args.file)
                 elif args.self_sign:
-                    generate_certificate_selfsign(args.certificate, args.install)
+                    generate_certificate_selfsign(args.certificate, install=args.install, file=args.file)
                 else:
                     generate_certificate_request(name=args.certificate, install=args.install)
             elif args.crl:
-                generate_certificate_revocation_list(args.crl, args.install)
+                generate_certificate_revocation_list(args.crl, install=args.install, file=args.file)
             elif args.ssh:
-                generate_ssh_keypair(args.ssh, args.install)
+                generate_ssh_keypair(args.ssh, install=args.install, file=args.file)
             elif args.dh:
-                generate_dh_parameters(args.dh, args.install)
+                generate_dh_parameters(args.dh, install=args.install, file=args.file)
             elif args.keypair:
-                generate_keypair(args.keypair, args.install)
+                generate_keypair(args.keypair, install=args.install, file=args.file)
             elif args.openvpn:
-                generate_openvpn_key(args.openvpn, args.install)
+                generate_openvpn_key(args.openvpn, install=args.install, file=args.file)
             elif args.wireguard:
                 if args.key:
-                    generate_wireguard_key(args.key, args.install)
+                    generate_wireguard_key(args.key, install=args.install, file=args.file)
                 elif args.psk:
-                    generate_wireguard_psk(args.psk, args.install)
+                    generate_wireguard_psk(args.psk, install=args.install, file=args.file)
         elif args.action == 'show':
             if args.ca:
                 show_certificate_authority(None if args.ca == 'all' else args.ca)
-- 
cgit v1.2.3