summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cloudinit/ssh_util.py5
-rw-r--r--tests/unittests/test_sshutil.py42
2 files changed, 43 insertions, 4 deletions
diff --git a/cloudinit/ssh_util.py b/cloudinit/ssh_util.py
index b95b956f..882517f5 100644
--- a/cloudinit/ssh_util.py
+++ b/cloudinit/ssh_util.py
@@ -171,16 +171,13 @@ def parse_authorized_keys(fname):
def update_authorized_keys(old_entries, keys):
- to_add = list(keys)
-
+ to_add = list([k for k in keys if k.valid()])
for i in range(0, len(old_entries)):
ent = old_entries[i]
if not ent.valid():
continue
# Replace those with the same base64
for k in keys:
- if not ent.valid():
- continue
if k.base64 == ent.base64:
# Replace it with our better one
ent = k
diff --git a/tests/unittests/test_sshutil.py b/tests/unittests/test_sshutil.py
index 2a8e6abe..4c62c8be 100644
--- a/tests/unittests/test_sshutil.py
+++ b/tests/unittests/test_sshutil.py
@@ -126,6 +126,48 @@ class TestAuthKeyLineParser(test_helpers.TestCase):
self.assertFalse(key.valid())
+class TestUpdateAuthorizedKeys(test_helpers.TestCase):
+
+ def test_new_keys_replace(self):
+ """new entries with the same base64 should replace old."""
+ orig_entries = [
+ ' '.join(('rsa', VALID_CONTENT['rsa'], 'orig_comment1')),
+ ' '.join(('dsa', VALID_CONTENT['dsa'], 'orig_comment2'))]
+
+ new_entries = [
+ ' '.join(('rsa', VALID_CONTENT['rsa'], 'new_comment1')), ]
+
+ expected = '\n'.join([new_entries[0], orig_entries[1]]) + '\n'
+
+ parser = ssh_util.AuthKeyLineParser()
+ found = ssh_util.update_authorized_keys(
+ [parser.parse(p) for p in orig_entries],
+ [parser.parse(p) for p in new_entries])
+
+ self.assertEqual(expected, found)
+
+ def test_new_invalid_keys_are_ignored(self):
+ """new entries that are invalid should be skipped."""
+ orig_entries = [
+ ' '.join(('rsa', VALID_CONTENT['rsa'], 'orig_comment1')),
+ ' '.join(('dsa', VALID_CONTENT['dsa'], 'orig_comment2'))]
+
+ new_entries = [
+ ' '.join(('rsa', VALID_CONTENT['rsa'], 'new_comment1')),
+ 'xxx-invalid-thing1',
+ 'xxx-invalid-blob2'
+ ]
+
+ expected = '\n'.join([new_entries[0], orig_entries[1]]) + '\n'
+
+ parser = ssh_util.AuthKeyLineParser()
+ found = ssh_util.update_authorized_keys(
+ [parser.parse(p) for p in orig_entries],
+ [parser.parse(p) for p in new_entries])
+
+ self.assertEqual(expected, found)
+
+
class TestParseSSHConfig(test_helpers.TestCase):
def setUp(self):