From 181fd3ceeb6a93530af7ccebfa1d06a1f7412a12 Mon Sep 17 00:00:00 2001 From: Mike Milner Date: Tue, 17 Jan 2012 09:58:42 -0400 Subject: Add unit tests for util.write_file. --- cloudinit/util.py | 71 ++++++++++++++++++++++++------------ tests/unittests/test_util.py | 85 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 132 insertions(+), 24 deletions(-) diff --git a/cloudinit/util.py b/cloudinit/util.py index de95ec79..b690d517 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -86,10 +86,24 @@ def get_cfg_option_str(yobj, key, default=None): return yobj[key] def get_cfg_option_list_or_str(yobj, key, default=None): - if not yobj.has_key(key): return default - if yobj[key] is None: return [] - if isinstance(yobj[key],list): return yobj[key] - return([yobj[key]]) + """ + Gets the C{key} config option from C{yobj} as a list of strings. If the + key is present as a single string it will be returned as a list with one + string arg. + + @param yobj: The configuration object. + @param key: The configuration key to get. + @param default: The default to return if key is not found. + @return: The configuration option as a list of strings or default if key + is not found. + """ + if not key in yobj: + return default + if yobj[key] is None: + return [] + if isinstance(yobj[key], list): + return yobj[key] + return [yobj[key]] # get a cfg entry by its path array # for f['a']['b']: get_cfg_by_path(mycfg,('a','b')) @@ -100,30 +114,41 @@ def get_cfg_by_path(yobj,keyp,default=None): cur = cur[tok] return(cur) -# merge values from cand into source -# if src has a key, cand will not override -def mergedict(src,cand): - if isinstance(src,dict) and isinstance(cand,dict): - for k,v in cand.iteritems(): +def mergedict(src, cand): + """ + Merge values from C{cand} into C{src}. If C{src} has a key C{cand} will + not override. Nested dictionaries are merged recursively. + """ + if isinstance(src, dict) and isinstance(cand, dict): + for k, v in cand.iteritems(): if k not in src: src[k] = v else: - src[k] = mergedict(src[k],v) + src[k] = mergedict(src[k], v) return src -def write_file(file,content,mode=0644,omode="wb"): - try: - os.makedirs(os.path.dirname(file)) - except OSError as e: - if e.errno != errno.EEXIST: - raise e - - f=open(file,omode) - if mode != None: - os.chmod(file,mode) - f.write(content) - f.close() - restorecon_if_possible(file) +def write_file(filepath, content, mode=0644, omode="wb"): + """ + Writes a file with the given content and sets the file mode as specified. + Resotres the SELinux context if possible. + + @param filepath: The full path of the file to write. + @param content: The content to write to the file. + @param mode: The filesystem mode to set on the file. + @param omode: The open mode used when opening the file (r, rb, a, etc.) + """ + try: + os.makedirs(os.path.dirname(filepath)) + except OSError as e: + if e.errno != errno.EEXIST: + raise e + + f = open(filepath, omode) + if mode is not None: + os.chmod(filepath, mode) + f.write(content) + f.close() + restorecon_if_possible(filepath) def restorecon_if_possible(path, recursive=False): if HAVE_LIBSELINUX and selinux.is_selinux_enabled(): diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index ba15e44d..ecbaba1a 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -1,6 +1,11 @@ from unittest import TestCase +from mocker import MockerTestCase +from tempfile import mkdtemp +from shutil import rmtree +import os +import stat -from cloudinit.util import mergedict, get_cfg_option_list_or_str +from cloudinit.util import mergedict, get_cfg_option_list_or_str, write_file class TestMergeDict(TestCase): def test_simple_merge(self): @@ -77,3 +82,81 @@ class TestGetCfgOptionListOrStr(TestCase): config = {"key": None} result = get_cfg_option_list_or_str(config, "key") self.assertEqual([], result) + +class TestWriteFile(MockerTestCase): + def setUp(self): + super(TestWriteFile, self).setUp() + # Make a temp directoy for tests to use. + self.tmp = mkdtemp(prefix="unittest_") + + def tearDown(self): + super(TestWriteFile, self).tearDown() + # Clean up temp directory + rmtree(self.tmp) + + def test_basic_usage(self): + """Verify basic usage with default args.""" + path = os.path.join(self.tmp, "NewFile.txt") + contents = "Hey there" + + write_file(path, contents) + + self.assertTrue(os.path.exists(path)) + self.assertTrue(os.path.isfile(path)) + with open(path) as f: + create_contents = f.read() + self.assertEqual(contents, create_contents) + file_stat = os.stat(path) + self.assertEqual(0644, stat.S_IMODE(file_stat.st_mode)) + + def test_dir_is_created_if_required(self): + """Verifiy that directories are created is required.""" + dirname = os.path.join(self.tmp, "subdir") + path = os.path.join(dirname, "NewFile.txt") + contents = "Hey there" + + write_file(path, contents) + + self.assertTrue(os.path.isdir(dirname)) + self.assertTrue(os.path.isfile(path)) + + def test_custom_mode(self): + """Verify custom mode works properly.""" + path = os.path.join(self.tmp, "NewFile.txt") + contents = "Hey there" + + write_file(path, contents, mode=0666) + + self.assertTrue(os.path.exists(path)) + self.assertTrue(os.path.isfile(path)) + file_stat = os.stat(path) + self.assertEqual(0666, stat.S_IMODE(file_stat.st_mode)) + + def test_custom_omode(self): + """Verify custom omode works properly.""" + path = os.path.join(self.tmp, "NewFile.txt") + contents = "Hey there" + + # Create file first with basic content + with open(path, "wb") as f: + f.write("LINE1\n") + write_file(path, contents, omode="a") + + self.assertTrue(os.path.exists(path)) + self.assertTrue(os.path.isfile(path)) + with open(path) as f: + create_contents = f.read() + self.assertEqual("LINE1\nHey there", create_contents) + + def test_restorecon_if_possible_is_called(self): + """Make sure the restorecon_if_possible is called correctly.""" + path = os.path.join(self.tmp, "NewFile.txt") + contents = "Hey there" + + # Mock out the restorecon_if_possible call to test if it's called. + mock_restorecon = self.mocker.replace( + "cloudinit.util.restorecon_if_possible", passthrough=False) + mock_restorecon(path) + self.mocker.replay() + + write_file(path, contents) -- cgit v1.2.3