summaryrefslogtreecommitdiff
path: root/python/vyos
diff options
context:
space:
mode:
Diffstat (limited to 'python/vyos')
-rw-r--r--python/vyos/pki.py305
-rw-r--r--python/vyos/util.py19
2 files changed, 324 insertions, 0 deletions
diff --git a/python/vyos/pki.py b/python/vyos/pki.py
new file mode 100644
index 000000000..80efe26b2
--- /dev/null
+++ b/python/vyos/pki.py
@@ -0,0 +1,305 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2021 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+import datetime
+
+from cryptography import x509
+from cryptography.exceptions import InvalidSignature
+from cryptography.x509.extensions import ExtensionNotFound
+from cryptography.x509.oid import NameOID, ExtendedKeyUsageOID, ExtensionOID
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import dh
+from cryptography.hazmat.primitives.asymmetric import dsa
+from cryptography.hazmat.primitives.asymmetric import ec
+from cryptography.hazmat.primitives.asymmetric import padding
+from cryptography.hazmat.primitives.asymmetric import rsa
+
+CERT_BEGIN='-----BEGIN CERTIFICATE-----\n'
+CERT_END='\n-----END CERTIFICATE-----'
+KEY_BEGIN='-----BEGIN PRIVATE KEY-----\n'
+KEY_END='\n-----END PRIVATE KEY-----'
+KEY_ENC_BEGIN='-----BEGIN ENCRYPTED PRIVATE KEY-----\n'
+KEY_ENC_END='\n-----END ENCRYPTED PRIVATE KEY-----'
+KEY_PUB_BEGIN='-----BEGIN PUBLIC KEY-----\n'
+KEY_PUB_END='\n-----END PUBLIC KEY-----'
+CRL_BEGIN='-----BEGIN X509 CRL-----\n'
+CRL_END='\n-----END X509 CRL-----'
+CSR_BEGIN='-----BEGIN CERTIFICATE REQUEST-----\n'
+CSR_END='\n-----END CERTIFICATE REQUEST-----'
+DH_BEGIN='-----BEGIN DH PARAMETERS-----\n'
+DH_END='\n-----END DH PARAMETERS-----'
+
+# Print functions
+
+encoding_map = {
+ 'PEM': serialization.Encoding.PEM,
+ 'OpenSSH': serialization.Encoding.OpenSSH
+}
+
+public_format_map = {
+ 'SubjectPublicKeyInfo': serialization.PublicFormat.SubjectPublicKeyInfo,
+ 'OpenSSH': serialization.PublicFormat.OpenSSH
+}
+
+private_format_map = {
+ 'PKCS8': serialization.PrivateFormat.PKCS8,
+ 'OpenSSH': serialization.PrivateFormat.OpenSSH
+}
+
+def encode_certificate(cert):
+ return cert.public_bytes(encoding=serialization.Encoding.PEM).decode('utf-8')
+
+def encode_public_key(cert, encoding='PEM', key_format='SubjectPublicKeyInfo'):
+ if encoding not in encoding_map:
+ encoding = 'PEM'
+ if key_format not in public_format_map:
+ key_format = 'SubjectPublicKeyInfo'
+ return cert.public_bytes(
+ encoding=encoding_map[encoding],
+ format=public_format_map[key_format]).decode('utf-8')
+
+def encode_private_key(private_key, encoding='PEM', key_format='PKCS8', passphrase=None):
+ if encoding not in encoding_map:
+ encoding = 'PEM'
+ if key_format not in private_format_map:
+ key_format = 'PKCS8'
+ encryption = serialization.NoEncryption() if not passphrase else serialization.BestAvailableEncryption(bytes(passphrase, 'utf-8'))
+ return private_key.private_bytes(
+ encoding=encoding_map[encoding],
+ format=private_format_map[key_format],
+ encryption_algorithm=encryption).decode('utf-8')
+
+def encode_dh_parameters(dh_parameters):
+ return dh_parameters.parameter_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.ParameterFormat.PKCS3).decode('utf-8')
+
+# EC Helper
+
+def get_elliptic_curve(size):
+ curve_func = None
+ name = f'SECP{size}R1'
+ if hasattr(ec, name):
+ curve_func = getattr(ec, name)
+ else:
+ curve_func = ec.SECP256R1() # Default to SECP256R1
+ return curve_func()
+
+# Creation functions
+
+def create_private_key(key_type, key_size=None):
+ private_key = None
+ if key_type == 'rsa':
+ private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
+ elif key_type == 'dsa':
+ private_key = dsa.generate_private_key(key_size=key_size)
+ elif key_type == 'ec':
+ curve = get_elliptic_curve(key_size)
+ private_key = ec.generate_private_key(curve)
+ return private_key
+
+def create_certificate_request(subject, private_key):
+ subject_obj = x509.Name([
+ x509.NameAttribute(NameOID.COUNTRY_NAME, subject['country']),
+ x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, subject['state']),
+ x509.NameAttribute(NameOID.LOCALITY_NAME, subject['locality']),
+ x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject['organization']),
+ x509.NameAttribute(NameOID.COMMON_NAME, subject['common_name'])])
+
+ return x509.CertificateSigningRequestBuilder() \
+ .subject_name(subject_obj) \
+ .sign(private_key, hashes.SHA256())
+
+def create_certificate(cert_req, ca_cert, ca_private_key, valid_days=365, cert_type='server', is_ca=False):
+ ext_key_usage = []
+ if is_ca:
+ ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH]
+ elif cert_type == 'client':
+ ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH]
+ elif cert_type == 'server':
+ ext_key_usage = [ExtendedKeyUsageOID.SERVER_AUTH]
+
+ builder = x509.CertificateBuilder() \
+ .subject_name(cert_req.subject) \
+ .issuer_name(ca_cert.subject) \
+ .public_key(cert_req.public_key()) \
+ .serial_number(x509.random_serial_number()) \
+ .not_valid_before(datetime.datetime.utcnow()) \
+ .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=int(valid_days)))
+
+ builder = builder.add_extension(x509.BasicConstraints(ca=is_ca, path_length=None), critical=True)
+ builder = builder.add_extension(x509.ExtendedKeyUsage(ext_key_usage), critical=True)
+ builder = builder.add_extension(x509.KeyUsage(
+ digital_signature=True,
+ content_commitment=False,
+ key_encipherment=False,
+ data_encipherment=False,
+ key_agreement=False,
+ key_cert_sign=is_ca,
+ crl_sign=is_ca,
+ encipher_only=False,
+ decipher_only=False), critical=True)
+
+ for ext in cert_req.extensions:
+ builder = builder.add_extension(ext, critical=False)
+
+ return builder.sign(ca_private_key, hashes.SHA256())
+
+def create_certificate_revocation_list(ca_cert, ca_private_key, serial_numbers=[]):
+ if not serial_numbers:
+ return False
+
+ builder = x509.CertificateRevocationListBuilder() \
+ .issuer_name(ca_cert.subject) \
+ .last_update(datetime.datetime.today()) \
+ .next_update(datetime.datetime.today() + datetime.timedelta(1, 0, 0))
+
+ for serial_number in serial_numbers:
+ revoked_cert = x509.RevokedCertificateBuilder() \
+ .serial_number(serial_number) \
+ .revocation_date(datetime.datetime.today()) \
+ .build()
+ builder = builder.add_revoked_certificate(revoked_cert)
+
+ return builder.sign(private_key=ca_private_key, algorithm=hashes.SHA256())
+
+def create_dh_parameters(bits=2048):
+ if not bits or bits < 512:
+ print("Invalid DH parameter key size")
+ return False
+
+ return dh.generate_parameters(generator=2, key_size=int(bits))
+
+# Wrap functions
+
+def wrap_public_key(raw_data):
+ return KEY_PUB_BEGIN + raw_data + KEY_PUB_END
+
+def wrap_private_key(raw_data, passphrase=None):
+ return (KEY_ENC_BEGIN if passphrase else KEY_BEGIN) + raw_data + (KEY_ENC_END if passphrase else KEY_END)
+
+def wrap_certificate_request(raw_data):
+ return CSR_BEGIN + raw_data + CSR_END
+
+def wrap_certificate(raw_data):
+ return CERT_BEGIN + raw_data + CERT_END
+
+def wrap_crl(raw_data):
+ return CRL_BEGIN + raw_data + CRL_END
+
+def wrap_dh_parameters(raw_data):
+ return DH_BEGIN + raw_data + DH_END
+
+# Load functions
+
+def load_public_key(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_public_key(raw_data)
+
+ try:
+ return serialization.load_pem_public_key(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+def load_private_key(raw_data, passphrase=None, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_private_key(raw_data, passphrase)
+
+ if passphrase:
+ passphrase = bytes(passphrase, 'utf-8')
+
+ try:
+ return serialization.load_pem_private_key(bytes(raw_data, 'utf-8'), password=passphrase)
+ except ValueError:
+ return False
+
+def load_certificate_request(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_certificate_request(raw_data)
+
+ try:
+ return x509.load_pem_x509_csr(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+def load_certificate(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_certificate(raw_data)
+
+ try:
+ return x509.load_pem_x509_certificate(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+def load_crl(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_crl(raw_data)
+
+ try:
+ return x509.load_pem_x509_crl(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+def load_dh_parameters(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_dh_parameters(raw_data)
+
+ try:
+ return serialization.load_pem_parameters(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+# Verify
+
+def is_ca_certificate(cert):
+ if not cert:
+ return False
+
+ try:
+ ext = cert.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS)
+ return ext.value.ca
+ except ExtensionNotFound:
+ return False
+
+def verify_certificate(cert, ca_cert):
+ # Verify certificate was signed by specified CA
+ if ca_cert.subject != cert.issuer:
+ return False
+
+ ca_public_key = ca_cert.public_key()
+ try:
+ if isinstance(ca_public_key, rsa.RSAPublicKeyWithSerialization):
+ ca_public_key.verify(
+ cert.signature,
+ cert.tbs_certificate_bytes,
+ padding=padding.PKCS1v15(),
+ algorithm=cert.signature_hash_algorithm)
+ elif isinstance(ca_public_key, dsa.DSAPublicKeyWithSerialization):
+ ca_public_key.verify(
+ cert.signature,
+ cert.tbs_certificate_bytes,
+ algorithm=cert.signature_hash_algorithm)
+ elif isinstance(ca_public_key, ec.EllipticCurvePublicKeyWithSerialization):
+ ca_public_key.verify(
+ cert.signature,
+ cert.tbs_certificate_bytes,
+ signature_algorithm=ec.ECDSA(cert.signature_hash_algorithm))
+ else:
+ return False # We cannot verify it
+ return True
+ except InvalidSignature:
+ return False
diff --git a/python/vyos/util.py b/python/vyos/util.py
index c318d58de..c3bf481ea 100644
--- a/python/vyos/util.py
+++ b/python/vyos/util.py
@@ -566,6 +566,25 @@ def wait_for_commit_lock():
while commit_in_progress():
sleep(1)
+def ask_input(question, default='', numeric_only=False, valid_responses=[]):
+ question_out = question
+ if default:
+ question_out += f' (Default: {default})'
+ response = ''
+ while True:
+ response = input(question_out + ' ').strip()
+ if not response and default:
+ return default
+ if numeric_only:
+ if not response.isnumeric():
+ print("Invalid value, try again.")
+ continue
+ response = int(response)
+ if valid_responses and response not in valid_responses:
+ print("Invalid value, try again.")
+ continue
+ break
+ return response
def ask_yes_no(question, default=False) -> bool:
"""Ask a yes/no question via input() and return their answer."""