mirror of
				https://github.com/nvbn/thefuck.git
				synced 2025-10-30 22:54:14 +00:00 
			
		
		
		
	Move rule-related code to Rule
				
					
				
			This commit is contained in:
		| @@ -2,7 +2,6 @@ import pytest | ||||
| import os | ||||
| from thefuck.rules.fix_file import match, get_new_command | ||||
| from tests.utils import Command | ||||
| from thefuck.types import Settings | ||||
|  | ||||
|  | ||||
| # (script, file, line, col (or None), stdout, stderr) | ||||
|   | ||||
| @@ -1,42 +1,8 @@ | ||||
| import pytest | ||||
| from pathlib import PosixPath, Path | ||||
| from mock import Mock | ||||
| from pathlib import PosixPath | ||||
| from thefuck import corrector, conf | ||||
| from tests.utils import Rule, Command, CorrectedCommand | ||||
| from thefuck.corrector import make_corrected_commands, get_corrected_commands,\ | ||||
|     is_rule_enabled, organize_commands | ||||
|  | ||||
|  | ||||
| def test_load_rule(mocker): | ||||
|     match = object() | ||||
|     get_new_command = object() | ||||
|     load_source = mocker.patch( | ||||
|         'thefuck.corrector.load_source', | ||||
|         return_value=Mock(match=match, | ||||
|                           get_new_command=get_new_command, | ||||
|                           enabled_by_default=True, | ||||
|                           priority=900, | ||||
|                           requires_output=True)) | ||||
|     assert corrector.load_rule(Path('/rules/bash.py')) \ | ||||
|            == Rule('bash', match, get_new_command, priority=900) | ||||
|     load_source.assert_called_once_with('bash', '/rules/bash.py') | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize('rules, exclude_rules, rule, is_enabled', [ | ||||
|     (conf.DEFAULT_RULES, [], Rule('git', enabled_by_default=True), True), | ||||
|     (conf.DEFAULT_RULES, [], Rule('git', enabled_by_default=False), False), | ||||
|     ([], [], Rule('git', enabled_by_default=False), False), | ||||
|     ([], [], Rule('git', enabled_by_default=True), False), | ||||
|     (conf.DEFAULT_RULES + ['git'], [], Rule('git', enabled_by_default=False), True), | ||||
|     (['git'], [], Rule('git', enabled_by_default=False), True), | ||||
|     (conf.DEFAULT_RULES, ['git'], Rule('git', enabled_by_default=True), False), | ||||
|     (conf.DEFAULT_RULES, ['git'], Rule('git', enabled_by_default=False), False), | ||||
|     ([], ['git'], Rule('git', enabled_by_default=True), False), | ||||
|     ([], ['git'], Rule('git', enabled_by_default=False), False)]) | ||||
| def test_is_rule_enabled(settings, rules, exclude_rules, rule, is_enabled): | ||||
|     settings.update(rules=rules, | ||||
|                     exclude_rules=exclude_rules) | ||||
|     assert is_rule_enabled(rule) == is_enabled | ||||
| from thefuck.corrector import get_corrected_commands, organize_commands | ||||
|  | ||||
|  | ||||
| class TestGetRules(object): | ||||
| @@ -49,7 +15,7 @@ class TestGetRules(object): | ||||
|  | ||||
|     @pytest.fixture(autouse=True) | ||||
|     def load_source(self, monkeypatch): | ||||
|         monkeypatch.setattr('thefuck.corrector.load_source', | ||||
|         monkeypatch.setattr('thefuck.types.load_source', | ||||
|                             lambda x, _: Rule(x)) | ||||
|  | ||||
|     def _compare_names(self, rules, names): | ||||
| @@ -70,37 +36,6 @@ class TestGetRules(object): | ||||
|         self._compare_names(rules, loaded_rules) | ||||
|  | ||||
|  | ||||
| class TestIsRuleMatch(object): | ||||
|     def test_no_match(self): | ||||
|         assert not corrector.is_rule_match( | ||||
|             Command('ls'), Rule('', lambda _: False)) | ||||
|  | ||||
|     def test_match(self): | ||||
|         rule = Rule('', lambda x: x.script == 'cd ..') | ||||
|         assert corrector.is_rule_match(Command('cd ..'), rule) | ||||
|  | ||||
|     @pytest.mark.usefixtures('no_colors') | ||||
|     def test_when_rule_failed(self, capsys): | ||||
|         rule = Rule('test', Mock(side_effect=OSError('Denied')), | ||||
|                     requires_output=False) | ||||
|         assert not corrector.is_rule_match(Command('ls'), rule) | ||||
|         assert capsys.readouterr()[1].split('\n')[0] == '[WARN] Rule test:' | ||||
|  | ||||
|  | ||||
| class TestMakeCorrectedCommands(object): | ||||
|     def test_with_rule_returns_list(self): | ||||
|         rule = Rule(get_new_command=lambda x: [x.script + '!', x.script + '@'], | ||||
|                     priority=100) | ||||
|         assert list(make_corrected_commands(Command(script='test'), rule)) \ | ||||
|                == [CorrectedCommand(script='test!', priority=100), | ||||
|                    CorrectedCommand(script='test@', priority=200)] | ||||
|  | ||||
|     def test_with_rule_returns_command(self): | ||||
|         rule = Rule(get_new_command=lambda x: x.script + '!', | ||||
|                     priority=100) | ||||
|         assert list(make_corrected_commands(Command(script='test'), rule)) \ | ||||
|                == [CorrectedCommand(script='test!', priority=100)] | ||||
|  | ||||
| def test_get_corrected_commands(mocker): | ||||
|     command = Command('test', 'test', 'test') | ||||
|     rules = [Rule(match=lambda _: False), | ||||
|   | ||||
| @@ -1,4 +1,8 @@ | ||||
| from tests.utils import CorrectedCommand | ||||
| from mock import Mock | ||||
| from pathlib import Path | ||||
| import pytest | ||||
| from tests.utils import CorrectedCommand, Rule, Command | ||||
| from thefuck import conf | ||||
|  | ||||
|  | ||||
| class TestCorrectedCommand(object): | ||||
| @@ -12,3 +16,63 @@ class TestCorrectedCommand(object): | ||||
|     def test_hashable(self): | ||||
|         assert {CorrectedCommand('ls', None, 100), | ||||
|                 CorrectedCommand('ls', None, 200)} == {CorrectedCommand('ls')} | ||||
|  | ||||
|  | ||||
| class TestRule(object): | ||||
|     def test_from_path(self, mocker): | ||||
|         match = object() | ||||
|         get_new_command = object() | ||||
|         load_source = mocker.patch( | ||||
|             'thefuck.types.load_source', | ||||
|             return_value=Mock(match=match, | ||||
|                               get_new_command=get_new_command, | ||||
|                               enabled_by_default=True, | ||||
|                               priority=900, | ||||
|                               requires_output=True)) | ||||
|         assert Rule.from_path(Path('/rules/bash.py')) \ | ||||
|                == Rule('bash', match, get_new_command, priority=900) | ||||
|         load_source.assert_called_once_with('bash', '/rules/bash.py') | ||||
|  | ||||
|     @pytest.mark.parametrize('rules, exclude_rules, rule, is_enabled', [ | ||||
|         (conf.DEFAULT_RULES, [], Rule('git', enabled_by_default=True), True), | ||||
|         (conf.DEFAULT_RULES, [], Rule('git', enabled_by_default=False), False), | ||||
|         ([], [], Rule('git', enabled_by_default=False), False), | ||||
|         ([], [], Rule('git', enabled_by_default=True), False), | ||||
|         (conf.DEFAULT_RULES + ['git'], [], Rule('git', enabled_by_default=False), True), | ||||
|         (['git'], [], Rule('git', enabled_by_default=False), True), | ||||
|         (conf.DEFAULT_RULES, ['git'], Rule('git', enabled_by_default=True), False), | ||||
|         (conf.DEFAULT_RULES, ['git'], Rule('git', enabled_by_default=False), False), | ||||
|         ([], ['git'], Rule('git', enabled_by_default=True), False), | ||||
|         ([], ['git'], Rule('git', enabled_by_default=False), False)]) | ||||
|     def test_is_enabled(self, settings, rules, exclude_rules, rule, is_enabled): | ||||
|         settings.update(rules=rules, | ||||
|                         exclude_rules=exclude_rules) | ||||
|         assert rule.is_enabled == is_enabled | ||||
|  | ||||
|     def test_isnt_match(self): | ||||
|         assert not Rule('', lambda _: False).is_match( | ||||
|             Command('ls')) | ||||
|  | ||||
|     def test_is_match(self): | ||||
|         rule = Rule('', lambda x: x.script == 'cd ..') | ||||
|         assert rule.is_match(Command('cd ..')) | ||||
|  | ||||
|     @pytest.mark.usefixtures('no_colors') | ||||
|     def test_isnt_match_when_rule_failed(self, capsys): | ||||
|         rule = Rule('test', Mock(side_effect=OSError('Denied')), | ||||
|                     requires_output=False) | ||||
|         assert not rule.is_match(Command('ls')) | ||||
|         assert capsys.readouterr()[1].split('\n')[0] == '[WARN] Rule test:' | ||||
|  | ||||
|     def test_get_corrected_commands_with_rule_returns_list(self): | ||||
|         rule = Rule(get_new_command=lambda x: [x.script + '!', x.script + '@'], | ||||
|                     priority=100) | ||||
|         assert list(rule.get_corrected_commands(Command(script='test'))) \ | ||||
|                == [CorrectedCommand(script='test!', priority=100), | ||||
|                    CorrectedCommand(script='test@', priority=200)] | ||||
|  | ||||
|     def test_get_corrected_commands_with_rule_returns_command(self): | ||||
|         rule = Rule(get_new_command=lambda x: x.script + '!', | ||||
|                     priority=100) | ||||
|         assert list(rule.get_corrected_commands(Command(script='test'))) \ | ||||
|                == [CorrectedCommand(script='test!', priority=100)] | ||||
|   | ||||
| @@ -7,19 +7,22 @@ def Command(script='', stdout='', stderr=''): | ||||
|     return types.Command(script, stdout, stderr) | ||||
|  | ||||
|  | ||||
| def Rule(name='', match=lambda *_: True, | ||||
|          get_new_command=lambda *_: '', | ||||
|          enabled_by_default=True, | ||||
|          side_effect=None, | ||||
|          priority=DEFAULT_PRIORITY, | ||||
|          requires_output=True): | ||||
|     return types.Rule(name, match, get_new_command, | ||||
|                       enabled_by_default, side_effect, | ||||
|                       priority, requires_output) | ||||
| class Rule(types.Rule): | ||||
|     def __init__(self, name='', match=lambda *_: True, | ||||
|                  get_new_command=lambda *_: '', | ||||
|                  enabled_by_default=True, | ||||
|                  side_effect=None, | ||||
|                  priority=DEFAULT_PRIORITY, | ||||
|                  requires_output=True): | ||||
|         super(Rule, self).__init__(name, match, get_new_command, | ||||
|                                    enabled_by_default, side_effect, | ||||
|                                    priority, requires_output) | ||||
|  | ||||
|  | ||||
| def CorrectedCommand(script='', side_effect=None, priority=DEFAULT_PRIORITY): | ||||
|     return types.CorrectedCommand(script, side_effect, priority) | ||||
| class CorrectedCommand(types.CorrectedCommand): | ||||
|     def __init__(self, script='', side_effect=None, priority=DEFAULT_PRIORITY): | ||||
|         super(CorrectedCommand, self).__init__( | ||||
|             script, side_effect, priority) | ||||
|  | ||||
|  | ||||
| root = Path(__file__).parent.parent.resolve() | ||||
|   | ||||
| @@ -2,7 +2,15 @@ from imp import load_source | ||||
| import os | ||||
| import sys | ||||
| from six import text_type | ||||
| from .types import Settings | ||||
|  | ||||
|  | ||||
| class Settings(dict): | ||||
|     def __getattr__(self, item): | ||||
|         return self.get(item) | ||||
|  | ||||
|     def __setattr__(self, key, value): | ||||
|         self[key] = value | ||||
|  | ||||
|  | ||||
| ALL_ENABLED = object() | ||||
| DEFAULT_RULES = [ALL_ENABLED] | ||||
|   | ||||
| @@ -1,45 +1,16 @@ | ||||
| import sys | ||||
| from imp import load_source | ||||
| from pathlib import Path | ||||
| from .conf import settings, DEFAULT_PRIORITY, ALL_ENABLED | ||||
| from .types import Rule, CorrectedCommand | ||||
| from .utils import compatibility_call | ||||
| from .conf import settings | ||||
| from .types import Rule | ||||
| from . import logs | ||||
|  | ||||
|  | ||||
| def load_rule(rule): | ||||
|     """Imports rule module and returns it.""" | ||||
|     name = rule.name[:-3] | ||||
|     with logs.debug_time(u'Importing rule: {};'.format(name)): | ||||
|         rule_module = load_source(name, str(rule)) | ||||
|         priority = getattr(rule_module, 'priority', DEFAULT_PRIORITY) | ||||
|     return Rule(name, rule_module.match, | ||||
|                 rule_module.get_new_command, | ||||
|                 getattr(rule_module, 'enabled_by_default', True), | ||||
|                 getattr(rule_module, 'side_effect', None), | ||||
|                 settings.priority.get(name, priority), | ||||
|                 getattr(rule_module, 'requires_output', True)) | ||||
|  | ||||
|  | ||||
| def is_rule_enabled(rule): | ||||
|     """Returns `True` when rule enabled.""" | ||||
|     if rule.name in settings.exclude_rules: | ||||
|         return False | ||||
|     elif rule.name in settings.rules: | ||||
|         return True | ||||
|     elif rule.enabled_by_default and ALL_ENABLED in settings.rules: | ||||
|         return True | ||||
|     else: | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def get_loaded_rules(rules): | ||||
| def get_loaded_rules(rules_paths): | ||||
|     """Yields all available rules.""" | ||||
|     for rule in rules: | ||||
|         if rule.name != '__init__.py': | ||||
|             loaded_rule = load_rule(rule) | ||||
|             if is_rule_enabled(loaded_rule): | ||||
|                 yield loaded_rule | ||||
|     for path in rules_paths: | ||||
|         if path.name != '__init__.py': | ||||
|             rule = Rule.from_path(path) | ||||
|             if rule.is_enabled: | ||||
|                 yield rule | ||||
|  | ||||
|  | ||||
| def get_rules(): | ||||
| @@ -52,30 +23,6 @@ def get_rules(): | ||||
|                   key=lambda rule: rule.priority) | ||||
|  | ||||
|  | ||||
| def is_rule_match(command, rule): | ||||
|     """Returns first matched rule for command.""" | ||||
|     script_only = command.stdout is None and command.stderr is None | ||||
|  | ||||
|     if script_only and rule.requires_output: | ||||
|         return False | ||||
|  | ||||
|     try: | ||||
|         with logs.debug_time(u'Trying rule: {};'.format(rule.name)): | ||||
|             if compatibility_call(rule.match, command): | ||||
|                 return True | ||||
|     except Exception: | ||||
|         logs.rule_failed(rule, sys.exc_info()) | ||||
|  | ||||
|  | ||||
| def make_corrected_commands(command, rule): | ||||
|     new_commands = compatibility_call(rule.get_new_command, command) | ||||
|     if not isinstance(new_commands, list): | ||||
|         new_commands = (new_commands,) | ||||
|     for n, new_command in enumerate(new_commands): | ||||
|         yield CorrectedCommand(script=new_command, | ||||
|                                side_effect=rule.side_effect, | ||||
|                                priority=(n + 1) * rule.priority) | ||||
|  | ||||
| def organize_commands(corrected_commands): | ||||
|     """Yields sorted commands without duplicates.""" | ||||
|     try: | ||||
| @@ -103,6 +50,6 @@ def organize_commands(corrected_commands): | ||||
| def get_corrected_commands(command): | ||||
|     corrected_commands = ( | ||||
|         corrected for rule in get_rules() | ||||
|         if is_rule_match(command, rule) | ||||
|         for corrected in make_corrected_commands(command, rule)) | ||||
|         if rule.is_match(command) | ||||
|         for corrected in rule.get_corrected_commands(command)) | ||||
|     return organize_commands(corrected_commands) | ||||
|   | ||||
| @@ -1,10 +1,91 @@ | ||||
| from collections import namedtuple | ||||
| from imp import load_source | ||||
| import sys | ||||
| from .conf import settings, DEFAULT_PRIORITY, ALL_ENABLED | ||||
| from .utils import compatibility_call | ||||
| from . import logs | ||||
|  | ||||
| Command = namedtuple('Command', ('script', 'stdout', 'stderr')) | ||||
|  | ||||
| Rule = namedtuple('Rule', ('name', 'match', 'get_new_command', | ||||
|                            'enabled_by_default', 'side_effect', | ||||
|                            'priority', 'requires_output')) | ||||
|  | ||||
| class Rule(object): | ||||
|     def __init__(self, name, match, get_new_command, | ||||
|                  enabled_by_default, side_effect, | ||||
|                  priority, requires_output): | ||||
|         self.name = name | ||||
|         self.match = match | ||||
|         self.get_new_command = get_new_command | ||||
|         self.enabled_by_default = enabled_by_default | ||||
|         self.side_effect = side_effect | ||||
|         self.priority = priority | ||||
|         self.requires_output = requires_output | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         if isinstance(other, Rule): | ||||
|             return (self.name, self.match, self.get_new_command, | ||||
|                     self.enabled_by_default, self.side_effect, | ||||
|                     self.priority, self.requires_output) \ | ||||
|                    == (other.name, other.match, other.get_new_command, | ||||
|                        other.enabled_by_default, other.side_effect, | ||||
|                        other.priority, other.requires_output) | ||||
|         else: | ||||
|             return False | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return 'Rule(name={}, match={}, get_new_command={}, ' \ | ||||
|                'enabled_by_default={}, side_effect={}, ' \ | ||||
|                'priority={}, requires_output)'.format( | ||||
|                     self.name, self.match, self.get_new_command, | ||||
|                     self.enabled_by_default, self.side_effect, | ||||
|                     self.priority, self.requires_output) | ||||
|  | ||||
|     @classmethod | ||||
|     def from_path(cls, path): | ||||
|         """Creates rule instance from path.""" | ||||
|         name = path.name[:-3] | ||||
|         with logs.debug_time(u'Importing rule: {};'.format(name)): | ||||
|             rule_module = load_source(name, str(path)) | ||||
|             priority = getattr(rule_module, 'priority', DEFAULT_PRIORITY) | ||||
|         return cls(name, rule_module.match, | ||||
|                    rule_module.get_new_command, | ||||
|                    getattr(rule_module, 'enabled_by_default', True), | ||||
|                    getattr(rule_module, 'side_effect', None), | ||||
|                    settings.priority.get(name, priority), | ||||
|                    getattr(rule_module, 'requires_output', True)) | ||||
|  | ||||
|     @property | ||||
|     def is_enabled(self): | ||||
|         if self.name in settings.exclude_rules: | ||||
|             return False | ||||
|         elif self.name in settings.rules: | ||||
|             return True | ||||
|         elif self.enabled_by_default and ALL_ENABLED in settings.rules: | ||||
|             return True | ||||
|         else: | ||||
|             return False | ||||
|  | ||||
|     def is_match(self, command): | ||||
|         """Returns `True` if rule matches the command.""" | ||||
|         script_only = command.stdout is None and command.stderr is None | ||||
|  | ||||
|         if script_only and self.requires_output: | ||||
|             return False | ||||
|  | ||||
|         try: | ||||
|             with logs.debug_time(u'Trying rule: {};'.format(self.name)): | ||||
|                 if compatibility_call(self.match, command): | ||||
|                     return True | ||||
|         except Exception: | ||||
|             logs.rule_failed(self, sys.exc_info()) | ||||
|  | ||||
|     def get_corrected_commands(self, command): | ||||
|         new_commands = compatibility_call(self.get_new_command, command) | ||||
|         if not isinstance(new_commands, list): | ||||
|             new_commands = (new_commands,) | ||||
|         for n, new_command in enumerate(new_commands): | ||||
|             yield CorrectedCommand(script=new_command, | ||||
|                                    side_effect=self.side_effect, | ||||
|                                    priority=(n + 1) * self.priority) | ||||
|  | ||||
|  | ||||
| class CorrectedCommand(object): | ||||
| @@ -27,11 +108,3 @@ class CorrectedCommand(object): | ||||
|     def __repr__(self): | ||||
|         return 'CorrectedCommand(script={}, side_effect={}, priority={})'.format( | ||||
|             self.script, self.side_effect, self.priority) | ||||
|  | ||||
|  | ||||
| class Settings(dict): | ||||
|     def __getattr__(self, item): | ||||
|         return self.get(item) | ||||
|  | ||||
|     def __setattr__(self, key, value): | ||||
|         self[key] = value | ||||
|   | ||||
		Reference in New Issue
	
	Block a user