#!/usr/bin/env python3
#
# Copyright (C) 2020 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import sys
import time
import logging
import unittest

from vyos.configsession import ConfigSession, ConfigSessionError
from vyos import ConfigError

config_dir = '/usr/libexec/vyos/tests/config'
config_test_dir = '/usr/libexec/vyos/tests/config-tests'
save_config = '/tmp/vyos-configtest-save'

class DynamicClassBase(unittest.TestCase):
    def setUp(self):
        self._start_time = time.time()
        self.session = ConfigSession(os.getpid())
        self.session.save_config(save_config)

    def tearDown(self):
        self.session.migrate_and_load_config(save_config)
        self.session.commit()
        log.info(f" time: {time.time() - self._start_time:.3f}")
        del self.session
        try:
            os.remove(save_config)
        except OSError:
            pass

def make_test_function(filename, test_path=None):
    def test_config_load(self):
        config_path = os.path.join(config_dir, filename)
        self.session.migrate_and_load_config(config_path)
        try:
            self.session.commit()
        except (ConfigError, ConfigSessionError):
            self.session.discard()
            self.fail()

        if test_path:
            config_commands = self.session.show(['configuration', 'commands'])
            
            with open(test_path, 'r') as f:
                for line in f.readlines():
                    if not line or line.startswith("#"):
                        continue

                    self.assertIn(line, config_commands)
    return test_config_load

def class_name_from_func_name(s):
    res = ''.join(str.capitalize(x) for x in s.split('_'))
    return res

if __name__ == '__main__':
    logging.basicConfig(stream=sys.stdout, level=logging.DEBUG,
                        format='%(message)s')
    log = logging.getLogger("TestConfigLog")

    start_time = time.time()
    log.info("Generating tests")

    (_, _, config_list) = next(iter(os.walk(config_dir)))
    config_list.sort()

    for config in config_list:
        test_path = os.path.join(config_test_dir, config)

        if not os.path.exists(test_path):
            test_path = None
        else:
            log.info(f'Loaded migration result test for config "{config}"')

        test_func = make_test_function(config, test_path)

        func_name = config.replace('-', '_')
        klassname = f'TestConfig{class_name_from_func_name(func_name)}'

        globals()[klassname] = type(klassname,
                                    (DynamicClassBase,),
                                    {f'test_{func_name}': test_func})

    log.info(f"... completed: {time.time() - start_time:.6f}")

    unittest.main(verbosity=2)