diff options
Diffstat (limited to 'python/vyos')
-rw-r--r-- | python/vyos/pki.py | 305 | ||||
-rw-r--r-- | python/vyos/util.py | 19 |
2 files changed, 324 insertions, 0 deletions
diff --git a/python/vyos/pki.py b/python/vyos/pki.py new file mode 100644 index 000000000..80efe26b2 --- /dev/null +++ b/python/vyos/pki.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2021 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import datetime + +from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.x509.extensions import ExtensionNotFound +from cryptography.x509.oid import NameOID, ExtendedKeyUsageOID, ExtensionOID +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import dh +from cryptography.hazmat.primitives.asymmetric import dsa +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.asymmetric import rsa + +CERT_BEGIN='-----BEGIN CERTIFICATE-----\n' +CERT_END='\n-----END CERTIFICATE-----' +KEY_BEGIN='-----BEGIN PRIVATE KEY-----\n' +KEY_END='\n-----END PRIVATE KEY-----' +KEY_ENC_BEGIN='-----BEGIN ENCRYPTED PRIVATE KEY-----\n' +KEY_ENC_END='\n-----END ENCRYPTED PRIVATE KEY-----' +KEY_PUB_BEGIN='-----BEGIN PUBLIC KEY-----\n' +KEY_PUB_END='\n-----END PUBLIC KEY-----' +CRL_BEGIN='-----BEGIN X509 CRL-----\n' +CRL_END='\n-----END X509 CRL-----' +CSR_BEGIN='-----BEGIN CERTIFICATE REQUEST-----\n' +CSR_END='\n-----END CERTIFICATE REQUEST-----' +DH_BEGIN='-----BEGIN DH PARAMETERS-----\n' +DH_END='\n-----END DH PARAMETERS-----' + +# Print functions + +encoding_map = { + 'PEM': serialization.Encoding.PEM, + 'OpenSSH': serialization.Encoding.OpenSSH +} + +public_format_map = { + 'SubjectPublicKeyInfo': serialization.PublicFormat.SubjectPublicKeyInfo, + 'OpenSSH': serialization.PublicFormat.OpenSSH +} + +private_format_map = { + 'PKCS8': serialization.PrivateFormat.PKCS8, + 'OpenSSH': serialization.PrivateFormat.OpenSSH +} + +def encode_certificate(cert): + return cert.public_bytes(encoding=serialization.Encoding.PEM).decode('utf-8') + +def encode_public_key(cert, encoding='PEM', key_format='SubjectPublicKeyInfo'): + if encoding not in encoding_map: + encoding = 'PEM' + if key_format not in public_format_map: + key_format = 'SubjectPublicKeyInfo' + return cert.public_bytes( + encoding=encoding_map[encoding], + format=public_format_map[key_format]).decode('utf-8') + +def encode_private_key(private_key, encoding='PEM', key_format='PKCS8', passphrase=None): + if encoding not in encoding_map: + encoding = 'PEM' + if key_format not in private_format_map: + key_format = 'PKCS8' + encryption = serialization.NoEncryption() if not passphrase else serialization.BestAvailableEncryption(bytes(passphrase, 'utf-8')) + return private_key.private_bytes( + encoding=encoding_map[encoding], + format=private_format_map[key_format], + encryption_algorithm=encryption).decode('utf-8') + +def encode_dh_parameters(dh_parameters): + return dh_parameters.parameter_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.ParameterFormat.PKCS3).decode('utf-8') + +# EC Helper + +def get_elliptic_curve(size): + curve_func = None + name = f'SECP{size}R1' + if hasattr(ec, name): + curve_func = getattr(ec, name) + else: + curve_func = ec.SECP256R1() # Default to SECP256R1 + return curve_func() + +# Creation functions + +def create_private_key(key_type, key_size=None): + private_key = None + if key_type == 'rsa': + private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) + elif key_type == 'dsa': + private_key = dsa.generate_private_key(key_size=key_size) + elif key_type == 'ec': + curve = get_elliptic_curve(key_size) + private_key = ec.generate_private_key(curve) + return private_key + +def create_certificate_request(subject, private_key): + subject_obj = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, subject['country']), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, subject['state']), + x509.NameAttribute(NameOID.LOCALITY_NAME, subject['locality']), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject['organization']), + x509.NameAttribute(NameOID.COMMON_NAME, subject['common_name'])]) + + return x509.CertificateSigningRequestBuilder() \ + .subject_name(subject_obj) \ + .sign(private_key, hashes.SHA256()) + +def create_certificate(cert_req, ca_cert, ca_private_key, valid_days=365, cert_type='server', is_ca=False): + ext_key_usage = [] + if is_ca: + ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH] + elif cert_type == 'client': + ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH] + elif cert_type == 'server': + ext_key_usage = [ExtendedKeyUsageOID.SERVER_AUTH] + + builder = x509.CertificateBuilder() \ + .subject_name(cert_req.subject) \ + .issuer_name(ca_cert.subject) \ + .public_key(cert_req.public_key()) \ + .serial_number(x509.random_serial_number()) \ + .not_valid_before(datetime.datetime.utcnow()) \ + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=int(valid_days))) + + builder = builder.add_extension(x509.BasicConstraints(ca=is_ca, path_length=None), critical=True) + builder = builder.add_extension(x509.ExtendedKeyUsage(ext_key_usage), critical=True) + builder = builder.add_extension(x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=is_ca, + crl_sign=is_ca, + encipher_only=False, + decipher_only=False), critical=True) + + for ext in cert_req.extensions: + builder = builder.add_extension(ext, critical=False) + + return builder.sign(ca_private_key, hashes.SHA256()) + +def create_certificate_revocation_list(ca_cert, ca_private_key, serial_numbers=[]): + if not serial_numbers: + return False + + builder = x509.CertificateRevocationListBuilder() \ + .issuer_name(ca_cert.subject) \ + .last_update(datetime.datetime.today()) \ + .next_update(datetime.datetime.today() + datetime.timedelta(1, 0, 0)) + + for serial_number in serial_numbers: + revoked_cert = x509.RevokedCertificateBuilder() \ + .serial_number(serial_number) \ + .revocation_date(datetime.datetime.today()) \ + .build() + builder = builder.add_revoked_certificate(revoked_cert) + + return builder.sign(private_key=ca_private_key, algorithm=hashes.SHA256()) + +def create_dh_parameters(bits=2048): + if not bits or bits < 512: + print("Invalid DH parameter key size") + return False + + return dh.generate_parameters(generator=2, key_size=int(bits)) + +# Wrap functions + +def wrap_public_key(raw_data): + return KEY_PUB_BEGIN + raw_data + KEY_PUB_END + +def wrap_private_key(raw_data, passphrase=None): + return (KEY_ENC_BEGIN if passphrase else KEY_BEGIN) + raw_data + (KEY_ENC_END if passphrase else KEY_END) + +def wrap_certificate_request(raw_data): + return CSR_BEGIN + raw_data + CSR_END + +def wrap_certificate(raw_data): + return CERT_BEGIN + raw_data + CERT_END + +def wrap_crl(raw_data): + return CRL_BEGIN + raw_data + CRL_END + +def wrap_dh_parameters(raw_data): + return DH_BEGIN + raw_data + DH_END + +# Load functions + +def load_public_key(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_public_key(raw_data) + + try: + return serialization.load_pem_public_key(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +def load_private_key(raw_data, passphrase=None, wrap_tags=True): + if wrap_tags: + raw_data = wrap_private_key(raw_data, passphrase) + + if passphrase: + passphrase = bytes(passphrase, 'utf-8') + + try: + return serialization.load_pem_private_key(bytes(raw_data, 'utf-8'), password=passphrase) + except ValueError: + return False + +def load_certificate_request(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_certificate_request(raw_data) + + try: + return x509.load_pem_x509_csr(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +def load_certificate(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_certificate(raw_data) + + try: + return x509.load_pem_x509_certificate(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +def load_crl(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_crl(raw_data) + + try: + return x509.load_pem_x509_crl(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +def load_dh_parameters(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_dh_parameters(raw_data) + + try: + return serialization.load_pem_parameters(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +# Verify + +def is_ca_certificate(cert): + if not cert: + return False + + try: + ext = cert.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS) + return ext.value.ca + except ExtensionNotFound: + return False + +def verify_certificate(cert, ca_cert): + # Verify certificate was signed by specified CA + if ca_cert.subject != cert.issuer: + return False + + ca_public_key = ca_cert.public_key() + try: + if isinstance(ca_public_key, rsa.RSAPublicKeyWithSerialization): + ca_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + padding=padding.PKCS1v15(), + algorithm=cert.signature_hash_algorithm) + elif isinstance(ca_public_key, dsa.DSAPublicKeyWithSerialization): + ca_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + algorithm=cert.signature_hash_algorithm) + elif isinstance(ca_public_key, ec.EllipticCurvePublicKeyWithSerialization): + ca_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + signature_algorithm=ec.ECDSA(cert.signature_hash_algorithm)) + else: + return False # We cannot verify it + return True + except InvalidSignature: + return False diff --git a/python/vyos/util.py b/python/vyos/util.py index c318d58de..c3bf481ea 100644 --- a/python/vyos/util.py +++ b/python/vyos/util.py @@ -566,6 +566,25 @@ def wait_for_commit_lock(): while commit_in_progress(): sleep(1) +def ask_input(question, default='', numeric_only=False, valid_responses=[]): + question_out = question + if default: + question_out += f' (Default: {default})' + response = '' + while True: + response = input(question_out + ' ').strip() + if not response and default: + return default + if numeric_only: + if not response.isnumeric(): + print("Invalid value, try again.") + continue + response = int(response) + if valid_responses and response not in valid_responses: + print("Invalid value, try again.") + continue + break + return response def ask_yes_no(question, default=False) -> bool: """Ask a yes/no question via input() and return their answer.""" |