From 708ed8f38a6f9c0e71f903af9c9314b7f0a89310 Mon Sep 17 00:00:00 2001 From: Daniel Baulig Date: Mon, 6 Nov 2023 16:59:03 -0800 Subject: [PATCH] [web_server] Adds the ability to handle Private Network Access preflight requests (#5669) --- esphome/components/web_server/__init__.py | 4 ++ esphome/components/web_server/web_server.cpp | 37 +++++++++++++++++++ esphome/components/web_server/web_server.h | 5 +++ .../web_server_idf/web_server_idf.cpp | 14 +++++++ .../web_server_idf/web_server_idf.h | 3 ++ esphome/const.py | 1 + 6 files changed, 64 insertions(+) diff --git a/esphome/components/web_server/__init__.py b/esphome/components/web_server/__init__.py index b8698438e2..2708b5d06e 100644 --- a/esphome/components/web_server/__init__.py +++ b/esphome/components/web_server/__init__.py @@ -9,6 +9,7 @@ from esphome.const import ( CONF_ID, CONF_JS_INCLUDE, CONF_JS_URL, + CONF_ENABLE_PRIVATE_NETWORK_ACCESS, CONF_PORT, CONF_AUTH, CONF_USERNAME, @@ -68,6 +69,7 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_CSS_INCLUDE): cv.file_, cv.Optional(CONF_JS_URL): cv.string, cv.Optional(CONF_JS_INCLUDE): cv.file_, + cv.Optional(CONF_ENABLE_PRIVATE_NETWORK_ACCESS, default=True): cv.boolean, cv.Optional(CONF_AUTH): cv.Schema( { cv.Required(CONF_USERNAME): cv.All( @@ -158,6 +160,8 @@ async def to_code(config): cg.add(var.set_js_url(config[CONF_JS_URL])) cg.add(var.set_allow_ota(config[CONF_OTA])) cg.add(var.set_expose_log(config[CONF_LOG])) + if config[CONF_ENABLE_PRIVATE_NETWORK_ACCESS]: + cg.add_define("USE_WEBSERVER_PRIVATE_NETWORK_ACCESS") if CONF_AUTH in config: cg.add(paren.set_auth_username(config[CONF_AUTH][CONF_USERNAME])) cg.add(paren.set_auth_password(config[CONF_AUTH][CONF_PASSWORD])) diff --git a/esphome/components/web_server/web_server.cpp b/esphome/components/web_server/web_server.cpp index ccc86e5e53..0d72e274cd 100644 --- a/esphome/components/web_server/web_server.cpp +++ b/esphome/components/web_server/web_server.cpp @@ -34,6 +34,13 @@ namespace web_server { static const char *const TAG = "web_server"; +#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS +static const char *const HEADER_PNA_NAME = "Private-Network-Access-Name"; +static const char *const HEADER_PNA_ID = "Private-Network-Access-ID"; +static const char *const HEADER_CORS_REQ_PNA = "Access-Control-Request-Private-Network"; +static const char *const HEADER_CORS_ALLOW_PNA = "Access-Control-Allow-Private-Network"; +#endif + #if USE_WEBSERVER_VERSION == 1 void write_row(AsyncResponseStream *stream, EntityBase *obj, const std::string &klass, const std::string &action, const std::function &action_func = nullptr) { @@ -359,6 +366,17 @@ void WebServer::handle_index_request(AsyncWebServerRequest *request) { } #endif +#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS +void WebServer::handle_pna_cors_request(AsyncWebServerRequest *request) { + AsyncWebServerResponse *response = request->beginResponse(200, ""); + response->addHeader(HEADER_CORS_ALLOW_PNA, "true"); + response->addHeader(HEADER_PNA_NAME, App.get_name().c_str()); + std::string mac = get_mac_address_pretty(); + response->addHeader(HEADER_PNA_ID, mac.c_str()); + request->send(response); +} +#endif + #ifdef USE_WEBSERVER_CSS_INCLUDE void WebServer::handle_css_request(AsyncWebServerRequest *request) { AsyncWebServerResponse *response = @@ -1145,6 +1163,18 @@ bool WebServer::canHandle(AsyncWebServerRequest *request) { return true; #endif +#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS + if (request->method() == HTTP_OPTIONS && request->hasHeader(HEADER_CORS_REQ_PNA)) { +#ifdef USE_ARDUINO + // Header needs to be added to interesting header list for it to not be + // nuked by the time we handle the request later. + // Only required in Arduino framework. + request->addInterestingHeader(HEADER_CORS_REQ_PNA); +#endif + return true; + } +#endif + UrlMatch match = match_url(request->url().c_str(), true); if (!match.valid) return false; @@ -1240,6 +1270,13 @@ void WebServer::handleRequest(AsyncWebServerRequest *request) { } #endif +#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS + if (request->method() == HTTP_OPTIONS && request->hasHeader(HEADER_CORS_REQ_PNA)) { + this->handle_pna_cors_request(request); + return; + } +#endif + UrlMatch match = match_url(request->url().c_str()); #ifdef USE_SENSOR if (match.domain == "sensor") { diff --git a/esphome/components/web_server/web_server.h b/esphome/components/web_server/web_server.h index 45b99d4eba..465e231984 100644 --- a/esphome/components/web_server/web_server.h +++ b/esphome/components/web_server/web_server.h @@ -130,6 +130,11 @@ class WebServer : public Controller, public Component, public AsyncWebHandler { void handle_js_request(AsyncWebServerRequest *request); #endif +#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS + // Handle Private Network Access CORS OPTIONS request + void handle_pna_cors_request(AsyncWebServerRequest *request); +#endif + #ifdef USE_SENSOR void on_sensor_update(sensor::Sensor *obj, float state) override; /// Handle a sensor request under '/sensor/'. diff --git a/esphome/components/web_server_idf/web_server_idf.cpp b/esphome/components/web_server_idf/web_server_idf.cpp index 444e682460..8e67f3f169 100644 --- a/esphome/components/web_server_idf/web_server_idf.cpp +++ b/esphome/components/web_server_idf/web_server_idf.cpp @@ -51,6 +51,14 @@ void AsyncWebServer::begin() { .user_ctx = this, }; httpd_register_uri_handler(this->server_, &handler_post); + + const httpd_uri_t handler_options = { + .uri = "", + .method = HTTP_OPTIONS, + .handler = AsyncWebServer::request_handler, + .user_ctx = this, + }; + httpd_register_uri_handler(this->server_, &handler_options); } } @@ -80,6 +88,8 @@ AsyncWebServerRequest::~AsyncWebServerRequest() { } } +bool AsyncWebServerRequest::hasHeader(const char *name) const { return httpd_req_get_hdr_value_len(*this, name); } + optional AsyncWebServerRequest::get_header(const char *name) const { size_t buf_len = httpd_req_get_hdr_value_len(*this, name); if (buf_len == 0) { @@ -305,6 +315,10 @@ AsyncEventSourceResponse::AsyncEventSourceResponse(const AsyncWebServerRequest * httpd_resp_set_hdr(req, "Cache-Control", "no-cache"); httpd_resp_set_hdr(req, "Connection", "keep-alive"); + for (const auto &pair : DefaultHeaders::Instance().headers_) { + httpd_resp_set_hdr(req, pair.first.c_str(), pair.second.c_str()); + } + httpd_resp_send_chunk(req, CRLF_STR, CRLF_LEN); req->sess_ctx = this; diff --git a/esphome/components/web_server_idf/web_server_idf.h b/esphome/components/web_server_idf/web_server_idf.h index f3cecca16f..bc64e5231e 100644 --- a/esphome/components/web_server_idf/web_server_idf.h +++ b/esphome/components/web_server_idf/web_server_idf.h @@ -157,6 +157,8 @@ class AsyncWebServerRequest { operator httpd_req_t *() const { return this->req_; } optional get_header(const char *name) const; + // NOLINTNEXTLINE(readability-identifier-naming) + bool hasHeader(const char *name) const; protected: httpd_req_t *req_; @@ -254,6 +256,7 @@ class AsyncEventSource : public AsyncWebHandler { class DefaultHeaders { friend class AsyncWebServerRequest; + friend class AsyncEventSourceResponse; public: // NOLINTNEXTLINE(readability-identifier-naming) diff --git a/esphome/const.py b/esphome/const.py index 9457958863..c2fa9951ff 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -227,6 +227,7 @@ CONF_ELSE = "else" CONF_ENABLE_BTM = "enable_btm" CONF_ENABLE_IPV6 = "enable_ipv6" CONF_ENABLE_PIN = "enable_pin" +CONF_ENABLE_PRIVATE_NETWORK_ACCESS = "enable_private_network_access" CONF_ENABLE_RRM = "enable_rrm" CONF_ENABLE_TIME = "enable_time" CONF_ENERGY = "energy"