esphome/esphome/util.py
Quentin Smith 19900b004b
Fix type annotation on extract_registry_entry_config (#3623)
Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
2022-10-12 15:15:03 +13:00

306 lines
9 KiB
Python

from typing import Union
import collections
import io
import logging
import os
import re
import subprocess
import sys
from pathlib import Path
from esphome import const
_LOGGER = logging.getLogger(__name__)
class RegistryEntry:
def __init__(self, name, fun, type_id, schema):
self.name = name
self.fun = fun
self.type_id = type_id
self.raw_schema = schema
@property
def coroutine_fun(self):
from esphome.core import coroutine
return coroutine(self.fun)
@property
def schema(self):
from esphome.config_validation import Schema
return Schema(self.raw_schema)
class Registry(dict[str, RegistryEntry]):
def __init__(self, base_schema=None, type_id_key=None):
super().__init__()
self.base_schema = base_schema or {}
self.type_id_key = type_id_key
def register(self, name, type_id, schema):
def decorator(fun):
self[name] = RegistryEntry(name, fun, type_id, schema)
return fun
return decorator
class SimpleRegistry(dict):
def register(self, name, data):
def decorator(fun):
self[name] = (fun, data)
return fun
return decorator
def safe_print(message=""):
from esphome.core import CORE
if CORE.dashboard:
try:
message = message.replace("\033", "\\033")
except UnicodeEncodeError:
pass
try:
print(message)
return
except UnicodeEncodeError:
pass
try:
print(message.encode("utf-8", "backslashreplace"))
except UnicodeEncodeError:
try:
print(message.encode("ascii", "backslashreplace"))
except UnicodeEncodeError:
print("Cannot print line because of invalid locale!")
def shlex_quote(s):
if not s:
return "''"
if re.search(r"[^\w@%+=:,./-]", s) is None:
return s
return "'" + s.replace("'", "'\"'\"'") + "'"
ANSI_ESCAPE = re.compile(r"\033[@-_][0-?]*[ -/]*[@-~]")
class RedirectText:
def __init__(self, out, filter_lines=None):
self._out = out
if filter_lines is None:
self._filter_pattern = None
else:
pattern = r"|".join(r"(?:" + pattern + r")" for pattern in filter_lines)
self._filter_pattern = re.compile(pattern)
self._line_buffer = ""
def __getattr__(self, item):
return getattr(self._out, item)
def _write_color_replace(self, s):
from esphome.core import CORE
if CORE.dashboard:
# With the dashboard, we must create a little hack to make color output
# work. The shell we create in the dashboard is not a tty, so python removes
# all color codes from the resulting stream. We just convert them to something
# we can easily recognize later here.
s = s.replace("\033", "\\033")
self._out.write(s)
def write(self, s):
# s is usually a str already (self._out is of type TextIOWrapper)
# However, s is sometimes also a bytes object in python3. Let's make sure it's a
# str
# If the conversion fails, we will create an exception, which is okay because we won't
# be able to print it anyway.
if not isinstance(s, str):
s = s.decode()
if self._filter_pattern is not None:
self._line_buffer += s
lines = self._line_buffer.splitlines(True)
for line in lines:
if "\n" not in line and "\r" not in line:
# Not a complete line, set line buffer
self._line_buffer = line
break
self._line_buffer = ""
line_without_ansi = ANSI_ESCAPE.sub("", line)
line_without_end = line_without_ansi.rstrip()
if self._filter_pattern.match(line_without_end) is not None:
# Filter pattern matched, ignore the line
continue
self._write_color_replace(line)
else:
self._write_color_replace(s)
# write() returns the number of characters written
# Let's print the number of characters of the original string in order to not confuse
# any caller.
return len(s)
def isatty(self):
return True
def run_external_command(
func, *cmd, capture_stdout: bool = False, filter_lines: str = None
) -> Union[int, str]:
"""
Run a function from an external package that acts like a main method.
Temporarily replaces stdin/stderr/stdout, sys.argv and sys.exit handler during the run.
:param func: Function to execute
:param cmd: Command to run as (eg first element of sys.argv)
:param capture_stdout: Capture text from stdout and return that.
:param filter_lines: Regular expression used to filter captured output.
:return: str if `capture_stdout` is set else int exit code.
"""
def mock_exit(return_code):
raise SystemExit(return_code)
orig_argv = sys.argv
orig_exit = sys.exit # mock sys.exit
full_cmd = " ".join(shlex_quote(x) for x in cmd)
_LOGGER.debug("Running: %s", full_cmd)
orig_stdout = sys.stdout
sys.stdout = RedirectText(sys.stdout, filter_lines=filter_lines)
orig_stderr = sys.stderr
sys.stderr = RedirectText(sys.stderr, filter_lines=filter_lines)
if capture_stdout:
cap_stdout = sys.stdout = io.StringIO()
try:
sys.argv = list(cmd)
sys.exit = mock_exit
return func() or 0
except KeyboardInterrupt: # pylint: disable=try-except-raise
raise
except SystemExit as err:
return err.args[0]
except Exception as err: # pylint: disable=broad-except
_LOGGER.error("Running command failed: %s", err)
_LOGGER.error("Please try running %s locally.", full_cmd)
return 1
finally:
sys.argv = orig_argv
sys.exit = orig_exit
sys.stdout = orig_stdout
sys.stderr = orig_stderr
if capture_stdout:
# pylint: disable=lost-exception
return cap_stdout.getvalue()
def run_external_process(*cmd, **kwargs):
full_cmd = " ".join(shlex_quote(x) for x in cmd)
_LOGGER.debug("Running: %s", full_cmd)
filter_lines = kwargs.get("filter_lines")
capture_stdout = kwargs.get("capture_stdout", False)
if capture_stdout:
sub_stdout = subprocess.PIPE
else:
sub_stdout = RedirectText(sys.stdout, filter_lines=filter_lines)
sub_stderr = RedirectText(sys.stderr, filter_lines=filter_lines)
try:
proc = subprocess.run(
cmd, stdout=sub_stdout, stderr=sub_stderr, encoding="utf-8", check=False
)
return proc.stdout if capture_stdout else proc.returncode
except KeyboardInterrupt: # pylint: disable=try-except-raise
raise
except Exception as err: # pylint: disable=broad-except
_LOGGER.error("Running command failed: %s", err)
_LOGGER.error("Please try running %s locally.", full_cmd)
return 1
def is_dev_esphome_version():
return "dev" in const.__version__
def parse_esphome_version() -> tuple[int, int, int]:
match = re.match(r"^(\d+).(\d+).(\d+)(-dev\d*|b\d*)?$", const.__version__)
if match is None:
raise ValueError(f"Failed to parse ESPHome version '{const.__version__}'")
return int(match.group(1)), int(match.group(2)), int(match.group(3))
# Custom OrderedDict with nicer repr method for debugging
class OrderedDict(collections.OrderedDict):
def __repr__(self):
return dict(self).__repr__()
def list_yaml_files(folders):
files = filter_yaml_files(
[os.path.join(folder, p) for folder in folders for p in os.listdir(folder)]
)
files.sort()
return files
def filter_yaml_files(files):
return [
f
for f in files
if (
os.path.splitext(f)[1] in (".yaml", ".yml")
and os.path.basename(f) not in ("secrets.yaml", "secrets.yml")
and not os.path.basename(f).startswith(".")
)
]
class SerialPort:
def __init__(self, path: str, description: str):
self.path = path
self.description = description
# from https://github.com/pyserial/pyserial/blob/master/serial/tools/list_ports.py
def get_serial_ports() -> list[SerialPort]:
from serial.tools.list_ports import comports
result = []
for port, desc, info in comports(include_links=True):
if not port:
continue
if "VID:PID" in info:
result.append(SerialPort(path=port, description=desc))
# Also add objects in /dev/serial/by-id/
# ref: https://github.com/esphome/issues/issues/1346
by_id_path = Path("/dev/serial/by-id")
if sys.platform.lower().startswith("linux") and by_id_path.exists():
from serial.tools.list_ports_linux import SysFS
for path in by_id_path.glob("*"):
device = SysFS(path)
if device.subsystem == "platform":
result.append(SerialPort(path=str(path), description=info[1]))
result.sort(key=lambda x: x.path)
return result