1
0
mirror of https://github.com/esphome/esphome.git synced 2025-01-18 20:10:55 +00:00

Lint the script folder files (#5991)

This commit is contained in:
Jesse Hills 2023-12-22 20:03:47 +13:00 committed by GitHub
parent 676ae6b26e
commit d2d0058386
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 398 additions and 377 deletions

View File

@ -17,28 +17,22 @@ then run this script with python3 and the files
will be generated, they still need to be formatted will be generated, they still need to be formatted
""" """
import re
import os import os
import re
import sys
from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from textwrap import dedent
from subprocess import call from subprocess import call
from textwrap import dedent
# Generate with # Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto # protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import aioesphomeapi.api_options_pb2 as pb import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor import google.protobuf.descriptor_pb2 as descriptor
file_header = "// This file was automatically generated with a tool.\n" FILE_HEADER = """// This file was automatically generated with a tool.
file_header += "// See scripts/api_protobuf/api_protobuf.py\n" // See scripts/api_protobuf/api_protobuf.py
"""
cwd = Path(__file__).resolve().parent
root = cwd.parent.parent / "esphome" / "components" / "api"
prot = root / "api.protoc"
call(["protoc", "-o", str(prot), "-I", str(root), "api.proto"])
content = prot.read_bytes()
d = descriptor.FileDescriptorSet.FromString(content)
def indent_list(text, padding=" "): def indent_list(text, padding=" "):
@ -64,7 +58,7 @@ def camel_to_snake(name):
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
class TypeInfo: class TypeInfo(ABC):
def __init__(self, field): def __init__(self, field):
self._field = field self._field = field
@ -186,10 +180,12 @@ class TypeInfo:
def dump_content(self): def dump_content(self):
o = f'out.append(" {self.name}: ");\n' o = f'out.append(" {self.name}: ");\n'
o += self.dump(f"this->{self.field_name}") + "\n" o += self.dump(f"this->{self.field_name}") + "\n"
o += f'out.append("\\n");\n' o += 'out.append("\\n");\n'
return o return o
dump = None @abstractmethod
def dump(self, name: str):
pass
TYPE_INFO = {} TYPE_INFO = {}
@ -212,7 +208,7 @@ class DoubleType(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%g", {name});\n' o = f'sprintf(buffer, "%g", {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -225,7 +221,7 @@ class FloatType(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%g", {name});\n' o = f'sprintf(buffer, "%g", {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -238,7 +234,7 @@ class Int64Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%lld", {name});\n' o = f'sprintf(buffer, "%lld", {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -251,7 +247,7 @@ class UInt64Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%llu", {name});\n' o = f'sprintf(buffer, "%llu", {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -264,7 +260,7 @@ class Int32Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%" PRId32, {name});\n' o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -277,7 +273,7 @@ class Fixed64Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%llu", {name});\n' o = f'sprintf(buffer, "%llu", {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -290,7 +286,7 @@ class Fixed32Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%" PRIu32, {name});\n' o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -372,7 +368,7 @@ class UInt32Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%" PRIu32, {name});\n' o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -406,7 +402,7 @@ class SFixed32Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%" PRId32, {name});\n' o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -419,7 +415,7 @@ class SFixed64Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%lld", {name});\n' o = f'sprintf(buffer, "%lld", {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -432,7 +428,7 @@ class SInt32Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%" PRId32, {name});\n' o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -445,7 +441,7 @@ class SInt64Type(TypeInfo):
def dump(self, name): def dump(self, name):
o = f'sprintf(buffer, "%lld", {name});\n' o = f'sprintf(buffer, "%lld", {name});\n'
o += f"out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -527,7 +523,7 @@ class RepeatedTypeInfo(TypeInfo):
def encode_content(self): def encode_content(self):
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n" o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n" o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
o += f"}}" o += "}"
return o return o
@property @property
@ -535,10 +531,13 @@ class RepeatedTypeInfo(TypeInfo):
o = f'for (const auto {"" if self._ti_is_bool else "&"}it : this->{self.field_name}) {{\n' o = f'for (const auto {"" if self._ti_is_bool else "&"}it : this->{self.field_name}) {{\n'
o += f' out.append(" {self.name}: ");\n' o += f' out.append(" {self.name}: ");\n'
o += indent(self._ti.dump("it")) + "\n" o += indent(self._ti.dump("it")) + "\n"
o += f' out.append("\\n");\n' o += ' out.append("\\n");\n'
o += f"}}\n" o += "}\n"
return o return o
def dump(self, _: str):
pass
def build_enum_type(desc): def build_enum_type(desc):
name = desc.name name = desc.name
@ -547,17 +546,17 @@ def build_enum_type(desc):
out += f" {v.name} = {v.number},\n" out += f" {v.name} = {v.number},\n"
out += "};\n" out += "};\n"
cpp = f"#ifdef HAS_PROTO_MESSAGE_DUMP\n" cpp = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cpp += f"template<> const char *proto_enum_to_string<enums::{name}>(enums::{name} value) {{\n" cpp += f"template<> const char *proto_enum_to_string<enums::{name}>(enums::{name} value) {{\n"
cpp += f" switch (value) {{\n" cpp += " switch (value) {\n"
for v in desc.value: for v in desc.value:
cpp += f" case enums::{v.name}:\n" cpp += f" case enums::{v.name}:\n"
cpp += f' return "{v.name}";\n' cpp += f' return "{v.name}";\n'
cpp += f" default:\n" cpp += " default:\n"
cpp += f' return "UNKNOWN";\n' cpp += ' return "UNKNOWN";\n'
cpp += f" }}\n" cpp += " }\n"
cpp += f"}}\n" cpp += "}\n"
cpp += f"#endif\n" cpp += "#endif\n"
return out, cpp return out, cpp
@ -652,10 +651,10 @@ def build_message_type(desc):
o += f" {dump[0]} " o += f" {dump[0]} "
else: else:
o += "\n" o += "\n"
o += f" __attribute__((unused)) char buffer[64];\n" o += " __attribute__((unused)) char buffer[64];\n"
o += f' out.append("{desc.name} {{\\n");\n' o += f' out.append("{desc.name} {{\\n");\n'
o += indent("\n".join(dump)) + "\n" o += indent("\n".join(dump)) + "\n"
o += f' out.append("}}");\n' o += ' out.append("}");\n'
else: else:
o2 = f'out.append("{desc.name} {{}}");' o2 = f'out.append("{desc.name} {{}}");'
if len(o) + len(o2) + 3 < 120: if len(o) + len(o2) + 3 < 120:
@ -664,9 +663,9 @@ def build_message_type(desc):
o += "\n" o += "\n"
o += f" {o2}\n" o += f" {o2}\n"
o += "}\n" o += "}\n"
cpp += f"#ifdef HAS_PROTO_MESSAGE_DUMP\n" cpp += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cpp += o cpp += o
cpp += f"#endif\n" cpp += "#endif\n"
prot = "#ifdef HAS_PROTO_MESSAGE_DUMP\n" prot = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
prot += "void dump_to(std::string &out) const override;\n" prot += "void dump_to(std::string &out) const override;\n"
prot += "#endif\n" prot += "#endif\n"
@ -684,71 +683,12 @@ def build_message_type(desc):
return out, cpp return out, cpp
file = d.file[0]
content = file_header
content += """\
#pragma once
#include "proto.h"
namespace esphome {
namespace api {
"""
cpp = file_header
cpp += """\
#include "api_pb2.h"
#include "esphome/core/log.h"
#include <cinttypes>
namespace esphome {
namespace api {
"""
content += "namespace enums {\n\n"
for enum in file.enum_type:
s, c = build_enum_type(enum)
content += s
cpp += c
content += "\n} // namespace enums\n\n"
mt = file.message_type
for m in mt:
s, c = build_message_type(m)
content += s
cpp += c
content += """\
} // namespace api
} // namespace esphome
"""
cpp += """\
} // namespace api
} // namespace esphome
"""
with open(root / "api_pb2.h", "w") as f:
f.write(content)
with open(root / "api_pb2.cpp", "w") as f:
f.write(cpp)
SOURCE_BOTH = 0 SOURCE_BOTH = 0
SOURCE_SERVER = 1 SOURCE_SERVER = 1
SOURCE_CLIENT = 2 SOURCE_CLIENT = 2
RECEIVE_CASES = {} RECEIVE_CASES = {}
class_name = "APIServerConnectionBase"
ifdefs = {} ifdefs = {}
@ -768,7 +708,6 @@ def build_service_message_type(mt):
ifdef = get_opt(mt, pb.ifdef) ifdef = get_opt(mt, pb.ifdef)
log = get_opt(mt, pb.log, True) log = get_opt(mt, pb.log, True)
nodelay = get_opt(mt, pb.no_delay, False)
hout = "" hout = ""
cout = "" cout = ""
@ -781,14 +720,14 @@ def build_service_message_type(mt):
# Generate send # Generate send
func = f"send_{snake}" func = f"send_{snake}"
hout += f"bool {func}(const {mt.name} &msg);\n" hout += f"bool {func}(const {mt.name} &msg);\n"
cout += f"bool {class_name}::{func}(const {mt.name} &msg) {{\n" cout += f"bool APIServerConnectionBase::{func}(const {mt.name} &msg) {{\n"
if log: if log:
cout += f"#ifdef HAS_PROTO_MESSAGE_DUMP\n" cout += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cout += f' ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n' cout += f' ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
cout += f"#endif\n" cout += "#endif\n"
# cout += f' this->set_nodelay({str(nodelay).lower()});\n' # cout += f' this->set_nodelay({str(nodelay).lower()});\n'
cout += f" return this->send_message_<{mt.name}>(msg, {id_});\n" cout += f" return this->send_message_<{mt.name}>(msg, {id_});\n"
cout += f"}}\n" cout += "}\n"
if source in (SOURCE_BOTH, SOURCE_CLIENT): if source in (SOURCE_BOTH, SOURCE_CLIENT):
# Generate receive # Generate receive
func = f"on_{snake}" func = f"on_{snake}"
@ -797,169 +736,242 @@ def build_service_message_type(mt):
if ifdef is not None: if ifdef is not None:
case += f"#ifdef {ifdef}\n" case += f"#ifdef {ifdef}\n"
case += f"{mt.name} msg;\n" case += f"{mt.name} msg;\n"
case += f"msg.decode(msg_data, msg_size);\n" case += "msg.decode(msg_data, msg_size);\n"
if log: if log:
case += f"#ifdef HAS_PROTO_MESSAGE_DUMP\n" case += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n' case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
case += f"#endif\n" case += "#endif\n"
case += f"this->{func}(msg);\n" case += f"this->{func}(msg);\n"
if ifdef is not None: if ifdef is not None:
case += f"#endif\n" case += "#endif\n"
case += "break;" case += "break;"
RECEIVE_CASES[id_] = case RECEIVE_CASES[id_] = case
if ifdef is not None: if ifdef is not None:
hout += f"#endif\n" hout += "#endif\n"
cout += f"#endif\n" cout += "#endif\n"
return hout, cout return hout, cout
hpp = file_header def main():
hpp += """\ cwd = Path(__file__).resolve().parent
#pragma once root = cwd.parent.parent / "esphome" / "components" / "api"
prot_file = root / "api.protoc"
call(["protoc", "-o", str(prot_file), "-I", str(root), "api.proto"])
proto_content = prot_file.read_bytes()
#include "api_pb2.h" # pylint: disable-next=no-member
#include "esphome/core/defines.h" d = descriptor.FileDescriptorSet.FromString(proto_content)
namespace esphome { file = d.file[0]
namespace api { content = FILE_HEADER
content += """\
#pragma once
""" #include "proto.h"
cpp = file_header namespace esphome {
cpp += """\ namespace api {
#include "api_pb2_service.h"
#include "esphome/core/log.h"
namespace esphome { """
namespace api {
static const char *const TAG = "api.service"; cpp = FILE_HEADER
cpp += """\
#include "api_pb2.h"
#include "esphome/core/log.h"
""" #include <cinttypes>
hpp += f"class {class_name} : public ProtoService {{\n" namespace esphome {
hpp += " public:\n" namespace api {
for mt in file.message_type: """
obj = build_service_message_type(mt)
if obj is None:
continue
hout, cout = obj
hpp += indent(hout) + "\n"
cpp += cout
cases = list(RECEIVE_CASES.items()) content += "namespace enums {\n\n"
cases.sort()
hpp += " protected:\n"
hpp += f" bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n"
out = f"bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n"
out += f" switch (msg_type) {{\n"
for i, case in cases:
c = f"case {i}: {{\n"
c += indent(case) + "\n"
c += f"}}"
out += indent(c, " ") + "\n"
out += " default:\n"
out += " return false;\n"
out += " }\n"
out += " return true;\n"
out += "}\n"
cpp += out
hpp += "};\n"
serv = file.service[0] for enum in file.enum_type:
class_name = "APIServerConnection" s, c = build_enum_type(enum)
hpp += "\n" content += s
hpp += f"class {class_name} : public {class_name}Base {{\n" cpp += c
hpp += " public:\n"
hpp_protected = ""
cpp += "\n"
m = serv.method[0] content += "\n} // namespace enums\n\n"
for m in serv.method:
func = m.name
inp = m.input_type[1:]
ret = m.output_type[1:]
is_void = ret == "void"
snake = camel_to_snake(inp)
on_func = f"on_{snake}"
needs_conn = get_opt(m, pb.needs_setup_connection, True)
needs_auth = get_opt(m, pb.needs_authentication, True)
ifdef = ifdefs.get(inp, None) mt = file.message_type
if ifdef is not None: for m in mt:
hpp += f"#ifdef {ifdef}\n" s, c = build_message_type(m)
hpp_protected += f"#ifdef {ifdef}\n" content += s
cpp += f"#ifdef {ifdef}\n" cpp += c
hpp_protected += f" void {on_func}(const {inp} &msg) override;\n" content += """\
hpp += f" virtual {ret} {func}(const {inp} &msg) = 0;\n"
cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n"
body = ""
if needs_conn:
body += "if (!this->is_connection_setup()) {\n"
body += " this->on_no_setup_connection();\n"
body += " return;\n"
body += "}\n"
if needs_auth:
body += "if (!this->is_authenticated()) {\n"
body += " this->on_unauthenticated_access();\n"
body += " return;\n"
body += "}\n"
if is_void: } // namespace api
body += f"this->{func}(msg);\n" } // namespace esphome
else: """
body += f"{ret} ret = this->{func}(msg);\n" cpp += """\
ret_snake = camel_to_snake(ret)
body += f"if (!this->send_{ret_snake}(ret)) {{\n"
body += f" this->on_fatal_error();\n"
body += "}\n"
cpp += indent(body) + "\n" + "}\n"
if ifdef is not None: } // namespace api
hpp += f"#endif\n" } // namespace esphome
hpp_protected += f"#endif\n" """
cpp += f"#endif\n"
hpp += " protected:\n" with open(root / "api_pb2.h", "w", encoding="utf-8") as f:
hpp += hpp_protected f.write(content)
hpp += "};\n"
hpp += """\ with open(root / "api_pb2.cpp", "w", encoding="utf-8") as f:
f.write(cpp)
} // namespace api hpp = FILE_HEADER
} // namespace esphome hpp += """\
""" #pragma once
cpp += """\
} // namespace api #include "api_pb2.h"
} // namespace esphome #include "esphome/core/defines.h"
"""
with open(root / "api_pb2_service.h", "w") as f: namespace esphome {
f.write(hpp) namespace api {
with open(root / "api_pb2_service.cpp", "w") as f: """
f.write(cpp)
prot.unlink() cpp = FILE_HEADER
cpp += """\
#include "api_pb2_service.h"
#include "esphome/core/log.h"
try: namespace esphome {
import clang_format namespace api {
def exec_clang_format(path): static const char *const TAG = "api.service";
clang_format_path = os.path.join(
os.path.dirname(clang_format.__file__), "data", "bin", "clang-format"
)
call([clang_format_path, "-i", path])
exec_clang_format(root / "api_pb2_service.h") """
exec_clang_format(root / "api_pb2_service.cpp")
exec_clang_format(root / "api_pb2.h") class_name = "APIServerConnectionBase"
exec_clang_format(root / "api_pb2.cpp")
except ImportError: hpp += f"class {class_name} : public ProtoService {{\n"
pass hpp += " public:\n"
for mt in file.message_type:
obj = build_service_message_type(mt)
if obj is None:
continue
hout, cout = obj
hpp += indent(hout) + "\n"
cpp += cout
cases = list(RECEIVE_CASES.items())
cases.sort()
hpp += " protected:\n"
hpp += " bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n"
out = f"bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n"
out += " switch (msg_type) {\n"
for i, case in cases:
c = f"case {i}: {{\n"
c += indent(case) + "\n"
c += "}"
out += indent(c, " ") + "\n"
out += " default:\n"
out += " return false;\n"
out += " }\n"
out += " return true;\n"
out += "}\n"
cpp += out
hpp += "};\n"
serv = file.service[0]
class_name = "APIServerConnection"
hpp += "\n"
hpp += f"class {class_name} : public {class_name}Base {{\n"
hpp += " public:\n"
hpp_protected = ""
cpp += "\n"
m = serv.method[0]
for m in serv.method:
func = m.name
inp = m.input_type[1:]
ret = m.output_type[1:]
is_void = ret == "void"
snake = camel_to_snake(inp)
on_func = f"on_{snake}"
needs_conn = get_opt(m, pb.needs_setup_connection, True)
needs_auth = get_opt(m, pb.needs_authentication, True)
ifdef = ifdefs.get(inp, None)
if ifdef is not None:
hpp += f"#ifdef {ifdef}\n"
hpp_protected += f"#ifdef {ifdef}\n"
cpp += f"#ifdef {ifdef}\n"
hpp_protected += f" void {on_func}(const {inp} &msg) override;\n"
hpp += f" virtual {ret} {func}(const {inp} &msg) = 0;\n"
cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n"
body = ""
if needs_conn:
body += "if (!this->is_connection_setup()) {\n"
body += " this->on_no_setup_connection();\n"
body += " return;\n"
body += "}\n"
if needs_auth:
body += "if (!this->is_authenticated()) {\n"
body += " this->on_unauthenticated_access();\n"
body += " return;\n"
body += "}\n"
if is_void:
body += f"this->{func}(msg);\n"
else:
body += f"{ret} ret = this->{func}(msg);\n"
ret_snake = camel_to_snake(ret)
body += f"if (!this->send_{ret_snake}(ret)) {{\n"
body += " this->on_fatal_error();\n"
body += "}\n"
cpp += indent(body) + "\n" + "}\n"
if ifdef is not None:
hpp += "#endif\n"
hpp_protected += "#endif\n"
cpp += "#endif\n"
hpp += " protected:\n"
hpp += hpp_protected
hpp += "};\n"
hpp += """\
} // namespace api
} // namespace esphome
"""
cpp += """\
} // namespace api
} // namespace esphome
"""
with open(root / "api_pb2_service.h", "w", encoding="utf-8") as f:
f.write(hpp)
with open(root / "api_pb2_service.cpp", "w", encoding="utf-8") as f:
f.write(cpp)
prot_file.unlink()
try:
import clang_format
def exec_clang_format(path):
clang_format_path = os.path.join(
os.path.dirname(clang_format.__file__), "data", "bin", "clang-format"
)
call([clang_format_path, "-i", path])
exec_clang_format(root / "api_pb2_service.h")
exec_clang_format(root / "api_pb2_service.cpp")
exec_clang_format(root / "api_pb2.h")
exec_clang_format(root / "api_pb2.cpp")
except ImportError:
pass
if __name__ == "__main__":
sys.exit(main())

View File

@ -61,6 +61,7 @@ solve_registry = []
def get_component_names(): def get_component_names():
# pylint: disable-next=redefined-outer-name,reimported
from esphome.loader import CORE_COMPONENTS_PATH from esphome.loader import CORE_COMPONENTS_PATH
component_names = ["esphome", "sensor", "esp32", "esp8266"] component_names = ["esphome", "sensor", "esp32", "esp8266"]
@ -82,9 +83,12 @@ def load_components():
components[domain] = get_component(domain) components[domain] = get_component(domain)
# pylint: disable=wrong-import-position
from esphome.const import CONF_TYPE, KEY_CORE from esphome.const import CONF_TYPE, KEY_CORE
from esphome.core import CORE from esphome.core import CORE
# pylint: enable=wrong-import-position
CORE.data[KEY_CORE] = {} CORE.data[KEY_CORE] = {}
load_components() load_components()
@ -114,7 +118,7 @@ def write_file(name, obj):
def delete_extra_files(keep_names): def delete_extra_files(keep_names):
for d in os.listdir(args.output_path): for d in os.listdir(args.output_path):
if d.endswith(".json") and not d[:-5] in keep_names: if d.endswith(".json") and d[:-5] not in keep_names:
os.remove(os.path.join(args.output_path, d)) os.remove(os.path.join(args.output_path, d))
print(f"Deleted {d}") print(f"Deleted {d}")
@ -552,11 +556,11 @@ def shrink():
s = f"{domain}.{schema_name}" s = f"{domain}.{schema_name}"
if ( if (
not s.endswith("." + S_CONFIG_SCHEMA) not s.endswith("." + S_CONFIG_SCHEMA)
and s not in referenced_schemas.keys() and s not in referenced_schemas
and not is_platform_schema(s) and not is_platform_schema(s)
): ):
print(f"Removing {s}") print(f"Removing {s}")
output[domain][S_SCHEMAS].pop(schema_name) domain_schemas[S_SCHEMAS].pop(schema_name)
def build_schema(): def build_schema():
@ -564,7 +568,7 @@ def build_schema():
# check esphome was not loaded globally (IDE auto imports) # check esphome was not loaded globally (IDE auto imports)
if len(ejs.extended_schemas) == 0: if len(ejs.extended_schemas) == 0:
raise Exception( raise LookupError(
"no data collected. Did you globally import an ESPHome component?" "no data collected. Did you globally import an ESPHome component?"
) )
@ -703,7 +707,7 @@ def convert(schema, config_var, path):
if schema_instance is schema: if schema_instance is schema:
assert S_CONFIG_VARS not in config_var assert S_CONFIG_VARS not in config_var
assert S_EXTENDS not in config_var assert S_EXTENDS not in config_var
if not S_TYPE in config_var: if S_TYPE not in config_var:
config_var[S_TYPE] = S_SCHEMA config_var[S_TYPE] = S_SCHEMA
# assert config_var[S_TYPE] == S_SCHEMA # assert config_var[S_TYPE] == S_SCHEMA
@ -765,9 +769,9 @@ def convert(schema, config_var, path):
elif schema == automation.validate_potentially_and_condition: elif schema == automation.validate_potentially_and_condition:
config_var[S_TYPE] = "registry" config_var[S_TYPE] = "registry"
config_var["registry"] = "condition" config_var["registry"] = "condition"
elif schema == cv.int_ or schema == cv.int_range: elif schema in (cv.int_, cv.int_range):
config_var[S_TYPE] = "integer" config_var[S_TYPE] = "integer"
elif schema == cv.string or schema == cv.string_strict or schema == cv.valid_name: elif schema in (cv.string, cv.string_strict, cv.valid_name):
config_var[S_TYPE] = "string" config_var[S_TYPE] = "string"
elif isinstance(schema, vol.Schema): elif isinstance(schema, vol.Schema):
@ -779,6 +783,7 @@ def convert(schema, config_var, path):
config_var |= pin_validators[repr_schema] config_var |= pin_validators[repr_schema]
config_var[S_TYPE] = "pin" config_var[S_TYPE] = "pin"
# pylint: disable-next=too-many-nested-blocks
elif repr_schema in ejs.hidden_schemas: elif repr_schema in ejs.hidden_schemas:
schema_type = ejs.hidden_schemas[repr_schema] schema_type = ejs.hidden_schemas[repr_schema]
@ -869,7 +874,7 @@ def convert(schema, config_var, path):
config_var["use_id_type"] = str(data.base) config_var["use_id_type"] = str(data.base)
config_var[S_TYPE] = "use_id" config_var[S_TYPE] = "use_id"
else: else:
raise Exception("Unknown extracted schema type") raise TypeError("Unknown extracted schema type")
elif config_var.get("key") == "GeneratedID": elif config_var.get("key") == "GeneratedID":
if path.startswith("i2c/CONFIG_SCHEMA/") and path.endswith("/id"): if path.startswith("i2c/CONFIG_SCHEMA/") and path.endswith("/id"):
config_var["id_type"] = { config_var["id_type"] = {
@ -884,7 +889,7 @@ def convert(schema, config_var, path):
elif path == "pins/esp32/val 1/id": elif path == "pins/esp32/val 1/id":
config_var["id_type"] = "pin" config_var["id_type"] = "pin"
else: else:
raise Exception("Cannot determine id_type for " + path) raise TypeError("Cannot determine id_type for " + path)
elif repr_schema in ejs.registry_schemas: elif repr_schema in ejs.registry_schemas:
solve_registry.append((ejs.registry_schemas[repr_schema], config_var)) solve_registry.append((ejs.registry_schemas[repr_schema], config_var))
@ -948,11 +953,7 @@ def convert_keys(converted, schema, path):
result["key"] = "GeneratedID" result["key"] = "GeneratedID"
elif isinstance(k, cv.Required): elif isinstance(k, cv.Required):
result["key"] = "Required" result["key"] = "Required"
elif ( elif isinstance(k, (cv.Optional, cv.Inclusive, cv.Exclusive)):
isinstance(k, cv.Optional)
or isinstance(k, cv.Inclusive)
or isinstance(k, cv.Exclusive)
):
result["key"] = "Optional" result["key"] = "Optional"
else: else:
converted["key"] = "String" converted["key"] = "String"

View File

@ -2,7 +2,6 @@
import argparse import argparse
import re import re
import subprocess
from dataclasses import dataclass from dataclasses import dataclass
import sys import sys
@ -40,12 +39,12 @@ class Version:
def sub(path, pattern, repl, expected_count=1): def sub(path, pattern, repl, expected_count=1):
with open(path) as fh: with open(path, encoding="utf-8") as fh:
content = fh.read() content = fh.read()
content, count = re.subn(pattern, repl, content, flags=re.MULTILINE) content, count = re.subn(pattern, repl, content, flags=re.MULTILINE)
if expected_count is not None: if expected_count is not None:
assert count == expected_count, f"Pattern {pattern} replacement failed!" assert count == expected_count, f"Pattern {pattern} replacement failed!"
with open(path, "w") as fh: with open(path, "w", encoding="utf-8") as fh:
fh.write(content) fh.write(content)

View File

@ -1,10 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from helpers import styled, print_error_for_file, git_ls_files, filter_changed
import argparse import argparse
import codecs import codecs
import collections import collections
import colorama
import fnmatch import fnmatch
import functools import functools
import os.path import os.path
@ -12,6 +10,9 @@ import re
import sys import sys
import time import time
import colorama
from helpers import filter_changed, git_ls_files, print_error_for_file, styled
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
@ -30,31 +31,6 @@ def find_all(a_str, sub):
column += len(sub) column += len(sub)
colorama.init()
parser = argparse.ArgumentParser()
parser.add_argument(
"files", nargs="*", default=[], help="files to be processed (regex on path)"
)
parser.add_argument(
"-c", "--changed", action="store_true", help="Only run on changed files"
)
parser.add_argument(
"--print-slowest", action="store_true", help="Print the slowest checks"
)
args = parser.parse_args()
EXECUTABLE_BIT = git_ls_files()
files = list(EXECUTABLE_BIT.keys())
# Match against re
file_name_re = re.compile("|".join(args.files))
files = [p for p in files if file_name_re.search(p)]
if args.changed:
files = filter_changed(files)
files.sort()
file_types = ( file_types = (
".h", ".h",
".c", ".c",
@ -86,6 +62,30 @@ ignore_types = (".ico", ".png", ".woff", ".woff2", "")
LINT_FILE_CHECKS = [] LINT_FILE_CHECKS = []
LINT_CONTENT_CHECKS = [] LINT_CONTENT_CHECKS = []
LINT_POST_CHECKS = [] LINT_POST_CHECKS = []
EXECUTABLE_BIT = {}
errors = collections.defaultdict(list)
def add_errors(fname, errs):
if not isinstance(errs, list):
errs = [errs]
for err in errs:
if err is None:
continue
try:
lineno, col, msg = err
except ValueError:
lineno = 1
col = 1
msg = err
if not isinstance(msg, str):
raise ValueError("Error is not instance of string!")
if not isinstance(lineno, int):
raise ValueError("Line number is not an int!")
if not isinstance(col, int):
raise ValueError("Column number is not an int!")
errors[fname].append((lineno, col, msg))
def run_check(lint_obj, fname, *args): def run_check(lint_obj, fname, *args):
@ -155,7 +155,7 @@ def lint_re_check(regex, **kwargs):
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def new_func(fname, content): def new_func(fname, content):
errors = [] errs = []
for match in prog.finditer(content): for match in prog.finditer(content):
if "NOLINT" in match.group(0): if "NOLINT" in match.group(0):
continue continue
@ -165,8 +165,8 @@ def lint_re_check(regex, **kwargs):
err = func(fname, match) err = func(fname, match)
if err is None: if err is None:
continue continue
errors.append((lineno, col + 1, err)) errs.append((lineno, col + 1, err))
return errors return errs
return decor(new_func) return decor(new_func)
@ -182,13 +182,13 @@ def lint_content_find_check(find, only_first=False, **kwargs):
find_ = find find_ = find
if callable(find): if callable(find):
find_ = find(fname, content) find_ = find(fname, content)
errors = [] errs = []
for line, col in find_all(content, find_): for line, col in find_all(content, find_):
err = func(fname) err = func(fname)
errors.append((line + 1, col + 1, err)) errs.append((line + 1, col + 1, err))
if only_first: if only_first:
break break
return errors return errs
return decor(new_func) return decor(new_func)
@ -235,8 +235,8 @@ def lint_executable_bit(fname):
ex = EXECUTABLE_BIT[fname] ex = EXECUTABLE_BIT[fname]
if ex != 100644: if ex != 100644:
return ( return (
"File has invalid executable bit {}. If running from a windows machine please " f"File has invalid executable bit {ex}. If running from a windows machine please "
"see disabling executable bit in git.".format(ex) "see disabling executable bit in git."
) )
return None return None
@ -285,8 +285,8 @@ def lint_no_defines(fname, match):
s = highlight(f"static const uint8_t {match.group(1)} = {match.group(2)};") s = highlight(f"static const uint8_t {match.group(1)} = {match.group(2)};")
return ( return (
"#define macros for integer constants are not allowed, please use " "#define macros for integer constants are not allowed, please use "
"{} style instead (replace uint8_t with the appropriate " f"{s} style instead (replace uint8_t with the appropriate "
"datatype). See also Google style guide.".format(s) "datatype). See also Google style guide."
) )
@ -296,11 +296,11 @@ def lint_no_long_delays(fname, match):
if duration_ms < 50: if duration_ms < 50:
return None return None
return ( return (
"{} - long calls to delay() are not allowed in ESPHome because everything executes " f"{highlight(match.group(0).strip())} - long calls to delay() are not allowed "
"in one thread. Calling delay() will block the main thread and slow down ESPHome.\n" "in ESPHome because everything executes in one thread. Calling delay() will "
"block the main thread and slow down ESPHome.\n"
"If there's no way to work around the delay() and it doesn't execute often, please add " "If there's no way to work around the delay() and it doesn't execute often, please add "
"a '// NOLINT' comment to the line." "a '// NOLINT' comment to the line."
"".format(highlight(match.group(0).strip()))
) )
@ -311,28 +311,28 @@ def lint_const_ordered(fname, content):
Reason: Otherwise people add it to the end, and then that results in merge conflicts. Reason: Otherwise people add it to the end, and then that results in merge conflicts.
""" """
lines = content.splitlines() lines = content.splitlines()
errors = [] errs = []
for start in ["CONF_", "ICON_", "UNIT_"]: for start in ["CONF_", "ICON_", "UNIT_"]:
matching = [ matching = [
(i + 1, line) for i, line in enumerate(lines) if line.startswith(start) (i + 1, line) for i, line in enumerate(lines) if line.startswith(start)
] ]
ordered = list(sorted(matching, key=lambda x: x[1].replace("_", " "))) ordered = list(sorted(matching, key=lambda x: x[1].replace("_", " ")))
ordered = [(mi, ol) for (mi, _), (_, ol) in zip(matching, ordered)] ordered = [(mi, ol) for (mi, _), (_, ol) in zip(matching, ordered)]
for (mi, ml), (oi, ol) in zip(matching, ordered): for (mi, mline), (_, ol) in zip(matching, ordered):
if ml == ol: if mline == ol:
continue continue
target = next(i for i, l in ordered if l == ml) target = next(i for i, line in ordered if line == mline)
target_text = next(l for i, l in matching if target == i) target_text = next(line for i, line in matching if target == i)
errors.append( errs.append(
( (
mi, mi,
1, 1,
f"Constant {highlight(ml)} is not ordered, please make sure all " f"Constant {highlight(mline)} is not ordered, please make sure all "
f"constants are ordered. See line {mi} (should go to line {target}, " f"constants are ordered. See line {mi} (should go to line {target}, "
f"{target_text})", f"{target_text})",
) )
) )
return errors return errs
@lint_re_check(r'^\s*CONF_([A-Z_0-9a-z]+)\s+=\s+[\'"](.*?)[\'"]\s*?$', include=["*.py"]) @lint_re_check(r'^\s*CONF_([A-Z_0-9a-z]+)\s+=\s+[\'"](.*?)[\'"]\s*?$', include=["*.py"])
@ -344,15 +344,14 @@ def lint_conf_matches(fname, match):
if const_norm == value_norm: if const_norm == value_norm:
return None return None
return ( return (
"Constant {} does not match value {}! Please make sure the constant's name matches its " f"Constant {highlight('CONF_' + const)} does not match value {highlight(value)}! "
"value!" "Please make sure the constant's name matches its value!"
"".format(highlight("CONF_" + const), highlight(value))
) )
CONF_RE = r'^(CONF_[a-zA-Z0-9_]+)\s*=\s*[\'"].*?[\'"]\s*?$' CONF_RE = r'^(CONF_[a-zA-Z0-9_]+)\s*=\s*[\'"].*?[\'"]\s*?$'
with codecs.open("esphome/const.py", "r", encoding="utf-8") as f_handle: with codecs.open("esphome/const.py", "r", encoding="utf-8") as const_f_handle:
constants_content = f_handle.read() constants_content = const_f_handle.read()
CONSTANTS = [m.group(1) for m in re.finditer(CONF_RE, constants_content, re.MULTILINE)] CONSTANTS = [m.group(1) for m in re.finditer(CONF_RE, constants_content, re.MULTILINE)]
CONSTANTS_USES = collections.defaultdict(list) CONSTANTS_USES = collections.defaultdict(list)
@ -365,8 +364,8 @@ def lint_conf_from_const_py(fname, match):
CONSTANTS_USES[name].append(fname) CONSTANTS_USES[name].append(fname)
return None return None
return ( return (
"Constant {} has already been defined in const.py - please import the constant from " f"Constant {highlight(name)} has already been defined in const.py - "
"const.py directly.".format(highlight(name)) "please import the constant from const.py directly."
) )
@ -473,16 +472,15 @@ def lint_no_byte_datatype(fname, match):
@lint_post_check @lint_post_check
def lint_constants_usage(): def lint_constants_usage():
errors = [] errs = []
for constant, uses in CONSTANTS_USES.items(): for constant, uses in CONSTANTS_USES.items():
if len(uses) < 4: if len(uses) < 4:
continue continue
errors.append( errs.append(
"Constant {} is defined in {} files. Please move all definitions of the " f"Constant {highlight(constant)} is defined in {len(uses)} files. Please move all definitions of the "
"constant to const.py (Uses: {})" f"constant to const.py (Uses: {', '.join(uses)})"
"".format(highlight(constant), len(uses), ", ".join(uses))
) )
return errors return errs
def relative_cpp_search_text(fname, content): def relative_cpp_search_text(fname, content):
@ -553,7 +551,7 @@ def lint_namespace(fname, content):
return ( return (
"Invalid namespace found in C++ file. All integration C++ files should put all " "Invalid namespace found in C++ file. All integration C++ files should put all "
"functions in a separate namespace that matches the integration's name. " "functions in a separate namespace that matches the integration's name. "
"Please make sure the file contains {}".format(highlight(search)) f"Please make sure the file contains {highlight(search)}"
) )
@ -639,66 +637,73 @@ def lint_log_in_header(fname):
) )
errors = collections.defaultdict(list) def main():
colorama.init()
parser = argparse.ArgumentParser()
parser.add_argument(
"files", nargs="*", default=[], help="files to be processed (regex on path)"
)
parser.add_argument(
"-c", "--changed", action="store_true", help="Only run on changed files"
)
parser.add_argument(
"--print-slowest", action="store_true", help="Print the slowest checks"
)
args = parser.parse_args()
def add_errors(fname, errs): global EXECUTABLE_BIT
if not isinstance(errs, list): EXECUTABLE_BIT = git_ls_files()
errs = [errs] files = list(EXECUTABLE_BIT.keys())
for err in errs: # Match against re
if err is None: file_name_re = re.compile("|".join(args.files))
files = [p for p in files if file_name_re.search(p)]
if args.changed:
files = filter_changed(files)
files.sort()
for fname in files:
_, ext = os.path.splitext(fname)
run_checks(LINT_FILE_CHECKS, fname, fname)
if ext in ignore_types:
continue continue
try: try:
lineno, col, msg = err with codecs.open(fname, "r", encoding="utf-8") as f_handle:
except ValueError: content = f_handle.read()
lineno = 1 except UnicodeDecodeError:
col = 1 add_errors(
msg = err fname,
if not isinstance(msg, str): "File is not readable as UTF-8. Please set your editor to UTF-8 mode.",
raise ValueError("Error is not instance of string!") )
if not isinstance(lineno, int): continue
raise ValueError("Line number is not an int!") run_checks(LINT_CONTENT_CHECKS, fname, fname, content)
if not isinstance(col, int):
raise ValueError("Column number is not an int!")
errors[fname].append((lineno, col, msg))
run_checks(LINT_POST_CHECKS, "POST")
for fname in files: for f, errs in sorted(errors.items()):
_, ext = os.path.splitext(fname) bold = functools.partial(styled, colorama.Style.BRIGHT)
run_checks(LINT_FILE_CHECKS, fname, fname) bold_red = functools.partial(styled, (colorama.Style.BRIGHT, colorama.Fore.RED))
if ext in ignore_types: err_str = (
continue f"{bold(f'{f}:{lineno}:{col}:')} {bold_red('lint:')} {msg}\n"
try: for lineno, col, msg in errs
with codecs.open(fname, "r", encoding="utf-8") as f_handle:
content = f_handle.read()
except UnicodeDecodeError:
add_errors(
fname,
"File is not readable as UTF-8. Please set your editor to UTF-8 mode.",
) )
continue print_error_for_file(f, "\n".join(err_str))
run_checks(LINT_CONTENT_CHECKS, fname, fname, content)
run_checks(LINT_POST_CHECKS, "POST") if args.print_slowest:
lint_times = []
for lint in LINT_FILE_CHECKS + LINT_CONTENT_CHECKS + LINT_POST_CHECKS:
durations = lint.get("durations", [])
lint_times.append((sum(durations), len(durations), lint["func"].__name__))
lint_times.sort(key=lambda x: -x[0])
for i in range(min(len(lint_times), 10)):
dur, invocations, name = lint_times[i]
print(f" - '{name}' took {dur:.2f}s total (ran on {invocations} files)")
print(f"Total time measured: {sum(x[0] for x in lint_times):.2f}s")
for f, errs in sorted(errors.items()): return len(errors)
bold = functools.partial(styled, colorama.Style.BRIGHT)
bold_red = functools.partial(styled, (colorama.Style.BRIGHT, colorama.Fore.RED))
err_str = (
f"{bold(f'{f}:{lineno}:{col}:')} {bold_red('lint:')} {msg}\n"
for lineno, col, msg in errs
)
print_error_for_file(f, "\n".join(err_str))
if args.print_slowest:
lint_times = []
for lint in LINT_FILE_CHECKS + LINT_CONTENT_CHECKS + LINT_POST_CHECKS:
durations = lint.get("durations", [])
lint_times.append((sum(durations), len(durations), lint["func"].__name__))
lint_times.sort(key=lambda x: -x[0])
for i in range(min(len(lint_times), 10)):
dur, invocations, name = lint_times[i]
print(f" - '{name}' took {dur:.2f}s total (ran on {invocations} files)")
print(f"Total time measured: {sum(x[0] for x in lint_times):.2f}s")
sys.exit(len(errors)) if __name__ == "__main__":
sys.exit(main())

View File

@ -1,10 +1,11 @@
import colorama import json
import os.path import os.path
import re import re
import subprocess import subprocess
import json
from pathlib import Path from pathlib import Path
import colorama
root_path = os.path.abspath(os.path.normpath(os.path.join(__file__, "..", ".."))) root_path = os.path.abspath(os.path.normpath(os.path.join(__file__, "..", "..")))
basepath = os.path.join(root_path, "esphome") basepath = os.path.join(root_path, "esphome")
temp_folder = os.path.join(root_path, ".temp") temp_folder = os.path.join(root_path, ".temp")
@ -44,7 +45,7 @@ def build_all_include():
content = "\n".join(headers) content = "\n".join(headers)
p = Path(temp_header_file) p = Path(temp_header_file)
p.parent.mkdir(exist_ok=True) p.parent.mkdir(exist_ok=True)
p.write_text(content) p.write_text(content, encoding="utf-8")
def walk_files(path): def walk_files(path):
@ -54,14 +55,14 @@ def walk_files(path):
def get_output(*args): def get_output(*args):
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) with subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc:
output, err = proc.communicate() output, _ = proc.communicate()
return output.decode("utf-8") return output.decode("utf-8")
def get_err(*args): def get_err(*args):
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) with subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc:
output, err = proc.communicate() _, err = proc.communicate()
return err.decode("utf-8") return err.decode("utf-8")
@ -78,7 +79,7 @@ def changed_files():
merge_base = splitlines_no_ends(get_output(*command))[0] merge_base = splitlines_no_ends(get_output(*command))[0]
break break
# pylint: disable=bare-except # pylint: disable=bare-except
except: except: # noqa: E722
pass pass
else: else:
raise ValueError("Git not configured") raise ValueError("Git not configured")
@ -103,7 +104,7 @@ def filter_changed(files):
def filter_grep(files, value): def filter_grep(files, value):
matched = [] matched = []
for file in files: for file in files:
with open(file) as handle: with open(file, encoding="utf-8") as handle:
contents = handle.read() contents = handle.read()
if value in contents: if value in contents:
matched.append(file) matched.append(file)
@ -114,8 +115,8 @@ def git_ls_files(patterns=None):
command = ["git", "ls-files", "-s"] command = ["git", "ls-files", "-s"]
if patterns is not None: if patterns is not None:
command.extend(patterns) command.extend(patterns)
proc = subprocess.Popen(command, stdout=subprocess.PIPE) with subprocess.Popen(command, stdout=subprocess.PIPE) as proc:
output, err = proc.communicate() output, _ = proc.communicate()
lines = [x.split() for x in output.decode("utf-8").splitlines()] lines = [x.split() for x in output.decode("utf-8").splitlines()]
return {s[3].strip(): int(s[0]) for s in lines} return {s[3].strip(): int(s[0]) for s in lines}

View File

@ -2,6 +2,7 @@
import re import re
# pylint: disable=import-error
from homeassistant.components.binary_sensor import BinarySensorDeviceClass from homeassistant.components.binary_sensor import BinarySensorDeviceClass
from homeassistant.components.button import ButtonDeviceClass from homeassistant.components.button import ButtonDeviceClass
from homeassistant.components.cover import CoverDeviceClass from homeassistant.components.cover import CoverDeviceClass
@ -9,6 +10,8 @@ from homeassistant.components.number import NumberDeviceClass
from homeassistant.components.sensor import SensorDeviceClass from homeassistant.components.sensor import SensorDeviceClass
from homeassistant.components.switch import SwitchDeviceClass from homeassistant.components.switch import SwitchDeviceClass
# pylint: enable=import-error
BLOCKLIST = ( BLOCKLIST = (
# requires special support on HA side # requires special support on HA side
"enum", "enum",
@ -25,10 +28,10 @@ DOMAINS = {
def sub(path, pattern, repl): def sub(path, pattern, repl):
with open(path) as handle: with open(path, encoding="utf-8") as handle:
content = handle.read() content = handle.read()
content = re.sub(pattern, repl, content, flags=re.MULTILINE) content = re.sub(pattern, repl, content, flags=re.MULTILINE)
with open(path, "w") as handle: with open(path, "w", encoding="utf-8") as handle:
handle.write(content) handle.write(content)