1
0
mirror of https://github.com/esphome/esphome.git synced 2025-10-16 16:53:47 +01:00

Merge branch 'integration' of https://github.com/esphome/esphome into integration

This commit is contained in:
J. Nick Koston
2025-10-15 09:44:42 -10:00
8 changed files with 268 additions and 48 deletions

View File

@@ -866,7 +866,7 @@ message ListEntitiesServicesResponse {
string name = 1; string name = 1;
fixed32 key = 2; fixed32 key = 2;
repeated ListEntitiesServicesArgument args = 3; repeated ListEntitiesServicesArgument args = 3 [(fixed_vector) = true];
} }
message ExecuteServiceArgument { message ExecuteServiceArgument {
option (ifdef) = "USE_API_SERVICES"; option (ifdef) = "USE_API_SERVICES";

View File

@@ -1263,7 +1263,7 @@ class ListEntitiesServicesResponse final : public ProtoMessage {
StringRef name_ref_{}; StringRef name_ref_{};
void set_name(const StringRef &ref) { this->name_ref_ = ref; } void set_name(const StringRef &ref) { this->name_ref_ = ref; }
uint32_t key{0}; uint32_t key{0};
std::vector<ListEntitiesServicesArgument> args{}; FixedVector<ListEntitiesServicesArgument> args{};
void encode(ProtoWriteBuffer buffer) const override; void encode(ProtoWriteBuffer buffer) const override;
void calculate_size(ProtoSize &size) const override; void calculate_size(ProtoSize &size) const override;
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP

View File

@@ -35,9 +35,9 @@ template<typename... Ts> class UserServiceBase : public UserServiceDescriptor {
msg.set_name(StringRef(this->name_)); msg.set_name(StringRef(this->name_));
msg.key = this->key_; msg.key = this->key_;
std::array<enums::ServiceArgType, sizeof...(Ts)> arg_types = {to_service_arg_type<Ts>()...}; std::array<enums::ServiceArgType, sizeof...(Ts)> arg_types = {to_service_arg_type<Ts>()...};
msg.args.init(sizeof...(Ts));
for (size_t i = 0; i < sizeof...(Ts); i++) { for (size_t i = 0; i < sizeof...(Ts); i++) {
msg.args.emplace_back(); auto &arg = msg.args.emplace_back();
auto &arg = msg.args.back();
arg.type = arg_types[i]; arg.type = arg_types[i];
arg.set_name(StringRef(this->arg_names_[i])); arg.set_name(StringRef(this->arg_names_[i]));
} }

View File

@@ -108,8 +108,13 @@ class BTLoggers(Enum):
"""ESP32 WiFi provisioning over Bluetooth""" """ESP32 WiFi provisioning over Bluetooth"""
# Set to track which loggers are needed by components # Key for storing required loggers in CORE.data
_required_loggers: set[BTLoggers] = set() ESP32_BLE_REQUIRED_LOGGERS_KEY = "esp32_ble_required_loggers"
def _get_required_loggers() -> set[BTLoggers]:
"""Get the set of required Bluetooth loggers from CORE.data."""
return CORE.data.setdefault(ESP32_BLE_REQUIRED_LOGGERS_KEY, set())
# Dataclass for handler registration counts # Dataclass for handler registration counts
@@ -170,12 +175,13 @@ def register_bt_logger(*loggers: BTLoggers) -> None:
Args: Args:
*loggers: One or more BTLoggers enum members *loggers: One or more BTLoggers enum members
""" """
required_loggers = _get_required_loggers()
for logger in loggers: for logger in loggers:
if not isinstance(logger, BTLoggers): if not isinstance(logger, BTLoggers):
raise TypeError( raise TypeError(
f"Logger must be a BTLoggers enum member, got {type(logger)}" f"Logger must be a BTLoggers enum member, got {type(logger)}"
) )
_required_loggers.add(logger) required_loggers.add(logger)
CONF_BLE_ID = "ble_id" CONF_BLE_ID = "ble_id"
@@ -488,8 +494,9 @@ async def to_code(config):
# Apply logger settings if log disabling is enabled # Apply logger settings if log disabling is enabled
if config.get(CONF_DISABLE_BT_LOGS, False): if config.get(CONF_DISABLE_BT_LOGS, False):
# Disable all Bluetooth loggers that are not required # Disable all Bluetooth loggers that are not required
required_loggers = _get_required_loggers()
for logger in BTLoggers: for logger in BTLoggers:
if logger not in _required_loggers: if logger not in required_loggers:
add_idf_sdkconfig_option(f"{logger.value}_NONE", True) add_idf_sdkconfig_option(f"{logger.value}_NONE", True)
# Set BLE connection establishment timeout to match aioesphomeapi/bleak-retry-connector # Set BLE connection establishment timeout to match aioesphomeapi/bleak-retry-connector

View File

@@ -60,11 +60,21 @@ class RegistrationCounts:
clients: int = 0 clients: int = 0
# Set to track which features are needed by components # CORE.data keys for state management
_required_features: set[BLEFeatures] = set() ESP32_BLE_TRACKER_REQUIRED_FEATURES_KEY = "esp32_ble_tracker_required_features"
ESP32_BLE_TRACKER_REGISTRATION_COUNTS_KEY = "esp32_ble_tracker_registration_counts"
# Track registration counts for StaticVector sizing
_registration_counts = RegistrationCounts() def _get_required_features() -> set[BLEFeatures]:
"""Get the set of required BLE features from CORE.data."""
return CORE.data.setdefault(ESP32_BLE_TRACKER_REQUIRED_FEATURES_KEY, set())
def _get_registration_counts() -> RegistrationCounts:
"""Get the registration counts from CORE.data."""
return CORE.data.setdefault(
ESP32_BLE_TRACKER_REGISTRATION_COUNTS_KEY, RegistrationCounts()
)
def register_ble_features(features: set[BLEFeatures]) -> None: def register_ble_features(features: set[BLEFeatures]) -> None:
@@ -73,7 +83,7 @@ def register_ble_features(features: set[BLEFeatures]) -> None:
Args: Args:
features: Set of BLEFeatures enum members features: Set of BLEFeatures enum members
""" """
_required_features.update(features) _get_required_features().update(features)
esp32_ble_tracker_ns = cg.esphome_ns.namespace("esp32_ble_tracker") esp32_ble_tracker_ns = cg.esphome_ns.namespace("esp32_ble_tracker")
@@ -267,15 +277,17 @@ async def to_code(config):
): ):
register_ble_features({BLEFeatures.ESP_BT_DEVICE}) register_ble_features({BLEFeatures.ESP_BT_DEVICE})
registration_counts = _get_registration_counts()
for conf in config.get(CONF_ON_BLE_ADVERTISE, []): for conf in config.get(CONF_ON_BLE_ADVERTISE, []):
_registration_counts.listeners += 1 registration_counts.listeners += 1
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
if CONF_MAC_ADDRESS in conf: if CONF_MAC_ADDRESS in conf:
addr_list = [it.as_hex for it in conf[CONF_MAC_ADDRESS]] addr_list = [it.as_hex for it in conf[CONF_MAC_ADDRESS]]
cg.add(trigger.set_addresses(addr_list)) cg.add(trigger.set_addresses(addr_list))
await automation.build_automation(trigger, [(ESPBTDeviceConstRef, "x")], conf) await automation.build_automation(trigger, [(ESPBTDeviceConstRef, "x")], conf)
for conf in config.get(CONF_ON_BLE_SERVICE_DATA_ADVERTISE, []): for conf in config.get(CONF_ON_BLE_SERVICE_DATA_ADVERTISE, []):
_registration_counts.listeners += 1 registration_counts.listeners += 1
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
if len(conf[CONF_SERVICE_UUID]) == len(bt_uuid16_format): if len(conf[CONF_SERVICE_UUID]) == len(bt_uuid16_format):
cg.add(trigger.set_service_uuid16(as_hex(conf[CONF_SERVICE_UUID]))) cg.add(trigger.set_service_uuid16(as_hex(conf[CONF_SERVICE_UUID])))
@@ -288,7 +300,7 @@ async def to_code(config):
cg.add(trigger.set_address(conf[CONF_MAC_ADDRESS].as_hex)) cg.add(trigger.set_address(conf[CONF_MAC_ADDRESS].as_hex))
await automation.build_automation(trigger, [(adv_data_t_const_ref, "x")], conf) await automation.build_automation(trigger, [(adv_data_t_const_ref, "x")], conf)
for conf in config.get(CONF_ON_BLE_MANUFACTURER_DATA_ADVERTISE, []): for conf in config.get(CONF_ON_BLE_MANUFACTURER_DATA_ADVERTISE, []):
_registration_counts.listeners += 1 registration_counts.listeners += 1
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
if len(conf[CONF_MANUFACTURER_ID]) == len(bt_uuid16_format): if len(conf[CONF_MANUFACTURER_ID]) == len(bt_uuid16_format):
cg.add(trigger.set_manufacturer_uuid16(as_hex(conf[CONF_MANUFACTURER_ID]))) cg.add(trigger.set_manufacturer_uuid16(as_hex(conf[CONF_MANUFACTURER_ID])))
@@ -301,7 +313,7 @@ async def to_code(config):
cg.add(trigger.set_address(conf[CONF_MAC_ADDRESS].as_hex)) cg.add(trigger.set_address(conf[CONF_MAC_ADDRESS].as_hex))
await automation.build_automation(trigger, [(adv_data_t_const_ref, "x")], conf) await automation.build_automation(trigger, [(adv_data_t_const_ref, "x")], conf)
for conf in config.get(CONF_ON_SCAN_END, []): for conf in config.get(CONF_ON_SCAN_END, []):
_registration_counts.listeners += 1 registration_counts.listeners += 1
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [], conf) await automation.build_automation(trigger, [], conf)
@@ -331,19 +343,21 @@ async def to_code(config):
@coroutine_with_priority(CoroPriority.FINAL) @coroutine_with_priority(CoroPriority.FINAL)
async def _add_ble_features(): async def _add_ble_features():
# Add feature-specific defines based on what's needed # Add feature-specific defines based on what's needed
if BLEFeatures.ESP_BT_DEVICE in _required_features: required_features = _get_required_features()
if BLEFeatures.ESP_BT_DEVICE in required_features:
cg.add_define("USE_ESP32_BLE_DEVICE") cg.add_define("USE_ESP32_BLE_DEVICE")
cg.add_define("USE_ESP32_BLE_UUID") cg.add_define("USE_ESP32_BLE_UUID")
# Add defines for StaticVector sizing based on registration counts # Add defines for StaticVector sizing based on registration counts
# Only define if count > 0 to avoid allocating unnecessary memory # Only define if count > 0 to avoid allocating unnecessary memory
if _registration_counts.listeners > 0: registration_counts = _get_registration_counts()
if registration_counts.listeners > 0:
cg.add_define( cg.add_define(
"ESPHOME_ESP32_BLE_TRACKER_LISTENER_COUNT", _registration_counts.listeners "ESPHOME_ESP32_BLE_TRACKER_LISTENER_COUNT", registration_counts.listeners
) )
if _registration_counts.clients > 0: if registration_counts.clients > 0:
cg.add_define( cg.add_define(
"ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT", _registration_counts.clients "ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT", registration_counts.clients
) )
@@ -395,7 +409,7 @@ async def register_ble_device(
var: cg.SafeExpType, config: ConfigType var: cg.SafeExpType, config: ConfigType
) -> cg.SafeExpType: ) -> cg.SafeExpType:
register_ble_features({BLEFeatures.ESP_BT_DEVICE}) register_ble_features({BLEFeatures.ESP_BT_DEVICE})
_registration_counts.listeners += 1 _get_registration_counts().listeners += 1
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID]) paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
cg.add(paren.register_listener(var)) cg.add(paren.register_listener(var))
return var return var
@@ -403,7 +417,7 @@ async def register_ble_device(
async def register_client(var: cg.SafeExpType, config: ConfigType) -> cg.SafeExpType: async def register_client(var: cg.SafeExpType, config: ConfigType) -> cg.SafeExpType:
register_ble_features({BLEFeatures.ESP_BT_DEVICE}) register_ble_features({BLEFeatures.ESP_BT_DEVICE})
_registration_counts.clients += 1 _get_registration_counts().clients += 1
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID]) paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
cg.add(paren.register_client(var)) cg.add(paren.register_client(var))
return var return var
@@ -417,7 +431,7 @@ async def register_raw_ble_device(
This does NOT register the ESP_BT_DEVICE feature, meaning ESPBTDevice This does NOT register the ESP_BT_DEVICE feature, meaning ESPBTDevice
will not be compiled in if this is the only registration method used. will not be compiled in if this is the only registration method used.
""" """
_registration_counts.listeners += 1 _get_registration_counts().listeners += 1
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID]) paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
cg.add(paren.register_listener(var)) cg.add(paren.register_listener(var))
return var return var
@@ -431,7 +445,7 @@ async def register_raw_client(
This does NOT register the ESP_BT_DEVICE feature, meaning ESPBTDevice This does NOT register the ESP_BT_DEVICE feature, meaning ESPBTDevice
will not be compiled in if this is the only registration method used. will not be compiled in if this is the only registration method used.
""" """
_registration_counts.clients += 1 _get_registration_counts().clients += 1
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID]) paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
cg.add(paren.register_client(var)) cg.add(paren.register_client(var))
return var return var

