Merge branch 'write_read_executor' into integration

This commit is contained in:
J. Nick Koston 2023-11-18 09:22:29 -06:00
commit 6a869836ac
No known key found for this signature in database
6 changed files with 145 additions and 17 deletions

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import hmac import hmac
import os import os
from pathlib import Path from pathlib import Path
from typing import Any
from esphome.core import CORE from esphome.core import CORE
from esphome.helpers import get_bool_env from esphome.helpers import get_bool_env
@ -69,7 +70,8 @@ class DashboardSettings:
# Compare password in constant running time (to prevent timing attacks) # Compare password in constant running time (to prevent timing attacks)
return hmac.compare_digest(self.password_hash, password_hash(password)) return hmac.compare_digest(self.password_hash, password_hash(password))
def rel_path(self, *args): def rel_path(self, *args: Any) -> str:
"""Return a path relative to the ESPHome config folder."""
joined_path = os.path.join(self.config_dir, *args) joined_path = os.path.join(self.config_dir, *args)
# Raises ValueError if not relative to ESPHome config folder # Raises ValueError if not relative to ESPHome config folder
Path(joined_path).resolve().relative_to(self.absolute_config_dir) Path(joined_path).resolve().relative_to(self.absolute_config_dir)

View file

@ -0,0 +1,55 @@
import logging
import os
import tempfile
from pathlib import Path
_LOGGER = logging.getLogger(__name__)
def write_utf8_file(
filename: Path,
utf8_str: str,
private: bool = False,
) -> None:
"""Write a file and rename it into place.
Writes all or nothing.
"""
write_file(filename, utf8_str.encode("utf-8"), private)
# from https://github.com/home-assistant/core/blob/dev/homeassistant/util/file.py
def write_file(
filename: Path,
utf8_data: bytes,
private: bool = False,
) -> None:
"""Write a file and rename it into place.
Writes all or nothing.
"""
tmp_filename = ""
try:
# Modern versions of Python tempfile create this file with mode 0o600
with tempfile.NamedTemporaryFile(
mode="wb", dir=os.path.dirname(filename), delete=False
) as fdesc:
fdesc.write(utf8_data)
tmp_filename = fdesc.name
if not private:
os.fchmod(fdesc.fileno(), 0o644)
os.replace(tmp_filename, filename)
finally:
if os.path.exists(tmp_filename):
try:
os.remove(tmp_filename)
except OSError as err:
# If we are cleaning up then something else went wrong, so
# we should suppress likely follow-on errors in the cleanup
_LOGGER.error(
"File replacement cleanup failed for %s while saving %s: %s",
tmp_filename,
filename,
err,
)

View file

@ -38,6 +38,7 @@ from esphome.yaml_util import FastestAvailableSafeLoader
from .core import DASHBOARD from .core import DASHBOARD
from .entries import EntryState, entry_state_to_bool from .entries import EntryState, entry_state_to_bool
from .util.file import write_file
from .util.subprocess import async_run_system_command from .util.subprocess import async_run_system_command
from .util.text import friendly_name_slugify from .util.text import friendly_name_slugify
@ -525,9 +526,19 @@ class DownloadListRequestHandler(BaseHandler):
class DownloadBinaryRequestHandler(BaseHandler): class DownloadBinaryRequestHandler(BaseHandler):
def _load_file(self, path: str, compressed: bool) -> bytes:
"""Load a file from disk and compress it if requested."""
with open(path, "rb") as f:
data = f.read()
if compressed:
return gzip.compress(data, 9)
return data
@authenticated @authenticated
@bind_config @bind_config
async def get(self, configuration=None): async def get(self, configuration: str | None = None):
"""Download a binary file."""
loop = asyncio.get_running_loop()
compressed = self.get_argument("compressed", "0") == "1" compressed = self.get_argument("compressed", "0") == "1"
storage_path = ext_storage_path(configuration) storage_path = ext_storage_path(configuration)
@ -584,11 +595,8 @@ class DownloadBinaryRequestHandler(BaseHandler):
self.send_error(404) self.send_error(404)
return return
with open(path, "rb") as f: data = await loop.run_in_executor(None, self._load_file, path, compressed)
data = f.read() self.write(data)
if compressed:
data = gzip.compress(data, 9)
self.write(data)
self.finish() self.finish()
@ -747,22 +755,32 @@ class InfoRequestHandler(BaseHandler):
class EditRequestHandler(BaseHandler): class EditRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
def get(self, configuration=None): async def get(self, configuration: str | None = None):
"""Get the content of a file."""
loop = asyncio.get_running_loop()
filename = settings.rel_path(configuration) filename = settings.rel_path(configuration)
content = "" content = await loop.run_in_executor(None, self._read_file, filename)
if os.path.isfile(filename):
with open(file=filename, encoding="utf-8") as f:
content = f.read()
self.write(content) self.write(content)
def _read_file(self, filename: str) -> bytes:
"""Read a file and return the content as bytes."""
with open(file=filename, encoding="utf-8") as f:
return f.read()
def _write_file(self, filename: str, content: bytes) -> None:
"""Write a file with the given content."""
write_file(filename, content)
@authenticated @authenticated
@bind_config @bind_config
async def post(self, configuration=None): async def post(self, configuration: str | None = None):
# Atomic write """Write the content of a file."""
loop = asyncio.get_running_loop()
config_file = settings.rel_path(configuration) config_file = settings.rel_path(configuration)
with open(file=config_file, mode="wb") as f: await loop.run_in_executor(
f.write(self.request.body) None, self._write_file, config_file, self.request.body
)
# Ensure the StorageJSON is updated as well
await async_run_system_command( await async_run_system_command(
[*DASHBOARD_COMMAND, "compile", "--only-generate", config_file] [*DASHBOARD_COMMAND, "compile", "--only-generate", config_file]
) )

View file

View file

View file

@ -0,0 +1,53 @@
import os
from pathlib import Path
from unittest.mock import patch
import py
import pytest
from esphome.dashboard.util.file import write_file, write_utf8_file
def test_write_utf8_file(tmp_path: Path) -> None:
write_utf8_file(tmp_path.joinpath("foo.txt"), "foo")
assert tmp_path.joinpath("foo.txt").read_text() == "foo"
with pytest.raises(OSError):
write_utf8_file(Path("/not-writable"), "bar")
def test_write_file(tmp_path: Path) -> None:
write_file(tmp_path.joinpath("foo.txt"), b"foo")
assert tmp_path.joinpath("foo.txt").read_text() == "foo"
def test_write_utf8_file_fails_at_rename(
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that if rename fails not not remove, we do not log the failed cleanup."""
test_dir = tmpdir.mkdir("files")
test_file = Path(test_dir / "test.json")
with pytest.raises(OSError), patch(
"esphome.dashboard.util.file.os.replace", side_effect=OSError
):
write_utf8_file(test_file, '{"some":"data"}', False)
assert not os.path.exists(test_file)
assert "File replacement cleanup failed" not in caplog.text
def test_write_utf8_file_fails_at_rename_and_remove(
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that if rename and remove both fail, we log the failed cleanup."""
test_dir = tmpdir.mkdir("files")
test_file = Path(test_dir / "test.json")
with pytest.raises(OSError), patch(
"esphome.dashboard.util.file.os.remove", side_effect=OSError
), patch("esphome.dashboard.util.file.os.replace", side_effect=OSError):
write_utf8_file(test_file, '{"some":"data"}', False)
assert "File replacement cleanup failed" in caplog.text