Wizard: fix colored text in input prompts (#5313)

This commit is contained in:
Kuba Szczodrzyński 2023-09-21 00:09:23 +02:00 committed by Jesse Hills
parent 807c47a076
commit d7e267eca5
No known key found for this signature in database
GPG key ID: BEAAE804EFD8E83A
3 changed files with 26 additions and 18 deletions

View file

@ -57,7 +57,7 @@ class SimpleRegistry(dict):
return decorator return decorator
def safe_print(message=""): def safe_print(message="", end="\n"):
from esphome.core import CORE from esphome.core import CORE
if CORE.dashboard: if CORE.dashboard:
@ -67,20 +67,26 @@ def safe_print(message=""):
pass pass
try: try:
print(message) print(message, end=end)
return return
except UnicodeEncodeError: except UnicodeEncodeError:
pass pass
try: try:
print(message.encode("utf-8", "backslashreplace")) print(message.encode("utf-8", "backslashreplace"), end=end)
except UnicodeEncodeError: except UnicodeEncodeError:
try: try:
print(message.encode("ascii", "backslashreplace")) print(message.encode("ascii", "backslashreplace"), end=end)
except UnicodeEncodeError: except UnicodeEncodeError:
print("Cannot print line because of invalid locale!") print("Cannot print line because of invalid locale!")
def safe_input(prompt=""):
if prompt:
safe_print(prompt, end="")
return input()
def shlex_quote(s): def shlex_quote(s):
if not s: if not s:
return "''" return "''"

View file

@ -11,7 +11,7 @@ from esphome.core import CORE
from esphome.helpers import get_bool_env, write_file from esphome.helpers import get_bool_env, write_file
from esphome.log import Fore, color from esphome.log import Fore, color
from esphome.storage_json import StorageJSON, ext_storage_path from esphome.storage_json import StorageJSON, ext_storage_path
from esphome.util import safe_print from esphome.util import safe_input, safe_print
CORE_BIG = r""" _____ ____ _____ ______ CORE_BIG = r""" _____ ____ _____ ______
/ ____/ __ \| __ \| ____| / ____/ __ \| __ \| ____|
@ -252,7 +252,7 @@ def safe_print_step(step, big):
def default_input(text, default): def default_input(text, default):
safe_print() safe_print()
safe_print(f"Press ENTER for default ({default})") safe_print(f"Press ENTER for default ({default})")
return input(text.format(default)) or default return safe_input(text.format(default)) or default
# From https://stackoverflow.com/a/518232/8924614 # From https://stackoverflow.com/a/518232/8924614
@ -306,7 +306,7 @@ def wizard(path):
) )
safe_print() safe_print()
sleep(1) sleep(1)
name = input(color(Fore.BOLD_WHITE, "(name): ")) name = safe_input(color(Fore.BOLD_WHITE, "(name): "))
while True: while True:
try: try:
@ -343,7 +343,9 @@ def wizard(path):
while True: while True:
sleep(0.5) sleep(0.5)
safe_print() safe_print()
platform = input(color(Fore.BOLD_WHITE, f"({'/'.join(wizard_platforms)}): ")) platform = safe_input(
color(Fore.BOLD_WHITE, f"({'/'.join(wizard_platforms)}): ")
)
try: try:
platform = vol.All(vol.Upper, vol.Any(*wizard_platforms))(platform.upper()) platform = vol.All(vol.Upper, vol.Any(*wizard_platforms))(platform.upper())
break break
@ -397,7 +399,7 @@ def wizard(path):
boards.append(board_id) boards.append(board_id)
while True: while True:
board = input(color(Fore.BOLD_WHITE, "(board): ")) board = safe_input(color(Fore.BOLD_WHITE, "(board): "))
try: try:
board = vol.All(vol.Lower, vol.Any(*boards))(board) board = vol.All(vol.Lower, vol.Any(*boards))(board)
break break
@ -423,7 +425,7 @@ def wizard(path):
sleep(1.5) sleep(1.5)
safe_print(f"For example \"{color(Fore.BOLD_WHITE, 'Abraham Linksys')}\".") safe_print(f"For example \"{color(Fore.BOLD_WHITE, 'Abraham Linksys')}\".")
while True: while True:
ssid = input(color(Fore.BOLD_WHITE, "(ssid): ")) ssid = safe_input(color(Fore.BOLD_WHITE, "(ssid): "))
try: try:
ssid = cv.ssid(ssid) ssid = cv.ssid(ssid)
break break
@ -449,7 +451,7 @@ def wizard(path):
safe_print() safe_print()
safe_print(f"For example \"{color(Fore.BOLD_WHITE, 'PASSWORD42')}\"") safe_print(f"For example \"{color(Fore.BOLD_WHITE, 'PASSWORD42')}\"")
sleep(0.5) sleep(0.5)
psk = input(color(Fore.BOLD_WHITE, "(PSK): ")) psk = safe_input(color(Fore.BOLD_WHITE, "(PSK): "))
safe_print( safe_print(
"Perfect! WiFi is now set up (you can create static IPs and so on later)." "Perfect! WiFi is now set up (you can create static IPs and so on later)."
) )
@ -466,7 +468,7 @@ def wizard(path):
safe_print() safe_print()
sleep(0.25) sleep(0.25)
safe_print("Press ENTER for no password") safe_print("Press ENTER for no password")
password = input(color(Fore.BOLD_WHITE, "(password): ")) password = safe_input(color(Fore.BOLD_WHITE, "(password): "))
if not wizard_write( if not wizard_write(
path=path, path=path,

View file

@ -319,7 +319,7 @@ def test_wizard_accepts_default_answers_esp8266(tmpdir, monkeypatch, wizard_answ
config_file = tmpdir.join("test.yaml") config_file = tmpdir.join("test.yaml")
input_mock = MagicMock(side_effect=wizard_answers) input_mock = MagicMock(side_effect=wizard_answers)
monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr("builtins.input", input_mock)
monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0)
monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0)
monkeypatch.setattr(wz, "wizard_write", MagicMock()) monkeypatch.setattr(wz, "wizard_write", MagicMock())
@ -341,7 +341,7 @@ def test_wizard_accepts_default_answers_esp32(tmpdir, monkeypatch, wizard_answer
config_file = tmpdir.join("test.yaml") config_file = tmpdir.join("test.yaml")
input_mock = MagicMock(side_effect=wizard_answers) input_mock = MagicMock(side_effect=wizard_answers)
monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr("builtins.input", input_mock)
monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0)
monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0)
monkeypatch.setattr(wz, "wizard_write", MagicMock()) monkeypatch.setattr(wz, "wizard_write", MagicMock())
@ -371,7 +371,7 @@ def test_wizard_offers_better_node_name(tmpdir, monkeypatch, wizard_answers):
config_file = tmpdir.join("test.yaml") config_file = tmpdir.join("test.yaml")
input_mock = MagicMock(side_effect=wizard_answers) input_mock = MagicMock(side_effect=wizard_answers)
monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr("builtins.input", input_mock)
monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0)
monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0)
monkeypatch.setattr(wz, "wizard_write", MagicMock()) monkeypatch.setattr(wz, "wizard_write", MagicMock())
@ -394,7 +394,7 @@ def test_wizard_requires_correct_platform(tmpdir, monkeypatch, wizard_answers):
config_file = tmpdir.join("test.yaml") config_file = tmpdir.join("test.yaml")
input_mock = MagicMock(side_effect=wizard_answers) input_mock = MagicMock(side_effect=wizard_answers)
monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr("builtins.input", input_mock)
monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0)
monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0)
monkeypatch.setattr(wz, "wizard_write", MagicMock()) monkeypatch.setattr(wz, "wizard_write", MagicMock())
@ -416,7 +416,7 @@ def test_wizard_requires_correct_board(tmpdir, monkeypatch, wizard_answers):
config_file = tmpdir.join("test.yaml") config_file = tmpdir.join("test.yaml")
input_mock = MagicMock(side_effect=wizard_answers) input_mock = MagicMock(side_effect=wizard_answers)
monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr("builtins.input", input_mock)
monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0)
monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0)
monkeypatch.setattr(wz, "wizard_write", MagicMock()) monkeypatch.setattr(wz, "wizard_write", MagicMock())
@ -438,7 +438,7 @@ def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers):
config_file = tmpdir.join("test.yaml") config_file = tmpdir.join("test.yaml")
input_mock = MagicMock(side_effect=wizard_answers) input_mock = MagicMock(side_effect=wizard_answers)
monkeypatch.setattr("builtins.input", input_mock) monkeypatch.setattr("builtins.input", input_mock)
monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0)
monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0)
monkeypatch.setattr(wz, "wizard_write", MagicMock()) monkeypatch.setattr(wz, "wizard_write", MagicMock())