diff options
| -rw-r--r-- | cloudinit/util.py | 28 | ||||
| -rw-r--r-- | tests/unittests/test_util.py | 46 | 
2 files changed, 44 insertions, 30 deletions
| diff --git a/cloudinit/util.py b/cloudinit/util.py index 3ff3835a..0c592656 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -46,19 +46,13 @@ import urlparse  import yaml +from cloudinit import importer  from cloudinit import log as logging  from cloudinit import url_helper as uhelp  from cloudinit.settings import (CFG_BUILTIN, CLOUD_CONFIG) -try: -    import selinux -    HAVE_LIBSELINUX = True -except ImportError: -    HAVE_LIBSELINUX = False - -  LOG = logging.getLogger(__name__)  # Helps cleanup filenames to ensure they aren't FS incompatible @@ -126,31 +120,37 @@ class ProcessExecutionError(IOError):  class SeLinuxGuard(object):      def __init__(self, path, recursive=False): +        # Late import since it might not always +        # be possible to use this +        try: +            self.selinux = importer.import_module('selinux') +        except ImportError: +            self.selinux = None          self.path = path          self.recursive = recursive -        self.enabled = False -        if HAVE_LIBSELINUX and selinux.is_selinux_enabled(): -            self.enabled = True      def __enter__(self): -        return self.enabled +        if self.selinux: +            return True +        else: +            return False      def __exit__(self, excp_type, excp_value, excp_traceback): -        if self.enabled: +        if self.selinux:              path = os.path.realpath(os.path.expanduser(self.path))              do_restore = False              try:                  # See if even worth restoring??                  stats = os.lstat(path)                  if stat.ST_MODE in stats: -                    selinux.matchpathcon(path, stats[stat.ST_MODE]) +                    self.selinux.matchpathcon(path, stats[stat.ST_MODE])                      do_restore = True              except OSError:                  pass              if do_restore:                  LOG.debug("Restoring selinux mode for %s (recursive=%s)",                            path, self.recursive) -                selinux.restorecon(path, recursive=self.recursive) +                self.selinux.restorecon(path, recursive=self.recursive)  class MountFailedError(Exception): diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 3be6e186..93979f06 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -5,6 +5,26 @@ from unittest import TestCase  from mocker import MockerTestCase  from cloudinit import util +from cloudinit import importer + + +class FakeSelinux(object): + +    def __init__(self, match_what): +        self.match_what = match_what +        self.restored = [] + +    def matchpathcon(self, path, mode): +        if path == self.match_what: +            return +        else: +            raise OSError("No match!") + +    def is_selinux_enabled(self): +        return True + +    def restorecon(self, path, recursive): +        self.restored.append(path)  class TestMergeDict(MockerTestCase): @@ -159,22 +179,16 @@ class TestWriteFile(MockerTestCase):      def test_restorecon_if_possible_is_called(self):          """Make sure the selinux guard is called correctly.""" -        try: -            # We can only mock these out if selinux actually -            # exists, so thats why we catch the import -            mock_restorecon = self.mocker.replace( -                "selinux.restorecon", passthrough=False) -            mock_is_selinux_enabled = self.mocker.replace( -                "selinux.is_selinux_enabled", passthrough=False) -            mock_is_selinux_enabled() -            self.mocker.result(True) -            mock_restorecon("/etc/hosts", recursive=False) -            self.mocker.result(True) -            self.mocker.replay() -            with util.SeLinuxGuard("/etc/hosts") as is_on: -                self.assertTrue(is_on) -        except ImportError: -            pass +        import_mock = self.mocker.replace(importer.import_module, +                                          passthrough=False) +        import_mock('selinux') +        fake_se = FakeSelinux('/etc/hosts') +        self.mocker.result(fake_se) +        self.mocker.replay() +        with util.SeLinuxGuard("/etc/hosts") as is_on: +            self.assertTrue(is_on) +        self.assertEqual(1, len(fake_se.restored)) +        self.assertEqual('/etc/hosts', fake_se.restored[0])  class TestDeleteDirContents(MockerTestCase): | 
