diff --git a/tests/rules/test_git_branch_delete_checked_out.py b/tests/rules/test_git_branch_delete_checked_out.py index 2bf551bf..98e4adb9 100644 --- a/tests/rules/test_git_branch_delete_checked_out.py +++ b/tests/rules/test_git_branch_delete_checked_out.py @@ -1,4 +1,7 @@ +from mock import patch + import pytest + from thefuck.rules.git_branch_delete_checked_out import match, get_new_command from thefuck.types import Command @@ -25,5 +28,18 @@ def test_not_match(script): ("git branch -D foo", "git checkout master && git branch -D foo"), ], ) -def test_get_new_command(script, new_command, output): - assert get_new_command(Command(script, output)) == new_command +def test_get_new_command_deletion_flag(script, new_command, output): + with patch('thefuck.rules.git_branch_delete_checked_out.get_sp_stdout', return_value='master'): + assert get_new_command(Command(script, output)) == new_command + + +@pytest.mark.parametrize( + "script, default_branch", + [ + ("git branch -d foo", "main"), + ("git branch -d foo", "bar"), + ], +) +def test_get_new_command_default_branch(script, default_branch, output): + with patch('thefuck.rules.git_branch_delete_checked_out.get_sp_stdout', return_value=default_branch): + assert get_new_command(Command(script, output)) == "git checkout {default_branch} && git branch -D foo".format(default_branch=default_branch) diff --git a/thefuck/rules/git_branch_delete_checked_out.py b/thefuck/rules/git_branch_delete_checked_out.py index eadc2d4d..5bee5e48 100644 --- a/thefuck/rules/git_branch_delete_checked_out.py +++ b/thefuck/rules/git_branch_delete_checked_out.py @@ -1,8 +1,13 @@ +import subprocess as sp + from thefuck.shells import shell from thefuck.specific.git import git_support from thefuck.utils import replace_argument +STDOUT_INDEX = 0 + + @git_support def match(command): return ( @@ -12,8 +17,14 @@ def match(command): ) +def get_sp_stdout(command): + return sp.Popen(command, stdout=sp.PIPE, shell=True).communicate()[STDOUT_INDEX].strip().decode("utf-8") + + @git_support def get_new_command(command): - return shell.and_("git checkout master", "{}").format( + remote_name = get_sp_stdout("git remote") + default_branch_name = get_sp_stdout("git remote show {remote} | sed -n '/HEAD branch/s/.*: //p'".format(remote=remote_name)) + return shell.and_("git checkout {default_branch}".format(default_branch=default_branch_name), "{}").format( replace_argument(command.script, "-d", "-D") )