diff --git a/esphome/dashboard/util/file.py b/esphome/dashboard/util/file.py index 74deeacf87..5f3c5f5f1b 100644 --- a/esphome/dashboard/util/file.py +++ b/esphome/dashboard/util/file.py @@ -6,10 +6,22 @@ from pathlib import Path _LOGGER = logging.getLogger(__name__) -# from https://github.com/home-assistant/core/blob/dev/homeassistant/util/file.py def write_utf8_file( filename: Path, - utf8_data: str, + utf8_str: str, + private: bool = False, +) -> None: + """Write a file and rename it into place. + + Writes all or nothing. + """ + write_file(filename, utf8_str.encode("utf-8"), private) + + +# from https://github.com/home-assistant/core/blob/dev/homeassistant/util/file.py +def write_file( + filename: Path, + utf8_data: bytes, private: bool = False, ) -> None: """Write a file and rename it into place. @@ -21,7 +33,7 @@ def write_utf8_file( try: # Modern versions of Python tempfile create this file with mode 0o600 with tempfile.NamedTemporaryFile( - mode="w", encoding="utf-8", dir=os.path.dirname(filename), delete=False + mode="wb", dir=os.path.dirname(filename), delete=False ) as fdesc: fdesc.write(utf8_data) tmp_filename = fdesc.name diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index aa8e445816..93d836d76d 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -38,6 +38,7 @@ from esphome.yaml_util import FastestAvailableSafeLoader from .core import DASHBOARD from .entries import EntryState, entry_state_to_bool +from .util.file import write_file from .util.subprocess import async_run_system_command from .util.text import friendly_name_slugify @@ -746,22 +747,34 @@ class InfoRequestHandler(BaseHandler): class EditRequestHandler(BaseHandler): @authenticated @bind_config - def get(self, configuration=None): + async def get(self, configuration: str | None = None): + loop = asyncio.get_running_loop() filename = settings.rel_path(configuration) - content = "" - if os.path.isfile(filename): - with open(file=filename, encoding="utf-8") as f: - content = f.read() + try: + content = await loop.run_in_executor(None, self._read_file, filename) + except OSError: + self.send_error(404) + return self.write(content) + def _read_file(self, filename: str) -> bytes: + """Read a file and return the content as bytes.""" + with open(file=filename, encoding="utf-8") as f: + return f.read() + + def _write_file(self, filename: str, content: bytes) -> None: + """Write a file with the given content.""" + write_file(filename, content) + @authenticated @bind_config - async def post(self, configuration=None): - # Atomic write + async def post(self, configuration: str | None = None): + loop = asyncio.get_running_loop() config_file = settings.rel_path(configuration) - with open(file=config_file, mode="wb") as f: - f.write(self.request.body) - + await loop.run_in_executor( + None, self._write_file, config_file, self.request.body + ) + # Ensure the StorageJSON is updated as well await async_run_system_command( [*DASHBOARD_COMMAND, "compile", "--only-generate", config_file] ) diff --git a/tests/dashboard/util/test_file.py b/tests/dashboard/util/test_file.py index fd4860dbb3..89e6b97086 100644 --- a/tests/dashboard/util/test_file.py +++ b/tests/dashboard/util/test_file.py @@ -5,7 +5,7 @@ from unittest.mock import patch import py import pytest -from esphome.dashboard.util.file import write_utf8_file +from esphome.dashboard.util.file import write_file, write_utf8_file def test_write_utf8_file(tmp_path: Path) -> None: @@ -16,6 +16,11 @@ def test_write_utf8_file(tmp_path: Path) -> None: write_utf8_file(Path("/not-writable"), "bar") +def test_write_file(tmp_path: Path) -> None: + write_file(tmp_path.joinpath("foo.txt"), b"foo") + assert tmp_path.joinpath("foo.txt").read_text() == "foo" + + def test_write_utf8_file_fails_at_rename( tmpdir: py.path.local, caplog: pytest.LogCaptureFixture ) -> None: