diff --git a/README.md b/README.md index e67bff9d..af770da3 100644 --- a/README.md +++ b/README.md @@ -216,7 +216,11 @@ match(command: Command, settings: Settings) -> bool get_new_command(command: Command, settings: Settings) -> str | list[str] ``` -Also the rule can contain an optional function `side_effect(command: Command, settings: Settings) -> None` +Also the rule can contain an optional function + +```python +side_effect(old_command: Command, fixed_command: str, settings: Settings) -> None +``` and optional `enabled_by_default`, `requires_output` and `priority` variables. `Command` has three attributes: `script`, `stdout` and `stderr`. diff --git a/tests/rules/test_dirty_untar.py b/tests/rules/test_dirty_untar.py index 77d125d2..6043c33e 100644 --- a/tests/rules/test_dirty_untar.py +++ b/tests/rules/test_dirty_untar.py @@ -51,7 +51,7 @@ def test_match(tar_error, filename, script, fixed): @parametrize_script def test_side_effect(tar_error, filename, script, fixed): tar_error(filename) - side_effect(Command(script=script.format(filename)), None) + side_effect(Command(script=script.format(filename)), None, None) assert(os.listdir('.') == [filename]) diff --git a/tests/rules/test_dirty_unzip.py b/tests/rules/test_dirty_unzip.py index 18d34f0d..74dabf3f 100644 --- a/tests/rules/test_dirty_unzip.py +++ b/tests/rules/test_dirty_unzip.py @@ -34,7 +34,7 @@ def test_match(zip_error, script): 'unzip foo', 'unzip foo.zip']) def test_side_effect(zip_error, script): - side_effect(Command(script=script), None) + side_effect(Command(script=script), None, None) assert(os.listdir('.') == ['foo.zip']) diff --git a/tests/rules/test_ssh_known_host.py b/tests/rules/test_ssh_known_host.py index ed414bc1..7d201f3a 100644 --- a/tests/rules/test_ssh_known_host.py +++ b/tests/rules/test_ssh_known_host.py @@ -56,7 +56,7 @@ def test_match(ssh_error): def test_side_effect(ssh_error): errormsg, path, reset, known_hosts = ssh_error command = Command('ssh user@host', stderr=errormsg) - side_effect(command, None) + side_effect(command, None, None) expected = ['123.234.567.890 asdjkasjdakjsd\n', '111.222.333.444 qwepoiwqepoiss\n'] assert known_hosts(path) == expected diff --git a/thefuck/main.py b/thefuck/main.py index c30ff7cf..24dc697c 100644 --- a/thefuck/main.py +++ b/thefuck/main.py @@ -74,10 +74,10 @@ def get_command(settings, args): return types.Command(script, None, None) -def run_command(command, settings): +def run_command(old_cmd, command, settings): """Runs command from rule for passed command.""" if command.side_effect: - command.side_effect(command, settings) + command.side_effect(old_cmd, command.script, settings) shells.put_to_history(command.script) print(command.script) @@ -100,7 +100,7 @@ def fix_command(): corrected_commands = get_corrected_commands(command, user_dir, settings) selected_command = select_command(corrected_commands, settings) if selected_command: - run_command(selected_command, settings) + run_command(command, selected_command, settings) def print_alias(entry_point=True): diff --git a/thefuck/rules/dirty_untar.py b/thefuck/rules/dirty_untar.py index 3b323daf..8dc40488 100644 --- a/thefuck/rules/dirty_untar.py +++ b/thefuck/rules/dirty_untar.py @@ -35,7 +35,7 @@ def get_new_command(command, settings): .format(dir=_tar_file(command.script)[1], cmd=command.script) -def side_effect(command, settings): - with tarfile.TarFile(_tar_file(command.script)[0]) as archive: +def side_effect(old_cmd, command, settings): + with tarfile.TarFile(_tar_file(old_cmd.script)[0]) as archive: for file in archive.getnames(): os.remove(file) diff --git a/thefuck/rules/dirty_unzip.py b/thefuck/rules/dirty_unzip.py index 4f6e6bcc..738cf82f 100644 --- a/thefuck/rules/dirty_unzip.py +++ b/thefuck/rules/dirty_unzip.py @@ -30,8 +30,8 @@ def get_new_command(command, settings): return '{} -d {}'.format(command.script, _zip_file(command)[:-4]) -def side_effect(command, settings): - with zipfile.ZipFile(_zip_file(command), 'r') as archive: +def side_effect(old_cmd, command, settings): + with zipfile.ZipFile(_zip_file(old_cmd), 'r') as archive: for file in archive.namelist(): os.remove(file) diff --git a/thefuck/rules/ssh_known_hosts.py b/thefuck/rules/ssh_known_hosts.py index 4e889696..df908d00 100644 --- a/thefuck/rules/ssh_known_hosts.py +++ b/thefuck/rules/ssh_known_hosts.py @@ -26,8 +26,8 @@ def get_new_command(command, settings): return command.script -def side_effect(command, settings): - offending = offending_pattern.findall(command.stderr) +def side_effect(old_cmd, command, settings): + offending = offending_pattern.findall(old_cmd.stderr) for filepath, lineno in offending: with open(filepath, 'r') as fh: lines = fh.readlines()