summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cloudinit/ssh_util.py118
1 files changed, 71 insertions, 47 deletions
diff --git a/cloudinit/ssh_util.py b/cloudinit/ssh_util.py
index ba252e7f..663afd92 100644
--- a/cloudinit/ssh_util.py
+++ b/cloudinit/ssh_util.py
@@ -32,7 +32,38 @@ LOG = logging.getLogger(__name__)
DEF_SSHD_CFG = "/etc/ssh/sshd_config"
-class AuthKeyEntry(object):
+class AuthKeyLine(object):
+ def __init__(self, source, keytype=None, base64=None,
+ comment=None, options=None):
+ self.base64 = base64
+ self.comment = comment
+ self.options = options
+ self.keytype = keytype
+ self.source = source
+
+ def empty(self):
+ if (not self.base64 and
+ not self.comment and not self.keytype and not self.options):
+ return True
+ return False
+
+ def __str__(self):
+ toks = []
+ if self.options:
+ toks.append(self.options)
+ if self.keytype:
+ toks.append(self.keytype)
+ if self.base64:
+ toks.append(self.base64)
+ if self.comment:
+ toks.append(self.comment)
+ if not toks:
+ return self.source
+ else:
+ return ' '.join(toks)
+
+
+class AuthKeyLineParser(object):
"""
AUTHORIZED_KEYS FILE FORMAT
AuthorizedKeysFile specifies the file containing public keys for public
@@ -52,10 +83,6 @@ class AuthKeyEntry(object):
case-insensitive):
"""
- def __init__(self, line, def_opt=None):
- self.line = str(line)
- (self.value, self.components) = self._parse(self.line, def_opt)
-
def _extract_options(self, ent):
"""
The options (if present) consist of comma-separated option specifica-
@@ -97,10 +124,11 @@ class AuthKeyEntry(object):
# as long as there is room to do this...
toks = []
if i + 1 < len(ent):
- toks = ent[i + 1:].split(None, 3)
+ rest = ent[i + 1:]
+ toks = rest.split(None, 2)
return (options_lst, toks)
- def _form_components(self, toks):
+ def _form_components(self, line, toks, options=None):
components = {}
if len(toks) == 1:
components['base64'] = toks[0]
@@ -111,50 +139,31 @@ class AuthKeyEntry(object):
components['keytype'] = toks[0]
components['base64'] = toks[1]
components['comment'] = toks[2]
- return components
-
- def get(self, piece):
- return self.components.get(piece)
+ components['options'] = options
+ if not components:
+ return AuthKeyLine(line)
+ else:
+ return AuthKeyLine(line, **components)
- def _parse(self, in_line, def_opt):
+ def parse(self, in_line, def_opt=None):
line = in_line.rstrip("\r\n")
if line.startswith("#") or line.strip() == '':
- return (False, {})
+ return AuthKeyLine(source=line)
else:
ent = line.strip()
toks = ent.split(None, 3)
- tmp_components = {}
- if def_opt:
- tmp_components['options'] = def_opt
if len(toks) < 4:
- tmp_components.update(self._form_components(toks))
+ return self._form_components(line, toks, def_opt)
else:
(options, toks) = self._extract_options(ent)
if options:
- tmp_components['options'] = ",".join(options)
- tmp_components.update(self._form_components(toks))
- # We got some useful value!
- return (True, tmp_components)
+ options = ",".join(options)
+ else:
+ options = def_opt
+ return self._form_components(line, toks, options)
- def __str__(self):
- if not self.value:
- return self.line
- else:
- toks = []
- if 'options' in self.components:
- toks.append(self.components['options'])
- if 'keytype' in self.components:
- toks.append(self.components['keytype'])
- if 'base64' in self.components:
- toks.append(self.components['base64'])
- if 'comment' in self.components:
- toks.append(self.components['comment'])
- if not toks:
- return ''
- return ' '.join(toks)
-
-def update_authorized_keys(fname, keys):
+def parse_authorized_keys(fname):
lines = []
try:
if os.path.isfile(fname):
@@ -163,25 +172,38 @@ def update_authorized_keys(fname, keys):
util.logexc(LOG, "Error reading lines from %s", fname)
lines = []
+ parser = AuthKeyLineParser()
+ contents = []
+ for line in lines:
+ contents.append(parser.parse(line))
+ return contents
+
+
+def update_authorized_keys(fname, keys):
+ entries = parse_authorized_keys(fname)
to_add = list(keys)
- for i in range(0, len(lines)):
- ent = AuthKeyEntry(lines[i])
- if not ent.value:
+
+ for i in range(0, len(entries)):
+ ent = entries[i]
+ if ent.empty() or not ent.base64:
continue
# Replace those with the same base64
for k in keys:
- if not k.value:
+ if k.empty() or not k.base64:
continue
- if k.get('base64') == ent.get('base64'):
+ if k.base64 == ent.base64:
# Replace it with our better one
ent = k
# Don't add it later
to_add.remove(k)
- lines[i] = str(ent)
+ entries[i] = ent
# Now append any entries we did not match above
for key in to_add:
- lines.append(str(key))
+ entries.append(key)
+
+ # Now format them back to strings...
+ lines = [str(b) for b in entries]
# Ensure it ends with a newline
lines.append('')
@@ -198,9 +220,11 @@ def setup_user_keys(keys, user, key_prefix, sshd_config_fn=None):
util.ensure_dir(ssh_dir, mode=0700)
util.chownbyid(ssh_dir, pwent.pw_uid, pwent.pw_gid)
+ # Turn the keys given into actual entries
+ parser = AuthKeyLineParser()
key_entries = []
for k in keys:
- key_entries.append(AuthKeyEntry(k, def_opt=key_prefix))
+ key_entries.append(parser.parse(str(k), def_opt=key_prefix))
with util.SeLinuxGuard(ssh_dir, recursive=True):
try: