diff --git a/esphome/components/web_server/__init__.py b/esphome/components/web_server/__init__.py index 7f17767657..240ba7c8a0 100644 --- a/esphome/components/web_server/__init__.py +++ b/esphome/components/web_server/__init__.py @@ -34,8 +34,8 @@ CONFIG_SCHEMA = cv.Schema( cv.Optional(CONF_JS_INCLUDE): cv.file_, cv.Optional(CONF_AUTH): cv.Schema( { - cv.Required(CONF_USERNAME): cv.string_strict, - cv.Required(CONF_PASSWORD): cv.string_strict, + cv.Required(CONF_USERNAME): cv.All(cv.string_strict, cv.Length(min=1)), + cv.Required(CONF_PASSWORD): cv.All(cv.string_strict, cv.Length(min=1)), } ), cv.GenerateID(CONF_WEB_SERVER_BASE_ID): cv.use_id( @@ -57,8 +57,8 @@ async def to_code(config): cg.add(var.set_css_url(config[CONF_CSS_URL])) cg.add(var.set_js_url(config[CONF_JS_URL])) if CONF_AUTH in config: - cg.add(var.set_username(config[CONF_AUTH][CONF_USERNAME])) - cg.add(var.set_password(config[CONF_AUTH][CONF_PASSWORD])) + cg.add(paren.set_auth_username(config[CONF_AUTH][CONF_USERNAME])) + cg.add(paren.set_auth_password(config[CONF_AUTH][CONF_PASSWORD])) if CONF_CSS_INCLUDE in config: cg.add_define("WEBSERVER_CSS_INCLUDE") path = CORE.relative_config_path(config[CONF_CSS_INCLUDE]) diff --git a/esphome/components/web_server/web_server.cpp b/esphome/components/web_server/web_server.cpp index dc97bcd5c2..e19a54931a 100644 --- a/esphome/components/web_server/web_server.cpp +++ b/esphome/components/web_server/web_server.cpp @@ -1,8 +1,8 @@ #include "web_server.h" -#include "esphome/core/log.h" -#include "esphome/core/application.h" -#include "esphome/core/util.h" #include "esphome/components/json/json_util.h" +#include "esphome/core/application.h" +#include "esphome/core/log.h" +#include "esphome/core/util.h" #include "StreamString.h" @@ -151,9 +151,6 @@ void WebServer::setup() { void WebServer::dump_config() { ESP_LOGCONFIG(TAG, "Web Server:"); ESP_LOGCONFIG(TAG, " Address: %s:%u", network_get_address().c_str(), this->base_->get_port()); - if (this->using_auth()) { - ESP_LOGCONFIG(TAG, " Basic authentication enabled"); - } } float WebServer::get_setup_priority() const { return setup_priority::WIFI - 1.0f; } @@ -728,10 +725,6 @@ bool WebServer::canHandle(AsyncWebServerRequest *request) { return false; } void WebServer::handleRequest(AsyncWebServerRequest *request) { - if (this->using_auth() && !request->authenticate(this->username_, this->password_)) { - return request->requestAuthentication(); - } - if (request->url() == "/") { this->handle_index_request(request); return; diff --git a/esphome/components/web_server/web_server.h b/esphome/components/web_server/web_server.h index 54d7356ac9..4e9224ee26 100644 --- a/esphome/components/web_server/web_server.h +++ b/esphome/components/web_server/web_server.h @@ -30,10 +30,6 @@ class WebServer : public Controller, public Component, public AsyncWebHandler { public: WebServer(web_server_base::WebServerBase *base) : base_(base) {} - void set_username(const char *username) { username_ = username; } - - void set_password(const char *password) { password_ = password; } - /** Set the URL to the CSS that's sent to each client. Defaults to * https://esphome.io/_static/webserver-v1.min.css * @@ -83,8 +79,6 @@ class WebServer : public Controller, public Component, public AsyncWebHandler { void handle_js_request(AsyncWebServerRequest *request); #endif - bool using_auth() { return username_ != nullptr && password_ != nullptr; } - #ifdef USE_SENSOR void on_sensor_update(sensor::Sensor *obj, float state) override; /// Handle a sensor request under '/sensor/'. @@ -182,8 +176,6 @@ class WebServer : public Controller, public Component, public AsyncWebHandler { protected: web_server_base::WebServerBase *base_; AsyncEventSource events_{"/events"}; - const char *username_{nullptr}; - const char *password_{nullptr}; const char *css_url_{nullptr}; const char *css_include_{nullptr}; const char *js_url_{nullptr}; diff --git a/esphome/components/web_server_base/web_server_base.cpp b/esphome/components/web_server_base/web_server_base.cpp index 85711704b9..832456dc83 100644 --- a/esphome/components/web_server_base/web_server_base.cpp +++ b/esphome/components/web_server_base/web_server_base.cpp @@ -15,6 +15,17 @@ namespace web_server_base { static const char *const TAG = "web_server_base"; +void WebServerBase::add_handler(AsyncWebHandler *handler) { + // remove all handlers + + if (!credentials_.username.empty()) { + handler = new internal::AuthMiddlewareHandler(handler, &credentials_); + } + this->handlers_.push_back(handler); + if (this->server_ != nullptr) + this->server_->addHandler(handler); +} + void report_ota_error() { StreamString ss; Update.printError(ss); diff --git a/esphome/components/web_server_base/web_server_base.h b/esphome/components/web_server_base/web_server_base.h index b6024ceafa..1bfec13fc5 100644 --- a/esphome/components/web_server_base/web_server_base.h +++ b/esphome/components/web_server_base/web_server_base.h @@ -7,6 +7,68 @@ namespace esphome { namespace web_server_base { +namespace internal { + +class MiddlewareHandler : public AsyncWebHandler { + public: + MiddlewareHandler(AsyncWebHandler *next) : next_(next) {} + + bool canHandle(AsyncWebServerRequest *request) override { return next_->canHandle(request); } + void handleRequest(AsyncWebServerRequest *request) override { next_->handleRequest(request); } + void handleUpload(AsyncWebServerRequest *request, const String &filename, size_t index, uint8_t *data, size_t len, + bool final) override { + next_->handleUpload(request, filename, index, data, len, final); + } + void handleBody(AsyncWebServerRequest *request, uint8_t *data, size_t len, size_t index, size_t total) override { + next_->handleBody(request, data, len, index, total); + } + bool isRequestHandlerTrivial() override { return next_->isRequestHandlerTrivial(); } + + protected: + AsyncWebHandler *next_; +}; + +struct Credentials { + std::string username; + std::string password; +}; + +class AuthMiddlewareHandler : public MiddlewareHandler { + public: + AuthMiddlewareHandler(AsyncWebHandler *next, Credentials *credentials) + : MiddlewareHandler(next), credentials_(credentials) {} + + bool check_auth(AsyncWebServerRequest *request) { + bool success = request->authenticate(credentials_->username.c_str(), credentials_->password.c_str()); + if (!success) { + request->requestAuthentication(); + } + return success; + } + + void handleRequest(AsyncWebServerRequest *request) override { + if (!check_auth(request)) + return; + MiddlewareHandler::handleRequest(request); + } + void handleUpload(AsyncWebServerRequest *request, const String &filename, size_t index, uint8_t *data, size_t len, + bool final) override { + if (!check_auth(request)) + return; + MiddlewareHandler::handleUpload(request, filename, index, data, len, final); + } + void handleBody(AsyncWebServerRequest *request, uint8_t *data, size_t len, size_t index, size_t total) override { + if (!check_auth(request)) + return; + MiddlewareHandler::handleBody(request, data, len, index, total); + } + + protected: + Credentials *credentials_; +}; + +} // namespace internal + class WebServerBase : public Component { public: void init() { @@ -32,13 +94,10 @@ class WebServerBase : public Component { AsyncWebServer *get_server() const { return server_; } float get_setup_priority() const override; - void add_handler(AsyncWebHandler *handler) { - // remove all handlers + void set_auth_username(std::string auth_username) { credentials_.username = auth_username; } + void set_auth_password(std::string auth_password) { credentials_.password = auth_password; } - this->handlers_.push_back(handler); - if (this->server_ != nullptr) - this->server_->addHandler(handler); - } + void add_handler(AsyncWebHandler *handler); void add_ota_handler(); @@ -52,6 +111,7 @@ class WebServerBase : public Component { uint16_t port_{80}; AsyncWebServer *server_{nullptr}; std::vector handlers_; + internal::Credentials credentials_; }; class OTARequestHandler : public AsyncWebHandler {