From 77a30830c753660706fc6c6fe3b5e895cc9f8195 Mon Sep 17 00:00:00 2001
From: oarcher <olivier.archer@gmail.com>
Date: Wed, 14 Aug 2024 22:18:54 +0200
Subject: [PATCH] fix schemas

---
 esphome/components/modem/__init__.py          | 23 +++++--
 esphome/components/modem/sensor/__init__.py   | 40 ++++++------
 esphome/components/modem/switch/__init__.py   | 62 +++++++++++++------
 .../components/modem/text_sensor/__init__.py  |  6 +-
 4 files changed, 85 insertions(+), 46 deletions(-)

diff --git a/esphome/components/modem/__init__.py b/esphome/components/modem/__init__.py
index ff591e6bfe..bb72d6a0bc 100644
--- a/esphome/components/modem/__init__.py
+++ b/esphome/components/modem/__init__.py
@@ -32,6 +32,7 @@ AUTO_LOAD = ["network", "watchdog"]
 CONFLICTS_WITH = ["captive_portal", "ethernet"]
 
 CONF_MODEM = "modem"
+CONF_MODEM_ID = "modem_id"
 CONF_PIN_CODE = "pin_code"
 CONF_APN = "apn"
 CONF_DTR_PIN = "dtr_pin"
@@ -67,10 +68,14 @@ ModemOnDisconnectTrigger = modem_ns.class_(
     "ModemOnDisconnectTrigger", automation.Trigger.template()
 )
 
+MODEM_COMPONENT_SCHEMA = cv.Schema(
+    {cv.GenerateID(CONF_MODEM_ID): cv.use_id(ModemComponent)}
+)
+
 CONFIG_SCHEMA = cv.All(
     cv.Schema(
         {
-            cv.GenerateID(): cv.declare_id(ModemComponent),
+            cv.GenerateID(CONF_ID): cv.declare_id(ModemComponent),
             cv.Required(CONF_TX_PIN): pins.internal_gpio_output_pin_schema,
             cv.Required(CONF_RX_PIN): pins.internal_gpio_output_pin_schema,
             cv.Required(CONF_MODEL): cv.one_of(*MODEM_MODELS, upper=True),
@@ -101,7 +106,9 @@ CONFIG_SCHEMA = cv.All(
                 }
             ),
         }
-    ).extend(cv.COMPONENT_SCHEMA),
+    )
+    .extend(MODEM_COMPONENT_SCHEMA)
+    .extend(cv.COMPONENT_SCHEMA),
     cv.require_framework_version(
         esp_idf=cv.Version(4, 0, 0),  # 5.2.0 OK
     ),
@@ -109,10 +116,14 @@ CONFIG_SCHEMA = cv.All(
 
 
 def final_validate_platform(config):
-    if not fv.full_config.get().data.get(KEY_MODEM_CMUX, None):
-        raise cv.Invalid(
-            f"'{CONF_MODEM}' platform require '{CONF_ENABLE_CMUX}' to be 'true'."
-        )
+    # to be called by platform components
+    if modem_config := fv.full_config.get().get(CONF_MODEM, None):
+        if not modem_config.get(CONF_ENABLE_CMUX, None):
+            raise cv.Invalid(
+                f"'{CONF_MODEM}' platform require '{CONF_ENABLE_CMUX}' to be 'true'."
+            )
+    else:
+        raise cv.Invalid("'{CONF_MODEM}' component required.")
     return config
 
 
diff --git a/esphome/components/modem/sensor/__init__.py b/esphome/components/modem/sensor/__init__.py
index 06da7a1f4e..00d1a027d7 100644
--- a/esphome/components/modem/sensor/__init__.py
+++ b/esphome/components/modem/sensor/__init__.py
@@ -18,9 +18,9 @@ from esphome.const import (
     UNIT_METER,
     UNIT_PERCENT,
 )
-import esphome.final_validate as fv
 
-from .. import final_validate_platform, modem_ns, switch
+from .. import MODEM_COMPONENT_SCHEMA, final_validate_platform, modem_ns
+from ..switch import GNSS_SWITCH_SCHEMA
 
 CODEOWNERS = ["@oarcher"]
 
@@ -45,6 +45,15 @@ ICON_SIGNAL_BAR = "mdi:signal"
 ModemSensor = modem_ns.class_("ModemSensor", cg.PollingComponent)
 
 
+GNSS_SENSORS = {
+    CONF_LATITUDE,
+    CONF_LONGITUDE,
+    CONF_ALTITUDE,
+    CONF_COURSE,
+    CONF_ACCURACY,
+    CONF_SPEED,
+}
+
 CONFIG_SCHEMA = cv.All(
     cv.Schema(
         {
@@ -69,51 +78,44 @@ CONFIG_SCHEMA = cv.All(
                 accuracy_decimals=5,
                 icon=ICON_LATITUDE,
                 state_class=STATE_CLASS_MEASUREMENT,
-            ),
+            ).extend(GNSS_SWITCH_SCHEMA),
             cv.Optional(CONF_LONGITUDE): sensor.sensor_schema(
                 unit_of_measurement=UNIT_DEGREES,
                 accuracy_decimals=5,
                 icon=ICON_LONGITUDE,
                 state_class=STATE_CLASS_MEASUREMENT,
-            ),
+            ).extend(GNSS_SWITCH_SCHEMA),
             cv.Optional(CONF_ALTITUDE): sensor.sensor_schema(
                 unit_of_measurement=UNIT_METER,
                 accuracy_decimals=1,
                 icon=ICON_LOCATION_UP,
                 state_class=STATE_CLASS_MEASUREMENT,
-            ),
+            ).extend(GNSS_SWITCH_SCHEMA),
             cv.Optional(CONF_SPEED): sensor.sensor_schema(
                 unit_of_measurement=UNIT_KILOMETER_PER_HOUR,
                 accuracy_decimals=1,
                 icon=ICON_SPEED,
                 state_class=STATE_CLASS_MEASUREMENT,
-            ),
+            ).extend(GNSS_SWITCH_SCHEMA),
             cv.Optional(CONF_ACCURACY): sensor.sensor_schema(
                 unit_of_measurement=UNIT_METER,
                 accuracy_decimals=1,
                 icon=ICON_LOCATION_RADIUS,
                 state_class=STATE_CLASS_MEASUREMENT,
-            ),
+            ).extend(GNSS_SWITCH_SCHEMA),
             cv.Optional(CONF_COURSE): sensor.sensor_schema(
                 unit_of_measurement=UNIT_DEGREES,
                 accuracy_decimals=1,
                 icon=ICON_COMPASS,
                 state_class=STATE_CLASS_MEASUREMENT,
-            ),
+            ).extend(GNSS_SWITCH_SCHEMA),
         }
-    ).extend(cv.polling_component_schema("60s"))
+    )
+    .extend(MODEM_COMPONENT_SCHEMA)
+    .extend(cv.polling_component_schema("60s")),
 )
 
-
-def _final_validate_gnss(config):
-    # GNSS sensors needs GNSS switch
-    if config.get(CONF_LATITUDE, None) or config.get(CONF_LONGITUDE, None):
-        if not fv.full_config.get().data.get(switch.KEY_MODEM_GNSS, None):
-            raise cv.Invalid("Using GNSS modem sensors require GNSS modem switch.")
-    return config
-
-
-FINAL_VALIDATE_SCHEMA = cv.All(final_validate_platform, _final_validate_gnss)
+FINAL_VALIDATE_SCHEMA = cv.All(final_validate_platform)
 
 
 async def to_code(config):
diff --git a/esphome/components/modem/switch/__init__.py b/esphome/components/modem/switch/__init__.py
index 268df91a33..7ca8d6ce61 100644
--- a/esphome/components/modem/switch/__init__.py
+++ b/esphome/components/modem/switch/__init__.py
@@ -4,7 +4,13 @@ import esphome.config_validation as cv
 from esphome.const import DEVICE_CLASS_SWITCH
 import esphome.final_validate as fv
 
-from .. import KEY_MODEM_MODEL, final_validate_platform, modem_ns
+from .. import (
+    CONF_MODEL,
+    CONF_MODEM_ID,
+    MODEM_COMPONENT_SCHEMA,
+    final_validate_platform,
+    modem_ns,
+)
 
 CODEOWNERS = ["@oarcher"]
 
@@ -17,8 +23,6 @@ IS_PLATFORM_COMPONENT = True
 CONF_GNSS = "gnss"
 CONF_GNSS_COMMAND = "gnss_command"  # will be set by _final_validate_gnss
 
