diff options
-rw-r--r-- | cloudinit/ssh_util.py | 72 | ||||
-rw-r--r-- | tests/unittests/test_sshutil.py | 83 |
2 files changed, 126 insertions, 29 deletions
diff --git a/cloudinit/ssh_util.py b/cloudinit/ssh_util.py index 3f99b58c..bcb23a5a 100644 --- a/cloudinit/ssh_util.py +++ b/cloudinit/ssh_util.py @@ -160,19 +160,19 @@ class AuthKeyLineParser(object): comment=comment, options=options) -def parse_authorized_keys(fname): +def parse_authorized_keys(fnames): lines = [] - try: - if os.path.isfile(fname): - lines = util.load_file(fname).splitlines() - except (IOError, OSError): - util.logexc(LOG, "Error reading lines from %s", fname) - lines = [] - parser = AuthKeyLineParser() contents = [] - for line in lines: - contents.append(parser.parse(line)) + for fname in fnames: + try: + if os.path.isfile(fname): + lines = util.load_file(fname).splitlines() + for line in lines: + contents.append(parser.parse(line)) + except (IOError, OSError): + util.logexc(LOG, "Error reading lines from %s", fname) + return contents @@ -211,32 +211,46 @@ def users_ssh_info(username): return (os.path.join(pw_ent.pw_dir, '.ssh'), pw_ent) -def extract_authorized_keys(username): +def render_authorizedkeysfile_paths(value, homedir, username): + # The 'AuthorizedKeysFile' may contain tokens + # of the form %T which are substituted during connection set-up. + # The following tokens are defined: %% is replaced by a literal + # '%', %h is replaced by the home directory of the user being + # authenticated and %u is replaced by the username of that user. + macros = (("%h", homedir), ("%u", username), ("%%", "%")) + if not value: + value = "%h/.ssh/authorized_keys" + paths = value.split() + rendered = [] + for path in paths: + for macro, field in macros: + path = path.replace(macro, field) + if not path.startswith("/"): + path = os.path.join(homedir, path) + rendered.append(path) + return rendered + + +def extract_authorized_keys(username, sshd_cfg_file=DEF_SSHD_CFG): (ssh_dir, pw_ent) = users_ssh_info(username) - auth_key_fn = None + default_authorizedkeys_file = os.path.join(ssh_dir, 'authorized_keys') + auth_key_fns = [] with util.SeLinuxGuard(ssh_dir, recursive=True): try: - # The 'AuthorizedKeysFile' may contain tokens - # of the form %T which are substituted during connection set-up. - # The following tokens are defined: %% is replaced by a literal - # '%', %h is replaced by the home directory of the user being - # authenticated and %u is replaced by the username of that user. - ssh_cfg = parse_ssh_config_map(DEF_SSHD_CFG) - auth_key_fn = ssh_cfg.get("authorizedkeysfile", '').strip() - if not auth_key_fn: - auth_key_fn = "%h/.ssh/authorized_keys" - auth_key_fn = auth_key_fn.replace("%h", pw_ent.pw_dir) - auth_key_fn = auth_key_fn.replace("%u", username) - auth_key_fn = auth_key_fn.replace("%%", '%') - if not auth_key_fn.startswith('/'): - auth_key_fn = os.path.join(pw_ent.pw_dir, auth_key_fn) + ssh_cfg = parse_ssh_config_map(sshd_cfg_file) + auth_key_fns = render_authorizedkeysfile_paths( + ssh_cfg.get("authorizedkeysfile", "%h/.ssh/authorized_keys"), + pw_ent.pw_dir, username) + except (IOError, OSError): # Give up and use a default key filename - auth_key_fn = os.path.join(ssh_dir, 'authorized_keys') + auth_key_fns[0] = default_authorizedkeys_file util.logexc(LOG, "Failed extracting 'AuthorizedKeysFile' in ssh " "config from %r, using 'AuthorizedKeysFile' file " - "%r instead", DEF_SSHD_CFG, auth_key_fn) - return (auth_key_fn, parse_authorized_keys(auth_key_fn)) + "%r instead", DEF_SSHD_CFG, auth_key_fns[0]) + + # always store all the keys in the user's private file + return (default_authorizedkeys_file, parse_authorized_keys(auth_key_fns)) def setup_user_keys(keys, username, options=None): diff --git a/tests/unittests/test_sshutil.py b/tests/unittests/test_sshutil.py index 73ae897f..b227c20b 100644 --- a/tests/unittests/test_sshutil.py +++ b/tests/unittests/test_sshutil.py @@ -1,11 +1,19 @@ # This file is part of cloud-init. See LICENSE file for license information. from mock import patch +from collections import namedtuple from cloudinit import ssh_util from cloudinit.tests import helpers as test_helpers from cloudinit import util +# https://stackoverflow.com/questions/11351032/ +FakePwEnt = namedtuple( + 'FakePwEnt', + ['pw_dir', 'pw_gecos', 'pw_name', 'pw_passwd', 'pw_shell', 'pwd_uid']) +FakePwEnt.__new__.__defaults__ = tuple( + "UNSET_%s" % n for n in FakePwEnt._fields) + VALID_CONTENT = { 'dsa': ( @@ -326,4 +334,79 @@ class TestUpdateSshConfig(test_helpers.CiTestCase): m_write_file.assert_not_called() +class TestBasicAuthorizedKeyParse(test_helpers.CiTestCase): + def test_user(self): + self.assertEqual( + ["/opt/bobby/keys"], + ssh_util.render_authorizedkeysfile_paths( + "/opt/%u/keys", "/home/bobby", "bobby")) + + def test_multiple(self): + self.assertEqual( + ["/keys/path1", "/keys/path2"], + ssh_util.render_authorizedkeysfile_paths( + "/keys/path1 /keys/path2", "/home/bobby", "bobby")) + + def test_relative(self): + self.assertEqual( + ["/home/bobby/.secret/keys"], + ssh_util.render_authorizedkeysfile_paths( + ".secret/keys", "/home/bobby", "bobby")) + + def test_home(self): + self.assertEqual( + ["/homedirs/bobby/.keys"], + ssh_util.render_authorizedkeysfile_paths( + "%h/.keys", "/homedirs/bobby", "bobby")) + + +class TestMultipleSshAuthorizedKeysFile(test_helpers.CiTestCase): + + @patch("cloudinit.ssh_util.pwd.getpwnam") + def test_multiple_authorizedkeys_file_order1(self, m_getpwnam): + fpw = FakePwEnt(pw_name='bobby', pw_dir='/home2/bobby') + m_getpwnam.return_value = fpw + authorized_keys = self.tmp_path('authorized_keys') + util.write_file(authorized_keys, VALID_CONTENT['rsa']) + + user_keys = self.tmp_path('user_keys') + util.write_file(user_keys, VALID_CONTENT['dsa']) + + sshd_config = self.tmp_path('sshd_config') + util.write_file( + sshd_config, + "AuthorizedKeysFile %s %s" % (authorized_keys, user_keys)) + + (auth_key_fn, auth_key_entries) = ssh_util.extract_authorized_keys( + fpw.pw_name, sshd_config) + content = ssh_util.update_authorized_keys( + auth_key_entries, []) + + self.assertEqual("%s/.ssh/authorized_keys" % fpw.pw_dir, auth_key_fn) + self.assertTrue(VALID_CONTENT['rsa'] in content) + self.assertTrue(VALID_CONTENT['dsa'] in content) + + @patch("cloudinit.ssh_util.pwd.getpwnam") + def test_multiple_authorizedkeys_file_order2(self, m_getpwnam): + fpw = FakePwEnt(pw_name='suzie', pw_dir='/home/suzie') + m_getpwnam.return_value = fpw + authorized_keys = self.tmp_path('authorized_keys') + util.write_file(authorized_keys, VALID_CONTENT['rsa']) + + user_keys = self.tmp_path('user_keys') + util.write_file(user_keys, VALID_CONTENT['dsa']) + + sshd_config = self.tmp_path('sshd_config') + util.write_file( + sshd_config, + "AuthorizedKeysFile %s %s" % (authorized_keys, user_keys)) + + (auth_key_fn, auth_key_entries) = ssh_util.extract_authorized_keys( + fpw.pw_name, sshd_config) + content = ssh_util.update_authorized_keys(auth_key_entries, []) + + self.assertEqual("%s/.ssh/authorized_keys" % fpw.pw_dir, auth_key_fn) + self.assertTrue(VALID_CONTENT['rsa'] in content) + self.assertTrue(VALID_CONTENT['dsa'] in content) + # vi: ts=4 expandtab |