mirror of
https://github.com/esphome/esphome.git
synced 2025-09-26 15:12:21 +01:00
Merge branch 'auth_connection_checks_dry' into integration
This commit is contained in:
@@ -2721,6 +2721,10 @@ static const char *const TAG = "api.service";
|
||||
hpp_protected = ""
|
||||
cpp += "\n"
|
||||
|
||||
# Build a mapping of message input types to their authentication requirements
|
||||
message_auth_map: dict[str, bool] = {}
|
||||
message_conn_map: dict[str, bool] = {}
|
||||
|
||||
m = serv.method[0]
|
||||
for m in serv.method:
|
||||
func = m.name
|
||||
@@ -2732,6 +2736,10 @@ static const char *const TAG = "api.service";
|
||||
needs_conn = get_opt(m, pb.needs_setup_connection, True)
|
||||
needs_auth = get_opt(m, pb.needs_authentication, True)
|
||||
|
||||
# Store authentication requirements for message types
|
||||
message_auth_map[inp] = needs_auth
|
||||
message_conn_map[inp] = needs_conn
|
||||
|
||||
ifdef = message_ifdef_map.get(inp, ifdefs.get(inp))
|
||||
|
||||
if ifdef is not None:
|
||||
@@ -2749,33 +2757,14 @@ static const char *const TAG = "api.service";
|
||||
|
||||
cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n"
|
||||
|
||||
# Start with authentication/connection check if needed
|
||||
if needs_auth or needs_conn:
|
||||
# Determine which check to use
|
||||
if needs_auth:
|
||||
check_func = "this->check_authenticated_()"
|
||||
else:
|
||||
check_func = "this->check_connection_setup_()"
|
||||
|
||||
if is_void:
|
||||
# For void methods, just wrap with auth check
|
||||
body = f"if ({check_func}) {{\n"
|
||||
body += f" this->{func}(msg);\n"
|
||||
body += "}\n"
|
||||
else:
|
||||
# For non-void methods, combine auth check and send response check
|
||||
body = f"if ({check_func} && !this->send_{func}_response(msg)) {{\n"
|
||||
body += " this->on_fatal_error();\n"
|
||||
body += "}\n"
|
||||
# No authentication check here - it's done in read_message
|
||||
body = ""
|
||||
if is_void:
|
||||
body += f"this->{func}(msg);\n"
|
||||
else:
|
||||
# No auth check needed, just call the handler
|
||||
body = ""
|
||||
if is_void:
|
||||
body += f"this->{func}(msg);\n"
|
||||
else:
|
||||
body += f"if (!this->send_{func}_response(msg)) {{\n"
|
||||
body += " this->on_fatal_error();\n"
|
||||
body += "}\n"
|
||||
body += f"if (!this->send_{func}_response(msg)) {{\n"
|
||||
body += " this->on_fatal_error();\n"
|
||||
body += "}\n"
|
||||
|
||||
cpp += indent(body) + "\n" + "}\n"
|
||||
|
||||
@@ -2784,6 +2773,65 @@ static const char *const TAG = "api.service";
|
||||
hpp_protected += "#endif\n"
|
||||
cpp += "#endif\n"
|
||||
|
||||
# Generate optimized read_message with authentication checking
|
||||
# Categorize messages by their authentication requirements
|
||||
no_conn_ids: set[int] = set()
|
||||
conn_only_ids: set[int] = set()
|
||||
|
||||
for id_, (_, _, case_msg_name) in cases:
|
||||
if case_msg_name in message_auth_map:
|
||||
needs_auth = message_auth_map[case_msg_name]
|
||||
needs_conn = message_conn_map[case_msg_name]
|
||||
|
||||
if not needs_conn:
|
||||
no_conn_ids.add(id_)
|
||||
elif not needs_auth:
|
||||
conn_only_ids.add(id_)
|
||||
|
||||
# Generate override if we have messages that skip checks
|
||||
if no_conn_ids or conn_only_ids:
|
||||
# Helper to generate case statements with ifdefs
|
||||
def generate_cases(ids: set[int], comment: str) -> str:
|
||||
result = ""
|
||||
for id_ in sorted(ids):
|
||||
_, ifdef, msg_name = RECEIVE_CASES[id_]
|
||||
if ifdef:
|
||||
result += f"#ifdef {ifdef}\n"
|
||||
result += f" case {msg_name}::MESSAGE_TYPE: {comment}\n"
|
||||
if ifdef:
|
||||
result += "#endif\n"
|
||||
return result
|
||||
|
||||
hpp_protected += " void read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n"
|
||||
|
||||
cpp += f"\nvoid {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n"
|
||||
cpp += " // Check authentication/connection requirements for messages\n"
|
||||
cpp += " switch (msg_type) {\n"
|
||||
|
||||
# Messages that don't need any checks
|
||||
if no_conn_ids:
|
||||
cpp += generate_cases(no_conn_ids, "// No setup required")
|
||||
cpp += " break; // Skip all checks for these messages\n"
|
||||
|
||||
# Messages that only need connection setup
|
||||
if conn_only_ids:
|
||||
cpp += generate_cases(conn_only_ids, "// Connection setup only")
|
||||
cpp += " if (!this->check_connection_setup_()) {\n"
|
||||
cpp += " return; // Connection not setup\n"
|
||||
cpp += " }\n"
|
||||
cpp += " break;\n"
|
||||
|
||||
cpp += " default:\n"
|
||||
cpp += " // All other messages require authentication (which includes connection check)\n"
|
||||
cpp += " if (!this->check_authenticated_()) {\n"
|
||||
cpp += " return; // Authentication failed\n"
|
||||
cpp += " }\n"
|
||||
cpp += " break;\n"
|
||||
cpp += " }\n\n"
|
||||
cpp += " // Call base implementation to process the message\n"
|
||||
cpp += f" {class_name}Base::read_message(msg_size, msg_type, msg_data);\n"
|
||||
cpp += "}\n"
|
||||
|
||||
hpp += " protected:\n"
|
||||
hpp += hpp_protected
|
||||
hpp += "};\n"
|
||||
|
Reference in New Issue
Block a user