diff --git a/wa/framework/configuration/core.py b/wa/framework/configuration/core.py index 90ecbf59..4d6d6446 100644 --- a/wa/framework/configuration/core.py +++ b/wa/framework/configuration/core.py @@ -16,13 +16,16 @@ import os import logging from copy import copy, deepcopy from collections import OrderedDict, defaultdict +from itertools import product +from string import Formatter + from wa.framework.exception import ConfigError, NotFoundError -from wa.framework.configuration.tree import SectionNode +from wa.framework.configuration.tree import SectionNode, WorkloadEntry from wa.utils import log from wa.utils.misc import (get_article, merge_config_values) from wa.utils.types import (identifier, integer, boolean, list_of_strings, - list_of, toggle_set, obj_dict, enum) + list_of, toggle_set, obj_dict, enum, sweep) from wa.utils.serializer import is_pod, Podable @@ -302,7 +305,11 @@ class ConfigurationPoint(object): raise ConfigError(msg.format(self.name, obj.name)) else: try: - value = self.kind(value) + if isinstance(value, sweep) and not value.auto: + for i, v in enumerate(value): + value[i] = self.kind(v) + else: + value = self.kind(value) except (ValueError, TypeError): typename = get_type_name(self.kind) msg = 'Bad value "{}" for {}; must be {} {}' @@ -332,6 +339,13 @@ class ConfigurationPoint(object): self.validate_constraint(name, value) def validate_allowed_values(self, name, value): + if isinstance(value, sweep) and not value.auto: + for v in value: + self._do_validate_allowed_values(name, v) + else: + self._do_validate_allowed_values(name, value) + + def _do_validate_allowed_values(self, name, value): if 'list' in str(self.kind): for v in value: if v not in self.allowed_values: @@ -343,6 +357,13 @@ class ConfigurationPoint(object): raise ConfigError(msg.format(value, self.name, name, self.allowed_values)) def validate_constraint(self, name, value): + if isinstance(value, sweep) and not value.auto: + for v in value: + self._do_validate_constraint(name, v) + else: + self._do_validate_constraint(name, value) + + def _do_validate_constraint(self, name, value): msg_vals = {'value': value, 'param': self.name, 'plugin': name} if isinstance(self.constraint, tuple) and len(self.constraint) == 2: constraint, msg = self.constraint # pylint: disable=unpacking-non-sequence @@ -1084,6 +1105,9 @@ class JobGenerator(object): specs.append(job_spec) return specs + def expand_sweeps(self, tm): + _expand_children_sweeps(self.root_node, tm, self.plugin_cache) + def create_job_spec(workload_entry, sections, target_manager, plugin_cache, disabled_augmentations): @@ -1124,3 +1148,90 @@ def get_config_point_map(params): settings = MetaConfiguration(os.environ) + + +def _expand_children_sweeps(node, tm, pc): + child_names = ['workload_entries', 'children'] + for name in child_names: + children_list = getattr(node, name, []) + to_replace = {} + for index, child in enumerate(children_list): + expanded_children = _process_node_sweeps(child, tm, pc) + if expanded_children: + to_replace[index] = expanded_children + + if to_replace: + new_children = [] + for i, v in enumerate(children_list): + if i in to_replace: + new_children.extend(to_replace.pop(i)) + else: + new_children.append(v) + group = getattr(children_list[0], 'group', None) + for child in new_children: + if group: + child.group = group + child.parent = node + + setattr(node, name, new_children) + + +def _process_node_sweeps(node, tm, pc): + # Expand the subtree and workload entries + _expand_children_sweeps(node, tm, pc) + + # Then own config: + # 1. Find sweeps in config + locs_sweeps = list(find_sweeps(node.config)) + if not locs_sweeps: + return [] + + sweeps = list(l_s[1] for l_s in locs_sweeps) + locs = list(l_s[0] for l_s in locs_sweeps) + + # 2. Convert any auto sweeps + resolve_auto_sweeps(sweeps, tm, pc) + + # 3. Generate configurations from + config_options = _construct_config_options(sweeps, locs, node.config) + + # 4. Copy subtree over + def construct_concrete_section(config): + new_node = node.copy_subtree() + new_node.config = config + return new_node + + new_nodes = list(map(construct_concrete_section, config_options)) + return new_nodes + +def _construct_config_options(sweeps, locations, config): + options = product(*sweeps) + config_options = [] + for i, option in enumerate(options): + new_config = deepcopy(config) + new_config['id'] = new_config['id'] + '_{}'.format(i) + for value, location in zip(option, locations): + _nested_dict_set(location, new_config, value) + config_options.append(new_config) + return config_options + + +def resolve_auto_sweeps(sweeps, tm, pc): + for sweep in sweeps: + if sweep.auto: + sweep.handler.resolve_auto_sweep(tm, pc) + + +def _nested_dict_set(loc, dictn, value): + for key in loc[:-1]: + dictn = dictn.setdefault(key, {}) + dictn[loc[-1]] = value + + +def find_sweeps(dictn): + for k, v in dictn.items(): + if isinstance(v, sweep): + yield (k,), v + elif isinstance(v, dict) or isinstance(v, obj_dict): + for loc, swp in find_sweeps(v): + yield (k,) + loc, swp diff --git a/wa/framework/configuration/execution.py b/wa/framework/configuration/execution.py index d83d8ac7..f0eb25e6 100644 --- a/wa/framework/configuration/execution.py +++ b/wa/framework/configuration/execution.py @@ -147,6 +147,7 @@ class ConfigManager(object): return self.get_config() def generate_jobs(self, context): + self.jobs_config.expand_sweeps(context.tm) job_specs = self.jobs_config.generate_job_specs(context.tm) if not job_specs: msg = 'No jobs available for running.' diff --git a/wa/framework/configuration/parsers.py b/wa/framework/configuration/parsers.py index a39e6188..ac28aa9f 100644 --- a/wa/framework/configuration/parsers.py +++ b/wa/framework/configuration/parsers.py @@ -16,7 +16,11 @@ import os import logging +import re +from abc import ABC, abstractmethod +from copy import deepcopy from functools import reduce # pylint: disable=redefined-builtin +from math import inf from devlib.utils.types import identifier @@ -24,7 +28,7 @@ from wa.framework.configuration.core import JobSpec from wa.framework.exception import ConfigError from wa.utils import log from wa.utils.serializer import json, read_pod, SerializerSyntaxError -from wa.utils.types import toggle_set, counter +from wa.utils.types import toggle_set, counter, sweep from wa.utils.misc import merge_config_values, isiterable @@ -103,7 +107,6 @@ class AgendaParser(object): if not isinstance(raw, dict): raise ConfigError('Invalid agenda, top level entry must be a dict') - self._populate_and_validate_config(state, raw, source) sections = self._pop_sections(raw) global_workloads = self._pop_workloads(raw) if not global_workloads: @@ -111,11 +114,13 @@ class AgendaParser(object): 'least one workload to run.' raise ConfigError(msg) + sect_ids, wkl_ids = self._collect_ids(sections, global_workloads) + self._populate_and_validate_config(state, raw, source, sect_ids) + if raw: msg = 'Invalid top level agenda entry(ies): "{}"' raise ConfigError(msg.format('", "'.join(list(raw.keys())))) - sect_ids, wkl_ids = self._collect_ids(sections, global_workloads) self._process_global_workloads(state, global_workloads, wkl_ids) self._process_sections(state, sections, sect_ids, wkl_ids) @@ -126,7 +131,7 @@ class AgendaParser(object): finally: log.dedent() - def _populate_and_validate_config(self, state, raw, source): + def _populate_and_validate_config(self, state, raw, source, sect_ids): for name in ['config', 'global']: entry = raw.pop(name, None) if entry is None: @@ -136,6 +141,9 @@ class AgendaParser(object): msg = 'Invalid entry "{}" - must be a dict' raise ConfigError(msg.format(name)) + # Want to take this entry and add any sweeps in a section + self._extract_global_sweeps(state, entry, sect_ids) + if 'run_name' in entry: value = entry.pop('run_name') logger.debug('Setting run name to "{}"'.format(value)) @@ -177,12 +185,14 @@ class AgendaParser(object): def _process_global_workloads(self, state, global_workloads, seen_wkl_ids): for workload_entry in global_workloads: + find_sweeps(workload_entry, convert=True) # Replace sweep text with sweep type workload = _process_workload_entry(workload_entry, seen_wkl_ids, state.jobs_config) state.jobs_config.add_workload(workload) def _process_sections(self, state, sections, seen_sect_ids, seen_wkl_ids): for section in sections: + find_sweeps(section, convert=True) # Replace sweep text with sweep type workloads = [] for workload_entry in section.pop("workloads", []): workload = _process_workload_entry(workload_entry, seen_wkl_ids, @@ -201,6 +211,44 @@ class AgendaParser(object): "s", state.jobs_config) state.jobs_config.add_section(section, workloads, group) + def _extract_global_sweeps(self, state, entry, seen_sect_ids): + # Find any sweeps in global config + # If there are any sweeps, add them as config inside a new section + sweep_keychains = find_sweeps(entry, convert=True) + + extracted_config = {} + for keychain in sweep_keychains: + # For each sweep found + structure = entry + for key in keychain[:-1]: + structure = structure[key] + # Pop the sweep + sweep = structure.pop(keychain[-1]) + + # Add the sweep location to the new config + dictn = extracted_config + for key in keychain[:-1]: + dictn = dictn.setdefault(key, {}) + dictn[keychain[-1]] = sweep + + if extracted_config: + # Ensure unique name + id_ctr = counter('global_sweeps') + extracted_config['id'] = 'global_sweeps{}'.format( + '' if id_ctr == 1 else '_{}'.format(id_ctr) + ) + grp_ctr = counter('global_sweeps') + group = 'global_sweeps{}'.format( + '' if grp_ctr == 1 else '_{}'.format(grp_ctr) + ) + + glbl_sweeps = _construct_valid_entry( + extracted_config, + seen_sect_ids, + None, + state.jobs_config + ) + state.jobs_config.add_section(glbl_sweeps, [], group) ######################## ### Helper functions ### @@ -345,7 +393,11 @@ def _construct_valid_entry(raw, seen_ids, prefix, jobs_config): for name, cfg_point in JobSpec.configuration.items(): value = pop_aliased_param(cfg_point, raw) if value is not None: - value = cfg_point.kind(value) + if isinstance(value, sweep) and not value.auto: + # If values have already been specified, check the kind + value = sweep(values=map(cfg_point.kind, value)) + else: + value = cfg_point.kind(value) cfg_point.validate_value(name, value) workload_entry[name] = value @@ -394,3 +446,218 @@ def _process_workload_entry(workload, seen_workload_ids, jobs_config): if "workload_name" not in workload: raise ConfigError('No workload name specified in entry {}'.format(workload['id'])) return workload + + +def find_sweeps(raw, convert=False): + keychain = [] + if isinstance(raw, dict): + to_convert = {} + for k, v in raw.items(): + if _is_sweep(k): + if convert: + sweep = _create_sweep(_sweep_handler_name(k), raw.get(k)) + to_convert[k] = sweep + keychain.append((sweep.param_name, )) + else: + keychain.append((k,)) + elif isinstance(v, dict) or isinstance(v, list): + subkeys = find_sweeps(v, convert=convert) + keychain.extend(map(lambda subkey: (k,) + subkey, subkeys)) + + for k, sweep in to_convert.items(): + raw[sweep.param_name] = sweep + raw.pop(k) + + elif isinstance(raw, list): + for index, v in enumerate(raw): + if not isinstance(v, dict): + continue + subkeys = find_sweeps(v, convert=convert) + keychain.extend(map(lambda subkey: (index,) + subkey, subkeys)) + + return keychain + + +def _is_sweep(name): + return name[:6] == 'sweep(' and name[-1] == ')' + + +def _sweep_handler_name(name): + return name[6:-1] + + +range_syntax = r'([0-9]+-[0-9]+)(,[0-9]+)?' + + +def _create_sweep(handler_name, raw_definition): + try: + handler_kind = _sweep_handlers[handler_name] + except KeyError: + msg = '{} is not a valid sweep handler' + raise ConfigError(msg.format(handler_name)) + else: + return sweep(handler=handler_kind(raw_definition)) + + +class SweepHandler(ABC): + ''' + Handles any functionality required for a sweep + ''' + auto = False + def __init__(self, raw_value: dict): + self.raw = deepcopy(raw_value) + self.plugin = None + self.param_name = None + self.values = None + self.parse() + + @abstractmethod + def parse(self): + """ + Extract all information required from ``self.raw``, the + value associated with the sweep key in the config + """ + pass + + +class RangeHandler(SweepHandler): + + def parse(self): + # Only require a param name and values + # Other arguments ignored + if not len(self.raw) == 1: + msg = 'Too many entries for range sweep' + raise ConfigError(msg) + + self.param_name, vals = tuple(self.raw.items())[0] + if isinstance(vals, list): + self.values = vals + if not self.values: + msg = 'At least 1 value must be specified in list for range '\ + 'sweep {}' + raise ConfigError(msg.format(self.param_name)) + elif isinstance(vals, str): + vals = ''.join(vals.split()) + match = re.match(range_syntax, vals) + if match: + start, stop = match[1].split('-') + step = match[2][1:] if match[2] else 1 + start, stop, step = int(start), int(stop), int(step) + self.values = list(range(start, stop, step)) + else: + msg = 'Invalid range sweep format for param {}' + raise ConfigError(msg.format(self.param_name)) + else: + msg = 'Invalid range sweep format for param {}' + raise ConfigError(msg.format(self.param_name)) + + +class AutoSweepHandler(SweepHandler): + + auto = True + + def __init__(self, raw_value): + self.min = None + self.max = None + super().__init__(raw_value) + + def parse(self): + # This ``parse`` method, if called, must be called *after* any child + # parse methods that pop from raw, as it assumes that raw should be + # empty by the end + if self.raw: + try: + min = self.raw.pop('min', None) + max = self.raw.pop('max', None) + if min is not None: + self.min = float(min) + if max is not None: + self.max = float(max) + except ValueError: + min_msg = 'minimum {} '.format(min) if min else '' + max_msg = 'maximum {} '.format(max) if max else '' + connective = 'and ' if min_msg and max_msg else '' + raise ConfigError('Sweep {}{}{}must be numeric'.format( + min_msg, connective, max_msg + )) + else: + # If either is specified, both must be + if min and not max: + self.max = inf + elif max and not min: + self.min = -inf + + self.param_name = self.raw.pop('param', None) + self.plugin = self.raw.pop('plugin', None) + + if self.raw: + msg = 'Too many arguments to sweep definition' + raise ConfigError(msg) + + @abstractmethod + def resolve_auto_sweep(self, tm, pc): + """ + Convert the auto sweep found in ``self.raw`` to + a list of values, to be stored in ``self.values`` + + :param tm: TargetManager + :param pc: PluginCache + """ + pass + + +class FreqHandler(AutoSweepHandler): + + def parse(self): + super().parse() + self.param_name = self.param_name if self.param_name is not None else 'frequency' + + def resolve_auto_sweep(self, tm, pc): + freq_cfg = tm.rpm.get_cfg_point(self.param_name) + allowed_values = None + if hasattr(freq_cfg, 'kind') and hasattr(freq_cfg.kind, 'values'): + allowed_values = freq_cfg.kind.values + elif hasattr(freq_cfg, 'allowed_values'): + allowed_values = freq_cfg.allowed_values + else: + msg = 'Runtime config parameter {} can not be swept' + raise ConfigError(msg.format(self.param_name)) + if self.min: + self.values = list(filter(lambda x: self.min < x < self.max, allowed_values)) + else: + self.values = allowed_values + + +class ParamHandler(AutoSweepHandler): + """ + For auto sweeps of any parameter that specifies + ``allowed_values`` + """ + def parse(self): + super().parse() + # Both plugin and param name must be specified + if self.plugin is None or self.param_name is None: + msg = 'autoparam sweeps require both the plugin ' \ + 'and parameter name to be specified' + raise ConfigError(msg) + + def resolve_auto_sweep(self, tm, pc): + params = pc.get_plugin_parameters(self.plugin) + cfg_pt = params[self.param_name] + if cfg_pt.allowed_values is None: + msg = 'Parameter \'{}\' does not specify allowed values to sweep' + raise ConfigError(msg.format(self.param_name)) + + # If min is not None, then the values are numeric, and both min and max + # will be numeric values + if self.min is not None: + self.values = list(filter(lambda x: self.min < x < self.max, cfg_pt.allowed_values)) + else: + self.values = cfg_pt.allowed_values + + +_sweep_handlers = { + 'range': RangeHandler, + 'autofreq': FreqHandler, + 'autoparam': ParamHandler, +} diff --git a/wa/framework/configuration/tree.py b/wa/framework/configuration/tree.py index 72d457da..c5f50d62 100644 --- a/wa/framework/configuration/tree.py +++ b/wa/framework/configuration/tree.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from copy import copy, deepcopy from wa.utils import log @@ -41,6 +42,9 @@ class JobSpecSource(object): with log.indentcontext(): for key, value in self.config.items(): logger.debug('"{}" to "{}"'.format(key, value)) + + def copy_subtree(self): + raise NotImplementedError() class WorkloadEntry(JobSpecSource): @@ -53,6 +57,8 @@ class WorkloadEntry(JobSpecSource): else: return 'workload "{}" from section "{}"'.format(self.id, self.parent.id) + def copy_subtree(self): + return WorkloadEntry(deepcopy(self.config), self.parent) class SectionNode(JobSpecSource): @@ -107,3 +113,17 @@ class SectionNode(JobSpecSource): for n in self.descendants(): if n.is_leaf: yield n + + def copy_subtree(self): + new_children = [child.copy_subtree() for child in self.children] + new_workloads = copy(self.workload_entries) + new_node = copy(self) + + for wkl in new_workloads: + wkl.parent = new_node + for child in new_children: + child.parent = new_node + + new_node.children = new_children + new_node.workload_entries = new_workloads + return new_node diff --git a/wa/utils/types.py b/wa/utils/types.py index 5968e60c..511709f2 100644 --- a/wa/utils/types.py +++ b/wa/utils/types.py @@ -46,10 +46,10 @@ from future.utils import with_metaclass from devlib.utils.types import identifier, boolean, integer, numeric, caseless_string +from wa.framework.exception import NotFoundError from wa.utils.misc import (isiterable, list_to_ranges, list_to_mask, mask_to_list, ranges_to_list) - def list_of_strs(value): """ Value must be iterable. All elements will be converted to strings. @@ -866,3 +866,41 @@ class cpu_mask(object): def to_pod(self): return {'cpu_mask': self._mask} + + +class sweep: + """ + Used to define a range of values a parameter may sweep through. + """ + + @property + def auto(self): + return self.handler.auto + + @property + def param_name(self): + return self.handler.param_name + + @property + def values(self): + if self._values: + return self._values + elif self.handler.values: + self._values = self.handler.values + return self._values + else: + msg = 'sweep values for param {} not yet generated' + raise NotFoundError(msg.format(self.param_name)) + + def __init__(self, values=None, handler=None): + self._values = list(values) if values is not None else None + self.handler = handler + + def __iter__(self): + return self.values.__iter__() + + def __getitem__(self, index): + return self.values[index] + + def __setitem__(self, index, value): + self.values[index] = value