View File

@@ -143,7 +143,18 @@ def validate_mclk_divisible_by_3(config):
return config return config
_use_legacy_driver = None # Key for storing legacy driver setting in CORE.data
I2S_USE_LEGACY_DRIVER_KEY = "i2s_use_legacy_driver"
def _get_use_legacy_driver():
"""Get the legacy driver setting from CORE.data."""
return CORE.data.get(I2S_USE_LEGACY_DRIVER_KEY)
def _set_use_legacy_driver(value: bool) -> None:
"""Set the legacy driver setting in CORE.data."""
CORE.data[I2S_USE_LEGACY_DRIVER_KEY] = value
def i2s_audio_component_schema( def i2s_audio_component_schema(
@@ -209,17 +220,15 @@ async def register_i2s_audio_component(var, config):
def validate_use_legacy(value): def validate_use_legacy(value):
global _use_legacy_driver # noqa: PLW0603
if CONF_USE_LEGACY in value: if CONF_USE_LEGACY in value:
if (_use_legacy_driver is not None) and ( existing_value = _get_use_legacy_driver()
_use_legacy_driver != value[CONF_USE_LEGACY] if (existing_value is not None) and (existing_value != value[CONF_USE_LEGACY]):
):
raise cv.Invalid( raise cv.Invalid(
f"All i2s_audio components must set {CONF_USE_LEGACY} to the same value." f"All i2s_audio components must set {CONF_USE_LEGACY} to the same value."
) )
if (not value[CONF_USE_LEGACY]) and (CORE.using_arduino): if (not value[CONF_USE_LEGACY]) and (CORE.using_arduino):
raise cv.Invalid("Arduino supports only the legacy i2s driver") raise cv.Invalid("Arduino supports only the legacy i2s driver")
_use_legacy_driver = value[CONF_USE_LEGACY] _set_use_legacy_driver(value[CONF_USE_LEGACY])
return value return value
@@ -249,7 +258,8 @@ def _final_validate(_):
def use_legacy(): def use_legacy():
return not (CORE.using_esp_idf and not _use_legacy_driver) legacy_driver = _get_use_legacy_driver()
return not (CORE.using_esp_idf and not legacy_driver)
FINAL_VALIDATE_SCHEMA = _final_validate FINAL_VALIDATE_SCHEMA = _final_validate

View File

@@ -5,6 +5,7 @@ import hashlib
import logging import logging
from pathlib import Path from pathlib import Path
import re import re
import shutil
import subprocess import subprocess
import urllib.parse import urllib.parse
@@ -55,6 +56,7 @@ def clone_or_update(
username: str = None, username: str = None,
password: str = None, password: str = None,
submodules: list[str] | None = None, submodules: list[str] | None = None,
_recover_broken: bool = True,
) -> tuple[Path, Callable[[], None] | None]: ) -> tuple[Path, Callable[[], None] | None]:
key = f"{url}@{ref}" key = f"{url}@{ref}"
@@ -80,7 +82,7 @@ def clone_or_update(
if submodules is not None: if submodules is not None:
_LOGGER.info( _LOGGER.info(
"Initialising submodules (%s) for %s", ", ".join(submodules), key "Initializing submodules (%s) for %s", ", ".join(submodules), key
) )
run_git_command( run_git_command(
["git", "submodule", "update", "--init"] + submodules, str(repo_dir) ["git", "submodule", "update", "--init"] + submodules, str(repo_dir)
@@ -99,20 +101,47 @@ def clone_or_update(
file_timestamp = Path(repo_dir / ".git" / "HEAD") file_timestamp = Path(repo_dir / ".git" / "HEAD")
age = datetime.now() - datetime.fromtimestamp(file_timestamp.stat().st_mtime) age = datetime.now() - datetime.fromtimestamp(file_timestamp.stat().st_mtime)
if refresh is None or age.total_seconds() > refresh.total_seconds: if refresh is None or age.total_seconds() > refresh.total_seconds:
old_sha = run_git_command(["git", "rev-parse", "HEAD"], str(repo_dir)) # Try to update the repository, recovering from broken state if needed
_LOGGER.info("Updating %s", key) old_sha: str | None = None
_LOGGER.debug("Location: %s", repo_dir) try:
# Stash local changes (if any) old_sha = run_git_command(["git", "rev-parse", "HEAD"], str(repo_dir))
run_git_command( _LOGGER.info("Updating %s", key)
["git", "stash", "push", "--include-untracked"], str(repo_dir) _LOGGER.debug("Location: %s", repo_dir)
) # Stash local changes (if any)
# Fetch remote ref run_git_command(
cmd = ["git", "fetch", "--", "origin"] ["git", "stash", "push", "--include-untracked"], str(repo_dir)
if ref is not None: )
cmd.append(ref) # Fetch remote ref
run_git_command(cmd, str(repo_dir)) cmd = ["git", "fetch", "--", "origin"]
# Hard reset to FETCH_HEAD (short-lived git ref corresponding to most recent fetch) if ref is not None:
run_git_command(["git", "reset", "--hard", "FETCH_HEAD"], str(repo_dir)) cmd.append(ref)
run_git_command(cmd, str(repo_dir))
# Hard reset to FETCH_HEAD (short-lived git ref corresponding to most recent fetch)
run_git_command(["git", "reset", "--hard", "FETCH_HEAD"], str(repo_dir))
except cv.Invalid as err:
# Repository is in a broken state or update failed
# Only attempt recovery once to prevent infinite recursion
if not _recover_broken:
raise
_LOGGER.warning(
"Repository %s has issues (%s), removing and re-cloning",
key,
err,
)
shutil.rmtree(repo_dir)
# Recursively call clone_or_update to re-clone
# Set _recover_broken=False to prevent infinite recursion
return clone_or_update(
url=url,
ref=ref,
refresh=refresh,
domain=domain,
username=username,
password=password,
submodules=submodules,
_recover_broken=False,
)
if submodules is not None: if submodules is not None:
_LOGGER.info( _LOGGER.info(

View File

@@ -6,7 +6,10 @@ import os
from pathlib import Path from pathlib import Path
from unittest.mock import Mock from unittest.mock import Mock
import pytest
from esphome import git from esphome import git
import esphome.config_validation as cv
from esphome.core import CORE, TimePeriodSeconds from esphome.core import CORE, TimePeriodSeconds
@@ -244,3 +247,160 @@ def test_clone_or_update_with_none_refresh_always_updates(
if len(call[0]) > 0 and "fetch" in call[0][0] if len(call[0]) > 0 and "fetch" in call[0][0]
] ]
assert len(fetch_calls) > 0 assert len(fetch_calls) > 0
@pytest.mark.parametrize(
("fail_command", "error_message"),
[
(
"rev-parse",
"ambiguous argument 'HEAD': unknown revision or path not in the working tree.",
),
("stash", "fatal: unable to write new index file"),
(
"fetch",
"fatal: unable to access 'https://github.com/test/repo/': Could not resolve host",
),
("reset", "fatal: Could not reset index file to revision 'FETCH_HEAD'"),
],
)
def test_clone_or_update_recovers_from_git_failures(
tmp_path: Path, mock_run_git_command: Mock, fail_command: str, error_message: str
) -> None:
"""Test that repos are re-cloned when various git commands fail."""
# Set up CORE.config_path so data_dir uses tmp_path
CORE.config_path = tmp_path / "test.yaml"
url = "https://github.com/test/repo"
ref = "main"
key = f"{url}@{ref}"
domain = "test"
h = hashlib.new("sha256")
h.update(key.encode())
repo_dir = tmp_path / ".esphome" / domain / h.hexdigest()[:8]
# Create repo directory
repo_dir.mkdir(parents=True)
git_dir = repo_dir / ".git"
git_dir.mkdir()
fetch_head = git_dir / "FETCH_HEAD"
fetch_head.write_text("test")
old_time = datetime.now() - timedelta(days=2)
fetch_head.touch()
os.utime(fetch_head, (old_time.timestamp(), old_time.timestamp()))
# Track command call counts to make first call fail, subsequent calls succeed
call_counts: dict[str, int] = {}
def git_command_side_effect(cmd: list[str], cwd: str | None = None) -> str:
# Determine which command this is
cmd_type = None
if "rev-parse" in cmd:
cmd_type = "rev-parse"
elif "stash" in cmd:
cmd_type = "stash"
elif "fetch" in cmd:
cmd_type = "fetch"
elif "reset" in cmd:
cmd_type = "reset"
elif "clone" in cmd:
cmd_type = "clone"
# Track call count for this command type
if cmd_type:
call_counts[cmd_type] = call_counts.get(cmd_type, 0) + 1
# Fail on first call to the specified command, succeed on subsequent calls
if cmd_type == fail_command and call_counts[cmd_type] == 1:
raise cv.Invalid(error_message)
# Default successful responses
if cmd_type == "rev-parse":
return "abc123"
return ""
mock_run_git_command.side_effect = git_command_side_effect
refresh = TimePeriodSeconds(days=1)
result_dir, revert = git.clone_or_update(
url=url,
ref=ref,
refresh=refresh,
domain=domain,
)
# Verify recovery happened
call_list = mock_run_git_command.call_args_list
# Should have attempted the failing command
assert any(fail_command in str(c) for c in call_list)
# Should have called clone for recovery
assert any("clone" in str(c) for c in call_list)
# Verify the repo directory path is returned
assert result_dir == repo_dir
def test_clone_or_update_fails_when_recovery_also_fails(
tmp_path: Path, mock_run_git_command: Mock
) -> None:
"""Test that we don't infinitely recurse when recovery also fails."""
# Set up CORE.config_path so data_dir uses tmp_path
CORE.config_path = tmp_path / "test.yaml"
url = "https://github.com/test/repo"
ref = "main"
key = f"{url}@{ref}"
domain = "test"
h = hashlib.new("sha256")
h.update(key.encode())
repo_dir = tmp_path / ".esphome" / domain / h.hexdigest()[:8]
# Create repo directory
repo_dir.mkdir(parents=True)
git_dir = repo_dir / ".git"
git_dir.mkdir()
fetch_head = git_dir / "FETCH_HEAD"
fetch_head.write_text("test")
old_time = datetime.now() - timedelta(days=2)
fetch_head.touch()
os.utime(fetch_head, (old_time.timestamp(), old_time.timestamp()))
# Mock git command to fail on clone (simulating network failure during recovery)
def git_command_side_effect(cmd: list[str], cwd: str | None = None) -> str:
if "rev-parse" in cmd:
# First time fails (broken repo)
raise cv.Invalid(
"ambiguous argument 'HEAD': unknown revision or path not in the working tree."
)
if "clone" in cmd:
# Clone also fails (recovery fails)
raise cv.Invalid("fatal: unable to access repository")
return ""
mock_run_git_command.side_effect = git_command_side_effect
refresh = TimePeriodSeconds(days=1)
# Should raise after one recovery attempt fails
with pytest.raises(cv.Invalid, match="fatal: unable to access repository"):
git.clone_or_update(
url=url,
ref=ref,
refresh=refresh,
domain=domain,
)
# Verify we only tried to clone once (no infinite recursion)
call_list = mock_run_git_command.call_args_list
clone_calls = [c for c in call_list if "clone" in c[0][0]]
# Should have exactly one clone call (the recovery attempt that failed)
assert len(clone_calls) == 1
# Should have tried rev-parse once (which failed and triggered recovery)
rev_parse_calls = [c for c in call_list if "rev-parse" in c[0][0]]
assert len(rev_parse_calls) == 1