From b32b918936cc1cdc796de27f2e022b743fa6c421 Mon Sep 17 00:00:00 2001
From: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
Date: Wed, 1 Dec 2021 04:18:21 +1300
Subject: [PATCH] Button device class (#2835)

---
 esphome/components/api/api.proto              |  1 +
 esphome/components/api/api_connection.cpp     |  1 +
 esphome/components/api/api_pb2.cpp            |  9 ++++++++
 esphome/components/api/api_pb2.h              |  1 +
 esphome/components/button/__init__.py         | 23 +++++++++++++++++++
 esphome/components/button/button.cpp          |  7 ++++++
 esphome/components/button/button.h            |  7 ++++++
 esphome/components/mqtt/mqtt_button.cpp       |  5 ++++
 esphome/components/mqtt/mqtt_button.h         |  2 +-
 esphome/components/restart/button/__init__.py |  6 +++--
 esphome/const.py                              |  6 ++++-
 11 files changed, 64 insertions(+), 4 deletions(-)

diff --git a/esphome/components/api/api.proto b/esphome/components/api/api.proto
index eaad4b8d07..3e2c806135 100644
--- a/esphome/components/api/api.proto
+++ b/esphome/components/api/api.proto
@@ -960,6 +960,7 @@ message ListEntitiesButtonResponse {
   string icon = 5;
   bool disabled_by_default = 6;
   EntityCategory entity_category = 7;
+  string device_class = 8;
 }
 message ButtonCommandRequest {
   option (id) = 62;
diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp
index 22b896a788..8367afc042 100644
--- a/esphome/components/api/api_connection.cpp
+++ b/esphome/components/api/api_connection.cpp
@@ -684,6 +684,7 @@ bool APIConnection::send_button_info(button::Button *button) {
   msg.icon = button->get_icon();
   msg.disabled_by_default = button->is_disabled_by_default();
   msg.entity_category = static_cast<enums::EntityCategory>(button->get_entity_category());
+  msg.device_class = button->get_device_class();
   return this->send_list_entities_button_response(msg);
 }
 void APIConnection::button_command(const ButtonCommandRequest &msg) {
diff --git a/esphome/components/api/api_pb2.cpp b/esphome/components/api/api_pb2.cpp
index 169349d995..b6974de08e 100644
--- a/esphome/components/api/api_pb2.cpp
+++ b/esphome/components/api/api_pb2.cpp
@@ -4179,6 +4179,10 @@ bool ListEntitiesButtonResponse::decode_length(uint32_t field_id, ProtoLengthDel
       this->icon = value.as_string();
       return true;
     }
+    case 8: {
+      this->device_class = value.as_string();
+      return true;
+    }
     default:
       return false;
   }
@@ -4201,6 +4205,7 @@ void ListEntitiesButtonResponse::encode(ProtoWriteBuffer buffer) const {
   buffer.encode_string(5, this->icon);
   buffer.encode_bool(6, this->disabled_by_default);
   buffer.encode_enum<enums::EntityCategory>(7, this->entity_category);
+  buffer.encode_string(8, this->device_class);
 }
 #ifdef HAS_PROTO_MESSAGE_DUMP
 void ListEntitiesButtonResponse::dump_to(std::string &out) const {
@@ -4234,6 +4239,10 @@ void ListEntitiesButtonResponse::dump_to(std::string &out) const {
   out.append("  entity_category: ");
   out.append(proto_enum_to_string<enums::EntityCategory>(this->entity_category));
   out.append("\n");
+
+  out.append("  device_class: ");
+  out.append("'").append(this->device_class).append("'");
+  out.append("\n");
   out.append("}");
 }
 #endif
diff --git a/esphome/components/api/api_pb2.h b/esphome/components/api/api_pb2.h
index 82fd8de687..4d1f658910 100644
--- a/esphome/components/api/api_pb2.h
+++ b/esphome/components/api/api_pb2.h
@@ -1050,6 +1050,7 @@ class ListEntitiesButtonResponse : public ProtoMessage {
   std::string icon{};
   bool disabled_by_default{false};
   enums::EntityCategory entity_category{};
+  std::string device_class{};
   void encode(ProtoWriteBuffer buffer) const override;
 #ifdef HAS_PROTO_MESSAGE_DUMP
   void dump_to(std::string &out) const override;
diff --git a/esphome/components/button/__init__.py b/esphome/components/button/__init__.py
index 495a85b6b4..1e248ddf07 100644
--- a/esphome/components/button/__init__.py
+++ b/esphome/components/button/__init__.py
@@ -4,12 +4,15 @@ from esphome import automation
 from esphome.automation import maybe_simple_id
 from esphome.components import mqtt
 from esphome.const import (
+    CONF_DEVICE_CLASS,
     CONF_ENTITY_CATEGORY,
     CONF_ICON,
     CONF_ID,
     CONF_ON_PRESS,
     CONF_TRIGGER_ID,
     CONF_MQTT_ID,
+    DEVICE_CLASS_RESTART,
+    DEVICE_CLASS_UPDATE,
 )
 from esphome.core import CORE, coroutine_with_priority
 from esphome.cpp_helpers import setup_entity
@@ -17,6 +20,11 @@ from esphome.cpp_helpers import setup_entity
 CODEOWNERS = ["@esphome/core"]
 IS_PLATFORM_COMPONENT = True
 
+DEVICE_CLASSES = [
+    DEVICE_CLASS_RESTART,
+    DEVICE_CLASS_UPDATE,
+]
+
 button_ns = cg.esphome_ns.namespace("button")
 Button = button_ns.class_("Button", cg.EntityBase)
 ButtonPtr = Button.operator("ptr")
@@ -27,10 +35,13 @@ ButtonPressTrigger = button_ns.class_(
     "ButtonPressTrigger", automation.Trigger.template()
 )
 
+validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True, space="_")
+
 
 BUTTON_SCHEMA = cv.ENTITY_BASE_SCHEMA.extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA).extend(
     {
         cv.OnlyWith(CONF_MQTT_ID, "mqtt"): cv.declare_id(mqtt.MQTTButtonComponent),
+        cv.Optional(CONF_DEVICE_CLASS): validate_device_class,
         cv.Optional(CONF_ON_PRESS): automation.validate_automation(
             {
                 cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ButtonPressTrigger),
@@ -45,6 +56,7 @@ _UNDEF = object()
 def button_schema(
     icon: str = _UNDEF,
     entity_category: str = _UNDEF,
+    device_class: str = _UNDEF,
 ) -> cv.Schema:
     schema = BUTTON_SCHEMA
     if icon is not _UNDEF:
@@ -57,6 +69,14 @@ def button_schema(
                 ): cv.entity_category
             }
         )
+    if device_class is not _UNDEF:
+        schema = schema.extend(
+            {
+                cv.Optional(
+                    CONF_DEVICE_CLASS, default=device_class
+                ): validate_device_class
+            }
+        )
     return schema
 
 
@@ -67,6 +87,9 @@ async def setup_button_core_(var, config):
         trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
         await automation.build_automation(trigger, [], conf)
 
+    if CONF_DEVICE_CLASS in config:
+        cg.add(var.set_device_class(config[CONF_DEVICE_CLASS]))
+
     if CONF_MQTT_ID in config:
         mqtt_ = cg.new_Pvariable(config[CONF_MQTT_ID], var)
         await mqtt.register_mqtt_component(mqtt_, config)
diff --git a/esphome/components/button/button.cpp b/esphome/components/button/button.cpp
index fe39b1e458..d57b46e9aa 100644
--- a/esphome/components/button/button.cpp
+++ b/esphome/components/button/button.cpp
@@ -17,5 +17,12 @@ void Button::press() {
 void Button::add_on_press_callback(std::function<void()> &&callback) { this->press_callback_.add(std::move(callback)); }
 uint32_t Button::hash_base() { return 1495763804UL; }
 
+void Button::set_device_class(const std::string &device_class) { this->device_class_ = device_class; }
+std::string Button::get_device_class() {
+  if (this->device_class_.has_value())
+    return *this->device_class_;
+  return "";
+}
+
 }  // namespace button
 }  // namespace esphome
diff --git a/esphome/components/button/button.h b/esphome/components/button/button.h
index 954afa0ab9..b21a96b8e1 100644
--- a/esphome/components/button/button.h
+++ b/esphome/components/button/button.h
@@ -36,6 +36,12 @@ class Button : public EntityBase {
    */
   void add_on_press_callback(std::function<void()> &&callback);
 
+  /// Set the Home Assistant device class (see button::device_class).
+  void set_device_class(const std::string &device_class);
+
+  /// Get the device class for this button.
+  std::string get_device_class();
+
  protected:
   /** You should implement this virtual method if you want to create your own button.
    */
@@ -44,6 +50,7 @@ class Button : public EntityBase {
   uint32_t hash_base() override;
 
   CallbackManager<void()> press_callback_{};
+  optional<std::string> device_class_{};
 };
 
 }  // namespace button
diff --git a/esphome/components/mqtt/mqtt_button.cpp b/esphome/components/mqtt/mqtt_button.cpp
index 5a0d14b648..25ff327cf9 100644
--- a/esphome/components/mqtt/mqtt_button.cpp
+++ b/esphome/components/mqtt/mqtt_button.cpp
@@ -30,6 +30,11 @@ void MQTTButtonComponent::dump_config() {
   LOG_MQTT_COMPONENT(true, true);
 }
 
+void MQTTButtonComponent::send_discovery(JsonObject &root, mqtt::SendDiscoveryConfig &config) {
+  if (!this->button_->get_device_class().empty())
+    root[MQTT_DEVICE_CLASS] = this->button_->get_device_class();
+}
+
 std::string MQTTButtonComponent::component_type() const { return "button"; }
 const EntityBase *MQTTButtonComponent::get_entity() const { return this->button_; }
 
diff --git a/esphome/components/mqtt/mqtt_button.h b/esphome/components/mqtt/mqtt_button.h
index a7e60db380..66e4b2609f 100644
--- a/esphome/components/mqtt/mqtt_button.h
+++ b/esphome/components/mqtt/mqtt_button.h
@@ -23,7 +23,7 @@ class MQTTButtonComponent : public mqtt::MQTTComponent {
   /// Buttons do not send a state so just return true.
   bool send_initial_state() override { return true; }
 
-  void send_discovery(JsonObject &root, mqtt::SendDiscoveryConfig &config) override {}
+  void send_discovery(JsonObject &root, mqtt::SendDiscoveryConfig &config) override;
 
  protected:
   /// "button" component type.
diff --git a/esphome/components/restart/button/__init__.py b/esphome/components/restart/button/__init__.py
index 257a8e35f7..1a0e9cdc3d 100644
--- a/esphome/components/restart/button/__init__.py
+++ b/esphome/components/restart/button/__init__.py
@@ -3,15 +3,17 @@ import esphome.config_validation as cv
 from esphome.components import button
 from esphome.const import (
     CONF_ID,
+    DEVICE_CLASS_RESTART,
     ENTITY_CATEGORY_CONFIG,
-    ICON_RESTART,
 )
 
 restart_ns = cg.esphome_ns.namespace("restart")
 RestartButton = restart_ns.class_("RestartButton", button.Button, cg.Component)
 
 CONFIG_SCHEMA = (
-    button.button_schema(icon=ICON_RESTART, entity_category=ENTITY_CATEGORY_CONFIG)
+    button.button_schema(
+        device_class=DEVICE_CLASS_RESTART, entity_category=ENTITY_CATEGORY_CONFIG
+    )
     .extend({cv.GenerateID(): cv.declare_id(RestartButton)})
     .extend(cv.COMPONENT_SCHEMA)
 )
diff --git a/esphome/const.py b/esphome/const.py
index 3510e500f5..740e38cf44 100644
--- a/esphome/const.py
+++ b/esphome/const.py
@@ -865,7 +865,6 @@ DEVICE_CLASS_SAFETY = "safety"
 DEVICE_CLASS_SMOKE = "smoke"
 DEVICE_CLASS_SOUND = "sound"
 DEVICE_CLASS_TAMPER = "tamper"
-DEVICE_CLASS_UPDATE = "update"
 DEVICE_CLASS_VIBRATION = "vibration"
 DEVICE_CLASS_WINDOW = "window"
 # device classes of both binary_sensor and sensor component
@@ -897,6 +896,11 @@ DEVICE_CLASS_TEMPERATURE = "temperature"
 DEVICE_CLASS_TIMESTAMP = "timestamp"
 DEVICE_CLASS_VOLATILE_ORGANIC_COMPOUNDS = "volatile_organic_compounds"
 DEVICE_CLASS_VOLTAGE = "voltage"
+# device classes of both binary_sensor and button component
+DEVICE_CLASS_UPDATE = "update"
+# device classes of button component
+DEVICE_CLASS_RESTART = "restart"
+
 
 # state classes
 STATE_CLASS_NONE = ""