From 42f6095960718bc33b40e21ab3dbec50dc347a05 Mon Sep 17 00:00:00 2001
From: Pietro <pxpert@gmail.com>
Date: Sun, 13 Oct 2024 11:24:17 +0200
Subject: [PATCH] [core][esp32_rmt_led_strip] Migrate ExternalRAMAllocator to
 RAMAllocator

And add psram flag to esp32_rmt_led_strip
Co-authored-by: guillempages <guillempages@users.noreply.github.com>
Co-authored-by: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com>
---
 .../esp32_rmt_led_strip/led_strip.cpp         |  4 +-
 .../esp32_rmt_led_strip/led_strip.h           |  2 +
 .../components/esp32_rmt_led_strip/light.py   |  4 +-
 esphome/core/helpers.h                        | 44 ++++++++++++-------
 4 files changed, 35 insertions(+), 19 deletions(-)

diff --git a/esphome/components/esp32_rmt_led_strip/led_strip.cpp b/esphome/components/esp32_rmt_led_strip/led_strip.cpp
index 71ab099de5..c2209f7a6c 100644
--- a/esphome/components/esp32_rmt_led_strip/led_strip.cpp
+++ b/esphome/components/esp32_rmt_led_strip/led_strip.cpp
@@ -22,7 +22,7 @@ void ESP32RMTLEDStripLightOutput::setup() {
 
   size_t buffer_size = this->get_buffer_size_();
 
-  ExternalRAMAllocator<uint8_t> allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
+  RAMAllocator<uint8_t> allocator(this->use_psram_ ? 0 : RAMAllocator<uint8_t>::ALLOC_INTERNAL);
   this->buf_ = allocator.allocate(buffer_size);
   if (this->buf_ == nullptr) {
     ESP_LOGE(TAG, "Cannot allocate LED buffer!");
@@ -37,7 +37,7 @@ void ESP32RMTLEDStripLightOutput::setup() {
     return;
   }
 
-  ExternalRAMAllocator<rmt_item32_t> rmt_allocator(ExternalRAMAllocator<rmt_item32_t>::ALLOW_FAILURE);
+  RAMAllocator<rmt_item32_t> rmt_allocator(this->use_psram_ ? 0 : RAMAllocator<rmt_item32_t>::ALLOC_INTERNAL);
   this->rmt_buf_ = rmt_allocator.allocate(buffer_size * 8 +
                                           1);  // 8 bits per byte, 1 rmt_item32_t per bit + 1 rmt_item32_t for reset
 
diff --git a/esphome/components/esp32_rmt_led_strip/led_strip.h b/esphome/components/esp32_rmt_led_strip/led_strip.h
index 43215cf12b..d21bd86e75 100644
--- a/esphome/components/esp32_rmt_led_strip/led_strip.h
+++ b/esphome/components/esp32_rmt_led_strip/led_strip.h
@@ -45,6 +45,7 @@ class ESP32RMTLEDStripLightOutput : public light::AddressableLight {
   void set_num_leds(uint16_t num_leds) { this->num_leds_ = num_leds; }
   void set_is_rgbw(bool is_rgbw) { this->is_rgbw_ = is_rgbw; }
   void set_is_wrgb(bool is_wrgb) { this->is_wrgb_ = is_wrgb; }
+  void set_use_psram(bool use_psram) { this->use_psram_ = use_psram; }
 
   /// Set a maximum refresh rate in µs as some lights do not like being updated too often.
   void set_max_refresh_rate(uint32_t interval_us) { this->max_refresh_rate_ = interval_us; }
@@ -75,6 +76,7 @@ class ESP32RMTLEDStripLightOutput : public light::AddressableLight {
   uint16_t num_leds_;
   bool is_rgbw_;
   bool is_wrgb_;
+  bool use_psram_;
 
   rmt_item32_t bit0_, bit1_, reset_;
   RGBOrder rgb_order_;
diff --git a/esphome/components/esp32_rmt_led_strip/light.py b/esphome/components/esp32_rmt_led_strip/light.py
index 1e3c2d4f72..79f339e248 100644
--- a/esphome/components/esp32_rmt_led_strip/light.py
+++ b/esphome/components/esp32_rmt_led_strip/light.py
@@ -55,7 +55,7 @@ CHIPSETS = {
     "SM16703": LEDStripTimings(300, 900, 900, 300, 0, 0),
 }
 
-
+CONF_USE_PSRAM = "use_psram"
 CONF_IS_WRGB = "is_wrgb"
 CONF_BIT0_HIGH = "bit0_high"
 CONF_BIT0_LOW = "bit0_low"
@@ -77,6 +77,7 @@ CONFIG_SCHEMA = cv.All(
             cv.Optional(CONF_CHIPSET): cv.one_of(*CHIPSETS, upper=True),
             cv.Optional(CONF_IS_RGBW, default=False): cv.boolean,
             cv.Optional(CONF_IS_WRGB, default=False): cv.boolean,
+            cv.Optional(CONF_USE_PSRAM, default=True): cv.boolean,
             cv.Inclusive(
                 CONF_BIT0_HIGH,
                 "custom",
@@ -145,6 +146,7 @@ async def to_code(config):
     cg.add(var.set_rgb_order(config[CONF_RGB_ORDER]))
     cg.add(var.set_is_rgbw(config[CONF_IS_RGBW]))
     cg.add(var.set_is_wrgb(config[CONF_IS_WRGB]))
+    cg.add(var.set_use_psram(config[CONF_USE_PSRAM]))
 
     cg.add(
         var.set_rmt_channel(
diff --git a/esphome/core/helpers.h b/esphome/core/helpers.h
index 7df4b84230..7f6fe9bfdc 100644
--- a/esphome/core/helpers.h
+++ b/esphome/core/helpers.h
@@ -651,35 +651,45 @@ void delay_microseconds_safe(uint32_t us);
 /// @name Memory management
 ///@{
 
-/** An STL allocator that uses SPI RAM.
+/** An STL allocator that uses SPI or internal RAM.
+ * Returns `nullptr` in case no memory is available.
  *
- * By setting flags, it can be configured to don't try main memory if SPI RAM is full or unavailable, and to return
- * `nulllptr` instead of aborting when no memory is available.
+ * By setting flags, it can be configured to:
+ * - perform external allocation falling back to main memory if SPI RAM is full or unavailable
+ * - perform external allocation only
+ * - perform internal allocation only
  */
-template<class T> class ExternalRAMAllocator {
+template<class T> class RAMAllocator {
  public:
   using value_type = T;
 
   enum Flags {
-    NONE = 0,
-    REFUSE_INTERNAL = 1 << 0,  ///< Refuse falling back to internal memory when external RAM is full or unavailable.
-    ALLOW_FAILURE = 1 << 1,    ///< Don't abort when memory allocation fails.
+    NONE = 0,                 // Perform external allocation and fall back to internal memory
+    ALLOC_EXTERNAL = 1 << 0,  // Perform external allocation only.
+    ALLOC_INTERNAL = 1 << 1,  // Perform internal allocation only.
+    ALLOW_FAILURE = 1 << 2,   // Does nothing. Kept for compatibility.
   };
 
-  ExternalRAMAllocator() = default;
-  ExternalRAMAllocator(Flags flags) : flags_{flags} {}
-  template<class U> constexpr ExternalRAMAllocator(const ExternalRAMAllocator<U> &other) : flags_{other.flags_} {}
+  RAMAllocator() = default;
+  RAMAllocator(uint8_t flags) : flags_{flags} {}
+  template<class U> constexpr RAMAllocator(const RAMAllocator<U> &other) : flags_{other.flags_} {}
 
   T *allocate(size_t n) {
     size_t size = n * sizeof(T);
     T *ptr = nullptr;
 #ifdef USE_ESP32
-    ptr = static_cast<T *>(heap_caps_malloc(size, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT));
-#endif
-    if (ptr == nullptr && (this->flags_ & Flags::REFUSE_INTERNAL) == 0)
+    // External allocation by default or if explicitely requested
+    if ((this->flags_ & Flags::ALLOC_EXTERNAL) || ((this->flags_ & Flags::ALLOC_INTERNAL) == 0)) {
+      ptr = static_cast<T *>(heap_caps_malloc(size, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT));
+    }
+    // Fallback to internal allocation if explicitely requested or no flag is specified
+    if (ptr == nullptr && ((this->flags_ & Flags::ALLOC_INTERNAL) || (this->flags_ & Flags::ALLOC_EXTERNAL) == 0)) {
       ptr = static_cast<T *>(malloc(size));  // NOLINT(cppcoreguidelines-owning-memory,cppcoreguidelines-no-malloc)
-    if (ptr == nullptr && (this->flags_ & Flags::ALLOW_FAILURE) == 0)
-      abort();
+    }
+#else
+    // Ignore ALLOC_EXTERNAL/ALLOC_INTERNAL flags if external allocation is not supported
+    ptr = static_cast<T *>(malloc(size));  // NOLINT(cppcoreguidelines-owning-memory,cppcoreguidelines-no-malloc)
+#endif
     return ptr;
   }
 
@@ -688,9 +698,11 @@ template<class T> class ExternalRAMAllocator {
   }
 
  private:
-  Flags flags_{Flags::ALLOW_FAILURE};
+  uint8_t flags_{Flags::ALLOW_FAILURE};
 };
 
+template<class T> using ExternalRAMAllocator = RAMAllocator<T>;
+
 /// @}
 
 /// @name Internal functions