diff options
Diffstat (limited to 'cloudinit/ssh_util.py')
| -rw-r--r-- | cloudinit/ssh_util.py | 94 | 
1 files changed, 54 insertions, 40 deletions
| diff --git a/cloudinit/ssh_util.py b/cloudinit/ssh_util.py index e0a2f0ca..88a11a1a 100644 --- a/cloudinit/ssh_util.py +++ b/cloudinit/ssh_util.py @@ -181,12 +181,11 @@ def parse_authorized_keys(fname):      return contents -def update_authorized_keys(fname, keys): -    entries = parse_authorized_keys(fname) +def update_authorized_keys(old_entries, keys):      to_add = list(keys) -    for i in range(0, len(entries)): -        ent = entries[i] +    for i in range(0, len(old_entries)): +        ent = old_entries[i]          if ent.empty() or not ent.base64:              continue          # Replace those with the same base64 @@ -199,66 +198,81 @@ def update_authorized_keys(fname, keys):                  # Don't add it later                  if k in to_add:                      to_add.remove(k) -        entries[i] = ent +        old_entries[i] = ent      # Now append any entries we did not match above      for key in to_add: -        entries.append(key) +        old_entries.append(key)      # Now format them back to strings... -    lines = [str(b) for b in entries] +    lines = [str(b) for b in old_entries]      # Ensure it ends with a newline      lines.append('')      return '\n'.join(lines) -def setup_user_keys(keys, user, key_prefix, paths): -    # Make sure the users .ssh dir is setup accordingly -    pwent = pwd.getpwnam(user) -    ssh_dir = os.path.join(pwent.pw_dir, '.ssh') -    ssh_dir = paths.join(False, ssh_dir) -    if not os.path.exists(ssh_dir): -        util.ensure_dir(ssh_dir, mode=0700) -        util.chownbyid(ssh_dir, pwent.pw_uid, pwent.pw_gid) +def users_ssh_info(username, paths): +    pw_ent = pwd.getpwnam(username) +    if not pw_ent: +        raise RuntimeError("Unable to get ssh info for user %r" % (username)) +    ssh_dir = paths.join(False, os.path.join(pw_ent.pw_dir, '.ssh')) +    return (ssh_dir, pw_ent) -    # Turn the keys given into actual entries -    parser = AuthKeyLineParser() -    key_entries = [] -    for k in keys: -        key_entries.append(parser.parse(str(k), def_opt=key_prefix)) +def extract_authorized_keys(username, paths): +    (ssh_dir, pw_ent) = users_ssh_info(username, paths)      sshd_conf_fn = paths.join(True, DEF_SSHD_CFG) +    auth_key_fn = None      with util.SeLinuxGuard(ssh_dir, recursive=True):          try: -            # AuthorizedKeysFile may contain tokens +            # The 'AuthorizedKeysFile' may contain tokens              # of the form %T which are substituted during connection set-up.              # The following tokens are defined: %% is replaced by a literal              # '%', %h is replaced by the home directory of the user being              # authenticated and %u is replaced by the username of that user.              ssh_cfg = parse_ssh_config_map(sshd_conf_fn) -            akeys = ssh_cfg.get("authorizedkeysfile", '') -            akeys = akeys.strip() -            if not akeys: -                akeys = "%h/.ssh/authorized_keys" -            akeys = akeys.replace("%h", pwent.pw_dir) -            akeys = akeys.replace("%u", user) -            akeys = akeys.replace("%%", '%') -            if not akeys.startswith('/'): -                akeys = os.path.join(pwent.pw_dir, akeys) -            authorized_keys = paths.join(False, akeys) +            auth_key_fn = ssh_cfg.get("authorizedkeysfile", '').strip() +            if not auth_key_fn: +                auth_key_fn = "%h/.ssh/authorized_keys" +            auth_key_fn = auth_key_fn.replace("%h", pw_ent.pw_dir) +            auth_key_fn = auth_key_fn.replace("%u", username) +            auth_key_fn = auth_key_fn.replace("%%", '%') +            if not auth_key_fn.startswith('/'): +                auth_key_fn = os.path.join(pw_ent.pw_dir, auth_key_fn) +            auth_key_fn = paths.join(False, auth_key_fn)          except (IOError, OSError): -            authorized_keys = os.path.join(ssh_dir, 'authorized_keys') +            # Give up and use a default key filename +            auth_key_fn = os.path.join(ssh_dir, 'authorized_keys')              util.logexc(LOG, ("Failed extracting 'AuthorizedKeysFile'"                                " in ssh config" -                              " from %s, using 'AuthorizedKeysFile' file" -                              " %s instead"), -                        sshd_conf_fn, authorized_keys) - -        content = update_authorized_keys(authorized_keys, key_entries) -        util.ensure_dir(os.path.dirname(authorized_keys), mode=0700) -        util.write_file(authorized_keys, content, mode=0600) -        util.chownbyid(authorized_keys, pwent.pw_uid, pwent.pw_gid) +                              " from %r, using 'AuthorizedKeysFile' file" +                              " %r instead"), +                        sshd_conf_fn, auth_key_fn) +    auth_key_entries = parse_authorized_keys(auth_key_fn) +    return (auth_key_fn, auth_key_entries) + + +def setup_user_keys(keys, username, key_prefix, paths): +    # Make sure the users .ssh dir is setup accordingly +    (ssh_dir, pwent) = users_ssh_info(username, paths) +    if not os.path.isdir(ssh_dir): +        util.ensure_dir(ssh_dir, mode=0700) +        util.chownbyid(ssh_dir, pwent.pw_uid, pwent.pw_gid) + +    # Turn the 'update' keys given into actual entries +    parser = AuthKeyLineParser() +    key_entries = [] +    for k in keys: +        key_entries.append(parser.parse(str(k), def_opt=key_prefix)) + +    # Extract the old and make the new +    (auth_key_fn, auth_key_entries) = extract_authorized_keys(username, paths) +    with util.SeLinuxGuard(ssh_dir, recursive=True): +        content = update_authorized_keys(auth_key_entries, key_entries) +        util.ensure_dir(os.path.dirname(auth_key_fn), mode=0700) +        util.write_file(auth_key_fn, content, mode=0600) +        util.chownbyid(auth_key_fn, pwent.pw_uid, pwent.pw_gid)  class SshdConfigLine(object): | 
