diff --git a/README.md b/README.md index 950214e2..46df28be 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ using the matched rule and runs it. Rules enabled by default are as follows: * `git_add` – fix *"Did you forget to 'git add'?"*; * `git_branch_delete` – changes `git branch -d` to `git branch -D`; * `git_branch_list` – catches `git branch list` in place of `git branch` and removes created branch; -* `git_checkout` – creates the branch before checking-out; +* `git_checkout` – fixes branch name or creates new branch; * `git_diff_staged` – adds `--staged` to previous `git diff` with unexpected output; * `git_no_command` – fixes wrong git commands like `git brnch`; * `git_pull` – sets upstream before executing previous `git pull`; diff --git a/tests/rules/test_git_checkout.py b/tests/rules/test_git_checkout.py index a540b62d..2b8b2bb7 100644 --- a/tests/rules/test_git_checkout.py +++ b/tests/rules/test_git_checkout.py @@ -12,6 +12,11 @@ def did_not_match(target, did_you_forget=False): return error +@pytest.fixture +def get_branches(mocker): + return mocker.patch('thefuck.rules.git_checkout') + + @pytest.mark.parametrize('command', [ Command(script='git checkout unknown', stderr=did_not_match('unknown')), Command(script='git commit unknown', stderr=did_not_match('unknown'))]) @@ -28,10 +33,19 @@ def test_not_match(command): assert not match(command, None) -@pytest.mark.parametrize('command, new_command', [ - (Command(script='git checkout unknown', stderr=did_not_match('unknown')), +@pytest.mark.parametrize('branches, command, new_command', [ + ([], + Command(script='git checkout unknown', stderr=did_not_match('unknown')), 'git branch unknown && git checkout unknown'), - (Command('git commit unknown', stderr=did_not_match('unknown')), - 'git branch unknown && git commit unknown')]) -def test_get_new_command(command, new_command): + ([], + Command('git commit unknown', stderr=did_not_match('unknown')), + 'git branch unknown && git commit unknown'), + (['master'], + Command(script='git checkout amster', stderr=did_not_match('amster')), + 'git checkout master'), + (['master'], + Command(script='git commit amster', stderr=did_not_match('amster')), + 'git commit master')]) +def test_get_new_command(branches, command, new_command, get_branches): + get_branches.return_value = branches assert get_new_command(command, None) == new_command diff --git a/tests/test_utils.py b/tests/test_utils.py index b724bca2..b6a02fcb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -59,3 +59,7 @@ class TestGetClosest(object): def test_when_cant_match(self): assert 'status' == get_closest('st', ['status', 'reset']) + + def test_without_fallback(self): + assert get_closest('st', ['status', 'reset'], + fallback_to_first=False) is None diff --git a/thefuck/rules/git_checkout.py b/thefuck/rules/git_checkout.py index 9a84ec95..c3a3fce8 100644 --- a/thefuck/rules/git_checkout.py +++ b/thefuck/rules/git_checkout.py @@ -1,4 +1,5 @@ import re +import subprocess from thefuck import shells, utils @@ -9,11 +10,28 @@ def match(command, settings): and "Did you forget to 'git add'?" not in command.stderr) +def get_branches(): + proc = subprocess.Popen( + ['git', 'branch', '-a', '--no-color', '--no-column'], + stdout=subprocess.PIPE) + for line in proc.stdout.readlines(): + line = line.decode('utf-8') + if line.startswith('*'): + line = line.split(' ')[1] + if '/' in line: + line = line.split('/')[-1] + yield line.strip() + + @utils.git_support def get_new_command(command, settings): missing_file = re.findall( - r"error: pathspec '([^']*)' " - "did not match any file\(s\) known to git.", command.stderr)[0] - - formatme = shells.and_('git branch {}', '{}') - return formatme.format(missing_file, command.script) + r"error: pathspec '([^']*)' " + "did not match any file\(s\) known to git.", command.stderr)[0] + closest_branch = utils.get_closest(missing_file, get_branches(), + fallback_to_first=False) + if closest_branch: + return command.script.replace(missing_file, closest_branch, 1) + else: + return shells.and_('git branch {}', '{}').format( + missing_file, command.script) diff --git a/thefuck/utils.py b/thefuck/utils.py index ffedb98c..ea564d82 100644 --- a/thefuck/utils.py +++ b/thefuck/utils.py @@ -113,10 +113,11 @@ def memoize(fn): memoize.disabled = False -def get_closest(word, possibilities, n=3, cutoff=0.6): +def get_closest(word, possibilities, n=3, cutoff=0.6, fallback_to_first=True): """Returns closest match or just first from possibilities.""" possibilities = list(possibilities) try: return get_close_matches(word, possibilities, n, cutoff)[0] except IndexError: - return possibilities[0] + if fallback_to_first: + return possibilities[0]