# vi: ts=4 expandtab # # Copyright (C) 2009-2010 Canonical Ltd. # # Author: Scott Moser # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License version 3, as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import email from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.mime.base import MIMEBase from email import encoders import yaml import cloudinit import cloudinit.util as util import hashlib import os import urllib starts_with_mappings={ '#include' : 'text/x-include-url', '#include-once' : 'text/x-include-once-url', '#!' : 'text/x-shellscript', '#cloud-config' : 'text/cloud-config', '#upstart-job' : 'text/upstart-job', '#part-handler' : 'text/part-handler', '#cloud-boothook' : 'text/cloud-boothook', '#cloud-config-archive' : 'text/cloud-config-archive', } # if 'str' is compressed return decompressed otherwise return it def decomp_str(str): import StringIO import gzip try: uncomp = gzip.GzipFile(None,"rb",1,StringIO.StringIO(str)).read() return(uncomp) except: return(str) def do_include(content, appendmsg): import os # is just a list of urls, one per line # also support '#include ' includeonce = False for line in content.splitlines(): if line == "#include": continue if line == "#include-once": includeonce = True continue if line.startswith("#include-once"): line = line[len("#include-once"):].lstrip() includeonce = True elif line.startswith("#include"): line = line[len("#include"):].lstrip() if line.startswith("#"): continue # urls cannot not have leading or trailing white space msum = hashlib.md5() msum.update(line.strip()) includeonce_filename = "%s/urlcache/%s" % ( cloudinit.get_ipath_cur("data"), msum.hexdigest()) try: if includeonce and os.path.isfile(includeonce_filename): with open(includeonce_filename, "r") as fp: content = fp.read() else: content = urllib.urlopen(line).read() if includeonce: util.write_file(includeonce_filename, content, mode=0600) except Exception as e: raise process_includes(message_from_string(decomp_str(content)), appendmsg) def explode_cc_archive(archive, appendmsg): for ent in yaml.load(archive): # ent can be one of: # dict { 'filename' : 'value' , 'content' : 'value', 'type' : 'value' } # filename and type not be present # or # scalar(payload) def_type = "text/cloud-config" if isinstance(ent,str): ent = { 'content': ent } content = ent.get('content', '') mtype = ent.get('type', None) if mtype == None: mtype = type_from_startswith(content, def_type) maintype, subtype = mtype.split('/', 1) if maintype == "text": msg = MIMEText(content, _subtype=subtype) else: msg = MIMEBase(maintype, subtype) msg.set_payload(content) if 'filename' in ent: msg.add_header('Content-Disposition', 'attachment', filename=ent['filename']) for header in ent.keys(): if header in ('content', 'filename', 'type'): continue msg.add_header(header, ent['header']) _attach_part(appendmsg,msg) def multi_part_count(outermsg, newcount=None): """ Return the number of attachments to this MIMEMultipart by looking at its 'Number-Attachments' header. """ nfield = 'Number-Attachments' if nfield not in outermsg: outermsg[nfield] = "0" if newcount != None: outermsg.replace_header(nfield, str(newcount)) return(int(outermsg.get('Number-Attachments', 0))) def _attach_part(outermsg, part): """ Attach an part to an outer message. outermsg must be a MIMEMultipart. Modifies a header in outermsg to keep track of number of attachments. """ cur = multi_part_count(outermsg) if not part.get_filename(None): part.add_header('Content-Disposition', 'attachment', filename = 'part-%03d' % (cur+1)) outermsg.attach(part) multi_part_count(outermsg, cur+1) def type_from_startswith(payload, default=None): # slist is sorted longest first slist = sorted(starts_with_mappings.keys(), key=lambda e: 0-len(e)) for sstr in slist: if payload.startswith(sstr): return(starts_with_mappings[sstr]) return default def process_includes(msg, appendmsg=None): if appendmsg == None: appendmsg = MIMEMultipart() for part in msg.walk(): # multipart/* are just containers if part.get_content_maintype() == 'multipart': continue ctype = None ctype_orig = part.get_content_type() payload = part.get_payload(decode=True) if ctype_orig == "text/plain": ctype = type_from_startswith(payload) if ctype is None: ctype = ctype_orig if ctype in ('text/x-include-url', 'text/x-include-once-url'): do_include(payload, appendmsg) continue if ctype == "text/cloud-config-archive": explode_cc_archive(payload, appendmsg) continue if 'Content-Type' in msg: msg.replace_header('Content-Type', ctype) else: msg['Content-Type'] = ctype _attach_part(appendmsg, part) def message_from_string(data, headers={}): if "mime-version:" in data[0:4096].lower(): was_mime = True msg = email.message_from_string(data) for (key,val) in headers.items(): if key in msg: msg.replace_header(key,val) else: msg[key] = val else: was_mime = False mtype = headers.get("Content-Type","text/plain") maintype, subtype = mtype.split("/", 1) msg = MIMEBase(maintype, subtype, *headers) msg.set_payload(data) return(msg) # this is heavily wasteful, reads through userdata string input def preprocess_userdata(data): newmsg = MIMEMultipart() process_includes(message_from_string(decomp_str(data)), newmsg) return(newmsg.as_string()) # callback is a function that will be called with (data, content_type, filename, payload) def walk_userdata(istr, callback, data = None): partnum = 0 for part in message_from_string(istr).walk(): # multipart/* are just containers if part.get_content_maintype() == 'multipart': continue ctype = part.get_content_type() if ctype is None: ctype = 'application/octet-stream' filename = part.get_filename() if not filename: filename = 'part-%03d' % partnum callback(data, ctype, filename, part.get_payload(decode=True)) partnum = partnum+1 if __name__ == "__main__": import sys data = decomp_str(file(sys.argv[1]).read()) newmsg = MIMEMultipart() process_includes(message_from_string(data), newmsg) print newmsg print "#found %s parts" % multi_part_count(newmsg)