Add native API user-defined services (#453)

This commit is contained in:
Otto Winter 2019-02-26 19:38:28 +01:00 committed by GitHub
parent 3b00cfd6c4
commit 311e837196
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 5 deletions

View file

@ -1,13 +1,15 @@
import voluptuous as vol import voluptuous as vol
from esphome import automation
from esphome.automation import ACTION_REGISTRY from esphome.automation import ACTION_REGISTRY
import esphome.config_validation as cv import esphome.config_validation as cv
from esphome.const import CONF_DATA, CONF_DATA_TEMPLATE, CONF_ID, CONF_PASSWORD, CONF_PORT, \ from esphome.const import CONF_DATA, CONF_DATA_TEMPLATE, CONF_ID, CONF_PASSWORD, CONF_PORT, \
CONF_REBOOT_TIMEOUT, CONF_SERVICE, CONF_VARIABLES CONF_REBOOT_TIMEOUT, CONF_SERVICE, CONF_VARIABLES, CONF_SERVICES, CONF_TRIGGER_ID
from esphome.core import CORE from esphome.core import CORE
from esphome.cpp_generator import Pvariable, add, get_variable, process_lambda from esphome.cpp_generator import Pvariable, add, get_variable, process_lambda
from esphome.cpp_helpers import setup_component from esphome.cpp_helpers import setup_component
from esphome.cpp_types import Action, App, Component, StoringController, esphome_ns from esphome.cpp_types import Action, App, Component, StoringController, esphome_ns, Trigger, bool_, \
int32, float_, std_string
api_ns = esphome_ns.namespace('api') api_ns = esphome_ns.namespace('api')
APIServer = api_ns.class_('APIServer', Component, StoringController) APIServer = api_ns.class_('APIServer', Component, StoringController)
@ -15,11 +17,35 @@ HomeAssistantServiceCallAction = api_ns.class_('HomeAssistantServiceCallAction',
KeyValuePair = api_ns.class_('KeyValuePair') KeyValuePair = api_ns.class_('KeyValuePair')
TemplatableKeyValuePair = api_ns.class_('TemplatableKeyValuePair') TemplatableKeyValuePair = api_ns.class_('TemplatableKeyValuePair')
UserService = api_ns.class_('UserService', Trigger)
ServiceTypeArgument = api_ns.class_('ServiceTypeArgument')
ServiceArgType = api_ns.enum('ServiceArgType')
SERVICE_ARG_TYPES = {
'bool': ServiceArgType.SERVICE_ARG_TYPE_BOOL,
'int': ServiceArgType.SERVICE_ARG_TYPE_INT,
'float': ServiceArgType.SERVICE_ARG_TYPE_FLOAT,
'string': ServiceArgType.SERVICE_ARG_TYPE_STRING,
}
SERVICE_ARG_NATIVE_TYPES = {
'bool': bool_,
'int': int32,
'float': float_,
'string': std_string,
}
CONFIG_SCHEMA = cv.Schema({ CONFIG_SCHEMA = cv.Schema({
cv.GenerateID(): cv.declare_variable_id(APIServer), cv.GenerateID(): cv.declare_variable_id(APIServer),
vol.Optional(CONF_PORT, default=6053): cv.port, vol.Optional(CONF_PORT, default=6053): cv.port,
vol.Optional(CONF_PASSWORD, default=''): cv.string_strict, vol.Optional(CONF_PASSWORD, default=''): cv.string_strict,
vol.Optional(CONF_REBOOT_TIMEOUT): cv.positive_time_period_milliseconds, vol.Optional(CONF_REBOOT_TIMEOUT): cv.positive_time_period_milliseconds,
vol.Optional(CONF_SERVICES): automation.validate_automation({
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_variable_id(UserService),
vol.Required(CONF_SERVICE): cv.valid_name,
vol.Optional(CONF_VARIABLES, default={}): cv.Schema({
cv.validate_id_name: cv.one_of(*SERVICE_ARG_TYPES, lower=True),
}),
}),
}).extend(cv.COMPONENT_SCHEMA.schema) }).extend(cv.COMPONENT_SCHEMA.schema)
@ -34,6 +60,21 @@ def to_code(config):
if CONF_REBOOT_TIMEOUT in config: if CONF_REBOOT_TIMEOUT in config:
add(api.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT])) add(api.set_reboot_timeout(config[CONF_REBOOT_TIMEOUT]))
for conf in config.get(CONF_SERVICES, []):
template_args = []
func_args = []
service_type_args = []
for name, var_ in conf[CONF_VARIABLES].items():
native = SERVICE_ARG_NATIVE_TYPES[var_]
template_args.append(native)
func_args.append((native, name))
service_type_args.append(ServiceTypeArgument(name, SERVICE_ARG_TYPES[var_]))
func = api.make_user_service_trigger.template(*template_args)
rhs = func(conf[CONF_SERVICE], service_type_args)
type_ = UserService.template(*template_args)
trigger = Pvariable(conf[CONF_TRIGGER_ID], rhs, type=type_)
automation.build_automations(trigger, func_args, conf)
setup_component(api, config) setup_component(api, config)

View file

@ -162,7 +162,7 @@ def int_(value):
hex_int = vol.Coerce(hex_int_) hex_int = vol.Coerce(hex_int_)
def variable_id_str_(value): def validate_id_name(value):
value = string(value) value = string(value)
if not value: if not value:
raise vol.Invalid("ID must not be empty") raise vol.Invalid("ID must not be empty")
@ -185,7 +185,7 @@ def use_variable_id(type):
if value is None: if value is None:
return core.ID(None, is_declaration=False, type=type) return core.ID(None, is_declaration=False, type=type)
return core.ID(variable_id_str_(value), is_declaration=False, type=type) return core.ID(validate_id_name(value), is_declaration=False, type=type)
return validator return validator
@ -195,7 +195,7 @@ def declare_variable_id(type):
if value is None: if value is None:
return core.ID(None, is_declaration=True, type=type) return core.ID(None, is_declaration=True, type=type)
return core.ID(variable_id_str_(value), is_declaration=True, type=type) return core.ID(validate_id_name(value), is_declaration=True, type=type)
return validator return validator

View file

@ -23,6 +23,7 @@ CONF_ARDUINO_VERSION = 'arduino_version'
CONF_LOCAL = 'local' CONF_LOCAL = 'local'
CONF_REPOSITORY = 'repository' CONF_REPOSITORY = 'repository'
CONF_COMMIT = 'commit' CONF_COMMIT = 'commit'
CONF_SERVICES = 'services'
CONF_TAG = 'tag' CONF_TAG = 'tag'
CONF_BRANCH = 'branch' CONF_BRANCH = 'branch'
CONF_LOGGER = 'logger' CONF_LOGGER = 'logger'