diff --git a/tests/test_corrector.py b/tests/test_corrector.py index 7142aa1e..2c58a017 100644 --- a/tests/test_corrector.py +++ b/tests/test_corrector.py @@ -41,40 +41,37 @@ class TestGetRules(object): rules) -class TestGetMatchedRules(object): - def test_no_match(self): - assert list(corrector.get_matched_rules( - Command('ls'), [Rule('', lambda *_: False)], - Mock(no_colors=True))) == [] +class TestIsRuleMatch(object): + def test_no_match(self, settings): + assert not corrector.is_rule_match( + Command('ls'), Rule('', lambda *_: False), settings) - def test_match(self): + def test_match(self, settings): rule = Rule('', lambda x, _: x.script == 'cd ..') - assert list(corrector.get_matched_rules( - Command('cd ..'), [rule], Mock(no_colors=True))) == [rule] + assert corrector.is_rule_match(Command('cd ..'), rule, settings) - def test_when_rule_failed(self, capsys): - all(corrector.get_matched_rules( - Command('ls'), [Rule('test', Mock(side_effect=OSError('Denied')), - requires_output=False)], - Mock(no_colors=True, debug=False))) + def test_when_rule_failed(self, capsys, settings): + rule = Rule('test', Mock(side_effect=OSError('Denied')), + requires_output=False) + assert not corrector.is_rule_match( + Command('ls'), rule, settings) assert capsys.readouterr()[1].split('\n')[0] == '[WARN] Rule test:' -class TestGetCorrectedCommands(object): +class TestMakeCorrectedCommands(object): def test_with_rule_returns_list(self): rule = Rule(get_new_command=lambda x, _: [x.script + '!', x.script + '@'], priority=100) - assert list(make_corrected_commands(Command(script='test'), [rule], None)) \ + assert list(make_corrected_commands(Command(script='test'), rule, None)) \ == [CorrectedCommand(script='test!', priority=100), CorrectedCommand(script='test@', priority=200)] def test_with_rule_returns_command(self): rule = Rule(get_new_command=lambda x, _: x.script + '!', priority=100) - assert list(make_corrected_commands(Command(script='test'), [rule], None)) \ + assert list(make_corrected_commands(Command(script='test'), rule, None)) \ == [CorrectedCommand(script='test!', priority=100)] - def test_get_corrected_commands(mocker): command = Command('test', 'test', 'test') rules = [Rule(match=lambda *_: False), diff --git a/tests/test_types.py b/tests/test_types.py index c17f7a58..e8028aa1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -23,7 +23,6 @@ class TestSortedCorrectedCommandsSequence(object): should_realise = False def gen(): - nonlocal should_realise yield CorrectedCommand('git commit') yield CorrectedCommand('git branch', priority=200) assert should_realise diff --git a/thefuck/corrector.py b/thefuck/corrector.py index a5216c5c..6fcaf42a 100644 --- a/thefuck/corrector.py +++ b/thefuck/corrector.py @@ -38,36 +38,35 @@ def get_rules(user_dir, settings): key=lambda rule: rule.priority) -def get_matched_rules(command, rules, settings): +def is_rule_match(command, rule, settings): """Returns first matched rule for command.""" script_only = command.stdout is None and command.stderr is None - for rule in rules: - if script_only and rule.requires_output: - continue + if script_only and rule.requires_output: + return False - try: - with logs.debug_time(u'Trying rule: {};'.format(rule.name), - settings): - if rule.match(command, settings): - yield rule - except Exception: - logs.rule_failed(rule, sys.exc_info(), settings) + try: + with logs.debug_time(u'Trying rule: {};'.format(rule.name), + settings): + if rule.match(command, settings): + return True + except Exception: + logs.rule_failed(rule, sys.exc_info(), settings) -def make_corrected_commands(command, rules, settings): - for rule in rules: - new_commands = rule.get_new_command(command, settings) - if not isinstance(new_commands, list): - new_commands = (new_commands,) - for n, new_command in enumerate(new_commands): - yield CorrectedCommand(script=new_command, - side_effect=rule.side_effect, - priority=(n + 1) * rule.priority) +def make_corrected_commands(command, rule, settings): + new_commands = rule.get_new_command(command, settings) + if not isinstance(new_commands, list): + new_commands = (new_commands,) + for n, new_command in enumerate(new_commands): + yield types.CorrectedCommand(script=new_command, + side_effect=rule.side_effect, + priority=(n + 1) * rule.priority) def get_corrected_commands(command, user_dir, settings): - rules = get_rules(user_dir, settings) - matched = get_matched_rules(command, rules, settings) - corrected_commands = make_corrected_commands(command, matched, settings) + corrected_commands = ( + corrected for rule in get_rules(user_dir, settings) + if is_rule_match(command, rule, settings) + for corrected in make_corrected_commands(command, rule, settings)) return types.SortedCorrectedCommandsSequence(corrected_commands, settings)