[web_server] Adds the ability to handle Private Network Access preflight requests (#5669)

This commit is contained in:
Daniel Baulig 2023-11-06 16:59:03 -08:00 committed by GitHub
parent a8a9c6192d
commit 708ed8f38a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 0 deletions

View file

@ -9,6 +9,7 @@ from esphome.const import (
CONF_ID, CONF_ID,
CONF_JS_INCLUDE, CONF_JS_INCLUDE,
CONF_JS_URL, CONF_JS_URL,
CONF_ENABLE_PRIVATE_NETWORK_ACCESS,
CONF_PORT, CONF_PORT,
CONF_AUTH, CONF_AUTH,
CONF_USERNAME, CONF_USERNAME,
@ -68,6 +69,7 @@ CONFIG_SCHEMA = cv.All(
cv.Optional(CONF_CSS_INCLUDE): cv.file_, cv.Optional(CONF_CSS_INCLUDE): cv.file_,
cv.Optional(CONF_JS_URL): cv.string, cv.Optional(CONF_JS_URL): cv.string,
cv.Optional(CONF_JS_INCLUDE): cv.file_, cv.Optional(CONF_JS_INCLUDE): cv.file_,
cv.Optional(CONF_ENABLE_PRIVATE_NETWORK_ACCESS, default=True): cv.boolean,
cv.Optional(CONF_AUTH): cv.Schema( cv.Optional(CONF_AUTH): cv.Schema(
{ {
cv.Required(CONF_USERNAME): cv.All( cv.Required(CONF_USERNAME): cv.All(
@ -158,6 +160,8 @@ async def to_code(config):
cg.add(var.set_js_url(config[CONF_JS_URL])) cg.add(var.set_js_url(config[CONF_JS_URL]))
cg.add(var.set_allow_ota(config[CONF_OTA])) cg.add(var.set_allow_ota(config[CONF_OTA]))
cg.add(var.set_expose_log(config[CONF_LOG])) cg.add(var.set_expose_log(config[CONF_LOG]))
if config[CONF_ENABLE_PRIVATE_NETWORK_ACCESS]:
cg.add_define("USE_WEBSERVER_PRIVATE_NETWORK_ACCESS")
if CONF_AUTH in config: if CONF_AUTH in config:
cg.add(paren.set_auth_username(config[CONF_AUTH][CONF_USERNAME])) cg.add(paren.set_auth_username(config[CONF_AUTH][CONF_USERNAME]))
cg.add(paren.set_auth_password(config[CONF_AUTH][CONF_PASSWORD])) cg.add(paren.set_auth_password(config[CONF_AUTH][CONF_PASSWORD]))

View file

@ -34,6 +34,13 @@ namespace web_server {
static const char *const TAG = "web_server"; static const char *const TAG = "web_server";
#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS
static const char *const HEADER_PNA_NAME = "Private-Network-Access-Name";
static const char *const HEADER_PNA_ID = "Private-Network-Access-ID";
static const char *const HEADER_CORS_REQ_PNA = "Access-Control-Request-Private-Network";
static const char *const HEADER_CORS_ALLOW_PNA = "Access-Control-Allow-Private-Network";
#endif
#if USE_WEBSERVER_VERSION == 1 #if USE_WEBSERVER_VERSION == 1
void write_row(AsyncResponseStream *stream, EntityBase *obj, const std::string &klass, const std::string &action, void write_row(AsyncResponseStream *stream, EntityBase *obj, const std::string &klass, const std::string &action,
const std::function<void(AsyncResponseStream &stream, EntityBase *obj)> &action_func = nullptr) { const std::function<void(AsyncResponseStream &stream, EntityBase *obj)> &action_func = nullptr) {
@ -359,6 +366,17 @@ void WebServer::handle_index_request(AsyncWebServerRequest *request) {
} }
#endif #endif
#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS
void WebServer::handle_pna_cors_request(AsyncWebServerRequest *request) {
AsyncWebServerResponse *response = request->beginResponse(200, "");
response->addHeader(HEADER_CORS_ALLOW_PNA, "true");
response->addHeader(HEADER_PNA_NAME, App.get_name().c_str());
std::string mac = get_mac_address_pretty();
response->addHeader(HEADER_PNA_ID, mac.c_str());
request->send(response);
}
#endif
#ifdef USE_WEBSERVER_CSS_INCLUDE #ifdef USE_WEBSERVER_CSS_INCLUDE
void WebServer::handle_css_request(AsyncWebServerRequest *request) { void WebServer::handle_css_request(AsyncWebServerRequest *request) {
AsyncWebServerResponse *response = AsyncWebServerResponse *response =
@ -1145,6 +1163,18 @@ bool WebServer::canHandle(AsyncWebServerRequest *request) {
return true; return true;
#endif #endif
#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS
if (request->method() == HTTP_OPTIONS && request->hasHeader(HEADER_CORS_REQ_PNA)) {
#ifdef USE_ARDUINO
// Header needs to be added to interesting header list for it to not be
// nuked by the time we handle the request later.
// Only required in Arduino framework.
request->addInterestingHeader(HEADER_CORS_REQ_PNA);
#endif
return true;
}
#endif
UrlMatch match = match_url(request->url().c_str(), true); UrlMatch match = match_url(request->url().c_str(), true);
if (!match.valid) if (!match.valid)
return false; return false;
@ -1240,6 +1270,13 @@ void WebServer::handleRequest(AsyncWebServerRequest *request) {
} }
#endif #endif
#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS
if (request->method() == HTTP_OPTIONS && request->hasHeader(HEADER_CORS_REQ_PNA)) {
this->handle_pna_cors_request(request);
return;
}
#endif
UrlMatch match = match_url(request->url().c_str()); UrlMatch match = match_url(request->url().c_str());
#ifdef USE_SENSOR #ifdef USE_SENSOR
if (match.domain == "sensor") { if (match.domain == "sensor") {

View file

@ -130,6 +130,11 @@ class WebServer : public Controller, public Component, public AsyncWebHandler {
void handle_js_request(AsyncWebServerRequest *request); void handle_js_request(AsyncWebServerRequest *request);
#endif #endif
#ifdef USE_WEBSERVER_PRIVATE_NETWORK_ACCESS
// Handle Private Network Access CORS OPTIONS request
void handle_pna_cors_request(AsyncWebServerRequest *request);
#endif
#ifdef USE_SENSOR #ifdef USE_SENSOR
void on_sensor_update(sensor::Sensor *obj, float state) override; void on_sensor_update(sensor::Sensor *obj, float state) override;
/// Handle a sensor request under '/sensor/<id>'. /// Handle a sensor request under '/sensor/<id>'.

View file

@ -51,6 +51,14 @@ void AsyncWebServer::begin() {
.user_ctx = this, .user_ctx = this,
}; };
httpd_register_uri_handler(this->server_, &handler_post); httpd_register_uri_handler(this->server_, &handler_post);
const httpd_uri_t handler_options = {
.uri = "",
.method = HTTP_OPTIONS,
.handler = AsyncWebServer::request_handler,
.user_ctx = this,
};
httpd_register_uri_handler(this->server_, &handler_options);
} }
} }
@ -80,6 +88,8 @@ AsyncWebServerRequest::~AsyncWebServerRequest() {
} }
} }
bool AsyncWebServerRequest::hasHeader(const char *name) const { return httpd_req_get_hdr_value_len(*this, name); }
optional<std::string> AsyncWebServerRequest::get_header(const char *name) const { optional<std::string> AsyncWebServerRequest::get_header(const char *name) const {
size_t buf_len = httpd_req_get_hdr_value_len(*this, name); size_t buf_len = httpd_req_get_hdr_value_len(*this, name);
if (buf_len == 0) { if (buf_len == 0) {
@ -305,6 +315,10 @@ AsyncEventSourceResponse::AsyncEventSourceResponse(const AsyncWebServerRequest *
httpd_resp_set_hdr(req, "Cache-Control", "no-cache"); httpd_resp_set_hdr(req, "Cache-Control", "no-cache");
httpd_resp_set_hdr(req, "Connection", "keep-alive"); httpd_resp_set_hdr(req, "Connection", "keep-alive");
for (const auto &pair : DefaultHeaders::Instance().headers_) {
httpd_resp_set_hdr(req, pair.first.c_str(), pair.second.c_str());
}
httpd_resp_send_chunk(req, CRLF_STR, CRLF_LEN); httpd_resp_send_chunk(req, CRLF_STR, CRLF_LEN);
req->sess_ctx = this; req->sess_ctx = this;

View file

@ -157,6 +157,8 @@ class AsyncWebServerRequest {
operator httpd_req_t *() const { return this->req_; } operator httpd_req_t *() const { return this->req_; }
optional<std::string> get_header(const char *name) const; optional<std::string> get_header(const char *name) const;
// NOLINTNEXTLINE(readability-identifier-naming)
bool hasHeader(const char *name) const;
protected: protected:
httpd_req_t *req_; httpd_req_t *req_;
@ -254,6 +256,7 @@ class AsyncEventSource : public AsyncWebHandler {
class DefaultHeaders { class DefaultHeaders {
friend class AsyncWebServerRequest; friend class AsyncWebServerRequest;
friend class AsyncEventSourceResponse;
public: public:
// NOLINTNEXTLINE(readability-identifier-naming) // NOLINTNEXTLINE(readability-identifier-naming)

View file

@ -227,6 +227,7 @@ CONF_ELSE = "else"
CONF_ENABLE_BTM = "enable_btm" CONF_ENABLE_BTM = "enable_btm"
CONF_ENABLE_IPV6 = "enable_ipv6" CONF_ENABLE_IPV6 = "enable_ipv6"
CONF_ENABLE_PIN = "enable_pin" CONF_ENABLE_PIN = "enable_pin"
CONF_ENABLE_PRIVATE_NETWORK_ACCESS = "enable_private_network_access"
CONF_ENABLE_RRM = "enable_rrm" CONF_ENABLE_RRM = "enable_rrm"
CONF_ENABLE_TIME = "enable_time" CONF_ENABLE_TIME = "enable_time"
CONF_ENERGY = "energy" CONF_ENERGY = "energy"