Actions to enable and disable WireGuard connection (#5690)

This commit is contained in:
Simone Rossetto 2024-01-11 06:09:42 +01:00 committed by GitHub
parent 082d9fcf0e
commit d616025fed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 172 additions and 5 deletions

View file

@ -10,6 +10,7 @@ from esphome.const import (
)
from esphome.components import time
from esphome.core import TimePeriod
from esphome import automation
CONF_NETMASK = "netmask"
CONF_PRIVATE_KEY = "private_key"
@ -30,6 +31,16 @@ _WG_KEY_REGEX = re.compile(r"^[A-Za-z0-9+/]{42}[AEIMQUYcgkosw480]=$")
wireguard_ns = cg.esphome_ns.namespace("wireguard")
Wireguard = wireguard_ns.class_("Wireguard", cg.Component, cg.PollingComponent)
WireguardPeerOnlineCondition = wireguard_ns.class_(
"WireguardPeerOnlineCondition", automation.Condition
)
WireguardEnabledCondition = wireguard_ns.class_(
"WireguardEnabledCondition", automation.Condition
)
WireguardEnableAction = wireguard_ns.class_("WireguardEnableAction", automation.Action)
WireguardDisableAction = wireguard_ns.class_(
"WireguardDisableAction", automation.Action
)
def _wireguard_key(value):
@ -112,3 +123,47 @@ async def to_code(config):
cg.add_library("droscy/esp_wireguard", "0.3.2")
await cg.register_component(var, config)
@automation.register_condition(
"wireguard.peer_online",
WireguardPeerOnlineCondition,
cv.Schema({cv.GenerateID(): cv.use_id(Wireguard)}),
)
async def wireguard_peer_up_to_code(config, condition_id, template_arg, args):
var = cg.new_Pvariable(condition_id, template_arg)
await cg.register_parented(var, config[CONF_ID])
return var
@automation.register_condition(
"wireguard.enabled",
WireguardEnabledCondition,
cv.Schema({cv.GenerateID(): cv.use_id(Wireguard)}),
)
async def wireguard_enabled_to_code(config, condition_id, template_arg, args):
var = cg.new_Pvariable(condition_id, template_arg)
await cg.register_parented(var, config[CONF_ID])
return var
@automation.register_action(
"wireguard.enable",
WireguardEnableAction,
cv.Schema({cv.GenerateID(): cv.use_id(Wireguard)}),
)
async def wireguard_enable_to_code(config, action_id, template_arg, args):
var = cg.new_Pvariable(action_id, template_arg)
await cg.register_parented(var, config[CONF_ID])
return var
@automation.register_action(
"wireguard.disable",
WireguardDisableAction,
cv.Schema({cv.GenerateID(): cv.use_id(Wireguard)}),
)
async def wireguard_disable_to_code(config, action_id, template_arg, args):
var = cg.new_Pvariable(action_id, template_arg)
await cg.register_parented(var, config[CONF_ID])
return var

View file

@ -4,11 +4,13 @@ from esphome.components import binary_sensor
from esphome.const import (
CONF_STATUS,
DEVICE_CLASS_CONNECTIVITY,
ENTITY_CATEGORY_DIAGNOSTIC,
)
from . import Wireguard
CONF_WIREGUARD_ID = "wireguard_id"
CONF_ENABLED = "enabled"
DEPENDENCIES = ["wireguard"]
@ -17,6 +19,9 @@ CONFIG_SCHEMA = {
cv.Optional(CONF_STATUS): binary_sensor.binary_sensor_schema(
device_class=DEVICE_CLASS_CONNECTIVITY,
),
cv.Optional(CONF_ENABLED): binary_sensor.binary_sensor_schema(
entity_category=ENTITY_CATEGORY_DIAGNOSTIC,
),
}
@ -26,3 +31,7 @@ async def to_code(config):
if status_config := config.get(CONF_STATUS):
sens = await binary_sensor.new_binary_sensor(status_config)
cg.add(parent.set_status_sensor(sens))
if enabled_config := config.get(CONF_ENABLED):
sens = await binary_sensor.new_binary_sensor(enabled_config)
cg.add(parent.set_enabled_sensor(sens))

View file

@ -48,6 +48,8 @@ void Wireguard::setup() {
if (this->preshared_key_.length() > 0)
this->wg_config_.preshared_key = this->preshared_key_.c_str();
this->publish_enabled_state();
this->wg_initialized_ = esp_wireguard_init(&(this->wg_config_), &(this->wg_ctx_));
if (this->wg_initialized_ == ESP_OK) {
@ -68,6 +70,10 @@ void Wireguard::setup() {
}
void Wireguard::loop() {
if (!this->enabled_) {
return;
}
if ((this->wg_initialized_ == ESP_OK) && (this->wg_connected_ == ESP_OK) && (!network::is_connected())) {
ESP_LOGV(TAG, "local network connection has been lost, stopping WireGuard...");
this->stop_connection_();
@ -79,8 +85,9 @@ void Wireguard::update() {
time_t lhs = this->get_latest_handshake();
bool lhs_updated = (lhs > this->latest_saved_handshake_);
ESP_LOGV(TAG, "handshake: latest=%.0f, saved=%.0f, updated=%d", (double) lhs, (double) this->latest_saved_handshake_,
(int) lhs_updated);
ESP_LOGV(TAG, "enabled=%d, connected=%d, peer_up=%d, handshake: current=%.0f latest=%.0f updated=%d",
(int) this->enabled_, (int) (this->wg_connected_ == ESP_OK), (int) peer_up, (double) lhs,
(double) this->latest_saved_handshake_, (int) lhs_updated);
if (lhs_updated) {
this->latest_saved_handshake_ = lhs;
@ -102,13 +109,13 @@ void Wireguard::update() {
if (this->wg_peer_offline_time_ == 0) {
ESP_LOGW(TAG, LOGMSG_PEER_STATUS, LOGMSG_OFFLINE, latest_handshake.c_str());
this->wg_peer_offline_time_ = millis();
} else {
} else if (this->enabled_) {
ESP_LOGD(TAG, LOGMSG_PEER_STATUS, LOGMSG_OFFLINE, latest_handshake.c_str());
this->start_connection_();
}
// check reboot timeout every time the peer is down
if (this->reboot_timeout_ > 0) {
if (this->enabled_ && this->reboot_timeout_ > 0) {
if (millis() - this->wg_peer_offline_time_ > this->reboot_timeout_) {
ESP_LOGE(TAG, "WireGuard remote peer is unreachable, rebooting...");
App.reboot();
@ -154,7 +161,7 @@ void Wireguard::dump_config() {
void Wireguard::on_shutdown() { this->stop_connection_(); }
bool Wireguard::can_proceed() { return (this->proceed_allowed_ || this->is_peer_up()); }
bool Wireguard::can_proceed() { return (this->proceed_allowed_ || this->is_peer_up() || !this->enabled_); }
bool Wireguard::is_peer_up() const {
return (this->wg_initialized_ == ESP_OK) && (this->wg_connected_ == ESP_OK) &&
@ -187,6 +194,7 @@ void Wireguard::set_srctime(time::RealTimeClock *srctime) { this->srctime_ = src
#ifdef USE_BINARY_SENSOR
void Wireguard::set_status_sensor(binary_sensor::BinarySensor *sensor) { this->status_sensor_ = sensor; }
void Wireguard::set_enabled_sensor(binary_sensor::BinarySensor *sensor) { this->enabled_sensor_ = sensor; }
#endif
#ifdef USE_SENSOR
@ -199,7 +207,35 @@ void Wireguard::set_address_sensor(text_sensor::TextSensor *sensor) { this->addr
void Wireguard::disable_auto_proceed() { this->proceed_allowed_ = false; }
void Wireguard::enable() {
this->enabled_ = true;
ESP_LOGI(TAG, "WireGuard enabled");
this->publish_enabled_state();
}
void Wireguard::disable() {
this->enabled_ = false;
this->defer(std::bind(&Wireguard::stop_connection_, this)); // defer to avoid blocking running loop
ESP_LOGI(TAG, "WireGuard disabled");
this->publish_enabled_state();
}
void Wireguard::publish_enabled_state() {
#ifdef USE_BINARY_SENSOR
if (this->enabled_sensor_ != nullptr) {
this->enabled_sensor_->publish_state(this->enabled_);
}
#endif
}
bool Wireguard::is_enabled() { return this->enabled_; }
void Wireguard::start_connection_() {
if (!this->enabled_) {
ESP_LOGV(TAG, "WireGuard is disabled, cannot start connection");
return;
}
if (this->wg_initialized_ != ESP_OK) {
ESP_LOGE(TAG, "cannot start WireGuard, initialization in error with code %d", this->wg_initialized_);
return;

View file

@ -26,6 +26,7 @@
namespace esphome {
namespace wireguard {
/// Main Wireguard component class.
class Wireguard : public PollingComponent {
public:
void setup() override;
@ -53,6 +54,7 @@ class Wireguard : public PollingComponent {
#ifdef USE_BINARY_SENSOR
void set_status_sensor(binary_sensor::BinarySensor *sensor);
void set_enabled_sensor(binary_sensor::BinarySensor *sensor);
#endif
#ifdef USE_SENSOR
@ -66,6 +68,18 @@ class Wireguard : public PollingComponent {
/// Block the setup step until peer is connected.
void disable_auto_proceed();
/// Enable the WireGuard component.
void enable();
/// Stop any running connection and disable the WireGuard component.
void disable();
/// Publish the enabled state if the enabled binary sensor is configured.
void publish_enabled_state();
/// Return if the WireGuard component is or is not enabled.
bool is_enabled();
bool is_peer_up() const;
time_t get_latest_handshake() const;
@ -87,6 +101,7 @@ class Wireguard : public PollingComponent {
#ifdef USE_BINARY_SENSOR
binary_sensor::BinarySensor *status_sensor_ = nullptr;
binary_sensor::BinarySensor *enabled_sensor_ = nullptr;
#endif
#ifdef USE_SENSOR
@ -100,6 +115,9 @@ class Wireguard : public PollingComponent {
/// Set to false to block the setup step until peer is connected.
bool proceed_allowed_ = true;
/// When false the wireguard link will not be established
bool enabled_ = true;
wireguard_config_t wg_config_ = ESP_WIREGUARD_CONFIG_DEFAULT();
wireguard_ctx_t wg_ctx_ = ESP_WIREGUARD_CONTEXT_DEFAULT();
@ -128,6 +146,30 @@ void resume_wdt();
/// Strip most part of the key only for secure printing
std::string mask_key(const std::string &key);
/// Condition to check if remote peer is online.
template<typename... Ts> class WireguardPeerOnlineCondition : public Condition<Ts...>, public Parented<Wireguard> {
public:
bool check(Ts... x) override { return this->parent_->is_peer_up(); }
};
/// Condition to check if Wireguard component is enabled.
template<typename... Ts> class WireguardEnabledCondition : public Condition<Ts...>, public Parented<Wireguard> {
public:
bool check(Ts... x) override { return this->parent_->is_enabled(); }
};
/// Action to enable Wireguard component.
template<typename... Ts> class WireguardEnableAction : public Action<Ts...>, public Parented<Wireguard> {
public:
void play(Ts... x) override { this->parent_->enable(); }
};
/// Action to disable Wireguard component.
template<typename... Ts> class WireguardDisableAction : public Action<Ts...>, public Parented<Wireguard> {
public:
void play(Ts... x) override { this->parent_->disable(); }
};
} // namespace wireguard
} // namespace esphome

View file

@ -44,6 +44,8 @@ binary_sensor:
- platform: wireguard
status:
name: 'WireGuard Status'
enabled:
name: 'WireGuard Enabled'
sensor:
- platform: wireguard
@ -54,3 +56,26 @@ text_sensor:
- platform: wireguard
address:
name: 'WireGuard Address'
button:
- platform: template
name: 'Toggle WireGuard'
entity_category: config
on_press:
- if:
condition: wireguard.enabled
then:
- wireguard.disable:
else:
- wireguard.enable:
- platform: template
name: 'Log WireGuard status'
entity_category: config
on_press:
- if:
condition: wireguard.peer_online
then:
- logger.log: 'wireguard remote peer is online'
else:
- logger.log: 'wireguard remote peer is offline'