This commit is contained in:
J. Nick Koston 2023-11-18 09:13:42 -06:00
parent f060c4020c
commit d46188ecda
No known key found for this signature in database
3 changed files with 44 additions and 14 deletions

View file

@ -6,10 +6,22 @@ from pathlib import Path
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# from https://github.com/home-assistant/core/blob/dev/homeassistant/util/file.py
def write_utf8_file( def write_utf8_file(
filename: Path, filename: Path,
utf8_data: str, utf8_str: str,
private: bool = False,
) -> None:
"""Write a file and rename it into place.
Writes all or nothing.
"""
write_file(filename, utf8_str.encode("utf-8"), private)
# from https://github.com/home-assistant/core/blob/dev/homeassistant/util/file.py
def write_file(
filename: Path,
utf8_data: bytes,
private: bool = False, private: bool = False,
) -> None: ) -> None:
"""Write a file and rename it into place. """Write a file and rename it into place.
@ -21,7 +33,7 @@ def write_utf8_file(
try: try:
# Modern versions of Python tempfile create this file with mode 0o600 # Modern versions of Python tempfile create this file with mode 0o600
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(
mode="w", encoding="utf-8", dir=os.path.dirname(filename), delete=False mode="wb", dir=os.path.dirname(filename), delete=False
) as fdesc: ) as fdesc:
fdesc.write(utf8_data) fdesc.write(utf8_data)
tmp_filename = fdesc.name tmp_filename = fdesc.name

View file

@ -38,6 +38,7 @@ from esphome.yaml_util import FastestAvailableSafeLoader
from .core import DASHBOARD from .core import DASHBOARD
from .entries import EntryState, entry_state_to_bool from .entries import EntryState, entry_state_to_bool
from .util.file import write_file
from .util.subprocess import async_run_system_command from .util.subprocess import async_run_system_command
from .util.text import friendly_name_slugify from .util.text import friendly_name_slugify
@ -746,22 +747,34 @@ class InfoRequestHandler(BaseHandler):
class EditRequestHandler(BaseHandler): class EditRequestHandler(BaseHandler):
@authenticated @authenticated
@bind_config @bind_config
def get(self, configuration=None): async def get(self, configuration: str | None = None):
loop = asyncio.get_running_loop()
filename = settings.rel_path(configuration) filename = settings.rel_path(configuration)
content = "" try:
if os.path.isfile(filename): content = await loop.run_in_executor(None, self._read_file, filename)
with open(file=filename, encoding="utf-8") as f: except OSError:
content = f.read() self.send_error(404)
return
self.write(content) self.write(content)
def _read_file(self, filename: str) -> bytes:
"""Read a file and return the content as bytes."""
with open(file=filename, encoding="utf-8") as f:
return f.read()
def _write_file(self, filename: str, content: bytes) -> None:
"""Write a file with the given content."""
write_file(filename, content)
@authenticated @authenticated
@bind_config @bind_config
async def post(self, configuration=None): async def post(self, configuration: str | None = None):
# Atomic write loop = asyncio.get_running_loop()
config_file = settings.rel_path(configuration) config_file = settings.rel_path(configuration)
with open(file=config_file, mode="wb") as f: await loop.run_in_executor(
f.write(self.request.body) None, self._write_file, config_file, self.request.body
)
# Ensure the StorageJSON is updated as well
await async_run_system_command( await async_run_system_command(
[*DASHBOARD_COMMAND, "compile", "--only-generate", config_file] [*DASHBOARD_COMMAND, "compile", "--only-generate", config_file]
) )

View file

@ -5,7 +5,7 @@ from unittest.mock import patch
import py import py
import pytest import pytest
from esphome.dashboard.util.file import write_utf8_file from esphome.dashboard.util.file import write_file, write_utf8_file
def test_write_utf8_file(tmp_path: Path) -> None: def test_write_utf8_file(tmp_path: Path) -> None:
@ -16,6 +16,11 @@ def test_write_utf8_file(tmp_path: Path) -> None:
write_utf8_file(Path("/not-writable"), "bar") write_utf8_file(Path("/not-writable"), "bar")
def test_write_file(tmp_path: Path) -> None:
write_file(tmp_path.joinpath("foo.txt"), b"foo")
assert tmp_path.joinpath("foo.txt").read_text() == "foo"
def test_write_utf8_file_fails_at_rename( def test_write_utf8_file_fails_at_rename(
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
) -> None: ) -> None: