From d9cf91210ea121c7bd46c3b9c4d48ed99022a3f5 Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Sun, 10 Feb 2019 16:57:34 +0100 Subject: [PATCH] Add local mDNS responder for .local (#386) * Add local mDNS responder * Fix * Use mDNS in dashboard status * Lint * Lint * Fix test * Remove hostname * Fix use enum * Lint --- docker/Dockerfile.hassio | 2 - esphomeyaml-edge/Dockerfile | 2 - esphomeyaml/api/client.py | 2 +- esphomeyaml/components/ethernet.py | 6 +- esphomeyaml/components/wifi.py | 6 +- esphomeyaml/dashboard/dashboard.py | 62 +-- esphomeyaml/helpers.py | 29 +- esphomeyaml/py_compat.py | 7 + esphomeyaml/zeroconf.py | 780 +++++++++++++++++++++++++++++ pylintrc | 1 + requirements.txt | 1 + setup.py | 1 + tests/test1.yaml | 1 - 13 files changed, 837 insertions(+), 63 deletions(-) create mode 100644 esphomeyaml/zeroconf.py diff --git a/docker/Dockerfile.hassio b/docker/Dockerfile.hassio index 45db490975..114af92ae7 100644 --- a/docker/Dockerfile.hassio +++ b/docker/Dockerfile.hassio @@ -24,8 +24,6 @@ RUN \ python-pil \ # Git for esphomelib downloads git \ - # Ping for dashboard online/offline status - iputils-ping \ # NGINX proxy nginx \ \ diff --git a/esphomeyaml-edge/Dockerfile b/esphomeyaml-edge/Dockerfile index a3db1d24ad..670b424735 100644 --- a/esphomeyaml-edge/Dockerfile +++ b/esphomeyaml-edge/Dockerfile @@ -22,8 +22,6 @@ RUN \ python-pil \ # Git for esphomelib downloads git \ - # Ping for dashboard online/offline status - iputils-ping \ # NGINX proxy nginx \ \ diff --git a/esphomeyaml/api/client.py b/esphomeyaml/api/client.py index e340e567ab..08ded6b2b9 100644 --- a/esphomeyaml/api/client.py +++ b/esphomeyaml/api/client.py @@ -454,7 +454,7 @@ def run_logs(config, address): wait_time = min(2**tries, 300) if not has_connects: - _LOGGER.warning(u"Initial connection failed. The ESP might not be connected" + _LOGGER.warning(u"Initial connection failed. The ESP might not be connected " u"to WiFi yet (%s). Re-Trying in %s seconds", error, wait_time) else: diff --git a/esphomeyaml/components/ethernet.py b/esphomeyaml/components/ethernet.py index 51ff308bbc..1284bc0aa7 100644 --- a/esphomeyaml/components/ethernet.py +++ b/esphomeyaml/components/ethernet.py @@ -3,7 +3,7 @@ import voluptuous as vol from esphomeyaml import pins from esphomeyaml.components import wifi import esphomeyaml.config_validation as cv -from esphomeyaml.const import CONF_DOMAIN, CONF_HOSTNAME, CONF_ID, CONF_MANUAL_IP, CONF_TYPE, \ +from esphomeyaml.const import CONF_DOMAIN, CONF_ID, CONF_MANUAL_IP, CONF_TYPE, \ ESP_PLATFORM_ESP32 from esphomeyaml.cpp_generator import Pvariable, add from esphomeyaml.cpp_helpers import gpio_output_pin_expression @@ -43,7 +43,6 @@ CONFIG_SCHEMA = vol.Schema({ vol.Optional(CONF_PHY_ADDR, default=0): vol.All(cv.int_, vol.Range(min=0, max=31)), vol.Optional(CONF_POWER_PIN): pins.gpio_output_pin_schema, vol.Optional(CONF_MANUAL_IP): wifi.STA_MANUAL_IP_SCHEMA, - vol.Optional(CONF_HOSTNAME): cv.hostname, vol.Optional(CONF_DOMAIN, default='.local'): cv.domain_name, }) @@ -63,9 +62,6 @@ def to_code(config): yield add(eth.set_power_pin(pin)) - if CONF_HOSTNAME in config: - add(eth.set_hostname(config[CONF_HOSTNAME])) - if CONF_MANUAL_IP in config: add(eth.set_manual_ip(wifi.manual_ip(config[CONF_MANUAL_IP]))) diff --git a/esphomeyaml/components/wifi.py b/esphomeyaml/components/wifi.py index 5067effb84..85a20060dc 100644 --- a/esphomeyaml/components/wifi.py +++ b/esphomeyaml/components/wifi.py @@ -111,6 +111,9 @@ CONFIG_SCHEMA = vol.All(vol.Schema({ vol.Optional(CONF_REBOOT_TIMEOUT): cv.positive_time_period_milliseconds, vol.Optional(CONF_POWER_SAVE_MODE): cv.one_of(*WIFI_POWER_SAVE_MODES, upper=True), vol.Optional(CONF_FAST_CONNECT): cv.boolean, + + vol.Optional(CONF_HOSTNAME): cv.invalid("The hostname option has been removed in 1.11.0, " + "now it's always the node name.") }), validate) @@ -166,9 +169,6 @@ def to_code(config): if CONF_AP in config: add(wifi.set_ap(wifi_network(config[CONF_AP], config.get(CONF_MANUAL_IP)))) - if CONF_HOSTNAME in config: - add(wifi.set_hostname(config[CONF_HOSTNAME])) - if CONF_REBOOT_TIMEOUT in config: add(wifi.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT])) diff --git a/esphomeyaml/dashboard/dashboard.py b/esphomeyaml/dashboard/dashboard.py index e9acd001c1..2b6a0a9def 100644 --- a/esphomeyaml/dashboard/dashboard.py +++ b/esphomeyaml/dashboard/dashboard.py @@ -2,11 +2,9 @@ from __future__ import print_function import codecs -import collections import hmac import json import logging -import multiprocessing import os import subprocess import threading @@ -25,7 +23,7 @@ import tornado.websocket from esphomeyaml import const from esphomeyaml.__main__ import get_serial_ports -from esphomeyaml.helpers import mkdir_p, run_system_command +from esphomeyaml.helpers import mkdir_p from esphomeyaml.py_compat import IS_PY2 from esphomeyaml.storage_json import EsphomeyamlStorageJSON, StorageJSON, \ esphomeyaml_storage_path, ext_storage_path @@ -34,6 +32,8 @@ from esphomeyaml.util import shlex_quote # pylint: disable=unused-import, wrong-import-order from typing import Optional # noqa +from esphomeyaml.zeroconf import Zeroconf, DashboardStatus + _LOGGER = logging.getLogger(__name__) CONFIG_DIR = '' PASSWORD_DIGEST = '' @@ -315,55 +315,25 @@ class MainRequestHandler(BaseHandler): get_static_file_url=get_static_file_url) -def _ping_func(filename, address): - if os.name == 'nt': - command = ['ping', '-n', '1', address] - else: - command = ['ping', '-c', '1', address] - rc, _, _ = run_system_command(*command) - return filename, rc == 0 - - class PingThread(threading.Thread): def run(self): - pool = multiprocessing.Pool(processes=8) + zc = Zeroconf() + def on_update(dat): + for key, b in dat.items(): + PING_RESULT[key] = b + + stat = DashboardStatus(zc, on_update) + stat.start() while not STOP_EVENT.is_set(): - # Only do pings if somebody has the dashboard open + entries = _list_dashboard_entries() + stat.request_query({entry.filename: entry.name + '.local.' for entry in entries}) + PING_REQUEST.wait() PING_REQUEST.clear() - - def callback(ret): - PING_RESULT[ret[0]] = ret[1] - - entries = _list_dashboard_entries() - queue = collections.deque() - for entry in entries: - if entry.address is None: - PING_RESULT[entry.filename] = None - continue - - result = pool.apply_async(_ping_func, (entry.filename, entry.address), - callback=callback) - queue.append(result) - - while queue: - item = queue[0] - if item.ready(): - queue.popleft() - continue - - try: - item.get(0.1) - except OSError: - # ping not installed - pass - except multiprocessing.TimeoutError: - pass - - if STOP_EVENT.is_set(): - pool.terminate() - return + stat.stop() + stat.join() + zc.close() class PingRequestHandler(BaseHandler): diff --git a/esphomeyaml/helpers.py b/esphomeyaml/helpers.py index b13b9768e1..0697fc820d 100644 --- a/esphomeyaml/helpers.py +++ b/esphomeyaml/helpers.py @@ -7,6 +7,7 @@ import socket import subprocess from esphomeyaml.py_compat import text_type, char_to_byte +from esphomeyaml.zeroconf import Zeroconf _LOGGER = logging.getLogger(__name__) @@ -100,12 +101,34 @@ def is_ip_address(host): return False +def _resolve_with_zeroconf(host): + from esphomeyaml.core import EsphomeyamlError + try: + zc = Zeroconf() + except Exception: + raise EsphomeyamlError("Cannot start mDNS sockets, is this a docker container without " + "host network mode?") + try: + info = zc.resolve_host(host + '.') + except Exception as err: + raise EsphomeyamlError("Error resolving mDNS hostname: {}".format(err)) + finally: + zc.close() + if info is None: + raise EsphomeyamlError("Error resolving address with mDNS: Did not respond. " + "Maybe the device is offline.") + return info + + def resolve_ip_address(host): + from esphomeyaml.core import EsphomeyamlError + try: ip = socket.gethostbyname(host) except socket.error as err: - from esphomeyaml.core import EsphomeyamlError - - raise EsphomeyamlError("Error resolving IP address: {}".format(err)) + if host.endswith('.local'): + ip = _resolve_with_zeroconf(host) + else: + raise EsphomeyamlError("Error resolving IP address: {}".format(err)) return ip diff --git a/esphomeyaml/py_compat.py b/esphomeyaml/py_compat.py index 5147749efa..16a4d27ebf 100644 --- a/esphomeyaml/py_compat.py +++ b/esphomeyaml/py_compat.py @@ -62,3 +62,10 @@ def sort_by_cmp(list_, cmp): list_.sort(cmp=cmp) else: list_.sort(key=functools.cmp_to_key(cmp)) + + +def indexbytes(buf, i): + if IS_PY3: + return buf[i] + else: + return ord(buf[i]) diff --git a/esphomeyaml/zeroconf.py b/esphomeyaml/zeroconf.py new file mode 100644 index 0000000000..67ade2a474 --- /dev/null +++ b/esphomeyaml/zeroconf.py @@ -0,0 +1,780 @@ +# Custom zeroconf implementation based on python-zeroconf +# (https://github.com/jstasiak/python-zeroconf) that supports Python 2 + +import errno +import logging +import select +import socket +import struct +import sys +import threading +import time + +import ifaddr + +from esphomeyaml.py_compat import indexbytes, text_type + +log = logging.getLogger(__name__) + +# Some timing constants + +_LISTENER_TIME = 200 + +# Some DNS constants + +_MDNS_ADDR = '224.0.0.251' +_MDNS_PORT = 5353 + +_MAX_MSG_ABSOLUTE = 8966 + +_FLAGS_QR_MASK = 0x8000 # query response mask +_FLAGS_QR_QUERY = 0x0000 # query +_FLAGS_QR_RESPONSE = 0x8000 # response + +_FLAGS_AA = 0x0400 # Authoritative answer +_FLAGS_TC = 0x0200 # Truncated +_FLAGS_RD = 0x0100 # Recursion desired +_FLAGS_RA = 0x8000 # Recursion available + +_FLAGS_Z = 0x0040 # Zero +_FLAGS_AD = 0x0020 # Authentic data +_FLAGS_CD = 0x0010 # Checking disabled + +_CLASS_IN = 1 +_CLASS_CS = 2 +_CLASS_CH = 3 +_CLASS_HS = 4 +_CLASS_NONE = 254 +_CLASS_ANY = 255 +_CLASS_MASK = 0x7FFF +_CLASS_UNIQUE = 0x8000 + +_TYPE_A = 1 +_TYPE_NS = 2 +_TYPE_MD = 3 +_TYPE_MF = 4 +_TYPE_CNAME = 5 +_TYPE_SOA = 6 +_TYPE_MB = 7 +_TYPE_MG = 8 +_TYPE_MR = 9 +_TYPE_NULL = 10 +_TYPE_WKS = 11 +_TYPE_PTR = 12 +_TYPE_HINFO = 13 +_TYPE_MINFO = 14 +_TYPE_MX = 15 +_TYPE_TXT = 16 +_TYPE_AAAA = 28 +_TYPE_SRV = 33 +_TYPE_ANY = 255 + +# Mapping constants to names +int2byte = struct.Struct(">B").pack + + +# Exceptions +class Error(Exception): + pass + + +class IncomingDecodeError(Error): + pass + + +# pylint: disable=no-init +class QuietLogger(object): + _seen_logs = {} + + @classmethod + def log_exception_warning(cls, logger_data=None): + exc_info = sys.exc_info() + exc_str = str(exc_info[1]) + if exc_str not in cls._seen_logs: + # log at warning level the first time this is seen + cls._seen_logs[exc_str] = exc_info + logger = log.warning + else: + logger = log.debug + if logger_data is not None: + logger(*logger_data) + logger('Exception occurred:', exc_info=True) + + @classmethod + def log_warning_once(cls, *args): + msg_str = args[0] + if msg_str not in cls._seen_logs: + cls._seen_logs[msg_str] = 0 + logger = log.warning + else: + logger = log.debug + cls._seen_logs[msg_str] += 1 + logger(*args) + + +class DNSEntry(object): + """A DNS entry""" + + def __init__(self, name, type_, class_): + self.key = name.lower() + self.name = name + self.type = type_ + self.class_ = class_ & _CLASS_MASK + self.unique = (class_ & _CLASS_UNIQUE) != 0 + + +class DNSQuestion(DNSEntry): + """A DNS question entry""" + + def __init__(self, name, type_, class_): + DNSEntry.__init__(self, name, type_, class_) + + def answered_by(self, rec): + """Returns true if the question is answered by the record""" + return (self.class_ == rec.class_ and + (self.type == rec.type or self.type == _TYPE_ANY) and + self.name == rec.name) + + +class DNSRecord(DNSEntry): + """A DNS record - like a DNS entry, but has a TTL""" + + def __init__(self, name, type_, class_, ttl): + DNSEntry.__init__(self, name, type_, class_) + self.ttl = 15 + self.created = time.time() + + def write(self, out): + """Abstract method""" + raise NotImplementedError + + def is_expired(self, now): + return self.created + self.ttl <= now + + def is_removable(self, now): + return self.created + self.ttl * 2 <= now + + +class DNSAddress(DNSRecord): + """A DNS address record""" + + def __init__(self, name, type_, class_, ttl, address): + DNSRecord.__init__(self, name, type_, class_, ttl) + self.address = address + + def write(self, out): + """Used in constructing an outgoing packet""" + out.write_string(self.address) + + +class DNSText(DNSRecord): + """A DNS text record""" + + def __init__(self, name, type_, class_, ttl, text): + assert isinstance(text, (bytes, type(None))) + DNSRecord.__init__(self, name, type_, class_, ttl) + self.text = text + + def write(self, out): + """Used in constructing an outgoing packet""" + out.write_string(self.text) + + +class DNSIncoming(QuietLogger): + """Object representation of an incoming DNS packet""" + + def __init__(self, data): + """Constructor from string holding bytes of packet""" + self.offset = 0 + self.data = data + self.questions = [] + self.answers = [] + self.id = 0 + self.flags = 0 # type: int + self.num_questions = 0 + self.num_answers = 0 + self.num_authorities = 0 + self.num_additionals = 0 + self.valid = False + + try: + self.read_header() + self.read_questions() + self.read_others() + self.valid = True + + except (IndexError, struct.error, IncomingDecodeError): + self.log_exception_warning(( + 'Choked at offset %d while unpacking %r', self.offset, data)) + + def unpack(self, format_): + length = struct.calcsize(format_) + info = struct.unpack( + format_, self.data[self.offset:self.offset + length]) + self.offset += length + return info + + def read_header(self): + """Reads header portion of packet""" + (self.id, self.flags, self.num_questions, self.num_answers, + self.num_authorities, self.num_additionals) = self.unpack(b'!6H') + + def read_questions(self): + """Reads questions section of packet""" + for _ in range(self.num_questions): + name = self.read_name() + type_, class_ = self.unpack(b'!HH') + + question = DNSQuestion(name, type_, class_) + self.questions.append(question) + + def read_character_string(self): + """Reads a character string from the packet""" + length = self.data[self.offset] + self.offset += 1 + return self.read_string(length) + + def read_string(self, length): + """Reads a string of a given length from the packet""" + info = self.data[self.offset:self.offset + length] + self.offset += length + return info + + def read_unsigned_short(self): + """Reads an unsigned short from the packet""" + return self.unpack(b'!H')[0] + + def read_others(self): + """Reads the answers, authorities and additionals section of the + packet""" + n = self.num_answers + self.num_authorities + self.num_additionals + for _ in range(n): + domain = self.read_name() + type_, class_, ttl, length = self.unpack(b'!HHiH') + + rec = None + if type_ == _TYPE_A: + rec = DNSAddress( + domain, type_, class_, ttl, self.read_string(4)) + elif type_ == _TYPE_TXT: + rec = DNSText( + domain, type_, class_, ttl, self.read_string(length)) + elif type_ == _TYPE_AAAA: + rec = DNSAddress( + domain, type_, class_, ttl, self.read_string(16)) + else: + # Try to ignore types we don't know about + # Skip the payload for the resource record so the next + # records can be parsed correctly + self.offset += length + + if rec is not None: + self.answers.append(rec) + + def is_query(self): + """Returns true if this is a query""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY + + def is_response(self): + """Returns true if this is a response""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE + + def read_utf(self, offset, length): + """Reads a UTF-8 string of a given length from the packet""" + return text_type(self.data[offset:offset + length], 'utf-8', 'replace') + + def read_name(self): + """Reads a domain name from the packet""" + result = '' + off = self.offset + next_ = -1 + first = off + + while True: + length = indexbytes(self.data, off) + off += 1 + if length == 0: + break + t = length & 0xC0 + if t == 0x00: + result = ''.join((result, self.read_utf(off, length) + '.')) + off += length + elif t == 0xC0: + if next_ < 0: + next_ = off + 1 + off = ((length & 0x3F) << 8) | indexbytes(self.data, off) + if off >= first: + raise IncomingDecodeError( + "Bad domain name (circular) at %s" % (off,)) + first = off + else: + raise IncomingDecodeError("Bad domain name at %s" % (off,)) + + if next_ >= 0: + self.offset = next_ + else: + self.offset = off + + return result + + +class DNSOutgoing(object): + """Object representation of an outgoing packet""" + + def __init__(self, flags): + self.finished = False + self.id = 0 + self.flags = flags + self.names = {} + self.data = [] + self.size = 12 + self.state = False + + self.questions = [] + self.answers = [] + + def add_question(self, record): + """Adds a question""" + self.questions.append(record) + + def pack(self, format_, value): + self.data.append(struct.pack(format_, value)) + self.size += struct.calcsize(format_) + + def write_byte(self, value): + """Writes a single byte to the packet""" + self.pack(b'!c', int2byte(value)) + + def insert_short(self, index, value): + """Inserts an unsigned short in a certain position in the packet""" + self.data.insert(index, struct.pack(b'!H', value)) + self.size += 2 + + def write_short(self, value): + """Writes an unsigned short to the packet""" + self.pack(b'!H', value) + + def write_int(self, value): + """Writes an unsigned integer to the packet""" + self.pack(b'!I', int(value)) + + def write_string(self, value): + """Writes a string to the packet""" + assert isinstance(value, bytes) + self.data.append(value) + self.size += len(value) + + def write_utf(self, s): + """Writes a UTF-8 string of a given length to the packet""" + utfstr = s.encode('utf-8') + length = len(utfstr) + self.write_byte(length) + self.write_string(utfstr) + + def write_character_string(self, value): + assert isinstance(value, bytes) + length = len(value) + self.write_byte(length) + self.write_string(value) + + def write_name(self, name): + # split name into each label + parts = name.split('.') + if not parts[-1]: + parts.pop() + + # construct each suffix + name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))] + + # look for an existing name or suffix + for count, sub_name in enumerate(name_suffices): + if sub_name in self.names: + break + else: + count = len(name_suffices) + + # note the new names we are saving into the packet + name_length = len(name.encode('utf-8')) + for suffix in name_suffices[:count]: + self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1 + + # write the new names out. + for part in parts[:count]: + self.write_utf(part) + + # if we wrote part of the name, create a pointer to the rest + if count != len(name_suffices): + # Found substring in packet, create pointer + index = self.names[name_suffices[count]] + self.write_byte((index >> 8) | 0xC0) + self.write_byte(index & 0xFF) + else: + # this is the end of a name + self.write_byte(0) + + def write_question(self, question): + self.write_name(question.name) + self.write_short(question.type) + self.write_short(question.class_) + + def packet(self): + if not self.state: + for question in self.questions: + self.write_question(question) + self.state = True + + self.insert_short(0, 0) # num additionals + self.insert_short(0, 0) # num authorities + self.insert_short(0, 0) # num answers + self.insert_short(0, len(self.questions)) + self.insert_short(0, self.flags) # _FLAGS_QR_QUERY + self.insert_short(0, 0) + return b''.join(self.data) + + +class Engine(threading.Thread): + def __init__(self, zc): + threading.Thread.__init__(self, name='zeroconf-Engine') + self.daemon = True + self.zc = zc + self.readers = {} + self.timeout = 5 + self.condition = threading.Condition() + self.start() + + def run(self): + while not self.zc.done: + # pylint: disable=len-as-condition + with self.condition: + rs = self.readers.keys() + if len(rs) == 0: + # No sockets to manage, but we wait for the timeout + # or addition of a socket + self.condition.wait(self.timeout) + + if len(rs) != 0: + try: + rr, _, _ = select.select(rs, [], [], self.timeout) + if not self.zc.done: + for socket_ in rr: + reader = self.readers.get(socket_) + if reader: + reader.handle_read(socket_) + + except (select.error, socket.error) as e: + # If the socket was closed by another thread, during + # shutdown, ignore it and exit + if e.args[0] != socket.EBADF or not self.zc.done: + raise + + def add_reader(self, reader, socket_): + with self.condition: + self.readers[socket_] = reader + self.condition.notify() + + def del_reader(self, socket_): + with self.condition: + del self.readers[socket_] + self.condition.notify() + + +class Listener(QuietLogger): + def __init__(self, zc): + self.zc = zc + self.data = None + + def handle_read(self, socket_): + try: + data, (addr, port) = socket_.recvfrom(_MAX_MSG_ABSOLUTE) + except Exception: # pylint: disable=broad-except + self.log_exception_warning() + return + + log.debug('Received from %r:%r: %r ', addr, port, data) + + self.data = data + msg = DNSIncoming(data) + if not msg.valid or msg.is_query(): + pass + else: + self.zc.handle_response(msg) + + +class RecordUpdateListener(object): + def update_record(self, zc, now, record): + raise NotImplementedError() + + +class HostResolver(RecordUpdateListener): + def __init__(self, name): + self.name = name + self.address = None + + def update_record(self, zc, now, record): + if record is None: + return + if record.type == _TYPE_A: + assert isinstance(record, DNSAddress) + if record.name == self.name: + self.address = record.address + + def request(self, zc, timeout): + now = time.time() + delay = 0.2 + next_ = now + delay + last = now + timeout + + try: + zc.add_listener(self) + while self.address is None: + if last <= now: + # Timeout + return False + if next_ <= now: + out = DNSOutgoing(_FLAGS_QR_QUERY) + out.add_question( + DNSQuestion(self.name, _TYPE_A, _CLASS_IN)) + zc.send(out) + next_ = now + delay + delay *= 2 + + zc.wait(min(next_, last) - now) + now = time.time() + finally: + zc.remove_listener(self) + + return True + + +class DashboardStatus(RecordUpdateListener, threading.Thread): + def __init__(self, zc, on_update): + threading.Thread.__init__(self) + self.zc = zc + self.query_hosts = set() + self.key_to_host = {} + self.cache = {} + self.stop_event = threading.Event() + self.query_event = threading.Event() + self.on_update = on_update + + def update_record(self, zc, now, record): + if record is None: + return + if record.type in (_TYPE_A, _TYPE_AAAA, _TYPE_TXT): + assert isinstance(record, DNSEntry) + if record.name in self.query_hosts: + self.cache.setdefault(record.name, []).insert(0, record) + self.purge_cache() + + def purge_cache(self): + new_cache = {} + for host, records in self.cache.items(): + if host not in self.query_hosts: + continue + new_records = [rec for rec in records if not rec.is_removable(time.time())] + if new_records: + new_cache[host] = new_records + self.cache = new_cache + self.on_update({key: self.host_status(key) for key in self.key_to_host}) + + def request_query(self, hosts): + self.query_hosts = set(host for host in hosts.values()) + self.key_to_host = hosts + self.query_event.set() + + def stop(self): + self.stop_event.set() + self.query_event.set() + + def host_status(self, key): + return self.key_to_host.get(key) in self.cache + + def run(self): + self.zc.add_listener(self) + while not self.stop_event.is_set(): + self.purge_cache() + for host in self.query_hosts: + if all(record.is_expired(time.time()) for record in self.cache.get(host, [])): + out = DNSOutgoing(_FLAGS_QR_QUERY) + out.add_question( + DNSQuestion(host, _TYPE_A, _CLASS_IN)) + self.zc.send(out) + self.query_event.wait() + self.query_event.clear() + self.zc.remove_listener(self) + + +def get_all_addresses(): + return list(set( + addr.ip + for iface in ifaddr.get_adapters() + for addr in iface.ips + if addr.is_IPv4 and addr.network_prefix != 32 # Host only netmask 255.255.255.255 + )) + + +def new_socket(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + # SO_REUSEADDR should be equivalent to SO_REUSEPORT for + # multicast UDP sockets (p 731, "TCP/IP Illustrated, + # Volume 2"), but some BSD-derived systems require + # SO_REUSEPORT to be specified explicitly. Also, not all + # versions of Python have SO_REUSEPORT available. + # Catch OSError and socket.error for kernel versions <3.9 because lacking + # SO_REUSEPORT support. + try: + reuseport = socket.SO_REUSEPORT + except AttributeError: + pass + else: + try: + s.setsockopt(socket.SOL_SOCKET, reuseport, 1) + except (OSError, socket.error) as err: + # OSError on python 3, socket.error on python 2 + if err.errno != errno.ENOPROTOOPT: + raise + + # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and + # IP_MULTICAST_LOOP socket options as an unsigned char. + ttl = struct.pack(b'B', 255) + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) + loop = struct.pack(b'B', 1) + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) + + s.bind(('', _MDNS_PORT)) + return s + + +class Zeroconf(QuietLogger): + def __init__(self): + # hook for threads + self._GLOBAL_DONE = False + + self._listen_socket = new_socket() + interfaces = get_all_addresses() + + self._respond_sockets = [] + + for i in interfaces: + try: + _value = socket.inet_aton(_MDNS_ADDR) + socket.inet_aton(i) + self._listen_socket.setsockopt( + socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value) + except socket.error as e: + _errno = e.args[0] + if _errno == errno.EADDRINUSE: + log.info( + 'Address in use when adding %s to multicast group, ' + 'it is expected to happen on some systems', i, + ) + elif _errno == errno.EADDRNOTAVAIL: + log.info( + 'Address not available when adding %s to multicast ' + 'group, it is expected to happen on some systems', i, + ) + continue + elif _errno == errno.EINVAL: + log.info( + 'Interface of %s does not support multicast, ' + 'it is expected in WSL', i + ) + continue + + else: + raise + + respond_socket = new_socket() + respond_socket.setsockopt( + socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(i)) + + self._respond_sockets.append(respond_socket) + + self.listeners = [] + + self.condition = threading.Condition() + + self.engine = Engine(self) + self.listener = Listener(self) + self.engine.add_reader(self.listener, self._listen_socket) + + @property + def done(self): + return self._GLOBAL_DONE + + def wait(self, timeout): + """Calling thread waits for a given number of milliseconds or + until notified.""" + with self.condition: + self.condition.wait(timeout) + + def notify_all(self): + """Notifies all waiting threads""" + with self.condition: + self.condition.notify_all() + + def resolve_host(self, host, timeout=3.0): + info = HostResolver(host) + if info.request(self, timeout): + return socket.inet_ntoa(info.address) + return None + + def add_listener(self, listener): + self.listeners.append(listener) + self.notify_all() + + def remove_listener(self, listener): + """Removes a listener.""" + try: + self.listeners.remove(listener) + self.notify_all() + except Exception as e: # pylint: disable=broad-except + log.exception('Unknown error, possibly benign: %r', e) + + def update_record(self, now, rec): + """Used to notify listeners of new information that has updated + a record.""" + for listener in self.listeners: + listener.update_record(self, now, rec) + self.notify_all() + + def handle_response(self, msg): + """Deal with incoming response packets. All answers + are held in the cache, and listeners are notified.""" + now = time.time() + for record in msg.answers: + self.update_record(now, record) + + def send(self, out): + """Sends an outgoing packet.""" + packet = out.packet() + log.debug('Sending %r (%d bytes) as %r...', out, len(packet), packet) + for s in self._respond_sockets: + if self._GLOBAL_DONE: + return + try: + bytes_sent = s.sendto(packet, 0, (_MDNS_ADDR, _MDNS_PORT)) + except Exception: # pylint: disable=broad-except + # on send errors, log the exception and keep going + self.log_exception_warning() + else: + if bytes_sent != len(packet): + self.log_warning_once( + '!!! sent %d out of %d bytes to %r' % ( + bytes_sent, len(packet), s)) + + def close(self): + """Ends the background threads, and prevent this instance from + servicing further queries.""" + if not self._GLOBAL_DONE: + self._GLOBAL_DONE = True + # shutdown recv socket and thread + self.engine.del_reader(self._listen_socket) + self._listen_socket.close() + self.engine.join() + + # shutdown the rest + self.notify_all() + for s in self._respond_sockets: + s.close() diff --git a/pylintrc b/pylintrc index 6532088f18..0fa064d807 100644 --- a/pylintrc +++ b/pylintrc @@ -14,6 +14,7 @@ disable= too-many-statements, too-many-arguments, too-many-return-statements, + too-many-instance-attributes, duplicate-code, invalid-name, cyclic-import, diff --git a/requirements.txt b/requirements.txt index 8f5b324ba8..402b4185bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ esptool>=2.3.1 typing>=3.0.0 protobuf>=3.4 pyserial>=3.4,<4 +ifaddr>=0.1.6 diff --git a/setup.py b/setup.py index c85a3f66d6..0453806688 100755 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ REQUIRES = [ 'protobuf>=3.4', 'tzlocal>=1.4', 'pyserial>=3.4,<4', + 'ifaddr>=0.1.6', ] # If you have problems importing platformio and esptool as modules you can set diff --git a/tests/test1.yaml b/tests/test1.yaml index ec494c2763..2e6366d3be 100644 --- a/tests/test1.yaml +++ b/tests/test1.yaml @@ -35,7 +35,6 @@ wifi: subnet: 255.255.255.0 dns1: 1.1.1.1 dns2: 1.2.2.1 - hostname: myverylonghostname domain: .local reboot_timeout: 120s power_save_mode: none