mirror of
https://github.com/esphome/esphome.git
synced 2024-11-14 02:58:11 +01:00
Merge branch 'write_read_executor' into integration
This commit is contained in:
commit
6a869836ac
6 changed files with 145 additions and 17 deletions
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
import hmac
|
import hmac
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from esphome.core import CORE
|
from esphome.core import CORE
|
||||||
from esphome.helpers import get_bool_env
|
from esphome.helpers import get_bool_env
|
||||||
|
@ -69,7 +70,8 @@ class DashboardSettings:
|
||||||
# Compare password in constant running time (to prevent timing attacks)
|
# Compare password in constant running time (to prevent timing attacks)
|
||||||
return hmac.compare_digest(self.password_hash, password_hash(password))
|
return hmac.compare_digest(self.password_hash, password_hash(password))
|
||||||
|
|
||||||
def rel_path(self, *args):
|
def rel_path(self, *args: Any) -> str:
|
||||||
|
"""Return a path relative to the ESPHome config folder."""
|
||||||
joined_path = os.path.join(self.config_dir, *args)
|
joined_path = os.path.join(self.config_dir, *args)
|
||||||
# Raises ValueError if not relative to ESPHome config folder
|
# Raises ValueError if not relative to ESPHome config folder
|
||||||
Path(joined_path).resolve().relative_to(self.absolute_config_dir)
|
Path(joined_path).resolve().relative_to(self.absolute_config_dir)
|
||||||
|
|
55
esphome/dashboard/util/file.py
Normal file
55
esphome/dashboard/util/file.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def write_utf8_file(
|
||||||
|
filename: Path,
|
||||||
|
utf8_str: str,
|
||||||
|
private: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Write a file and rename it into place.
|
||||||
|
|
||||||
|
Writes all or nothing.
|
||||||
|
"""
|
||||||
|
write_file(filename, utf8_str.encode("utf-8"), private)
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/home-assistant/core/blob/dev/homeassistant/util/file.py
|
||||||
|
def write_file(
|
||||||
|
filename: Path,
|
||||||
|
utf8_data: bytes,
|
||||||
|
private: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Write a file and rename it into place.
|
||||||
|
|
||||||
|
Writes all or nothing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tmp_filename = ""
|
||||||
|
try:
|
||||||
|
# Modern versions of Python tempfile create this file with mode 0o600
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="wb", dir=os.path.dirname(filename), delete=False
|
||||||
|
) as fdesc:
|
||||||
|
fdesc.write(utf8_data)
|
||||||
|
tmp_filename = fdesc.name
|
||||||
|
if not private:
|
||||||
|
os.fchmod(fdesc.fileno(), 0o644)
|
||||||
|
os.replace(tmp_filename, filename)
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_filename):
|
||||||
|
try:
|
||||||
|
os.remove(tmp_filename)
|
||||||
|
except OSError as err:
|
||||||
|
# If we are cleaning up then something else went wrong, so
|
||||||
|
# we should suppress likely follow-on errors in the cleanup
|
||||||
|
_LOGGER.error(
|
||||||
|
"File replacement cleanup failed for %s while saving %s: %s",
|
||||||
|
tmp_filename,
|
||||||
|
filename,
|
||||||
|
err,
|
||||||
|
)
|
|
@ -38,6 +38,7 @@ from esphome.yaml_util import FastestAvailableSafeLoader
|
||||||
|
|
||||||
from .core import DASHBOARD
|
from .core import DASHBOARD
|
||||||
from .entries import EntryState, entry_state_to_bool
|
from .entries import EntryState, entry_state_to_bool
|
||||||
|
from .util.file import write_file
|
||||||
from .util.subprocess import async_run_system_command
|
from .util.subprocess import async_run_system_command
|
||||||
from .util.text import friendly_name_slugify
|
from .util.text import friendly_name_slugify
|
||||||
|
|
||||||
|
@ -525,9 +526,19 @@ class DownloadListRequestHandler(BaseHandler):
|
||||||
|
|
||||||
|
|
||||||
class DownloadBinaryRequestHandler(BaseHandler):
|
class DownloadBinaryRequestHandler(BaseHandler):
|
||||||
|
def _load_file(self, path: str, compressed: bool) -> bytes:
|
||||||
|
"""Load a file from disk and compress it if requested."""
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
if compressed:
|
||||||
|
return gzip.compress(data, 9)
|
||||||
|
return data
|
||||||
|
|
||||||
@authenticated
|
@authenticated
|
||||||
@bind_config
|
@bind_config
|
||||||
async def get(self, configuration=None):
|
async def get(self, configuration: str | None = None):
|
||||||
|
"""Download a binary file."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
compressed = self.get_argument("compressed", "0") == "1"
|
compressed = self.get_argument("compressed", "0") == "1"
|
||||||
|
|
||||||
storage_path = ext_storage_path(configuration)
|
storage_path = ext_storage_path(configuration)
|
||||||
|
@ -584,11 +595,8 @@ class DownloadBinaryRequestHandler(BaseHandler):
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(path, "rb") as f:
|
data = await loop.run_in_executor(None, self._load_file, path, compressed)
|
||||||
data = f.read()
|
self.write(data)
|
||||||
if compressed:
|
|
||||||
data = gzip.compress(data, 9)
|
|
||||||
self.write(data)
|
|
||||||
|
|
||||||
self.finish()
|
self.finish()
|
||||||
|
|
||||||
|
@ -747,22 +755,32 @@ class InfoRequestHandler(BaseHandler):
|
||||||
class EditRequestHandler(BaseHandler):
|
class EditRequestHandler(BaseHandler):
|
||||||
@authenticated
|
@authenticated
|
||||||
@bind_config
|
@bind_config
|
||||||
def get(self, configuration=None):
|
async def get(self, configuration: str | None = None):
|
||||||
|
"""Get the content of a file."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
filename = settings.rel_path(configuration)
|
filename = settings.rel_path(configuration)
|
||||||
content = ""
|
content = await loop.run_in_executor(None, self._read_file, filename)
|
||||||
if os.path.isfile(filename):
|
|
||||||
with open(file=filename, encoding="utf-8") as f:
|
|
||||||
content = f.read()
|
|
||||||
self.write(content)
|
self.write(content)
|
||||||
|
|
||||||
|
def _read_file(self, filename: str) -> bytes:
|
||||||
|
"""Read a file and return the content as bytes."""
|
||||||
|
with open(file=filename, encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def _write_file(self, filename: str, content: bytes) -> None:
|
||||||
|
"""Write a file with the given content."""
|
||||||
|
write_file(filename, content)
|
||||||
|
|
||||||
@authenticated
|
@authenticated
|
||||||
@bind_config
|
@bind_config
|
||||||
async def post(self, configuration=None):
|
async def post(self, configuration: str | None = None):
|
||||||
# Atomic write
|
"""Write the content of a file."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
config_file = settings.rel_path(configuration)
|
config_file = settings.rel_path(configuration)
|
||||||
with open(file=config_file, mode="wb") as f:
|
await loop.run_in_executor(
|
||||||
f.write(self.request.body)
|
None, self._write_file, config_file, self.request.body
|
||||||
|
)
|
||||||
|
# Ensure the StorageJSON is updated as well
|
||||||
await async_run_system_command(
|
await async_run_system_command(
|
||||||
[*DASHBOARD_COMMAND, "compile", "--only-generate", config_file]
|
[*DASHBOARD_COMMAND, "compile", "--only-generate", config_file]
|
||||||
)
|
)
|
||||||
|
|
0
tests/dashboard/__init__.py
Normal file
0
tests/dashboard/__init__.py
Normal file
0
tests/dashboard/util/__init__.py
Normal file
0
tests/dashboard/util/__init__.py
Normal file
53
tests/dashboard/util/test_file.py
Normal file
53
tests/dashboard/util/test_file.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import py
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from esphome.dashboard.util.file import write_file, write_utf8_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_utf8_file(tmp_path: Path) -> None:
|
||||||
|
write_utf8_file(tmp_path.joinpath("foo.txt"), "foo")
|
||||||
|
assert tmp_path.joinpath("foo.txt").read_text() == "foo"
|
||||||
|
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
write_utf8_file(Path("/not-writable"), "bar")
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_file(tmp_path: Path) -> None:
|
||||||
|
write_file(tmp_path.joinpath("foo.txt"), b"foo")
|
||||||
|
assert tmp_path.joinpath("foo.txt").read_text() == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_utf8_file_fails_at_rename(
|
||||||
|
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test that if rename fails not not remove, we do not log the failed cleanup."""
|
||||||
|
test_dir = tmpdir.mkdir("files")
|
||||||
|
test_file = Path(test_dir / "test.json")
|
||||||
|
|
||||||
|
with pytest.raises(OSError), patch(
|
||||||
|
"esphome.dashboard.util.file.os.replace", side_effect=OSError
|
||||||
|
):
|
||||||
|
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||||
|
|
||||||
|
assert not os.path.exists(test_file)
|
||||||
|
|
||||||
|
assert "File replacement cleanup failed" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_utf8_file_fails_at_rename_and_remove(
|
||||||
|
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test that if rename and remove both fail, we log the failed cleanup."""
|
||||||
|
test_dir = tmpdir.mkdir("files")
|
||||||
|
test_file = Path(test_dir / "test.json")
|
||||||
|
|
||||||
|
with pytest.raises(OSError), patch(
|
||||||
|
"esphome.dashboard.util.file.os.remove", side_effect=OSError
|
||||||
|
), patch("esphome.dashboard.util.file.os.replace", side_effect=OSError):
|
||||||
|
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||||
|
|
||||||
|
assert "File replacement cleanup failed" in caplog.text
|
Loading…
Reference in a new issue