diff options
Diffstat (limited to 'cloudinit/net/network_state.py')
| -rw-r--r-- | cloudinit/net/network_state.py | 284 | 
1 files changed, 122 insertions, 162 deletions
diff --git a/cloudinit/net/network_state.py b/cloudinit/net/network_state.py index 4c726ab4..c5aeadb5 100644 --- a/cloudinit/net/network_state.py +++ b/cloudinit/net/network_state.py @@ -15,9 +15,12 @@  #   You should have received a copy of the GNU Affero General Public License  #   along with Curtin.  If not, see <http://www.gnu.org/licenses/>. -from cloudinit import log as logging -from cloudinit import util -from cloudinit.util import yaml_dumps as dump_config +import copy +import logging + +import six + +from cloudinit import net  LOG = logging.getLogger(__name__) @@ -27,39 +30,104 @@ NETWORK_STATE_REQUIRED_KEYS = {  } +def parse_net_config_data(net_config, skip_broken=True): +    """Parses the config, returns NetworkState object + +    :param net_config: curtin network config dict +    """ +    state = None +    if 'version' in net_config and 'config' in net_config: +        ns = NetworkState(version=net_config.get('version'), +                          config=net_config.get('config')) +        ns.parse_config(skip_broken=skip_broken) +        state = ns.network_state +    return state + + +def parse_net_config(path, skip_broken=True): +    """Parses a curtin network configuration file and +       return network state""" +    ns = None +    net_config = net.read_yaml_file(path) +    if 'network' in net_config: +        ns = parse_net_config_data(net_config.get('network'), +                                   skip_broken=skip_broken) +    return ns + +  def from_state_file(state_file):      network_state = None -    state = util.read_conf(state_file) +    state = net.read_yaml_file(state_file)      network_state = NetworkState()      network_state.load(state) -      return network_state +def diff_keys(expected, actual): +    missing = set(expected) +    for key in actual: +        missing.discard(key) +    return missing + + +class InvalidCommand(Exception): +    pass + + +def ensure_command_keys(required_keys): + +    def wrapper(func): + +        @six.wraps(func) +        def decorator(self, command, *args, **kwargs): +            if required_keys: +                missing_keys = diff_keys(required_keys, command) +                if missing_keys: +                    raise InvalidCommand("Command missing %s of required" +                                         " keys %s" % (missing_keys, +                                                       required_keys)) +            return func(self, command, *args, **kwargs) + +        return decorator + +    return wrapper + + +class CommandHandlerMeta(type): +    """Metaclass that dynamically creates a 'command_handlers' attribute. + +    This will scan the to-be-created class for methods that start with +    'handle_' and on finding those will populate a class attribute mapping +    so that those methods can be quickly located and called. +    """ +    def __new__(cls, name, parents, dct): +        command_handlers = {} +        for attr_name, attr in six.iteritems(dct): +            if six.callable(attr) and attr_name.startswith('handle_'): +                handles_what = attr_name[len('handle_'):] +                if handles_what: +                    command_handlers[handles_what] = attr +        dct['command_handlers'] = command_handlers +        return super(CommandHandlerMeta, cls).__new__(cls, name, +                                                      parents, dct) + + +@six.add_metaclass(CommandHandlerMeta)  class NetworkState(object): + +    initial_network_state = { +        'interfaces': {}, +        'routes': [], +        'dns': { +            'nameservers': [], +            'search': [], +        } +    } +      def __init__(self, version=NETWORK_STATE_VERSION, config=None):          self.version = version          self.config = config -        self.network_state = { -            'interfaces': {}, -            'routes': [], -            'dns': { -                'nameservers': [], -                'search': [], -            } -        } -        self.command_handlers = self.get_command_handlers() - -    def get_command_handlers(self): -        METHOD_PREFIX = 'handle_' -        methods = filter(lambda x: callable(getattr(self, x)) and -                         x.startswith(METHOD_PREFIX), dir(self)) -        handlers = {} -        for m in methods: -            key = m.replace(METHOD_PREFIX, '') -            handlers[key] = getattr(self, m) - -        return handlers +        self.network_state = copy.deepcopy(self.initial_network_state)      def dump(self):          state = { @@ -67,7 +135,7 @@ class NetworkState(object):              'config': self.config,              'network_state': self.network_state,          } -        return dump_config(state) +        return net.dump_yaml(state)      def load(self, state):          if 'version' not in state: @@ -75,32 +143,39 @@ class NetworkState(object):              raise Exception('Invalid state, missing version field')          required_keys = NETWORK_STATE_REQUIRED_KEYS[state['version']] -        if not self.valid_command(state, required_keys): -            msg = 'Invalid state, missing keys: {}'.format(required_keys) +        missing_keys = diff_keys(required_keys, state) +        if missing_keys: +            msg = 'Invalid state, missing keys: %s' % (missing_keys)              LOG.error(msg) -            raise Exception(msg) +            raise ValueError(msg)          # v1 - direct attr mapping, except version          for key in [k for k in required_keys if k not in ['version']]:              setattr(self, key, state[key]) -        self.command_handlers = self.get_command_handlers()      def dump_network_state(self): -        return dump_config(self.network_state) +        return net.dump_yaml(self.network_state) -    def parse_config(self): +    def parse_config(self, skip_broken=True):          # rebuild network state          for command in self.config: -            handler = self.command_handlers.get(command['type']) -            handler(command) - -    def valid_command(self, command, required_keys): -        if not required_keys: -            return False - -        found_keys = [key for key in command.keys() if key in required_keys] -        return len(found_keys) == len(required_keys) - +            command_type = command['type'] +            try: +                handler = self.command_handlers[command_type] +            except KeyError: +                raise RuntimeError("No handler found for" +                                   " command '%s'" % command_type) +            try: +                handler(self, command) +            except InvalidCommand: +                if not skip_broken: +                    raise +                else: +                    LOG.warn("Skipping invalid command: %s", command, +                             exc_info=True) +                    LOG.debug(self.dump_network_state()) + +    @ensure_command_keys(['name'])      def handle_physical(self, command):          '''          command = { @@ -112,13 +187,6 @@ class NetworkState(object):               ]          }          ''' -        required_keys = [ -            'name', -        ] -        if not self.valid_command(command, required_keys): -            LOG.warn('Skipping Invalid command: {}'.format(command)) -            LOG.debug(self.dump_network_state()) -            return          interfaces = self.network_state.get('interfaces')          iface = interfaces.get(command['name'], {}) @@ -149,6 +217,7 @@ class NetworkState(object):          self.network_state['interfaces'].update({command.get('name'): iface})          self.dump_network_state() +    @ensure_command_keys(['name', 'vlan_id', 'vlan_link'])      def handle_vlan(self, command):          '''              auto eth0.222 @@ -158,16 +227,6 @@ class NetworkState(object):                      hwaddress ether BC:76:4E:06:96:B3                      vlan-raw-device eth0          ''' -        required_keys = [ -            'name', -            'vlan_link', -            'vlan_id', -        ] -        if not self.valid_command(command, required_keys): -            print('Skipping Invalid command: {}'.format(command)) -            print(self.dump_network_state()) -            return -          interfaces = self.network_state.get('interfaces')          self.handle_physical(command)          iface = interfaces.get(command.get('name'), {}) @@ -175,6 +234,7 @@ class NetworkState(object):          iface['vlan_id'] = command.get('vlan_id')          interfaces.update({iface['name']: iface}) +    @ensure_command_keys(['name', 'bond_interfaces', 'params'])      def handle_bond(self, command):          '''      #/etc/network/interfaces @@ -200,15 +260,6 @@ class NetworkState(object):           bond-updelay 200           bond-lacp-rate 4          ''' -        required_keys = [ -            'name', -            'bond_interfaces', -            'params', -        ] -        if not self.valid_command(command, required_keys): -            print('Skipping Invalid command: {}'.format(command)) -            print(self.dump_network_state()) -            return          self.handle_physical(command)          interfaces = self.network_state.get('interfaces') @@ -236,6 +287,7 @@ class NetworkState(object):                  bond_if.update({param: val})              self.network_state['interfaces'].update({ifname: bond_if}) +    @ensure_command_keys(['name', 'bridge_interfaces', 'params'])      def handle_bridge(self, command):          '''              auto br0 @@ -263,15 +315,6 @@ class NetworkState(object):              "bridge_waitport",          ]          ''' -        required_keys = [ -            'name', -            'bridge_interfaces', -            'params', -        ] -        if not self.valid_command(command, required_keys): -            print('Skipping Invalid command: {}'.format(command)) -            print(self.dump_network_state()) -            return          # find one of the bridge port ifaces to get mac_addr          # handle bridge_slaves @@ -295,15 +338,8 @@ class NetworkState(object):          interfaces.update({iface['name']: iface}) +    @ensure_command_keys(['address'])      def handle_nameserver(self, command): -        required_keys = [ -            'address', -        ] -        if not self.valid_command(command, required_keys): -            print('Skipping Invalid command: {}'.format(command)) -            print(self.dump_network_state()) -            return -          dns = self.network_state.get('dns')          if 'address' in command:              addrs = command['address'] @@ -318,15 +354,8 @@ class NetworkState(object):              for path in paths:                  dns['search'].append(path) +    @ensure_command_keys(['destination'])      def handle_route(self, command): -        required_keys = [ -            'destination', -        ] -        if not self.valid_command(command, required_keys): -            print('Skipping Invalid command: {}'.format(command)) -            print(self.dump_network_state()) -            return -          routes = self.network_state.get('routes')          network, cidr = command['destination'].split("/")          netmask = cidr2mask(int(cidr)) @@ -376,72 +405,3 @@ def mask2cidr(mask):          return ipv4mask2cidr(mask)      else:          return mask - - -if __name__ == '__main__': -    import random -    import sys - -    from cloudinit import net - -    def load_config(nc): -        version = nc.get('version') -        config = nc.get('config') -        return (version, config) - -    def test_parse(network_config): -        (version, config) = load_config(network_config) -        ns1 = NetworkState(version=version, config=config) -        ns1.parse_config() -        random.shuffle(config) -        ns2 = NetworkState(version=version, config=config) -        ns2.parse_config() -        print("----NS1-----") -        print(ns1.dump_network_state()) -        print() -        print("----NS2-----") -        print(ns2.dump_network_state()) -        print("NS1 == NS2 ?=> {}".format( -            ns1.network_state == ns2.network_state)) -        eni = net.render_interfaces(ns2.network_state) -        print(eni) -        udev_rules = net.render_persistent_net(ns2.network_state) -        print(udev_rules) - -    def test_dump_and_load(network_config): -        print("Loading network_config into NetworkState") -        (version, config) = load_config(network_config) -        ns1 = NetworkState(version=version, config=config) -        ns1.parse_config() -        print("Dumping state to file") -        ns1_dump = ns1.dump() -        ns1_state = "/tmp/ns1.state" -        with open(ns1_state, "w+") as f: -            f.write(ns1_dump) - -        print("Loading state from file") -        ns2 = from_state_file(ns1_state) -        print("NS1 == NS2 ?=> {}".format( -            ns1.network_state == ns2.network_state)) - -    def test_output(network_config): -        (version, config) = load_config(network_config) -        ns1 = NetworkState(version=version, config=config) -        ns1.parse_config() -        random.shuffle(config) -        ns2 = NetworkState(version=version, config=config) -        ns2.parse_config() -        print("NS1 == NS2 ?=> {}".format( -            ns1.network_state == ns2.network_state)) -        eni_1 = net.render_interfaces(ns1.network_state) -        eni_2 = net.render_interfaces(ns2.network_state) -        print(eni_1) -        print(eni_2) -        print("eni_1 == eni_2 ?=> {}".format( -            eni_1 == eni_2)) - -    y = util.read_conf(sys.argv[1]) -    network_config = y.get('network') -    test_parse(network_config) -    test_dump_and_load(network_config) -    test_output(network_config)  | 
