diff --git a/esphome/__main__.py b/esphome/__main__.py index 4bd611d50c..3edccc2bdb 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -9,6 +9,7 @@ import os import re import sys import time +from typing import Protocol import argcomplete @@ -56,6 +57,20 @@ from esphome.util import ( _LOGGER = logging.getLogger(__name__) +class ArgsProtocol(Protocol): + device: list[str] | None + reset: bool + username: str | None + password: str | None + client_id: str | None + topic: str | None + file: str | None + no_logs: bool + only_generate: bool + show_secrets: bool + dashboard: bool + + def choose_prompt(options, purpose: str = None): if not options: raise EsphomeError( @@ -344,7 +359,7 @@ def check_permissions(port: str): ) -def upload_program(config: ConfigType, args, host: str) -> int | str: +def upload_program(config: ConfigType, args: ArgsProtocol, host: str) -> int | str: try: module = importlib.import_module("esphome.components." + CORE.target_platform) if getattr(module, "upload_program")(config, args, host): @@ -405,7 +420,7 @@ def upload_program(config: ConfigType, args, host: str) -> int | str: return espota2.run_ota(host, remote_port, password, CORE.firmware_bin) -def show_logs(config: ConfigType, args, devices: list[str]) -> int | None: +def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int | None: if "logger" not in config: raise EsphomeError("Logger is not configured!") @@ -437,7 +452,7 @@ def show_logs(config: ConfigType, args, devices: list[str]) -> int | None: raise EsphomeError("No remote or local logging method configured (api/mqtt/logger)") -def clean_mqtt(config, args): +def clean_mqtt(config: ConfigType, args: ArgsProtocol) -> int | None: from esphome import mqtt return mqtt.clear_topic( @@ -445,13 +460,13 @@ def clean_mqtt(config, args): ) -def command_wizard(args): +def command_wizard(args: ArgsProtocol) -> int | None: from esphome import wizard return wizard.wizard(args.configuration) -def command_config(args, config): +def command_config(args: ArgsProtocol, config: ConfigType) -> int | None: if not CORE.verbose: config = strip_default_ids(config) output = yaml_util.dump(config, args.show_secrets) @@ -466,7 +481,7 @@ def command_config(args, config): return 0 -def command_vscode(args): +def command_vscode(args: ArgsProtocol) -> int | None: from esphome import vscode logging.disable(logging.INFO) @@ -474,7 +489,7 @@ def command_vscode(args): vscode.read_config(args) -def command_compile(args, config): +def command_compile(args: ArgsProtocol, config: ConfigType) -> int | None: exit_code = write_cpp(config) if exit_code != 0: return exit_code @@ -488,7 +503,7 @@ def command_compile(args, config): return 0 -def command_upload(args, config) -> int: +def command_upload(args: ArgsProtocol, config: ConfigType) -> int | None: devices: list[str] = args.device or [] if not devices: # No devices specified, use the interactive chooser @@ -516,7 +531,7 @@ def command_upload(args, config) -> int: return exit_code -def command_discover(args, config): +def command_discover(args: ArgsProtocol, config: ConfigType) -> int | None: if "mqtt" in config: from esphome import mqtt @@ -525,7 +540,7 @@ def command_discover(args, config): raise EsphomeError("No discover method configured (mqtt)") -def command_logs(args, config) -> int | None: +def command_logs(args: ArgsProtocol, config: ConfigType) -> int | None: # No devices specified, use the interactive chooser devices = args.device or [ choose_upload_log_host( @@ -540,7 +555,7 @@ def command_logs(args, config) -> int | None: return show_logs(config, args, devices) -def command_run(args, config): +def command_run(args: ArgsProtocol, config: ConfigType) -> int | None: exit_code = write_cpp(config) if exit_code != 0: return exit_code @@ -599,22 +614,22 @@ def command_run(args, config): return show_logs(config, args, [port]) -def command_clean_mqtt(args, config): +def command_clean_mqtt(args: ArgsProtocol, config: ConfigType) -> int | None: return clean_mqtt(config, args) -def command_mqtt_fingerprint(args, config): +def command_mqtt_fingerprint(args: ArgsProtocol, config: ConfigType) -> int | None: from esphome import mqtt return mqtt.get_fingerprint(config) -def command_version(args): +def command_version(args: ArgsProtocol) -> int | None: safe_print(f"Version: {const.__version__}") return 0 -def command_clean(args, config): +def command_clean(args: ArgsProtocol, config: ConfigType) -> int | None: try: writer.clean_build() except OSError as err: @@ -624,13 +639,13 @@ def command_clean(args, config): return 0 -def command_dashboard(args): +def command_dashboard(args: ArgsProtocol) -> int | None: from esphome.dashboard import dashboard return dashboard.start_dashboard(args) -def command_update_all(args): +def command_update_all(args: ArgsProtocol) -> int | None: import click success = {}