diff --git a/wlauto/tests/test_utils.py b/wlauto/tests/test_utils.py index 3541b0f4..9e483803 100644 --- a/wlauto/tests/test_utils.py +++ b/wlauto/tests/test_utils.py @@ -17,11 +17,12 @@ # pylint: disable=R0201 from unittest import TestCase -from nose.tools import raises, assert_equal, assert_not_equal # pylint: disable=E0611 +from nose.tools import raises, assert_equal, assert_not_equal, assert_true # pylint: disable=E0611 from wlauto.utils.android import check_output from wlauto.utils.misc import merge_dicts, merge_lists, TimeoutError -from wlauto.utils.types import list_or_integer, list_or_bool, caseless_string, arguments +from wlauto.utils.types import (list_or_integer, list_or_bool, caseless_string, arguments, + ParameterDict) class TestCheckOutput(TestCase): @@ -88,3 +89,126 @@ class TestTypes(TestCase): assert_equal(arguments('--foo 7 --bar "fizz buzz"'), ['--foo', '7', '--bar', 'fizz buzz']) assert_equal(arguments(['test', 42]), ['test', '42']) + +class TestParameterDict(TestCase): + + # Define test parameters + orig_params = { + 'string' : 'A Test String', + 'string_list' : ['A Test', 'List', 'With', '\n in.'], + 'bool_list' : [False, True, True], + 'int' : 42, + 'float' : 1.23, + 'long' : long(987), + 'none' : None, + } + + def setUp(self): + self.params = ParameterDict() + self.params['string'] = self.orig_params['string'] + self.params['string_list'] = self.orig_params['string_list'] + self.params['bool_list'] = self.orig_params['bool_list'] + self.params['int'] = self.orig_params['int'] + self.params['float'] = self.orig_params['float'] + self.params['long'] = self.orig_params['long'] + self.params['none'] = self.orig_params['none'] + + # Test values are encoded correctly + def test_getEncodedItems(self): + encoded = { + 'string' : 'ssA%20Test%20String', + 'string_list' : 'slA%20Test0newelement0List0newelement0With0newelement0%0A%20in.', + 'bool_list' : 'blFalse0newelement0True0newelement0True', + 'int' : 'is42', + 'float' : 'fs1.23', + 'long' : 'ds987', + 'none' : 'nsNone', + } + # Test iter_encoded_items + for k, v in self.params.iter_encoded_items(): + assert_equal(v, encoded[k]) + + # Test get single encoded value + assert_equal(self.params.get_encoded_value('string'), encoded['string']) + assert_equal(self.params.get_encoded_value('string_list'), encoded['string_list']) + assert_equal(self.params.get_encoded_value('bool_list'), encoded['bool_list']) + assert_equal(self.params.get_encoded_value('int'), encoded['int']) + assert_equal(self.params.get_encoded_value('float'), encoded['float']) + assert_equal(self.params.get_encoded_value('long'), encoded['long']) + assert_equal(self.params.get_encoded_value('none'), encoded['none']) + + # Test it behaves like a normal dict + def test_getitem(self): + assert_equal(self.params['string'], self.orig_params['string']) + assert_equal(self.params['string_list'], self.orig_params['string_list']) + assert_equal(self.params['bool_list'], self.orig_params['bool_list']) + assert_equal(self.params['int'], self.orig_params['int']) + assert_equal(self.params['float'], self.orig_params['float']) + assert_equal(self.params['long'], self.orig_params['long']) + assert_equal(self.params['none'], self.orig_params['none']) + + def test_get(self): + assert_equal(self.params.get('string'), self.orig_params['string']) + assert_equal(self.params.get('string_list'), self.orig_params['string_list']) + assert_equal(self.params.get('bool_list'), self.orig_params['bool_list']) + assert_equal(self.params.get('int'), self.orig_params['int']) + assert_equal(self.params.get('float'), self.orig_params['float']) + assert_equal(self.params.get('long'), self.orig_params['long']) + assert_equal(self.params.get('none'), self.orig_params['none']) + + def test_contains(self): + assert_true(self.orig_params['string'] in self.params.values()) + assert_true(self.orig_params['string_list'] in self.params.values()) + assert_true(self.orig_params['bool_list'] in self.params.values()) + assert_true(self.orig_params['int'] in self.params.values()) + assert_true(self.orig_params['float'] in self.params.values()) + assert_true(self.orig_params['long'] in self.params.values()) + assert_true(self.orig_params['none'] in self.params.values()) + + def test_pop(self): + assert_equal(self.params.pop('string'), self.orig_params['string']) + assert_equal(self.params.pop('string_list'), self.orig_params['string_list']) + assert_equal(self.params.pop('bool_list'), self.orig_params['bool_list']) + assert_equal(self.params.pop('int'), self.orig_params['int']) + assert_equal(self.params.pop('float'), self.orig_params['float']) + assert_equal(self.params.pop('long'), self.orig_params['long']) + assert_equal(self.params.pop('none'), self.orig_params['none']) + + self.params['string'] = self.orig_params['string'] + assert_equal(self.params.popitem(), ('string', self.orig_params['string'])) + + def test_iteritems(self): + for k, v in self.params.iteritems(): + assert_equal(v, self.orig_params[k]) + + def test_parameter_dict_update(self): + params_1 = ParameterDict() + params_2 = ParameterDict() + + # Test two ParameterDicts + params_1['string'] = self.orig_params['string'] + params_1['string_list'] = self.orig_params['string_list'] + params_1['bool_list'] = self.orig_params['bool_list'] + params_2['int'] = self.orig_params['int'] + params_2['float'] = self.orig_params['float'] + params_2['long'] = self.orig_params['long'] + params_2['none'] = self.orig_params['none'] + + params_1.update(params_2) + assert_equal(params_1, self.params) + + # Test update with normal dict + params_3 = ParameterDict() + std_dict = dict() + + params_3['string'] = self.orig_params['string'] + std_dict['string_list'] = self.orig_params['string_list'] + std_dict['bool_list'] = self.orig_params['bool_list'] + std_dict['int'] = self.orig_params['int'] + std_dict['float'] = self.orig_params['float'] + std_dict['long'] = self.orig_params['long'] + std_dict['none'] = self.orig_params['none'] + + params_3.update(std_dict) + for key in params_3.keys(): + assert_equal(params_3[key], self.params[key]) diff --git a/wlauto/utils/types.py b/wlauto/utils/types.py index 66c97998..14254599 100644 --- a/wlauto/utils/types.py +++ b/wlauto/utils/types.py @@ -30,6 +30,7 @@ import re import math import shlex from collections import defaultdict +from urllib import quote, unquote from wlauto.utils.misc import isiterable, to_identifier @@ -328,3 +329,119 @@ class range_dict(dict): def __setitem__(self, i, v): i = int(i) super(range_dict, self).__setitem__(i, v) + + +class ParameterDict(dict): + """ + A dict-like object that automatically encodes various types into a url safe string, + and enforces a single type for the contents in a list. + Each value is first prefixed with 2 letters to preserve type when encoding to a string. + The format used is "value_type, value_dimension" e.g a 'list of floats' would become 'fl'. + """ + + # Function to determine the appropriate prefix based on the parameters type + @staticmethod + def _get_prefix(obj): + if isinstance(obj, basestring): + prefix = 's' + elif isinstance(obj, float): + prefix = 'f' + elif isinstance(obj, long): + prefix = 'd' + elif isinstance(obj, bool): + prefix = 'b' + elif isinstance(obj, int): + prefix = 'i' + elif obj is None: + prefix = 'n' + else: + raise ValueError('Unable to encode {} {}'.format(obj, type(obj))) + return prefix + + # Function to add prefix and urlencode a provided parameter. + @staticmethod + def _encode(obj): + if isinstance(obj, list): + t = type(obj[0]) + prefix = ParameterDict._get_prefix(obj[0]) + 'l' + for item in obj: + if not isinstance(item, t): + msg = 'Lists must only contain a single type, contains {} and {}' + raise ValueError(msg.format(t, type(item))) + obj = '0newelement0'.join(str(x) for x in obj) + else: + prefix = ParameterDict._get_prefix(obj) + 's' + return quote(prefix + str(obj)) + + # Function to decode a string and return a value of the original parameter type. + # pylint: disable=too-many-return-statements + @staticmethod + def _decode(string): + value_type = string[:1] + value_dimension = string[1:2] + value = unquote(string[2:]) + if value_dimension == 's': + if value_type == 's': + return str(value) + elif value_type == 'b': + return boolean(value) + elif value_type == 'd': + return long(value) + elif value_type == 'f': + return float(value) + elif value_type == 'i': + return int(value) + elif value_type == 'n': + return None + elif value_dimension == 'l': + return [ParameterDict._decode(value_type + 's' + x) + for x in value.split('0newelement0')] + else: + raise ValueError('Unknown {} {}'.format(type(string), string)) + + def __init__(self, *args, **kwargs): + for k, v in kwargs.iteritems(): + self.__setitem__(k, v) + dict.__init__(self, *args) + + def __setitem__(self, name, value): + dict.__setitem__(self, name, self._encode(value)) + + def __getitem__(self, name): + return self._decode(dict.__getitem__(self, name)) + + def __contains__(self, item): + return dict.__contains__(self, self._encode(item)) + + def __iter__(self): + return iter((k, self._decode(v)) for (k, v) in self.items()) + + def iteritems(self): + return self.__iter__() + + def get(self, name): + return self[name] + + def pop(self, key): + return self._decode(dict.pop(self, key)) + + def popitem(self): + key, value = dict.popitem(self) + return (key, self._decode(value)) + + def iter_encoded_items(self): + return dict.iteritems(self) + + def get_encoded_value(self, name): + return dict.__getitem__(self, name) + + def values(self): + return [self[k] for k in dict.keys(self)] + + def update(self, *args, **kwargs): + for d in list(args) + [kwargs]: + if isinstance(d, ParameterDict): + dict.update(self, d) + else: + for k, v in d.iteritems(): + self[k] = v