summaryrefslogtreecommitdiff
path: root/cloudinit/includer.py
blob: d1022c5ae6e25c474c997ca63f77078d1eae2f36 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import re

from cloudinit import downloader as down
from cloudinit import exceptions as excp
from cloudinit import log as logging
from cloudinit import shell as sh

INCLUDE_PATT = re.compile("^#(opt_include|include)[ \t](.*)$", re.MULTILINE)
OPT_PATS = ['opt_include']

LOG = logging.getLogger(__name__)


class Includer(object):

    def __init__(self, root_fn, stack_limit=10):
        self.root_fn = root_fn
        self.stack_limit = stack_limit

    def _read_file(self, fname):
        return sh.read_file(fname)

    def _read(self, fname, stack, rel):
        if len(stack) >= self.stack_limit:
            raise excp.StackExceeded("Stack limit of %s reached while including %s" % (self.stack_limit, fname))

        canon_fname = self._canon_name(fname, rel)
        if canon_fname in stack:
            raise excp.RecursiveInclude("File %s recursively included" % (canon_fname))

        stack.add(canon_fname)
        new_rel = os.path.dirname(canon_fname)
        contents = self._read_file(canon_fname)

        def include_cb(match):
            is_optional = (match.group(1).lower() in OPT_PATS)
            fn = match.group(2).strip()
            if not fn:
                # Should we die??
                return match.group(0)
            else:
                try:
                    LOG.debug("Including file %s", fn)
                    return self._read(fn, stack, new_rel)
                except IOError:
                    if is_optional:
                        return ''
                    else:
                        raise

        adjusted_contents = INCLUDE_PATT.sub(include_cb, contents)
        stack.remove(fname)
        return adjusted_contents

    def _canon_name(self, fname, rel):
        fname = fname.strip()
        if not fname.startswith("/"):
            fname = os.path.sep.join([rel, fname])
        return os.path.realpath(fname)

    def read(self, relative_to="."):
        stack = set()
        return self._read(self.root_fn, stack, rel=relative_to)