mirror of
https://github.com/esphome/esphome.git
synced 2025-10-16 16:53:47 +01:00
Merge branch 'integration' of https://github.com/esphome/esphome into integration
This commit is contained in:
@@ -866,7 +866,7 @@ message ListEntitiesServicesResponse {
|
|||||||
|
|
||||||
string name = 1;
|
string name = 1;
|
||||||
fixed32 key = 2;
|
fixed32 key = 2;
|
||||||
repeated ListEntitiesServicesArgument args = 3;
|
repeated ListEntitiesServicesArgument args = 3 [(fixed_vector) = true];
|
||||||
}
|
}
|
||||||
message ExecuteServiceArgument {
|
message ExecuteServiceArgument {
|
||||||
option (ifdef) = "USE_API_SERVICES";
|
option (ifdef) = "USE_API_SERVICES";
|
||||||
|
@@ -1263,7 +1263,7 @@ class ListEntitiesServicesResponse final : public ProtoMessage {
|
|||||||
StringRef name_ref_{};
|
StringRef name_ref_{};
|
||||||
void set_name(const StringRef &ref) { this->name_ref_ = ref; }
|
void set_name(const StringRef &ref) { this->name_ref_ = ref; }
|
||||||
uint32_t key{0};
|
uint32_t key{0};
|
||||||
std::vector<ListEntitiesServicesArgument> args{};
|
FixedVector<ListEntitiesServicesArgument> args{};
|
||||||
void encode(ProtoWriteBuffer buffer) const override;
|
void encode(ProtoWriteBuffer buffer) const override;
|
||||||
void calculate_size(ProtoSize &size) const override;
|
void calculate_size(ProtoSize &size) const override;
|
||||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||||
|
@@ -35,9 +35,9 @@ template<typename... Ts> class UserServiceBase : public UserServiceDescriptor {
|
|||||||
msg.set_name(StringRef(this->name_));
|
msg.set_name(StringRef(this->name_));
|
||||||
msg.key = this->key_;
|
msg.key = this->key_;
|
||||||
std::array<enums::ServiceArgType, sizeof...(Ts)> arg_types = {to_service_arg_type<Ts>()...};
|
std::array<enums::ServiceArgType, sizeof...(Ts)> arg_types = {to_service_arg_type<Ts>()...};
|
||||||
|
msg.args.init(sizeof...(Ts));
|
||||||
for (size_t i = 0; i < sizeof...(Ts); i++) {
|
for (size_t i = 0; i < sizeof...(Ts); i++) {
|
||||||
msg.args.emplace_back();
|
auto &arg = msg.args.emplace_back();
|
||||||
auto &arg = msg.args.back();
|
|
||||||
arg.type = arg_types[i];
|
arg.type = arg_types[i];
|
||||||
arg.set_name(StringRef(this->arg_names_[i]));
|
arg.set_name(StringRef(this->arg_names_[i]));
|
||||||
}
|
}
|
||||||
|
@@ -108,8 +108,13 @@ class BTLoggers(Enum):
|
|||||||
"""ESP32 WiFi provisioning over Bluetooth"""
|
"""ESP32 WiFi provisioning over Bluetooth"""
|
||||||
|
|
||||||
|
|
||||||
# Set to track which loggers are needed by components
|
# Key for storing required loggers in CORE.data
|
||||||
_required_loggers: set[BTLoggers] = set()
|
ESP32_BLE_REQUIRED_LOGGERS_KEY = "esp32_ble_required_loggers"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_required_loggers() -> set[BTLoggers]:
|
||||||
|
"""Get the set of required Bluetooth loggers from CORE.data."""
|
||||||
|
return CORE.data.setdefault(ESP32_BLE_REQUIRED_LOGGERS_KEY, set())
|
||||||
|
|
||||||
|
|
||||||
# Dataclass for handler registration counts
|
# Dataclass for handler registration counts
|
||||||
@@ -170,12 +175,13 @@ def register_bt_logger(*loggers: BTLoggers) -> None:
|
|||||||
Args:
|
Args:
|
||||||
*loggers: One or more BTLoggers enum members
|
*loggers: One or more BTLoggers enum members
|
||||||
"""
|
"""
|
||||||
|
required_loggers = _get_required_loggers()
|
||||||
for logger in loggers:
|
for logger in loggers:
|
||||||
if not isinstance(logger, BTLoggers):
|
if not isinstance(logger, BTLoggers):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Logger must be a BTLoggers enum member, got {type(logger)}"
|
f"Logger must be a BTLoggers enum member, got {type(logger)}"
|
||||||
)
|
)
|
||||||
_required_loggers.add(logger)
|
required_loggers.add(logger)
|
||||||
|
|
||||||
|
|
||||||
CONF_BLE_ID = "ble_id"
|
CONF_BLE_ID = "ble_id"
|
||||||
@@ -488,8 +494,9 @@ async def to_code(config):
|
|||||||
# Apply logger settings if log disabling is enabled
|
# Apply logger settings if log disabling is enabled
|
||||||
if config.get(CONF_DISABLE_BT_LOGS, False):
|
if config.get(CONF_DISABLE_BT_LOGS, False):
|
||||||
# Disable all Bluetooth loggers that are not required
|
# Disable all Bluetooth loggers that are not required
|
||||||
|
required_loggers = _get_required_loggers()
|
||||||
for logger in BTLoggers:
|
for logger in BTLoggers:
|
||||||
if logger not in _required_loggers:
|
if logger not in required_loggers:
|
||||||
add_idf_sdkconfig_option(f"{logger.value}_NONE", True)
|
add_idf_sdkconfig_option(f"{logger.value}_NONE", True)
|
||||||
|
|
||||||
# Set BLE connection establishment timeout to match aioesphomeapi/bleak-retry-connector
|
# Set BLE connection establishment timeout to match aioesphomeapi/bleak-retry-connector
|
||||||
|
@@ -60,11 +60,21 @@ class RegistrationCounts:
|
|||||||
clients: int = 0
|
clients: int = 0
|
||||||
|
|
||||||
|
|
||||||
# Set to track which features are needed by components
|
# CORE.data keys for state management
|
||||||
_required_features: set[BLEFeatures] = set()
|
ESP32_BLE_TRACKER_REQUIRED_FEATURES_KEY = "esp32_ble_tracker_required_features"
|
||||||
|
ESP32_BLE_TRACKER_REGISTRATION_COUNTS_KEY = "esp32_ble_tracker_registration_counts"
|
||||||
|
|
||||||
# Track registration counts for StaticVector sizing
|
|
||||||
_registration_counts = RegistrationCounts()
|
def _get_required_features() -> set[BLEFeatures]:
|
||||||
|
"""Get the set of required BLE features from CORE.data."""
|
||||||
|
return CORE.data.setdefault(ESP32_BLE_TRACKER_REQUIRED_FEATURES_KEY, set())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_registration_counts() -> RegistrationCounts:
|
||||||
|
"""Get the registration counts from CORE.data."""
|
||||||
|
return CORE.data.setdefault(
|
||||||
|
ESP32_BLE_TRACKER_REGISTRATION_COUNTS_KEY, RegistrationCounts()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_ble_features(features: set[BLEFeatures]) -> None:
|
def register_ble_features(features: set[BLEFeatures]) -> None:
|
||||||
@@ -73,7 +83,7 @@ def register_ble_features(features: set[BLEFeatures]) -> None:
|
|||||||
Args:
|
Args:
|
||||||
features: Set of BLEFeatures enum members
|
features: Set of BLEFeatures enum members
|
||||||
"""
|
"""
|
||||||
_required_features.update(features)
|
_get_required_features().update(features)
|
||||||
|
|
||||||
|
|
||||||
esp32_ble_tracker_ns = cg.esphome_ns.namespace("esp32_ble_tracker")
|
esp32_ble_tracker_ns = cg.esphome_ns.namespace("esp32_ble_tracker")
|
||||||
@@ -267,15 +277,17 @@ async def to_code(config):
|
|||||||
):
|
):
|
||||||
register_ble_features({BLEFeatures.ESP_BT_DEVICE})
|
register_ble_features({BLEFeatures.ESP_BT_DEVICE})
|
||||||
|
|
||||||
|
registration_counts = _get_registration_counts()
|
||||||
|
|
||||||
for conf in config.get(CONF_ON_BLE_ADVERTISE, []):
|
for conf in config.get(CONF_ON_BLE_ADVERTISE, []):
|
||||||
_registration_counts.listeners += 1
|
registration_counts.listeners += 1
|
||||||
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
|
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
|
||||||
if CONF_MAC_ADDRESS in conf:
|
if CONF_MAC_ADDRESS in conf:
|
||||||
addr_list = [it.as_hex for it in conf[CONF_MAC_ADDRESS]]
|
addr_list = [it.as_hex for it in conf[CONF_MAC_ADDRESS]]
|
||||||
cg.add(trigger.set_addresses(addr_list))
|
cg.add(trigger.set_addresses(addr_list))
|
||||||
await automation.build_automation(trigger, [(ESPBTDeviceConstRef, "x")], conf)
|
await automation.build_automation(trigger, [(ESPBTDeviceConstRef, "x")], conf)
|
||||||
for conf in config.get(CONF_ON_BLE_SERVICE_DATA_ADVERTISE, []):
|
for conf in config.get(CONF_ON_BLE_SERVICE_DATA_ADVERTISE, []):
|
||||||
_registration_counts.listeners += 1
|
registration_counts.listeners += 1
|
||||||
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
|
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
|
||||||
if len(conf[CONF_SERVICE_UUID]) == len(bt_uuid16_format):
|
if len(conf[CONF_SERVICE_UUID]) == len(bt_uuid16_format):
|
||||||
cg.add(trigger.set_service_uuid16(as_hex(conf[CONF_SERVICE_UUID])))
|
cg.add(trigger.set_service_uuid16(as_hex(conf[CONF_SERVICE_UUID])))
|
||||||
@@ -288,7 +300,7 @@ async def to_code(config):
|
|||||||
cg.add(trigger.set_address(conf[CONF_MAC_ADDRESS].as_hex))
|
cg.add(trigger.set_address(conf[CONF_MAC_ADDRESS].as_hex))
|
||||||
await automation.build_automation(trigger, [(adv_data_t_const_ref, "x")], conf)
|
await automation.build_automation(trigger, [(adv_data_t_const_ref, "x")], conf)
|
||||||
for conf in config.get(CONF_ON_BLE_MANUFACTURER_DATA_ADVERTISE, []):
|
for conf in config.get(CONF_ON_BLE_MANUFACTURER_DATA_ADVERTISE, []):
|
||||||
_registration_counts.listeners += 1
|
registration_counts.listeners += 1
|
||||||
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
|
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
|
||||||
if len(conf[CONF_MANUFACTURER_ID]) == len(bt_uuid16_format):
|
if len(conf[CONF_MANUFACTURER_ID]) == len(bt_uuid16_format):
|
||||||
cg.add(trigger.set_manufacturer_uuid16(as_hex(conf[CONF_MANUFACTURER_ID])))
|
cg.add(trigger.set_manufacturer_uuid16(as_hex(conf[CONF_MANUFACTURER_ID])))
|
||||||
@@ -301,7 +313,7 @@ async def to_code(config):
|
|||||||
cg.add(trigger.set_address(conf[CONF_MAC_ADDRESS].as_hex))
|
cg.add(trigger.set_address(conf[CONF_MAC_ADDRESS].as_hex))
|
||||||
await automation.build_automation(trigger, [(adv_data_t_const_ref, "x")], conf)
|
await automation.build_automation(trigger, [(adv_data_t_const_ref, "x")], conf)
|
||||||
for conf in config.get(CONF_ON_SCAN_END, []):
|
for conf in config.get(CONF_ON_SCAN_END, []):
|
||||||
_registration_counts.listeners += 1
|
registration_counts.listeners += 1
|
||||||
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
|
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
|
||||||
await automation.build_automation(trigger, [], conf)
|
await automation.build_automation(trigger, [], conf)
|
||||||
|
|
||||||
@@ -331,19 +343,21 @@ async def to_code(config):
|
|||||||
@coroutine_with_priority(CoroPriority.FINAL)
|
@coroutine_with_priority(CoroPriority.FINAL)
|
||||||
async def _add_ble_features():
|
async def _add_ble_features():
|
||||||
# Add feature-specific defines based on what's needed
|
# Add feature-specific defines based on what's needed
|
||||||
if BLEFeatures.ESP_BT_DEVICE in _required_features:
|
required_features = _get_required_features()
|
||||||
|
if BLEFeatures.ESP_BT_DEVICE in required_features:
|
||||||
cg.add_define("USE_ESP32_BLE_DEVICE")
|
cg.add_define("USE_ESP32_BLE_DEVICE")
|
||||||
cg.add_define("USE_ESP32_BLE_UUID")
|
cg.add_define("USE_ESP32_BLE_UUID")
|
||||||
|
|
||||||
# Add defines for StaticVector sizing based on registration counts
|
# Add defines for StaticVector sizing based on registration counts
|
||||||
# Only define if count > 0 to avoid allocating unnecessary memory
|
# Only define if count > 0 to avoid allocating unnecessary memory
|
||||||
if _registration_counts.listeners > 0:
|
registration_counts = _get_registration_counts()
|
||||||
|
if registration_counts.listeners > 0:
|
||||||
cg.add_define(
|
cg.add_define(
|
||||||
"ESPHOME_ESP32_BLE_TRACKER_LISTENER_COUNT", _registration_counts.listeners
|
"ESPHOME_ESP32_BLE_TRACKER_LISTENER_COUNT", registration_counts.listeners
|
||||||
)
|
)
|
||||||
if _registration_counts.clients > 0:
|
if registration_counts.clients > 0:
|
||||||
cg.add_define(
|
cg.add_define(
|
||||||
"ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT", _registration_counts.clients
|
"ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT", registration_counts.clients
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -395,7 +409,7 @@ async def register_ble_device(
|
|||||||
var: cg.SafeExpType, config: ConfigType
|
var: cg.SafeExpType, config: ConfigType
|
||||||
) -> cg.SafeExpType:
|
) -> cg.SafeExpType:
|
||||||
register_ble_features({BLEFeatures.ESP_BT_DEVICE})
|
register_ble_features({BLEFeatures.ESP_BT_DEVICE})
|
||||||
_registration_counts.listeners += 1
|
_get_registration_counts().listeners += 1
|
||||||
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
|
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
|
||||||
cg.add(paren.register_listener(var))
|
cg.add(paren.register_listener(var))
|
||||||
return var
|
return var
|
||||||
@@ -403,7 +417,7 @@ async def register_ble_device(
|
|||||||
|
|
||||||
async def register_client(var: cg.SafeExpType, config: ConfigType) -> cg.SafeExpType:
|
async def register_client(var: cg.SafeExpType, config: ConfigType) -> cg.SafeExpType:
|
||||||
register_ble_features({BLEFeatures.ESP_BT_DEVICE})
|
register_ble_features({BLEFeatures.ESP_BT_DEVICE})
|
||||||
_registration_counts.clients += 1
|
_get_registration_counts().clients += 1
|
||||||
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
|
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
|
||||||
cg.add(paren.register_client(var))
|
cg.add(paren.register_client(var))
|
||||||
return var
|
return var
|
||||||
@@ -417,7 +431,7 @@ async def register_raw_ble_device(
|
|||||||
This does NOT register the ESP_BT_DEVICE feature, meaning ESPBTDevice
|
This does NOT register the ESP_BT_DEVICE feature, meaning ESPBTDevice
|
||||||
will not be compiled in if this is the only registration method used.
|
will not be compiled in if this is the only registration method used.
|
||||||
"""
|
"""
|
||||||
_registration_counts.listeners += 1
|
_get_registration_counts().listeners += 1
|
||||||
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
|
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
|
||||||
cg.add(paren.register_listener(var))
|
cg.add(paren.register_listener(var))
|
||||||
return var
|
return var
|
||||||
@@ -431,7 +445,7 @@ async def register_raw_client(
|
|||||||
This does NOT register the ESP_BT_DEVICE feature, meaning ESPBTDevice
|
This does NOT register the ESP_BT_DEVICE feature, meaning ESPBTDevice
|
||||||
will not be compiled in if this is the only registration method used.
|
will not be compiled in if this is the only registration method used.
|
||||||
"""
|
"""
|
||||||
_registration_counts.clients += 1
|
_get_registration_counts().clients += 1
|
||||||
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
|
paren = await cg.get_variable(config[CONF_ESP32_BLE_ID])
|
||||||
cg.add(paren.register_client(var))
|
cg.add(paren.register_client(var))
|
||||||
return var
|
return var
|
||||||
|
@@ -143,7 +143,18 @@ def validate_mclk_divisible_by_3(config):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
_use_legacy_driver = None
|
# Key for storing legacy driver setting in CORE.data
|
||||||
|
I2S_USE_LEGACY_DRIVER_KEY = "i2s_use_legacy_driver"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_use_legacy_driver():
|
||||||
|
"""Get the legacy driver setting from CORE.data."""
|
||||||
|
return CORE.data.get(I2S_USE_LEGACY_DRIVER_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_use_legacy_driver(value: bool) -> None:
|
||||||
|
"""Set the legacy driver setting in CORE.data."""
|
||||||
|
CORE.data[I2S_USE_LEGACY_DRIVER_KEY] = value
|
||||||
|
|
||||||
|
|
||||||
def i2s_audio_component_schema(
|
def i2s_audio_component_schema(
|
||||||
@@ -209,17 +220,15 @@ async def register_i2s_audio_component(var, config):
|
|||||||
|
|
||||||
|
|
||||||
def validate_use_legacy(value):
|
def validate_use_legacy(value):
|
||||||
global _use_legacy_driver # noqa: PLW0603
|
|
||||||
if CONF_USE_LEGACY in value:
|
if CONF_USE_LEGACY in value:
|
||||||
if (_use_legacy_driver is not None) and (
|
existing_value = _get_use_legacy_driver()
|
||||||
_use_legacy_driver != value[CONF_USE_LEGACY]
|
if (existing_value is not None) and (existing_value != value[CONF_USE_LEGACY]):
|
||||||
):
|
|
||||||
raise cv.Invalid(
|
raise cv.Invalid(
|
||||||
f"All i2s_audio components must set {CONF_USE_LEGACY} to the same value."
|
f"All i2s_audio components must set {CONF_USE_LEGACY} to the same value."
|
||||||
)
|
)
|
||||||
if (not value[CONF_USE_LEGACY]) and (CORE.using_arduino):
|
if (not value[CONF_USE_LEGACY]) and (CORE.using_arduino):
|
||||||
raise cv.Invalid("Arduino supports only the legacy i2s driver")
|
raise cv.Invalid("Arduino supports only the legacy i2s driver")
|
||||||
_use_legacy_driver = value[CONF_USE_LEGACY]
|
_set_use_legacy_driver(value[CONF_USE_LEGACY])
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
@@ -249,7 +258,8 @@ def _final_validate(_):
|
|||||||
|
|
||||||
|
|
||||||
def use_legacy():
|
def use_legacy():
|
||||||
return not (CORE.using_esp_idf and not _use_legacy_driver)
|
legacy_driver = _get_use_legacy_driver()
|
||||||
|
return not (CORE.using_esp_idf and not legacy_driver)
|
||||||
|
|
||||||
|
|
||||||
FINAL_VALIDATE_SCHEMA = _final_validate
|
FINAL_VALIDATE_SCHEMA = _final_validate
|
||||||
|
@@ -5,6 +5,7 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
@@ -55,6 +56,7 @@ def clone_or_update(
|
|||||||
username: str = None,
|
username: str = None,
|
||||||
password: str = None,
|
password: str = None,
|
||||||
submodules: list[str] | None = None,
|
submodules: list[str] | None = None,
|
||||||
|
_recover_broken: bool = True,
|
||||||
) -> tuple[Path, Callable[[], None] | None]:
|
) -> tuple[Path, Callable[[], None] | None]:
|
||||||
key = f"{url}@{ref}"
|
key = f"{url}@{ref}"
|
||||||
|
|
||||||
@@ -80,7 +82,7 @@ def clone_or_update(
|
|||||||
|
|
||||||
if submodules is not None:
|
if submodules is not None:
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"Initialising submodules (%s) for %s", ", ".join(submodules), key
|
"Initializing submodules (%s) for %s", ", ".join(submodules), key
|
||||||
)
|
)
|
||||||
run_git_command(
|
run_git_command(
|
||||||
["git", "submodule", "update", "--init"] + submodules, str(repo_dir)
|
["git", "submodule", "update", "--init"] + submodules, str(repo_dir)
|
||||||
@@ -99,6 +101,9 @@ def clone_or_update(
|
|||||||
file_timestamp = Path(repo_dir / ".git" / "HEAD")
|
file_timestamp = Path(repo_dir / ".git" / "HEAD")
|
||||||
age = datetime.now() - datetime.fromtimestamp(file_timestamp.stat().st_mtime)
|
age = datetime.now() - datetime.fromtimestamp(file_timestamp.stat().st_mtime)
|
||||||
if refresh is None or age.total_seconds() > refresh.total_seconds:
|
if refresh is None or age.total_seconds() > refresh.total_seconds:
|
||||||
|
# Try to update the repository, recovering from broken state if needed
|
||||||
|
old_sha: str | None = None
|
||||||
|
try:
|
||||||
old_sha = run_git_command(["git", "rev-parse", "HEAD"], str(repo_dir))
|
old_sha = run_git_command(["git", "rev-parse", "HEAD"], str(repo_dir))
|
||||||
_LOGGER.info("Updating %s", key)
|
_LOGGER.info("Updating %s", key)
|
||||||
_LOGGER.debug("Location: %s", repo_dir)
|
_LOGGER.debug("Location: %s", repo_dir)
|
||||||
@@ -113,6 +118,30 @@ def clone_or_update(
|
|||||||
run_git_command(cmd, str(repo_dir))
|
run_git_command(cmd, str(repo_dir))
|
||||||
# Hard reset to FETCH_HEAD (short-lived git ref corresponding to most recent fetch)
|
# Hard reset to FETCH_HEAD (short-lived git ref corresponding to most recent fetch)
|
||||||
run_git_command(["git", "reset", "--hard", "FETCH_HEAD"], str(repo_dir))
|
run_git_command(["git", "reset", "--hard", "FETCH_HEAD"], str(repo_dir))
|
||||||
|
except cv.Invalid as err:
|
||||||
|
# Repository is in a broken state or update failed
|
||||||
|
# Only attempt recovery once to prevent infinite recursion
|
||||||
|
if not _recover_broken:
|
||||||
|
raise
|
||||||
|
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Repository %s has issues (%s), removing and re-cloning",
|
||||||
|
key,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
shutil.rmtree(repo_dir)
|
||||||
|
# Recursively call clone_or_update to re-clone
|
||||||
|
# Set _recover_broken=False to prevent infinite recursion
|
||||||
|
return clone_or_update(
|
||||||
|
url=url,
|
||||||
|
ref=ref,
|
||||||
|
refresh=refresh,
|
||||||
|
domain=domain,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
submodules=submodules,
|
||||||
|
_recover_broken=False,
|
||||||
|
)
|
||||||
|
|
||||||
if submodules is not None:
|
if submodules is not None:
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
|
@@ -6,7 +6,10 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from esphome import git
|
from esphome import git
|
||||||
|
import esphome.config_validation as cv
|
||||||
from esphome.core import CORE, TimePeriodSeconds
|
from esphome.core import CORE, TimePeriodSeconds
|
||||||
|
|
||||||
|
|
||||||
@@ -244,3 +247,160 @@ def test_clone_or_update_with_none_refresh_always_updates(
|
|||||||
if len(call[0]) > 0 and "fetch" in call[0][0]
|
if len(call[0]) > 0 and "fetch" in call[0][0]
|
||||||
]
|
]
|
||||||
assert len(fetch_calls) > 0
|
assert len(fetch_calls) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("fail_command", "error_message"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"rev-parse",
|
||||||
|
"ambiguous argument 'HEAD': unknown revision or path not in the working tree.",
|
||||||
|
),
|
||||||
|
("stash", "fatal: unable to write new index file"),
|
||||||
|
(
|
||||||
|
"fetch",
|
||||||
|
"fatal: unable to access 'https://github.com/test/repo/': Could not resolve host",
|
||||||
|
),
|
||||||
|
("reset", "fatal: Could not reset index file to revision 'FETCH_HEAD'"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_clone_or_update_recovers_from_git_failures(
|
||||||
|
tmp_path: Path, mock_run_git_command: Mock, fail_command: str, error_message: str
|
||||||
|
) -> None:
|
||||||
|
"""Test that repos are re-cloned when various git commands fail."""
|
||||||
|
# Set up CORE.config_path so data_dir uses tmp_path
|
||||||
|
CORE.config_path = tmp_path / "test.yaml"
|
||||||
|
|
||||||
|
url = "https://github.com/test/repo"
|
||||||
|
ref = "main"
|
||||||
|
key = f"{url}@{ref}"
|
||||||
|
domain = "test"
|
||||||
|
|
||||||
|
h = hashlib.new("sha256")
|
||||||
|
h.update(key.encode())
|
||||||
|
repo_dir = tmp_path / ".esphome" / domain / h.hexdigest()[:8]
|
||||||
|
|
||||||
|
# Create repo directory
|
||||||
|
repo_dir.mkdir(parents=True)
|
||||||
|
git_dir = repo_dir / ".git"
|
||||||
|
git_dir.mkdir()
|
||||||
|
|
||||||
|
fetch_head = git_dir / "FETCH_HEAD"
|
||||||
|
fetch_head.write_text("test")
|
||||||
|
old_time = datetime.now() - timedelta(days=2)
|
||||||
|
fetch_head.touch()
|
||||||
|
os.utime(fetch_head, (old_time.timestamp(), old_time.timestamp()))
|
||||||
|
|
||||||
|
# Track command call counts to make first call fail, subsequent calls succeed
|
||||||
|
call_counts: dict[str, int] = {}
|
||||||
|
|
||||||
|
def git_command_side_effect(cmd: list[str], cwd: str | None = None) -> str:
|
||||||
|
# Determine which command this is
|
||||||
|
cmd_type = None
|
||||||
|
if "rev-parse" in cmd:
|
||||||
|
cmd_type = "rev-parse"
|
||||||
|
elif "stash" in cmd:
|
||||||
|
cmd_type = "stash"
|
||||||
|
elif "fetch" in cmd:
|
||||||
|
cmd_type = "fetch"
|
||||||
|
elif "reset" in cmd:
|
||||||
|
cmd_type = "reset"
|
||||||
|
elif "clone" in cmd:
|
||||||
|
cmd_type = "clone"
|
||||||
|
|
||||||
|
# Track call count for this command type
|
||||||
|
if cmd_type:
|
||||||
|
call_counts[cmd_type] = call_counts.get(cmd_type, 0) + 1
|
||||||
|
|
||||||
|
# Fail on first call to the specified command, succeed on subsequent calls
|
||||||
|
if cmd_type == fail_command and call_counts[cmd_type] == 1:
|
||||||
|
raise cv.Invalid(error_message)
|
||||||
|
|
||||||
|
# Default successful responses
|
||||||
|
if cmd_type == "rev-parse":
|
||||||
|
return "abc123"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
mock_run_git_command.side_effect = git_command_side_effect
|
||||||
|
|
||||||
|
refresh = TimePeriodSeconds(days=1)
|
||||||
|
result_dir, revert = git.clone_or_update(
|
||||||
|
url=url,
|
||||||
|
ref=ref,
|
||||||
|
refresh=refresh,
|
||||||
|
domain=domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify recovery happened
|
||||||
|
call_list = mock_run_git_command.call_args_list
|
||||||
|
|
||||||
|
# Should have attempted the failing command
|
||||||
|
assert any(fail_command in str(c) for c in call_list)
|
||||||
|
|
||||||
|
# Should have called clone for recovery
|
||||||
|
assert any("clone" in str(c) for c in call_list)
|
||||||
|
|
||||||
|
# Verify the repo directory path is returned
|
||||||
|
assert result_dir == repo_dir
|
||||||
|
|
||||||
|
|
||||||
|
def test_clone_or_update_fails_when_recovery_also_fails(
|
||||||
|
tmp_path: Path, mock_run_git_command: Mock
|
||||||
|
) -> None:
|
||||||
|
"""Test that we don't infinitely recurse when recovery also fails."""
|
||||||
|
# Set up CORE.config_path so data_dir uses tmp_path
|
||||||
|
CORE.config_path = tmp_path / "test.yaml"
|
||||||
|
|
||||||
|
url = "https://github.com/test/repo"
|
||||||
|
ref = "main"
|
||||||
|
key = f"{url}@{ref}"
|
||||||
|
domain = "test"
|
||||||
|
|
||||||
|
h = hashlib.new("sha256")
|
||||||
|
h.update(key.encode())
|
||||||
|
repo_dir = tmp_path / ".esphome" / domain / h.hexdigest()[:8]
|
||||||
|
|
||||||
|
# Create repo directory
|
||||||
|
repo_dir.mkdir(parents=True)
|
||||||
|
git_dir = repo_dir / ".git"
|
||||||
|
git_dir.mkdir()
|
||||||
|
|
||||||
|
fetch_head = git_dir / "FETCH_HEAD"
|
||||||
|
fetch_head.write_text("test")
|
||||||
|
old_time = datetime.now() - timedelta(days=2)
|
||||||
|
fetch_head.touch()
|
||||||
|
os.utime(fetch_head, (old_time.timestamp(), old_time.timestamp()))
|
||||||
|
|
||||||
|
# Mock git command to fail on clone (simulating network failure during recovery)
|
||||||
|
def git_command_side_effect(cmd: list[str], cwd: str | None = None) -> str:
|
||||||
|
if "rev-parse" in cmd:
|
||||||
|
# First time fails (broken repo)
|
||||||
|
raise cv.Invalid(
|
||||||
|
"ambiguous argument 'HEAD': unknown revision or path not in the working tree."
|
||||||
|
)
|
||||||
|
if "clone" in cmd:
|
||||||
|
# Clone also fails (recovery fails)
|
||||||
|
raise cv.Invalid("fatal: unable to access repository")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
mock_run_git_command.side_effect = git_command_side_effect
|
||||||
|
|
||||||
|
refresh = TimePeriodSeconds(days=1)
|
||||||
|
|
||||||
|
# Should raise after one recovery attempt fails
|
||||||
|
with pytest.raises(cv.Invalid, match="fatal: unable to access repository"):
|
||||||
|
git.clone_or_update(
|
||||||
|
url=url,
|
||||||
|
ref=ref,
|
||||||
|
refresh=refresh,
|
||||||
|
domain=domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we only tried to clone once (no infinite recursion)
|
||||||
|
call_list = mock_run_git_command.call_args_list
|
||||||
|
clone_calls = [c for c in call_list if "clone" in c[0][0]]
|
||||||
|
# Should have exactly one clone call (the recovery attempt that failed)
|
||||||
|
assert len(clone_calls) == 1
|
||||||
|
# Should have tried rev-parse once (which failed and triggered recovery)
|
||||||
|
rev_parse_calls = [c for c in call_list if "rev-parse" in c[0][0]]
|
||||||
|
assert len(rev_parse_calls) == 1
|
||||||
|
Reference in New Issue
Block a user