-KEY_MODEM_GNSS = "modem_gnss"
-
 ICON_SATELLITE = "mdi:satellite-variant"
 
 GnssSwitch = modem_ns.class_("GnssSwitch", switch.Switch, cg.Component)
@@ -26,26 +30,46 @@ GnssSwitch = modem_ns.class_("GnssSwitch", switch.Switch, cg.Component)
 # SIM70xx doesn't support AT+CGNSSINFO, so gnss is not available
 MODEM_MODELS_GNSS_COMMAND = {"SIM7600": "AT+CGPS", "SIM7670": "AT+CGNSSPWR"}
 
-CONFIG_SCHEMA = cv.Schema(
-    {
-        cv.Optional(CONF_GNSS): switch.switch_schema(
-            GnssSwitch,
-            block_inverted=True,
-            device_class=DEVICE_CLASS_SWITCH,
-            icon=ICON_SATELLITE,
-        ),
-    }
-).extend(cv.COMPONENT_SCHEMA)
+
+CONF_GNSS_SWITCH_ID = "gnss_switch_id"
+
+GNSS_SWITCH_SCHEMA = cv.Schema(
+    {cv.GenerateID(CONF_GNSS_SWITCH_ID): cv.use_id(GnssSwitch)}
+)
+
+CONFIG_SCHEMA = (
+    cv.Schema(
+        {
+            cv.Optional(CONF_GNSS): switch.switch_schema(
+                GnssSwitch,
+                block_inverted=True,
+                device_class=DEVICE_CLASS_SWITCH,
+                icon=ICON_SATELLITE,
+            ),
+        }
+    )
+    .extend(GNSS_SWITCH_SCHEMA)
+    .extend(MODEM_COMPONENT_SCHEMA)
+    .extend(cv.COMPONENT_SCHEMA)
+)
 
 
 def _final_validate_gnss(config):
+    # get modem model from modem config, and add CONF_GNSS_COMMAND to config
     if config.get(CONF_GNSS, None):
-        full_config = fv.full_config.get()
-        modem_model = full_config.data.get(KEY_MODEM_MODEL, None)
-        if modem_model not in MODEM_MODELS_GNSS_COMMAND:
-            raise cv.Invalid(f"GNSS not supported for modem '{modem_model}'.")
-        config[CONF_GNSS_COMMAND] = MODEM_MODELS_GNSS_COMMAND[modem_model]
-        full_config.data[KEY_MODEM_GNSS] = True
+        fconf = fv.full_config.get()
+        modem_path = fconf.get_path_for_id(config[CONF_MODEM_ID])[:-1]
+        modem_config = fconf.get_config_for_path(modem_path)
+        if modem_model := modem_config.get(CONF_MODEL, None):
+            if modem_model not in MODEM_MODELS_GNSS_COMMAND:
+                raise cv.Invalid(
+                    f"GNSS not supported for modem '{modem_model}'. Supported models are %s",
+                    ", ".join(MODEM_MODELS_GNSS_COMMAND.keys()),
+                )
+
+            # is it allowed to add config option?
+            config[CONF_GNSS_COMMAND] = MODEM_MODELS_GNSS_COMMAND[modem_model]
+
     return config
 
 
diff --git a/esphome/components/modem/text_sensor/__init__.py b/esphome/components/modem/text_sensor/__init__.py
index 9a3a87dd0f..e466ad0550 100644
--- a/esphome/components/modem/text_sensor/__init__.py
+++ b/esphome/components/modem/text_sensor/__init__.py
@@ -3,7 +3,7 @@ from esphome.components import text_sensor
 import esphome.config_validation as cv
 from esphome.const import CONF_ID, DEVICE_CLASS_EMPTY
 
-from .. import final_validate_platform, modem_ns
+from .. import MODEM_COMPONENT_SCHEMA, final_validate_platform, modem_ns
 
 CODEOWNERS = ["@oarcher"]
 
@@ -25,7 +25,9 @@ CONFIG_SCHEMA = cv.All(
                 device_class=DEVICE_CLASS_EMPTY,
             ),
         }
-    ).extend(cv.polling_component_schema("60s"))
+    )
+    .extend(MODEM_COMPONENT_SCHEMA)
+    .extend(cv.polling_component_schema("60s"))
 )