diff --git a/wlauto/tests/test_config.py b/wlauto/tests/test_config.py index a7e9a5d2..0587ddb5 100644 --- a/wlauto/tests/test_config.py +++ b/wlauto/tests/test_config.py @@ -22,7 +22,7 @@ from unittest import TestCase from nose.tools import assert_equal, assert_in, raises from wlauto.core.bootstrap import ConfigLoader -from wlauto.core.agenda import AgendaWorkloadEntry, AgendaGlobalEntry +from wlauto.core.agenda import AgendaWorkloadEntry, AgendaGlobalEntry, Agenda from wlauto.core.configuration import RunConfiguration from wlauto.exceptions import ConfigError @@ -33,24 +33,35 @@ BAD_CONFIG_TEXT = """device = 'TEST device_config = 'TEST-CONFIG'""" +LIST_PARAMS_AGENDA_TEXT = """ +config: + instrumentation: [list_params] + list_params: + param: [0.1, 0.1, 0.1] +workloads: + - dhrystone +""" + + class MockExtensionLoader(object): def __init__(self): self.aliases = {} self.global_param_aliases = {} - self.extensions = {} + self.extensions = { + 'defaults_workload': DefaultsWorkload(), + 'list_params': ListParamstrument(), + } def get_extension_class(self, name, kind=None): # pylint: disable=unused-argument - if name == 'defaults_workload': - return DefaultsWorkload() - else: - return NamedMock(name) + return self.extensions.get(name, NamedMock(name)) def resolve_alias(self, name): return name, {} def get_default_config(self, name): # pylint: disable=unused-argument - return {} + ec = self.get_extension_class(name) + return {p.name: p.default for p in ec.parameters} def has_extension(self, name): return name in self.aliases or name in self.extensions @@ -88,6 +99,14 @@ class DefaultsWorkload(object): self.parameters[0].default = [1, 2] +class ListParamstrument(object): + + def __init__(self): + self.name = 'list_params' + self.parameters = [NamedMock('param')] + self.parameters[0].default = [] + + class ConfigLoaderTest(TestCase): def setUp(self): @@ -142,6 +161,12 @@ class ConfigTest(TestCase): spec = self.config.workload_specs[0] assert_equal(spec.workload_parameters, {'param': [3]}) + def test_exetension_params_lists(self): + a = Agenda(LIST_PARAMS_AGENDA_TEXT) + self.config.set_agenda(a) + self.config.finalize() + assert_equal(self.config.instrumentation['list_params']['param'], [0.1, 0.1, 0.1]) + def test_global_instrumentation(self): self.config.load_config({'instrumentation': ['global_instrument']}) ws = AgendaWorkloadEntry(id='a', iterations=1, name='linpack', instrumentation=['local_instrument']) diff --git a/wlauto/utils/misc.py b/wlauto/utils/misc.py index 8e93ece5..1eed15d8 100644 --- a/wlauto/utils/misc.py +++ b/wlauto/utils/misc.py @@ -313,7 +313,23 @@ def merge_lists(*args, **kwargs): def _merge_two_lists(base, other, duplicates='all', dict_type=dict): # pylint: disable=R0912 - """Merge lists, normalizing their entries.""" + """ + Merge lists, normalizing their entries. + + parameters: + + :base, other: the two lists to be merged. ``other`` will be merged on + top of base. + :duplicates: Indicates the strategy of handling entries that appear + in both lists. ``all`` will keep occurrences from both + lists; ``first`` will only keep occurrences from + ``base``; ``last`` will only keep occurrences from + ``other``; + + .. note:: duplicate entries that appear in the *same* list + will never be removed. + + """ if duplicates == 'all': merged_list = [] for v in normalize(base, dict_type) + normalize(other, dict_type): @@ -321,20 +337,21 @@ def _merge_two_lists(base, other, duplicates='all', dict_type=dict): # pylint: merged_list.append(v) return merged_list elif duplicates == 'first': - merged_list = [] - for v in normalize(base + other, dict_type): + base_norm = normalize(base, dict_type) + merged_list = normalize(base, dict_type) + for v in normalize(other, dict_type): if not _check_remove_item(merged_list, v): - if v not in merged_list: - merged_list.append(v) + if v not in base_norm: + merged_list.append(v) # pylint: disable=no-member return merged_list elif duplicates == 'last': + other_norm = normalize(other, dict_type) merged_list = [] - for v in normalize(base + other, dict_type): + for v in normalize(base, dict_type): if not _check_remove_item(merged_list, v): - if v in merged_list: - del merged_list[merged_list.index(v)] - merged_list.append(v) - return merged_list + if v not in other_norm: + merged_list.append(v) + return merged_list + other_norm else: raise ValueError('Unexpected value for list duplcates argument: {}. '.format(duplicates) + 'Must be in {"all", "first", "last"}.')