# Copyright (C) 2018-2024 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import unittest import vyos.configtree import vyos.initialsetup as vis from unittest import TestCase from vyos.xml_ref import definition from vyos.xml_ref.pkg_cache.vyos_1x_cache import reference class TestInitialSetup(TestCase): def setUp(self): with open('tests/data/config.boot.default', 'r') as f: config_string = f.read() self.config = vyos.configtree.ConfigTree(config_string) self.xml = definition.Xml() self.xml.define(reference) def test_set_user_password(self): vis.set_user_password(self.config, 'vyos', 'vyosvyos') # Old password hash from the default config old_pw = '$6$QxPS.uk6mfo$9QBSo8u1FkH16gMyAVhus6fU3LOzvLR9Z9.82m3tiHFAxTtIkhaZSWssSgzt4v4dGAL8rhVQxTg0oAG9/q11h/' new_pw = self.config.return_value(["system", "login", "user", "vyos", "authentication", "encrypted-password"]) # Just check it changed the hash, don't try to check if hash is good self.assertNotEqual(old_pw, new_pw) def test_disable_user_password(self): vis.disable_user_password(self.config, 'vyos') new_pw = self.config.return_value(["system", "login", "user", "vyos", "authentication", "encrypted-password"]) self.assertEqual(new_pw, '!') def test_set_ssh_key_with_name(self): test_ssh_key = " ssh-rsa fakedata vyos@vyos " vis.set_user_ssh_key(self.config, 'vyos', test_ssh_key) key_type = self.config.return_value(["system", "login", "user", "vyos", "authentication", "public-keys", "vyos@vyos", "type"]) key_data = self.config.return_value(["system", "login", "user", "vyos", "authentication", "public-keys", "vyos@vyos", "key"]) self.assertEqual(key_type, 'ssh-rsa') self.assertEqual(key_data, 'fakedata') self.assertTrue(self.xml.is_tag(["system", "login", "user", "vyos", "authentication", "public-keys"])) def test_set_ssh_key_without_name(self): # If key file doesn't include a name, the function will use user name for the key name test_ssh_key = " ssh-rsa fakedata " vis.set_user_ssh_key(self.config, 'vyos', test_ssh_key) key_type = self.config.return_value(["system", "login", "user", "vyos", "authentication", "public-keys", "vyos", "type"]) key_data = self.config.return_value(["system", "login", "user", "vyos", "authentication", "public-keys", "vyos", "key"]) self.assertEqual(key_type, 'ssh-rsa') self.assertEqual(key_data, 'fakedata') self.assertTrue(self.xml.is_tag(["system", "login", "user", "vyos", "authentication", "public-keys"])) def test_create_user(self): vis.create_user(self.config, 'jrandomhacker', password='qwerty', key=" ssh-rsa fakedata jrandomhacker@foovax ") self.assertTrue(self.config.exists(["system", "login", "user", "jrandomhacker"])) self.assertTrue(self.config.exists(["system", "login", "user", "jrandomhacker", "authentication", "public-keys", "jrandomhacker@foovax"])) self.assertTrue(self.config.exists(["system", "login", "user", "jrandomhacker", "authentication", "encrypted-password"])) self.assertEqual(self.config.return_value(["system", "login", "user", "jrandomhacker", "level"]), "admin") def test_set_hostname(self): vis.set_host_name(self.config, "vyos-test") self.assertEqual(self.config.return_value(["system", "host-name"]), "vyos-test") def test_set_name_servers(self): vis.set_name_servers(self.config, ["192.0.2.10", "203.0.113.20"]) servers = self.config.return_values(["system", "name-server"]) self.assertIn("192.0.2.10", servers) self.assertIn("203.0.113.20", servers) def test_set_gateway(self): vis.set_default_gateway(self.config, '192.0.2.1') self.assertTrue(self.config.exists(['protocols', 'static', 'route', '0.0.0.0/0', 'next-hop', '192.0.2.1'])) self.assertTrue(self.xml.is_tag(['protocols', 'static', 'mroute', '0.0.0.0/0', 'next-hop'])) self.assertTrue(self.xml.is_tag(['protocols', 'static', 'mroute'])) if __name__ == "__main__": unittest.main()