diff --git a/tests/rules/test_git_diff_staged.py b/tests/rules/test_git_diff_staged.py index a4c62a37..364cb267 100644 --- a/tests/rules/test_git_diff_staged.py +++ b/tests/rules/test_git_diff_staged.py @@ -3,7 +3,9 @@ from thefuck.rules.git_diff_staged import match, get_new_command from tests.utils import Command -@pytest.mark.parametrize('command', [Command(script='git diff')]) +@pytest.mark.parametrize('command', [ + Command(script='git diff foo'), + Command(script='git diff')]) def test_match(command): assert match(command, None) @@ -18,6 +20,7 @@ def test_not_match(command): @pytest.mark.parametrize('command, new_command', [ - (Command('git diff'), 'git diff --staged')]) + (Command('git diff'), 'git diff --staged'), + (Command('git diff foo'), 'git diff --staged foo')]) def test_get_new_command(command, new_command): assert get_new_command(command, None) == new_command diff --git a/thefuck/rules/git_diff_staged.py b/thefuck/rules/git_diff_staged.py index a35234e4..c879cf40 100644 --- a/thefuck/rules/git_diff_staged.py +++ b/thefuck/rules/git_diff_staged.py @@ -10,4 +10,4 @@ def match(command, settings): @utils.git_support def get_new_command(command, settings): - return '{} --staged'.format(command.script) + return command.script.replace(' diff', ' diff --staged')