diff options
Diffstat (limited to 'python/vyos')
82 files changed, 6006 insertions, 2107 deletions
diff --git a/python/vyos/accel_ppp.py b/python/vyos/accel_ppp.py index bfc8ee5a9..0b4f8a9fe 100644 --- a/python/vyos/accel_ppp.py +++ b/python/vyos/accel_ppp.py @@ -18,7 +18,7 @@ import sys import vyos.opmode -from vyos.util import rc_cmd +from vyos.utils.process import rc_cmd def get_server_statistics(accel_statistics, pattern, sep=':') -> dict: @@ -38,6 +38,9 @@ def get_server_statistics(accel_statistics, pattern, sep=':') -> dict: if key in ['starting', 'active', 'finishing']: stat_dict['sessions'][key] = value.strip() continue + if key == 'cpu': + stat_dict['cpu_load_percentage'] = int(re.sub(r'%', '', value.strip())) + continue stat_dict[key] = value.strip() return stat_dict diff --git a/python/vyos/component_version.py b/python/vyos/component_version.py index a4e318d08..84e0ae51a 100644 --- a/python/vyos/component_version.py +++ b/python/vyos/component_version.py @@ -37,7 +37,7 @@ import re import sys import fileinput -from vyos.xml import component_version +from vyos.xml_ref import component_version from vyos.version import get_version from vyos.defaults import directories diff --git a/python/vyos/config.py b/python/vyos/config.py index 287fd2ed1..0ca41718f 100644 --- a/python/vyos/config.py +++ b/python/vyos/config.py @@ -1,4 +1,4 @@ -# Copyright 2017, 2019 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2017, 2019-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -66,11 +66,31 @@ In operational mode, all functions return values from the running config. import re import json from copy import deepcopy +from typing import Union -import vyos.xml -import vyos.util import vyos.configtree -from vyos.configsource import ConfigSource, ConfigSourceSession +from vyos.xml_ref import multi_to_list +from vyos.xml_ref import from_source +from vyos.xml_ref import ext_dict_merge +from vyos.xml_ref import relative_defaults +from vyos.utils.dict import get_sub_dict +from vyos.utils.dict import mangle_dict_keys +from vyos.configsource import ConfigSource +from vyos.configsource import ConfigSourceSession + +class ConfigDict(dict): + _from_defaults = {} + _dict_kwargs = {} + def from_defaults(self, path: list[str]) -> bool: + return from_source(self._from_defaults, path) + @property + def kwargs(self) -> dict: + return self._dict_kwargs + +def config_dict_merge(src: dict, dest: Union[dict, ConfigDict]) -> ConfigDict: + if not isinstance(dest, ConfigDict): + dest = ConfigDict(dest) + return ext_dict_merge(src, dest) class Config(object): """ @@ -93,6 +113,11 @@ class Config(object): (self._running_config, self._session_config) = self._config_source.get_configtree_tuple() + def get_config_tree(self, effective=False): + if effective: + return self._running_config + return self._session_config + def _make_path(self, path): # Backwards-compatibility stuff: original implementation used string paths # libvyosconfig paths are lists, but since node names cannot contain whitespace, @@ -223,9 +248,17 @@ class Config(object): return config_dict + def verify_mangling(self, key_mangling): + if not (isinstance(key_mangling, tuple) and \ + (len(key_mangling) == 2) and \ + isinstance(key_mangling[0], str) and \ + isinstance(key_mangling[1], str)): + raise ValueError("key_mangling must be a tuple of two strings") + def get_config_dict(self, path=[], effective=False, key_mangling=None, get_first_key=False, no_multi_convert=False, - no_tag_node_value_mangle=False): + no_tag_node_value_mangle=False, + with_defaults=False, with_recursive_defaults=False): """ Args: path (str list): Configuration tree path, can be empty @@ -236,32 +269,73 @@ class Config(object): Returns: a dict representation of the config under path """ + kwargs = locals().copy() + del kwargs['self'] + del kwargs['no_multi_convert'] + del kwargs['with_defaults'] + del kwargs['with_recursive_defaults'] + lpath = self._make_path(path) root_dict = self.get_cached_root_dict(effective) - conf_dict = vyos.util.get_sub_dict(root_dict, lpath, get_first_key) - - if not key_mangling and no_multi_convert: - return deepcopy(conf_dict) + conf_dict = get_sub_dict(root_dict, lpath, get_first_key=get_first_key) - xmlpath = lpath if get_first_key else lpath[:-1] + rpath = lpath if get_first_key else lpath[:-1] - if not key_mangling: - conf_dict = vyos.xml.multi_to_list(xmlpath, conf_dict) - return conf_dict + if not no_multi_convert: + conf_dict = multi_to_list(rpath, conf_dict) - if no_multi_convert is False: - conf_dict = vyos.xml.multi_to_list(xmlpath, conf_dict) + if key_mangling is not None: + self.verify_mangling(key_mangling) + conf_dict = mangle_dict_keys(conf_dict, + key_mangling[0], key_mangling[1], + abs_path=rpath, + no_tag_node_value_mangle=no_tag_node_value_mangle) - if not (isinstance(key_mangling, tuple) and \ - (len(key_mangling) == 2) and \ - isinstance(key_mangling[0], str) and \ - isinstance(key_mangling[1], str)): - raise ValueError("key_mangling must be a tuple of two strings") + if with_defaults or with_recursive_defaults: + defaults = self.get_config_defaults(**kwargs, + recursive=with_recursive_defaults) + conf_dict = config_dict_merge(defaults, conf_dict) + else: + conf_dict = ConfigDict(conf_dict) - conf_dict = vyos.util.mangle_dict_keys(conf_dict, key_mangling[0], key_mangling[1], abs_path=xmlpath, no_tag_node_value_mangle=no_tag_node_value_mangle) + # save optional args for a call to get_config_defaults + setattr(conf_dict, '_dict_kwargs', kwargs) return conf_dict + def get_config_defaults(self, path=[], effective=False, key_mangling=None, + no_tag_node_value_mangle=False, get_first_key=False, + recursive=False) -> dict: + lpath = self._make_path(path) + root_dict = self.get_cached_root_dict(effective) + conf_dict = get_sub_dict(root_dict, lpath, get_first_key) + + defaults = relative_defaults(lpath, conf_dict, + get_first_key=get_first_key, + recursive=recursive) + + rpath = lpath if get_first_key else lpath[:-1] + + if key_mangling is not None: + self.verify_mangling(key_mangling) + defaults = mangle_dict_keys(defaults, + key_mangling[0], key_mangling[1], + abs_path=rpath, + no_tag_node_value_mangle=no_tag_node_value_mangle) + + return defaults + + def merge_defaults(self, config_dict: ConfigDict, recursive=False): + if not isinstance(config_dict, ConfigDict): + raise TypeError('argument is not of type ConfigDict') + if not config_dict.kwargs: + raise ValueError('argument missing metadata') + + args = config_dict.kwargs + d = self.get_config_defaults(**args, recursive=recursive) + config_dict = config_dict_merge(d, config_dict) + return config_dict + def is_multi(self, path): """ Args: diff --git a/python/vyos/config_mgmt.py b/python/vyos/config_mgmt.py new file mode 100644 index 000000000..dbf17ade4 --- /dev/null +++ b/python/vyos/config_mgmt.py @@ -0,0 +1,700 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +import os +import re +import sys +import gzip +import logging + +from typing import Optional, Tuple, Union +from filecmp import cmp +from datetime import datetime +from textwrap import dedent +from pathlib import Path +from tabulate import tabulate +from shutil import copy + +from vyos.config import Config +from vyos.configtree import ConfigTree, ConfigTreeError, show_diff +from vyos.defaults import directories +from vyos.version import get_full_version_data +from vyos.utils.io import ask_yes_no +from vyos.utils.boot import boot_configuration_complete +from vyos.utils.process import is_systemd_service_active +from vyos.utils.process import rc_cmd + +SAVE_CONFIG = '/usr/libexec/vyos/vyos-save-config.py' + +# created by vyatta-cfg-postinst +commit_post_hook_dir = '/etc/commit/post-hooks.d' + +commit_hooks = {'commit_revision': '01vyos-commit-revision', + 'commit_archive': '02vyos-commit-archive'} + +DEFAULT_TIME_MINUTES = 10 +timer_name = 'commit-confirm' + +config_file = os.path.join(directories['config'], 'config.boot') +archive_dir = os.path.join(directories['config'], 'archive') +archive_config_file = os.path.join(archive_dir, 'config.boot') +commit_log_file = os.path.join(archive_dir, 'commits') +logrotate_conf = os.path.join(archive_dir, 'lr.conf') +logrotate_state = os.path.join(archive_dir, 'lr.state') +rollback_config = os.path.join(archive_dir, 'config.boot-rollback') +prerollback_config = os.path.join(archive_dir, 'config.boot-prerollback') +tmp_log_entry = '/tmp/commit-rev-entry' + +logger = logging.getLogger('config_mgmt') +logger.setLevel(logging.INFO) +ch = logging.StreamHandler() +formatter = logging.Formatter('%(funcName)s: %(levelname)s:%(message)s') +ch.setFormatter(formatter) +logger.addHandler(ch) + +def save_config(target): + cmd = f'{SAVE_CONFIG} {target}' + rc, out = rc_cmd(cmd) + if rc != 0: + logger.critical(f'save config failed: {out}') + +def unsaved_commits() -> bool: + if get_full_version_data()['boot_via'] == 'livecd': + return False + tmp_save = '/tmp/config.running' + save_config(tmp_save) + ret = not cmp(tmp_save, config_file, shallow=False) + os.unlink(tmp_save) + return ret + +def get_file_revision(rev: int): + revision = os.path.join(archive_dir, f'config.boot.{rev}.gz') + try: + with gzip.open(revision) as f: + r = f.read().decode() + except FileNotFoundError: + logger.warning(f'commit revision {rev} not available') + return '' + return r + +def get_config_tree_revision(rev: int): + c = get_file_revision(rev) + return ConfigTree(c) + +def is_node_revised(path: list = [], rev1: int = 1, rev2: int = 0) -> bool: + from vyos.configtree import DiffTree + left = get_config_tree_revision(rev1) + right = get_config_tree_revision(rev2) + diff_tree = DiffTree(left, right) + if diff_tree.add.exists(path) or diff_tree.sub.exists(path): + return True + return False + +class ConfigMgmtError(Exception): + pass + +class ConfigMgmt: + def __init__(self, session_env=None, config=None): + if session_env: + self._session_env = session_env + else: + self._session_env = None + + if config is None: + config = Config() + + d = config.get_config_dict(['system', 'config-management'], + key_mangling=('-', '_'), + get_first_key=True) + + self.max_revisions = int(d.get('commit_revisions', 0)) + self.locations = d.get('commit_archive', {}).get('location', []) + self.source_address = d.get('commit_archive', + {}).get('source_address', '') + if config.exists(['system', 'host-name']): + self.hostname = config.return_value(['system', 'host-name']) + else: + self.hostname = 'vyos' + + # upload only on existence of effective values, notably, on boot. + # one still needs session self.locations (above) for setting + # post-commit hook in conf_mode script + path = ['system', 'config-management', 'commit-archive', 'location'] + if config.exists_effective(path): + self.effective_locations = config.return_effective_values(path) + else: + self.effective_locations = [] + + # a call to compare without args is edit_level aware + edit_level = os.getenv('VYATTA_EDIT_LEVEL', '') + self.edit_path = [l for l in edit_level.split('/') if l] + + self.active_config = config._running_config + self.working_config = config._session_config + + # Console script functions + # + def commit_confirm(self, minutes: int=DEFAULT_TIME_MINUTES, + no_prompt: bool=False) -> Tuple[str,int]: + """Commit with reboot to saved config in 'minutes' minutes if + 'confirm' call is not issued. + """ + if is_systemd_service_active(f'{timer_name}.timer'): + msg = 'Another confirm is pending' + return msg, 1 + + if unsaved_commits(): + W = '\nYou should save previous commits before commit-confirm !\n' + else: + W = '' + + prompt_str = f''' +commit-confirm will automatically reboot in {minutes} minutes unless changes +are confirmed.\n +Proceed ?''' + prompt_str = W + prompt_str + if not no_prompt and not ask_yes_no(prompt_str, default=True): + msg = 'commit-confirm canceled' + return msg, 1 + + action = 'sg vyattacfg "/usr/bin/config-mgmt revert"' + cmd = f'sudo systemd-run --quiet --on-active={minutes}m --unit={timer_name} {action}' + rc, out = rc_cmd(cmd) + if rc != 0: + raise ConfigMgmtError(out) + + # start notify + cmd = f'sudo -b /usr/libexec/vyos/commit-confirm-notify.py {minutes}' + os.system(cmd) + + msg = f'Initialized commit-confirm; {minutes} minutes to confirm before reboot' + return msg, 0 + + def confirm(self) -> Tuple[str,int]: + """Do not reboot to saved config following 'commit-confirm'. + Update commit log and archive. + """ + if not is_systemd_service_active(f'{timer_name}.timer'): + msg = 'No confirm pending' + return msg, 0 + + cmd = f'sudo systemctl stop --quiet {timer_name}.timer' + rc, out = rc_cmd(cmd) + if rc != 0: + raise ConfigMgmtError(out) + + # kill notify + cmd = 'sudo pkill -f commit-confirm-notify.py' + rc, out = rc_cmd(cmd) + if rc != 0: + raise ConfigMgmtError(out) + + entry = self._read_tmp_log_entry() + + if self._archive_active_config(): + self._add_log_entry(**entry) + self._update_archive() + + msg = 'Reboot timer stopped' + return msg, 0 + + def revert(self) -> Tuple[str,int]: + """Reboot to saved config, dropping commits from 'commit-confirm'. + """ + _ = self._read_tmp_log_entry() + + # archived config will be reverted on boot + rc, out = rc_cmd('sudo systemctl reboot') + if rc != 0: + raise ConfigMgmtError(out) + + return '', 0 + + def rollback(self, rev: int, no_prompt: bool=False) -> Tuple[str,int]: + """Reboot to config revision 'rev'. + """ + msg = '' + + if not self._check_revision_number(rev): + msg = f'Invalid revision number {rev}: must be 0 < rev < {maxrev}' + return msg, 1 + + prompt_str = 'Proceed with reboot ?' + if not no_prompt and not ask_yes_no(prompt_str, default=True): + msg = 'Canceling rollback' + return msg, 0 + + rc, out = rc_cmd(f'sudo cp {archive_config_file} {prerollback_config}') + if rc != 0: + raise ConfigMgmtError(out) + + path = os.path.join(archive_dir, f'config.boot.{rev}.gz') + with gzip.open(path) as f: + config = f.read() + try: + with open(rollback_config, 'wb') as f: + f.write(config) + copy(rollback_config, config_file) + except OSError as e: + raise ConfigMgmtError from e + + rc, out = rc_cmd('sudo systemctl reboot') + if rc != 0: + raise ConfigMgmtError(out) + + return msg, 0 + + def compare(self, saved: bool=False, commands: bool=False, + rev1: Optional[int]=None, + rev2: Optional[int]=None) -> Tuple[str,int]: + """General compare function for config file revisions: + revision n vs. revision m; working version vs. active version; + or working version vs. saved version. + """ + ct1 = self.active_config + ct2 = self.working_config + msg = 'No changes between working and active configurations.\n' + if saved: + ct1 = self._get_saved_config_tree() + ct2 = self.working_config + msg = 'No changes between working and saved configurations.\n' + if rev1 is not None: + if not self._check_revision_number(rev1): + return f'Invalid revision number {rev1}', 1 + ct1 = self._get_config_tree_revision(rev1) + ct2 = self.working_config + msg = f'No changes between working and revision {rev1} configurations.\n' + if rev2 is not None: + if not self._check_revision_number(rev2): + return f'Invalid revision number {rev2}', 1 + # compare older to newer + ct2 = ct1 + ct1 = self._get_config_tree_revision(rev2) + msg = f'No changes between revisions {rev2} and {rev1} configurations.\n' + + out = '' + path = [] if commands else self.edit_path + try: + if commands: + out = show_diff(ct1, ct2, path=path, commands=True) + else: + out = show_diff(ct1, ct2, path=path) + except ConfigTreeError as e: + return e, 1 + + if out: + msg = out + + return msg, 0 + + def wrap_compare(self, options) -> Tuple[str,int]: + """Interface to vyatta-cfg-run: args collected as 'options' to parse + for compare. + """ + cmnds = False + r1 = None + r2 = None + if 'commands' in options: + cmnds=True + options.remove('commands') + for i in options: + if not i.isnumeric(): + options.remove(i) + if len(options) > 0: + r1 = int(options[0]) + if len(options) > 1: + r2 = int(options[1]) + + return self.compare(commands=cmnds, rev1=r1, rev2=r2) + + # Initialization and post-commit hooks for conf-mode + # + def initialize_revision(self): + """Initialize config archive, logrotate conf, and commit log. + """ + mask = os.umask(0o002) + os.makedirs(archive_dir, exist_ok=True) + + self._add_logrotate_conf() + + if (not os.path.exists(commit_log_file) or + self._get_number_of_revisions() == 0): + user = self._get_user() + via = 'init' + comment = '' + # add empty init config before boot-config load for revision + # and diff consistency + if self._archive_active_config(): + self._add_log_entry(user, via, comment) + self._update_archive() + + os.umask(mask) + + def commit_revision(self): + """Update commit log and rotate archived config.boot. + + commit_revision is called in post-commit-hooks, if + ['commit-archive', 'commit-revisions'] is configured. + """ + if os.getenv('IN_COMMIT_CONFIRM', ''): + self._new_log_entry(tmp_file=tmp_log_entry) + return + + if self._archive_active_config(): + self._add_log_entry() + self._update_archive() + + def commit_archive(self): + """Upload config to remote archive. + """ + from vyos.remote import upload + + hostname = self.hostname + t = datetime.now() + timestamp = t.strftime('%Y%m%d_%H%M%S') + remote_file = f'config.boot-{hostname}.{timestamp}' + source_address = self.source_address + + for location in self.effective_locations: + upload(archive_config_file, f'{location}/{remote_file}', + source_host=source_address) + + # op-mode functions + # + def get_raw_log_data(self) -> list: + """Return list of dicts of log data: + keys: [timestamp, user, commit_via, commit_comment] + """ + log = self._get_log_entries() + res_l = [] + for line in log: + d = self._get_log_entry(line) + res_l.append(d) + + return res_l + + @staticmethod + def format_log_data(data: list) -> str: + """Return formatted log data as str. + """ + res_l = [] + for l_no, l in enumerate(data): + time_d = datetime.fromtimestamp(int(l['timestamp'])) + time_str = time_d.strftime("%Y-%m-%d %H:%M:%S") + + res_l.append([l_no, time_str, + f"by {l['user']}", f"via {l['commit_via']}"]) + + if l['commit_comment'] != 'commit': # default comment + res_l.append([None, l['commit_comment']]) + + ret = tabulate(res_l, tablefmt="plain") + return ret + + @staticmethod + def format_log_data_brief(data: list) -> str: + """Return 'brief' form of log data as str. + + Slightly compacted format used in completion help for + 'rollback'. + """ + res_l = [] + for l_no, l in enumerate(data): + time_d = datetime.fromtimestamp(int(l['timestamp'])) + time_str = time_d.strftime("%Y-%m-%d %H:%M:%S") + + res_l.append(['\t', l_no, time_str, + f"{l['user']}", f"by {l['commit_via']}"]) + + ret = tabulate(res_l, tablefmt="plain") + return ret + + def show_commit_diff(self, rev: int, rev2: Optional[int]=None, + commands: bool=False) -> str: + """Show commit diff at revision number, compared to previous + revision, or to another revision. + """ + if rev2 is None: + out, _ = self.compare(commands=commands, rev1=rev, rev2=(rev+1)) + return out + + out, _ = self.compare(commands=commands, rev1=rev, rev2=rev2) + return out + + def show_commit_file(self, rev: int) -> str: + return self._get_file_revision(rev) + + # utility functions + # + @staticmethod + def _strip_version(s): + return re.split(r'(^//)', s, maxsplit=1, flags=re.MULTILINE)[0] + + def _get_saved_config_tree(self): + with open(config_file) as f: + c = self._strip_version(f.read()) + return ConfigTree(c) + + def _get_file_revision(self, rev: int): + if rev not in range(0, self._get_number_of_revisions()): + raise ConfigMgmtError('revision not available') + revision = os.path.join(archive_dir, f'config.boot.{rev}.gz') + with gzip.open(revision) as f: + r = f.read().decode() + return r + + def _get_config_tree_revision(self, rev: int): + c = self._strip_version(self._get_file_revision(rev)) + return ConfigTree(c) + + def _add_logrotate_conf(self): + conf: str = dedent(f"""\ + {archive_config_file} {{ + su root vyattacfg + rotate {self.max_revisions} + start 0 + compress + copy + }} + """) + conf_file = Path(logrotate_conf) + conf_file.write_text(conf) + conf_file.chmod(0o644) + + def _archive_active_config(self) -> bool: + save_to_tmp = (boot_configuration_complete() or not + os.path.isfile(archive_config_file)) + mask = os.umask(0o113) + + ext = os.getpid() + cmp_saved = f'/tmp/config.boot.{ext}' + if save_to_tmp: + save_config(cmp_saved) + else: + copy(config_file, cmp_saved) + + try: + if cmp(cmp_saved, archive_config_file, shallow=False): + os.unlink(cmp_saved) + os.umask(mask) + return False + except FileNotFoundError: + pass + + rc, out = rc_cmd(f'sudo mv {cmp_saved} {archive_config_file}') + os.umask(mask) + + if rc != 0: + logger.critical(f'mv file to archive failed: {out}') + return False + + return True + + @staticmethod + def _update_archive(): + cmd = f"sudo logrotate -f -s {logrotate_state} {logrotate_conf}" + rc, out = rc_cmd(cmd) + if rc != 0: + logger.critical(f'logrotate failure: {out}') + + @staticmethod + def _get_log_entries() -> list: + """Return lines of commit log as list of strings + """ + entries = [] + if os.path.exists(commit_log_file): + with open(commit_log_file) as f: + entries = f.readlines() + + return entries + + def _get_number_of_revisions(self) -> int: + l = self._get_log_entries() + return len(l) + + def _check_revision_number(self, rev: int) -> bool: + maxrev = self._get_number_of_revisions() + if not 0 <= rev < maxrev: + return False + return True + + @staticmethod + def _get_user() -> str: + import pwd + + try: + user = os.getlogin() + except OSError: + try: + user = pwd.getpwuid(os.geteuid())[0] + except KeyError: + user = 'unknown' + return user + + def _new_log_entry(self, user: str='', commit_via: str='', + commit_comment: str='', timestamp: Optional[int]=None, + tmp_file: str=None) -> Optional[str]: + # Format log entry and return str or write to file. + # + # Usage is within a post-commit hook, using env values. In case of + # commit-confirm, it can be written to a temporary file for + # inclusion on 'confirm'. + from time import time + + if timestamp is None: + timestamp = int(time()) + + if not user: + user = self._get_user() + if not commit_via: + commit_via = os.getenv('COMMIT_VIA', 'other') + if not commit_comment: + commit_comment = os.getenv('COMMIT_COMMENT', 'commit') + + # the commit log reserves '|' as field demarcation, so replace in + # comment if present; undo this in _get_log_entry, below + if re.search(r'\|', commit_comment): + commit_comment = commit_comment.replace('|', '%%') + + entry = f'|{timestamp}|{user}|{commit_via}|{commit_comment}|\n' + + mask = os.umask(0o113) + if tmp_file is not None: + try: + with open(tmp_file, 'w') as f: + f.write(entry) + except OSError as e: + logger.critical(f'write to {tmp_file} failed: {e}') + os.umask(mask) + return None + + os.umask(mask) + return entry + + @staticmethod + def _get_log_entry(line: str) -> dict: + log_fmt = re.compile(r'\|.*\|\n?$') + keys = ['user', 'commit_via', 'commit_comment', 'timestamp'] + if not log_fmt.match(line): + logger.critical(f'Invalid log format {line}') + return {} + + timestamp, user, commit_via, commit_comment = ( + tuple(line.strip().strip('|').split('|'))) + + commit_comment = commit_comment.replace('%%', '|') + d = dict(zip(keys, [user, commit_via, + commit_comment, timestamp])) + + return d + + def _read_tmp_log_entry(self) -> dict: + try: + with open(tmp_log_entry) as f: + entry = f.read() + os.unlink(tmp_log_entry) + except OSError as e: + logger.critical(f'error on file {tmp_log_entry}: {e}') + + return self._get_log_entry(entry) + + def _add_log_entry(self, user: str='', commit_via: str='', + commit_comment: str='', timestamp: Optional[int]=None): + mask = os.umask(0o113) + + entry = self._new_log_entry(user=user, commit_via=commit_via, + commit_comment=commit_comment, + timestamp=timestamp) + + log_entries = self._get_log_entries() + log_entries.insert(0, entry) + if len(log_entries) > self.max_revisions: + log_entries = log_entries[:-1] + + try: + with open(commit_log_file, 'w') as f: + f.writelines(log_entries) + except OSError as e: + logger.critical(e) + + os.umask(mask) + +# entry_point for console script +# +def run(): + from argparse import ArgumentParser, REMAINDER + + config_mgmt = ConfigMgmt() + + for s in list(commit_hooks): + if sys.argv[0].replace('-', '_').endswith(s): + func = getattr(config_mgmt, s) + try: + func() + except Exception as e: + print(f'{s}: {e}') + sys.exit(0) + + parser = ArgumentParser() + subparsers = parser.add_subparsers(dest='subcommand') + + commit_confirm = subparsers.add_parser('commit_confirm', + help="Commit with opt-out reboot to saved config") + commit_confirm.add_argument('-t', dest='minutes', type=int, + default=DEFAULT_TIME_MINUTES, + help="Minutes until reboot, unless 'confirm'") + commit_confirm.add_argument('-y', dest='no_prompt', action='store_true', + help="Execute without prompt") + + subparsers.add_parser('confirm', help="Confirm commit") + subparsers.add_parser('revert', help="Revert commit-confirm") + + rollback = subparsers.add_parser('rollback', + help="Rollback to earlier config") + rollback.add_argument('--rev', type=int, + help="Revision number for rollback") + rollback.add_argument('-y', dest='no_prompt', action='store_true', + help="Excute without prompt") + + compare = subparsers.add_parser('compare', + help="Compare config files") + + compare.add_argument('--saved', action='store_true', + help="Compare session config with saved config") + compare.add_argument('--commands', action='store_true', + help="Show difference between commands") + compare.add_argument('--rev1', type=int, default=None, + help="Compare revision with session config or other revision") + compare.add_argument('--rev2', type=int, default=None, + help="Compare revisions") + + wrap_compare = subparsers.add_parser('wrap_compare', + help="Wrapper interface for vyatta-cfg-run") + wrap_compare.add_argument('--options', nargs=REMAINDER) + + args = vars(parser.parse_args()) + + func = getattr(config_mgmt, args['subcommand']) + del args['subcommand'] + + res = '' + try: + res, rc = func(**args) + except ConfigMgmtError as e: + print(e) + sys.exit(1) + if res: + print(res) + sys.exit(rc) diff --git a/python/vyos/configdep.py b/python/vyos/configdep.py index d4b2cc78f..05d9a3fa3 100644 --- a/python/vyos/configdep.py +++ b/python/vyos/configdep.py @@ -1,4 +1,4 @@ -# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -17,8 +17,10 @@ import os import json import typing from inspect import stack +from graphlib import TopologicalSorter, CycleError -from vyos.util import load_as_module +from vyos.utils.system import load_as_module +from vyos.configdict import dict_merge from vyos.defaults import directories from vyos.configsource import VyOSError from vyos import ConfigError @@ -28,6 +30,9 @@ from vyos import ConfigError if typing.TYPE_CHECKING: from vyos.config import Config +dependency_dir = os.path.join(directories['data'], + 'config-mode-dependencies') + dependent_func: dict[str, list[typing.Callable]] = {} def canon_name(name: str) -> str: @@ -40,12 +45,20 @@ def canon_name_of_path(path: str) -> str: def caller_name() -> str: return stack()[-1].filename -def read_dependency_dict() -> dict: - path = os.path.join(directories['data'], - 'config-mode-dependencies.json') - with open(path) as f: - d = json.load(f) - return d +def read_dependency_dict(dependency_dir: str = dependency_dir) -> dict: + res = {} + for dep_file in os.listdir(dependency_dir): + if not dep_file.endswith('.json'): + continue + path = os.path.join(dependency_dir, dep_file) + with open(path) as f: + d = json.load(f) + if dep_file == 'vyos-1x.json': + res = dict_merge(res, d) + else: + res = dict_merge(d, res) + + return res def get_dependency_dict(config: 'Config') -> dict: if hasattr(config, 'cached_dependency_dict'): @@ -93,3 +106,37 @@ def call_dependents(): while l: f = l.pop(0) f() + +def graph_from_dependency_dict(d: dict) -> dict: + g = {} + for k in list(d): + g[k] = set() + # add the dependencies for every sub-case; should there be cases + # that are mutally exclusive in the future, the graphs will be + # distinguished + for el in list(d[k]): + g[k] |= set(d[k][el]) + + return g + +def is_acyclic(d: dict) -> bool: + g = graph_from_dependency_dict(d) + ts = TopologicalSorter(g) + try: + # get node iterator + order = ts.static_order() + # try iteration + _ = [*order] + except CycleError: + return False + + return True + +def check_dependency_graph(dependency_dir: str = dependency_dir, + supplement: str = None) -> bool: + d = read_dependency_dict(dependency_dir=dependency_dir) + if supplement is not None: + with open(supplement) as f: + d = dict_merge(json.load(f), d) + + return is_acyclic(d) diff --git a/python/vyos/configdict.py b/python/vyos/configdict.py index 53decfbf5..71a06b625 100644 --- a/python/vyos/configdict.py +++ b/python/vyos/configdict.py @@ -19,9 +19,8 @@ A library for retrieving value dicts from VyOS configs in a declarative fashion. import os import json -from vyos.util import dict_search -from vyos.xml import defaults -from vyos.util import cmd +from vyos.utils.dict import dict_search +from vyos.utils.process import cmd def retrieve_config(path_hash, base_path, config): """ @@ -177,24 +176,6 @@ def get_removed_vlans(conf, path, dict): return dict -def T2665_set_dhcpv6pd_defaults(config_dict): - """ Properly configure DHCPv6 default options in the dictionary. If there is - no DHCPv6 configured at all, it is safe to remove the entire configuration. - """ - # As this is the same for every interface type it is safe to assume this - # for ethernet - pd_defaults = defaults(['interfaces', 'ethernet', 'dhcpv6-options', 'pd']) - - # Implant default dictionary for DHCPv6-PD instances - if dict_search('dhcpv6_options.pd.length', config_dict): - del config_dict['dhcpv6_options']['pd']['length'] - - for pd in (dict_search('dhcpv6_options.pd', config_dict) or []): - config_dict['dhcpv6_options']['pd'][pd] = dict_merge(pd_defaults, - config_dict['dhcpv6_options']['pd'][pd]) - - return config_dict - def is_member(conf, interface, intftype=None): """ Checks if passed interface is member of other interface of specified type. @@ -263,6 +244,48 @@ def is_mirror_intf(conf, interface, direction=None): return ret_val +def has_address_configured(conf, intf): + """ + Checks if interface has an address configured. + Checks the following config nodes: + 'address', 'ipv6 address eui64', 'ipv6 address autoconf' + + Returns True if interface has address configured, False if it doesn't. + """ + from vyos.ifconfig import Section + ret = False + + old_level = conf.get_level() + conf.set_level([]) + + intfpath = 'interfaces ' + Section.get_config_path(intf) + if ( conf.exists(f'{intfpath} address') or + conf.exists(f'{intfpath} ipv6 address autoconf') or + conf.exists(f'{intfpath} ipv6 address eui64') ): + ret = True + + conf.set_level(old_level) + return ret + +def has_vrf_configured(conf, intf): + """ + Checks if interface has a VRF configured. + + Returns True if interface has VRF configured, False if it doesn't. + """ + from vyos.ifconfig import Section + ret = False + + old_level = conf.get_level() + conf.set_level([]) + + tmp = ['interfaces', Section.get_config_path(intf), 'vrf'] + if conf.exists(tmp): + ret = True + + conf.set_level(old_level) + return ret + def has_vlan_subinterface_configured(conf, intf): """ Checks if interface has an VLAN subinterface configured. @@ -333,8 +356,9 @@ def get_dhcp_interfaces(conf, vrf=None): if dict_search('dhcp_options.default_route_distance', config) != None: options.update({'dhcp_options' : config['dhcp_options']}) if 'vrf' in config: - if vrf is config['vrf']: tmp.update({ifname : options}) - else: tmp.update({ifname : options}) + if vrf == config['vrf']: tmp.update({ifname : options}) + else: + if vrf is None: tmp.update({ifname : options}) return tmp @@ -382,12 +406,13 @@ def get_pppoe_interfaces(conf, vrf=None): if 'no_default_route' in ifconfig: options.update({'no_default_route' : {}}) if 'vrf' in ifconfig: - if vrf is ifconfig['vrf']: pppoe_interfaces.update({ifname : options}) - else: pppoe_interfaces.update({ifname : options}) + if vrf == ifconfig['vrf']: pppoe_interfaces.update({ifname : options}) + else: + if vrf is None: pppoe_interfaces.update({ifname : options}) return pppoe_interfaces -def get_interface_dict(config, base, ifname=''): +def get_interface_dict(config, base, ifname='', recursive_defaults=True): """ Common utility function to retrieve and mangle the interfaces configuration from the CLI input nodes. All interfaces have a common base where value @@ -403,42 +428,23 @@ def get_interface_dict(config, base, ifname=''): raise ConfigError('Interface (VYOS_TAGNODE_VALUE) not specified') ifname = os.environ['VYOS_TAGNODE_VALUE'] - # retrieve interface default values - default_values = defaults(base) - - # We take care about VLAN (vif, vif-s, vif-c) default values later on when - # parsing vlans in default dict and merge the "proper" values in correctly, - # see T2665. - for vif in ['vif', 'vif_s']: - if vif in default_values: del default_values[vif] - - dict = config.get_config_dict(base + [ifname], key_mangling=('-', '_'), - get_first_key=True, - no_tag_node_value_mangle=True) - # Check if interface has been removed. We must use exists() as # get_config_dict() will always return {} - even when an empty interface # node like the following exists. # +macsec macsec1 { # +} if not config.exists(base + [ifname]): + dict = config.get_config_dict(base + [ifname], key_mangling=('-', '_'), + get_first_key=True, + no_tag_node_value_mangle=True) dict.update({'deleted' : {}}) - - # Add interface instance name into dictionary - dict.update({'ifname': ifname}) - - # XXX: T2665: When there is no DHCPv6-PD configuration given, we can safely - # remove the default values from the dict. - if 'dhcpv6_options' not in dict: - if 'dhcpv6_options' in default_values: - del default_values['dhcpv6_options'] - - # We have gathered the dict representation of the CLI, but there are - # default options which we need to update into the dictionary retrived. - # But we should only add them when interface is not deleted - as this might - # confuse parsers - if 'deleted' not in dict: - dict = dict_merge(default_values, dict) + else: + # Get config_dict with default values + dict = config.get_config_dict(base + [ifname], key_mangling=('-', '_'), + get_first_key=True, + no_tag_node_value_mangle=True, + with_defaults=True, + with_recursive_defaults=recursive_defaults) # If interface does not request an IPv4 DHCP address there is no need # to keep the dhcp-options key @@ -446,8 +452,12 @@ def get_interface_dict(config, base, ifname=''): if 'dhcp_options' in dict: del dict['dhcp_options'] - # XXX: T2665: blend in proper DHCPv6-PD default values - dict = T2665_set_dhcpv6pd_defaults(dict) + # Add interface instance name into dictionary + dict.update({'ifname': ifname}) + + # Check if QoS policy applied on this interface - See ifconfig.interface.set_mirror_redirect() + if config.exists(['qos', 'interface', ifname]): + dict.update({'traffic_policy': {}}) address = leaf_node_changed(config, base + [ifname, 'address']) if address: dict.update({'address_old' : address}) @@ -468,6 +478,10 @@ def get_interface_dict(config, base, ifname=''): dhcp = is_node_changed(config, base + [ifname, 'dhcp-options']) if dhcp: dict.update({'dhcp_options_changed' : {}}) + # Changine interface VRF assignemnts require a DHCP restart, too + dhcp = is_node_changed(config, base + [ifname, 'vrf']) + if dhcp: dict.update({'dhcp_options_changed' : {}}) + # Some interfaces come with a source_interface which must also not be part # of any other bond or bridge interface as it is exclusivly assigned as the # Kernels "lower" interface to this new "virtual/upper" interface. @@ -491,29 +505,17 @@ def get_interface_dict(config, base, ifname=''): else: dict['ipv6']['address'].update({'eui64_old': eui64}) - # Implant default dictionary in vif/vif-s VLAN interfaces. Values are - # identical for all types of VLAN interfaces as they all include the same - # XML definitions which hold the defaults. for vif, vif_config in dict.get('vif', {}).items(): # Add subinterface name to dictionary dict['vif'][vif].update({'ifname' : f'{ifname}.{vif}'}) - default_vif_values = defaults(base + ['vif']) - # XXX: T2665: When there is no DHCPv6-PD configuration given, we can safely - # remove the default values from the dict. - if not 'dhcpv6_options' in vif_config: - del default_vif_values['dhcpv6_options'] + if config.exists(['qos', 'interface', f'{ifname}.{vif}']): + dict['vif'][vif].update({'traffic_policy': {}}) - # Only add defaults if interface is not about to be deleted - this is - # to keep a cleaner config dict. if 'deleted' not in dict: address = leaf_node_changed(config, base + [ifname, 'vif', vif, 'address']) if address: dict['vif'][vif].update({'address_old' : address}) - dict['vif'][vif] = dict_merge(default_vif_values, dict['vif'][vif]) - # XXX: T2665: blend in proper DHCPv6-PD default values - dict['vif'][vif] = T2665_set_dhcpv6pd_defaults(dict['vif'][vif]) - # If interface does not request an IPv4 DHCP address there is no need # to keep the dhcp-options key if 'address' not in dict['vif'][vif] or 'dhcp' not in dict['vif'][vif]['address']: @@ -532,26 +534,13 @@ def get_interface_dict(config, base, ifname=''): # Add subinterface name to dictionary dict['vif_s'][vif_s].update({'ifname' : f'{ifname}.{vif_s}'}) - default_vif_s_values = defaults(base + ['vif-s']) - # XXX: T2665: we only wan't the vif-s defaults - do not care about vif-c - if 'vif_c' in default_vif_s_values: del default_vif_s_values['vif_c'] - - # XXX: T2665: When there is no DHCPv6-PD configuration given, we can safely - # remove the default values from the dict. - if not 'dhcpv6_options' in vif_s_config: - del default_vif_s_values['dhcpv6_options'] + if config.exists(['qos', 'interface', f'{ifname}.{vif_s}']): + dict['vif_s'][vif_s].update({'traffic_policy': {}}) - # Only add defaults if interface is not about to be deleted - this is - # to keep a cleaner config dict. if 'deleted' not in dict: address = leaf_node_changed(config, base + [ifname, 'vif-s', vif_s, 'address']) if address: dict['vif_s'][vif_s].update({'address_old' : address}) - dict['vif_s'][vif_s] = dict_merge(default_vif_s_values, - dict['vif_s'][vif_s]) - # XXX: T2665: blend in proper DHCPv6-PD default values - dict['vif_s'][vif_s] = T2665_set_dhcpv6pd_defaults(dict['vif_s'][vif_s]) - # If interface does not request an IPv4 DHCP address there is no need # to keep the dhcp-options key if 'address' not in dict['vif_s'][vif_s] or 'dhcp' not in \ @@ -571,26 +560,14 @@ def get_interface_dict(config, base, ifname=''): # Add subinterface name to dictionary dict['vif_s'][vif_s]['vif_c'][vif_c].update({'ifname' : f'{ifname}.{vif_s}.{vif_c}'}) - default_vif_c_values = defaults(base + ['vif-s', 'vif-c']) - - # XXX: T2665: When there is no DHCPv6-PD configuration given, we can safely - # remove the default values from the dict. - if not 'dhcpv6_options' in vif_c_config: - del default_vif_c_values['dhcpv6_options'] + if config.exists(['qos', 'interface', f'{ifname}.{vif_s}.{vif_c}']): + dict['vif_s'][vif_s]['vif_c'][vif_c].update({'traffic_policy': {}}) - # Only add defaults if interface is not about to be deleted - this is - # to keep a cleaner config dict. if 'deleted' not in dict: address = leaf_node_changed(config, base + [ifname, 'vif-s', vif_s, 'vif-c', vif_c, 'address']) if address: dict['vif_s'][vif_s]['vif_c'][vif_c].update( {'address_old' : address}) - dict['vif_s'][vif_s]['vif_c'][vif_c] = dict_merge( - default_vif_c_values, dict['vif_s'][vif_s]['vif_c'][vif_c]) - # XXX: T2665: blend in proper DHCPv6-PD default values - dict['vif_s'][vif_s]['vif_c'][vif_c] = T2665_set_dhcpv6pd_defaults( - dict['vif_s'][vif_s]['vif_c'][vif_c]) - # If interface does not request an IPv4 DHCP address there is no need # to keep the dhcp-options key if 'address' not in dict['vif_s'][vif_s]['vif_c'][vif_c] or 'dhcp' \ @@ -640,45 +617,13 @@ def get_accel_dict(config, base, chap_secrets): Return a dictionary with the necessary interface config keys. """ - from vyos.util import get_half_cpus + from vyos.utils.system import get_half_cpus from vyos.template import is_ipv4 dict = config.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, - no_tag_node_value_mangle=True) - - # We have gathered the dict representation of the CLI, but there are default - # options which we need to update into the dictionary retrived. - default_values = defaults(base) - - # T2665: defaults include RADIUS server specifics per TAG node which need to - # be added to individual RADIUS servers instead - so we can simply delete them - if dict_search('authentication.radius.server', default_values): - del default_values['authentication']['radius']['server'] - - # T2665: defaults include static-ip address per TAG node which need to be - # added to individual local users instead - so we can simply delete them - if dict_search('authentication.local_users.username', default_values): - del default_values['authentication']['local_users']['username'] - - # T2665: defaults include IPv6 client-pool mask per TAG node which need to be - # added to individual local users instead - so we can simply delete them - if dict_search('client_ipv6_pool.prefix.mask', default_values): - del default_values['client_ipv6_pool']['prefix']['mask'] - # delete empty dicts - if len (default_values['client_ipv6_pool']['prefix']) == 0: - del default_values['client_ipv6_pool']['prefix'] - if len (default_values['client_ipv6_pool']) == 0: - del default_values['client_ipv6_pool'] - - # T2665: IPoE only - it has an interface tag node - # added to individual local users instead - so we can simply delete them - if dict_search('authentication.interface', default_values): - del default_values['authentication']['interface'] - if dict_search('interface', default_values): - del default_values['interface'] - - dict = dict_merge(default_values, dict) + no_tag_node_value_mangle=True, + with_recursive_defaults=True) # set CPUs cores to process requests dict.update({'thread_count' : get_half_cpus()}) @@ -698,43 +643,9 @@ def get_accel_dict(config, base, chap_secrets): dict.update({'name_server_ipv4' : ns_v4, 'name_server_ipv6' : ns_v6}) del dict['name_server'] - # T2665: Add individual RADIUS server default values - if dict_search('authentication.radius.server', dict): - default_values = defaults(base + ['authentication', 'radius', 'server']) - for server in dict_search('authentication.radius.server', dict): - dict['authentication']['radius']['server'][server] = dict_merge( - default_values, dict['authentication']['radius']['server'][server]) - - # Check option "disable-accounting" per server and replace default value from '1813' to '0' - # set vpn sstp authentication radius server x.x.x.x disable-accounting - if 'disable_accounting' in dict['authentication']['radius']['server'][server]: - dict['authentication']['radius']['server'][server]['acct_port'] = '0' - - # T2665: Add individual local-user default values - if dict_search('authentication.local_users.username', dict): - default_values = defaults(base + ['authentication', 'local-users', 'username']) - for username in dict_search('authentication.local_users.username', dict): - dict['authentication']['local_users']['username'][username] = dict_merge( - default_values, dict['authentication']['local_users']['username'][username]) - - # T2665: Add individual IPv6 client-pool default mask if required - if dict_search('client_ipv6_pool.prefix', dict): - default_values = defaults(base + ['client-ipv6-pool', 'prefix']) - for prefix in dict_search('client_ipv6_pool.prefix', dict): - dict['client_ipv6_pool']['prefix'][prefix] = dict_merge( - default_values, dict['client_ipv6_pool']['prefix'][prefix]) - - # T2665: IPoE only - add individual local-user default values - if dict_search('authentication.interface', dict): - default_values = defaults(base + ['authentication', 'interface']) - for interface in dict_search('authentication.interface', dict): - dict['authentication']['interface'][interface] = dict_merge( - default_values, dict['authentication']['interface'][interface]) - - if dict_search('interface', dict): - default_values = defaults(base + ['interface']) - for interface in dict_search('interface', dict): - dict['interface'][interface] = dict_merge(default_values, - dict['interface'][interface]) + # Check option "disable-accounting" per server and replace default value from '1813' to '0' + for server in (dict_search('authentication.radius.server', dict) or []): + if 'disable_accounting' in dict['authentication']['radius']['server'][server]: + dict['authentication']['radius']['server'][server]['acct_port'] = '0' return dict diff --git a/python/vyos/configdiff.py b/python/vyos/configdiff.py index 9185575df..1ec2dfafe 100644 --- a/python/vyos/configdiff.py +++ b/python/vyos/configdiff.py @@ -19,9 +19,10 @@ from vyos.config import Config from vyos.configtree import DiffTree from vyos.configdict import dict_merge from vyos.configdict import list_diff -from vyos.util import get_sub_dict, mangle_dict_keys -from vyos.util import dict_search_args -from vyos.xml import defaults +from vyos.utils.dict import get_sub_dict +from vyos.utils.dict import mangle_dict_keys +from vyos.utils.dict import dict_search_args +from vyos.xml_ref import get_defaults class ConfigDiffError(Exception): """ @@ -78,23 +79,34 @@ def get_config_diff(config, key_mangling=None): isinstance(key_mangling[1], str)): raise ValueError("key_mangling must be a tuple of two strings") - diff_t = DiffTree(config._running_config, config._session_config) + if hasattr(config, 'cached_diff_tree'): + diff_t = getattr(config, 'cached_diff_tree') + else: + diff_t = DiffTree(config._running_config, config._session_config) + setattr(config, 'cached_diff_tree', diff_t) - return ConfigDiff(config, key_mangling, diff_tree=diff_t) + if hasattr(config, 'cached_diff_dict'): + diff_d = getattr(config, 'cached_diff_dict') + else: + diff_d = diff_t.dict + setattr(config, 'cached_diff_dict', diff_d) + + return ConfigDiff(config, key_mangling, diff_tree=diff_t, + diff_dict=diff_d) class ConfigDiff(object): """ The class of config changes as represented by comparison between the session config dict and the effective config dict. """ - def __init__(self, config, key_mangling=None, diff_tree=None): + def __init__(self, config, key_mangling=None, diff_tree=None, diff_dict=None): self._level = config.get_level() self._session_config_dict = config.get_cached_root_dict(effective=False) self._effective_config_dict = config.get_cached_root_dict(effective=True) self._key_mangling = key_mangling self._diff_tree = diff_tree - self._diff_dict = diff_tree.dict if diff_tree else {} + self._diff_dict = diff_dict # mirrored from Config; allow path arguments relative to level def _make_path(self, path): @@ -209,9 +221,9 @@ class ConfigDiff(object): if self._diff_tree is None: raise NotImplementedError("diff_tree class not available") else: - add = get_sub_dict(self._diff_tree.dict, ['add'], get_first_key=True) - sub = get_sub_dict(self._diff_tree.dict, ['sub'], get_first_key=True) - inter = get_sub_dict(self._diff_tree.dict, ['inter'], get_first_key=True) + add = get_sub_dict(self._diff_dict, ['add'], get_first_key=True) + sub = get_sub_dict(self._diff_dict, ['sub'], get_first_key=True) + inter = get_sub_dict(self._diff_dict, ['inter'], get_first_key=True) ret = {} ret[enum_to_key(Diff.MERGE)] = session_dict ret[enum_to_key(Diff.DELETE)] = get_sub_dict(sub, self._make_path(path), @@ -228,7 +240,9 @@ class ConfigDiff(object): if self._key_mangling: ret[k] = self._mangle_dict_keys(ret[k]) if k in target_defaults and not no_defaults: - default_values = defaults(self._make_path(path)) + default_values = get_defaults(self._make_path(path), + get_first_key=True, + recursive=True) ret[k] = dict_merge(default_values, ret[k]) return ret @@ -252,7 +266,9 @@ class ConfigDiff(object): ret[k] = self._mangle_dict_keys(ret[k]) if k in target_defaults and not no_defaults: - default_values = defaults(self._make_path(path)) + default_values = get_defaults(self._make_path(path), + get_first_key=True, + recursive=True) ret[k] = dict_merge(default_values, ret[k]) return ret @@ -284,9 +300,9 @@ class ConfigDiff(object): if self._diff_tree is None: raise NotImplementedError("diff_tree class not available") else: - add = get_sub_dict(self._diff_tree.dict, ['add'], get_first_key=True) - sub = get_sub_dict(self._diff_tree.dict, ['sub'], get_first_key=True) - inter = get_sub_dict(self._diff_tree.dict, ['inter'], get_first_key=True) + add = get_sub_dict(self._diff_dict, ['add'], get_first_key=True) + sub = get_sub_dict(self._diff_dict, ['sub'], get_first_key=True) + inter = get_sub_dict(self._diff_dict, ['inter'], get_first_key=True) ret = {} ret[enum_to_key(Diff.MERGE)] = session_dict ret[enum_to_key(Diff.DELETE)] = get_sub_dict(sub, self._make_path(path)) @@ -300,7 +316,9 @@ class ConfigDiff(object): if self._key_mangling: ret[k] = self._mangle_dict_keys(ret[k]) if k in target_defaults and not no_defaults: - default_values = defaults(self._make_path(path)) + default_values = get_defaults(self._make_path(path), + get_first_key=True, + recursive=True) ret[k] = dict_merge(default_values, ret[k]) return ret @@ -323,7 +341,9 @@ class ConfigDiff(object): ret[k] = self._mangle_dict_keys(ret[k]) if k in target_defaults and not no_defaults: - default_values = defaults(self._make_path(path)) + default_values = get_defaults(self._make_path(path), + get_first_key=True, + recursive=True) ret[k] = dict_merge(default_values, ret[k]) return ret diff --git a/python/vyos/configquery.py b/python/vyos/configquery.py index 5b097b312..71ad5b4f0 100644 --- a/python/vyos/configquery.py +++ b/python/vyos/configquery.py @@ -1,4 +1,4 @@ -# Copyright 2021 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2021-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -19,9 +19,11 @@ settings from op mode, and execution of arbitrary op mode commands. ''' import os -from subprocess import STDOUT -from vyos.util import popen, boot_configuration_complete +from vyos.utils.process import STDOUT +from vyos.utils.process import popen + +from vyos.utils.boot import boot_configuration_complete from vyos.config import Config from vyos.configsource import ConfigSourceSession, ConfigSourceString from vyos.defaults import directories @@ -88,7 +90,7 @@ class ConfigTreeQuery(GenericConfigQuery): with open(config_file) as f: config_string = f.read() except OSError as err: - raise ConfigQueryError('No config file available') from err + config_string = '' config_source = ConfigSourceString(running_config_text=config_string, session_config_text=config_string) diff --git a/python/vyos/configsession.py b/python/vyos/configsession.py index 3a60f6d92..6d4b2af59 100644 --- a/python/vyos/configsession.py +++ b/python/vyos/configsession.py @@ -1,5 +1,5 @@ # configsession -- the write API for the VyOS running config -# Copyright (C) 2019 VyOS maintainers and contributors +# Copyright (C) 2019-2023 VyOS maintainers and contributors # # This library is free software; you can redistribute it and/or modify it under the terms of # the GNU Lesser General Public License as published by the Free Software Foundation; @@ -17,7 +17,8 @@ import re import sys import subprocess -from vyos.util import is_systemd_service_running +from vyos.utils.process import is_systemd_service_running +from vyos.utils.dict import dict_to_paths CLI_SHELL_API = '/bin/cli-shell-api' SET = '/opt/vyatta/sbin/my_set' @@ -28,12 +29,14 @@ DISCARD = '/opt/vyatta/sbin/my_discard' SHOW_CONFIG = ['/bin/cli-shell-api', 'showConfig'] LOAD_CONFIG = ['/bin/cli-shell-api', 'loadFile'] MIGRATE_LOAD_CONFIG = ['/usr/libexec/vyos/vyos-load-config.py'] -SAVE_CONFIG = ['/opt/vyatta/sbin/vyatta-save-config.pl'] +SAVE_CONFIG = ['/usr/libexec/vyos/vyos-save-config.py'] INSTALL_IMAGE = ['/opt/vyatta/sbin/install-image', '--url'] REMOVE_IMAGE = ['/opt/vyatta/bin/vyatta-boot-image.pl', '--del'] GENERATE = ['/opt/vyatta/bin/vyatta-op-cmd-wrapper', 'generate'] SHOW = ['/opt/vyatta/bin/vyatta-op-cmd-wrapper', 'show'] RESET = ['/opt/vyatta/bin/vyatta-op-cmd-wrapper', 'reset'] +OP_CMD_ADD = ['/opt/vyatta/bin/vyatta-op-cmd-wrapper', 'add'] +OP_CMD_DELETE = ['/opt/vyatta/bin/vyatta-op-cmd-wrapper', 'delete'] # Default "commit via" string APP = "vyos-http-api" @@ -146,6 +149,13 @@ class ConfigSession(object): value = [value] self.__run_command([SET] + path + value) + def set_section(self, path: list, d: dict): + try: + for p in dict_to_paths(d): + self.set(path + p) + except (ValueError, ConfigSessionError) as e: + raise ConfigSessionError(e) + def delete(self, path, value=None): if not value: value = [] @@ -153,6 +163,15 @@ class ConfigSession(object): value = [value] self.__run_command([DELETE] + path + value) + def load_section(self, path: list, d: dict): + try: + self.delete(path) + if d: + for p in dict_to_paths(d): + self.set(path + p) + except (ValueError, ConfigSessionError) as e: + raise ConfigSessionError(e) + def comment(self, path, value=None): if not value: value = [""] @@ -204,3 +223,15 @@ class ConfigSession(object): def reset(self, path): out = self.__run_command(RESET + path) return out + + def add_container_image(self, name): + out = self.__run_command(OP_CMD_ADD + ['container', 'image'] + [name]) + return out + + def delete_container_image(self, name): + out = self.__run_command(OP_CMD_DELETE + ['container', 'image'] + [name]) + return out + + def show_container_image(self): + out = self.__run_command(SHOW + ['container', 'image']) + return out diff --git a/python/vyos/configsource.py b/python/vyos/configsource.py index 510b5b65a..f582bdfab 100644 --- a/python/vyos/configsource.py +++ b/python/vyos/configsource.py @@ -1,5 +1,5 @@ -# Copyright 2020 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2020-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -19,7 +19,7 @@ import re import subprocess from vyos.configtree import ConfigTree -from vyos.util import boot_configuration_complete +from vyos.utils.boot import boot_configuration_complete class VyOSError(Exception): """ diff --git a/python/vyos/configtree.py b/python/vyos/configtree.py index b88615513..09cfd43d3 100644 --- a/python/vyos/configtree.py +++ b/python/vyos/configtree.py @@ -16,7 +16,7 @@ import os import re import json -from ctypes import cdll, c_char_p, c_void_p, c_int +from ctypes import cdll, c_char_p, c_void_p, c_int, c_bool LIBPATH = '/usr/lib/libvyosconfig.so.0' @@ -60,7 +60,7 @@ class ConfigTree(object): self.__get_error.restype = c_char_p self.__to_string = self.__lib.to_string - self.__to_string.argtypes = [c_void_p] + self.__to_string.argtypes = [c_void_p, c_bool] self.__to_string.restype = c_char_p self.__to_commands = self.__lib.to_commands @@ -160,8 +160,8 @@ class ConfigTree(object): def _get_config(self): return self.__config - def to_string(self): - config_string = self.__to_string(self.__config).decode() + def to_string(self, ordered_values=False): + config_string = self.__to_string(self.__config, ordered_values).decode() config_string = "{0}\n{1}".format(config_string, self.__version) return config_string @@ -201,7 +201,9 @@ class ConfigTree(object): check_path(path) path_str = " ".join(map(str, path)).encode() - self.__delete(self.__config, path_str) + res = self.__delete(self.__config, path_str) + if (res != 0): + raise ConfigTreeError(f"Path doesn't exist: {path}") if self.__migration: print(f"- op: delete path: {path}") @@ -210,7 +212,14 @@ class ConfigTree(object): check_path(path) path_str = " ".join(map(str, path)).encode() - self.__delete_value(self.__config, path_str, value.encode()) + res = self.__delete_value(self.__config, path_str, value.encode()) + if (res != 0): + if res == 1: + raise ConfigTreeError(f"Path doesn't exist: {path}") + elif res == 2: + raise ConfigTreeError(f"Value doesn't exist: '{value}'") + else: + raise ConfigTreeError() if self.__migration: print(f"- op: delete_value path: {path} value: {value}") @@ -242,7 +251,8 @@ class ConfigTree(object): raise ConfigTreeError() res = self.__copy(self.__config, oldpath_str, newpath_str) if (res != 0): - raise ConfigTreeError("Path [{}] doesn't exist".format(old_path)) + msg = self.__get_error().decode() + raise ConfigTreeError(msg) if self.__migration: print(f"- op: copy old_path: {old_path} new_path: {new_path}") @@ -321,6 +331,72 @@ class ConfigTree(object): subt = ConfigTree(address=res) return subt +def show_diff(left, right, path=[], commands=False, libpath=LIBPATH): + if left is None: + left = ConfigTree(config_string='\n') + if right is None: + right = ConfigTree(config_string='\n') + if not (isinstance(left, ConfigTree) and isinstance(right, ConfigTree)): + raise TypeError("Arguments must be instances of ConfigTree") + if path: + if (not left.exists(path)) and (not right.exists(path)): + raise ConfigTreeError(f"Path {path} doesn't exist") + + check_path(path) + path_str = " ".join(map(str, path)).encode() + + __lib = cdll.LoadLibrary(libpath) + __show_diff = __lib.show_diff + __show_diff.argtypes = [c_bool, c_char_p, c_void_p, c_void_p] + __show_diff.restype = c_char_p + __get_error = __lib.get_error + __get_error.argtypes = [] + __get_error.restype = c_char_p + + res = __show_diff(commands, path_str, left._get_config(), right._get_config()) + res = res.decode() + if res == "#1@": + msg = __get_error().decode() + raise ConfigTreeError(msg) + + return res + +def union(left, right, libpath=LIBPATH): + if left is None: + left = ConfigTree(config_string='\n') + if right is None: + right = ConfigTree(config_string='\n') + if not (isinstance(left, ConfigTree) and isinstance(right, ConfigTree)): + raise TypeError("Arguments must be instances of ConfigTree") + + __lib = cdll.LoadLibrary(libpath) + __tree_union = __lib.tree_union + __tree_union.argtypes = [c_void_p, c_void_p] + __tree_union.restype = c_void_p + __get_error = __lib.get_error + __get_error.argtypes = [] + __get_error.restype = c_char_p + + res = __tree_union( left._get_config(), right._get_config()) + tree = ConfigTree(address=res) + + return tree + +def reference_tree_to_json(from_dir, to_file, libpath=LIBPATH): + try: + __lib = cdll.LoadLibrary(libpath) + __reference_tree_to_json = __lib.reference_tree_to_json + __reference_tree_to_json.argtypes = [c_char_p, c_char_p] + __get_error = __lib.get_error + __get_error.argtypes = [] + __get_error.restype = c_char_p + res = __reference_tree_to_json(from_dir.encode(), to_file.encode()) + except Exception as e: + raise ConfigTreeError(e) + if res == 1: + msg = __get_error().decode() + raise ConfigTreeError(msg) + class DiffTree: def __init__(self, left, right, path=[], libpath=LIBPATH): if left is None: @@ -344,10 +420,6 @@ class DiffTree: self.__diff_tree.argtypes = [c_char_p, c_void_p, c_void_p] self.__diff_tree.restype = c_void_p - self.__trim_tree = self.__lib.trim_tree - self.__trim_tree.argtypes = [c_void_p, c_void_p] - self.__trim_tree.restype = c_void_p - check_path(path) path_str = " ".join(map(str, path)).encode() @@ -361,11 +433,7 @@ class DiffTree: self.add = self.full.get_subtree(['add']) self.sub = self.full.get_subtree(['sub']) self.inter = self.full.get_subtree(['inter']) - - # trim sub(-tract) tree to get delete tree for commands - ref = self.right.get_subtree(path, with_node=True) if path else self.right - res = self.__trim_tree(self.sub._get_config(), ref._get_config()) - self.delete = ConfigTree(address=res) + self.delete = self.full.get_subtree(['del']) def to_commands(self): add = self.add.to_commands() diff --git a/python/vyos/configverify.py b/python/vyos/configverify.py index 8e0ce701e..52f9238b8 100644 --- a/python/vyos/configverify.py +++ b/python/vyos/configverify.py @@ -1,4 +1,4 @@ -# Copyright 2020-2022 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2020-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -22,7 +22,8 @@ # makes use of it! from vyos import ConfigError -from vyos.util import dict_search +from vyos.utils.dict import dict_search +from vyos.utils.dict import dict_search_recursive def verify_mtu(config): """ @@ -35,8 +36,14 @@ def verify_mtu(config): mtu = int(config['mtu']) tmp = Interface(config['ifname']) - min_mtu = tmp.get_min_mtu() - max_mtu = tmp.get_max_mtu() + # Not all interfaces support min/max MTU + # https://vyos.dev/T5011 + try: + min_mtu = tmp.get_min_mtu() + max_mtu = tmp.get_max_mtu() + except: # Fallback to defaults + min_mtu = 68 + max_mtu = 9000 if mtu < min_mtu: raise ConfigError(f'Interface MTU too low, ' \ @@ -180,15 +187,14 @@ def verify_eapol(config): if 'ca' not in config['pki']: raise ConfigError('Invalid CA certificate specified for EAPoL') - ca_cert_name = config['eapol']['ca_certificate'] + for ca_cert_name in config['eapol']['ca_certificate']: + if ca_cert_name not in config['pki']['ca']: + raise ConfigError('Invalid CA certificate specified for EAPoL') - if ca_cert_name not in config['pki']['ca']: - raise ConfigError('Invalid CA certificate specified for EAPoL') - - ca_cert = config['pki']['ca'][ca_cert_name] + ca_cert = config['pki']['ca'][ca_cert_name] - if 'certificate' not in ca_cert: - raise ConfigError('Invalid CA certificate specified for EAPoL') + if 'certificate' not in ca_cert: + raise ConfigError('Invalid CA certificate specified for EAPoL') def verify_mirror_redirect(config): """ @@ -232,7 +238,7 @@ def verify_authentication(config): """ if 'authentication' not in config: return - if not {'user', 'password'} <= set(config['authentication']): + if not {'username', 'password'} <= set(config['authentication']): raise ConfigError('Authentication requires both username and ' \ 'password to be set!') @@ -307,15 +313,13 @@ def verify_dhcpv6(config): recurring validation of DHCPv6 options which are mutually exclusive. """ if 'dhcpv6_options' in config: - from vyos.util import dict_search - if {'parameters_only', 'temporary'} <= set(config['dhcpv6_options']): raise ConfigError('DHCPv6 temporary and parameters-only options ' 'are mutually exclusive!') # It is not allowed to have duplicate SLA-IDs as those identify an # assigned IPv6 subnet from a delegated prefix - for pd in dict_search('dhcpv6_options.pd', config): + for pd in (dict_search('dhcpv6_options.pd', config) or []): sla_ids = [] interfaces = dict_search(f'dhcpv6_options.pd.{pd}.interface', config) @@ -414,7 +418,18 @@ def verify_accel_ppp_base_service(config, local_users=True): if 'key' not in radius_config: raise ConfigError(f'Missing RADIUS secret key for server "{server}"') - if 'gateway_address' not in config: + # Check global gateway or gateway in named pool + gateway = False + if 'gateway_address' in config: + gateway = True + else: + if 'client_ip_pool' in config: + if dict_search_recursive(config, 'gateway_address', ['client_ip_pool', 'name']): + for _, v in config['client_ip_pool']['name'].items(): + if 'gateway_address' in v: + gateway = True + break + if not gateway: raise ConfigError('Server requires gateway-address to be configured!') if 'name_server_ipv4' in config: @@ -442,7 +457,7 @@ def verify_diffie_hellman_length(file, min_keysize): then or equal to min_keysize """ import os import re - from vyos.util import cmd + from vyos.utils.process import cmd try: keysize = str(min_keysize) diff --git a/python/vyos/cpu.py b/python/vyos/cpu.py index 488ae79fb..d2e5f6504 100644 --- a/python/vyos/cpu.py +++ b/python/vyos/cpu.py @@ -73,7 +73,7 @@ def _find_physical_cpus(): # On other architectures, e.g. on ARM, there's no such field. # We just assume they are different CPUs, # whether single core ones or cores of physical CPUs. - phys_cpus[num] = cpu[num] + phys_cpus[num] = cpus[num] return phys_cpus diff --git a/python/vyos/defaults.py b/python/vyos/defaults.py index 7de458960..a5314790d 100644 --- a/python/vyos/defaults.py +++ b/python/vyos/defaults.py @@ -1,4 +1,4 @@ -# Copyright 2018 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2018-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -15,19 +15,26 @@ import os +base_dir = '/usr/libexec/vyos/' + directories = { - "data": "/usr/share/vyos/", - "conf_mode": "/usr/libexec/vyos/conf_mode", - "op_mode": "/usr/libexec/vyos/op_mode", - "config": "/opt/vyatta/etc/config", - "current": "/opt/vyatta/etc/config-migrate/current", - "migrate": "/opt/vyatta/etc/config-migrate/migrate", - "log": "/var/log/vyatta", - "templates": "/usr/share/vyos/templates/", - "certbot": "/config/auth/letsencrypt", - "api_schema": "/usr/libexec/vyos/services/api/graphql/graphql/schema/", - "api_templates": "/usr/libexec/vyos/services/api/graphql/session/templates/", - "vyos_udev_dir": "/run/udev/vyos" + 'base' : base_dir, + 'data' : '/usr/share/vyos/', + 'conf_mode' : f'{base_dir}/conf_mode', + 'op_mode' : f'{base_dir}/op_mode', + 'services' : f'{base_dir}/services', + 'config' : '/opt/vyatta/etc/config', + 'current' : '/opt/vyatta/etc/config-migrate/current', + 'migrate' : '/opt/vyatta/etc/config-migrate/migrate', + 'log' : '/var/log/vyatta', + 'templates' : '/usr/share/vyos/templates/', + 'certbot' : '/config/auth/letsencrypt', + 'api_schema': f'{base_dir}/services/api/graphql/graphql/schema/', + 'api_client_op': f'{base_dir}/services/api/graphql/graphql/client_op/', + 'api_templates': f'{base_dir}/services/api/graphql/session/templates/', + 'vyos_udev_dir' : '/run/udev/vyos', + 'isc_dhclient_dir' : '/run/dhclient', + 'dhcp6_client_dir' : '/run/dhcp6c', } config_status = '/tmp/vyos-config-status' @@ -50,12 +57,12 @@ api_data = { 'socket' : False, 'strict' : False, 'debug' : False, - 'api_keys' : [ {"id": "testapp", "key": "qwerty"} ] + 'api_keys' : [ {'id' : 'testapp', 'key' : 'qwerty'} ] } vyos_cert_data = { - "conf": "/etc/nginx/snippets/vyos-cert.conf", - "crt": "/etc/ssl/certs/vyos-selfsigned.crt", - "key": "/etc/ssl/private/vyos-selfsign", - "lifetime": "365", + 'conf' : '/etc/nginx/snippets/vyos-cert.conf', + 'crt' : '/etc/ssl/certs/vyos-selfsigned.crt', + 'key' : '/etc/ssl/private/vyos-selfsign', + 'lifetime' : '365', } diff --git a/python/vyos/dicts.py b/python/vyos/dicts.py deleted file mode 100644 index b12cda40f..000000000 --- a/python/vyos/dicts.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright (C) 2019 VyOS maintainers and contributors -# -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License version 2 or later as -# published by the Free Software Foundation. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. - - -from vyos import ConfigError - - -class FixedDict(dict): - """ - FixedDict: A dictionnary not allowing new keys to be created after initialisation. - - >>> f = FixedDict(**{'count':1}) - >>> f['count'] = 2 - >>> f['king'] = 3 - File "...", line ..., in __setitem__ - raise ConfigError(f'Option "{k}" has no defined default') - """ - - def __init__(self, **options): - self._allowed = options.keys() - super().__init__(**options) - - def __setitem__(self, k, v): - """ - __setitem__ is a builtin which is called by python when setting dict values: - >>> d = dict() - >>> d['key'] = 'value' - >>> d - {'key': 'value'} - - is syntaxic sugar for - - >>> d = dict() - >>> d.__setitem__('key','value') - >>> d - {'key': 'value'} - """ - if k not in self._allowed: - raise ConfigError(f'Option "{k}" has no defined default') - super().__setitem__(k, v) diff --git a/python/vyos/ethtool.py b/python/vyos/ethtool.py index 2b6012a73..ca3bcfc3d 100644 --- a/python/vyos/ethtool.py +++ b/python/vyos/ethtool.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2021-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -16,12 +16,13 @@ import os import re -from vyos.util import popen +from vyos.utils.process import popen # These drivers do not support using ethtool to change the speed, duplex, or # flow control settings _drivers_without_speed_duplex_flow = ['vmxnet3', 'virtio_net', 'xen_netfront', - 'iavf', 'ice', 'i40e', 'hv_netvsc'] + 'iavf', 'ice', 'i40e', 'hv_netvsc', 'veth', 'ixgbevf', + 'tun'] class Ethtool: """ @@ -51,15 +52,16 @@ class Ethtool: _ring_buffers_max = { } _driver_name = None _auto_negotiation = False + _auto_negotiation_supported = None _flow_control = False _flow_control_enabled = None def __init__(self, ifname): # Get driver used for interface - sysfs_file = f'/sys/class/net/{ifname}/device/driver/module' - if os.path.exists(sysfs_file): - link = os.readlink(sysfs_file) - self._driver_name = os.path.basename(link) + out, err = popen(f'ethtool --driver {ifname}') + driver = re.search(r'driver:\s(\w+)', out) + if driver: + self._driver_name = driver.group(1) # Build a dictinary of supported link-speed and dupley settings. out, err = popen(f'ethtool {ifname}') @@ -80,7 +82,13 @@ class Ethtool: self._speed_duplex.update({ speed : {}}) if duplex not in self._speed_duplex[speed]: self._speed_duplex[speed].update({ duplex : ''}) - if 'Auto-negotiation:' in line: + if 'Supports auto-negotiation:' in line: + # Split the following string: Auto-negotiation: off + # we are only interested in off or on + tmp = line.split()[-1] + self._auto_negotiation_supported = bool(tmp == 'Yes') + # Only read in if Auto-negotiation is supported + if self._auto_negotiation_supported and 'Auto-negotiation:' in line: # Split the following string: Auto-negotiation: off # we are only interested in off or on tmp = line.split()[-1] @@ -132,8 +140,12 @@ class Ethtool: # ['Autonegotiate:', 'on'] self._flow_control_enabled = out.splitlines()[1].split()[-1] + def check_auto_negotiation_supported(self): + """ Check if the NIC supports changing auto-negotiation """ + return self._auto_negotiation_supported + def get_auto_negotiation(self): - return self._auto_negotiation + return self._auto_negotiation_supported and self._auto_negotiation def get_driver_name(self): return self._driver_name diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py index 48263eef5..53ff8259e 100644 --- a/python/vyos/firewall.py +++ b/python/vyos/firewall.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright (C) 2021-2022 VyOS maintainers and contributors +# Copyright (C) 2021-2023 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as @@ -28,11 +28,11 @@ from time import strftime from vyos.remote import download from vyos.template import is_ipv4 from vyos.template import render -from vyos.util import call -from vyos.util import cmd -from vyos.util import dict_search_args -from vyos.util import dict_search_recursive -from vyos.util import run +from vyos.utils.dict import dict_search_args +from vyos.utils.dict import dict_search_recursive +from vyos.utils.process import call +from vyos.utils.process import cmd +from vyos.utils.process import run # Domain Resolver @@ -41,14 +41,19 @@ def fqdn_config_parse(firewall): firewall['ip6_fqdn'] = {} for domain, path in dict_search_recursive(firewall, 'fqdn'): - fw_name = path[1] # name/ipv6-name - rule = path[3] # rule id - suffix = path[4][0] # source/destination (1 char) - set_name = f'{fw_name}_{rule}_{suffix}' + hook_name = path[1] + priority = path[2] + + fw_name = path[2] + rule = path[4] + suffix = path[5][0] + set_name = f'{hook_name}_{priority}_{rule}_{suffix}' - if path[0] == 'name': + if (path[0] == 'ipv4') and (path[1] == 'forward' or path[1] == 'input' or path[1] == 'output' or path[1] == 'name'): firewall['ip_fqdn'][set_name] = domain - elif path[0] == 'ipv6_name': + elif (path[0] == 'ipv6') and (path[1] == 'forward' or path[1] == 'input' or path[1] == 'output' or path[1] == 'name'): + if path[1] == 'name': + set_name = f'name6_{priority}_{rule}_{suffix}' firewall['ip6_fqdn'][set_name] = domain def fqdn_resolve(fqdn, ipv6=False): @@ -80,7 +85,7 @@ def nft_action(vyos_action): return 'return' return vyos_action -def parse_rule(rule_conf, fw_name, rule_id, ip_name): +def parse_rule(rule_conf, hook, fw_name, rule_id, ip_name): output = [] def_suffix = '6' if ip_name == 'ip6' else '' @@ -129,16 +134,34 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if 'fqdn' in side_conf: fqdn = side_conf['fqdn'] + hook_name = '' operator = '' if fqdn[0] == '!': operator = '!=' - output.append(f'{ip_name} {prefix}addr {operator} @FQDN_{fw_name}_{rule_id}_{prefix}') + if hook == 'FWD': + hook_name = 'forward' + if hook == 'INP': + hook_name = 'input' + if hook == 'OUT': + hook_name = 'output' + if hook == 'NAM': + hook_name = f'name{def_suffix}' + output.append(f'{ip_name} {prefix}addr {operator} @FQDN_{hook_name}_{fw_name}_{rule_id}_{prefix}') if dict_search_args(side_conf, 'geoip', 'country_code'): operator = '' + hook_name = '' if dict_search_args(side_conf, 'geoip', 'inverse_match') != None: operator = '!=' - output.append(f'{ip_name} {prefix}addr {operator} @GEOIP_CC_{fw_name}_{rule_id}') + if hook == 'FWD': + hook_name = 'forward' + if hook == 'INP': + hook_name = 'input' + if hook == 'OUT': + hook_name = 'output' + if hook == 'NAM': + hook_name = f'name' + output.append(f'{ip_name} {prefix}addr {operator} @GEOIP_CC{def_suffix}_{hook_name}_{fw_name}_{rule_id}') if 'mac_address' in side_conf: suffix = side_conf["mac_address"] @@ -223,10 +246,23 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): action = rule_conf['action'] if 'action' in rule_conf else 'accept' output.append(f'log prefix "[{fw_name[:19]}-{rule_id}-{action[:1].upper()}]"') - if 'log_level' in rule_conf: - log_level = rule_conf['log_level'] - output.append(f'level {log_level}') + if 'log_options' in rule_conf: + + if 'level' in rule_conf['log_options']: + log_level = rule_conf['log_options']['level'] + output.append(f'log level {log_level}') + + if 'group' in rule_conf['log_options']: + log_group = rule_conf['log_options']['group'] + output.append(f'log group {log_group}') + if 'queue_threshold' in rule_conf['log_options']: + queue_threshold = rule_conf['log_options']['queue_threshold'] + output.append(f'queue-threshold {queue_threshold}') + + if 'snapshot_length' in rule_conf['log_options']: + log_snaplen = rule_conf['log_options']['snapshot_length'] + output.append(f'snaplen {log_snaplen}') if 'hop_limit' in rule_conf: operators = {'eq': '==', 'gt': '>', 'lt': '<'} @@ -236,12 +272,34 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): output.append(f'ip6 hoplimit {operator} {value}') if 'inbound_interface' in rule_conf: - iiface = rule_conf['inbound_interface'] - output.append(f'iifname {iiface}') + operator = '' + if 'interface_name' in rule_conf['inbound_interface']: + iiface = rule_conf['inbound_interface']['interface_name'] + if iiface[0] == '!': + operator = '!=' + iiface = iiface[1:] + output.append(f'iifname {operator} {{{iiface}}}') + else: + iiface = rule_conf['inbound_interface']['interface_group'] + if iiface[0] == '!': + operator = '!=' + iiface = iiface[1:] + output.append(f'iifname {operator} @I_{iiface}') if 'outbound_interface' in rule_conf: - oiface = rule_conf['outbound_interface'] - output.append(f'oifname {oiface}') + operator = '' + if 'interface_name' in rule_conf['outbound_interface']: + oiface = rule_conf['outbound_interface']['interface_name'] + if oiface[0] == '!': + operator = '!=' + oiface = oiface[1:] + output.append(f'oifname {operator} {{{oiface}}}') + else: + oiface = rule_conf['outbound_interface']['interface_group'] + if oiface[0] == '!': + operator = '!=' + oiface = oiface[1:] + output.append(f'oifname {operator} @I_{oiface}') if 'ttl' in rule_conf: operators = {'eq': '==', 'gt': '>', 'lt': '<'} @@ -269,6 +327,9 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): negated_lengths_str = ','.join(rule_conf['packet_length_exclude']) output.append(f'ip{def_suffix} length != {{{negated_lengths_str}}}') + if 'packet_type' in rule_conf: + output.append(f'pkttype ' + rule_conf['packet_type']) + if 'dscp' in rule_conf: dscp_str = ','.join(rule_conf['dscp']) output.append(f'ip{def_suffix} dscp {{{dscp_str}}}') @@ -280,7 +341,7 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if 'ipsec' in rule_conf: if 'match_ipsec' in rule_conf['ipsec']: output.append('meta ipsec == 1') - if 'match_non_ipsec' in rule_conf['ipsec']: + if 'match_none' in rule_conf['ipsec']: output.append('meta ipsec == 0') if 'fragment' in rule_conf: @@ -300,7 +361,7 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if 'recent' in rule_conf: count = rule_conf['recent']['count'] time = rule_conf['recent']['time'] - output.append(f'add @RECENT{def_suffix}_{fw_name}_{rule_id} {{ {ip_name} saddr limit rate over {count}/{time} burst {count} packets }}') + output.append(f'add @RECENT{def_suffix}_{hook}_{fw_name}_{rule_id} {{ {ip_name} saddr limit rate over {count}/{time} burst {count} packets }}') if 'time' in rule_conf: output.append(parse_time(rule_conf['time'])) @@ -314,21 +375,36 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if tcp_mss: output.append(f'tcp option maxseg size {tcp_mss}') + if 'connection_mark' in rule_conf: + conn_mark_str = ','.join(rule_conf['connection_mark']) + output.append(f'ct mark {{{conn_mark_str}}}') + output.append('counter') if 'set' in rule_conf: output.append(parse_policy_set(rule_conf['set'], def_suffix)) if 'action' in rule_conf: - output.append(nft_action(rule_conf['action'])) + # Change action=return to action=action + # #output.append(nft_action(rule_conf['action'])) + output.append(f'{rule_conf["action"]}') if 'jump' in rule_conf['action']: target = rule_conf['jump_target'] output.append(f'NAME{def_suffix}_{target}') + if 'queue' in rule_conf['action']: + if 'queue' in rule_conf: + target = rule_conf['queue'] + output.append(f'num {target}') + + if 'queue_options' in rule_conf: + queue_opts = ','.join(rule_conf['queue_options']) + output.append(f'{queue_opts}') + else: output.append('return') - output.append(f'comment "{fw_name}-{rule_id}"') + output.append(f'comment "{hook}-{fw_name}-{rule_id}"') return " ".join(output) def parse_tcp_flags(flags): @@ -360,6 +436,9 @@ def parse_time(time): def parse_policy_set(set_conf, def_suffix): out = [] + if 'connection_mark' in set_conf: + conn_mark = set_conf['connection_mark'] + out.append(f'ct mark set {conn_mark}') if 'dscp' in set_conf: dscp = set_conf['dscp'] out.append(f'ip{def_suffix} dscp set {dscp}') @@ -453,11 +532,12 @@ def geoip_update(firewall, force=False): # Map country codes to set names for codes, path in dict_search_recursive(firewall, 'country_code'): - set_name = f'GEOIP_CC_{path[1]}_{path[3]}' - if path[0] == 'name': + set_name = f'GEOIP_CC_{path[1]}_{path[2]}_{path[4]}' + if ( path[0] == 'ipv4'): for code in codes: ipv4_codes.setdefault(code, []).append(set_name) - elif path[0] == 'ipv6_name': + elif ( path[0] == 'ipv6' ): + set_name = f'GEOIP_CC6_{path[1]}_{path[2]}_{path[4]}' for code in codes: ipv6_codes.setdefault(code, []).append(set_name) diff --git a/python/vyos/frr.py b/python/vyos/frr.py index ccb132dd5..2e3c8a271 100644 --- a/python/vyos/frr.py +++ b/python/vyos/frr.py @@ -67,9 +67,12 @@ Apply the new configuration: import tempfile import re -from vyos import util -from vyos.util import chown -from vyos.util import cmd + +from vyos.utils.permission import chown +from vyos.utils.process import cmd +from vyos.utils.process import popen +from vyos.utils.process import STDOUT + import logging from logging.handlers import SysLogHandler import os @@ -85,7 +88,7 @@ LOG.addHandler(ch2) _frr_daemons = ['zebra', 'bgpd', 'fabricd', 'isisd', 'ospf6d', 'ospfd', 'pbrd', 'pimd', 'ripd', 'ripngd', 'sharpd', 'staticd', 'vrrpd', 'ldpd', - 'bfdd', 'eigrpd'] + 'bfdd', 'eigrpd', 'babeld'] path_vtysh = '/usr/bin/vtysh' path_frr_reload = '/usr/lib/frr/frr-reload.py' @@ -144,7 +147,7 @@ def get_configuration(daemon=None, marked=False): if daemon: cmd += f' -d {daemon}' - output, code = util.popen(cmd, stderr=util.STDOUT) + output, code = popen(cmd, stderr=STDOUT) if code: raise OSError(code, output) @@ -166,7 +169,7 @@ def mark_configuration(config): config: The configuration string to mark/test return: The marked configuration from FRR """ - output, code = util.popen(f"{path_vtysh} -m -f -", stderr=util.STDOUT, input=config) + output, code = popen(f"{path_vtysh} -m -f -", stderr=STDOUT, input=config) if code == 2: raise ConfigurationNotValid(str(output)) @@ -206,7 +209,7 @@ def reload_configuration(config, daemon=None): cmd += f' {f.name}' LOG.debug(f'reload_configuration: Executing command against frr-reload: "{cmd}"') - output, code = util.popen(cmd, stderr=util.STDOUT) + output, code = popen(cmd, stderr=STDOUT) f.close() for i, e in enumerate(output.split('\n')): LOG.debug(f'frr-reload output: {i:3} {e}') @@ -235,7 +238,7 @@ def execute(command): cmd = f"{path_vtysh} -c '{command}'" - output, code = util.popen(cmd, stderr=util.STDOUT) + output, code = popen(cmd, stderr=STDOUT) if code: raise OSError(code, output) @@ -267,7 +270,7 @@ def configure(lines, daemon=False): for x in lines: cmd += f" -c '{x}'" - output, code = util.popen(cmd, stderr=util.STDOUT) + output, code = popen(cmd, stderr=STDOUT) if code == 1: raise ConfigurationNotValid(f'Configuration FRR failed: {repr(output)}') elif code: diff --git a/python/vyos/ifconfig/bond.py b/python/vyos/ifconfig/bond.py index 0edd17055..d1d7d48c4 100644 --- a/python/vyos/ifconfig/bond.py +++ b/python/vyos/ifconfig/bond.py @@ -16,10 +16,10 @@ import os from vyos.ifconfig.interface import Interface -from vyos.util import cmd -from vyos.util import dict_search -from vyos.validate import assert_list -from vyos.validate import assert_positive +from vyos.utils.process import cmd +from vyos.utils.dict import dict_search +from vyos.utils.assertion import assert_list +from vyos.utils.assertion import assert_positive @Interface.register class BondIf(Interface): diff --git a/python/vyos/ifconfig/bridge.py b/python/vyos/ifconfig/bridge.py index aa818bc5f..b29e71394 100644 --- a/python/vyos/ifconfig/bridge.py +++ b/python/vyos/ifconfig/bridge.py @@ -17,10 +17,10 @@ from netifaces import interfaces import json from vyos.ifconfig.interface import Interface -from vyos.validate import assert_boolean -from vyos.validate import assert_positive -from vyos.util import cmd -from vyos.util import dict_search +from vyos.utils.assertion import assert_boolean +from vyos.utils.assertion import assert_positive +from vyos.utils.process import cmd +from vyos.utils.dict import dict_search from vyos.configdict import get_vlan_ids from vyos.configdict import list_diff diff --git a/python/vyos/ifconfig/control.py b/python/vyos/ifconfig/control.py index 7a6b36e7c..7402da55a 100644 --- a/python/vyos/ifconfig/control.py +++ b/python/vyos/ifconfig/control.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2019-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -19,10 +19,10 @@ from inspect import signature from inspect import _empty from vyos.ifconfig.section import Section -from vyos.util import popen -from vyos.util import cmd -from vyos.util import read_file -from vyos.util import write_file +from vyos.utils.process import popen +from vyos.utils.process import cmd +from vyos.utils.file import read_file +from vyos.utils.file import write_file from vyos import debug class Control(Section): @@ -49,6 +49,18 @@ class Control(Section): return popen(command, self.debug) def _cmd(self, command): + import re + if 'netns' in self.config: + # This command must be executed from default netns 'ip link set dev X netns X' + # exclude set netns cmd from netns to avoid: + # failed to run command: ip netns exec ns01 ip link set dev veth20 netns ns01 + pattern = r'ip link set dev (\S+) netns (\S+)' + matches = re.search(pattern, command) + if matches and matches.group(2) == self.config['netns']: + # Command already includes netns and matches desired namespace: + command = command + else: + command = f'ip netns exec {self.config["netns"]} {command}' return cmd(command, self.debug) def _get_command(self, config, name): diff --git a/python/vyos/ifconfig/ethernet.py b/python/vyos/ifconfig/ethernet.py index 519cfc58c..24ce3a803 100644 --- a/python/vyos/ifconfig/ethernet.py +++ b/python/vyos/ifconfig/ethernet.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2019-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -14,15 +14,16 @@ # License along with this library. If not, see <http://www.gnu.org/licenses/>. import os -import re from glob import glob + +from vyos.base import Warning from vyos.ethtool import Ethtool from vyos.ifconfig.interface import Interface -from vyos.util import run -from vyos.util import dict_search -from vyos.util import read_file -from vyos.validate import assert_list +from vyos.utils.dict import dict_search +from vyos.utils.file import read_file +from vyos.utils.process import run +from vyos.utils.assertion import assert_list @Interface.register class EthernetIf(Interface): @@ -118,7 +119,7 @@ class EthernetIf(Interface): cmd = f'ethtool --pause {ifname} autoneg {enable} tx {enable} rx {enable}' output, code = self._popen(cmd) if code: - print(f'Could not set flowcontrol for {ifname}') + Warning(f'could not change "{ifname}" flow control setting!') return output return None @@ -134,6 +135,7 @@ class EthernetIf(Interface): >>> i = EthernetIf('eth0') >>> i.set_speed_duplex('auto', 'auto') """ + ifname = self.config['ifname'] if speed not in ['auto', '10', '100', '1000', '2500', '5000', '10000', '25000', '40000', '50000', '100000', '400000']: @@ -143,7 +145,11 @@ class EthernetIf(Interface): raise ValueError("Value out of range (duplex)") if not self.ethtool.check_speed_duplex(speed, duplex): - self._debug_msg(f'NIC driver does not support changing speed/duplex settings!') + Warning(f'changing speed/duplex setting on "{ifname}" is unsupported!') + return + + if not self.ethtool.check_auto_negotiation_supported(): + Warning(f'changing auto-negotiation setting on "{ifname}" is unsupported!') return # Get current speed and duplex settings: @@ -239,7 +245,7 @@ class EthernetIf(Interface): if not isinstance(state, bool): raise ValueError('Value out of range') - rps_cpus = '0' + rps_cpus = 0 queues = len(glob(f'/sys/class/net/{self.ifname}/queues/rx-*')) if state: # Enable RPS on all available CPUs except CPU0 which we will not @@ -248,10 +254,16 @@ class EthernetIf(Interface): # representation of the CPUs which should participate on RPS, we # can enable more CPUs that are physically present on the system, # Linux will clip that internally! - rps_cpus = 'ffffffff,ffffffff,ffffffff,fffffffe' + rps_cpus = (1 << os.cpu_count()) -1 + + # XXX: we should probably reserve one core when the system is under + # high preasure so we can still have a core left for housekeeping. + # This is done by masking out the lowst bit so CPU0 is spared from + # receive packet steering. + rps_cpus &= ~1 for i in range(0, queues): - self._write_sysfs(f'/sys/class/net/{self.ifname}/queues/rx-{i}/rps_cpus', rps_cpus) + self._write_sysfs(f'/sys/class/net/{self.ifname}/queues/rx-{i}/rps_cpus', f'{rps_cpus:x}') # send bitmask representation as hex string without leading '0x' return True @@ -362,10 +374,11 @@ class EthernetIf(Interface): self.set_tso(dict_search('offload.tso', config) != None) # Set physical interface speed and duplex - if {'speed', 'duplex'} <= set(config): - speed = config.get('speed') - duplex = config.get('duplex') - self.set_speed_duplex(speed, duplex) + if 'speed_duplex_changed' in config: + if {'speed', 'duplex'} <= set(config): + speed = config.get('speed') + duplex = config.get('duplex') + self.set_speed_duplex(speed, duplex) # Set interface ring buffer if 'ring_buffer' in config: diff --git a/python/vyos/ifconfig/geneve.py b/python/vyos/ifconfig/geneve.py index 276c34cd7..fbb261a35 100644 --- a/python/vyos/ifconfig/geneve.py +++ b/python/vyos/ifconfig/geneve.py @@ -14,7 +14,7 @@ # License along with this library. If not, see <http://www.gnu.org/licenses/>. from vyos.ifconfig import Interface -from vyos.util import dict_search +from vyos.utils.dict import dict_search @Interface.register class GeneveIf(Interface): @@ -45,6 +45,7 @@ class GeneveIf(Interface): 'parameters.ip.df' : 'df', 'parameters.ip.tos' : 'tos', 'parameters.ip.ttl' : 'ttl', + 'parameters.ip.innerproto' : 'innerprotoinherit', 'parameters.ipv6.flowlabel' : 'flowlabel', } diff --git a/python/vyos/ifconfig/input.py b/python/vyos/ifconfig/input.py index db7d2b6b4..3e5f5790d 100644 --- a/python/vyos/ifconfig/input.py +++ b/python/vyos/ifconfig/input.py @@ -1,4 +1,4 @@ -# Copyright 2020 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -17,6 +17,16 @@ from vyos.ifconfig.interface import Interface @Interface.register class InputIf(Interface): + """ + The Intermediate Functional Block (ifb) pseudo network interface acts as a + QoS concentrator for multiple different sources of traffic. Packets from + or to other interfaces have to be redirected to it using the mirred action + in order to be handled, regularly routed traffic will be dropped. This way, + a single stack of qdiscs, classes and filters can be shared between + multiple interfaces. + """ + + iftype = 'ifb' definition = { **Interface.definition, **{ diff --git a/python/vyos/ifconfig/interface.py b/python/vyos/ifconfig/interface.py index c50ead89f..050095364 100644 --- a/python/vyos/ifconfig/interface.py +++ b/python/vyos/ifconfig/interface.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2019-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -31,23 +31,26 @@ from vyos import ConfigError from vyos.configdict import list_diff from vyos.configdict import dict_merge from vyos.configdict import get_vlan_ids +from vyos.defaults import directories from vyos.template import render -from vyos.util import mac2eui64 -from vyos.util import dict_search -from vyos.util import read_file -from vyos.util import get_interface_config -from vyos.util import get_interface_namespace -from vyos.util import is_systemd_service_active +from vyos.utils.network import mac2eui64 +from vyos.utils.dict import dict_search +from vyos.utils.file import read_file +from vyos.utils.network import get_interface_config +from vyos.utils.network import get_interface_namespace +from vyos.utils.network import is_netns_interface +from vyos.utils.process import is_systemd_service_active +from vyos.utils.process import run from vyos.template import is_ipv4 from vyos.template import is_ipv6 -from vyos.validate import is_intf_addr_assigned -from vyos.validate import is_ipv6_link_local -from vyos.validate import assert_boolean -from vyos.validate import assert_list -from vyos.validate import assert_mac -from vyos.validate import assert_mtu -from vyos.validate import assert_positive -from vyos.validate import assert_range +from vyos.utils.network import is_intf_addr_assigned +from vyos.utils.network import is_ipv6_link_local +from vyos.utils.assertion import assert_boolean +from vyos.utils.assertion import assert_list +from vyos.utils.assertion import assert_mac +from vyos.utils.assertion import assert_mtu +from vyos.utils.assertion import assert_positive +from vyos.utils.assertion import assert_range from vyos.ifconfig.control import Control from vyos.ifconfig.vrrp import VRRP @@ -57,6 +60,8 @@ from vyos.ifconfig import Section from netaddr import EUI from netaddr import mac_unix_expanded +link_local_prefix = 'fe80::/64' + class Interface(Control): # This is the class which will be used to create # self.operational, it allows subclasses, such as @@ -135,9 +140,6 @@ class Interface(Control): 'validate': assert_mtu, 'shellcmd': 'ip link set dev {ifname} mtu {value}', }, - 'netns': { - 'shellcmd': 'ip link set dev {ifname} netns {value}', - }, 'vrf': { 'convert': lambda v: f'master {v}' if v else 'nomaster', 'shellcmd': 'ip link set dev {ifname} {value}', @@ -172,10 +174,6 @@ class Interface(Control): 'validate': assert_boolean, 'location': '/proc/sys/net/ipv4/conf/{ifname}/bc_forwarding', }, - 'rp_filter': { - 'validate': lambda flt: assert_range(flt,0,3), - 'location': '/proc/sys/net/ipv4/conf/{ifname}/rp_filter', - }, 'ipv6_accept_ra': { 'validate': lambda ara: assert_range(ara,0,3), 'location': '/proc/sys/net/ipv6/conf/{ifname}/accept_ra', @@ -188,6 +186,10 @@ class Interface(Control): 'validate': lambda fwd: assert_range(fwd,0,2), 'location': '/proc/sys/net/ipv6/conf/{ifname}/forwarding', }, + 'ipv6_accept_dad': { + 'validate': lambda dad: assert_range(dad,0,3), + 'location': '/proc/sys/net/ipv6/conf/{ifname}/accept_dad', + }, 'ipv6_dad_transmits': { 'validate': assert_positive, 'location': '/proc/sys/net/ipv6/conf/{ifname}/dad_transmits', @@ -217,6 +219,10 @@ class Interface(Control): 'validate': lambda link: assert_range(link,0,3), 'location': '/proc/sys/net/ipv4/conf/{ifname}/link_filter', }, + 'per_client_thread': { + 'validate': assert_boolean, + 'location': '/sys/class/net/{ifname}/threaded', + }, } _sysfs_get = { @@ -241,9 +247,6 @@ class Interface(Control): 'ipv4_directed_broadcast': { 'location': '/proc/sys/net/ipv4/conf/{ifname}/bc_forwarding', }, - 'rp_filter': { - 'location': '/proc/sys/net/ipv4/conf/{ifname}/rp_filter', - }, 'ipv6_accept_ra': { 'location': '/proc/sys/net/ipv6/conf/{ifname}/accept_ra', }, @@ -253,6 +256,9 @@ class Interface(Control): 'ipv6_forwarding': { 'location': '/proc/sys/net/ipv6/conf/{ifname}/forwarding', }, + 'ipv6_accept_dad': { + 'location': '/proc/sys/net/ipv6/conf/{ifname}/accept_dad', + }, 'ipv6_dad_transmits': { 'location': '/proc/sys/net/ipv6/conf/{ifname}/dad_transmits', }, @@ -265,11 +271,18 @@ class Interface(Control): 'link_detect': { 'location': '/proc/sys/net/ipv4/conf/{ifname}/link_filter', }, + 'per_client_thread': { + 'validate': assert_boolean, + 'location': '/sys/class/net/{ifname}/threaded', + }, } @classmethod - def exists(cls, ifname): - return os.path.exists(f'/sys/class/net/{ifname}') + def exists(cls, ifname: str, netns: str=None) -> bool: + cmd = f'ip link show dev {ifname}' + if netns: + cmd = f'ip netns exec {netns} {cmd}' + return run(cmd) == 0 @classmethod def get_config(cls): @@ -337,7 +350,13 @@ class Interface(Control): self.vrrp = VRRP(ifname) def _create(self): + # Do not create interface that already exist or exists in netns + netns = self.config.get('netns', None) + if self.exists(f'{self.ifname}', netns=netns): + return + cmd = 'ip link add dev {ifname} type {type}'.format(**self.config) + if 'netns' in self.config: cmd = f'ip netns exec {netns} {cmd}' self._cmd(cmd) def remove(self): @@ -372,6 +391,9 @@ class Interface(Control): # after interface removal no other commands should be allowed # to be called and instead should raise an Exception: cmd = 'ip link del dev {ifname}'.format(**self.config) + # for delete we can't get data from self.config{'netns'} + netns = get_interface_namespace(self.ifname) + if netns: cmd = f'ip netns exec {netns} {cmd}' return self._cmd(cmd) def _set_vrf_ct_zone(self, vrf): @@ -379,6 +401,10 @@ class Interface(Control): Add/Remove rules in nftables to associate traffic in VRF to an individual conntack zone """ + # Don't allow for netns yet + if 'netns' in self.config: + return None + if vrf: # Get routing table ID for VRF vrf_table_id = get_interface_config(vrf).get('linkinfo', {}).get( @@ -522,36 +548,30 @@ class Interface(Control): if prev_state == 'up': self.set_admin_state('up') - def del_netns(self, netns): - """ - Remove interface from given NETNS. - """ - - # If NETNS does not exist then there is nothing to delete + def del_netns(self, netns: str) -> bool: + """ Remove interface from given network namespace """ + # If network namespace does not exist then there is nothing to delete if not os.path.exists(f'/run/netns/{netns}'): - return None - - # As a PoC we only allow 'dummy' interfaces - if 'dum' not in self.ifname: - return None + return False - # Check if interface realy exists in namespace - if get_interface_namespace(self.ifname) != None: - self._cmd(f'ip netns exec {get_interface_namespace(self.ifname)} ip link del dev {self.ifname}') - return + # Check if interface exists in network namespace + if is_netns_interface(self.ifname, netns): + self._cmd(f'ip netns exec {netns} ip link del dev {self.ifname}') + return True + return False - def set_netns(self, netns): + def set_netns(self, netns: str) -> bool: """ - Add interface from given NETNS. + Add interface from given network namespace Example: >>> from vyos.ifconfig import Interface >>> Interface('dum0').set_netns('foo') """ + self._cmd(f'ip link set dev {self.ifname} netns {netns}') + return True - self.set_interface('netns', netns) - - def set_vrf(self, vrf): + def set_vrf(self, vrf: str) -> bool: """ Add/Remove interface from given VRF instance. @@ -563,10 +583,11 @@ class Interface(Control): tmp = self.get_interface('vrf') if tmp == vrf: - return None + return False self.set_interface('vrf', vrf) self._set_vrf_ct_zone(vrf) + return True def set_arp_cache_tmo(self, tmo): """ @@ -603,6 +624,10 @@ class Interface(Control): >>> from vyos.ifconfig import Interface >>> Interface('eth0').set_tcp_ipv4_mss(1340) """ + # Don't allow for netns yet + if 'netns' in self.config: + return None + self._cleanup_mss_rules('raw', self.ifname) nft_prefix = 'nft add rule raw VYOS_TCP_MSS' base_cmd = f'oifname "{self.ifname}" tcp flags & (syn|rst) == syn' @@ -623,6 +648,10 @@ class Interface(Control): >>> from vyos.ifconfig import Interface >>> Interface('eth0').set_tcp_mss(1320) """ + # Don't allow for netns yet + if 'netns' in self.config: + return None + self._cleanup_mss_rules('ip6 raw', self.ifname) nft_prefix = 'nft add rule ip6 raw VYOS_TCP_MSS' base_cmd = f'oifname "{self.ifname}" tcp flags & (syn|rst) == syn' @@ -727,37 +756,63 @@ class Interface(Control): return None return self.set_interface('ipv4_directed_broadcast', forwarding) - def set_ipv4_source_validation(self, value): + def _cleanup_ipv4_source_validation_rules(self, ifname): + results = self._cmd(f'nft -a list chain ip raw vyos_rpfilter').split("\n") + for line in results: + if f'iifname "{ifname}"' in line: + handle_search = re.search('handle (\d+)', line) + if handle_search: + self._cmd(f'nft delete rule ip raw vyos_rpfilter handle {handle_search[1]}') + + def set_ipv4_source_validation(self, mode): """ - Help prevent attacks used by Spoofing IP Addresses. Reverse path - filtering is a Kernel feature that, when enabled, is designed to ensure - packets that are not routable to be dropped. The easiest example of this - would be and IP Address of the range 10.0.0.0/8, a private IP Address, - being received on the Internet facing interface of the router. + Set IPv4 reverse path validation - As per RFC3074. + Example: + >>> from vyos.ifconfig import Interface + >>> Interface('eth0').set_ipv4_source_validation('strict') """ - if value == 'strict': - value = 1 - elif value == 'loose': - value = 2 - else: - value = 0 + # Don't allow for netns yet + if 'netns' in self.config: + return None - all_rp_filter = int(read_file('/proc/sys/net/ipv4/conf/all/rp_filter')) - if all_rp_filter > value: - global_setting = 'disable' - if all_rp_filter == 1: global_setting = 'strict' - elif all_rp_filter == 2: global_setting = 'loose' + self._cleanup_ipv4_source_validation_rules(self.ifname) + nft_prefix = f'nft insert rule ip raw vyos_rpfilter iifname "{self.ifname}"' + if mode in ['strict', 'loose']: + self._cmd(f"{nft_prefix} counter return") + if mode == 'strict': + self._cmd(f"{nft_prefix} fib saddr . iif oif 0 counter drop") + elif mode == 'loose': + self._cmd(f"{nft_prefix} fib saddr oif 0 counter drop") + + def _cleanup_ipv6_source_validation_rules(self, ifname): + results = self._cmd(f'nft -a list chain ip6 raw vyos_rpfilter').split("\n") + for line in results: + if f'iifname "{ifname}"' in line: + handle_search = re.search('handle (\d+)', line) + if handle_search: + self._cmd(f'nft delete rule ip6 raw vyos_rpfilter handle {handle_search[1]}') - from vyos.base import Warning - Warning(f'Global source-validation is set to "{global_setting} '\ - f'this overrides per interface setting!') + def set_ipv6_source_validation(self, mode): + """ + Set IPv6 reverse path validation - tmp = self.get_interface('rp_filter') - if int(tmp) == value: + Example: + >>> from vyos.ifconfig import Interface + >>> Interface('eth0').set_ipv6_source_validation('strict') + """ + # Don't allow for netns yet + if 'netns' in self.config: return None - return self.set_interface('rp_filter', value) + + self._cleanup_ipv6_source_validation_rules(self.ifname) + nft_prefix = f'nft insert rule ip6 raw vyos_rpfilter iifname "{self.ifname}"' + if mode in ['strict', 'loose']: + self._cmd(f"{nft_prefix} counter return") + if mode == 'strict': + self._cmd(f"{nft_prefix} fib saddr . iif oif 0 counter drop") + elif mode == 'loose': + self._cmd(f"{nft_prefix} fib saddr oif 0 counter drop") def set_ipv6_accept_ra(self, accept_ra): """ @@ -843,6 +898,13 @@ class Interface(Control): return None return self.set_interface('ipv6_forwarding', forwarding) + def set_ipv6_dad_accept(self, dad): + """Whether to accept DAD (Duplicate Address Detection)""" + tmp = self.get_interface('ipv6_accept_dad') + if tmp == dad: + return None + return self.set_interface('ipv6_accept_dad', dad) + def set_ipv6_dad_messages(self, dad): """ The amount of Duplicate Address Detection probes to send. @@ -1094,13 +1156,17 @@ class Interface(Control): if addr in self._addr: return False + # get interface network namespace if specified + netns = self.config.get('netns', None) + # add to interface if addr == 'dhcp': self.set_dhcp(True) elif addr == 'dhcpv6': self.set_dhcpv6(True) - elif not is_intf_addr_assigned(self.ifname, addr): - tmp = f'ip addr add {addr} dev {self.ifname}' + elif not is_intf_addr_assigned(self.ifname, addr, netns=netns): + netns_cmd = f'ip netns exec {netns}' if netns else '' + tmp = f'{netns_cmd} ip addr add {addr} dev {self.ifname}' # Add broadcast address for IPv4 if is_ipv4(addr): tmp += ' brd +' @@ -1140,13 +1206,17 @@ class Interface(Control): if not addr: raise ValueError() + # get interface network namespace if specified + netns = self.config.get('netns', None) + # remove from interface if addr == 'dhcp': self.set_dhcp(False) elif addr == 'dhcpv6': self.set_dhcpv6(False) - elif is_intf_addr_assigned(self.ifname, addr): - self._cmd(f'ip addr del "{addr}" dev "{self.ifname}"') + elif is_intf_addr_assigned(self.ifname, addr, netns=netns): + netns_cmd = f'ip netns exec {netns}' if netns else '' + self._cmd(f'{netns_cmd} ip addr del {addr} dev {self.ifname}') else: return False @@ -1166,8 +1236,11 @@ class Interface(Control): self.set_dhcp(False) self.set_dhcpv6(False) + netns = get_interface_namespace(self.ifname) + netns_cmd = f'ip netns exec {netns}' if netns else '' + cmd = f'{netns_cmd} ip addr flush dev {self.ifname}' # flush all addresses - self._cmd(f'ip addr flush dev "{self.ifname}"') + self._cmd(cmd) def add_to_bridge(self, bridge_dict): """ @@ -1238,44 +1311,49 @@ class Interface(Control): raise ValueError() ifname = self.ifname - config_base = r'/var/lib/dhcp/dhclient' - config_file = f'{config_base}_{ifname}.conf' - options_file = f'{config_base}_{ifname}.options' - pid_file = f'{config_base}_{ifname}.pid' - lease_file = f'{config_base}_{ifname}.leases' + config_base = directories['isc_dhclient_dir'] + '/dhclient' + dhclient_config_file = f'{config_base}_{ifname}.conf' + dhclient_lease_file = f'{config_base}_{ifname}.leases' + systemd_override_file = f'/run/systemd/system/dhclient@{ifname}.service.d/10-override.conf' systemd_service = f'dhclient@{ifname}.service' + # Rendered client configuration files require the apsolute config path + self.config['isc_dhclient_dir'] = directories['isc_dhclient_dir'] + # 'up' check is mandatory b/c even if the interface is A/D, as soon as # the DHCP client is started the interface will be placed in u/u state. # This is not what we intended to do when disabling an interface. - if enable and 'disable' not in self._config: - if dict_search('dhcp_options.host_name', self._config) == None: + if enable and 'disable' not in self.config: + if dict_search('dhcp_options.host_name', self.config) == None: # read configured system hostname. # maybe change to vyos hostd client ??? hostname = 'vyos' with open('/etc/hostname', 'r') as f: hostname = f.read().rstrip('\n') tmp = {'dhcp_options' : { 'host_name' : hostname}} - self._config = dict_merge(tmp, self._config) + self.config = dict_merge(tmp, self.config) + + render(systemd_override_file, 'dhcp-client/override.conf.j2', self.config) + render(dhclient_config_file, 'dhcp-client/ipv4.j2', self.config) - render(options_file, 'dhcp-client/daemon-options.j2', self._config) - render(config_file, 'dhcp-client/ipv4.j2', self._config) + # Reload systemd unit definitons as some options are dynamically generated + self._cmd('systemctl daemon-reload') # When the DHCP client is restarted a brief outage will occur, as # the old lease is released a new one is acquired (T4203). We will # only restart DHCP client if it's option changed, or if it's not # running, but it should be running (e.g. on system startup) - if 'dhcp_options_changed' in self._config or not is_systemd_service_active(systemd_service): + if 'dhcp_options_changed' in self.config or not is_systemd_service_active(systemd_service): return self._cmd(f'systemctl restart {systemd_service}') - return None else: if is_systemd_service_active(systemd_service): self._cmd(f'systemctl stop {systemd_service}') # cleanup old config files - for file in [config_file, options_file, pid_file, lease_file]: + for file in [dhclient_config_file, systemd_override_file, dhclient_lease_file]: if os.path.isfile(file): os.remove(file) + return None def set_dhcpv6(self, enable): """ @@ -1285,11 +1363,20 @@ class Interface(Control): raise ValueError() ifname = self.ifname - config_file = f'/run/dhcp6c/dhcp6c.{ifname}.conf' + config_base = directories['dhcp6_client_dir'] + config_file = f'{config_base}/dhcp6c.{ifname}.conf' + systemd_override_file = f'/run/systemd/system/dhcp6c@{ifname}.service.d/10-override.conf' systemd_service = f'dhcp6c@{ifname}.service' - if enable and 'disable' not in self._config: - render(config_file, 'dhcp-client/ipv6.j2', self._config) + # Rendered client configuration files require the apsolute config path + self.config['dhcp6_client_dir'] = directories['dhcp6_client_dir'] + + if enable and 'disable' not in self.config: + render(systemd_override_file, 'dhcp-client/ipv6.override.conf.j2', self.config) + render(config_file, 'dhcp-client/ipv6.j2', self.config) + + # Reload systemd unit definitons as some options are dynamically generated + self._cmd('systemctl daemon-reload') # We must ignore any return codes. This is required to enable # DHCPv6-PD for interfaces which are yet not up and running. @@ -1300,26 +1387,33 @@ class Interface(Control): if os.path.isfile(config_file): os.remove(config_file) + return None + def set_mirror_redirect(self): # Please refer to the document for details # - https://man7.org/linux/man-pages/man8/tc.8.html # - https://man7.org/linux/man-pages/man8/tc-mirred.8.html # Depening if we are the source or the target interface of the port # mirror we need to setup some variables. - source_if = self._config['ifname'] + + # Don't allow for netns yet + if 'netns' in self.config: + return None + + source_if = self.config['ifname'] mirror_config = None - if 'mirror' in self._config: - mirror_config = self._config['mirror'] - if 'is_mirror_intf' in self._config: - source_if = next(iter(self._config['is_mirror_intf'])) - mirror_config = self._config['is_mirror_intf'][source_if].get('mirror', None) + if 'mirror' in self.config: + mirror_config = self.config['mirror'] + if 'is_mirror_intf' in self.config: + source_if = next(iter(self.config['is_mirror_intf'])) + mirror_config = self.config['is_mirror_intf'][source_if].get('mirror', None) redirect_config = None # clear existing ingess - ignore errors (e.g. "Error: Cannot find specified # qdisc on specified device") - we simply cleanup all stuff here - if not 'traffic_policy' in self._config: + if not 'traffic_policy' in self.config: self._popen(f'tc qdisc del dev {source_if} parent ffff: 2>/dev/null'); self._popen(f'tc qdisc del dev {source_if} parent 1: 2>/dev/null'); @@ -1343,43 +1437,39 @@ class Interface(Control): if err: print('tc qdisc(filter for mirror port failed') # Apply interface traffic redirection policy - elif 'redirect' in self._config: + elif 'redirect' in self.config: _, err = self._popen(f'tc qdisc add dev {source_if} handle ffff: ingress') if err: print(f'tc qdisc add for redirect failed!') - target_if = self._config['redirect'] + target_if = self.config['redirect'] _, err = self._popen(f'tc filter add dev {source_if} parent ffff: protocol '\ f'all prio 10 u32 match u32 0 0 flowid 1:1 action mirred '\ f'egress redirect dev {target_if}') if err: print('tc filter add for redirect failed') - def set_xdp(self, state): + def set_per_client_thread(self, enable): """ - Enable Kernel XDP support. State can be either True or False. + Per-device control to enable/disable the threaded mode for all the napi + instances of the given network device, without the need for a device up/down. + + User sets it to 1 or 0 to enable or disable threaded mode. Example: >>> from vyos.ifconfig import Interface - >>> i = Interface('eth0') - >>> i.set_xdp(True) - """ - if not isinstance(state, bool): - raise ValueError("Value out of range") - - # https://phabricator.vyos.net/T3448 - there is (yet) no RPI support for XDP - if not os.path.exists('/usr/sbin/xdp_loader'): - return - - ifname = self.config['ifname'] - cmd = f'xdp_loader -d {ifname} -U --auto-mode' - if state: - # Using 'xdp' will automatically decide if the driver supports - # 'xdpdrv' or only 'xdpgeneric'. A user later sees which driver is - # actually in use by calling 'ip a' or 'show interfaces ethernet' - cmd = f'xdp_loader -d {ifname} --auto-mode -F --progsec xdp_router ' \ - f'--filename /usr/share/vyos/xdp/xdp_prog_kern.o && ' \ - f'xdp_prog_user -d {ifname}' + >>> Interface('wg1').set_per_client_thread(1) + """ + # In the case of a "virtual" interface like wireguard, the sysfs + # node is only created once there is a peer configured. We can now + # add a verify() code-path for this or make this dynamic without + # nagging the user + tmp = self._sysfs_get['per_client_thread']['location'] + if not os.path.exists(tmp): + return None - return self._cmd(cmd) + tmp = self.get_interface('per_client_thread') + if tmp == enable: + return None + self.set_interface('per_client_thread', enable) def update(self, config): """ General helper function which works on a dictionary retrived by @@ -1394,7 +1484,7 @@ class Interface(Control): # Cache the configuration - it will be reused inside e.g. DHCP handler # XXX: maybe pass the option via __init__ in the future and rename this # method to apply()? - self._config = config + self.config = config # Change interface MAC address - re-set to real hardware address (hw-id) # if custom mac is removed. Skip if bond member. @@ -1410,8 +1500,8 @@ class Interface(Control): # Since the interface is pushed onto a separate logical stack # Configure NETNS if dict_search('netns', config) != None: - self.set_netns(config.get('netns', '')) - return + if not is_netns_interface(self.ifname, self.config['netns']): + self.set_netns(config.get('netns', '')) else: self.del_netns(config.get('netns', '')) @@ -1444,7 +1534,7 @@ class Interface(Control): # we will delete all interface specific IP addresses if they are not # explicitly configured on the CLI if is_ipv6_link_local(addr): - eui64 = mac2eui64(self.get_mac(), 'fe80::/64') + eui64 = mac2eui64(self.get_mac(), link_local_prefix) if addr != f'{eui64}/64': self.del_addr(addr) else: @@ -1531,6 +1621,11 @@ class Interface(Control): value = tmp if (tmp != None) else '0' self.set_ipv4_source_validation(value) + # IPv6 source-validation + tmp = dict_search('ipv6.source_validation', config) + value = tmp if (tmp != None) else '0' + self.set_ipv6_source_validation(value) + # MTU - Maximum Transfer Unit has a default value. It must ALWAYS be set # before mangling any IPv6 option. If MTU is less then 1280 IPv6 will be # automatically disabled by the kernel. Also MTU must be increased before @@ -1560,10 +1655,17 @@ class Interface(Control): value = '1' if (tmp != None) else '0' self.set_ipv6_autoconf(value) - # IPv6 Duplicate Address Detection (DAD) tries + # Whether to accept IPv6 DAD (Duplicate Address Detection) packets + tmp = dict_search('ipv6.accept_dad', config) + # Not all interface types got this CLI option, but if they do, there + # is an XML defaultValue available + if (tmp != None): self.set_ipv6_dad_accept(tmp) + + # IPv6 DAD tries tmp = dict_search('ipv6.dup_addr_detect_transmits', config) - value = tmp if (tmp != None) else '1' - self.set_ipv6_dad_messages(value) + # Not all interface types got this CLI option, but if they do, there + # is an XML defaultValue available + if (tmp != None): self.set_ipv6_dad_messages(tmp) # Delete old IPv6 EUI64 addresses before changing MAC for addr in (dict_search('ipv6.address.eui64_old', config) or []): @@ -1571,9 +1673,9 @@ class Interface(Control): # Manage IPv6 link-local addresses if dict_search('ipv6.address.no_default_link_local', config) != None: - self.del_ipv6_eui64_address('fe80::/64') + self.del_ipv6_eui64_address(link_local_prefix) else: - self.add_ipv6_eui64_address('fe80::/64') + self.add_ipv6_eui64_address(link_local_prefix) # Add IPv6 EUI-based addresses tmp = dict_search('ipv6.address.eui64', config) @@ -1586,12 +1688,14 @@ class Interface(Control): tmp = config.get('is_bridge_member') self.add_to_bridge(tmp) - # eXpress Data Path - highly experimental - self.set_xdp('xdp' in config) - # configure interface mirror or redirection target self.set_mirror_redirect() + # enable/disable NAPI threading mode + tmp = dict_search('per_client_thread', config) + value = '1' if (tmp != None) else '0' + self.set_per_client_thread(value) + # Enable/Disable of an interface must always be done at the end of the # derived class to make use of the ref-counting set_admin_state() # function. We will only enable the interface if 'up' was called as @@ -1709,6 +1813,14 @@ class VLANIf(Interface): if self.exists(f'{self.ifname}'): return + # If source_interface or vlan_id was not explicitly defined (e.g. when + # calling VLANIf('eth0.1').remove() we can define source_interface and + # vlan_id here, as it's quiet obvious that it would be eth0 in that case. + if 'source_interface' not in self.config: + self.config['source_interface'] = '.'.join(self.ifname.split('.')[:-1]) + if 'vlan_id' not in self.config: + self.config['vlan_id'] = self.ifname.split('.')[-1] + cmd = 'ip link add link {source_interface} name {ifname} type vlan id {vlan_id}' if 'protocol' in self.config: cmd += ' protocol {protocol}' diff --git a/python/vyos/ifconfig/l2tpv3.py b/python/vyos/ifconfig/l2tpv3.py index fcd1fbf81..85a89ef8b 100644 --- a/python/vyos/ifconfig/l2tpv3.py +++ b/python/vyos/ifconfig/l2tpv3.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2019-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -15,7 +15,8 @@ from time import sleep from time import time -from vyos.util import run + +from vyos.utils.process import run from vyos.ifconfig.interface import Interface def wait_for_add_l2tpv3(timeout=10, sleep_interval=1, cmd=None): diff --git a/python/vyos/ifconfig/loopback.py b/python/vyos/ifconfig/loopback.py index b3babfadc..e1d041839 100644 --- a/python/vyos/ifconfig/loopback.py +++ b/python/vyos/ifconfig/loopback.py @@ -46,7 +46,7 @@ class LoopbackIf(Interface): if addr in self._persistent_addresses: # Do not allow deletion of the default loopback addresses as # this will cause weird system behavior like snmp/ssh no longer - # operating as expected, see https://phabricator.vyos.net/T2034. + # operating as expected, see https://vyos.dev/T2034. continue self.del_addr(addr) diff --git a/python/vyos/ifconfig/macsec.py b/python/vyos/ifconfig/macsec.py index 1a78d18d8..9329c5ee7 100644 --- a/python/vyos/ifconfig/macsec.py +++ b/python/vyos/ifconfig/macsec.py @@ -1,4 +1,4 @@ -# Copyright 2020-2021 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2020-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -41,10 +41,30 @@ class MACsecIf(Interface): Create MACsec interface in OS kernel. Interface is administrative down by default. """ + # create tunnel interface cmd = 'ip link add link {source_interface} {ifname} type {type}'.format(**self.config) cmd += f' cipher {self.config["security"]["cipher"]}' self._cmd(cmd) + # Check if using static keys + if 'static' in self.config["security"]: + # Set static TX key + cmd = 'ip macsec add {ifname} tx sa 0 pn 1 on key 00'.format(**self.config) + cmd += f' {self.config["security"]["static"]["key"]}' + self._cmd(cmd) + + for peer, peer_config in self.config["security"]["static"]["peer"].items(): + if 'disable' in peer_config: + continue + + # Create the address + cmd = 'ip macsec add {ifname} rx port 1 address'.format(**self.config) + cmd += f' {peer_config["mac"]}' + self._cmd(cmd) + # Add the rx-key to the address + cmd += f' sa 0 pn 1 on key 01 {peer_config["key"]}' + self._cmd(cmd) + # interface is always A/D down. It needs to be enabled explicitly self.set_admin_state('down') diff --git a/python/vyos/ifconfig/operational.py b/python/vyos/ifconfig/operational.py index 33e8614f0..dc2742123 100644 --- a/python/vyos/ifconfig/operational.py +++ b/python/vyos/ifconfig/operational.py @@ -143,15 +143,17 @@ class Operational(Control): except IOError: return no_stats - def clear_counters(self, counters=None): - clear = self._stats_all if counters is None else [] - stats = self.load_counters() + def clear_counters(self): + stats = self.get_stats() for counter, value in stats.items(): - stats[counter] = 0 if counter in clear else value + stats[counter] = value self.save_counters(stats) def reset_counters(self): - os.remove(self.cachefile(self.ifname)) + try: + os.remove(self.cachefile(self.ifname)) + except FileNotFoundError: + pass def get_stats(self): """ return a dict() with the value for each interface counter """ diff --git a/python/vyos/ifconfig/pppoe.py b/python/vyos/ifconfig/pppoe.py index 437fe0cae..febf1452d 100644 --- a/python/vyos/ifconfig/pppoe.py +++ b/python/vyos/ifconfig/pppoe.py @@ -14,8 +14,8 @@ # License along with this library. If not, see <http://www.gnu.org/licenses/>. from vyos.ifconfig.interface import Interface -from vyos.validate import assert_range -from vyos.util import get_interface_config +from vyos.utils.assertion import assert_range +from vyos.utils.network import get_interface_config @Interface.register class PPPoEIf(Interface): diff --git a/python/vyos/ifconfig/tunnel.py b/python/vyos/ifconfig/tunnel.py index 5258a2cb1..9ba7b31a6 100644 --- a/python/vyos/ifconfig/tunnel.py +++ b/python/vyos/ifconfig/tunnel.py @@ -17,8 +17,8 @@ # https://community.hetzner.com/tutorials/linux-setup-gre-tunnel from vyos.ifconfig.interface import Interface -from vyos.util import dict_search -from vyos.validate import assert_list +from vyos.utils.dict import dict_search +from vyos.utils.assertion import assert_list def enable_to_on(value): if value == 'enable': @@ -83,11 +83,6 @@ class TunnelIf(Interface): 'convert': enable_to_on, 'shellcmd': 'ip link set dev {ifname} multicast {value}', }, - 'allmulticast': { - 'validate': lambda v: assert_list(v, ['enable', 'disable']), - 'convert': enable_to_on, - 'shellcmd': 'ip link set dev {ifname} allmulticast {value}', - }, } } @@ -162,6 +157,10 @@ class TunnelIf(Interface): """ Get a synthetic MAC address. """ return self.get_mac_synthetic() + def set_multicast(self, enable): + """ Change the MULTICAST flag on the device """ + return self.set_interface('multicast', enable) + def update(self, config): """ General helper function which works on a dictionary retrived by get_config_dict(). It's main intention is to consolidate the scattered @@ -170,5 +169,10 @@ class TunnelIf(Interface): # Adjust iproute2 tunnel parameters if necessary self._change_options() + # IP Multicast + tmp = dict_search('enable_multicast', config) + value = 'enable' if (tmp != None) else 'disable' + self.set_multicast(value) + # call base class first super().update(config) diff --git a/python/vyos/ifconfig/vrrp.py b/python/vyos/ifconfig/vrrp.py index 47aaadecd..fde903a53 100644 --- a/python/vyos/ifconfig/vrrp.py +++ b/python/vyos/ifconfig/vrrp.py @@ -1,4 +1,4 @@ -# Copyright 2019 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2019-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -21,8 +21,11 @@ from time import time from time import sleep from tabulate import tabulate -from vyos import util from vyos.configquery import ConfigTreeQuery +from vyos.utils.convert import seconds_to_human +from vyos.utils.file import read_file +from vyos.utils.file import wait_for_file_write_complete +from vyos.utils.process import process_running class VRRPError(Exception): pass @@ -84,21 +87,21 @@ class VRRP(object): def is_running(cls): if not os.path.exists(cls.location['pid']): return False - return util.process_running(cls.location['pid']) + return process_running(cls.location['pid']) @classmethod def collect(cls, what): fname = cls.location[what] try: # send signal to generate the configuration file - pid = util.read_file(cls.location['pid']) - util.wait_for_file_write_complete(fname, + pid = read_file(cls.location['pid']) + wait_for_file_write_complete(fname, pre_hook=(lambda: os.kill(int(pid), cls._signal[what])), timeout=30) - return util.read_file(fname) + return read_file(fname) except OSError: - # raised by vyos.util.read_file + # raised by vyos.utils.file.read_file raise VRRPNoData("VRRP data is not available (wait time exceeded)") except FileNotFoundError: raise VRRPNoData("VRRP data is not available (process not running or no active groups)") @@ -145,7 +148,7 @@ class VRRP(object): priority = data['effective_priority'] since = int(time() - float(data['last_transition'])) - last = util.seconds_to_human(since) + last = seconds_to_human(since) groups.append([name, intf, vrid, state, priority, last]) diff --git a/python/vyos/ifconfig/vti.py b/python/vyos/ifconfig/vti.py index dc99d365a..9ebbeb9ed 100644 --- a/python/vyos/ifconfig/vti.py +++ b/python/vyos/ifconfig/vti.py @@ -14,7 +14,7 @@ # License along with this library. If not, see <http://www.gnu.org/licenses/>. from vyos.ifconfig.interface import Interface -from vyos.util import dict_search +from vyos.utils.dict import dict_search @Interface.register class VTIIf(Interface): diff --git a/python/vyos/ifconfig/vxlan.py b/python/vyos/ifconfig/vxlan.py index 5baff10a9..6a9911588 100644 --- a/python/vyos/ifconfig/vxlan.py +++ b/python/vyos/ifconfig/vxlan.py @@ -15,7 +15,7 @@ from vyos import ConfigError from vyos.ifconfig import Interface -from vyos.util import dict_search +from vyos.utils.dict import dict_search @Interface.register class VXLANIf(Interface): diff --git a/python/vyos/ifconfig/wireguard.py b/python/vyos/ifconfig/wireguard.py index fe5e9c519..4aac103ec 100644 --- a/python/vyos/ifconfig/wireguard.py +++ b/python/vyos/ifconfig/wireguard.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2019-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -25,6 +25,7 @@ from hurry.filesize import alternative from vyos.ifconfig import Interface from vyos.ifconfig import Operational from vyos.template import is_ipv6 +from vyos.base import Warning class WireGuardOperational(Operational): def _dump(self): @@ -184,7 +185,6 @@ class WireGuardIf(Interface): base_cmd += f' private-key {tmp_file.name}' base_cmd = base_cmd.format(**config) - if 'peer' in config: for peer, peer_config in config['peer'].items(): # T4702: No need to configure this peer when it was explicitly diff --git a/python/vyos/initialsetup.py b/python/vyos/initialsetup.py index 574e7892d..3b280dc6b 100644 --- a/python/vyos/initialsetup.py +++ b/python/vyos/initialsetup.py @@ -1,7 +1,7 @@ # initialsetup -- functions for setting common values in config file, # for use in installation and first boot scripts # -# Copyright (C) 2018 VyOS maintainers and contributors +# Copyright (C) 2018-2023 VyOS maintainers and contributors # # This library is free software; you can redistribute it and/or modify it under the terms of # the GNU Lesser General Public License as published by the Free Software Foundation; @@ -12,10 +12,12 @@ # See the GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License along with this library; -# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA import vyos.configtree -import vyos.authutils + +from vyos.utils.auth import make_password_hash +from vyos.utils.auth import split_ssh_public_key def set_interface_address(config, intf, addr, intf_type="ethernet"): config.set(["interfaces", intf_type, intf, "address"], value=addr) @@ -35,8 +37,8 @@ def set_default_gateway(config, gateway): def set_user_password(config, user, password): # Make a password hash - hash = vyos.authutils.make_password_hash(password) - + hash = make_password_hash(password) + config.set(["system", "login", "user", user, "authentication", "encrypted-password"], value=hash) config.set(["system", "login", "user", user, "authentication", "plaintext-password"], value="") @@ -48,7 +50,7 @@ def set_user_level(config, user, level): config.set(["system", "login", "user", user, "level"], value=level) def set_user_ssh_key(config, user, key_string): - key = vyos.authutils.split_ssh_public_key(key_string, defaultname=user) + key = split_ssh_public_key(key_string, defaultname=user) config.set(["system", "login", "user", user, "authentication", "public-keys", key["name"], "key"], value=key["data"]) config.set(["system", "login", "user", user, "authentication", "public-keys", key["name"], "type"], value=key["type"]) diff --git a/python/vyos/ipsec.py b/python/vyos/ipsec.py new file mode 100644 index 000000000..4603aab22 --- /dev/null +++ b/python/vyos/ipsec.py @@ -0,0 +1,183 @@ +# Copyright 2020-2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +#Package to communicate with Strongswan VICI + +class ViciInitiateError(Exception): + """ + VICI can't initiate a session. + """ + pass +class ViciCommandError(Exception): + """ + VICI can't execute a command by any reason. + """ + pass + +def get_vici_sas(): + from vici import Session as vici_session + + try: + session = vici_session() + except Exception: + raise ViciInitiateError("IPsec not initialized") + try: + sas = list(session.list_sas()) + return sas + except Exception: + raise ViciCommandError(f'Failed to get SAs') + +def get_vici_connections(): + from vici import Session as vici_session + + try: + session = vici_session() + except Exception: + raise ViciInitiateError("IPsec not initialized") + try: + connections = list(session.list_conns()) + return connections + except Exception: + raise ViciCommandError(f'Failed to get connections') + +def get_vici_sas_by_name(ike_name: str, tunnel: str) -> list: + """ + Find sas by IKE_SA name and/or CHILD_SA name + and return list of OrdinaryDicts with SASs info + If tunnel is not None return value is list of OrdenaryDicts contained only + CHILD_SAs wich names equal tunnel value. + :param ike_name: IKE SA name + :type ike_name: str + :param tunnel: CHILD SA name + :type tunnel: str + :return: list of Ordinary Dicts with SASs + :rtype: list + """ + from vici import Session as vici_session + + try: + session = vici_session() + except Exception: + raise ViciInitiateError("IPsec not initialized") + vici_dict = {} + if ike_name: + vici_dict['ike'] = ike_name + if tunnel: + vici_dict['child'] = tunnel + try: + sas = list(session.list_sas(vici_dict)) + return sas + except Exception: + raise ViciCommandError(f'Failed to get SAs') + + +def terminate_vici_ikeid_list(ike_id_list: list) -> None: + """ + Terminate IKE SAs by their id that contained in the list + :param ike_id_list: list of IKE SA id + :type ike_id_list: list + """ + from vici import Session as vici_session + + try: + session = vici_session() + except Exception: + raise ViciInitiateError("IPsec not initialized") + try: + for ikeid in ike_id_list: + session_generator = session.terminate( + {'ike-id': ikeid, 'timeout': '-1'}) + # a dummy `for` loop is required because of requirements + # from vici. Without a full iteration on the output, the + # command to vici may not be executed completely + for _ in session_generator: + pass + except Exception: + raise ViciCommandError( + f'Failed to terminate SA for IKE ids {ike_id_list}') + + +def terminate_vici_by_name(ike_name: str, child_name: str) -> None: + """ + Terminate IKE SAs by name if CHILD SA name is None. + Terminate CHILD SAs by name if CHILD SA name is specified + :param ike_name: IKE SA name + :type ike_name: str + :param child_name: CHILD SA name + :type child_name: str + """ + from vici import Session as vici_session + + try: + session = vici_session() + except Exception: + raise ViciInitiateError("IPsec not initialized") + try: + vici_dict: dict= {} + if ike_name: + vici_dict['ike'] = ike_name + if child_name: + vici_dict['child'] = child_name + session_generator = session.terminate(vici_dict) + # a dummy `for` loop is required because of requirements + # from vici. Without a full iteration on the output, the + # command to vici may not be executed completely + for _ in session_generator: + pass + except Exception: + if child_name: + raise ViciCommandError( + f'Failed to terminate SA for IPSEC {child_name}') + else: + raise ViciCommandError( + f'Failed to terminate SA for IKE {ike_name}') + + +def vici_initiate(ike_sa_name: str, child_sa_name: str, src_addr: str, + dst_addr: str) -> bool: + """Initiate IKE SA connection with specific peer + + Args: + ike_sa_name (str): an IKE SA connection name + child_sa_name (str): a child SA profile name + src_addr (str): source address + dst_addr (str): remote address + + Returns: + bool: a result of initiation command + """ + from vici import Session as vici_session + + try: + session = vici_session() + except Exception: + raise ViciInitiateError("IPsec not initialized") + + try: + session_generator = session.initiate({ + 'ike': ike_sa_name, + 'child': child_sa_name, + 'timeout': '-1', + 'my-host': src_addr, + 'other-host': dst_addr + }) + # a dummy `for` loop is required because of requirements + # from vici. Without a full iteration on the output, the + # command to vici may not be executed completely + for _ in session_generator: + pass + return True + except Exception: + raise ViciCommandError(f'Failed to initiate SA for IKE {ike_sa_name}')
\ No newline at end of file diff --git a/python/vyos/migrator.py b/python/vyos/migrator.py index 87c74e1ea..872682bc0 100644 --- a/python/vyos/migrator.py +++ b/python/vyos/migrator.py @@ -20,7 +20,7 @@ import logging import vyos.defaults import vyos.component_version as component_version -from vyos.util import cmd +from vyos.utils.process import cmd log_file = os.path.join(vyos.defaults.directories['config'], 'vyos-migrate.log') diff --git a/python/vyos/nat.py b/python/vyos/nat.py index 8a311045a..9cbc2b96e 100644 --- a/python/vyos/nat.py +++ b/python/vyos/nat.py @@ -15,7 +15,7 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. from vyos.template import is_ip_network -from vyos.util import dict_search_args +from vyos.utils.dict import dict_search_args from vyos.template import bracketize_ipv6 @@ -47,32 +47,42 @@ def parse_nat_rule(rule_conf, rule_id, nat_type, ipv6=False): protocol = '{ tcp, udp }' output.append(f'meta l4proto {protocol}') + if 'packet_type' in rule_conf: + output.append(f'pkttype ' + rule_conf['packet_type']) + if 'exclude' in rule_conf: translation_str = 'return' log_suffix = '-EXCL' elif 'translation' in rule_conf: - translation_prefix = nat_type[:1] - translation_output = [f'{translation_prefix}nat'] addr = dict_search_args(rule_conf, 'translation', 'address') port = dict_search_args(rule_conf, 'translation', 'port') + if 'redirect' in rule_conf['translation']: + translation_output = [f'redirect'] + redirect_port = dict_search_args(rule_conf, 'translation', 'redirect', 'port') + if redirect_port: + translation_output.append(f'to {redirect_port}') + else: - if addr and is_ip_network(addr): - if not ipv6: - map_addr = dict_search_args(rule_conf, nat_type, 'address') - translation_output.append(f'{ip_prefix} prefix to {ip_prefix} {translation_prefix}addr map {{ {map_addr} : {addr} }}') - ignore_type_addr = True + translation_prefix = nat_type[:1] + translation_output = [f'{translation_prefix}nat'] + + if addr and is_ip_network(addr): + if not ipv6: + map_addr = dict_search_args(rule_conf, nat_type, 'address') + translation_output.append(f'{ip_prefix} prefix to {ip_prefix} {translation_prefix}addr map {{ {map_addr} : {addr} }}') + ignore_type_addr = True + else: + translation_output.append(f'prefix to {addr}') + elif addr == 'masquerade': + if port: + addr = f'{addr} to ' + translation_output = [addr] + log_suffix = '-MASQ' else: - translation_output.append(f'prefix to {addr}') - elif addr == 'masquerade': - if port: - addr = f'{addr} to ' - translation_output = [addr] - log_suffix = '-MASQ' - else: - translation_output.append('to') - if addr: - addr = bracketize_ipv6(addr) - translation_output.append(addr) + translation_output.append('to') + if addr: + addr = bracketize_ipv6(addr) + translation_output.append(addr) options = [] addr_mapping = dict_search_args(rule_conf, 'translation', 'options', 'address_mapping') @@ -87,6 +97,39 @@ def parse_nat_rule(rule_conf, rule_id, nat_type, ipv6=False): if options: translation_str += f' {",".join(options)}' + if not ipv6 and 'backend' in rule_conf['load_balance']: + hash_input_items = [] + current_prob = 0 + nat_map = [] + + for trans_addr, addr in rule_conf['load_balance']['backend'].items(): + item_prob = int(addr['weight']) + upper_limit = current_prob + item_prob - 1 + hash_val = str(current_prob) + '-' + str(upper_limit) + element = hash_val + " : " + trans_addr + nat_map.append(element) + current_prob = current_prob + item_prob + + elements = ' , '.join(nat_map) + + if 'hash' in rule_conf['load_balance'] and 'random' in rule_conf['load_balance']['hash']: + translation_str += ' numgen random mod 100 map ' + '{ ' + f'{elements}' + ' }' + else: + for input_param in rule_conf['load_balance']['hash']: + if input_param == 'source-address': + param = 'ip saddr' + elif input_param == 'destination-address': + param = 'ip daddr' + elif input_param == 'source-port': + prot = rule_conf['protocol'] + param = f'{prot} sport' + elif input_param == 'destination-port': + prot = rule_conf['protocol'] + param = f'{prot} dport' + hash_input_items.append(param) + hash_input = ' . '.join(hash_input_items) + translation_str += f' jhash ' + f'{hash_input}' + ' mod 100 map ' + '{ ' + f'{elements}' + ' }' + for target in ['source', 'destination']: if target not in rule_conf: continue diff --git a/python/vyos/opmode.py b/python/vyos/opmode.py index 5ff768859..230a85541 100644 --- a/python/vyos/opmode.py +++ b/python/vyos/opmode.py @@ -1,4 +1,4 @@ -# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2022-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -22,6 +22,10 @@ from humps import decamelize class Error(Exception): """ Any error that makes requested operation impossible to complete for reasons unrelated to the user input or script logic. + + This is the base class, scripts should not use it directly + and should raise more specific errors instead, + whenever possible. """ pass @@ -45,6 +49,17 @@ class PermissionDenied(Error): """ pass +class InsufficientResources(Error): + """ Requested operation and its arguments are valid but the system + does not have enough resources (such as drive space or memory) + to complete it. + """ + pass + +class UnsupportedOperation(Error): + """ Requested operation is technically valid but is not implemented yet. """ + pass + class IncorrectValue(Error): """ Requested operation is valid, but an argument provided has an incorrect value, preventing successful completion. @@ -66,13 +81,13 @@ class InternalError(Error): def _is_op_mode_function_name(name): - if re.match(r"^(show|clear|reset|restart)", name): + if re.match(r"^(show|clear|reset|restart|add|delete|generate|set)", name): return True else: return False -def _is_show(name): - if re.match(r"^show", name): +def _capture_output(name): + if re.match(r"^(show|generate)", name): return True else: return False @@ -113,6 +128,25 @@ def _get_arg_type(t): else: return t +def _is_literal_type(t): + if _is_optional_type(t): + t = _get_arg_type(t) + + if typing.get_origin(t) == typing.Literal: + return True + + return False + +def _get_literal_values(t): + """ Returns the tuple of allowed values for a Literal type + """ + if not _is_literal_type(t): + return tuple() + if _is_optional_type(t): + t = _get_arg_type(t) + + return typing.get_args(t) + def _normalize_field_name(name): # Convert the name to string if it is not # (in some cases they may be numbers) @@ -175,13 +209,30 @@ def run(module): for opt in type_hints: th = type_hints[opt] + # Function argument names use underscores as separators + # but command-line options should use hyphens + # Without this, we'd get options like "--foo_bar" + opt = re.sub(r'_', '-', opt) + if _get_arg_type(th) == bool: subparser.add_argument(f"--{opt}", action='store_true') else: if _is_optional_type(th): - subparser.add_argument(f"--{opt}", type=_get_arg_type(th), default=None) + if _is_literal_type(th): + subparser.add_argument(f"--{opt}", + choices=list(_get_literal_values(th)), + default=None) + else: + subparser.add_argument(f"--{opt}", + type=_get_arg_type(th), default=None) else: - subparser.add_argument(f"--{opt}", type=_get_arg_type(th), required=True) + if _is_literal_type(th): + subparser.add_argument(f"--{opt}", + choices=list(_get_literal_values(th)), + required=True) + else: + subparser.add_argument(f"--{opt}", + type=_get_arg_type(th), required=True) # Get options as a dict rather than a namespace, # so that we can modify it and pack for passing to functions @@ -199,20 +250,23 @@ def run(module): # it would cause an extra argument error when we pass the dict to a function del args["subcommand"] - # Show commands must always get the "raw" argument, - # but other commands (clear/reset/restart) should not, + # Show and generate commands must always get the "raw" argument, + # but other commands (clear/reset/restart/add/delete) should not, # because they produce no output and it makes no sense for them. - if ("raw" not in args) and _is_show(function_name): + if ("raw" not in args) and _capture_output(function_name): args["raw"] = False - if re.match(r"^show", function_name): - # Show commands are slightly special: + if _capture_output(function_name): + # Show and generate commands are slightly special: # they may return human-formatted output # or a raw dict that we need to serialize in JSON for printing res = func(**args) if not args["raw"]: return res else: + if not isinstance(res, dict) and not isinstance(res, list): + raise InternalError(f"Bare literal is not an acceptable raw output, must be a list or an object.\ + The output was:{res}") res = decamelize(res) res = _normalize_field_names(res) from json import dumps diff --git a/python/vyos/pki.py b/python/vyos/pki.py index cd15e3878..792e24b76 100644 --- a/python/vyos/pki.py +++ b/python/vyos/pki.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright (C) 2021 VyOS maintainers and contributors +# Copyright (C) 2023 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as @@ -63,6 +63,18 @@ private_format_map = { 'OpenSSH': serialization.PrivateFormat.OpenSSH } +hash_map = { + 'sha256': hashes.SHA256, + 'sha384': hashes.SHA384, + 'sha512': hashes.SHA512, +} + +def get_certificate_fingerprint(cert, hash): + hash_algorithm = hash_map[hash]() + fp = cert.fingerprint(hash_algorithm) + + return fp.hex(':').upper() + def encode_certificate(cert): return cert.public_bytes(encoding=serialization.Encoding.PEM).decode('utf-8') diff --git a/python/vyos/qos/__init__.py b/python/vyos/qos/__init__.py new file mode 100644 index 000000000..a2980ccde --- /dev/null +++ b/python/vyos/qos/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase +from vyos.qos.cake import CAKE +from vyos.qos.droptail import DropTail +from vyos.qos.fairqueue import FairQueue +from vyos.qos.fqcodel import FQCodel +from vyos.qos.limiter import Limiter +from vyos.qos.netem import NetEm +from vyos.qos.priority import Priority +from vyos.qos.randomdetect import RandomDetect +from vyos.qos.ratelimiter import RateLimiter +from vyos.qos.roundrobin import RoundRobin +from vyos.qos.trafficshaper import TrafficShaper +from vyos.qos.trafficshaper import TrafficShaperHFSC diff --git a/python/vyos/qos/base.py b/python/vyos/qos/base.py new file mode 100644 index 000000000..d8bbfe970 --- /dev/null +++ b/python/vyos/qos/base.py @@ -0,0 +1,393 @@ +# Copyright 2022-2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +import os + +from vyos.base import Warning +from vyos.utils.process import cmd +from vyos.utils.dict import dict_search +from vyos.utils.file import read_file + +from vyos.utils.network import get_protocol_by_name + + +class QoSBase: + _debug = False + _direction = ['egress'] + _parent = 0xffff + _dsfields = { + "default": 0x0, + "lowdelay": 0x10, + "throughput": 0x08, + "reliability": 0x04, + "mincost": 0x02, + "priority": 0x20, + "immediate": 0x40, + "flash": 0x60, + "flash-override": 0x80, + "critical": 0x0A, + "internet": 0xC0, + "network": 0xE0, + "AF11": 0x28, + "AF12": 0x30, + "AF13": 0x38, + "AF21": 0x48, + "AF22": 0x50, + "AF23": 0x58, + "AF31": 0x68, + "AF32": 0x70, + "AF33": 0x78, + "AF41": 0x88, + "AF42": 0x90, + "AF43": 0x98, + "CS1": 0x20, + "CS2": 0x40, + "CS3": 0x60, + "CS4": 0x80, + "CS5": 0xA0, + "CS6": 0xC0, + "CS7": 0xE0, + "EF": 0xB8 + } + qostype = None + + def __init__(self, interface): + if os.path.exists('/tmp/vyos.qos.debug'): + self._debug = True + self._interface = interface + + def _cmd(self, command): + if self._debug: + print(f'DEBUG/QoS: {command}') + return cmd(command) + + def get_direction(self) -> list: + return self._direction + + def _get_class_max_id(self, config) -> int: + if 'class' in config: + tmp = list(config['class'].keys()) + tmp.sort(key=lambda ii: int(ii)) + return tmp[-1] + return None + + def _get_dsfield(self, value): + if value in self._dsfields: + return self._dsfields[value] + else: + return value + + def _build_base_qdisc(self, config : dict, cls_id : int): + """ + Add/replace qdisc for every class (also default is a class). This is + a genetic method which need an implementation "per" queue-type. + + This matches the old mapping as defined in Perl here: + https://github.com/vyos/vyatta-cfg-qos/blob/equuleus/lib/Vyatta/Qos/ShaperClass.pm#L223-L229 + """ + queue_type = dict_search('queue_type', config) + default_tc = f'tc qdisc replace dev {self._interface} parent {self._parent}:{cls_id:x}' + + if queue_type == 'priority': + handle = 0x4000 + cls_id + default_tc += f' handle {handle:x}: prio' + self._cmd(default_tc) + + queue_limit = dict_search('queue_limit', config) + for ii in range(1, 4): + tmp = f'tc qdisc replace dev {self._interface} parent {handle:x}:{ii:x} pfifo' + if queue_limit: tmp += f' limit {queue_limit}' + self._cmd(tmp) + + elif queue_type == 'fair-queue': + default_tc += f' sfq' + + tmp = dict_search('queue_limit', config) + if tmp: default_tc += f' limit {tmp}' + + self._cmd(default_tc) + + elif queue_type == 'fq-codel': + default_tc += f' fq_codel' + tmp = dict_search('codel_quantum', config) + if tmp: default_tc += f' quantum {tmp}' + + tmp = dict_search('flows', config) + if tmp: default_tc += f' flows {tmp}' + + tmp = dict_search('interval', config) + if tmp: default_tc += f' interval {tmp}' + + tmp = dict_search('interval', config) + if tmp: default_tc += f' interval {tmp}' + + tmp = dict_search('queue_limit', config) + if tmp: default_tc += f' limit {tmp}' + + tmp = dict_search('target', config) + if tmp: default_tc += f' target {tmp}' + + default_tc += f' noecn' + + self._cmd(default_tc) + + elif queue_type == 'random-detect': + default_tc += f' red' + + self._cmd(default_tc) + + elif queue_type == 'drop-tail': + default_tc += f' pfifo' + + tmp = dict_search('queue_limit', config) + if tmp: default_tc += f' limit {tmp}' + + self._cmd(default_tc) + + def _rate_convert(self, rate) -> int: + rates = { + 'bit' : 1, + 'kbit' : 1000, + 'mbit' : 1000000, + 'gbit' : 1000000000, + 'tbit' : 1000000000000, + } + + if rate == 'auto' or rate.endswith('%'): + speed = 10 + # Not all interfaces have valid entries in the speed file. PPPoE + # interfaces have the appropriate speed file, but you can not read it: + # cat: /sys/class/net/pppoe7/speed: Invalid argument + try: + speed = read_file(f'/sys/class/net/{self._interface}/speed') + if not speed.isnumeric(): + Warning('Interface speed cannot be determined (assuming 10 Mbit/s)') + if rate.endswith('%'): + percent = rate.rstrip('%') + speed = int(speed) * int(percent) // 100 + except: + pass + + return int(speed) *1000000 # convert to MBit/s + + rate_numeric = int(''.join([n for n in rate if n.isdigit()])) + rate_scale = ''.join([n for n in rate if not n.isdigit()]) + + if int(rate_numeric) <= 0: + raise ValueError(f'{rate_numeric} is not a valid bandwidth <= 0') + + if rate_scale: + return int(rate_numeric * rates[rate_scale]) + else: + # No suffix implies Kbps just as Cisco IOS + return int(rate_numeric * 1000) + + def update(self, config, direction, priority=None): + """ method must be called from derived class after it has completed qdisc setup """ + if self._debug: + import pprint + pprint.pprint(config) + + if 'class' in config: + for cls, cls_config in config['class'].items(): + self._build_base_qdisc(cls_config, int(cls)) + + # every match criteria has it's tc instance + filter_cmd_base = f'tc filter add dev {self._interface} parent {self._parent:x}:' + + if priority: + filter_cmd_base += f' prio {cls}' + elif 'priority' in cls_config: + prio = cls_config['priority'] + filter_cmd_base += f' prio {prio}' + + filter_cmd_base += ' protocol all' + + if 'match' in cls_config: + for index, (match, match_config) in enumerate(cls_config['match'].items(), start=1): + filter_cmd = filter_cmd_base + if self.qostype == 'shaper' and 'prio ' not in filter_cmd: + filter_cmd += f' prio {index}' + if 'mark' in match_config: + mark = match_config['mark'] + filter_cmd += f' handle {mark} fw' + + for af in ['ip', 'ipv6']: + tc_af = af + if af == 'ipv6': + tc_af = 'ip6' + + if af in match_config: + filter_cmd += ' u32' + + tmp = dict_search(f'{af}.source.address', match_config) + if tmp: filter_cmd += f' match {tc_af} src {tmp}' + + tmp = dict_search(f'{af}.source.port', match_config) + if tmp: filter_cmd += f' match {tc_af} sport {tmp} 0xffff' + + tmp = dict_search(f'{af}.destination.address', match_config) + if tmp: filter_cmd += f' match {tc_af} dst {tmp}' + + tmp = dict_search(f'{af}.destination.port', match_config) + if tmp: filter_cmd += f' match {tc_af} dport {tmp} 0xffff' + + tmp = dict_search(f'{af}.protocol', match_config) + if tmp: + tmp = get_protocol_by_name(tmp) + filter_cmd += f' match {tc_af} protocol {tmp} 0xff' + + tmp = dict_search(f'{af}.dscp', match_config) + if tmp: + tmp = self._get_dsfield(tmp) + if af == 'ip': + filter_cmd += f' match {tc_af} dsfield {tmp} 0xff' + elif af == 'ipv6': + filter_cmd += f' match u16 {tmp} 0x0ff0 at 0' + + # Will match against total length of an IPv4 packet and + # payload length of an IPv6 packet. + # + # IPv4 : match u16 0x0000 ~MAXLEN at 2 + # IPv6 : match u16 0x0000 ~MAXLEN at 4 + tmp = dict_search(f'{af}.max_length', match_config) + if tmp: + # We need the 16 bit two's complement of the maximum + # packet length + tmp = hex(0xffff & ~int(tmp)) + + if af == 'ip': + filter_cmd += f' match u16 0x0000 {tmp} at 2' + elif af == 'ipv6': + filter_cmd += f' match u16 0x0000 {tmp} at 4' + + # We match against specific TCP flags - we assume the IPv4 + # header length is 20 bytes and assume the IPv6 packet is + # not using extension headers (hence a ip header length of 40 bytes) + # TCP Flags are set on byte 13 of the TCP header. + # IPv4 : match u8 X X at 33 + # IPv6 : match u8 X X at 53 + # with X = 0x02 for SYN and X = 0x10 for ACK + tmp = dict_search(f'{af}.tcp', match_config) + if tmp: + mask = 0 + if 'ack' in tmp: + mask |= 0x10 + if 'syn' in tmp: + mask |= 0x02 + mask = hex(mask) + + if af == 'ip': + filter_cmd += f' match u8 {mask} {mask} at 33' + elif af == 'ipv6': + filter_cmd += f' match u8 {mask} {mask} at 53' + + cls = int(cls) + filter_cmd += f' flowid {self._parent:x}:{cls:x}' + self._cmd(filter_cmd) + + if any(tmp in ['exceed', 'bandwidth', 'burst'] for tmp in cls_config): + filter_cmd += f' action police' + + if 'exceed' in cls_config: + action = cls_config['exceed'] + filter_cmd += f' conform-exceed {action}' + if 'not_exceed' in cls_config: + action = cls_config['not_exceed'] + filter_cmd += f'/{action}' + + if 'bandwidth' in cls_config: + rate = self._rate_convert(cls_config['bandwidth']) + filter_cmd += f' rate {rate}' + + if 'burst' in cls_config: + burst = cls_config['burst'] + filter_cmd += f' burst {burst}' + cls = int(cls) + filter_cmd += f' flowid {self._parent:x}:{cls:x}' + self._cmd(filter_cmd) + + else: + + filter_cmd += ' basic' + + cls = int(cls) + filter_cmd += f' flowid {self._parent:x}:{cls:x}' + self._cmd(filter_cmd) + + + # The police block allows limiting of the byte or packet rate of + # traffic matched by the filter it is attached to. + # https://man7.org/linux/man-pages/man8/tc-police.8.html + + # T5295: We do not handle rate via tc filter directly, + # but rather set the tc filter to direct traffic to the correct tc class flow. + # + # if any(tmp in ['exceed', 'bandwidth', 'burst'] for tmp in cls_config): + # filter_cmd += f' action police' + # + # if 'exceed' in cls_config: + # action = cls_config['exceed'] + # filter_cmd += f' conform-exceed {action}' + # if 'not_exceed' in cls_config: + # action = cls_config['not_exceed'] + # filter_cmd += f'/{action}' + # + # if 'bandwidth' in cls_config: + # rate = self._rate_convert(cls_config['bandwidth']) + # filter_cmd += f' rate {rate}' + # + # if 'burst' in cls_config: + # burst = cls_config['burst'] + # filter_cmd += f' burst {burst}' + + if 'default' in config: + default_cls_id = 1 + if 'class' in config: + class_id_max = self._get_class_max_id(config) + default_cls_id = int(class_id_max) +1 + self._build_base_qdisc(config['default'], default_cls_id) + + if self.qostype == 'limiter': + if 'default' in config: + filter_cmd = f'tc filter replace dev {self._interface} parent {self._parent:x}: ' + filter_cmd += 'prio 255 protocol all basic' + + # The police block allows limiting of the byte or packet rate of + # traffic matched by the filter it is attached to. + # https://man7.org/linux/man-pages/man8/tc-police.8.html + if any(tmp in ['exceed', 'bandwidth', 'burst'] for tmp in + config['default']): + filter_cmd += f' action police' + + if 'exceed' in config['default']: + action = config['default']['exceed'] + filter_cmd += f' conform-exceed {action}' + if 'not_exceed' in config['default']: + action = config['default']['not_exceed'] + filter_cmd += f'/{action}' + + if 'bandwidth' in config['default']: + rate = self._rate_convert(config['default']['bandwidth']) + filter_cmd += f' rate {rate}' + + if 'burst' in config['default']: + burst = config['default']['burst'] + filter_cmd += f' burst {burst}' + + if 'class' in config: + filter_cmd += f' flowid {self._parent:x}:{default_cls_id:x}' + + self._cmd(filter_cmd) diff --git a/python/vyos/qos/cake.py b/python/vyos/qos/cake.py new file mode 100644 index 000000000..a89b1de1e --- /dev/null +++ b/python/vyos/qos/cake.py @@ -0,0 +1,55 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase + +class CAKE(QoSBase): + _direction = ['egress'] + + # https://man7.org/linux/man-pages/man8/tc-cake.8.html + def update(self, config, direction): + tmp = f'tc qdisc add dev {self._interface} root handle 1: cake {direction}' + if 'bandwidth' in config: + bandwidth = self._rate_convert(config['bandwidth']) + tmp += f' bandwidth {bandwidth}' + + if 'rtt' in config: + rtt = config['rtt'] + tmp += f' rtt {rtt}ms' + + if 'flow_isolation' in config: + if 'blind' in config['flow_isolation']: + tmp += f' flowblind' + if 'dst_host' in config['flow_isolation']: + tmp += f' dsthost' + if 'dual_dst_host' in config['flow_isolation']: + tmp += f' dual-dsthost' + if 'dual_src_host' in config['flow_isolation']: + tmp += f' dual-srchost' + if 'flow' in config['flow_isolation']: + tmp += f' flows' + if 'host' in config['flow_isolation']: + tmp += f' hosts' + if 'nat' in config['flow_isolation']: + tmp += f' nat' + if 'src_host' in config['flow_isolation']: + tmp += f' srchost ' + else: + tmp += f' nonat' + + self._cmd(tmp) + + # call base class + super().update(config, direction) diff --git a/python/vyos/qos/droptail.py b/python/vyos/qos/droptail.py new file mode 100644 index 000000000..427d43d19 --- /dev/null +++ b/python/vyos/qos/droptail.py @@ -0,0 +1,28 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase + +class DropTail(QoSBase): + # https://man7.org/linux/man-pages/man8/tc-pfifo.8.html + def update(self, config, direction): + tmp = f'tc qdisc add dev {self._interface} root pfifo' + if 'queue_limit' in config: + limit = config["queue_limit"] + tmp += f' limit {limit}' + self._cmd(tmp) + + # call base class + super().update(config, direction) diff --git a/python/vyos/qos/fairqueue.py b/python/vyos/qos/fairqueue.py new file mode 100644 index 000000000..f41d098fb --- /dev/null +++ b/python/vyos/qos/fairqueue.py @@ -0,0 +1,31 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase + +class FairQueue(QoSBase): + # https://man7.org/linux/man-pages/man8/tc-sfq.8.html + def update(self, config, direction): + tmp = f'tc qdisc add dev {self._interface} root sfq' + + if 'hash_interval' in config: + tmp += f' perturb {config["hash_interval"]}' + if 'queue_limit' in config: + tmp += f' limit {config["queue_limit"]}' + + self._cmd(tmp) + + # call base class + super().update(config, direction) diff --git a/python/vyos/qos/fqcodel.py b/python/vyos/qos/fqcodel.py new file mode 100644 index 000000000..cd2340aa2 --- /dev/null +++ b/python/vyos/qos/fqcodel.py @@ -0,0 +1,40 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase + +class FQCodel(QoSBase): + # https://man7.org/linux/man-pages/man8/tc-fq_codel.8.html + def update(self, config, direction): + tmp = f'tc qdisc add dev {self._interface} root fq_codel' + + if 'codel_quantum' in config: + tmp += f' quantum {config["codel_quantum"]}' + if 'flows' in config: + tmp += f' flows {config["flows"]}' + if 'interval' in config: + interval = int(config['interval']) * 1000 + tmp += f' interval {interval}' + if 'queue_limit' in config: + tmp += f' limit {config["queue_limit"]}' + if 'target' in config: + target = int(config['target']) * 1000 + tmp += f' target {target}' + + tmp += f' noecn' + self._cmd(tmp) + + # call base class + super().update(config, direction) diff --git a/python/vyos/qos/limiter.py b/python/vyos/qos/limiter.py new file mode 100644 index 000000000..3f5c11112 --- /dev/null +++ b/python/vyos/qos/limiter.py @@ -0,0 +1,28 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase + +class Limiter(QoSBase): + _direction = ['ingress'] + qostype = 'limiter' + + def update(self, config, direction): + tmp = f'tc qdisc add dev {self._interface} handle {self._parent:x}: {direction}' + self._cmd(tmp) + + # base class must be called last + super().update(config, direction) + diff --git a/python/vyos/qos/netem.py b/python/vyos/qos/netem.py new file mode 100644 index 000000000..8bdef300b --- /dev/null +++ b/python/vyos/qos/netem.py @@ -0,0 +1,53 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase + +class NetEm(QoSBase): + # https://man7.org/linux/man-pages/man8/tc-netem.8.html + def update(self, config, direction): + tmp = f'tc qdisc add dev {self._interface} root netem' + if 'bandwidth' in config: + rate = self._rate_convert(config["bandwidth"]) + tmp += f' rate {rate}' + + if 'queue_limit' in config: + limit = config["queue_limit"] + tmp += f' limit {limit}' + + if 'delay' in config: + delay = config["delay"] + tmp += f' delay {delay}ms' + + if 'loss' in config: + drop = config["loss"] + tmp += f' drop {drop}%' + + if 'corruption' in config: + corrupt = config["corruption"] + tmp += f' corrupt {corrupt}%' + + if 'reordering' in config: + reorder = config["reordering"] + tmp += f' reorder {reorder}%' + + if 'duplicate' in config: + duplicate = config["duplicate"] + tmp += f' duplicate {duplicate}%' + + self._cmd(tmp) + + # call base class + super().update(config, direction) diff --git a/python/vyos/qos/priority.py b/python/vyos/qos/priority.py new file mode 100644 index 000000000..8182400f9 --- /dev/null +++ b/python/vyos/qos/priority.py @@ -0,0 +1,41 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase +from vyos.utils.dict import dict_search + +class Priority(QoSBase): + _parent = 1 + + # https://man7.org/linux/man-pages/man8/tc-prio.8.html + def update(self, config, direction): + if 'class' in config: + class_id_max = self._get_class_max_id(config) + bands = int(class_id_max) +1 + + tmp = f'tc qdisc add dev {self._interface} root handle {self._parent:x}: prio bands {bands} priomap ' \ + f'{class_id_max} {class_id_max} {class_id_max} {class_id_max} ' \ + f'{class_id_max} {class_id_max} {class_id_max} {class_id_max} ' \ + f'{class_id_max} {class_id_max} {class_id_max} {class_id_max} ' \ + f'{class_id_max} {class_id_max} {class_id_max} {class_id_max} ' + self._cmd(tmp) + + for cls in config['class']: + cls = int(cls) + tmp = f'tc qdisc add dev {self._interface} parent {self._parent:x}:{cls:x} pfifo' + self._cmd(tmp) + + # base class must be called last + super().update(config, direction, priority=True) diff --git a/python/vyos/qos/randomdetect.py b/python/vyos/qos/randomdetect.py new file mode 100644 index 000000000..d7d84260f --- /dev/null +++ b/python/vyos/qos/randomdetect.py @@ -0,0 +1,54 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase + +class RandomDetect(QoSBase): + _parent = 1 + + # https://man7.org/linux/man-pages/man8/tc.8.html + def update(self, config, direction): + + tmp = f'tc qdisc add dev {self._interface} root handle {self._parent}:0 dsmark indices 8 set_tc_index' + self._cmd(tmp) + + tmp = f'tc filter add dev {self._interface} parent {self._parent}:0 protocol ip prio 1 tcindex mask 0xe0 shift 5' + self._cmd(tmp) + + # Generalized Random Early Detection + handle = self._parent +1 + tmp = f'tc qdisc add dev {self._interface} parent {self._parent}:0 handle {handle}:0 gred setup DPs 8 default 0 grio' + self._cmd(tmp) + + bandwidth = self._rate_convert(config['bandwidth']) + + # set VQ (virtual queue) parameters + for precedence, precedence_config in config['precedence'].items(): + precedence = int(precedence) + avg_pkt = int(precedence_config['average_packet']) + limit = int(precedence_config['queue_limit']) * avg_pkt + min_val = int(precedence_config['minimum_threshold']) * avg_pkt + max_val = int(precedence_config['maximum_threshold']) * avg_pkt + + tmp = f'tc qdisc change dev {self._interface} handle {handle}:0 gred limit {limit} min {min_val} max {max_val} avpkt {avg_pkt} ' + + burst = (2 * int(precedence_config['minimum_threshold']) + int(precedence_config['maximum_threshold'])) // 3 + probability = 1 / int(precedence_config['mark_probability']) + tmp += f'burst {burst} bandwidth {bandwidth} probability {probability} DP {precedence} prio {8 - precedence:x}' + + self._cmd(tmp) + + # call base class + super().update(config, direction) diff --git a/python/vyos/qos/ratelimiter.py b/python/vyos/qos/ratelimiter.py new file mode 100644 index 000000000..a4f80a1be --- /dev/null +++ b/python/vyos/qos/ratelimiter.py @@ -0,0 +1,37 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase + +class RateLimiter(QoSBase): + # https://man7.org/linux/man-pages/man8/tc-tbf.8.html + def update(self, config, direction): + # call base class + super().update(config, direction) + + tmp = f'tc qdisc add dev {self._interface} root tbf' + if 'bandwidth' in config: + rate = self._rate_convert(config['bandwidth']) + tmp += f' rate {rate}' + + if 'burst' in config: + burst = config['burst'] + tmp += f' burst {burst}' + + if 'latency' in config: + latency = config['latency'] + tmp += f' latency {latency}ms' + + self._cmd(tmp) diff --git a/python/vyos/qos/roundrobin.py b/python/vyos/qos/roundrobin.py new file mode 100644 index 000000000..80814ddfb --- /dev/null +++ b/python/vyos/qos/roundrobin.py @@ -0,0 +1,44 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.qos.base import QoSBase + +class RoundRobin(QoSBase): + _parent = 1 + + # https://man7.org/linux/man-pages/man8/tc-drr.8.html + def update(self, config, direction): + tmp = f'tc qdisc add dev {self._interface} root handle 1: drr' + self._cmd(tmp) + + if 'class' in config: + for cls in config['class']: + cls = int(cls) + tmp = f'tc class replace dev {self._interface} parent 1:1 classid 1:{cls:x} drr' + self._cmd(tmp) + + tmp = f'tc qdisc replace dev {self._interface} parent 1:{cls:x} pfifo' + self._cmd(tmp) + + if 'default' in config: + class_id_max = self._get_class_max_id(config) + default_cls_id = int(class_id_max) +1 + + # class ID via CLI is in range 1-4095, thus 1000 hex = 4096 + tmp = f'tc class replace dev {self._interface} parent 1:1 classid 1:{default_cls_id:x} drr' + self._cmd(tmp) + + # call base class + super().update(config, direction, priority=True) diff --git a/python/vyos/qos/trafficshaper.py b/python/vyos/qos/trafficshaper.py new file mode 100644 index 000000000..c63c7cf39 --- /dev/null +++ b/python/vyos/qos/trafficshaper.py @@ -0,0 +1,117 @@ +# Copyright 2022 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from math import ceil +from vyos.qos.base import QoSBase + +# Kernel limits on quantum (bytes) +MAXQUANTUM = 200000 +MINQUANTUM = 1000 + +class TrafficShaper(QoSBase): + _parent = 1 + qostype = 'shaper' + + # https://man7.org/linux/man-pages/man8/tc-htb.8.html + def update(self, config, direction): + class_id_max = 0 + if 'class' in config: + tmp = list(config['class']) + tmp.sort() + class_id_max = tmp[-1] + + r2q = 10 + # bandwidth is a mandatory CLI node + speed = self._rate_convert(config['bandwidth']) + speed_bps = int(speed) // 8 + + # need a bigger r2q if going fast than 16 mbits/sec + if (speed_bps // r2q) >= MAXQUANTUM: # integer division + r2q = ceil(speed_bps // MAXQUANTUM) + else: + # if there is a slow class then may need smaller value + if 'class' in config: + min_speed = speed_bps + for cls, cls_options in config['class'].items(): + # find class with the lowest bandwidth used + if 'bandwidth' in cls_options: + bw_bps = int(self._rate_convert(cls_options['bandwidth'])) // 8 # bandwidth in bytes per second + if bw_bps < min_speed: + min_speed = bw_bps + + while (r2q > 1) and (min_speed // r2q) < MINQUANTUM: + tmp = r2q -1 + if (speed_bps // tmp) >= MAXQUANTUM: + break + r2q = tmp + + + default_minor_id = int(class_id_max) +1 + tmp = f'tc qdisc replace dev {self._interface} root handle {self._parent:x}: htb r2q {r2q} default {default_minor_id:x}' # default is in hex + self._cmd(tmp) + + tmp = f'tc class replace dev {self._interface} parent {self._parent:x}: classid {self._parent:x}:1 htb rate {speed}' + self._cmd(tmp) + + if 'class' in config: + for cls, cls_config in config['class'].items(): + # class id is used later on and passed as hex, thus this needs to be an int + cls = int(cls) + + # bandwidth is a mandatory CLI node + # T5296 if bandwidth 'auto' or 'xx%' get value from config shaper total "bandwidth" + # i.e from set shaper test bandwidth '300mbit' + # without it, it tries to get value from qos.base /sys/class/net/{self._interface}/speed + if cls_config['bandwidth'] == 'auto': + rate = self._rate_convert(config['bandwidth']) + elif cls_config['bandwidth'].endswith('%'): + percent = cls_config['bandwidth'].rstrip('%') + rate = self._rate_convert(config['bandwidth']) * int(percent) // 100 + else: + rate = self._rate_convert(cls_config['bandwidth']) + + burst = cls_config['burst'] + quantum = cls_config['codel_quantum'] + + tmp = f'tc class replace dev {self._interface} parent {self._parent:x}:1 classid {self._parent:x}:{cls:x} htb rate {rate} burst {burst} quantum {quantum}' + if 'priority' in cls_config: + priority = cls_config['priority'] + tmp += f' prio {priority}' + self._cmd(tmp) + + tmp = f'tc qdisc replace dev {self._interface} parent {self._parent:x}:{cls:x} sfq' + self._cmd(tmp) + + if 'default' in config: + rate = self._rate_convert(config['default']['bandwidth']) + burst = config['default']['burst'] + quantum = config['default']['codel_quantum'] + tmp = f'tc class replace dev {self._interface} parent {self._parent:x}:1 classid {self._parent:x}:{default_minor_id:x} htb rate {rate} burst {burst} quantum {quantum}' + if 'priority' in config['default']: + priority = config['default']['priority'] + tmp += f' prio {priority}' + self._cmd(tmp) + + tmp = f'tc qdisc replace dev {self._interface} parent {self._parent:x}:{default_minor_id:x} sfq' + self._cmd(tmp) + + # call base class + super().update(config, direction) + +class TrafficShaperHFSC(TrafficShaper): + def update(self, config, direction): + # call base class + super().update(config, direction) + diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 66044fa52..cf731c881 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -25,22 +25,21 @@ import urllib.parse from ftplib import FTP from ftplib import FTP_TLS -from paramiko import SSHClient +from paramiko import SSHClient, SSHException from paramiko import MissingHostKeyPolicy from requests import Session from requests.adapters import HTTPAdapter from requests.packages.urllib3 import PoolManager -from vyos.util import ask_yes_no -from vyos.util import begin -from vyos.util import cmd -from vyos.util import make_incremental_progressbar -from vyos.util import make_progressbar -from vyos.util import print_error +from vyos.utils.io import ask_yes_no +from vyos.utils.io import make_incremental_progressbar +from vyos.utils.io import make_progressbar +from vyos.utils.io import print_error +from vyos.utils.misc import begin +from vyos.utils.process import cmd from vyos.version import get_version - CHUNK_SIZE = 8192 class InteractivePolicy(MissingHostKeyPolicy): @@ -51,7 +50,7 @@ class InteractivePolicy(MissingHostKeyPolicy): def missing_host_key(self, client, hostname, key): print_error(f"Host '{hostname}' not found in known hosts.") print_error('Fingerprint: ' + key.get_fingerprint().hex()) - if ask_yes_no('Do you wish to continue?'): + if sys.stdout.isatty() and ask_yes_no('Do you wish to continue?'): if client._host_keys_filename\ and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): client._host_keys.add(hostname, key.get_name(), key) @@ -97,7 +96,13 @@ def check_storage(path, size): class FtpC: - def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): + def __init__(self, + url, + progressbar=False, + check_space=False, + source_host='', + source_port=0, + timeout=10): self.secure = url.scheme == 'ftps' self.hostname = url.hostname self.path = url.path @@ -107,12 +112,15 @@ class FtpC: self.source = (source_host, source_port) self.progressbar = progressbar self.check_space = check_space + self.timeout = timeout def _establish(self): if self.secure: - return FTP_TLS(source_address=self.source, context=ssl.create_default_context()) + return FTP_TLS(source_address=self.source, + context=ssl.create_default_context(), + timeout=self.timeout) else: - return FTP(source_address=self.source) + return FTP(source_address=self.source, timeout=self.timeout) def download(self, location: str): # Open the file upfront before establishing connection. @@ -151,7 +159,13 @@ class FtpC: class SshC: known_hosts = os.path.expanduser('~/.ssh/known_hosts') - def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): + def __init__(self, + url, + progressbar=False, + check_space=False, + source_host='', + source_port=0, + timeout=10.0): self.hostname = url.hostname self.path = url.path self.username = url.username or os.getenv('REMOTE_USERNAME') @@ -160,6 +174,7 @@ class SshC: self.source = (source_host, source_port) self.progressbar = progressbar self.check_space = check_space + self.timeout = timeout def _establish(self): ssh = SSHClient() @@ -170,7 +185,7 @@ class SshC: ssh.set_missing_host_key_policy(InteractivePolicy()) # `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family # for us on dual-stack systems. - sock = socket.create_connection((self.hostname, self.port), socket.getdefaulttimeout(), self.source) + sock = socket.create_connection((self.hostname, self.port), self.timeout, self.source) ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock) return ssh @@ -199,13 +214,20 @@ class SshC: class HttpC: - def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): + def __init__(self, + url, + progressbar=False, + check_space=False, + source_host='', + source_port=0, + timeout=10.0): self.urlstring = urllib.parse.urlunsplit(url) self.progressbar = progressbar self.check_space = check_space self.source_pair = (source_host, source_port) self.username = url.username or os.getenv('REMOTE_USERNAME') self.password = url.password or os.getenv('REMOTE_PASSWORD') + self.timeout = timeout def _establish(self): session = Session() @@ -221,8 +243,11 @@ class HttpC: # Not only would it potentially mess up with the progress bar but # `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding. s.headers.update({'Accept-Encoding': 'identity'}) - with s.head(self.urlstring, allow_redirects=True) as r: + with s.head(self.urlstring, + allow_redirects=True, + timeout=self.timeout) as r: # Abort early if the destination is inaccessible. + print('pre-3') r.raise_for_status() # If the request got redirected, keep the last URL we ended up with. final_urlstring = r.url @@ -236,7 +261,8 @@ class HttpC: size = None if self.check_space: check_storage(location, size) - with s.get(final_urlstring, stream=True) as r, open(location, 'wb') as f: + with s.get(final_urlstring, stream=True, + timeout=self.timeout) as r, open(location, 'wb') as f: if self.progressbar and size: progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) @@ -250,7 +276,10 @@ class HttpC: def upload(self, location: str): # Does not yet support progressbars. with self._establish() as s, open(location, 'rb') as f: - s.post(self.urlstring, data=f, allow_redirects=True) + s.post(self.urlstring, + data=f, + allow_redirects=True, + timeout=self.timeout) class TftpC: @@ -259,10 +288,16 @@ class TftpC: # 2. Since there's no concept authentication, we don't need to deal with keys/passwords. # 3. It would be a waste to import, audit and maintain a third-party library for TFTP. # 4. I'd rather not implement the entire protocol here, no matter how simple it is. - def __init__(self, url, progressbar=False, check_space=False, source_host=None, source_port=0): + def __init__(self, + url, + progressbar=False, + check_space=False, + source_host=None, + source_port=0, + timeout=10): source_option = f'--interface {source_host} --local-port {source_port}' if source_host else '' progress_flag = '--progress-bar' if progressbar else '-s' - self.command = f'curl {source_option} {progress_flag}' + self.command = f'curl {source_option} {progress_flag} --connect-timeout {timeout}' self.urlstring = urllib.parse.urlunsplit(url) def download(self, location: str): @@ -287,10 +322,16 @@ def urlc(urlstring, *args, **kwargs): raise ValueError(f'Unsupported URL scheme: "{url.scheme}"') def download(local_path, urlstring, *args, **kwargs): - urlc(urlstring, *args, **kwargs).download(local_path) + try: + urlc(urlstring, *args, **kwargs).download(local_path) + except Exception as err: + print_error(f'Unable to download "{urlstring}": {err}') def upload(local_path, urlstring, *args, **kwargs): - urlc(urlstring, *args, **kwargs).upload(local_path) + try: + urlc(urlstring, *args, **kwargs).upload(local_path) + except Exception as err: + print_error(f'Unable to upload "{urlstring}": {err}') def get_remote_config(urlstring, source_host='', source_port=0): """ diff --git a/python/vyos/template.py b/python/vyos/template.py index 2a4135f9e..c1b57b883 100644 --- a/python/vyos/template.py +++ b/python/vyos/template.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2019-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -20,10 +20,10 @@ from jinja2 import Environment from jinja2 import FileSystemLoader from jinja2 import ChainableUndefined from vyos.defaults import directories -from vyos.util import chmod -from vyos.util import chown -from vyos.util import dict_search_args -from vyos.util import makedir +from vyos.utils.dict import dict_search_args +from vyos.utils.file import makedir +from vyos.utils.permission import chmod +from vyos.utils.permission import chown # Holds template filters registered via register_filter() _FILTERS = {} @@ -44,6 +44,7 @@ def _get_environment(location=None): loader=loc_loader, trim_blocks=True, undefined=ChainableUndefined, + extensions=['jinja2.ext.loopcontrols'] ) env.filters.update(_FILTERS) env.tests.update(_TESTS) @@ -158,6 +159,24 @@ def force_to_list(value): else: return [value] +@register_filter('seconds_to_human') +def seconds_to_human(seconds, separator=""): + """ Convert seconds to human-readable values like 1d6h15m23s """ + from vyos.utils.convert import seconds_to_human + return seconds_to_human(seconds, separator=separator) + +@register_filter('bytes_to_human') +def bytes_to_human(bytes, initial_exponent=0, precision=2): + """ Convert bytes to human-readable values like 1.44M """ + from vyos.utils.convert import bytes_to_human + return bytes_to_human(bytes, initial_exponent=initial_exponent, precision=precision) + +@register_filter('human_to_bytes') +def human_to_bytes(value): + """ Convert a data amount with a unit suffix to bytes, like 2K to 2048 """ + from vyos.utils.convert import human_to_bytes + return human_to_bytes(value) + @register_filter('ip_from_cidr') def ip_from_cidr(prefix): """ Take an IPv4/IPv6 CIDR host and strip cidr mask. @@ -193,6 +212,16 @@ def dot_colon_to_dash(text): text = text.replace(".", "-") return text +@register_filter('generate_uuid4') +def generate_uuid4(text): + """ Generate random unique ID + Example: + % uuid4() + UUID('958ddf6a-ef14-4e81-8cfb-afb12456d1c5') + """ + from uuid import uuid4 + return uuid4() + @register_filter('netmask_from_cidr') def netmask_from_cidr(prefix): """ Take CIDR prefix and convert the prefix length to a "subnet mask". @@ -391,11 +420,11 @@ def get_dhcp_router(interface): Returns False of no router is found, returns the IP address as string if a router is found. """ - lease_file = f'/var/lib/dhcp/dhclient_{interface}.leases' + lease_file = directories['isc_dhclient_dir'] + f'/dhclient_{interface}.leases' if not os.path.exists(lease_file): return None - from vyos.util import read_file + from vyos.utils.file import read_file for line in read_file(lease_file).splitlines(): if 'option routers' in line: (_, _, address) = line.split() @@ -476,6 +505,8 @@ def get_esp_ike_cipher(group_config, ike_group=None): continue tmp = '{encryption}-{hash}'.format(**proposal) + if 'prf' in proposal: + tmp += '-' + proposal['prf'] if 'dh_group' in proposal: tmp += '-' + pfs_lut[ 'dh-group' + proposal['dh_group'] ] elif 'pfs' in group_config and group_config['pfs'] != 'disable': @@ -543,9 +574,9 @@ def nft_action(vyos_action): return vyos_action @register_filter('nft_rule') -def nft_rule(rule_conf, fw_name, rule_id, ip_name='ip'): +def nft_rule(rule_conf, fw_hook, fw_name, rule_id, ip_name='ip'): from vyos.firewall import parse_rule - return parse_rule(rule_conf, fw_name, rule_id, ip_name) + return parse_rule(rule_conf, fw_hook, fw_name, rule_id, ip_name) @register_filter('nft_default_rule') def nft_default_rule(fw_conf, fw_name, ipv6=False): @@ -556,7 +587,8 @@ def nft_default_rule(fw_conf, fw_name, ipv6=False): action_suffix = default_action[:1].upper() output.append(f'log prefix "[{fw_name[:19]}-default-{action_suffix}]"') - output.append(nft_action(default_action)) + #output.append(nft_action(default_action)) + output.append(f'{default_action}') if 'default_jump_target' in fw_conf: target = fw_conf['default_jump_target'] def_suffix = '6' if ipv6 else '' @@ -631,9 +663,104 @@ def nat_static_rule(rule_conf, rule_id, nat_type): from vyos.nat import parse_nat_static_rule return parse_nat_static_rule(rule_conf, rule_id, nat_type) +@register_filter('conntrack_ignore_rule') +def conntrack_ignore_rule(rule_conf, rule_id, ipv6=False): + ip_prefix = 'ip6' if ipv6 else 'ip' + def_suffix = '6' if ipv6 else '' + output = [] + + if 'inbound_interface' in rule_conf: + ifname = rule_conf['inbound_interface'] + output.append(f'iifname {ifname}') + + if 'protocol' in rule_conf: + proto = rule_conf['protocol'] + output.append(f'meta l4proto {proto}') + + for side in ['source', 'destination']: + if side in rule_conf: + side_conf = rule_conf[side] + prefix = side[0] + + if 'address' in side_conf: + address = side_conf['address'] + operator = '' + if address[0] == '!': + operator = '!=' + address = address[1:] + output.append(f'{ip_prefix} {prefix}addr {operator} {address}') + + if 'port' in side_conf: + port = side_conf['port'] + operator = '' + if port[0] == '!': + operator = '!=' + port = port[1:] + output.append(f'th {prefix}port {operator} {port}') + + if 'group' in side_conf: + group = side_conf['group'] + + if 'address_group' in group: + group_name = group['address_group'] + operator = '' + if group_name[0] == '!': + operator = '!=' + group_name = group_name[1:] + output.append(f'{ip_prefix} {prefix}addr {operator} @A{def_suffix}_{group_name}') + # Generate firewall group domain-group + elif 'domain_group' in group: + group_name = group['domain_group'] + operator = '' + if group_name[0] == '!': + operator = '!=' + group_name = group_name[1:] + output.append(f'{ip_prefix} {prefix}addr {operator} @D_{group_name}') + elif 'network_group' in group: + group_name = group['network_group'] + operator = '' + if group_name[0] == '!': + operator = '!=' + group_name = group_name[1:] + output.append(f'{ip_prefix} {prefix}addr {operator} @N{def_suffix}_{group_name}') + if 'port_group' in group: + group_name = group['port_group'] + + if proto == 'tcp_udp': + proto = 'th' + + operator = '' + if group_name[0] == '!': + operator = '!=' + group_name = group_name[1:] + + output.append(f'{proto} {prefix}port {operator} @P_{group_name}') + + output.append('counter notrack') + output.append(f'comment "ignore-{rule_id}"') + + return " ".join(output) + @register_filter('range_to_regex') def range_to_regex(num_range): + """Convert range of numbers or list of ranges + to regex + + % range_to_regex('11-12') + '(1[1-2])' + % range_to_regex(['11-12', '14-15']) + '(1[1-2]|1[4-5])' + """ from vyos.range_regex import range_to_regex + if isinstance(num_range, list): + data = [] + for entry in num_range: + if '-' not in entry: + data.append(entry) + else: + data.append(range_to_regex(entry)) + return f'({"|".join(data)})' + if '-' not in num_range: return num_range diff --git a/python/vyos/util.py b/python/vyos/util.py deleted file mode 100644 index 6a828c0ac..000000000 --- a/python/vyos/util.py +++ /dev/null @@ -1,1156 +0,0 @@ -# Copyright 2020-2022 VyOS maintainers and contributors <maintainers@vyos.io> -# -# This library is free software; you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public -# License as published by the Free Software Foundation; either -# version 2.1 of the License, or (at your option) any later version. -# -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public -# License along with this library. If not, see <http://www.gnu.org/licenses/>. - -import os -import re -import sys - -# -# NOTE: Do not import full classes here, move your import to the function -# where it is used so it is as local as possible to the execution -# - -from subprocess import Popen -from subprocess import PIPE -from subprocess import STDOUT -from subprocess import DEVNULL - -def popen(command, flag='', shell=None, input=None, timeout=None, env=None, - stdout=PIPE, stderr=PIPE, decode='utf-8'): - """ - popen is a wrapper helper aound subprocess.Popen - with it default setting it will return a tuple (out, err) - out: the output of the program run - err: the error code returned by the program - - it can be affected by the following flags: - shell: do not try to auto-detect if a shell is required - for example if a pipe (|) or redirection (>, >>) is used - input: data to sent to the child process via STDIN - the data should be bytes but string will be converted - timeout: time after which the command will be considered to have failed - env: mapping that defines the environment variables for the new process - stdout: define how the output of the program should be handled - - PIPE (default), sends stdout to the output - - DEVNULL, discard the output - stderr: define how the output of the program should be handled - - None (default), send/merge the data to/with stderr - - PIPE, popen will append it to output - - STDOUT, send the data to be merged with stdout - - DEVNULL, discard the output - decode: specify the expected text encoding (utf-8, ascii, ...) - the default is explicitely utf-8 which is python's own default - - usage: - get both stdout and stderr: popen('command', stdout=PIPE, stderr=STDOUT) - discard stdout and get stderr: popen('command', stdout=DEVNUL, stderr=PIPE) - """ - - # airbag must be left as an import in the function as otherwise we have a - # a circual import dependency - from vyos import debug - from vyos import airbag - - # log if the flag is set, otherwise log if command is set - if not debug.enabled(flag): - flag = 'command' - - cmd_msg = f"cmd '{command}'" - debug.message(cmd_msg, flag) - - use_shell = shell - stdin = None - if shell is None: - use_shell = False - if ' ' in command: - use_shell = True - if env: - use_shell = True - - if input: - stdin = PIPE - input = input.encode() if type(input) is str else input - - p = Popen(command, stdin=stdin, stdout=stdout, stderr=stderr, - env=env, shell=use_shell) - - pipe = p.communicate(input, timeout) - - pipe_out = b'' - if stdout == PIPE: - pipe_out = pipe[0] - - pipe_err = b'' - if stderr == PIPE: - pipe_err = pipe[1] - - str_out = pipe_out.decode(decode).replace('\r\n', '\n').strip() - str_err = pipe_err.decode(decode).replace('\r\n', '\n').strip() - - out_msg = f"returned (out):\n{str_out}" - if str_out: - debug.message(out_msg, flag) - - if str_err: - err_msg = f"returned (err):\n{str_err}" - # this message will also be send to syslog via airbag - debug.message(err_msg, flag, destination=sys.stderr) - - # should something go wrong, report this too via airbag - airbag.noteworthy(cmd_msg) - airbag.noteworthy(out_msg) - airbag.noteworthy(err_msg) - - return str_out, p.returncode - - -def run(command, flag='', shell=None, input=None, timeout=None, env=None, - stdout=DEVNULL, stderr=PIPE, decode='utf-8'): - """ - A wrapper around popen, which discard the stdout and - will return the error code of a command - """ - _, code = popen( - command, flag, - stdout=stdout, stderr=stderr, - input=input, timeout=timeout, - env=env, shell=shell, - decode=decode, - ) - return code - - -def cmd(command, flag='', shell=None, input=None, timeout=None, env=None, - stdout=PIPE, stderr=PIPE, decode='utf-8', raising=None, message='', - expect=[0]): - """ - A wrapper around popen, which returns the stdout and - will raise the error code of a command - - raising: specify which call should be used when raising - the class should only require a string as parameter - (default is OSError) with the error code - expect: a list of error codes to consider as normal - """ - decoded, code = popen( - command, flag, - stdout=stdout, stderr=stderr, - input=input, timeout=timeout, - env=env, shell=shell, - decode=decode, - ) - if code not in expect: - feedback = message + '\n' if message else '' - feedback += f'failed to run command: {command}\n' - feedback += f'returned: {decoded}\n' - feedback += f'exit code: {code}' - if raising is None: - # error code can be recovered with .errno - raise OSError(code, feedback) - else: - raise raising(feedback) - return decoded - - -def rc_cmd(command, flag='', shell=None, input=None, timeout=None, env=None, - stdout=PIPE, stderr=STDOUT, decode='utf-8'): - """ - A wrapper around popen, which returns the return code - of a command and stdout - - % rc_cmd('uname') - (0, 'Linux') - % rc_cmd('ip link show dev eth99') - (1, 'Device "eth99" does not exist.') - """ - out, code = popen( - command, flag, - stdout=stdout, stderr=stderr, - input=input, timeout=timeout, - env=env, shell=shell, - decode=decode, - ) - return code, out - - -def call(command, flag='', shell=None, input=None, timeout=None, env=None, - stdout=PIPE, stderr=PIPE, decode='utf-8'): - """ - A wrapper around popen, which print the stdout and - will return the error code of a command - """ - out, code = popen( - command, flag, - stdout=stdout, stderr=stderr, - input=input, timeout=timeout, - env=env, shell=shell, - decode=decode, - ) - if out: - print(out) - return code - - -def read_file(fname, defaultonfailure=None): - """ - read the content of a file, stripping any end characters (space, newlines) - should defaultonfailure be not None, it is returned on failure to read - """ - try: - """ Read a file to string """ - with open(fname, 'r') as f: - data = f.read().strip() - return data - except Exception as e: - if defaultonfailure is not None: - return defaultonfailure - raise e - -def write_file(fname, data, defaultonfailure=None, user=None, group=None, mode=None, append=False): - """ - Write content of data to given fname, should defaultonfailure be not None, - it is returned on failure to read. - - If directory of file is not present, it is auto-created. - """ - dirname = os.path.dirname(fname) - if not os.path.isdir(dirname): - os.makedirs(dirname, mode=0o755, exist_ok=False) - chown(dirname, user, group) - - try: - """ Write a file to string """ - bytes = 0 - with open(fname, 'w' if not append else 'a') as f: - bytes = f.write(data) - chown(fname, user, group) - chmod(fname, mode) - return bytes - except Exception as e: - if defaultonfailure is not None: - return defaultonfailure - raise e - -def read_json(fname, defaultonfailure=None): - """ - read and json decode the content of a file - should defaultonfailure be not None, it is returned on failure to read - """ - import json - try: - with open(fname, 'r') as f: - data = json.load(f) - return data - except Exception as e: - if defaultonfailure is not None: - return defaultonfailure - raise e - - -def chown(path, user, group): - """ change file/directory owner """ - from pwd import getpwnam - from grp import getgrnam - - if user is None or group is None: - return False - - # path may also be an open file descriptor - if not isinstance(path, int) and not os.path.exists(path): - return False - - uid = getpwnam(user).pw_uid - gid = getgrnam(group).gr_gid - os.chown(path, uid, gid) - return True - - -def chmod(path, bitmask): - # path may also be an open file descriptor - if not isinstance(path, int) and not os.path.exists(path): - return - if bitmask is None: - return - os.chmod(path, bitmask) - - -def chmod_600(path): - """ make file only read/writable by owner """ - from stat import S_IRUSR, S_IWUSR - - bitmask = S_IRUSR | S_IWUSR - chmod(path, bitmask) - - -def chmod_750(path): - """ make file/directory only executable to user and group """ - from stat import S_IRUSR, S_IWUSR, S_IXUSR, S_IRGRP, S_IXGRP - - bitmask = S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IXGRP - chmod(path, bitmask) - - -def chmod_755(path): - """ make file executable by all """ - from stat import S_IRUSR, S_IWUSR, S_IXUSR, S_IRGRP, S_IXGRP, S_IROTH, S_IXOTH - - bitmask = S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IXGRP | \ - S_IROTH | S_IXOTH - chmod(path, bitmask) - - -def makedir(path, user=None, group=None): - if os.path.exists(path): - return - os.makedirs(path, mode=0o755) - chown(path, user, group) - -def colon_separated_to_dict(data_string, uniquekeys=False): - """ Converts a string containing newline-separated entries - of colon-separated key-value pairs into a dict. - - Such files are common in Linux /proc filesystem - - Args: - data_string (str): data string - uniquekeys (bool): whether to insist that keys are unique or not - - Returns: dict - - Raises: - ValueError: if uniquekeys=True and the data string has - duplicate keys. - - Note: - If uniquekeys=True, then dict entries are always strings, - otherwise they are always lists of strings. - """ - import re - key_value_re = re.compile('([^:]+)\s*\:\s*(.*)') - - data_raw = re.split('\n', data_string) - - data = {} - - for l in data_raw: - l = l.strip() - if l: - match = re.match(key_value_re, l) - if match: - key = match.groups()[0].strip() - value = match.groups()[1].strip() - if key in data.keys(): - if uniquekeys: - raise ValueError("Data string has duplicate keys: {0}".format(key)) - else: - data[key].append(value) - else: - if uniquekeys: - data[key] = value - else: - data[key] = [value] - else: - pass - - return data - -def _mangle_dict_keys(data, regex, replacement, abs_path=[], no_tag_node_value_mangle=False, mod=0): - """ Mangles dict keys according to a regex and replacement character. - Some libraries like Jinja2 do not like certain characters in dict keys. - This function can be used for replacing all offending characters - with something acceptable. - - Args: - data (dict): Original dict to mangle - - Returns: dict - """ - from vyos.xml import is_tag - - new_dict = {} - - for key in data.keys(): - save_mod = mod - save_path = abs_path[:] - - abs_path.append(key) - - if not is_tag(abs_path): - new_key = re.sub(regex, replacement, key) - else: - if mod%2: - new_key = key - else: - new_key = re.sub(regex, replacement, key) - if no_tag_node_value_mangle: - mod += 1 - - value = data[key] - - if isinstance(value, dict): - new_dict[new_key] = _mangle_dict_keys(value, regex, replacement, abs_path=abs_path, mod=mod, no_tag_node_value_mangle=no_tag_node_value_mangle) - else: - new_dict[new_key] = value - - mod = save_mod - abs_path = save_path[:] - - return new_dict - -def mangle_dict_keys(data, regex, replacement, abs_path=[], no_tag_node_value_mangle=False): - return _mangle_dict_keys(data, regex, replacement, abs_path=abs_path, no_tag_node_value_mangle=no_tag_node_value_mangle, mod=0) - -def _get_sub_dict(d, lpath): - k = lpath[0] - if k not in d.keys(): - return {} - c = {k: d[k]} - lpath = lpath[1:] - if not lpath: - return c - elif not isinstance(c[k], dict): - return {} - return _get_sub_dict(c[k], lpath) - -def get_sub_dict(source, lpath, get_first_key=False): - """ Returns the sub-dict of a nested dict, defined by path of keys. - - Args: - source (dict): Source dict to extract from - lpath (list[str]): sequence of keys - - Returns: source, if lpath is empty, else - {key : source[..]..[key]} for key the last element of lpath, if exists - {} otherwise - """ - if not isinstance(source, dict): - raise TypeError("source must be of type dict") - if not isinstance(lpath, list): - raise TypeError("path must be of type list") - if not lpath: - return source - - ret = _get_sub_dict(source, lpath) - - if get_first_key and lpath and ret: - tmp = next(iter(ret.values())) - if not isinstance(tmp, dict): - raise TypeError("Data under node is not of type dict") - ret = tmp - - return ret - -def process_running(pid_file): - """ Checks if a process with PID in pid_file is running """ - from psutil import pid_exists - if not os.path.isfile(pid_file): - return False - with open(pid_file, 'r') as f: - pid = f.read().strip() - return pid_exists(int(pid)) - -def process_named_running(name): - """ Checks if process with given name is running and returns its PID. - If Process is not running, return None - """ - from psutil import process_iter - for p in process_iter(): - if name in p.name(): - return p.pid - return None - -def is_list_equal(first: list, second: list) -> bool: - """ Check if 2 lists are equal and list not empty """ - if len(first) != len(second) or len(first) == 0: - return False - return sorted(first) == sorted(second) - -def is_listen_port_bind_service(port: int, service: str) -> bool: - """Check if listen port bound to expected program name - :param port: Bind port - :param service: Program name - :return: bool - - Example: - % is_listen_port_bind_service(443, 'nginx') - True - % is_listen_port_bind_service(443, 'ocservr-main') - False - """ - from psutil import net_connections as connections - from psutil import Process as process - for connection in connections(): - addr = connection.laddr - pid = connection.pid - pid_name = process(pid).name() - pid_port = addr.port - if service == pid_name and port == pid_port: - return True - return False - -def seconds_to_human(s, separator=""): - """ Converts number of seconds passed to a human-readable - interval such as 1w4d18h35m59s - """ - s = int(s) - - week = 60 * 60 * 24 * 7 - day = 60 * 60 * 24 - hour = 60 * 60 - - remainder = 0 - result = "" - - weeks = s // week - if weeks > 0: - result = "{0}w".format(weeks) - s = s % week - - days = s // day - if days > 0: - result = "{0}{1}{2}d".format(result, separator, days) - s = s % day - - hours = s // hour - if hours > 0: - result = "{0}{1}{2}h".format(result, separator, hours) - s = s % hour - - minutes = s // 60 - if minutes > 0: - result = "{0}{1}{2}m".format(result, separator, minutes) - s = s % 60 - - seconds = s - if seconds > 0: - result = "{0}{1}{2}s".format(result, separator, seconds) - - return result - -def bytes_to_human(bytes, initial_exponent=0, precision=2): - """ Converts a value in bytes to a human-readable size string like 640 KB - - The initial_exponent parameter is the exponent of 2, - e.g. 10 (1024) for kilobytes, 20 (1024 * 1024) for megabytes. - """ - - if bytes == 0: - return "0 B" - - from math import log2 - - bytes = bytes * (2**initial_exponent) - - # log2 is a float, while range checking requires an int - exponent = int(log2(bytes)) - - if exponent < 10: - value = bytes - suffix = "B" - elif exponent in range(10, 20): - value = bytes / 1024 - suffix = "KB" - elif exponent in range(20, 30): - value = bytes / 1024**2 - suffix = "MB" - elif exponent in range(30, 40): - value = bytes / 1024**3 - suffix = "GB" - else: - value = bytes / 1024**4 - suffix = "TB" - # Add a new case when the first machine with petabyte RAM - # hits the market. - - size_string = "{0:.{1}f} {2}".format(value, precision, suffix) - return size_string - -def human_to_bytes(value): - """ Converts a data amount with a unit suffix to bytes, like 2K to 2048 """ - - from re import match as re_match - - res = re_match(r'^\s*(\d+(?:\.\d+)?)\s*([a-zA-Z]+)\s*$', value) - - if not res: - raise ValueError(f"'{value}' is not a valid data amount") - else: - amount = float(res.group(1)) - unit = res.group(2).lower() - - if unit == 'b': - res = amount - elif (unit == 'k') or (unit == 'kb'): - res = amount * 1024 - elif (unit == 'm') or (unit == 'mb'): - res = amount * 1024**2 - elif (unit == 'g') or (unit == 'gb'): - res = amount * 1024**3 - elif (unit == 't') or (unit == 'tb'): - res = amount * 1024**4 - else: - raise ValueError(f"Unsupported data unit '{unit}'") - - # There cannot be fractional bytes, so we convert them to integer. - # However, truncating causes problems with conversion back to human unit, - # so we round instead -- that seems to work well enough. - return round(res) - -def get_cfg_group_id(): - from grp import getgrnam - from vyos.defaults import cfg_group - - group_data = getgrnam(cfg_group) - return group_data.gr_gid - - -def file_is_persistent(path): - import re - location = r'^(/config|/opt/vyatta/etc/config)' - absolute = os.path.abspath(os.path.dirname(path)) - return re.match(location,absolute) - -def wait_for_inotify(file_path, pre_hook=None, event_type=None, timeout=None, sleep_interval=0.1): - """ Waits for an inotify event to occur """ - if not os.path.dirname(file_path): - raise ValueError( - "File path {} does not have a directory part (required for inotify watching)".format(file_path)) - if not os.path.basename(file_path): - raise ValueError( - "File path {} does not have a file part, do not know what to watch for".format(file_path)) - - from inotify.adapters import Inotify - from time import time - from time import sleep - - time_start = time() - - i = Inotify() - i.add_watch(os.path.dirname(file_path)) - - if pre_hook: - pre_hook() - - for event in i.event_gen(yield_nones=True): - if (timeout is not None) and ((time() - time_start) > timeout): - # If the function didn't return until this point, - # the file failed to have been written to and closed within the timeout - raise OSError("Waiting for file {} to be written has failed".format(file_path)) - - # Most such events don't take much time, so it's better to check right away - # and sleep later. - if event is not None: - (_, type_names, path, filename) = event - if filename == os.path.basename(file_path): - if event_type in type_names: - return - sleep(sleep_interval) - -def wait_for_file_write_complete(file_path, pre_hook=None, timeout=None, sleep_interval=0.1): - """ Waits for a process to close a file after opening it in write mode. """ - wait_for_inotify(file_path, - event_type='IN_CLOSE_WRITE', pre_hook=pre_hook, timeout=timeout, sleep_interval=sleep_interval) - -def commit_in_progress(): - """ Not to be used in normal op mode scripts! """ - - # The CStore backend locks the config by opening a file - # The file is not removed after commit, so just checking - # if it exists is insufficient, we need to know if it's open by anyone - - # There are two ways to check if any other process keeps a file open. - # The first one is to try opening it and see if the OS objects. - # That's faster but prone to race conditions and can be intrusive. - # The other one is to actually check if any process keeps it open. - # It's non-intrusive but needs root permissions, else you can't check - # processes of other users. - # - # Since this will be used in scripts that modify the config outside of the CLI - # framework, those knowingly have root permissions. - # For everything else, we add a safeguard. - from psutil import process_iter - from psutil import NoSuchProcess - from getpass import getuser - from vyos.defaults import commit_lock - - if getuser() != 'root': - raise OSError('This functions needs to be run as root to return correct results!') - - for proc in process_iter(): - try: - files = proc.open_files() - if files: - for f in files: - if f.path == commit_lock: - return True - except NoSuchProcess as err: - # Process died before we could examine it - pass - # Default case - return False - - -def wait_for_commit_lock(): - """ Not to be used in normal op mode scripts! """ - from time import sleep - # Very synchronous approach to multiprocessing - while commit_in_progress(): - sleep(1) - -def ask_input(question, default='', numeric_only=False, valid_responses=[]): - question_out = question - if default: - question_out += f' (Default: {default})' - response = '' - while True: - response = input(question_out + ' ').strip() - if not response and default: - return default - if numeric_only: - if not response.isnumeric(): - print("Invalid value, try again.") - continue - response = int(response) - if valid_responses and response not in valid_responses: - print("Invalid value, try again.") - continue - break - return response - -def ask_yes_no(question, default=False) -> bool: - """Ask a yes/no question via input() and return their answer.""" - from sys import stdout - default_msg = "[Y/n]" if default else "[y/N]" - while True: - try: - stdout.write("%s %s " % (question, default_msg)) - c = input().lower() - if c == '': - return default - elif c in ("y", "ye", "yes"): - return True - elif c in ("n", "no"): - return False - else: - stdout.write("Please respond with yes/y or no/n\n") - except EOFError: - stdout.write("\nPlease respond with yes/y or no/n\n") - -def is_admin() -> bool: - """Look if current user is in sudo group""" - from getpass import getuser - from grp import getgrnam - current_user = getuser() - (_, _, _, admin_group_members) = getgrnam('sudo') - return current_user in admin_group_members - - -def mac2eui64(mac, prefix=None): - """ - Convert a MAC address to a EUI64 address or, with prefix provided, a full - IPv6 address. - Thankfully copied from https://gist.github.com/wido/f5e32576bb57b5cc6f934e177a37a0d3 - """ - import re - from ipaddress import ip_network - # http://tools.ietf.org/html/rfc4291#section-2.5.1 - eui64 = re.sub(r'[.:-]', '', mac).lower() - eui64 = eui64[0:6] + 'fffe' + eui64[6:] - eui64 = hex(int(eui64[0:2], 16) ^ 2)[2:].zfill(2) + eui64[2:] - - if prefix is None: - return ':'.join(re.findall(r'.{4}', eui64)) - else: - try: - net = ip_network(prefix, strict=False) - euil = int('0x{0}'.format(eui64), 16) - return str(net[euil]) - except: # pylint: disable=bare-except - return - -def get_half_cpus(): - """ return 1/2 of the numbers of available CPUs """ - cpu = os.cpu_count() - if cpu > 1: - cpu /= 2 - return int(cpu) - -def check_kmod(k_mod): - """ Common utility function to load required kernel modules on demand """ - from vyos import ConfigError - if isinstance(k_mod, str): - k_mod = k_mod.split() - for module in k_mod: - if not os.path.exists(f'/sys/module/{module}'): - if call(f'modprobe {module}') != 0: - raise ConfigError(f'Loading Kernel module {module} failed') - -def find_device_file(device): - """ Recurively search /dev for the given device file and return its full path. - If no device file was found 'None' is returned """ - from fnmatch import fnmatch - - for root, dirs, files in os.walk('/dev'): - for basename in files: - if fnmatch(basename, device): - return os.path.join(root, basename) - - return None - -def dict_search(path, dict_object): - """ Traverse Python dictionary (dict_object) delimited by dot (.). - Return value of key if found, None otherwise. - - This is faster implementation then jmespath.search('foo.bar', dict_object)""" - if not isinstance(dict_object, dict) or not path: - return None - - parts = path.split('.') - inside = parts[:-1] - if not inside: - if path not in dict_object: - return None - return dict_object[path] - c = dict_object - for p in parts[:-1]: - c = c.get(p, {}) - return c.get(parts[-1], None) - -def dict_search_args(dict_object, *path): - # Traverse dictionary using variable arguments - # Added due to above function not allowing for '.' in the key names - # Example: dict_search_args(some_dict, 'key', 'subkey', 'subsubkey', ...) - if not isinstance(dict_object, dict) or not path: - return None - - for item in path: - if item not in dict_object: - return None - dict_object = dict_object[item] - return dict_object - -def dict_search_recursive(dict_object, key, path=[]): - """ Traverse a dictionary recurisvely and return the value of the key - we are looking for. - - Thankfully copied from https://stackoverflow.com/a/19871956 - - Modified to yield optional path to found keys - """ - if isinstance(dict_object, list): - for i in dict_object: - new_path = path + [i] - for x in dict_search_recursive(i, key, new_path): - yield x - elif isinstance(dict_object, dict): - if key in dict_object: - new_path = path + [key] - yield dict_object[key], new_path - for k, j in dict_object.items(): - new_path = path + [k] - for x in dict_search_recursive(j, key, new_path): - yield x - -def convert_data(data): - """Convert multiple types of data to types usable in CLI - - Args: - data (str | bytes | list | OrderedDict): input data - - Returns: - str | list | dict: converted data - """ - from collections import OrderedDict - - if isinstance(data, str): - return data - if isinstance(data, bytes): - return data.decode() - if isinstance(data, list): - list_tmp = [] - for item in data: - list_tmp.append(convert_data(item)) - return list_tmp - if isinstance(data, OrderedDict): - dict_tmp = {} - for key, value in data.items(): - dict_tmp[key] = convert_data(value) - return dict_tmp - -def get_bridge_fdb(interface): - """ Returns the forwarding database entries for a given interface """ - if not os.path.exists(f'/sys/class/net/{interface}'): - return None - from json import loads - tmp = loads(cmd(f'bridge -j fdb show dev {interface}')) - return tmp - -def get_interface_config(interface): - """ Returns the used encapsulation protocol for given interface. - If interface does not exist, None is returned. - """ - if not os.path.exists(f'/sys/class/net/{interface}'): - return None - from json import loads - tmp = loads(cmd(f'ip -d -j link show {interface}'))[0] - return tmp - -def get_interface_address(interface): - """ Returns the used encapsulation protocol for given interface. - If interface does not exist, None is returned. - """ - if not os.path.exists(f'/sys/class/net/{interface}'): - return None - from json import loads - tmp = loads(cmd(f'ip -d -j addr show {interface}'))[0] - return tmp - -def get_interface_namespace(iface): - """ - Returns wich netns the interface belongs to - """ - from json import loads - # Check if netns exist - tmp = loads(cmd(f'ip --json netns ls')) - if len(tmp) == 0: - return None - - for ns in tmp: - namespace = f'{ns["name"]}' - # Search interface in each netns - data = loads(cmd(f'ip netns exec {namespace} ip -j link show')) - for compare in data: - if iface == compare["ifname"]: - return namespace - -def get_all_vrfs(): - """ Return a dictionary of all system wide known VRF instances """ - from json import loads - tmp = loads(cmd('ip -j vrf list')) - # Result is of type [{"name":"red","table":1000},{"name":"blue","table":2000}] - # so we will re-arrange it to a more nicer representation: - # {'red': {'table': 1000}, 'blue': {'table': 2000}} - data = {} - for entry in tmp: - name = entry.pop('name') - data[name] = entry - return data - -def print_error(str='', end='\n'): - """ - Print `str` to stderr, terminated with `end`. - Used for warnings and out-of-band messages to avoid mangling precious - stdout output. - """ - sys.stderr.write(str) - sys.stderr.write(end) - sys.stderr.flush() - -def make_progressbar(): - """ - Make a procedure that takes two arguments `done` and `total` and prints a - progressbar based on the ratio thereof, whose length is determined by the - width of the terminal. - """ - import shutil, math - col, _ = shutil.get_terminal_size() - col = max(col - 15, 20) - def print_progressbar(done, total): - if done <= total: - increment = total / col - length = math.ceil(done / increment) - percentage = str(math.ceil(100 * done / total)).rjust(3) - print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') - # Print a newline so that the subsequent prints don't overwrite the full bar. - if done == total: - print_error() - return print_progressbar - -def make_incremental_progressbar(increment: float): - """ - Make a generator that displays a progressbar that grows monotonically with - every iteration. - First call displays it at 0% and every subsequent iteration displays it - at `increment` increments where 0.0 < `increment` < 1.0. - Intended for FTP and HTTP transfers with stateless callbacks. - """ - print_progressbar = make_progressbar() - total = 0.0 - while total < 1.0: - print_progressbar(total, 1.0) - yield - total += increment - print_progressbar(1, 1) - # Ignore further calls. - while True: - yield - -def begin(*args): - """ - Evaluate arguments in order and return the result of the *last* argument. - For combining multiple expressions in one statement. Useful for lambdas. - """ - return args[-1] - -def begin0(*args): - """ - Evaluate arguments in order and return the result of the *first* argument. - For combining multiple expressions in one statement. Useful for lambdas. - """ - return args[0] - -def is_systemd_service_active(service): - """ Test is a specified systemd service is activated. - Returns True if service is active, false otherwise. - Copied from: https://unix.stackexchange.com/a/435317 """ - tmp = cmd(f'systemctl show --value -p ActiveState {service}') - return bool((tmp == 'active')) - -def is_systemd_service_running(service): - """ Test is a specified systemd service is actually running. - Returns True if service is running, false otherwise. - Copied from: https://unix.stackexchange.com/a/435317 """ - tmp = cmd(f'systemctl show --value -p SubState {service}') - return bool((tmp == 'running')) - -def check_port_availability(ipaddress, port, protocol): - """ - Check if port is available and not used by any service - Return False if a port is busy or IP address does not exists - Should be used carefully for services that can start listening - dynamically, because IP address may be dynamic too - """ - from socketserver import TCPServer, UDPServer - from ipaddress import ip_address - - # verify arguments - try: - ipaddress = ip_address(ipaddress).compressed - except: - raise ValueError(f'The {ipaddress} is not a valid IPv4 or IPv6 address') - if port not in range(1, 65536): - raise ValueError(f'The port number {port} is not in the 1-65535 range') - if protocol not in ['tcp', 'udp']: - raise ValueError( - f'The protocol {protocol} is not supported. Only tcp and udp are allowed' - ) - - # check port availability - try: - if protocol == 'tcp': - server = TCPServer((ipaddress, port), None, bind_and_activate=True) - if protocol == 'udp': - server = UDPServer((ipaddress, port), None, bind_and_activate=True) - server.server_close() - return True - except: - return False - -def install_into_config(conf, config_paths, override_prompt=True): - # Allows op-mode scripts to install values if called from an active config session - # config_paths: dict of config paths - # override_prompt: if True, user will be prompted before existing nodes are overwritten - - if not config_paths: - return None - - from vyos.config import Config - - if not Config().in_session(): - print('You are not in configure mode, commands to install manually from configure mode:') - for path in config_paths: - print(f'set {path}') - return None - - count = 0 - failed = [] - - for path in config_paths: - if override_prompt and conf.exists(path) and not conf.is_multi(path): - if not ask_yes_no(f'Config node "{node}" already exists. Do you want to overwrite it?'): - continue - - try: - cmd(f'/opt/vyatta/sbin/my_set {path}') - count += 1 - except: - failed.append(path) - - if failed: - print(f'Failed to install {len(failed)} value(s). Commands to manually install:') - for path in failed: - print(f'set {path}') - - if count > 0: - print(f'{count} value(s) installed. Use "compare" to see the pending changes, and "commit" to apply.') - -def is_wwan_connected(interface): - """ Determine if a given WWAN interface, e.g. wwan0 is connected to the - carrier network or not """ - import json - - if not interface.startswith('wwan'): - raise ValueError(f'Specified interface "{interface}" is not a WWAN interface') - - # ModemManager is required for connection(s) - if service is not running, - # there won't be any connection at all! - if not is_systemd_service_active('ModemManager.service'): - return False - - modem = interface.lstrip('wwan') - - tmp = cmd(f'mmcli --modem {modem} --output-json') - tmp = json.loads(tmp) - - # return True/False if interface is in connected state - return dict_search('modem.generic.state', tmp) == 'connected' - -def boot_configuration_complete() -> bool: - """ Check if the boot config loader has completed - """ - from vyos.defaults import config_status - - if os.path.isfile(config_status): - return True - return False - -def sysctl_read(name): - """ Read and return current value of sysctl() option """ - tmp = cmd(f'sysctl {name}') - return tmp.split()[-1] - -def sysctl_write(name, value): - """ Change value via sysctl() - return True if changed, False otherwise """ - tmp = cmd(f'sysctl {name}') - # last list index contains the actual value - only write if value differs - if sysctl_read(name) != str(value): - call(f'sysctl -wq {name}={value}') - return True - return False - -# approach follows a discussion in: -# https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case -def camel_to_snake_case(name: str) -> str: - pattern = r'\d+|[A-Z]?[a-z]+|\W|[A-Z]{2,}(?=[A-Z][a-z]|\d|\W|$)' - words = re.findall(pattern, name) - return '_'.join(map(str.lower, words)) - -def load_as_module(name: str, path: str): - import importlib.util - - spec = importlib.util.spec_from_file_location(name, path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod diff --git a/python/vyos/utils/__init__.py b/python/vyos/utils/__init__.py new file mode 100644 index 000000000..12ef2d3b8 --- /dev/null +++ b/python/vyos/utils/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.utils import assertion +from vyos.utils import auth +from vyos.utils import boot +from vyos.utils import commit +from vyos.utils import convert +from vyos.utils import dict +from vyos.utils import file +from vyos.utils import io +from vyos.utils import kernel +from vyos.utils import list +from vyos.utils import misc +from vyos.utils import network +from vyos.utils import permission +from vyos.utils import process +from vyos.utils import system diff --git a/python/vyos/utils/assertion.py b/python/vyos/utils/assertion.py new file mode 100644 index 000000000..1aaa54dff --- /dev/null +++ b/python/vyos/utils/assertion.py @@ -0,0 +1,81 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +def assert_boolean(b): + if int(b) not in (0, 1): + raise ValueError(f'Value {b} out of range') + +def assert_range(value, lower=0, count=3): + if int(value, 16) not in range(lower, lower+count): + raise ValueError("Value out of range") + +def assert_list(s, l): + if s not in l: + o = ' or '.join([f'"{n}"' for n in l]) + raise ValueError(f'state must be {o}, got {s}') + +def assert_number(n): + if not str(n).isnumeric(): + raise ValueError(f'{n} must be a number') + +def assert_positive(n, smaller=0): + assert_number(n) + if int(n) < smaller: + raise ValueError(f'{n} is smaller than {smaller}') + +def assert_mtu(mtu, ifname): + assert_number(mtu) + + import json + from vyos.utils.process import cmd + out = cmd(f'ip -j -d link show dev {ifname}') + # [{"ifindex":2,"ifname":"eth0","flags":["BROADCAST","MULTICAST","UP","LOWER_UP"],"mtu":1500,"qdisc":"pfifo_fast","operstate":"UP","linkmode":"DEFAULT","group":"default","txqlen":1000,"link_type":"ether","address":"08:00:27:d9:5b:04","broadcast":"ff:ff:ff:ff:ff:ff","promiscuity":0,"min_mtu":46,"max_mtu":16110,"inet6_addr_gen_mode":"none","num_tx_queues":1,"num_rx_queues":1,"gso_max_size":65536,"gso_max_segs":65535}] + parsed = json.loads(out)[0] + min_mtu = int(parsed.get('min_mtu', '0')) + # cur_mtu = parsed.get('mtu',0), + max_mtu = int(parsed.get('max_mtu', '0')) + cur_mtu = int(mtu) + + if (min_mtu and cur_mtu < min_mtu) or cur_mtu < 68: + raise ValueError(f'MTU is too small for interface "{ifname}": {mtu} < {min_mtu}') + if (max_mtu and cur_mtu > max_mtu) or cur_mtu > 65536: + raise ValueError(f'MTU is too small for interface "{ifname}": {mtu} > {max_mtu}') + +def assert_mac(m): + split = m.split(':') + size = len(split) + + # a mac address consits out of 6 octets + if size != 6: + raise ValueError(f'wrong number of MAC octets ({size}): {m}') + + octets = [] + try: + for octet in split: + octets.append(int(octet, 16)) + except ValueError: + raise ValueError(f'invalid hex number "{octet}" in : {m}') + + # validate against the first mac address byte if it's a multicast + # address + if octets[0] & 1: + raise ValueError(f'{m} is a multicast MAC address') + + # overall mac address is not allowed to be 00:00:00:00:00:00 + if sum(octets) == 0: + raise ValueError('00:00:00:00:00:00 is not a valid MAC address') + + if octets[:5] == (0, 0, 94, 0, 1): + raise ValueError(f'{m} is a VRRP MAC address') diff --git a/python/vyos/authutils.py b/python/vyos/utils/auth.py index 66b5f4a74..a59858d72 100644 --- a/python/vyos/authutils.py +++ b/python/vyos/utils/auth.py @@ -15,7 +15,7 @@ import re -from vyos.util import cmd +from vyos.utils.process import cmd def make_password_hash(password): diff --git a/python/vyos/utils/boot.py b/python/vyos/utils/boot.py new file mode 100644 index 000000000..3aecbec64 --- /dev/null +++ b/python/vyos/utils/boot.py @@ -0,0 +1,35 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +import os + +def boot_configuration_complete() -> bool: + """ Check if the boot config loader has completed + """ + from vyos.defaults import config_status + if os.path.isfile(config_status): + return True + return False + +def boot_configuration_success() -> bool: + from vyos.defaults import config_status + try: + with open(config_status) as f: + res = f.read().strip() + except FileNotFoundError: + return False + if int(res) == 0: + return True + return False diff --git a/python/vyos/utils/commit.py b/python/vyos/utils/commit.py new file mode 100644 index 000000000..105aed8c2 --- /dev/null +++ b/python/vyos/utils/commit.py @@ -0,0 +1,60 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +def commit_in_progress(): + """ Not to be used in normal op mode scripts! """ + + # The CStore backend locks the config by opening a file + # The file is not removed after commit, so just checking + # if it exists is insufficient, we need to know if it's open by anyone + + # There are two ways to check if any other process keeps a file open. + # The first one is to try opening it and see if the OS objects. + # That's faster but prone to race conditions and can be intrusive. + # The other one is to actually check if any process keeps it open. + # It's non-intrusive but needs root permissions, else you can't check + # processes of other users. + # + # Since this will be used in scripts that modify the config outside of the CLI + # framework, those knowingly have root permissions. + # For everything else, we add a safeguard. + from psutil import process_iter + from psutil import NoSuchProcess + from getpass import getuser + from vyos.defaults import commit_lock + + if getuser() != 'root': + raise OSError('This functions needs to be run as root to return correct results!') + + for proc in process_iter(): + try: + files = proc.open_files() + if files: + for f in files: + if f.path == commit_lock: + return True + except NoSuchProcess as err: + # Process died before we could examine it + pass + # Default case + return False + + +def wait_for_commit_lock(): + """ Not to be used in normal op mode scripts! """ + from time import sleep + # Very synchronous approach to multiprocessing + while commit_in_progress(): + sleep(1) diff --git a/python/vyos/utils/convert.py b/python/vyos/utils/convert.py new file mode 100644 index 000000000..9a8a1ff7d --- /dev/null +++ b/python/vyos/utils/convert.py @@ -0,0 +1,197 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +def seconds_to_human(s, separator=""): + """ Converts number of seconds passed to a human-readable + interval such as 1w4d18h35m59s + """ + s = int(s) + + week = 60 * 60 * 24 * 7 + day = 60 * 60 * 24 + hour = 60 * 60 + + remainder = 0 + result = "" + + weeks = s // week + if weeks > 0: + result = "{0}w".format(weeks) + s = s % week + + days = s // day + if days > 0: + result = "{0}{1}{2}d".format(result, separator, days) + s = s % day + + hours = s // hour + if hours > 0: + result = "{0}{1}{2}h".format(result, separator, hours) + s = s % hour + + minutes = s // 60 + if minutes > 0: + result = "{0}{1}{2}m".format(result, separator, minutes) + s = s % 60 + + seconds = s + if seconds > 0: + result = "{0}{1}{2}s".format(result, separator, seconds) + + return result + +def bytes_to_human(bytes, initial_exponent=0, precision=2): + """ Converts a value in bytes to a human-readable size string like 640 KB + + The initial_exponent parameter is the exponent of 2, + e.g. 10 (1024) for kilobytes, 20 (1024 * 1024) for megabytes. + """ + + if bytes == 0: + return "0 B" + + from math import log2 + + bytes = bytes * (2**initial_exponent) + + # log2 is a float, while range checking requires an int + exponent = int(log2(bytes)) + + if exponent < 10: + value = bytes + suffix = "B" + elif exponent in range(10, 20): + value = bytes / 1024 + suffix = "KB" + elif exponent in range(20, 30): + value = bytes / 1024**2 + suffix = "MB" + elif exponent in range(30, 40): + value = bytes / 1024**3 + suffix = "GB" + else: + value = bytes / 1024**4 + suffix = "TB" + # Add a new case when the first machine with petabyte RAM + # hits the market. + + size_string = "{0:.{1}f} {2}".format(value, precision, suffix) + return size_string + +def human_to_bytes(value): + """ Converts a data amount with a unit suffix to bytes, like 2K to 2048 """ + + from re import match as re_match + + res = re_match(r'^\s*(\d+(?:\.\d+)?)\s*([a-zA-Z]+)\s*$', value) + + if not res: + raise ValueError(f"'{value}' is not a valid data amount") + else: + amount = float(res.group(1)) + unit = res.group(2).lower() + + if unit == 'b': + res = amount + elif (unit == 'k') or (unit == 'kb'): + res = amount * 1024 + elif (unit == 'm') or (unit == 'mb'): + res = amount * 1024**2 + elif (unit == 'g') or (unit == 'gb'): + res = amount * 1024**3 + elif (unit == 't') or (unit == 'tb'): + res = amount * 1024**4 + else: + raise ValueError(f"Unsupported data unit '{unit}'") + + # There cannot be fractional bytes, so we convert them to integer. + # However, truncating causes problems with conversion back to human unit, + # so we round instead -- that seems to work well enough. + return round(res) + +def mac_to_eui64(mac, prefix=None): + """ + Convert a MAC address to a EUI64 address or, with prefix provided, a full + IPv6 address. + Thankfully copied from https://gist.github.com/wido/f5e32576bb57b5cc6f934e177a37a0d3 + """ + import re + from ipaddress import ip_network + # http://tools.ietf.org/html/rfc4291#section-2.5.1 + eui64 = re.sub(r'[.:-]', '', mac).lower() + eui64 = eui64[0:6] + 'fffe' + eui64[6:] + eui64 = hex(int(eui64[0:2], 16) ^ 2)[2:].zfill(2) + eui64[2:] + + if prefix is None: + return ':'.join(re.findall(r'.{4}', eui64)) + else: + try: + net = ip_network(prefix, strict=False) + euil = int('0x{0}'.format(eui64), 16) + return str(net[euil]) + except: # pylint: disable=bare-except + return + + +def convert_data(data) -> dict | list | tuple | str | int | float | bool | None: + """Filter and convert multiple types of data to types usable in CLI/API + + WARNING: Must not be used for anything except formatting output for API or CLI + + On the output allowed everything supported in JSON. + + Args: + data (Any): input data + + Returns: + dict | list | tuple | str | int | float | bool | None: converted data + """ + from base64 import b64encode + + # return original data for types which do not require conversion + if isinstance(data, str | int | float | bool | None): + return data + + if isinstance(data, list): + list_tmp = [] + for item in data: + list_tmp.append(convert_data(item)) + return list_tmp + + if isinstance(data, tuple): + list_tmp = list(data) + tuple_tmp = tuple(convert_data(list_tmp)) + return tuple_tmp + + if isinstance(data, bytes | bytearray): + try: + return data.decode() + except UnicodeDecodeError: + return b64encode(data).decode() + + if isinstance(data, set | frozenset): + list_tmp = convert_data(list(data)) + return list_tmp + + if isinstance(data, dict): + dict_tmp = {} + for key, value in data.items(): + dict_tmp[key] = convert_data(value) + return dict_tmp + + # do not return anything for other types + # which cannot be converted to JSON + # for example: complex | range | memoryview + return diff --git a/python/vyos/utils/dict.py b/python/vyos/utils/dict.py new file mode 100644 index 000000000..9484eacdd --- /dev/null +++ b/python/vyos/utils/dict.py @@ -0,0 +1,307 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +def colon_separated_to_dict(data_string, uniquekeys=False): + """ Converts a string containing newline-separated entries + of colon-separated key-value pairs into a dict. + + Such files are common in Linux /proc filesystem + + Args: + data_string (str): data string + uniquekeys (bool): whether to insist that keys are unique or not + + Returns: dict + + Raises: + ValueError: if uniquekeys=True and the data string has + duplicate keys. + + Note: + If uniquekeys=True, then dict entries are always strings, + otherwise they are always lists of strings. + """ + import re + key_value_re = re.compile('([^:]+)\s*\:\s*(.*)') + + data_raw = re.split('\n', data_string) + + data = {} + + for l in data_raw: + l = l.strip() + if l: + match = re.match(key_value_re, l) + if match and (len(match.groups()) == 2): + key = match.groups()[0].strip() + value = match.groups()[1].strip() + else: + raise ValueError(f"""Line "{l}" could not be parsed a colon-separated pair """, l) + if key in data.keys(): + if uniquekeys: + raise ValueError("Data string has duplicate keys: {0}".format(key)) + else: + data[key].append(value) + else: + if uniquekeys: + data[key] = value + else: + data[key] = [value] + else: + pass + + return data + +def mangle_dict_keys(data, regex, replacement, abs_path=None, no_tag_node_value_mangle=False): + """ Mangles dict keys according to a regex and replacement character. + Some libraries like Jinja2 do not like certain characters in dict keys. + This function can be used for replacing all offending characters + with something acceptable. + + Args: + data (dict): Original dict to mangle + regex, replacement (str): arguments to re.sub(regex, replacement, ...) + abs_path (list): if data is a config dict and no_tag_node_value_mangle is True + then abs_path should be the absolute config path to the first + keys of data, non-inclusive + no_tag_node_value_mangle (bool): do not mangle keys of tag node values + + Returns: dict + """ + import re + from vyos.xml_ref import is_tag_value + + if abs_path is None: + abs_path = [] + + new_dict = type(data)() + + for k in data.keys(): + if no_tag_node_value_mangle and is_tag_value(abs_path + [k]): + new_key = k + else: + new_key = re.sub(regex, replacement, k) + + value = data[k] + + if isinstance(value, dict): + new_dict[new_key] = mangle_dict_keys(value, regex, replacement, + abs_path=abs_path + [k], + no_tag_node_value_mangle=no_tag_node_value_mangle) + else: + new_dict[new_key] = value + + return new_dict + +def _get_sub_dict(d, lpath): + k = lpath[0] + if k not in d.keys(): + return {} + c = {k: d[k]} + lpath = lpath[1:] + if not lpath: + return c + elif not isinstance(c[k], dict): + return {} + return _get_sub_dict(c[k], lpath) + +def get_sub_dict(source, lpath, get_first_key=False): + """ Returns the sub-dict of a nested dict, defined by path of keys. + + Args: + source (dict): Source dict to extract from + lpath (list[str]): sequence of keys + + Returns: source, if lpath is empty, else + {key : source[..]..[key]} for key the last element of lpath, if exists + {} otherwise + """ + if not isinstance(source, dict): + raise TypeError("source must be of type dict") + if not isinstance(lpath, list): + raise TypeError("path must be of type list") + if not lpath: + return source + + ret = _get_sub_dict(source, lpath) + + if get_first_key and lpath and ret: + tmp = next(iter(ret.values())) + if not isinstance(tmp, dict): + raise TypeError("Data under node is not of type dict") + ret = tmp + + return ret + +def dict_search(path, dict_object): + """ Traverse Python dictionary (dict_object) delimited by dot (.). + Return value of key if found, None otherwise. + + This is faster implementation then jmespath.search('foo.bar', dict_object)""" + if not isinstance(dict_object, dict) or not path: + return None + + parts = path.split('.') + inside = parts[:-1] + if not inside: + if path not in dict_object: + return None + return dict_object[path] + c = dict_object + for p in parts[:-1]: + c = c.get(p, {}) + return c.get(parts[-1], None) + +def dict_search_args(dict_object, *path): + # Traverse dictionary using variable arguments + # Added due to above function not allowing for '.' in the key names + # Example: dict_search_args(some_dict, 'key', 'subkey', 'subsubkey', ...) + if not isinstance(dict_object, dict) or not path: + return None + + for item in path: + if item not in dict_object: + return None + dict_object = dict_object[item] + return dict_object + +def dict_search_recursive(dict_object, key, path=[]): + """ Traverse a dictionary recurisvely and return the value of the key + we are looking for. + + Thankfully copied from https://stackoverflow.com/a/19871956 + + Modified to yield optional path to found keys + """ + if isinstance(dict_object, list): + for i in dict_object: + new_path = path + [i] + for x in dict_search_recursive(i, key, new_path): + yield x + elif isinstance(dict_object, dict): + if key in dict_object: + new_path = path + [key] + yield dict_object[key], new_path + for k, j in dict_object.items(): + new_path = path + [k] + for x in dict_search_recursive(j, key, new_path): + yield x + +def dict_to_list(d, save_key_to=None): + """ Convert a dict to a list of dicts. + + Optionally, save the original key of the dict inside + dicts stores in that list. + """ + def save_key(i, k): + if isinstance(i, dict): + i[save_key_to] = k + return + elif isinstance(i, list): + for _i in i: + save_key(_i, k) + else: + raise ValueError(f"Cannot save the key: the item is {type(i)}, not a dict") + + collect = [] + + for k,_ in d.items(): + item = d[k] + if save_key_to is not None: + save_key(item, k) + if isinstance(item, list): + collect += item + else: + collect.append(item) + + return collect + +def dict_to_paths(d: dict) -> list: + """ Generator to return list of paths from dict of list[str]|str + """ + def func(d, path): + if isinstance(d, dict): + if not d: + yield path + for k, v in d.items(): + for r in func(v, path + [k]): + yield r + elif isinstance(d, list): + for i in d: + for r in func(i, path): + yield r + elif isinstance(d, str): + yield path + [d] + else: + raise ValueError('object is not a dict of strings/list of strings') + for r in func(d, []): + yield r + +def check_mutually_exclusive_options(d, keys, required=False): + """ Checks if a dict has at most one or only one of + mutually exclusive keys. + """ + present_keys = [] + + for k in d: + if k in keys: + present_keys.append(k) + + # Un-mangle the keys to make them match CLI option syntax + from re import sub + orig_keys = list(map(lambda s: sub(r'_', '-', s), keys)) + orig_present_keys = list(map(lambda s: sub(r'_', '-', s), present_keys)) + + if len(present_keys) > 1: + raise ValueError(f"Options {orig_keys} are mutually-exclusive but more than one of them is present: {orig_present_keys}") + + if required and (len(present_keys) < 1): + raise ValueError(f"At least one of the following options is required: {orig_keys}") + +class FixedDict(dict): + """ + FixedDict: A dictionnary not allowing new keys to be created after initialisation. + + >>> f = FixedDict(**{'count':1}) + >>> f['count'] = 2 + >>> f['king'] = 3 + File "...", line ..., in __setitem__ + raise ConfigError(f'Option "{k}" has no defined default') + """ + + from vyos import ConfigError + + def __init__(self, **options): + self._allowed = options.keys() + super().__init__(**options) + + def __setitem__(self, k, v): + """ + __setitem__ is a builtin which is called by python when setting dict values: + >>> d = dict() + >>> d['key'] = 'value' + >>> d + {'key': 'value'} + + is syntaxic sugar for + + >>> d = dict() + >>> d.__setitem__('key','value') + >>> d + {'key': 'value'} + """ + if k not in self._allowed: + raise ConfigError(f'Option "{k}" has no defined default') + super().__setitem__(k, v) diff --git a/python/vyos/utils/file.py b/python/vyos/utils/file.py new file mode 100644 index 000000000..667a2464b --- /dev/null +++ b/python/vyos/utils/file.py @@ -0,0 +1,183 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +import os +from vyos.utils.permission import chown + +def makedir(path, user=None, group=None): + if os.path.exists(path): + return + os.makedirs(path, mode=0o755) + chown(path, user, group) + +def file_is_persistent(path): + import re + location = r'^(/config|/opt/vyatta/etc/config)' + absolute = os.path.abspath(os.path.dirname(path)) + return re.match(location,absolute) + +def read_file(fname, defaultonfailure=None): + """ + read the content of a file, stripping any end characters (space, newlines) + should defaultonfailure be not None, it is returned on failure to read + """ + try: + """ Read a file to string """ + with open(fname, 'r') as f: + data = f.read().strip() + return data + except Exception as e: + if defaultonfailure is not None: + return defaultonfailure + raise e + +def write_file(fname, data, defaultonfailure=None, user=None, group=None, mode=None, append=False): + """ + Write content of data to given fname, should defaultonfailure be not None, + it is returned on failure to read. + + If directory of file is not present, it is auto-created. + """ + dirname = os.path.dirname(fname) + if not os.path.isdir(dirname): + os.makedirs(dirname, mode=0o755, exist_ok=False) + chown(dirname, user, group) + + try: + """ Write a file to string """ + bytes = 0 + with open(fname, 'w' if not append else 'a') as f: + bytes = f.write(data) + chown(fname, user, group) + chmod(fname, mode) + return bytes + except Exception as e: + if defaultonfailure is not None: + return defaultonfailure + raise e + +def read_json(fname, defaultonfailure=None): + """ + read and json decode the content of a file + should defaultonfailure be not None, it is returned on failure to read + """ + import json + try: + with open(fname, 'r') as f: + data = json.load(f) + return data + except Exception as e: + if defaultonfailure is not None: + return defaultonfailure + raise e + +def chown(path, user, group): + """ change file/directory owner """ + from pwd import getpwnam + from grp import getgrnam + + if user is None or group is None: + return False + + # path may also be an open file descriptor + if not isinstance(path, int) and not os.path.exists(path): + return False + + uid = getpwnam(user).pw_uid + gid = getgrnam(group).gr_gid + os.chown(path, uid, gid) + return True + + +def chmod(path, bitmask): + # path may also be an open file descriptor + if not isinstance(path, int) and not os.path.exists(path): + return + if bitmask is None: + return + os.chmod(path, bitmask) + + +def chmod_600(path): + """ Make file only read/writable by owner """ + from stat import S_IRUSR, S_IWUSR + + bitmask = S_IRUSR | S_IWUSR + chmod(path, bitmask) + + +def chmod_750(path): + """ Make file/directory only executable to user and group """ + from stat import S_IRUSR, S_IWUSR, S_IXUSR, S_IRGRP, S_IXGRP + + bitmask = S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IXGRP + chmod(path, bitmask) + + +def chmod_755(path): + """ Make file executable by all """ + from stat import S_IRUSR, S_IWUSR, S_IXUSR, S_IRGRP, S_IXGRP, S_IROTH, S_IXOTH + + bitmask = S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IXGRP | \ + S_IROTH | S_IXOTH + chmod(path, bitmask) + + +def makedir(path, user=None, group=None): + if os.path.exists(path): + return + os.makedirs(path, mode=0o755) + chown(path, user, group) + +def wait_for_inotify(file_path, pre_hook=None, event_type=None, timeout=None, sleep_interval=0.1): + """ Waits for an inotify event to occur """ + if not os.path.dirname(file_path): + raise ValueError( + "File path {} does not have a directory part (required for inotify watching)".format(file_path)) + if not os.path.basename(file_path): + raise ValueError( + "File path {} does not have a file part, do not know what to watch for".format(file_path)) + + from inotify.adapters import Inotify + from time import time + from time import sleep + + time_start = time() + + i = Inotify() + i.add_watch(os.path.dirname(file_path)) + + if pre_hook: + pre_hook() + + for event in i.event_gen(yield_nones=True): + if (timeout is not None) and ((time() - time_start) > timeout): + # If the function didn't return until this point, + # the file failed to have been written to and closed within the timeout + raise OSError("Waiting for file {} to be written has failed".format(file_path)) + + # Most such events don't take much time, so it's better to check right away + # and sleep later. + if event is not None: + (_, type_names, path, filename) = event + if filename == os.path.basename(file_path): + if event_type in type_names: + return + sleep(sleep_interval) + +def wait_for_file_write_complete(file_path, pre_hook=None, timeout=None, sleep_interval=0.1): + """ Waits for a process to close a file after opening it in write mode. """ + wait_for_inotify(file_path, + event_type='IN_CLOSE_WRITE', pre_hook=pre_hook, timeout=timeout, sleep_interval=sleep_interval) diff --git a/python/vyos/utils/io.py b/python/vyos/utils/io.py new file mode 100644 index 000000000..843494855 --- /dev/null +++ b/python/vyos/utils/io.py @@ -0,0 +1,103 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +def print_error(str='', end='\n'): + """ + Print `str` to stderr, terminated with `end`. + Used for warnings and out-of-band messages to avoid mangling precious + stdout output. + """ + import sys + sys.stderr.write(str) + sys.stderr.write(end) + sys.stderr.flush() + +def make_progressbar(): + """ + Make a procedure that takes two arguments `done` and `total` and prints a + progressbar based on the ratio thereof, whose length is determined by the + width of the terminal. + """ + import shutil, math + col, _ = shutil.get_terminal_size() + col = max(col - 15, 20) + def print_progressbar(done, total): + if done <= total: + increment = total / col + length = math.ceil(done / increment) + percentage = str(math.ceil(100 * done / total)).rjust(3) + print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') + # Print a newline so that the subsequent prints don't overwrite the full bar. + if done == total: + print_error() + return print_progressbar + +def make_incremental_progressbar(increment: float): + """ + Make a generator that displays a progressbar that grows monotonically with + every iteration. + First call displays it at 0% and every subsequent iteration displays it + at `increment` increments where 0.0 < `increment` < 1.0. + Intended for FTP and HTTP transfers with stateless callbacks. + """ + print_progressbar = make_progressbar() + total = 0.0 + while total < 1.0: + print_progressbar(total, 1.0) + yield + total += increment + print_progressbar(1, 1) + # Ignore further calls. + while True: + yield + +def ask_input(question, default='', numeric_only=False, valid_responses=[]): + question_out = question + if default: + question_out += f' (Default: {default})' + response = '' + while True: + response = input(question_out + ' ').strip() + if not response and default: + return default + if numeric_only: + if not response.isnumeric(): + print("Invalid value, try again.") + continue + response = int(response) + if valid_responses and response not in valid_responses: + print("Invalid value, try again.") + continue + break + return response + +def ask_yes_no(question, default=False) -> bool: + """Ask a yes/no question via input() and return their answer.""" + from sys import stdout + default_msg = "[Y/n]" if default else "[y/N]" + while True: + try: + stdout.write("%s %s " % (question, default_msg)) + c = input().lower() + if c == '': + return default + elif c in ("y", "ye", "yes"): + return True + elif c in ("n", "no"): + return False + else: + stdout.write("Please respond with yes/y or no/n\n") + except EOFError: + stdout.write("\nPlease respond with yes/y or no/n\n") diff --git a/python/vyos/utils/kernel.py b/python/vyos/utils/kernel.py new file mode 100644 index 000000000..1f3bbdffe --- /dev/null +++ b/python/vyos/utils/kernel.py @@ -0,0 +1,38 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +import os + +def check_kmod(k_mod): + """ Common utility function to load required kernel modules on demand """ + from vyos import ConfigError + from vyos.utils.process import call + if isinstance(k_mod, str): + k_mod = k_mod.split() + for module in k_mod: + if not os.path.exists(f'/sys/module/{module}'): + if call(f'modprobe {module}') != 0: + raise ConfigError(f'Loading Kernel module {module} failed') + +def unload_kmod(k_mod): + """ Common utility function to unload required kernel modules on demand """ + from vyos import ConfigError + from vyos.utils.process import call + if isinstance(k_mod, str): + k_mod = k_mod.split() + for module in k_mod: + if os.path.exists(f'/sys/module/{module}'): + if call(f'rmmod {module}') != 0: + raise ConfigError(f'Unloading Kernel module {module} failed') diff --git a/python/vyos/utils/list.py b/python/vyos/utils/list.py new file mode 100644 index 000000000..63ef720ab --- /dev/null +++ b/python/vyos/utils/list.py @@ -0,0 +1,20 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +def is_list_equal(first: list, second: list) -> bool: + """ Check if 2 lists are equal and list not empty """ + if len(first) != len(second) or len(first) == 0: + return False + return sorted(first) == sorted(second) diff --git a/python/vyos/utils/misc.py b/python/vyos/utils/misc.py new file mode 100644 index 000000000..d82655914 --- /dev/null +++ b/python/vyos/utils/misc.py @@ -0,0 +1,66 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +def begin(*args): + """ + Evaluate arguments in order and return the result of the *last* argument. + For combining multiple expressions in one statement. Useful for lambdas. + """ + return args[-1] + +def begin0(*args): + """ + Evaluate arguments in order and return the result of the *first* argument. + For combining multiple expressions in one statement. Useful for lambdas. + """ + return args[0] + +def install_into_config(conf, config_paths, override_prompt=True): + # Allows op-mode scripts to install values if called from an active config session + # config_paths: dict of config paths + # override_prompt: if True, user will be prompted before existing nodes are overwritten + if not config_paths: + return None + + from vyos.config import Config + from vyos.utils.io import ask_yes_no + from vyos.utils.process import cmd + if not Config().in_session(): + print('You are not in configure mode, commands to install manually from configure mode:') + for path in config_paths: + print(f'set {path}') + return None + + count = 0 + failed = [] + + for path in config_paths: + if override_prompt and conf.exists(path) and not conf.is_multi(path): + if not ask_yes_no(f'Config node "{node}" already exists. Do you want to overwrite it?'): + continue + + try: + cmd(f'/opt/vyatta/sbin/my_set {path}') + count += 1 + except: + failed.append(path) + + if failed: + print(f'Failed to install {len(failed)} value(s). Commands to manually install:') + for path in failed: + print(f'set {path}') + + if count > 0: + print(f'{count} value(s) installed. Use "compare" to see the pending changes, and "commit" to apply.') diff --git a/python/vyos/utils/network.py b/python/vyos/utils/network.py new file mode 100644 index 000000000..55ff29f0c --- /dev/null +++ b/python/vyos/utils/network.py @@ -0,0 +1,427 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +def _are_same_ip(one, two): + from socket import AF_INET + from socket import AF_INET6 + from socket import inet_pton + from vyos.template import is_ipv4 + # compare the binary representation of the IP + f_one = AF_INET if is_ipv4(one) else AF_INET6 + s_two = AF_INET if is_ipv4(two) else AF_INET6 + return inet_pton(f_one, one) == inet_pton(f_one, two) + +def get_protocol_by_name(protocol_name): + """Get protocol number by protocol name + + % get_protocol_by_name('tcp') + % 6 + """ + import socket + try: + protocol_number = socket.getprotobyname(protocol_name) + return protocol_number + except socket.error: + return protocol_name + +def interface_exists(interface) -> bool: + import os + return os.path.exists(f'/sys/class/net/{interface}') + +def is_netns_interface(interface, netns): + from vyos.utils.process import rc_cmd + rc, out = rc_cmd(f'sudo ip netns exec {netns} ip link show dev {interface}') + if rc == 0: + return True + return False + +def get_netns_all() -> list: + from json import loads + from vyos.utils.process import cmd + tmp = loads(cmd('ip --json netns ls')) + return [ netns['name'] for netns in tmp ] + +def get_vrf_members(vrf: str) -> list: + """ + Get list of interface VRF members + :param vrf: str + :return: list + """ + import json + from vyos.utils.process import cmd + if not interface_exists(vrf): + raise ValueError(f'VRF "{vrf}" does not exist!') + output = cmd(f'ip --json --brief link show master {vrf}') + answer = json.loads(output) + interfaces = [] + for data in answer: + if 'ifname' in data: + interfaces.append(data.get('ifname')) + return interfaces + +def get_interface_vrf(interface): + """ Returns VRF of given interface """ + from vyos.utils.dict import dict_search + from vyos.utils.network import get_interface_config + tmp = get_interface_config(interface) + if dict_search('linkinfo.info_slave_kind', tmp) == 'vrf': + return tmp['master'] + return 'default' + +def get_interface_config(interface): + """ Returns the used encapsulation protocol for given interface. + If interface does not exist, None is returned. + """ + if not interface_exists(interface): + return None + from json import loads + from vyos.utils.process import cmd + tmp = loads(cmd(f'ip --detail --json link show dev {interface}'))[0] + return tmp + +def get_interface_address(interface): + """ Returns the used encapsulation protocol for given interface. + If interface does not exist, None is returned. + """ + if not interface_exists(interface): + return None + from json import loads + from vyos.utils.process import cmd + tmp = loads(cmd(f'ip --detail --json addr show dev {interface}'))[0] + return tmp + +def get_interface_namespace(interface: str): + """ + Returns wich netns the interface belongs to + """ + from json import loads + from vyos.utils.process import cmd + + # Bail out early if netns does not exist + tmp = cmd(f'ip --json netns ls') + if not tmp: return None + + for ns in loads(tmp): + netns = f'{ns["name"]}' + # Search interface in each netns + data = loads(cmd(f'ip netns exec {netns} ip --json link show')) + for tmp in data: + if interface == tmp["ifname"]: + return netns + +def is_ipv6_tentative(iface: str, ipv6_address: str) -> bool: + """Check if IPv6 address is in tentative state. + + This function checks if an IPv6 address on a specific network interface is + in the tentative state. IPv6 tentative addresses are not fully configured + and are undergoing Duplicate Address Detection (DAD) to ensure they are + unique on the network. + + Args: + iface (str): The name of the network interface. + ipv6_address (str): The IPv6 address to check. + + Returns: + bool: True if the IPv6 address is tentative, False otherwise. + """ + import json + from vyos.utils.process import rc_cmd + + rc, out = rc_cmd(f'ip -6 --json address show dev {iface} scope global') + if rc: + return False + + data = json.loads(out) + for addr_info in data[0]['addr_info']: + if ( + addr_info.get('local') == ipv6_address and + addr_info.get('tentative', False) + ): + return True + return False + +def is_wwan_connected(interface): + """ Determine if a given WWAN interface, e.g. wwan0 is connected to the + carrier network or not """ + import json + from vyos.utils.process import cmd + + if not interface.startswith('wwan'): + raise ValueError(f'Specified interface "{interface}" is not a WWAN interface') + + # ModemManager is required for connection(s) - if service is not running, + # there won't be any connection at all! + if not is_systemd_service_active('ModemManager.service'): + return False + + modem = interface.lstrip('wwan') + + tmp = cmd(f'mmcli --modem {modem} --output-json') + tmp = json.loads(tmp) + + # return True/False if interface is in connected state + return dict_search('modem.generic.state', tmp) == 'connected' + +def get_bridge_fdb(interface): + """ Returns the forwarding database entries for a given interface """ + if not interface_exists(interface): + return None + from json import loads + from vyos.utils.process import cmd + tmp = loads(cmd(f'bridge -j fdb show dev {interface}')) + return tmp + +def get_all_vrfs(): + """ Return a dictionary of all system wide known VRF instances """ + from json import loads + from vyos.utils.process import cmd + tmp = loads(cmd('ip --json vrf list')) + # Result is of type [{"name":"red","table":1000},{"name":"blue","table":2000}] + # so we will re-arrange it to a more nicer representation: + # {'red': {'table': 1000}, 'blue': {'table': 2000}} + data = {} + for entry in tmp: + name = entry.pop('name') + data[name] = entry + return data + +def mac2eui64(mac, prefix=None): + """ + Convert a MAC address to a EUI64 address or, with prefix provided, a full + IPv6 address. + Thankfully copied from https://gist.github.com/wido/f5e32576bb57b5cc6f934e177a37a0d3 + """ + import re + from ipaddress import ip_network + # http://tools.ietf.org/html/rfc4291#section-2.5.1 + eui64 = re.sub(r'[.:-]', '', mac).lower() + eui64 = eui64[0:6] + 'fffe' + eui64[6:] + eui64 = hex(int(eui64[0:2], 16) ^ 2)[2:].zfill(2) + eui64[2:] + + if prefix is None: + return ':'.join(re.findall(r'.{4}', eui64)) + else: + try: + net = ip_network(prefix, strict=False) + euil = int('0x{0}'.format(eui64), 16) + return str(net[euil]) + except: # pylint: disable=bare-except + return + +def check_port_availability(ipaddress, port, protocol): + """ + Check if port is available and not used by any service + Return False if a port is busy or IP address does not exists + Should be used carefully for services that can start listening + dynamically, because IP address may be dynamic too + """ + from socketserver import TCPServer, UDPServer + from ipaddress import ip_address + + # verify arguments + try: + ipaddress = ip_address(ipaddress).compressed + except: + raise ValueError(f'The {ipaddress} is not a valid IPv4 or IPv6 address') + if port not in range(1, 65536): + raise ValueError(f'The port number {port} is not in the 1-65535 range') + if protocol not in ['tcp', 'udp']: + raise ValueError(f'The protocol {protocol} is not supported. Only tcp and udp are allowed') + + # check port availability + try: + if protocol == 'tcp': + server = TCPServer((ipaddress, port), None, bind_and_activate=True) + if protocol == 'udp': + server = UDPServer((ipaddress, port), None, bind_and_activate=True) + server.server_close() + except Exception as e: + # errno.h: + #define EADDRINUSE 98 /* Address already in use */ + if e.errno == 98: + return False + + return True + +def is_listen_port_bind_service(port: int, service: str) -> bool: + """Check if listen port bound to expected program name + :param port: Bind port + :param service: Program name + :return: bool + + Example: + % is_listen_port_bind_service(443, 'nginx') + True + % is_listen_port_bind_service(443, 'ocserv-main') + False + """ + from psutil import net_connections as connections + from psutil import Process as process + for connection in connections(): + addr = connection.laddr + pid = connection.pid + pid_name = process(pid).name() + pid_port = addr.port + if service == pid_name and port == pid_port: + return True + return False + +def is_ipv6_link_local(addr): + """ Check if addrsss is an IPv6 link-local address. Returns True/False """ + from ipaddress import ip_interface + from vyos.template import is_ipv6 + addr = addr.split('%')[0] + if is_ipv6(addr): + if ip_interface(addr).is_link_local: + return True + + return False + +def is_addr_assigned(ip_address, vrf=None) -> bool: + """ Verify if the given IPv4/IPv6 address is assigned to any interface """ + from netifaces import interfaces + from vyos.utils.network import get_interface_config + from vyos.utils.dict import dict_search + + for interface in interfaces(): + # Check if interface belongs to the requested VRF, if this is not the + # case there is no need to proceed with this data set - continue loop + # with next element + tmp = get_interface_config(interface) + if dict_search('master', tmp) != vrf: + continue + + if is_intf_addr_assigned(interface, ip_address): + return True + + return False + +def is_intf_addr_assigned(ifname: str, addr: str, netns: str=None) -> bool: + """ + Verify if the given IPv4/IPv6 address is assigned to specific interface. + It can check both a single IP address (e.g. 192.0.2.1 or a assigned CIDR + address 192.0.2.1/24. + """ + import json + import jmespath + + from vyos.utils.process import rc_cmd + from ipaddress import ip_interface + + netns_cmd = f'ip netns exec {netns}' if netns else '' + rc, out = rc_cmd(f'{netns_cmd} ip --json address show dev {ifname}') + if rc == 0: + json_out = json.loads(out) + addresses = jmespath.search("[].addr_info[].{family: family, address: local, prefixlen: prefixlen}", json_out) + for address_info in addresses: + family = address_info['family'] + address = address_info['address'] + prefixlen = address_info['prefixlen'] + # Remove the interface name if present in the given address + if '%' in addr: + addr = addr.split('%')[0] + interface = ip_interface(f"{address}/{prefixlen}") + if ip_interface(addr) == interface or address == addr: + return True + + return False + +def is_loopback_addr(addr): + """ Check if supplied IPv4/IPv6 address is a loopback address """ + from ipaddress import ip_address + return ip_address(addr).is_loopback + +def is_wireguard_key_pair(private_key: str, public_key:str) -> bool: + """ + Checks if public/private keys are keypair + :param private_key: Wireguard private key + :type private_key: str + :param public_key: Wireguard public key + :type public_key: str + :return: If public/private keys are keypair returns True else False + :rtype: bool + """ + from vyos.utils.process import cmd + gen_public_key = cmd('wg pubkey', input=private_key) + if gen_public_key == public_key: + return True + else: + return False + +def is_subnet_connected(subnet, primary=False): + """ + Verify is the given IPv4/IPv6 subnet is connected to any interface on this + system. + + primary check if the subnet is reachable via the primary IP address of this + interface, or in other words has a broadcast address configured. ISC DHCP + for instance will complain if it should listen on non broadcast interfaces. + + Return True/False + """ + from ipaddress import ip_address + from ipaddress import ip_network + + from netifaces import ifaddresses + from netifaces import interfaces + from netifaces import AF_INET + from netifaces import AF_INET6 + + from vyos.template import is_ipv6 + + # determine IP version (AF_INET or AF_INET6) depending on passed address + addr_type = AF_INET + if is_ipv6(subnet): + addr_type = AF_INET6 + + for interface in interfaces(): + # check if the requested address type is configured at all + if addr_type not in ifaddresses(interface).keys(): + continue + + # An interface can have multiple addresses, but some software components + # only support the primary address :( + if primary: + ip = ifaddresses(interface)[addr_type][0]['addr'] + if ip_address(ip) in ip_network(subnet): + return True + else: + # Check every assigned IP address if it is connected to the subnet + # in question + for ip in ifaddresses(interface)[addr_type]: + # remove interface extension (e.g. %eth0) that gets thrown on the end of _some_ addrs + addr = ip['addr'].split('%')[0] + if ip_address(addr) in ip_network(subnet): + return True + + return False + +def is_afi_configured(interface, afi): + """ Check if given address family is configured, or in other words - an IP + address is assigned to the interface. """ + from netifaces import ifaddresses + from netifaces import AF_INET + from netifaces import AF_INET6 + + if afi not in [AF_INET, AF_INET6]: + raise ValueError('Address family must be in [AF_INET, AF_INET6]') + + try: + addresses = ifaddresses(interface) + except ValueError as e: + print(e) + return False + + return afi in addresses diff --git a/python/vyos/utils/permission.py b/python/vyos/utils/permission.py new file mode 100644 index 000000000..d938b494f --- /dev/null +++ b/python/vyos/utils/permission.py @@ -0,0 +1,78 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +import os + +def chown(path, user, group): + """ change file/directory owner """ + from pwd import getpwnam + from grp import getgrnam + + if user is None or group is None: + return False + + # path may also be an open file descriptor + if not isinstance(path, int) and not os.path.exists(path): + return False + + uid = getpwnam(user).pw_uid + gid = getgrnam(group).gr_gid + os.chown(path, uid, gid) + return True + +def chmod(path, bitmask): + # path may also be an open file descriptor + if not isinstance(path, int) and not os.path.exists(path): + return + if bitmask is None: + return + os.chmod(path, bitmask) + +def chmod_600(path): + """ make file only read/writable by owner """ + from stat import S_IRUSR, S_IWUSR + + bitmask = S_IRUSR | S_IWUSR + chmod(path, bitmask) + +def chmod_750(path): + """ make file/directory only executable to user and group """ + from stat import S_IRUSR, S_IWUSR, S_IXUSR, S_IRGRP, S_IXGRP + + bitmask = S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IXGRP + chmod(path, bitmask) + +def chmod_755(path): + """ make file executable by all """ + from stat import S_IRUSR, S_IWUSR, S_IXUSR, S_IRGRP, S_IXGRP, S_IROTH, S_IXOTH + + bitmask = S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IXGRP | \ + S_IROTH | S_IXOTH + chmod(path, bitmask) + +def is_admin() -> bool: + """Look if current user is in sudo group""" + from getpass import getuser + from grp import getgrnam + current_user = getuser() + (_, _, _, admin_group_members) = getgrnam('sudo') + return current_user in admin_group_members + +def get_cfg_group_id(): + from grp import getgrnam + from vyos.defaults import cfg_group + + group_data = getgrnam(cfg_group) + return group_data.gr_gid diff --git a/python/vyos/utils/process.py b/python/vyos/utils/process.py new file mode 100644 index 000000000..c2ef98140 --- /dev/null +++ b/python/vyos/utils/process.py @@ -0,0 +1,232 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +import os + +from subprocess import Popen +from subprocess import PIPE +from subprocess import STDOUT +from subprocess import DEVNULL + +def popen(command, flag='', shell=None, input=None, timeout=None, env=None, + stdout=PIPE, stderr=PIPE, decode='utf-8'): + """ + popen is a wrapper helper aound subprocess.Popen + with it default setting it will return a tuple (out, err) + out: the output of the program run + err: the error code returned by the program + + it can be affected by the following flags: + shell: do not try to auto-detect if a shell is required + for example if a pipe (|) or redirection (>, >>) is used + input: data to sent to the child process via STDIN + the data should be bytes but string will be converted + timeout: time after which the command will be considered to have failed + env: mapping that defines the environment variables for the new process + stdout: define how the output of the program should be handled + - PIPE (default), sends stdout to the output + - DEVNULL, discard the output + stderr: define how the output of the program should be handled + - None (default), send/merge the data to/with stderr + - PIPE, popen will append it to output + - STDOUT, send the data to be merged with stdout + - DEVNULL, discard the output + decode: specify the expected text encoding (utf-8, ascii, ...) + the default is explicitely utf-8 which is python's own default + + usage: + get both stdout and stderr: popen('command', stdout=PIPE, stderr=STDOUT) + discard stdout and get stderr: popen('command', stdout=DEVNUL, stderr=PIPE) + """ + + # airbag must be left as an import in the function as otherwise we have a + # a circual import dependency + from vyos import debug + from vyos import airbag + + # log if the flag is set, otherwise log if command is set + if not debug.enabled(flag): + flag = 'command' + + cmd_msg = f"cmd '{command}'" + debug.message(cmd_msg, flag) + + use_shell = shell + stdin = None + if shell is None: + use_shell = False + if ' ' in command: + use_shell = True + if env: + use_shell = True + + if input: + stdin = PIPE + input = input.encode() if type(input) is str else input + + p = Popen(command, stdin=stdin, stdout=stdout, stderr=stderr, + env=env, shell=use_shell) + + pipe = p.communicate(input, timeout) + + pipe_out = b'' + if stdout == PIPE: + pipe_out = pipe[0] + + pipe_err = b'' + if stderr == PIPE: + pipe_err = pipe[1] + + str_out = pipe_out.decode(decode).replace('\r\n', '\n').strip() + str_err = pipe_err.decode(decode).replace('\r\n', '\n').strip() + + out_msg = f"returned (out):\n{str_out}" + if str_out: + debug.message(out_msg, flag) + + if str_err: + from sys import stderr + err_msg = f"returned (err):\n{str_err}" + # this message will also be send to syslog via airbag + debug.message(err_msg, flag, destination=stderr) + + # should something go wrong, report this too via airbag + airbag.noteworthy(cmd_msg) + airbag.noteworthy(out_msg) + airbag.noteworthy(err_msg) + + return str_out, p.returncode + + +def run(command, flag='', shell=None, input=None, timeout=None, env=None, + stdout=DEVNULL, stderr=PIPE, decode='utf-8'): + """ + A wrapper around popen, which discard the stdout and + will return the error code of a command + """ + _, code = popen( + command, flag, + stdout=stdout, stderr=stderr, + input=input, timeout=timeout, + env=env, shell=shell, + decode=decode, + ) + return code + + +def cmd(command, flag='', shell=None, input=None, timeout=None, env=None, + stdout=PIPE, stderr=PIPE, decode='utf-8', raising=None, message='', + expect=[0]): + """ + A wrapper around popen, which returns the stdout and + will raise the error code of a command + + raising: specify which call should be used when raising + the class should only require a string as parameter + (default is OSError) with the error code + expect: a list of error codes to consider as normal + """ + decoded, code = popen( + command, flag, + stdout=stdout, stderr=stderr, + input=input, timeout=timeout, + env=env, shell=shell, + decode=decode, + ) + if code not in expect: + feedback = message + '\n' if message else '' + feedback += f'failed to run command: {command}\n' + feedback += f'returned: {decoded}\n' + feedback += f'exit code: {code}' + if raising is None: + # error code can be recovered with .errno + raise OSError(code, feedback) + else: + raise raising(feedback) + return decoded + + +def rc_cmd(command, flag='', shell=None, input=None, timeout=None, env=None, + stdout=PIPE, stderr=STDOUT, decode='utf-8'): + """ + A wrapper around popen, which returns the return code + of a command and stdout + + % rc_cmd('uname') + (0, 'Linux') + % rc_cmd('ip link show dev eth99') + (1, 'Device "eth99" does not exist.') + """ + out, code = popen( + command.lstrip(), flag, + stdout=stdout, stderr=stderr, + input=input, timeout=timeout, + env=env, shell=shell, + decode=decode, + ) + return code, out + +def call(command, flag='', shell=None, input=None, timeout=None, env=None, + stdout=None, stderr=None, decode='utf-8'): + """ + A wrapper around popen, which print the stdout and + will return the error code of a command + """ + out, code = popen( + command, flag, + stdout=stdout, stderr=stderr, + input=input, timeout=timeout, + env=env, shell=shell, + decode=decode, + ) + if out: + print(out) + return code + +def process_running(pid_file): + """ Checks if a process with PID in pid_file is running """ + from psutil import pid_exists + if not os.path.isfile(pid_file): + return False + with open(pid_file, 'r') as f: + pid = f.read().strip() + return pid_exists(int(pid)) + +def process_named_running(name, cmdline: str=None): + """ Checks if process with given name is running and returns its PID. + If Process is not running, return None + """ + from psutil import process_iter + for p in process_iter(['name', 'pid', 'cmdline']): + if cmdline: + if p.info['name'] == name and cmdline in p.info['cmdline']: + return p.info['pid'] + elif p.info['name'] == name: + return p.info['pid'] + return None + +def is_systemd_service_active(service): + """ Test is a specified systemd service is activated. + Returns True if service is active, false otherwise. + Copied from: https://unix.stackexchange.com/a/435317 """ + tmp = cmd(f'systemctl show --value -p ActiveState {service}') + return bool((tmp == 'active')) + +def is_systemd_service_running(service): + """ Test is a specified systemd service is actually running. + Returns True if service is running, false otherwise. + Copied from: https://unix.stackexchange.com/a/435317 """ + tmp = cmd(f'systemctl show --value -p SubState {service}') + return bool((tmp == 'running')) diff --git a/python/vyos/utils/system.py b/python/vyos/utils/system.py new file mode 100644 index 000000000..5d41c0c05 --- /dev/null +++ b/python/vyos/utils/system.py @@ -0,0 +1,107 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +import os +from subprocess import run + +def sysctl_read(name: str) -> str: + """Read and return current value of sysctl() option + + Args: + name (str): sysctl key name + + Returns: + str: sysctl key value + """ + tmp = run(['sysctl', '-nb', name], capture_output=True) + return tmp.stdout.decode() + +def sysctl_write(name: str, value: str | int) -> bool: + """Change value via sysctl() + + Args: + name (str): sysctl key name + value (str | int): sysctl key value + + Returns: + bool: True if changed, False otherwise + """ + # convert other types to string before comparison + if not isinstance(value, str): + value = str(value) + # do not change anything if a value is already configured + if sysctl_read(name) == value: + return True + # return False if sysctl call failed + if run(['sysctl', '-wq', f'{name}={value}']).returncode != 0: + return False + # compare old and new values + # sysctl may apply value, but its actual value will be + # different from requested + if sysctl_read(name) == value: + return True + # False in other cases + return False + +def sysctl_apply(sysctl_dict: dict[str, str], revert: bool = True) -> bool: + """Apply sysctl values. + + Args: + sysctl_dict (dict[str, str]): dictionary with sysctl keys with values + revert (bool, optional): Revert to original values if new were not + applied. Defaults to True. + + Returns: + bool: True if all params configured properly, False in other cases + """ + # get current values + sysctl_original: dict[str, str] = {} + for key_name in sysctl_dict.keys(): + sysctl_original[key_name] = sysctl_read(key_name) + # apply new values and revert in case one of them was not applied + for key_name, value in sysctl_dict.items(): + if not sysctl_write(key_name, value): + if revert: + sysctl_apply(sysctl_original, revert=False) + return False + # everything applied + return True + +def get_half_cpus(): + """ return 1/2 of the numbers of available CPUs """ + cpu = os.cpu_count() + if cpu > 1: + cpu /= 2 + return int(cpu) + +def find_device_file(device): + """ Recurively search /dev for the given device file and return its full path. + If no device file was found 'None' is returned """ + from fnmatch import fnmatch + + for root, dirs, files in os.walk('/dev'): + for basename in files: + if fnmatch(basename, device): + return os.path.join(root, basename) + + return None + +def load_as_module(name: str, path: str): + import importlib.util + + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod diff --git a/python/vyos/validate.py b/python/vyos/validate.py deleted file mode 100644 index a83193363..000000000 --- a/python/vyos/validate.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright 2018-2021 VyOS maintainers and contributors <maintainers@vyos.io> -# -# This library is free software; you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public -# License as published by the Free Software Foundation; either -# version 2.1 of the License, or (at your option) any later version. -# -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public -# License along with this library. If not, see <http://www.gnu.org/licenses/>. - -# Important note when you are adding new validation functions: -# -# The Control class will analyse the signature of the function in this file -# and will build the parameters to be passed to it. -# -# The parameter names "ifname" and "self" will get the Interface name and class -# parameters with default will be left unset -# all other paramters will receive the value to check - -def is_ipv6_link_local(addr): - """ Check if addrsss is an IPv6 link-local address. Returns True/False """ - from ipaddress import ip_interface - from vyos.template import is_ipv6 - addr = addr.split('%')[0] - if is_ipv6(addr): - if ip_interface(addr).is_link_local: - return True - - return False - -def _are_same_ip(one, two): - from socket import AF_INET - from socket import AF_INET6 - from socket import inet_pton - from vyos.template import is_ipv4 - # compare the binary representation of the IP - f_one = AF_INET if is_ipv4(one) else AF_INET6 - s_two = AF_INET if is_ipv4(two) else AF_INET6 - return inet_pton(f_one, one) == inet_pton(f_one, two) - -def is_intf_addr_assigned(intf, address) -> bool: - """ - Verify if the given IPv4/IPv6 address is assigned to specific interface. - It can check both a single IP address (e.g. 192.0.2.1 or a assigned CIDR - address 192.0.2.1/24. - """ - from vyos.template import is_ipv4 - - from netifaces import ifaddresses - from netifaces import AF_INET - from netifaces import AF_INET6 - - # check if the requested address type is configured at all - # { - # 17: [{'addr': '08:00:27:d9:5b:04', 'broadcast': 'ff:ff:ff:ff:ff:ff'}], - # 2: [{'addr': '10.0.2.15', 'netmask': '255.255.255.0', 'broadcast': '10.0.2.255'}], - # 10: [{'addr': 'fe80::a00:27ff:fed9:5b04%eth0', 'netmask': 'ffff:ffff:ffff:ffff::'}] - # } - try: - ifaces = ifaddresses(intf) - except ValueError as e: - print(e) - return False - - # determine IP version (AF_INET or AF_INET6) depending on passed address - addr_type = AF_INET if is_ipv4(address) else AF_INET6 - - # Check every IP address on this interface for a match - netmask = None - if '/' in address: - address, netmask = address.split('/') - for ip in ifaces.get(addr_type,[]): - # ip can have the interface name in the 'addr' field, we need to remove it - # {'addr': 'fe80::a00:27ff:fec5:f821%eth2', 'netmask': 'ffff:ffff:ffff:ffff::'} - ip_addr = ip['addr'].split('%')[0] - - if not _are_same_ip(address, ip_addr): - continue - - # we do not have a netmask to compare against, they are the same - if not netmask: - return True - - prefixlen = '' - if is_ipv4(ip_addr): - prefixlen = sum([bin(int(_)).count('1') for _ in ip['netmask'].split('.')]) - else: - prefixlen = sum([bin(int(_,16)).count('1') for _ in ip['netmask'].split('/')[0].split(':') if _]) - - if str(prefixlen) == netmask: - return True - - return False - -def is_addr_assigned(ip_address, vrf=None) -> bool: - """ Verify if the given IPv4/IPv6 address is assigned to any interfac """ - from netifaces import interfaces - from vyos.util import get_interface_config - from vyos.util import dict_search - for interface in interfaces(): - # Check if interface belongs to the requested VRF, if this is not the - # case there is no need to proceed with this data set - continue loop - # with next element - tmp = get_interface_config(interface) - if dict_search('master', tmp) != vrf: - continue - - if is_intf_addr_assigned(interface, ip_address): - return True - - return False - -def is_loopback_addr(addr): - """ Check if supplied IPv4/IPv6 address is a loopback address """ - from ipaddress import ip_address - return ip_address(addr).is_loopback - -def is_subnet_connected(subnet, primary=False): - """ - Verify is the given IPv4/IPv6 subnet is connected to any interface on this - system. - - primary check if the subnet is reachable via the primary IP address of this - interface, or in other words has a broadcast address configured. ISC DHCP - for instance will complain if it should listen on non broadcast interfaces. - - Return True/False - """ - from ipaddress import ip_address - from ipaddress import ip_network - - from netifaces import ifaddresses - from netifaces import interfaces - from netifaces import AF_INET - from netifaces import AF_INET6 - - from vyos.template import is_ipv6 - - # determine IP version (AF_INET or AF_INET6) depending on passed address - addr_type = AF_INET - if is_ipv6(subnet): - addr_type = AF_INET6 - - for interface in interfaces(): - # check if the requested address type is configured at all - if addr_type not in ifaddresses(interface).keys(): - continue - - # An interface can have multiple addresses, but some software components - # only support the primary address :( - if primary: - ip = ifaddresses(interface)[addr_type][0]['addr'] - if ip_address(ip) in ip_network(subnet): - return True - else: - # Check every assigned IP address if it is connected to the subnet - # in question - for ip in ifaddresses(interface)[addr_type]: - # remove interface extension (e.g. %eth0) that gets thrown on the end of _some_ addrs - addr = ip['addr'].split('%')[0] - if ip_address(addr) in ip_network(subnet): - return True - - return False - - -def assert_boolean(b): - if int(b) not in (0, 1): - raise ValueError(f'Value {b} out of range') - - -def assert_range(value, lower=0, count=3): - if int(value, 16) not in range(lower, lower+count): - raise ValueError("Value out of range") - - -def assert_list(s, l): - if s not in l: - o = ' or '.join([f'"{n}"' for n in l]) - raise ValueError(f'state must be {o}, got {s}') - - -def assert_number(n): - if not str(n).isnumeric(): - raise ValueError(f'{n} must be a number') - - -def assert_positive(n, smaller=0): - assert_number(n) - if int(n) < smaller: - raise ValueError(f'{n} is smaller than {smaller}') - - -def assert_mtu(mtu, ifname): - assert_number(mtu) - - import json - from vyos.util import cmd - out = cmd(f'ip -j -d link show dev {ifname}') - # [{"ifindex":2,"ifname":"eth0","flags":["BROADCAST","MULTICAST","UP","LOWER_UP"],"mtu":1500,"qdisc":"pfifo_fast","operstate":"UP","linkmode":"DEFAULT","group":"default","txqlen":1000,"link_type":"ether","address":"08:00:27:d9:5b:04","broadcast":"ff:ff:ff:ff:ff:ff","promiscuity":0,"min_mtu":46,"max_mtu":16110,"inet6_addr_gen_mode":"none","num_tx_queues":1,"num_rx_queues":1,"gso_max_size":65536,"gso_max_segs":65535}] - parsed = json.loads(out)[0] - min_mtu = int(parsed.get('min_mtu', '0')) - # cur_mtu = parsed.get('mtu',0), - max_mtu = int(parsed.get('max_mtu', '0')) - cur_mtu = int(mtu) - - if (min_mtu and cur_mtu < min_mtu) or cur_mtu < 68: - raise ValueError(f'MTU is too small for interface "{ifname}": {mtu} < {min_mtu}') - if (max_mtu and cur_mtu > max_mtu) or cur_mtu > 65536: - raise ValueError(f'MTU is too small for interface "{ifname}": {mtu} > {max_mtu}') - - -def assert_mac(m): - split = m.split(':') - size = len(split) - - # a mac address consits out of 6 octets - if size != 6: - raise ValueError(f'wrong number of MAC octets ({size}): {m}') - - octets = [] - try: - for octet in split: - octets.append(int(octet, 16)) - except ValueError: - raise ValueError(f'invalid hex number "{octet}" in : {m}') - - # validate against the first mac address byte if it's a multicast - # address - if octets[0] & 1: - raise ValueError(f'{m} is a multicast MAC address') - - # overall mac address is not allowed to be 00:00:00:00:00:00 - if sum(octets) == 0: - raise ValueError('00:00:00:00:00:00 is not a valid MAC address') - - if octets[:5] == (0, 0, 94, 0, 1): - raise ValueError(f'{m} is a VRRP MAC address') - -def has_address_configured(conf, intf): - """ - Checks if interface has an address configured. - Checks the following config nodes: - 'address', 'ipv6 address eui64', 'ipv6 address autoconf' - - Returns True if interface has address configured, False if it doesn't. - """ - from vyos.ifconfig import Section - ret = False - - old_level = conf.get_level() - conf.set_level([]) - - intfpath = 'interfaces ' + Section.get_config_path(intf) - if ( conf.exists(f'{intfpath} address') or - conf.exists(f'{intfpath} ipv6 address autoconf') or - conf.exists(f'{intfpath} ipv6 address eui64') ): - ret = True - - conf.set_level(old_level) - return ret - -def has_vrf_configured(conf, intf): - """ - Checks if interface has a VRF configured. - - Returns True if interface has VRF configured, False if it doesn't. - """ - from vyos.ifconfig import Section - ret = False - - old_level = conf.get_level() - conf.set_level([]) - - tmp = ['interfaces', Section.get_config_path(intf), 'vrf'] - if conf.exists(tmp): - ret = True - - conf.set_level(old_level) - return ret diff --git a/python/vyos/version.py b/python/vyos/version.py index fb706ad44..1c5651c83 100644 --- a/python/vyos/version.py +++ b/python/vyos/version.py @@ -1,4 +1,4 @@ -# Copyright 2017-2020 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2017-2023 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -34,11 +34,11 @@ import json import requests import vyos.defaults -from vyos.util import read_file -from vyos.util import read_json -from vyos.util import popen -from vyos.util import run -from vyos.util import DEVNULL +from vyos.utils.file import read_file +from vyos.utils.file import read_json +from vyos.utils.process import popen +from vyos.utils.process import run +from vyos.utils.process import DEVNULL version_file = os.path.join(vyos.defaults.directories['data'], 'version.json') diff --git a/python/vyos/vpp.py b/python/vyos/vpp.py new file mode 100644 index 000000000..76e5d29c3 --- /dev/null +++ b/python/vyos/vpp.py @@ -0,0 +1,315 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from functools import wraps +from pathlib import Path +from re import search as re_search, fullmatch as re_fullmatch, MULTILINE as re_M +from subprocess import run +from time import sleep + +from vpp_papi import VPPApiClient +from vpp_papi import VPPIOError, VPPValueError + + +class VPPControl: + """Control VPP network stack + """ + + class _Decorators: + """Decorators for VPPControl + """ + + @classmethod + def api_call(cls, decorated_func): + """Check if API is connected before API call + + Args: + decorated_func: function to decorate + + Raises: + VPPIOError: Connection to API is not established + """ + + @wraps(decorated_func) + def api_safe_wrapper(cls, *args, **kwargs): + if not cls.vpp_api_client.transport.connected: + raise VPPIOError(2, 'VPP API is not connected') + return decorated_func(cls, *args, **kwargs) + + return api_safe_wrapper + + @classmethod + def check_retval(cls, decorated_func): + """Check retval from API response + + Args: + decorated_func: function to decorate + + Raises: + VPPValueError: raised when retval is not 0 + """ + + @wraps(decorated_func) + def check_retval_wrapper(cls, *args, **kwargs): + return_value = decorated_func(cls, *args, **kwargs) + if not return_value.retval == 0: + raise VPPValueError( + f'VPP API call failed: {return_value.retval}') + return return_value + + return check_retval_wrapper + + def __init__(self, attempts: int = 5, interval: int = 1000) -> None: + """Create VPP API connection + + Args: + attempts (int, optional): attempts to connect. Defaults to 5. + interval (int, optional): interval between attempts in ms. Defaults to 1000. + + Raises: + VPPIOError: Connection to API cannot be established + """ + self.vpp_api_client = VPPApiClient() + # connect with interval + while attempts: + try: + attempts -= 1 + self.vpp_api_client.connect('vpp-vyos') + break + except (ConnectionRefusedError, FileNotFoundError) as err: + print(f'VPP API connection timeout: {err}') + sleep(interval / 1000) + # raise exception if connection was not successful in the end + if not self.vpp_api_client.transport.connected: + raise VPPIOError(2, 'Cannot connect to VPP API') + + def __del__(self) -> None: + """Disconnect from VPP API (destructor) + """ + self.disconnect() + + def disconnect(self) -> None: + """Disconnect from VPP API + """ + if self.vpp_api_client.transport.connected: + self.vpp_api_client.disconnect() + + @_Decorators.check_retval + @_Decorators.api_call + def cli_cmd(self, command: str): + """Send raw CLI command + + Args: + command (str): command to send + + Returns: + vpp_papi.vpp_serializer.cli_inband_reply: CLI reply class + """ + return self.vpp_api_client.api.cli_inband(cmd=command) + + @_Decorators.api_call + def get_mac(self, ifname: str) -> str: + """Find MAC address by interface name in VPP + + Args: + ifname (str): interface name inside VPP + + Returns: + str: MAC address + """ + for iface in self.vpp_api_client.api.sw_interface_dump(): + if iface.interface_name == ifname: + return iface.l2_address.mac_string + return '' + + @_Decorators.api_call + def get_sw_if_index(self, ifname: str) -> int | None: + """Find interface index by interface name in VPP + + Args: + ifname (str): interface name inside VPP + + Returns: + int | None: Interface index or None (if was not fount) + """ + for iface in self.vpp_api_client.api.sw_interface_dump(): + if iface.interface_name == ifname: + return iface.sw_if_index + return None + + @_Decorators.check_retval + @_Decorators.api_call + def lcp_pair_add(self, iface_name_vpp: str, iface_name_kernel: str) -> None: + """Create LCP interface pair between VPP and kernel + + Args: + iface_name_vpp (str): interface name in VPP + iface_name_kernel (str): interface name in kernel + """ + iface_index = self.get_sw_if_index(iface_name_vpp) + if iface_index: + return self.vpp_api_client.api.lcp_itf_pair_add_del( + is_add=True, + sw_if_index=iface_index, + host_if_name=iface_name_kernel) + + @_Decorators.check_retval + @_Decorators.api_call + def lcp_pair_del(self, iface_name_vpp: str, iface_name_kernel: str) -> None: + """Delete LCP interface pair between VPP and kernel + + Args: + iface_name_vpp (str): interface name in VPP + iface_name_kernel (str): interface name in kernel + """ + iface_index = self.get_sw_if_index(iface_name_vpp) + if iface_index: + return self.vpp_api_client.api.lcp_itf_pair_add_del( + is_add=False, + sw_if_index=iface_index, + host_if_name=iface_name_kernel) + + @_Decorators.check_retval + @_Decorators.api_call + def iface_rxmode(self, iface_name: str, rx_mode: str) -> None: + """Set interface rx-mode in VPP + + Args: + iface_name (str): interface name in VPP + rx_mode (str): mode (polling, interrupt, adaptive) + """ + modes_dict: dict[str, int] = { + 'polling': 1, + 'interrupt': 2, + 'adaptive': 3 + } + if rx_mode not in modes_dict: + raise VPPValueError(f'Mode {rx_mode} is not known') + iface_index = self.get_sw_if_index(iface_name) + return self.vpp_api_client.api.sw_interface_set_rx_mode( + sw_if_index=iface_index, mode=modes_dict[rx_mode]) + + @_Decorators.api_call + def get_pci_addr(self, ifname: str) -> str: + """Find PCI address of interface by interface name in VPP + + Args: + ifname (str): interface name inside VPP + + Returns: + str: PCI address + """ + hw_info = self.cli_cmd(f'show hardware-interfaces {ifname}').reply + + regex_filter = r'^\s+pci: device (?P<device>\w+:\w+) subsystem (?P<subsystem>\w+:\w+) address (?P<address>\w+:\w+:\w+\.\w+) numa (?P<numa>\w+)$' + re_obj = re_search(regex_filter, hw_info, re_M) + + # return empty string if no interface or no PCI info was found + if not hw_info or not re_obj: + return '' + + address = re_obj.groupdict().get('address', '') + + # we need to modify address to math kernel style + # for example: 0000:06:14.00 -> 0000:06:14.0 + address_chunks: list[str] = address.split('.') + address_normalized: str = f'{address_chunks[0]}.{int(address_chunks[1])}' + + return address_normalized + + +class HostControl: + """Control Linux host + """ + + @staticmethod + def pci_rescan(pci_addr: str = '') -> None: + """Rescan PCI device by removing it and rescan PCI bus + + If PCI address is not defined - just rescan PCI bus + + Args: + address (str, optional): PCI address of device. Defaults to ''. + """ + if pci_addr: + device_file = Path(f'/sys/bus/pci/devices/{pci_addr}/remove') + if device_file.exists(): + device_file.write_text('1') + # wait 10 seconds max until device will be removed + attempts = 100 + while device_file.exists() and attempts: + attempts -= 1 + sleep(0.1) + if device_file.exists(): + raise TimeoutError( + f'Timeout was reached for removing PCI device {pci_addr}' + ) + else: + raise FileNotFoundError(f'PCI device {pci_addr} does not exist') + rescan_file = Path('/sys/bus/pci/rescan') + rescan_file.write_text('1') + if pci_addr: + # wait 10 seconds max until device will be installed + attempts = 100 + while not device_file.exists() and attempts: + attempts -= 1 + sleep(0.1) + if not device_file.exists(): + raise TimeoutError( + f'Timeout was reached for installing PCI device {pci_addr}') + + @staticmethod + def get_eth_name(pci_addr: str) -> str: + """Find Ethernet interface name by PCI address + + Args: + pci_addr (str): PCI address + + Raises: + FileNotFoundError: no Ethernet interface was found + + Returns: + str: Ethernet interface name + """ + # find all PCI devices with eth* names + net_devs: dict[str, str] = {} + net_devs_dir = Path('/sys/class/net') + regex_filter = r'^/sys/devices/pci[\w/:\.]+/(?P<pci_addr>\w+:\w+:\w+\.\w+)/[\w/:\.]+/(?P<iface_name>eth\d+)$' + for dir in net_devs_dir.iterdir(): + real_dir: str = dir.resolve().as_posix() + re_obj = re_fullmatch(regex_filter, real_dir) + if re_obj: + iface_name: str = re_obj.group('iface_name') + iface_addr: str = re_obj.group('pci_addr') + net_devs.update({iface_addr: iface_name}) + # match to provided PCI address and return a name if found + if pci_addr in net_devs: + return net_devs[pci_addr] + # raise error if device was not found + raise FileNotFoundError( + f'PCI device {pci_addr} not found in ethernet interfaces') + + @staticmethod + def rename_iface(name_old: str, name_new: str) -> None: + """Rename interface + + Args: + name_old (str): old name + name_new (str): new name + """ + rename_cmd: list[str] = [ + 'ip', 'link', 'set', name_old, 'name', name_new + ] + run(rename_cmd) diff --git a/python/vyos/xml/load.py b/python/vyos/xml/load.py index c3022f3d6..f842ff9ce 100644 --- a/python/vyos/xml/load.py +++ b/python/vyos/xml/load.py @@ -71,16 +71,12 @@ def _merge(dict1, dict2): continue if isinstance(dict1[k], dict) and isinstance(dict2[k], dict): dict1[k] = _merge(dict1[k], dict2[k]) - elif isinstance(dict1[k], dict) and isinstance(dict2[k], dict): + elif isinstance(dict1[k], list) and isinstance(dict2[k], list): dict1[k].extend(dict2[k]) elif dict1[k] == dict2[k]: - # A definition shared between multiple files - if k in (kw.valueless, kw.multi, kw.hidden, kw.node, kw.summary, kw.owner, kw.priority): - continue - _fatal() - raise RuntimeError('parsing issue - undefined leaf?') + continue else: - raise RuntimeError('parsing issue - we messed up?') + dict1[k] = dict2[k] return dict1 @@ -131,7 +127,7 @@ def _format_nodes(inside, conf, xml): name = node.pop('@name') into = inside + [name] if name in r: - r[name].update(_format_node(into, node, xml)) + _merge(r[name], _format_node(into, node, xml)) else: r[name] = _format_node(into, node, xml) r[name][kw.node] = nodename @@ -141,7 +137,7 @@ def _format_nodes(inside, conf, xml): name = node.pop('@name') into = inside + [name] if name in r: - r[name].update(_format_node(inside + [name], node, xml)) + _merge(r[name], _format_node(inside + [name], node, xml)) else: r[name] = _format_node(inside + [name], node, xml) r[name][kw.node] = nodename @@ -180,10 +176,10 @@ def _format_node(inside, conf, xml): if isinstance(conf, list): for child in children: - r = _safe_update(r, _format_nodes(inside, child, xml)) + _merge(r, _format_nodes(inside, child, xml)) else: child = children - r = _safe_update(r, _format_nodes(inside, child, xml)) + _merge(r, _format_nodes(inside, child, xml)) elif 'properties' in keys: properties = conf.pop('properties') diff --git a/python/vyos/xml_ref/__init__.py b/python/vyos/xml_ref/__init__.py new file mode 100644 index 000000000..bf434865d --- /dev/null +++ b/python/vyos/xml_ref/__init__.py @@ -0,0 +1,83 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +from typing import Optional, Union, TYPE_CHECKING +from vyos.xml_ref import definition + +if TYPE_CHECKING: + from vyos.config import ConfigDict + +def load_reference(cache=[]): + if cache: + return cache[0] + + xml = definition.Xml() + + try: + from vyos.xml_ref.cache import reference + except Exception: + raise ImportError('no xml reference cache !!') + + if not reference: + raise ValueError('empty xml reference cache !!') + + xml.define(reference) + cache.append(xml) + + return xml + +def is_tag(path: list) -> bool: + return load_reference().is_tag(path) + +def is_tag_value(path: list) -> bool: + return load_reference().is_tag_value(path) + +def is_multi(path: list) -> bool: + return load_reference().is_multi(path) + +def is_valueless(path: list) -> bool: + return load_reference().is_valueless(path) + +def is_leaf(path: list) -> bool: + return load_reference().is_leaf(path) + +def cli_defined(path: list, node: str, non_local=False) -> bool: + return load_reference().cli_defined(path, node, non_local=non_local) + +def component_version() -> dict: + return load_reference().component_version() + +def default_value(path: list) -> Optional[Union[str, list]]: + return load_reference().default_value(path) + +def multi_to_list(rpath: list, conf: dict) -> dict: + return load_reference().multi_to_list(rpath, conf) + +def get_defaults(path: list, get_first_key=False, recursive=False) -> dict: + return load_reference().get_defaults(path, get_first_key=get_first_key, + recursive=recursive) + +def relative_defaults(rpath: list, conf: dict, get_first_key=False, + recursive=False) -> dict: + + return load_reference().relative_defaults(rpath, conf, + get_first_key=get_first_key, + recursive=recursive) + +def from_source(d: dict, path: list) -> bool: + return definition.from_source(d, path) + +def ext_dict_merge(source: dict, destination: Union[dict, 'ConfigDict']): + return definition.ext_dict_merge(source, destination) diff --git a/python/vyos/xml_ref/definition.py b/python/vyos/xml_ref/definition.py new file mode 100644 index 000000000..c90c5ddbc --- /dev/null +++ b/python/vyos/xml_ref/definition.py @@ -0,0 +1,302 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +from typing import Optional, Union, Any, TYPE_CHECKING + +# https://peps.python.org/pep-0484/#forward-references +# for type 'ConfigDict' +if TYPE_CHECKING: + from vyos.config import ConfigDict + +def set_source_recursive(o: Union[dict, str, list], b: bool): + d = {} + if not isinstance(o, dict): + d = {'_source': b} + else: + for k, v in o.items(): + d[k] = set_source_recursive(v, b) + d |= {'_source': b} + return d + +def source_dict_merge(src: dict, dest: dict): + from copy import deepcopy + dst = deepcopy(dest) + from_src = {} + + for key, value in src.items(): + if key not in dst: + dst[key] = value + from_src[key] = set_source_recursive(value, True) + elif isinstance(src[key], dict): + dst[key], f = source_dict_merge(src[key], dst[key]) + f |= {'_source': False} + from_src[key] = f + + return dst, from_src + +def ext_dict_merge(src: dict, dest: Union[dict, 'ConfigDict']): + d, f = source_dict_merge(src, dest) + if hasattr(d, '_from_defaults'): + setattr(d, '_from_defaults', f) + return d + +def from_source(d: dict, path: list) -> bool: + for key in path: + d = d[key] if key in d else {} + if not d or not isinstance(d, dict): + return False + return d.get('_source', False) + +class Xml: + def __init__(self): + self.ref = {} + + def define(self, ref: dict): + self.ref = ref + + def _get_ref_node_data(self, node: dict, data: str) -> Union[bool, str]: + res = node.get('node_data', {}) + if not res: + raise ValueError("non-existent node data") + if data not in res: + raise ValueError("non-existent data field") + + return res.get(data) + + def _get_ref_path(self, path: list) -> dict: + ref_path = path.copy() + d = self.ref + while ref_path and d: + d = d.get(ref_path[0], {}) + ref_path.pop(0) + if self._is_tag_node(d) and ref_path: + ref_path.pop(0) + + return d + + def _is_tag_node(self, node: dict) -> bool: + res = self._get_ref_node_data(node, 'node_type') + return res == 'tag' + + def is_tag(self, path: list) -> bool: + ref_path = path.copy() + d = self.ref + while ref_path and d: + d = d.get(ref_path[0], {}) + ref_path.pop(0) + if self._is_tag_node(d) and ref_path: + if len(ref_path) == 1: + return False + ref_path.pop(0) + + return self._is_tag_node(d) + + def is_tag_value(self, path: list) -> bool: + if len(path) < 2: + return False + + return self.is_tag(path[:-1]) + + def _is_multi_node(self, node: dict) -> bool: + b = self._get_ref_node_data(node, 'multi') + assert isinstance(b, bool) + return b + + def is_multi(self, path: list) -> bool: + d = self._get_ref_path(path) + return self._is_multi_node(d) + + def _is_valueless_node(self, node: dict) -> bool: + b = self._get_ref_node_data(node, 'valueless') + assert isinstance(b, bool) + return b + + def is_valueless(self, path: list) -> bool: + d = self._get_ref_path(path) + return self._is_valueless_node(d) + + def _is_leaf_node(self, node: dict) -> bool: + res = self._get_ref_node_data(node, 'node_type') + return res == 'leaf' + + def is_leaf(self, path: list) -> bool: + d = self._get_ref_path(path) + return self._is_leaf_node(d) + + @staticmethod + def _dict_get(d: dict, path: list) -> dict: + for i in path: + d = d.get(i, {}) + if not isinstance(d, dict): + return {} + if not d: + break + return d + + def _dict_find(self, d: dict, key: str, non_local=False) -> bool: + for k in list(d): + if k in ('node_data', 'component_version'): + continue + if k == key: + return True + if non_local and isinstance(d[k], dict): + if self._dict_find(d[k], key): + return True + return False + + def cli_defined(self, path: list, node: str, non_local=False) -> bool: + d = self._dict_get(self.ref, path) + return self._dict_find(d, node, non_local=non_local) + + def component_version(self) -> dict: + d = {} + for k, v in self.ref['component_version'].items(): + d[k] = int(v) + return d + + def multi_to_list(self, rpath: list, conf: dict) -> dict: + res: Any = {} + + for k in list(conf): + d = self._get_ref_path(rpath + [k]) + if self._is_leaf_node(d): + if self._is_multi_node(d) and not isinstance(conf[k], list): + res[k] = [conf[k]] + else: + res[k] = conf[k] + else: + res[k] = self.multi_to_list(rpath + [k], conf[k]) + + return res + + def _get_default_value(self, node: dict) -> Optional[str]: + return self._get_ref_node_data(node, "default_value") + + def _get_default(self, node: dict) -> Optional[Union[str, list]]: + default = self._get_default_value(node) + if default is None: + return None + if self._is_multi_node(node): + return default.split() + return default + + def default_value(self, path: list) -> Optional[Union[str, list]]: + d = self._get_ref_path(path) + default = self._get_default_value(d) + if default is None: + return None + if self._is_multi_node(d) or self._is_tag_node(d): + return default.split() + return default + + def get_defaults(self, path: list, get_first_key=False, recursive=False) -> dict: + """Return dict containing default values below path + + Note that descent below path will not proceed beyond an encountered + tag node, as no tag node value is known. For a default dict relative + to an existing config dict containing tag node values, see function: + 'relative_defaults' + """ + res: dict = {} + if self.is_tag(path): + return res + + d = self._get_ref_path(path) + + if self._is_leaf_node(d): + default_value = self._get_default(d) + if default_value is not None: + return {path[-1]: default_value} if path else {} + + for k in list(d): + if k in ('node_data', 'component_version') : + continue + if self._is_leaf_node(d[k]): + default_value = self._get_default(d[k]) + if default_value is not None: + res |= {k: default_value} + elif self.is_tag(path + [k]): + # tag node defaults are used as suggestion, not default value; + # should this change, append to path and continue if recursive + pass + else: + if recursive: + pos = self.get_defaults(path + [k], recursive=True) + res |= pos + if res: + if get_first_key or not path: + return res + return {path[-1]: res} + + return {} + + def _well_defined(self, path: list, conf: dict) -> bool: + # test disjoint path + conf for sensible config paths + def step(c): + return [next(iter(c.keys()))] if c else [] + try: + tmp = step(conf) + if tmp and self.is_tag_value(path + tmp): + c = conf[tmp[0]] + if not isinstance(c, dict): + raise ValueError + tmp = tmp + step(c) + self._get_ref_path(path + tmp) + else: + self._get_ref_path(path + tmp) + except ValueError: + return False + return True + + def _relative_defaults(self, rpath: list, conf: dict, recursive=False) -> dict: + res: dict = {} + res = self.get_defaults(rpath, recursive=recursive, + get_first_key=True) + for k in list(conf): + if isinstance(conf[k], dict): + step = self._relative_defaults(rpath + [k], conf=conf[k], + recursive=recursive) + res |= step + + if res: + return {rpath[-1]: res} if rpath else res + + return {} + + def relative_defaults(self, path: list, conf: dict, get_first_key=False, + recursive=False) -> dict: + """Return dict containing defaults along paths of a config dict + """ + if not conf: + return self.get_defaults(path, get_first_key=get_first_key, + recursive=recursive) + if not self._well_defined(path, conf): + # adjust for possible overlap: + if path and path[-1] in list(conf): + conf = conf[path[-1]] + conf = {} if not isinstance(conf, dict) else conf + if not self._well_defined(path, conf): + print('path to config dict does not define full config paths') + return {} + + res = self._relative_defaults(path, conf, recursive=recursive) + + if get_first_key and path: + if res.values(): + res = next(iter(res.values())) + else: + res = {} + + return res diff --git a/python/vyos/xml_ref/generate_cache.py b/python/vyos/xml_ref/generate_cache.py new file mode 100755 index 000000000..6a05d4608 --- /dev/null +++ b/python/vyos/xml_ref/generate_cache.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2023 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# + +import sys +import json +from argparse import ArgumentParser +from argparse import ArgumentTypeError +from os import getcwd +from os import makedirs +from os.path import join +from os.path import abspath +from os.path import dirname +from os.path import basename +from xmltodict import parse + +_here = dirname(__file__) + +sys.path.append(join(_here, '..')) +from configtree import reference_tree_to_json, ConfigTreeError + +xml_cache_json = 'xml_cache.json' +xml_tmp = join('/tmp', xml_cache_json) +pkg_cache = abspath(join(_here, 'pkg_cache')) +ref_cache = abspath(join(_here, 'cache.py')) + +node_data_fields = ("node_type", "multi", "valueless", "default_value") + +def trim_node_data(cache: dict): + for k in list(cache): + if k == "node_data": + for l in list(cache[k]): + if l not in node_data_fields: + del cache[k][l] + else: + if isinstance(cache[k], dict): + trim_node_data(cache[k]) + +def non_trivial(s): + if not s: + raise ArgumentTypeError("Argument must be non empty string") + return s + +def main(): + parser = ArgumentParser(description='generate and save dict from xml defintions') + parser.add_argument('--xml-dir', type=str, required=True, + help='transcluded xml interface-definition directory') + parser.add_argument('--package-name', type=non_trivial, default='vyos-1x', + help='name of current package') + parser.add_argument('--output-path', help='path to generated cache') + args = vars(parser.parse_args()) + + xml_dir = abspath(args['xml_dir']) + pkg_name = args['package_name'].replace('-','_') + cache_name = pkg_name + '_cache.py' + out_path = args['output_path'] + path = out_path if out_path is not None else pkg_cache + xml_cache = abspath(join(path, cache_name)) + + try: + reference_tree_to_json(xml_dir, xml_tmp) + except ConfigTreeError as e: + print(e) + sys.exit(1) + + with open(xml_tmp) as f: + d = json.loads(f.read()) + + trim_node_data(d) + + syntax_version = join(xml_dir, 'xml-component-version.xml') + try: + with open(syntax_version) as f: + component = f.read() + except FileNotFoundError: + if pkg_name != 'vyos_1x': + component = '' + else: + print("\nWARNING: missing xml-component-version.xml\n") + sys.exit(1) + + if component: + parsed = parse(component) + else: + parsed = None + version = {} + # addon package definitions may have empty (== 0) version info + if parsed is not None and parsed['interfaceDefinition'] is not None: + converted = parsed['interfaceDefinition']['syntaxVersion'] + if not isinstance(converted, list): + converted = [converted] + for i in converted: + tmp = {i['@component']: i['@version']} + version |= tmp + + version = {"component_version": version} + + d |= version + + with open(xml_cache, 'w') as f: + f.write(f'reference = {str(d)}') + + print(cache_name) + +if __name__ == '__main__': + main() diff --git a/python/vyos/xml_ref/pkg_cache/__init__.py b/python/vyos/xml_ref/pkg_cache/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/python/vyos/xml_ref/pkg_cache/__init__.py diff --git a/python/vyos/xml_ref/update_cache.py b/python/vyos/xml_ref/update_cache.py new file mode 100755 index 000000000..0842bcbe9 --- /dev/null +++ b/python/vyos/xml_ref/update_cache.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2023 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# +import os +from copy import deepcopy +from generate_cache import pkg_cache +from generate_cache import ref_cache + +def dict_merge(source, destination): + dest = deepcopy(destination) + + for key, value in source.items(): + if key not in dest: + dest[key] = value + elif isinstance(source[key], dict): + dest[key] = dict_merge(source[key], dest[key]) + + return dest + +def main(): + res = {} + cache_dir = os.path.basename(pkg_cache) + for mod in os.listdir(pkg_cache): + mod = os.path.splitext(mod)[0] + if not mod.endswith('_cache'): + continue + d = getattr(__import__(f'{cache_dir}.{mod}', fromlist=[mod]), 'reference') + if mod == 'vyos_1x_cache': + res = dict_merge(res, d) + else: + res = dict_merge(d, res) + + with open(ref_cache, 'w') as f: + f.write(f'reference = {str(res)}') + +if __name__ == '__main__': + main() |