summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cloudinit/util.py28
-rw-r--r--tests/unittests/test_util.py46
2 files changed, 44 insertions, 30 deletions
diff --git a/cloudinit/util.py b/cloudinit/util.py
index 3ff3835a..0c592656 100644
--- a/cloudinit/util.py
+++ b/cloudinit/util.py
@@ -46,19 +46,13 @@ import urlparse
import yaml
+from cloudinit import importer
from cloudinit import log as logging
from cloudinit import url_helper as uhelp
from cloudinit.settings import (CFG_BUILTIN, CLOUD_CONFIG)
-try:
- import selinux
- HAVE_LIBSELINUX = True
-except ImportError:
- HAVE_LIBSELINUX = False
-
-
LOG = logging.getLogger(__name__)
# Helps cleanup filenames to ensure they aren't FS incompatible
@@ -126,31 +120,37 @@ class ProcessExecutionError(IOError):
class SeLinuxGuard(object):
def __init__(self, path, recursive=False):
+ # Late import since it might not always
+ # be possible to use this
+ try:
+ self.selinux = importer.import_module('selinux')
+ except ImportError:
+ self.selinux = None
self.path = path
self.recursive = recursive
- self.enabled = False
- if HAVE_LIBSELINUX and selinux.is_selinux_enabled():
- self.enabled = True
def __enter__(self):
- return self.enabled
+ if self.selinux:
+ return True
+ else:
+ return False
def __exit__(self, excp_type, excp_value, excp_traceback):
- if self.enabled:
+ if self.selinux:
path = os.path.realpath(os.path.expanduser(self.path))
do_restore = False
try:
# See if even worth restoring??
stats = os.lstat(path)
if stat.ST_MODE in stats:
- selinux.matchpathcon(path, stats[stat.ST_MODE])
+ self.selinux.matchpathcon(path, stats[stat.ST_MODE])
do_restore = True
except OSError:
pass
if do_restore:
LOG.debug("Restoring selinux mode for %s (recursive=%s)",
path, self.recursive)
- selinux.restorecon(path, recursive=self.recursive)
+ self.selinux.restorecon(path, recursive=self.recursive)
class MountFailedError(Exception):
diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py
index 3be6e186..93979f06 100644
--- a/tests/unittests/test_util.py
+++ b/tests/unittests/test_util.py
@@ -5,6 +5,26 @@ from unittest import TestCase
from mocker import MockerTestCase
from cloudinit import util
+from cloudinit import importer
+
+
+class FakeSelinux(object):
+
+ def __init__(self, match_what):
+ self.match_what = match_what
+ self.restored = []
+
+ def matchpathcon(self, path, mode):
+ if path == self.match_what:
+ return
+ else:
+ raise OSError("No match!")
+
+ def is_selinux_enabled(self):
+ return True
+
+ def restorecon(self, path, recursive):
+ self.restored.append(path)
class TestMergeDict(MockerTestCase):
@@ -159,22 +179,16 @@ class TestWriteFile(MockerTestCase):
def test_restorecon_if_possible_is_called(self):
"""Make sure the selinux guard is called correctly."""
- try:
- # We can only mock these out if selinux actually
- # exists, so thats why we catch the import
- mock_restorecon = self.mocker.replace(
- "selinux.restorecon", passthrough=False)
- mock_is_selinux_enabled = self.mocker.replace(
- "selinux.is_selinux_enabled", passthrough=False)
- mock_is_selinux_enabled()
- self.mocker.result(True)
- mock_restorecon("/etc/hosts", recursive=False)
- self.mocker.result(True)
- self.mocker.replay()
- with util.SeLinuxGuard("/etc/hosts") as is_on:
- self.assertTrue(is_on)
- except ImportError:
- pass
+ import_mock = self.mocker.replace(importer.import_module,
+ passthrough=False)
+ import_mock('selinux')
+ fake_se = FakeSelinux('/etc/hosts')
+ self.mocker.result(fake_se)
+ self.mocker.replay()
+ with util.SeLinuxGuard("/etc/hosts") as is_on:
+ self.assertTrue(is_on)
+ self.assertEqual(1, len(fake_se.restored))
+ self.assertEqual('/etc/hosts', fake_se.restored[0])
class TestDeleteDirContents(MockerTestCase):