mirror of
				https://github.com/esphome/esphome.git
				synced 2025-10-30 22:53:59 +00:00 
			
		
		
		
	Add local mDNS responder for .local (#386)
* Add local mDNS responder * Fix * Use mDNS in dashboard status * Lint * Lint * Fix test * Remove hostname * Fix use enum * Lint
This commit is contained in:
		| @@ -454,7 +454,7 @@ def run_logs(config, address): | ||||
|  | ||||
|         wait_time = min(2**tries, 300) | ||||
|         if not has_connects: | ||||
|             _LOGGER.warning(u"Initial connection failed. The ESP might not be connected" | ||||
|             _LOGGER.warning(u"Initial connection failed. The ESP might not be connected " | ||||
|                             u"to WiFi yet (%s). Re-Trying in %s seconds", | ||||
|                             error, wait_time) | ||||
|         else: | ||||
|   | ||||
| @@ -3,7 +3,7 @@ import voluptuous as vol | ||||
| from esphomeyaml import pins | ||||
| from esphomeyaml.components import wifi | ||||
| import esphomeyaml.config_validation as cv | ||||
| from esphomeyaml.const import CONF_DOMAIN, CONF_HOSTNAME, CONF_ID, CONF_MANUAL_IP, CONF_TYPE, \ | ||||
| from esphomeyaml.const import CONF_DOMAIN, CONF_ID, CONF_MANUAL_IP, CONF_TYPE, \ | ||||
|     ESP_PLATFORM_ESP32 | ||||
| from esphomeyaml.cpp_generator import Pvariable, add | ||||
| from esphomeyaml.cpp_helpers import gpio_output_pin_expression | ||||
| @@ -43,7 +43,6 @@ CONFIG_SCHEMA = vol.Schema({ | ||||
|     vol.Optional(CONF_PHY_ADDR, default=0): vol.All(cv.int_, vol.Range(min=0, max=31)), | ||||
|     vol.Optional(CONF_POWER_PIN): pins.gpio_output_pin_schema, | ||||
|     vol.Optional(CONF_MANUAL_IP): wifi.STA_MANUAL_IP_SCHEMA, | ||||
|     vol.Optional(CONF_HOSTNAME): cv.hostname, | ||||
|     vol.Optional(CONF_DOMAIN, default='.local'): cv.domain_name, | ||||
| }) | ||||
|  | ||||
| @@ -63,9 +62,6 @@ def to_code(config): | ||||
|             yield | ||||
|         add(eth.set_power_pin(pin)) | ||||
|  | ||||
|     if CONF_HOSTNAME in config: | ||||
|         add(eth.set_hostname(config[CONF_HOSTNAME])) | ||||
|  | ||||
|     if CONF_MANUAL_IP in config: | ||||
|         add(eth.set_manual_ip(wifi.manual_ip(config[CONF_MANUAL_IP]))) | ||||
|  | ||||
|   | ||||
| @@ -111,6 +111,9 @@ CONFIG_SCHEMA = vol.All(vol.Schema({ | ||||
|     vol.Optional(CONF_REBOOT_TIMEOUT): cv.positive_time_period_milliseconds, | ||||
|     vol.Optional(CONF_POWER_SAVE_MODE): cv.one_of(*WIFI_POWER_SAVE_MODES, upper=True), | ||||
|     vol.Optional(CONF_FAST_CONNECT): cv.boolean, | ||||
|  | ||||
|     vol.Optional(CONF_HOSTNAME): cv.invalid("The hostname option has been removed in 1.11.0, " | ||||
|                                             "now it's always the node name.") | ||||
| }), validate) | ||||
|  | ||||
|  | ||||
| @@ -166,9 +169,6 @@ def to_code(config): | ||||
|     if CONF_AP in config: | ||||
|         add(wifi.set_ap(wifi_network(config[CONF_AP], config.get(CONF_MANUAL_IP)))) | ||||
|  | ||||
|     if CONF_HOSTNAME in config: | ||||
|         add(wifi.set_hostname(config[CONF_HOSTNAME])) | ||||
|  | ||||
|     if CONF_REBOOT_TIMEOUT in config: | ||||
|         add(wifi.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT])) | ||||
|  | ||||
|   | ||||
| @@ -2,11 +2,9 @@ | ||||
| from __future__ import print_function | ||||
|  | ||||
| import codecs | ||||
| import collections | ||||
| import hmac | ||||
| import json | ||||
| import logging | ||||
| import multiprocessing | ||||
| import os | ||||
| import subprocess | ||||
| import threading | ||||
| @@ -25,7 +23,7 @@ import tornado.websocket | ||||
|  | ||||
| from esphomeyaml import const | ||||
| from esphomeyaml.__main__ import get_serial_ports | ||||
| from esphomeyaml.helpers import mkdir_p, run_system_command | ||||
| from esphomeyaml.helpers import mkdir_p | ||||
| from esphomeyaml.py_compat import IS_PY2 | ||||
| from esphomeyaml.storage_json import EsphomeyamlStorageJSON, StorageJSON, \ | ||||
|     esphomeyaml_storage_path, ext_storage_path | ||||
| @@ -34,6 +32,8 @@ from esphomeyaml.util import shlex_quote | ||||
| # pylint: disable=unused-import, wrong-import-order | ||||
| from typing import Optional  # noqa | ||||
|  | ||||
| from esphomeyaml.zeroconf import Zeroconf, DashboardStatus | ||||
|  | ||||
| _LOGGER = logging.getLogger(__name__) | ||||
| CONFIG_DIR = '' | ||||
| PASSWORD_DIGEST = '' | ||||
| @@ -315,55 +315,25 @@ class MainRequestHandler(BaseHandler): | ||||
|                     get_static_file_url=get_static_file_url) | ||||
|  | ||||
|  | ||||
| def _ping_func(filename, address): | ||||
|     if os.name == 'nt': | ||||
|         command = ['ping', '-n', '1', address] | ||||
|     else: | ||||
|         command = ['ping', '-c', '1', address] | ||||
|     rc, _, _ = run_system_command(*command) | ||||
|     return filename, rc == 0 | ||||
|  | ||||
|  | ||||
| class PingThread(threading.Thread): | ||||
|     def run(self): | ||||
|         pool = multiprocessing.Pool(processes=8) | ||||
|         zc = Zeroconf() | ||||
|  | ||||
|         def on_update(dat): | ||||
|             for key, b in dat.items(): | ||||
|                 PING_RESULT[key] = b | ||||
|  | ||||
|         stat = DashboardStatus(zc, on_update) | ||||
|         stat.start() | ||||
|         while not STOP_EVENT.is_set(): | ||||
|             # Only do pings if somebody has the dashboard open | ||||
|             entries = _list_dashboard_entries() | ||||
|             stat.request_query({entry.filename: entry.name + '.local.' for entry in entries}) | ||||
|  | ||||
|             PING_REQUEST.wait() | ||||
|             PING_REQUEST.clear() | ||||
|  | ||||
|             def callback(ret): | ||||
|                 PING_RESULT[ret[0]] = ret[1] | ||||
|  | ||||
|             entries = _list_dashboard_entries() | ||||
|             queue = collections.deque() | ||||
|             for entry in entries: | ||||
|                 if entry.address is None: | ||||
|                     PING_RESULT[entry.filename] = None | ||||
|                     continue | ||||
|  | ||||
|                 result = pool.apply_async(_ping_func, (entry.filename, entry.address), | ||||
|                                           callback=callback) | ||||
|                 queue.append(result) | ||||
|  | ||||
|             while queue: | ||||
|                 item = queue[0] | ||||
|                 if item.ready(): | ||||
|                     queue.popleft() | ||||
|                     continue | ||||
|  | ||||
|                 try: | ||||
|                     item.get(0.1) | ||||
|                 except OSError: | ||||
|                     # ping not installed | ||||
|                     pass | ||||
|                 except multiprocessing.TimeoutError: | ||||
|                     pass | ||||
|  | ||||
|                 if STOP_EVENT.is_set(): | ||||
|                     pool.terminate() | ||||
|                     return | ||||
|         stat.stop() | ||||
|         stat.join() | ||||
|         zc.close() | ||||
|  | ||||
|  | ||||
| class PingRequestHandler(BaseHandler): | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import socket | ||||
| import subprocess | ||||
|  | ||||
| from esphomeyaml.py_compat import text_type, char_to_byte | ||||
| from esphomeyaml.zeroconf import Zeroconf | ||||
|  | ||||
| _LOGGER = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -100,12 +101,34 @@ def is_ip_address(host): | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def _resolve_with_zeroconf(host): | ||||
|     from esphomeyaml.core import EsphomeyamlError | ||||
|     try: | ||||
|         zc = Zeroconf() | ||||
|     except Exception: | ||||
|         raise EsphomeyamlError("Cannot start mDNS sockets, is this a docker container without " | ||||
|                                "host network mode?") | ||||
|     try: | ||||
|         info = zc.resolve_host(host + '.') | ||||
|     except Exception as err: | ||||
|         raise EsphomeyamlError("Error resolving mDNS hostname: {}".format(err)) | ||||
|     finally: | ||||
|         zc.close() | ||||
|     if info is None: | ||||
|         raise EsphomeyamlError("Error resolving address with mDNS: Did not respond. " | ||||
|                                "Maybe the device is offline.") | ||||
|     return info | ||||
|  | ||||
|  | ||||
| def resolve_ip_address(host): | ||||
|     from esphomeyaml.core import EsphomeyamlError | ||||
|  | ||||
|     try: | ||||
|         ip = socket.gethostbyname(host) | ||||
|     except socket.error as err: | ||||
|         from esphomeyaml.core import EsphomeyamlError | ||||
|  | ||||
|         raise EsphomeyamlError("Error resolving IP address: {}".format(err)) | ||||
|         if host.endswith('.local'): | ||||
|             ip = _resolve_with_zeroconf(host) | ||||
|         else: | ||||
|             raise EsphomeyamlError("Error resolving IP address: {}".format(err)) | ||||
|  | ||||
|     return ip | ||||
|   | ||||
| @@ -62,3 +62,10 @@ def sort_by_cmp(list_, cmp): | ||||
|         list_.sort(cmp=cmp) | ||||
|     else: | ||||
|         list_.sort(key=functools.cmp_to_key(cmp)) | ||||
|  | ||||
|  | ||||
| def indexbytes(buf, i): | ||||
|     if IS_PY3: | ||||
|         return buf[i] | ||||
|     else: | ||||
|         return ord(buf[i]) | ||||
|   | ||||
							
								
								
									
										780
									
								
								esphomeyaml/zeroconf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										780
									
								
								esphomeyaml/zeroconf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,780 @@ | ||||
