diff --git a/esphome/components/script/__init__.py b/esphome/components/script/__init__.py index 9702878475..907d7bf0e3 100644 --- a/esphome/components/script/__init__.py +++ b/esphome/components/script/__init__.py @@ -2,7 +2,8 @@ import esphome.codegen as cg import esphome.config_validation as cv from esphome import automation from esphome.automation import maybe_simple_id -from esphome.const import CONF_ID, CONF_MODE +from esphome.const import CONF_ID, CONF_MODE, CONF_PARAMETERS +from esphome.core import CORE, EsphomeError CODEOWNERS = ["@esphome/core"] script_ns = cg.esphome_ns.namespace("script") @@ -16,6 +17,7 @@ RestartScript = script_ns.class_("RestartScript", Script) QueueingScript = script_ns.class_("QueueingScript", Script, cg.Component) ParallelScript = script_ns.class_("ParallelScript", Script) +CONF_SCRIPT = "script" CONF_SINGLE = "single" CONF_RESTART = "restart" CONF_QUEUED = "queued" @@ -29,6 +31,18 @@ SCRIPT_MODES = { CONF_PARALLEL: ParallelScript, } +PARAMETER_TYPE_TRANSLATIONS = { + "string": "std::string", +} + + +def get_script(script_id): + scripts = CORE.config.get(CONF_SCRIPT, {}) + for script in scripts: + if script.get(CONF_ID, None) == script_id: + return script + raise cv.Invalid(f"Script id '{script_id}' not found") + def check_max_runs(value): if CONF_MAX_RUNS not in value: @@ -47,6 +61,44 @@ def assign_declare_id(value): return value +def parameters_to_template(args): + + template_args = [] + func_args = [] + script_arg_names = [] + for name, type_ in args.items(): + array = False + if type_.endswith("[]"): + array = True + type_ = type_[:-2] + type_ = PARAMETER_TYPE_TRANSLATIONS.get(type_, type_) + if array: + type_ = f"std::vector<{type_}>" + type_ = cg.esphome_ns.namespace(type_) + template_args.append(type_) + func_args.append((type_, name)) + script_arg_names.append(name) + template = cg.TemplateArguments(*template_args) + return template, func_args + + +def validate_parameter_name(value): + value = cv.string(value) + if value != CONF_ID: + return value + raise cv.Invalid(f"Script's parameter name cannot be {CONF_ID}") + + +ALLOWED_PARAM_TYPE_CHARSET = set("abcdefghijklmnopqrstuvwxyz0123456789_:*&[]") + + +def validate_parameter_type(value): + value = cv.string_strict(value) + if set(value.lower()) <= ALLOWED_PARAM_TYPE_CHARSET: + return value + raise cv.Invalid("Parameter type contains invalid characters") + + CONFIG_SCHEMA = automation.validate_automation( { # Don't declare id as cv.declare_id yet, because the ID type @@ -56,6 +108,11 @@ CONFIG_SCHEMA = automation.validate_automation( *SCRIPT_MODES, lower=True ), cv.Optional(CONF_MAX_RUNS): cv.positive_int, + cv.Optional(CONF_PARAMETERS, default={}): cv.Schema( + { + validate_parameter_name: validate_parameter_type, + } + ), }, extra_validators=cv.All(check_max_runs, assign_declare_id), ) @@ -65,7 +122,8 @@ async def to_code(config): # Register all variables first, so that scripts can use other scripts triggers = [] for conf in config: - trigger = cg.new_Pvariable(conf[CONF_ID]) + template, func_args = parameters_to_template(conf[CONF_PARAMETERS]) + trigger = cg.new_Pvariable(conf[CONF_ID], template) # Add a human-readable name to the script cg.add(trigger.set_name(conf[CONF_ID].id)) @@ -75,10 +133,10 @@ async def to_code(config): if conf[CONF_MODE] == CONF_QUEUED: await cg.register_component(trigger, conf) - triggers.append((trigger, conf)) + triggers.append((trigger, func_args, conf)) - for trigger, conf in triggers: - await automation.build_automation(trigger, [], conf) + for trigger, func_args, conf in triggers: + await automation.build_automation(trigger, func_args, conf) @automation.register_action( @@ -87,12 +145,39 @@ async def to_code(config): maybe_simple_id( { cv.Required(CONF_ID): cv.use_id(Script), - } + cv.Optional(validate_parameter_name): cv.templatable(cv.valid), + }, ), ) async def script_execute_action_to_code(config, action_id, template_arg, args): + async def get_ordered_args(config, script_params): + config_args = config.copy() + config_args.pop(CONF_ID) + + # match script_args to the formal parameter order + script_args = [] + for type, name in script_params: + if name not in config_args: + raise EsphomeError( + f"Missing parameter: '{name}' in script.execute {config[CONF_ID]}" + ) + arg = await cg.templatable(config_args[name], args, type) + script_args.append(arg) + return script_args + + script = get_script(config[CONF_ID]) + params = script.get(CONF_PARAMETERS, []) + template, script_params = parameters_to_template(params) + script_args = await get_ordered_args(config, script_params) + + # We need to use the parent class 'Script' as the template argument + # to match the partial specialization of the ScriptExecuteAction template + template_arg = cg.TemplateArguments(Script.template(template), *template_arg) + paren = await cg.get_variable(config[CONF_ID]) - return cg.new_Pvariable(action_id, template_arg, paren) + var = cg.new_Pvariable(action_id, template_arg, paren) + cg.add(var.set_args(*script_args)) + return var @automation.register_action( @@ -101,7 +186,8 @@ async def script_execute_action_to_code(config, action_id, template_arg, args): maybe_simple_id({cv.Required(CONF_ID): cv.use_id(Script)}), ) async def script_stop_action_to_code(config, action_id, template_arg, args): - paren = await cg.get_variable(config[CONF_ID]) + full_id, paren = await cg.get_variable_with_full_id(config[CONF_ID]) + template_arg = cg.TemplateArguments(full_id.type, *template_arg) return cg.new_Pvariable(action_id, template_arg, paren) @@ -111,7 +197,8 @@ async def script_stop_action_to_code(config, action_id, template_arg, args): maybe_simple_id({cv.Required(CONF_ID): cv.use_id(Script)}), ) async def script_wait_action_to_code(config, action_id, template_arg, args): - paren = await cg.get_variable(config[CONF_ID]) + full_id, paren = await cg.get_variable_with_full_id(config[CONF_ID]) + template_arg = cg.TemplateArguments(full_id.type, *template_arg) var = cg.new_Pvariable(action_id, template_arg, paren) await cg.register_component(var, {}) return var @@ -123,5 +210,6 @@ async def script_wait_action_to_code(config, action_id, template_arg, args): automation.maybe_simple_id({cv.Required(CONF_ID): cv.use_id(Script)}), ) async def script_is_running_to_code(config, condition_id, template_arg, args): - paren = await cg.get_variable(config[CONF_ID]) + full_id, paren = await cg.get_variable_with_full_id(config[CONF_ID]) + template_arg = cg.TemplateArguments(full_id.type, *template_arg) return cg.new_Pvariable(condition_id, template_arg, paren) diff --git a/esphome/components/script/script.cpp b/esphome/components/script/script.cpp index 46bcef905b..331f7dcd65 100644 --- a/esphome/components/script/script.cpp +++ b/esphome/components/script/script.cpp @@ -6,61 +6,8 @@ namespace script { static const char *const TAG = "script"; -void SingleScript::execute() { - if (this->is_action_running()) { - ESP_LOGW(TAG, "Script '%s' is already running! (mode: single)", this->name_.c_str()); - return; - } - - this->trigger(); -} - -void RestartScript::execute() { - if (this->is_action_running()) { - ESP_LOGD(TAG, "Script '%s' restarting (mode: restart)", this->name_.c_str()); - this->stop_action(); - } - - this->trigger(); -} - -void QueueingScript::execute() { - if (this->is_action_running()) { - // num_runs_ is the number of *queued* instances, so total number of instances is - // num_runs_ + 1 - if (this->max_runs_ != 0 && this->num_runs_ + 1 >= this->max_runs_) { - ESP_LOGW(TAG, "Script '%s' maximum number of queued runs exceeded!", this->name_.c_str()); - return; - } - - ESP_LOGD(TAG, "Script '%s' queueing new instance (mode: queued)", this->name_.c_str()); - this->num_runs_++; - return; - } - - this->trigger(); - // Check if the trigger was immediate and we can continue right away. - this->loop(); -} - -void QueueingScript::stop() { - this->num_runs_ = 0; - Script::stop(); -} - -void QueueingScript::loop() { - if (this->num_runs_ != 0 && !this->is_action_running()) { - this->num_runs_--; - this->trigger(); - } -} - -void ParallelScript::execute() { - if (this->max_runs_ != 0 && this->automation_parent_->num_running() >= this->max_runs_) { - ESP_LOGW(TAG, "Script '%s' maximum number of parallel runs exceeded!", this->name_.c_str()); - return; - } - this->trigger(); +void ScriptLogger::esp_log_(int level, int line, const char *format, const char *param) { + esp_log_printf_(level, TAG, line, format, param); } } // namespace script diff --git a/esphome/components/script/script.h b/esphome/components/script/script.h index 5663d32ce8..f3a83cd6ec 100644 --- a/esphome/components/script/script.h +++ b/esphome/components/script/script.h @@ -2,27 +2,48 @@ #include "esphome/core/automation.h" #include "esphome/core/component.h" +#include "esphome/core/log.h" namespace esphome { namespace script { +class ScriptLogger { + protected: + void esp_logw_(int line, const char *format, const char *param) { + esp_log_(ESPHOME_LOG_LEVEL_WARN, line, format, param); + } + void esp_logd_(int line, const char *format, const char *param) { + esp_log_(ESPHOME_LOG_LEVEL_DEBUG, line, format, param); + } + void esp_log_(int level, int line, const char *format, const char *param); +}; + /// The abstract base class for all script types. -class Script : public Trigger<> { +template class Script : public ScriptLogger, public Trigger { public: /** Execute a new instance of this script. * * The behavior of this function when a script is already running is defined by the subtypes */ - virtual void execute() = 0; + virtual void execute(Ts...) = 0; /// Check if any instance of this script is currently running. virtual bool is_running() { return this->is_action_running(); } /// Stop all instances of this script. virtual void stop() { this->stop_action(); } + // execute this script using a tuple that contains the arguments + void execute_tuple(const std::tuple &tuple) { + this->execute_tuple_(tuple, typename gens::type()); + } + // Internal function to give scripts readable names. void set_name(const std::string &name) { name_ = name; } protected: + template void execute_tuple_(const std::tuple &tuple, seq /*unused*/) { + this->execute(std::get(tuple)...); + } + std::string name_; }; @@ -31,9 +52,16 @@ class Script : public Trigger<> { * If a new instance is executed while the previous one hasn't finished yet, * a warning is printed and the new instance is discarded. */ -class SingleScript : public Script { +template class SingleScript : public Script { public: - void execute() override; + void execute(Ts... x) override { + if (this->is_action_running()) { + this->esp_logw_(__LINE__, "Script '%s' is already running! (mode: single)", this->name_.c_str()); + return; + } + + this->trigger(x...); + } }; /** A script type that restarts scripts from the beginning when a new instance is started. @@ -41,20 +69,55 @@ class SingleScript : public Script { * If a new instance is started but another one is already running, the existing * script is stopped and the new instance starts from the beginning. */ -class RestartScript : public Script { +template class RestartScript : public Script { public: - void execute() override; + void execute(Ts... x) override { + if (this->is_action_running()) { + this->esp_logd_(__LINE__, "Script '%s' restarting (mode: restart)", this->name_.c_str()); + this->stop_action(); + } + + this->trigger(x...); + } }; /** A script type that queues new instances that are created. * * Only one instance of the script can be active at a time. */ -class QueueingScript : public Script, public Component { +template class QueueingScript : public Script, public Component { public: - void execute() override; - void stop() override; - void loop() override; + void execute(Ts... x) override { + if (this->is_action_running()) { + // num_runs_ is the number of *queued* instances, so total number of instances is + // num_runs_ + 1 + if (this->max_runs_ != 0 && this->num_runs_ + 1 >= this->max_runs_) { + this->esp_logw_(__LINE__, "Script '%s' maximum number of queued runs exceeded!", this->name_.c_str()); + return; + } + + this->esp_logd_(__LINE__, "Script '%s' queueing new instance (mode: queued)", this->name_.c_str()); + this->num_runs_++; + return; + } + + this->trigger(x...); + // Check if the trigger was immediate and we can continue right away. + this->loop(); + } + + void stop() override { + this->num_runs_ = 0; + Script::stop(); + } + + void loop() override { + if (this->num_runs_ != 0 && !this->is_action_running()) { + this->num_runs_--; + this->trigger(); + } + } + void set_max_runs(int max_runs) { max_runs_ = max_runs; } protected: @@ -67,48 +130,84 @@ class QueueingScript : public Script, public Component { * If a new instance is started while previous ones haven't finished yet, * the new one is executed in parallel to the other instances. */ -class ParallelScript : public Script { +template class ParallelScript : public Script { public: - void execute() override; + void execute(Ts... x) override { + if (this->max_runs_ != 0 && this->automation_parent_->num_running() >= this->max_runs_) { + this->esp_logw_(__LINE__, "Script '%s' maximum number of parallel runs exceeded!", this->name_.c_str()); + return; + } + this->trigger(x...); + } void set_max_runs(int max_runs) { max_runs_ = max_runs; } protected: int max_runs_ = 0; }; -template class ScriptExecuteAction : public Action { - public: - ScriptExecuteAction(Script *script) : script_(script) {} +template class ScriptExecuteAction; - void play(Ts... x) override { this->script_->execute(); } +template class ScriptExecuteAction, Ts...> : public Action { + public: + ScriptExecuteAction(Script *script) : script_(script) {} + + using Args = std::tuple...>; + + template void set_args(F... x) { args_ = Args{x...}; } + + void play(Ts... x) override { this->script_->execute_tuple(this->eval_args_(x...)); } protected: - Script *script_; + // NOTE: + // `eval_args_impl` functions evaluates `I`th the functions in `args` member. + // and then recursively calls `eval_args_impl` for the `I+1`th arg. + // if `I` = `N` all args have been stored, and nothing is done. + + template + void eval_args_impl_(std::tuple & /*unused*/, std::integral_constant /*unused*/, + std::integral_constant /*unused*/, Ts... /*unused*/) {} + + template + void eval_args_impl_(std::tuple &evaled_args, std::integral_constant /*unused*/, + std::integral_constant n, Ts... x) { + std::get(evaled_args) = std::get(args_).value(x...); // NOTE: evaluate `i`th arg, and store in tuple. + eval_args_impl_(evaled_args, std::integral_constant{}, n, + x...); // NOTE: recurse to next index. + } + + std::tuple eval_args_(Ts... x) { + std::tuple evaled_args; + eval_args_impl_(evaled_args, std::integral_constant{}, std::tuple_size{}, x...); + return evaled_args; + } + + Script *script_; + Args args_; }; -template class ScriptStopAction : public Action { +template class ScriptStopAction : public Action { public: - ScriptStopAction(Script *script) : script_(script) {} + ScriptStopAction(C *script) : script_(script) {} void play(Ts... x) override { this->script_->stop(); } protected: - Script *script_; + C *script_; }; -template class IsRunningCondition : public Condition { +template class IsRunningCondition : public Condition { public: - explicit IsRunningCondition(Script *parent) : parent_(parent) {} + explicit IsRunningCondition(C *parent) : parent_(parent) {} bool check(Ts... x) override { return this->parent_->is_running(); } protected: - Script *parent_; + C *parent_; }; -template class ScriptWaitAction : public Action, public Component { +template class ScriptWaitAction : public Action, public Component { public: - ScriptWaitAction(Script *script) : script_(script) {} + ScriptWaitAction(C *script) : script_(script) {} void play_complex(Ts... x) override { this->num_running_++; @@ -137,7 +236,7 @@ template class ScriptWaitAction : public Action, public C } protected: - Script *script_; + C *script_; std::tuple var_{}; }; diff --git a/esphome/const.py b/esphome/const.py index 56f48d2f92..c19da541eb 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -491,6 +491,7 @@ CONF_PACKAGES = "packages" CONF_PAGE_ID = "page_id" CONF_PAGES = "pages" CONF_PANASONIC = "panasonic" +CONF_PARAMETERS = "parameters" CONF_PASSWORD = "password" CONF_PATH = "path" CONF_PAYLOAD = "payload" diff --git a/tests/test2.yaml b/tests/test2.yaml index d485d3f16c..7d4cb4cbb2 100644 --- a/tests/test2.yaml +++ b/tests/test2.yaml @@ -532,6 +532,16 @@ text_sensor: ESP_LOGD("main", "The state is %s=%s", x.c_str(), id(version_sensor).state.c_str()); # yamllint enable rule:line-length - script.execute: my_script + - script.execute: + id: my_script_with_params + prefix: Running my_script_with_params + param2: 100 + param3: true + - script.execute: + id: my_script_with_params + prefix: Running my_script_with_params using lambda parameters + param2: !lambda return 200; + param3: !lambda return true; - homeassistant.service: service: notify.html5 data: @@ -597,6 +607,13 @@ script: mode: restart then: - lambda: 'ESP_LOGD("main", "Hello World!");' + - id: my_script_with_params + parameters: + prefix: string + param2: int + param3: bool + then: + - lambda: 'ESP_LOGD("main", (prefix + " Hello World!" + to_string(param2) + " " + to_string(param3)).c_str());' stepper: - platform: uln2003