diff --git a/esphome/components/wireguard/__init__.py b/esphome/components/wireguard/__init__.py index acb5f690ec..b59a6011cd 100644 --- a/esphome/components/wireguard/__init__.py +++ b/esphome/components/wireguard/__init__.py @@ -10,6 +10,7 @@ from esphome.const import ( ) from esphome.components import time from esphome.core import TimePeriod +from esphome import automation CONF_NETMASK = "netmask" CONF_PRIVATE_KEY = "private_key" @@ -30,6 +31,16 @@ _WG_KEY_REGEX = re.compile(r"^[A-Za-z0-9+/]{42}[AEIMQUYcgkosw480]=$") wireguard_ns = cg.esphome_ns.namespace("wireguard") Wireguard = wireguard_ns.class_("Wireguard", cg.Component, cg.PollingComponent) +WireguardPeerOnlineCondition = wireguard_ns.class_( + "WireguardPeerOnlineCondition", automation.Condition +) +WireguardEnabledCondition = wireguard_ns.class_( + "WireguardEnabledCondition", automation.Condition +) +WireguardEnableAction = wireguard_ns.class_("WireguardEnableAction", automation.Action) +WireguardDisableAction = wireguard_ns.class_( + "WireguardDisableAction", automation.Action +) def _wireguard_key(value): @@ -112,3 +123,47 @@ async def to_code(config): cg.add_library("droscy/esp_wireguard", "0.3.2") await cg.register_component(var, config) + + +@automation.register_condition( + "wireguard.peer_online", + WireguardPeerOnlineCondition, + cv.Schema({cv.GenerateID(): cv.use_id(Wireguard)}), +) +async def wireguard_peer_up_to_code(config, condition_id, template_arg, args): + var = cg.new_Pvariable(condition_id, template_arg) + await cg.register_parented(var, config[CONF_ID]) + return var + + +@automation.register_condition( + "wireguard.enabled", + WireguardEnabledCondition, + cv.Schema({cv.GenerateID(): cv.use_id(Wireguard)}), +) +async def wireguard_enabled_to_code(config, condition_id, template_arg, args): + var = cg.new_Pvariable(condition_id, template_arg) + await cg.register_parented(var, config[CONF_ID]) + return var + + +@automation.register_action( + "wireguard.enable", + WireguardEnableAction, + cv.Schema({cv.GenerateID(): cv.use_id(Wireguard)}), +) +async def wireguard_enable_to_code(config, action_id, template_arg, args): + var = cg.new_Pvariable(action_id, template_arg) + await cg.register_parented(var, config[CONF_ID]) + return var + + +@automation.register_action( + "wireguard.disable", + WireguardDisableAction, + cv.Schema({cv.GenerateID(): cv.use_id(Wireguard)}), +) +async def wireguard_disable_to_code(config, action_id, template_arg, args): + var = cg.new_Pvariable(action_id, template_arg) + await cg.register_parented(var, config[CONF_ID]) + return var diff --git a/esphome/components/wireguard/binary_sensor.py b/esphome/components/wireguard/binary_sensor.py index 14ff2b0159..bf60aaa1d6 100644 --- a/esphome/components/wireguard/binary_sensor.py +++ b/esphome/components/wireguard/binary_sensor.py @@ -4,11 +4,13 @@ from esphome.components import binary_sensor from esphome.const import ( CONF_STATUS, DEVICE_CLASS_CONNECTIVITY, + ENTITY_CATEGORY_DIAGNOSTIC, ) from . import Wireguard CONF_WIREGUARD_ID = "wireguard_id" +CONF_ENABLED = "enabled" DEPENDENCIES = ["wireguard"] @@ -17,6 +19,9 @@ CONFIG_SCHEMA = { cv.Optional(CONF_STATUS): binary_sensor.binary_sensor_schema( device_class=DEVICE_CLASS_CONNECTIVITY, ), + cv.Optional(CONF_ENABLED): binary_sensor.binary_sensor_schema( + entity_category=ENTITY_CATEGORY_DIAGNOSTIC, + ), } @@ -26,3 +31,7 @@ async def to_code(config): if status_config := config.get(CONF_STATUS): sens = await binary_sensor.new_binary_sensor(status_config) cg.add(parent.set_status_sensor(sens)) + + if enabled_config := config.get(CONF_ENABLED): + sens = await binary_sensor.new_binary_sensor(enabled_config) + cg.add(parent.set_enabled_sensor(sens)) diff --git a/esphome/components/wireguard/wireguard.cpp b/esphome/components/wireguard/wireguard.cpp index f89a5ebbad..cca30d4310 100644 --- a/esphome/components/wireguard/wireguard.cpp +++ b/esphome/components/wireguard/wireguard.cpp @@ -48,6 +48,8 @@ void Wireguard::setup() { if (this->preshared_key_.length() > 0) this->wg_config_.preshared_key = this->preshared_key_.c_str(); + this->publish_enabled_state(); + this->wg_initialized_ = esp_wireguard_init(&(this->wg_config_), &(this->wg_ctx_)); if (this->wg_initialized_ == ESP_OK) { @@ -68,6 +70,10 @@ void Wireguard::setup() { } void Wireguard::loop() { + if (!this->enabled_) { + return; + } + if ((this->wg_initialized_ == ESP_OK) && (this->wg_connected_ == ESP_OK) && (!network::is_connected())) { ESP_LOGV(TAG, "local network connection has been lost, stopping WireGuard..."); this->stop_connection_(); @@ -79,8 +85,9 @@ void Wireguard::update() { time_t lhs = this->get_latest_handshake(); bool lhs_updated = (lhs > this->latest_saved_handshake_); - ESP_LOGV(TAG, "handshake: latest=%.0f, saved=%.0f, updated=%d", (double) lhs, (double) this->latest_saved_handshake_, - (int) lhs_updated); + ESP_LOGV(TAG, "enabled=%d, connected=%d, peer_up=%d, handshake: current=%.0f latest=%.0f updated=%d", + (int) this->enabled_, (int) (this->wg_connected_ == ESP_OK), (int) peer_up, (double) lhs, + (double) this->latest_saved_handshake_, (int) lhs_updated); if (lhs_updated) { this->latest_saved_handshake_ = lhs; @@ -102,13 +109,13 @@ void Wireguard::update() { if (this->wg_peer_offline_time_ == 0) { ESP_LOGW(TAG, LOGMSG_PEER_STATUS, LOGMSG_OFFLINE, latest_handshake.c_str()); this->wg_peer_offline_time_ = millis(); - } else { + } else if (this->enabled_) { ESP_LOGD(TAG, LOGMSG_PEER_STATUS, LOGMSG_OFFLINE, latest_handshake.c_str()); this->start_connection_(); } // check reboot timeout every time the peer is down - if (this->reboot_timeout_ > 0) { + if (this->enabled_ && this->reboot_timeout_ > 0) { if (millis() - this->wg_peer_offline_time_ > this->reboot_timeout_) { ESP_LOGE(TAG, "WireGuard remote peer is unreachable, rebooting..."); App.reboot(); @@ -154,7 +161,7 @@ void Wireguard::dump_config() { void Wireguard::on_shutdown() { this->stop_connection_(); } -bool Wireguard::can_proceed() { return (this->proceed_allowed_ || this->is_peer_up()); } +bool Wireguard::can_proceed() { return (this->proceed_allowed_ || this->is_peer_up() || !this->enabled_); } bool Wireguard::is_peer_up() const { return (this->wg_initialized_ == ESP_OK) && (this->wg_connected_ == ESP_OK) && @@ -187,6 +194,7 @@ void Wireguard::set_srctime(time::RealTimeClock *srctime) { this->srctime_ = src #ifdef USE_BINARY_SENSOR void Wireguard::set_status_sensor(binary_sensor::BinarySensor *sensor) { this->status_sensor_ = sensor; } +void Wireguard::set_enabled_sensor(binary_sensor::BinarySensor *sensor) { this->enabled_sensor_ = sensor; } #endif #ifdef USE_SENSOR @@ -199,7 +207,35 @@ void Wireguard::set_address_sensor(text_sensor::TextSensor *sensor) { this->addr void Wireguard::disable_auto_proceed() { this->proceed_allowed_ = false; } +void Wireguard::enable() { + this->enabled_ = true; + ESP_LOGI(TAG, "WireGuard enabled"); + this->publish_enabled_state(); +} + +void Wireguard::disable() { + this->enabled_ = false; + this->defer(std::bind(&Wireguard::stop_connection_, this)); // defer to avoid blocking running loop + ESP_LOGI(TAG, "WireGuard disabled"); + this->publish_enabled_state(); +} + +void Wireguard::publish_enabled_state() { +#ifdef USE_BINARY_SENSOR + if (this->enabled_sensor_ != nullptr) { + this->enabled_sensor_->publish_state(this->enabled_); + } +#endif +} + +bool Wireguard::is_enabled() { return this->enabled_; } + void Wireguard::start_connection_() { + if (!this->enabled_) { + ESP_LOGV(TAG, "WireGuard is disabled, cannot start connection"); + return; + } + if (this->wg_initialized_ != ESP_OK) { ESP_LOGE(TAG, "cannot start WireGuard, initialization in error with code %d", this->wg_initialized_); return; diff --git a/esphome/components/wireguard/wireguard.h b/esphome/components/wireguard/wireguard.h index c47d9e6603..7753a8dfc2 100644 --- a/esphome/components/wireguard/wireguard.h +++ b/esphome/components/wireguard/wireguard.h @@ -26,6 +26,7 @@ namespace esphome { namespace wireguard { +/// Main Wireguard component class. class Wireguard : public PollingComponent { public: void setup() override; @@ -53,6 +54,7 @@ class Wireguard : public PollingComponent { #ifdef USE_BINARY_SENSOR void set_status_sensor(binary_sensor::BinarySensor *sensor); + void set_enabled_sensor(binary_sensor::BinarySensor *sensor); #endif #ifdef USE_SENSOR @@ -66,6 +68,18 @@ class Wireguard : public PollingComponent { /// Block the setup step until peer is connected. void disable_auto_proceed(); + /// Enable the WireGuard component. + void enable(); + + /// Stop any running connection and disable the WireGuard component. + void disable(); + + /// Publish the enabled state if the enabled binary sensor is configured. + void publish_enabled_state(); + + /// Return if the WireGuard component is or is not enabled. + bool is_enabled(); + bool is_peer_up() const; time_t get_latest_handshake() const; @@ -87,6 +101,7 @@ class Wireguard : public PollingComponent { #ifdef USE_BINARY_SENSOR binary_sensor::BinarySensor *status_sensor_ = nullptr; + binary_sensor::BinarySensor *enabled_sensor_ = nullptr; #endif #ifdef USE_SENSOR @@ -100,6 +115,9 @@ class Wireguard : public PollingComponent { /// Set to false to block the setup step until peer is connected. bool proceed_allowed_ = true; + /// When false the wireguard link will not be established + bool enabled_ = true; + wireguard_config_t wg_config_ = ESP_WIREGUARD_CONFIG_DEFAULT(); wireguard_ctx_t wg_ctx_ = ESP_WIREGUARD_CONTEXT_DEFAULT(); @@ -128,6 +146,30 @@ void resume_wdt(); /// Strip most part of the key only for secure printing std::string mask_key(const std::string &key); +/// Condition to check if remote peer is online. +template class WireguardPeerOnlineCondition : public Condition, public Parented { + public: + bool check(Ts... x) override { return this->parent_->is_peer_up(); } +}; + +/// Condition to check if Wireguard component is enabled. +template class WireguardEnabledCondition : public Condition, public Parented { + public: + bool check(Ts... x) override { return this->parent_->is_enabled(); } +}; + +/// Action to enable Wireguard component. +template class WireguardEnableAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->enable(); } +}; + +/// Action to disable Wireguard component. +template class WireguardDisableAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->disable(); } +}; + } // namespace wireguard } // namespace esphome diff --git a/tests/test10.yaml b/tests/test10.yaml index dda7601048..7e3a685b36 100644 --- a/tests/test10.yaml +++ b/tests/test10.yaml @@ -44,6 +44,8 @@ binary_sensor: - platform: wireguard status: name: 'WireGuard Status' + enabled: + name: 'WireGuard Enabled' sensor: - platform: wireguard @@ -54,3 +56,26 @@ text_sensor: - platform: wireguard address: name: 'WireGuard Address' + +button: + - platform: template + name: 'Toggle WireGuard' + entity_category: config + on_press: + - if: + condition: wireguard.enabled + then: + - wireguard.disable: + else: + - wireguard.enable: + + - platform: template + name: 'Log WireGuard status' + entity_category: config + on_press: + - if: + condition: wireguard.peer_online + then: + - logger.log: 'wireguard remote peer is online' + else: + - logger.log: 'wireguard remote peer is offline'