| # Custom zeroconf implementation based on python-zeroconf | ||||
| # (https://github.com/jstasiak/python-zeroconf) that supports Python 2 | ||||
|  | ||||
| import errno | ||||
| import logging | ||||
| import select | ||||
| import socket | ||||
| import struct | ||||
| import sys | ||||
| import threading | ||||
| import time | ||||
|  | ||||
| import ifaddr | ||||
|  | ||||
| from esphomeyaml.py_compat import indexbytes, text_type | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| # Some timing constants | ||||
|  | ||||
| _LISTENER_TIME = 200 | ||||
|  | ||||
| # Some DNS constants | ||||
|  | ||||
| _MDNS_ADDR = '224.0.0.251' | ||||
| _MDNS_PORT = 5353 | ||||
|  | ||||
| _MAX_MSG_ABSOLUTE = 8966 | ||||
|  | ||||
| _FLAGS_QR_MASK = 0x8000  # query response mask | ||||
| _FLAGS_QR_QUERY = 0x0000  # query | ||||
| _FLAGS_QR_RESPONSE = 0x8000  # response | ||||
|  | ||||
| _FLAGS_AA = 0x0400  # Authoritative answer | ||||
| _FLAGS_TC = 0x0200  # Truncated | ||||
| _FLAGS_RD = 0x0100  # Recursion desired | ||||
| _FLAGS_RA = 0x8000  # Recursion available | ||||
|  | ||||
| _FLAGS_Z = 0x0040  # Zero | ||||
| _FLAGS_AD = 0x0020  # Authentic data | ||||
| _FLAGS_CD = 0x0010  # Checking disabled | ||||
|  | ||||
| _CLASS_IN = 1 | ||||
| _CLASS_CS = 2 | ||||
| _CLASS_CH = 3 | ||||
| _CLASS_HS = 4 | ||||
| _CLASS_NONE = 254 | ||||
| _CLASS_ANY = 255 | ||||
| _CLASS_MASK = 0x7FFF | ||||
| _CLASS_UNIQUE = 0x8000 | ||||
|  | ||||
| _TYPE_A = 1 | ||||
| _TYPE_NS = 2 | ||||
| _TYPE_MD = 3 | ||||
| _TYPE_MF = 4 | ||||
| _TYPE_CNAME = 5 | ||||
| _TYPE_SOA = 6 | ||||
| _TYPE_MB = 7 | ||||
| _TYPE_MG = 8 | ||||
| _TYPE_MR = 9 | ||||
| _TYPE_NULL = 10 | ||||
| _TYPE_WKS = 11 | ||||
| _TYPE_PTR = 12 | ||||
| _TYPE_HINFO = 13 | ||||
| _TYPE_MINFO = 14 | ||||
| _TYPE_MX = 15 | ||||
| _TYPE_TXT = 16 | ||||
| _TYPE_AAAA = 28 | ||||
| _TYPE_SRV = 33 | ||||
| _TYPE_ANY = 255 | ||||
|  | ||||
| # Mapping constants to names | ||||
| int2byte = struct.Struct(">B").pack | ||||
|  | ||||
|  | ||||
| # Exceptions | ||||
| class Error(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class IncomingDecodeError(Error): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| # pylint: disable=no-init | ||||
| class QuietLogger(object): | ||||
|     _seen_logs = {} | ||||
|  | ||||
|     @classmethod | ||||
|     def log_exception_warning(cls, logger_data=None): | ||||
|         exc_info = sys.exc_info() | ||||
|         exc_str = str(exc_info[1]) | ||||
|         if exc_str not in cls._seen_logs: | ||||
|             # log at warning level the first time this is seen | ||||
|             cls._seen_logs[exc_str] = exc_info | ||||
|             logger = log.warning | ||||
|         else: | ||||
|             logger = log.debug | ||||
|         if logger_data is not None: | ||||
|             logger(*logger_data) | ||||
|         logger('Exception occurred:', exc_info=True) | ||||
|  | ||||
|     @classmethod | ||||
|     def log_warning_once(cls, *args): | ||||
|         msg_str = args[0] | ||||
|         if msg_str not in cls._seen_logs: | ||||
|             cls._seen_logs[msg_str] = 0 | ||||
|             logger = log.warning | ||||
|         else: | ||||
|             logger = log.debug | ||||
|         cls._seen_logs[msg_str] += 1 | ||||
|         logger(*args) | ||||
|  | ||||
|  | ||||
| class DNSEntry(object): | ||||
|     """A DNS entry""" | ||||
|  | ||||
|     def __init__(self, name, type_, class_): | ||||
|         self.key = name.lower() | ||||
|         self.name = name | ||||
|         self.type = type_ | ||||
|         self.class_ = class_ & _CLASS_MASK | ||||
|         self.unique = (class_ & _CLASS_UNIQUE) != 0 | ||||
|  | ||||
|  | ||||
| class DNSQuestion(DNSEntry): | ||||
|     """A DNS question entry""" | ||||
|  | ||||
|     def __init__(self, name, type_, class_): | ||||
|         DNSEntry.__init__(self, name, type_, class_) | ||||
|  | ||||
|     def answered_by(self, rec): | ||||
|         """Returns true if the question is answered by the record""" | ||||
|         return (self.class_ == rec.class_ and | ||||
|                 (self.type == rec.type or self.type == _TYPE_ANY) and | ||||
|                 self.name == rec.name) | ||||
|  | ||||
|  | ||||
| class DNSRecord(DNSEntry): | ||||
|     """A DNS record - like a DNS entry, but has a TTL""" | ||||
|  | ||||
|     def __init__(self, name, type_, class_, ttl): | ||||
|         DNSEntry.__init__(self, name, type_, class_) | ||||
|         self.ttl = 15 | ||||
|         self.created = time.time() | ||||
|  | ||||
|     def write(self, out): | ||||
|         """Abstract method""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def is_expired(self, now): | ||||
|         return self.created + self.ttl <= now | ||||
|  | ||||
|     def is_removable(self, now): | ||||
|         return self.created + self.ttl * 2 <= now | ||||
|  | ||||
|  | ||||
| class DNSAddress(DNSRecord): | ||||
|     """A DNS address record""" | ||||
|  | ||||
|     def __init__(self, name, type_, class_, ttl, address): | ||||
|         DNSRecord.__init__(self, name, type_, class_, ttl) | ||||
|         self.address = address | ||||
|  | ||||
|     def write(self, out): | ||||
|         """Used in constructing an outgoing packet""" | ||||
|         out.write_string(self.address) | ||||
|  | ||||
|  | ||||
| class DNSText(DNSRecord): | ||||
|     """A DNS text record""" | ||||
|  | ||||
|     def __init__(self, name, type_, class_, ttl, text): | ||||
|         assert isinstance(text, (bytes, type(None))) | ||||
|         DNSRecord.__init__(self, name, type_, class_, ttl) | ||||
|         self.text = text | ||||
|  | ||||
|     def write(self, out): | ||||
|         """Used in constructing an outgoing packet""" | ||||
|         out.write_string(self.text) | ||||
|  | ||||
|  | ||||
| class DNSIncoming(QuietLogger): | ||||
|     """Object representation of an incoming DNS packet""" | ||||
|  | ||||
|     def __init__(self, data): | ||||
|         """Constructor from string holding bytes of packet""" | ||||
|         self.offset = 0 | ||||
|         self.data = data | ||||
|         self.questions = [] | ||||
|         self.answers = [] | ||||
|         self.id = 0 | ||||
|         self.flags = 0  # type: int | ||||
|         self.num_questions = 0 | ||||
|         self.num_answers = 0 | ||||
|         self.num_authorities = 0 | ||||
|         self.num_additionals = 0 | ||||
|         self.valid = False | ||||
|  | ||||
|         try: | ||||
|             self.read_header() | ||||
|             self.read_questions() | ||||
|             self.read_others() | ||||
|             self.valid = True | ||||
|  | ||||
|         except (IndexError, struct.error, IncomingDecodeError): | ||||
|             self.log_exception_warning(( | ||||
|                 'Choked at offset %d while unpacking %r', self.offset, data)) | ||||
|  | ||||
|     def unpack(self, format_): | ||||
|         length = struct.calcsize(format_) | ||||
|         info = struct.unpack( | ||||
|             format_, self.data[self.offset:self.offset + length]) | ||||
|         self.offset += length | ||||
|         return info | ||||
|  | ||||
|     def read_header(self): | ||||
|         """Reads header portion of packet""" | ||||
|         (self.id, self.flags, self.num_questions, self.num_answers, | ||||
|          self.num_authorities, self.num_additionals) = self.unpack(b'!6H') | ||||
|  | ||||
|     def read_questions(self): | ||||
|         """Reads questions section of packet""" | ||||
|         for _ in range(self.num_questions): | ||||
|             name = self.read_name() | ||||
|             type_, class_ = self.unpack(b'!HH') | ||||
|  | ||||
|             question = DNSQuestion(name, type_, class_) | ||||
|             self.questions.append(question) | ||||
|  | ||||
|     def read_character_string(self): | ||||
|         """Reads a character string from the packet""" | ||||
|         length = self.data[self.offset] | ||||
|         self.offset += 1 | ||||
|         return self.read_string(length) | ||||
|  | ||||
|     def read_string(self, length): | ||||
|         """Reads a string of a given length from the packet""" | ||||
|         info = self.data[self.offset:self.offset + length] | ||||
|         self.offset += length | ||||
|         return info | ||||
|  | ||||
|     def read_unsigned_short(self): | ||||
|         """Reads an unsigned short from the packet""" | ||||
|         return self.unpack(b'!H')[0] | ||||
|  | ||||
|     def read_others(self): | ||||
|         """Reads the answers, authorities and additionals section of the | ||||
|         packet""" | ||||
|         n = self.num_answers + self.num_authorities + self.num_additionals | ||||
|         for _ in range(n): | ||||
|             domain = self.read_name() | ||||
|             type_, class_, ttl, length = self.unpack(b'!HHiH') | ||||
|  | ||||
|             rec = None | ||||
|             if type_ == _TYPE_A: | ||||
|                 rec = DNSAddress( | ||||
|                     domain, type_, class_, ttl, self.read_string(4)) | ||||
|             elif type_ == _TYPE_TXT: | ||||
|                 rec = DNSText( | ||||
|                     domain, type_, class_, ttl, self.read_string(length)) | ||||
|             elif type_ == _TYPE_AAAA: | ||||
|                 rec = DNSAddress( | ||||
|                     domain, type_, class_, ttl, self.read_string(16)) | ||||
|             else: | ||||
|                 # Try to ignore types we don't know about | ||||
|                 # Skip the payload for the resource record so the next | ||||
|                 # records can be parsed correctly | ||||
|                 self.offset += length | ||||
|  | ||||
|             if rec is not None: | ||||
|                 self.answers.append(rec) | ||||
|  | ||||
|     def is_query(self): | ||||
|         """Returns true if this is a query""" | ||||
|         return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY | ||||
|  | ||||
|     def is_response(self): | ||||
|         """Returns true if this is a response""" | ||||
|         return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE | ||||
|  | ||||
|     def read_utf(self, offset, length): | ||||
|         """Reads a UTF-8 string of a given length from the packet""" | ||||
|         return text_type(self.data[offset:offset + length], 'utf-8', 'replace') | ||||
|  | ||||
|     def read_name(self): | ||||
|         """Reads a domain name from the packet""" | ||||
|         result = '' | ||||
|         off = self.offset | ||||
|         next_ = -1 | ||||
|         first = off | ||||
|  | ||||
|         while True: | ||||
|             length = indexbytes(self.data, off) | ||||
|             off += 1 | ||||
|             if length == 0: | ||||
|                 break | ||||
|             t = length & 0xC0 | ||||
|             if t == 0x00: | ||||
|                 result = ''.join((result, self.read_utf(off, length) + '.')) | ||||
|                 off += length | ||||
|             elif t == 0xC0: | ||||
|                 if next_ < 0: | ||||
|                     next_ = off + 1 | ||||
|                 off = ((length & 0x3F) << 8) | indexbytes(self.data, off) | ||||
|                 if off >= first: | ||||
|                     raise IncomingDecodeError( | ||||
|                         "Bad domain name (circular) at %s" % (off,)) | ||||
|                 first = off | ||||
|             else: | ||||
|                 raise IncomingDecodeError("Bad domain name at %s" % (off,)) | ||||
|  | ||||
|         if next_ >= 0: | ||||
|             self.offset = next_ | ||||
|         else: | ||||
|             self.offset = off | ||||
|  | ||||
|         return result | ||||
|  | ||||
|  | ||||
| class DNSOutgoing(object): | ||||
|     """Object representation of an outgoing packet""" | ||||
|  | ||||
|     def __init__(self, flags): | ||||
|         self.finished = False | ||||
|         self.id = 0 | ||||
|         self.flags = flags | ||||
|         self.names = {} | ||||
|         self.data = [] | ||||
|         self.size = 12 | ||||
|         self.state = False | ||||
|  | ||||
|         self.questions = [] | ||||
|         self.answers = [] | ||||
|  | ||||
|     def add_question(self, record): | ||||
|         """Adds a question""" | ||||
|         self.questions.append(record) | ||||
|  | ||||
|     def pack(self, format_, value): | ||||
|         self.data.append(struct.pack(format_, value)) | ||||
|         self.size += struct.calcsize(format_) | ||||
|  | ||||
|     def write_byte(self, value): | ||||
|         """Writes a single byte to the packet""" | ||||
|         self.pack(b'!c', int2byte(value)) | ||||
|  | ||||
|     def insert_short(self, index, value): | ||||
|         """Inserts an unsigned short in a certain position in the packet""" | ||||
|         self.data.insert(index, struct.pack(b'!H', value)) | ||||
|         self.size += 2 | ||||
|  | ||||
|     def write_short(self, value): | ||||
|         """Writes an unsigned short to the packet""" | ||||
|         self.pack(b'!H', value) | ||||
|  | ||||
|     def write_int(self, value): | ||||
|         """Writes an unsigned integer to the packet""" | ||||
|         self.pack(b'!I', int(value)) | ||||
|  | ||||
|     def write_string(self, value): | ||||
|         """Writes a string to the packet""" | ||||
|         assert isinstance(value, bytes) | ||||
|         self.data.append(value) | ||||
|         self.size += len(value) | ||||
|  | ||||
|     def write_utf(self, s): | ||||
|         """Writes a UTF-8 string of a given length to the packet""" | ||||
|         utfstr = s.encode('utf-8') | ||||
|         length = len(utfstr) | ||||
|         self.write_byte(length) | ||||
|         self.write_string(utfstr) | ||||
|  | ||||
|     def write_character_string(self, value): | ||||
|         assert isinstance(value, bytes) | ||||
|         length = len(value) | ||||
|         self.write_byte(length) | ||||
|         self.write_string(value) | ||||
|  | ||||
|     def write_name(self, name): | ||||
|         # split name into each label | ||||
|         parts = name.split('.') | ||||
|         if not parts[-1]: | ||||
|             parts.pop() | ||||
|  | ||||
|         # construct each suffix | ||||
|         name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))] | ||||
|  | ||||
|         # look for an existing name or suffix | ||||
|         for count, sub_name in enumerate(name_suffices): | ||||
|             if sub_name in self.names: | ||||
|                 break | ||||
|         else: | ||||
|             count = len(name_suffices) | ||||
|  | ||||
|         # note the new names we are saving into the packet | ||||
|         name_length = len(name.encode('utf-8')) | ||||
|         for suffix in name_suffices[:count]: | ||||
|             self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1 | ||||
|  | ||||
|         # write the new names out. | ||||
|         for part in parts[:count]: | ||||
|             self.write_utf(part) | ||||
|  | ||||
|         # if we wrote part of the name, create a pointer to the rest | ||||
|         if count != len(name_suffices): | ||||
|             # Found substring in packet, create pointer | ||||
|             index = self.names[name_suffices[count]] | ||||
|             self.write_byte((index >> 8) | 0xC0) | ||||
|             self.write_byte(index & 0xFF) | ||||
|         else: | ||||
|             # this is the end of a name | ||||
|             self.write_byte(0) | ||||
|  | ||||
|     def write_question(self, question): | ||||
|         self.write_name(question.name) | ||||
|         self.write_short(question.type) | ||||
|         self.write_short(question.class_) | ||||
|  | ||||
|     def packet(self): | ||||
|         if not self.state: | ||||
|             for question in self.questions: | ||||
|                 self.write_question(question) | ||||
|             self.state = True | ||||
|  | ||||
|             self.insert_short(0, 0)  # num additionals | ||||
|             self.insert_short(0, 0)  # num authorities | ||||
|             self.insert_short(0, 0)  # num answers | ||||
|             self.insert_short(0, len(self.questions)) | ||||
|             self.insert_short(0, self.flags)  # _FLAGS_QR_QUERY | ||||
|             self.insert_short(0, 0) | ||||
|         return b''.join(self.data) | ||||
|  | ||||
|  | ||||
| class Engine(threading.Thread): | ||||
|     def __init__(self, zc): | ||||
|         threading.Thread.__init__(self, name='zeroconf-Engine') | ||||
|         self.daemon = True | ||||
|         self.zc = zc | ||||
|         self.readers = {} | ||||
|         self.timeout = 5 | ||||
|         self.condition = threading.Condition() | ||||
|         self.start() | ||||
|  | ||||
|     def run(self): | ||||
|         while not self.zc.done: | ||||
|             # pylint: disable=len-as-condition | ||||
|             with self.condition: | ||||
|                 rs = self.readers.keys() | ||||
|                 if len(rs) == 0: | ||||
|                     # No sockets to manage, but we wait for the timeout | ||||
|                     # or addition of a socket | ||||
|                     self.condition.wait(self.timeout) | ||||
|  | ||||
|             if len(rs) != 0: | ||||
|                 try: | ||||
|                     rr, _, _ = select.select(rs, [], [], self.timeout) | ||||
|                     if not self.zc.done: | ||||
|                         for socket_ in rr: | ||||
|                             reader = self.readers.get(socket_) | ||||
|                             if reader: | ||||
|                                 reader.handle_read(socket_) | ||||
|  | ||||
|                 except (select.error, socket.error) as e: | ||||
|                     # If the socket was closed by another thread, during | ||||
|                     # shutdown, ignore it and exit | ||||
|                     if e.args[0] != socket.EBADF or not self.zc.done: | ||||
|                         raise | ||||
|  | ||||
|     def add_reader(self, reader, socket_): | ||||
|         with self.condition: | ||||
|             self.readers[socket_] = reader | ||||
|             self.condition.notify() | ||||
|  | ||||
|     def del_reader(self, socket_): | ||||
|         with self.condition: | ||||
|             del self.readers[socket_] | ||||
|             self.condition.notify() | ||||
|  | ||||
|  | ||||
| class Listener(QuietLogger): | ||||
|     def __init__(self, zc): | ||||
|         self.zc = zc | ||||
|         self.data = None | ||||
|  | ||||
|     def handle_read(self, socket_): | ||||
|         try: | ||||
|             data, (addr, port) = socket_.recvfrom(_MAX_MSG_ABSOLUTE) | ||||
|         except Exception:  # pylint: disable=broad-except | ||||
|             self.log_exception_warning() | ||||
|             return | ||||
|  | ||||
|         log.debug('Received from %r:%r: %r ', addr, port, data) | ||||
|  | ||||
|         self.data = data | ||||
|         msg = DNSIncoming(data) | ||||
|         if not msg.valid or msg.is_query(): | ||||
|             pass | ||||
|         else: | ||||
|             self.zc.handle_response(msg) | ||||
|  | ||||
|  | ||||
| class RecordUpdateListener(object): | ||||
|     def update_record(self, zc, now, record): | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|  | ||||
| class HostResolver(RecordUpdateListener): | ||||
|     def __init__(self, name): | ||||
|         self.name = name | ||||
|         self.address = None | ||||
|  | ||||
|     def update_record(self, zc, now, record): | ||||
|         if record is None: | ||||
|             return | ||||
|         if record.type == _TYPE_A: | ||||
|             assert isinstance(record, DNSAddress) | ||||
|             if record.name == self.name: | ||||
|                 self.address = record.address | ||||
|  | ||||
|     def request(self, zc, timeout): | ||||
|         now = time.time() | ||||
|         delay = 0.2 | ||||
|         next_ = now + delay | ||||
|         last = now + timeout | ||||
|  | ||||
|         try: | ||||
|             zc.add_listener(self) | ||||
|             while self.address is None: | ||||
|                 if last <= now: | ||||
|                     # Timeout | ||||
|                     return False | ||||
|                 if next_ <= now: | ||||
|                     out = DNSOutgoing(_FLAGS_QR_QUERY) | ||||
|                     out.add_question( | ||||
|                         DNSQuestion(self.name, _TYPE_A, _CLASS_IN)) | ||||
|                     zc.send(out) | ||||
|                     next_ = now + delay | ||||
|                     delay *= 2 | ||||
|  | ||||
|                 zc.wait(min(next_, last) - now) | ||||
|                 now = time.time() | ||||
|         finally: | ||||
|             zc.remove_listener(self) | ||||
|  | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class DashboardStatus(RecordUpdateListener, threading.Thread): | ||||
|     def __init__(self, zc, on_update): | ||||
|         threading.Thread.__init__(self) | ||||
|         self.zc = zc | ||||
|         self.query_hosts = set() | ||||
|         self.key_to_host = {} | ||||
|         self.cache = {} | ||||
|         self.stop_event = threading.Event() | ||||
|         self.query_event = threading.Event() | ||||
|         self.on_update = on_update | ||||
|  | ||||
|     def update_record(self, zc, now, record): | ||||
|         if record is None: | ||||
|             return | ||||
|         if record.type in (_TYPE_A, _TYPE_AAAA, _TYPE_TXT): | ||||
|             assert isinstance(record, DNSEntry) | ||||
|             if record.name in self.query_hosts: | ||||
|                 self.cache.setdefault(record.name, []).insert(0, record) | ||||
|             self.purge_cache() | ||||
|  | ||||
|     def purge_cache(self): | ||||
|         new_cache = {} | ||||
|         for host, records in self.cache.items(): | ||||
|             if host not in self.query_hosts: | ||||
|                 continue | ||||
|             new_records = [rec for rec in records if not rec.is_removable(time.time())] | ||||
|             if new_records: | ||||
|                 new_cache[host] = new_records | ||||
|         self.cache = new_cache | ||||
|         self.on_update({key: self.host_status(key) for key in self.key_to_host}) | ||||
|  | ||||
|     def request_query(self, hosts): | ||||
|         self.query_hosts = set(host for host in hosts.values()) | ||||
|         self.key_to_host = hosts | ||||
|         self.query_event.set() | ||||
|  | ||||
|     def stop(self): | ||||
|         self.stop_event.set() | ||||
|         self.query_event.set() | ||||
|  | ||||
|     def host_status(self, key): | ||||
|         return self.key_to_host.get(key) in self.cache | ||||
|  | ||||
|     def run(self): | ||||
|         self.zc.add_listener(self) | ||||
|         while not self.stop_event.is_set(): | ||||
|             self.purge_cache() | ||||
|             for host in self.query_hosts: | ||||
|                 if all(record.is_expired(time.time()) for record in self.cache.get(host, [])): | ||||
|                     out = DNSOutgoing(_FLAGS_QR_QUERY) | ||||
|                     out.add_question( | ||||
|                         DNSQuestion(host, _TYPE_A, _CLASS_IN)) | ||||
|                     self.zc.send(out) | ||||
|             self.query_event.wait() | ||||
|             self.query_event.clear() | ||||
|         self.zc.remove_listener(self) | ||||
|  | ||||
|  | ||||
| def get_all_addresses(): | ||||
|     return list(set( | ||||
|         addr.ip | ||||
|         for iface in ifaddr.get_adapters() | ||||
|         for addr in iface.ips | ||||
|         if addr.is_IPv4 and addr.network_prefix != 32  # Host only netmask 255.255.255.255 | ||||
|     )) | ||||
|  | ||||
|  | ||||
| def new_socket(): | ||||
|     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||||
|     s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | ||||
|  | ||||
|     # SO_REUSEADDR should be equivalent to SO_REUSEPORT for | ||||
|     # multicast UDP sockets (p 731, "TCP/IP Illustrated, | ||||
|     # Volume 2"), but some BSD-derived systems require | ||||
|     # SO_REUSEPORT to be specified explicitly.  Also, not all | ||||
|     # versions of Python have SO_REUSEPORT available. | ||||
|     # Catch OSError and socket.error for kernel versions <3.9 because lacking | ||||
|     # SO_REUSEPORT support. | ||||
|     try: | ||||
|         reuseport = socket.SO_REUSEPORT | ||||
|     except AttributeError: | ||||
|         pass | ||||
|     else: | ||||
|         try: | ||||
|             s.setsockopt(socket.SOL_SOCKET, reuseport, 1) | ||||
|         except (OSError, socket.error) as err: | ||||
|             # OSError on python 3, socket.error on python 2 | ||||
|             if err.errno != errno.ENOPROTOOPT: | ||||
|                 raise | ||||
|  | ||||
|     # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and | ||||
|     # IP_MULTICAST_LOOP socket options as an unsigned char. | ||||
|     ttl = struct.pack(b'B', 255) | ||||
|     s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) | ||||
|     loop = struct.pack(b'B', 1) | ||||
|     s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) | ||||
|  | ||||
|     s.bind(('', _MDNS_PORT)) | ||||
|     return s | ||||
|  | ||||
|  | ||||
| class Zeroconf(QuietLogger): | ||||
|     def __init__(self): | ||||
|         # hook for threads | ||||
|         self._GLOBAL_DONE = False | ||||
|  | ||||
|         self._listen_socket = new_socket() | ||||
|         interfaces = get_all_addresses() | ||||
|  | ||||
|         self._respond_sockets = [] | ||||
|  | ||||
|         for i in interfaces: | ||||
|             try: | ||||
|                 _value = socket.inet_aton(_MDNS_ADDR) + socket.inet_aton(i) | ||||
|                 self._listen_socket.setsockopt( | ||||
|                     socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value) | ||||
|             except socket.error as e: | ||||
|                 _errno = e.args[0] | ||||
|                 if _errno == errno.EADDRINUSE: | ||||
|                     log.info( | ||||
|                         'Address in use when adding %s to multicast group, ' | ||||
|                         'it is expected to happen on some systems', i, | ||||
|                     ) | ||||
|                 elif _errno == errno.EADDRNOTAVAIL: | ||||
|                     log.info( | ||||
|                         'Address not available when adding %s to multicast ' | ||||
|                         'group, it is expected to happen on some systems', i, | ||||
|                     ) | ||||
|                     continue | ||||
|                 elif _errno == errno.EINVAL: | ||||
|                     log.info( | ||||
|                         'Interface of %s does not support multicast, ' | ||||
|                         'it is expected in WSL', i | ||||
|                     ) | ||||
|                     continue | ||||
|  | ||||
|                 else: | ||||
|                     raise | ||||
|  | ||||
|             respond_socket = new_socket() | ||||
|             respond_socket.setsockopt( | ||||
|                 socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(i)) | ||||
|  | ||||
|             self._respond_sockets.append(respond_socket) | ||||
|  | ||||
|         self.listeners = [] | ||||
|  | ||||
|         self.condition = threading.Condition() | ||||
|  | ||||
|         self.engine = Engine(self) | ||||
|         self.listener = Listener(self) | ||||
|         self.engine.add_reader(self.listener, self._listen_socket) | ||||
|  | ||||
|     @property | ||||
|     def done(self): | ||||
|         return self._GLOBAL_DONE | ||||
|  | ||||
|     def wait(self, timeout): | ||||
|         """Calling thread waits for a given number of milliseconds or | ||||
|         until notified.""" | ||||
|         with self.condition: | ||||
|             self.condition.wait(timeout) | ||||
|  | ||||
|     def notify_all(self): | ||||
|         """Notifies all waiting threads""" | ||||
|         with self.condition: | ||||
|             self.condition.notify_all() | ||||
|  | ||||
|     def resolve_host(self, host, timeout=3.0): | ||||
|         info = HostResolver(host) | ||||
|         if info.request(self, timeout): | ||||
|             return socket.inet_ntoa(info.address) | ||||
|         return None | ||||
|  | ||||
|     def add_listener(self, listener): | ||||
|         self.listeners.append(listener) | ||||
|         self.notify_all() | ||||
|  | ||||
|     def remove_listener(self, listener): | ||||
|         """Removes a listener.""" | ||||
|         try: | ||||
|             self.listeners.remove(listener) | ||||
|             self.notify_all() | ||||
|         except Exception as e:  # pylint: disable=broad-except | ||||
|             log.exception('Unknown error, possibly benign: %r', e) | ||||
|  | ||||
|     def update_record(self, now, rec): | ||||
|         """Used to notify listeners of new information that has updated | ||||
|         a record.""" | ||||
|         for listener in self.listeners: | ||||
|             listener.update_record(self, now, rec) | ||||
|         self.notify_all() | ||||
|  | ||||
|     def handle_response(self, msg): | ||||
|         """Deal with incoming response packets.  All answers | ||||
|         are held in the cache, and listeners are notified.""" | ||||
|         now = time.time() | ||||
|         for record in msg.answers: | ||||
|             self.update_record(now, record) | ||||
|  | ||||
|     def send(self, out): | ||||
|         """Sends an outgoing packet.""" | ||||
|         packet = out.packet() | ||||
|         log.debug('Sending %r (%d bytes) as %r...', out, len(packet), packet) | ||||
|         for s in self._respond_sockets: | ||||
|             if self._GLOBAL_DONE: | ||||
|                 return | ||||
|             try: | ||||
|                 bytes_sent = s.sendto(packet, 0, (_MDNS_ADDR, _MDNS_PORT)) | ||||
|             except Exception:  # pylint: disable=broad-except | ||||
|                 # on send errors, log the exception and keep going | ||||
|                 self.log_exception_warning() | ||||
|             else: | ||||
|                 if bytes_sent != len(packet): | ||||
|                     self.log_warning_once( | ||||
|                         '!!! sent %d out of %d bytes to %r' % ( | ||||
|                             bytes_sent, len(packet), s)) | ||||
|  | ||||
|     def close(self): | ||||
|         """Ends the background threads, and prevent this instance from | ||||
|         servicing further queries.""" | ||||
|         if not self._GLOBAL_DONE: | ||||
|             self._GLOBAL_DONE = True | ||||
|             # shutdown recv socket and thread | ||||
|             self.engine.del_reader(self._listen_socket) | ||||
|             self._listen_socket.close() | ||||
|             self.engine.join() | ||||
|  | ||||
|             # shutdown the rest | ||||
|             self.notify_all() | ||||
|             for s in self._respond_sockets: | ||||
|                 s.close() | ||||
		Reference in New Issue
	
	Block a user