craftbeerpi4-pione/venv3/lib/python3.7/site-packages/asyncio_mqtt/client.py
2021-03-03 23:49:41 +01:00

354 lines
15 KiB
Python

# SPDX-License-Identifier: BSD-3-Clause
import asyncio
import logging
import socket
from contextlib import contextmanager, suppress
try:
from contextlib import asynccontextmanager
except ImportError:
from async_generator import asynccontextmanager
import paho.mqtt.client as mqtt
from .error import MqttError, MqttCodeError, MqttConnectError
MQTT_LOGGER = logging.getLogger('mqtt')
MQTT_LOGGER.setLevel(logging.WARNING)
class Client:
def __init__(self, hostname, port=1883, *, username=None, password=None,
logger=None, client_id=None, tls_context=None, protocol=None,
will=None, clean_session=None, transport="tcp"):
self._hostname = hostname
self._port = port
self._loop = asyncio.get_event_loop()
self._connected = asyncio.Future()
self._disconnected = asyncio.Future()
self._pending_calls = {} # Pending subscribe, unsubscribe, and publish calls
self._pending_calls_threshold = 10
self._misc_task = None
if protocol is None:
protocol = mqtt.MQTTv311
self._client = mqtt.Client(client_id=client_id, protocol=protocol, clean_session=clean_session, transport=transport)
self._client.on_connect = self._on_connect
self._client.on_disconnect = self._on_disconnect
self._client.on_subscribe = self._on_subscribe
self._client.on_unsubscribe = self._on_unsubscribe
self._client.on_message = None
self._client.on_publish = self._on_publish
# Callbacks for custom event loop
self._client.on_socket_open = self._on_socket_open
self._client.on_socket_close = self._on_socket_close
self._client.on_socket_register_write = self._on_socket_register_write
self._client.on_socket_unregister_write = self._on_socket_unregister_write
if logger is None:
logger = MQTT_LOGGER
self._client.enable_logger(logger)
if username is not None and password is not None:
self._client.username_pw_set(username=username, password=password)
if tls_context is not None:
self._client.tls_set_context(tls_context)
if will is not None:
self._client.will_set(
will.topic,
will.payload,
will.qos,
will.retain,
will.properties
)
@property
def id(self):
"""Return the client ID.
Note that paho-mqtt stores the client ID as `bytes` internally.
We assume that the client ID is a UTF8-encoded string and decode
it first.
"""
return self._client._client_id.decode()
async def connect(self, *, timeout=10):
try:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._client.connect, self._hostname, self._port, 60)
# paho.mqttClient.socket() return non-None after the call to connect.
client_socket = self._client.socket()
if not isinstance(client_socket, mqtt.WebsocketWrapper):
client_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
# paho.mqtt.Client.connect may raise one of several exceptions.
# We convert all of them to the common MqttError for user convenience.
# See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770
except (socket.error, OSError, mqtt.WebsocketConnectionError) as error:
raise MqttError(str(error))
await self._wait_for(self._connected, timeout=timeout)
async def disconnect(self, *, timeout=10):
rc = self._client.disconnect()
# Early out on error
if rc != mqtt.MQTT_ERR_SUCCESS:
raise MqttCodeError(rc, "Could not disconnect")
# Wait for acknowledgement
await self._wait_for(self._disconnected, timeout=timeout)
async def force_disconnect(self):
self._disconnected.set_result(None)
async def subscribe(self, *args, timeout=10, **kwargs):
result, mid = self._client.subscribe(*args, **kwargs)
# Early out on error
if result != mqtt.MQTT_ERR_SUCCESS:
raise MqttCodeError(result, 'Could not subscribe to topic')
# Create future for when the on_subscribe callback is called
cb_result = asyncio.Future()
with self._pending_call(mid, cb_result):
# Wait for cb_result
return await self._wait_for(cb_result, timeout=timeout)
async def unsubscribe(self, *args, timeout=10):
result, mid = self._client.unsubscribe(*args)
# Early out on error
if result != mqtt.MQTT_ERR_SUCCESS:
raise MqttCodeError(result, 'Could not unsubscribe from topic')
# Create event for when the on_unsubscribe callback is called
confirmation = asyncio.Event()
with self._pending_call(mid, confirmation):
# Wait for confirmation
await self._wait_for(confirmation.wait(), timeout=timeout)
async def publish(self, *args, timeout=10, **kwargs):
info = self._client.publish(*args, **kwargs) # [2]
# Early out on error
if info.rc != mqtt.MQTT_ERR_SUCCESS:
raise MqttCodeError(info.rc, 'Could not publish message')
# Early out on immediate success
if info.is_published():
return
# Create event for when the on_publish callback is called
confirmation = asyncio.Event()
with self._pending_call(info.mid, confirmation):
# Wait for confirmation
await self._wait_for(confirmation.wait(), timeout=timeout)
@asynccontextmanager
async def filtered_messages(self, topic_filter, *, queue_maxsize=0):
"""Return async generator of messages that match the given filter.
Use queue_maxsize to restrict the queue size. If the queue is full,
incoming messages will be discarded (and a warning is logged).
If queue_maxsize is less than or equal to zero, the queue size is infinite.
Example use:
async with client.filtered_messages('floors/+/humidity') as messages:
async for message in messages:
print(f'Humidity reading: {message.payload.decode()}')
"""
cb, generator = self._cb_and_generator(log_context=f'topic_filter="{topic_filter}"',
queue_maxsize=queue_maxsize)
try:
self._client.message_callback_add(topic_filter, cb)
# Back to the caller (run whatever is inside the with statement)
yield generator
finally:
# We are exitting the with statement. Remove the topic filter.
self._client.message_callback_remove(topic_filter)
@asynccontextmanager
async def unfiltered_messages(self, *, queue_maxsize=0):
"""Return async generator of all messages that are not caught in filters."""
# Early out
if self._client.on_message is not None:
# TODO: This restriction can easily be removed.
raise RuntimeError('Only a single unfiltered_messages generator can be used at a time.')
cb, generator = self._cb_and_generator(log_context='unfiltered',
queue_maxsize=queue_maxsize)
try:
self._client.on_message = cb
# Back to the caller (run whatever is inside the with statement)
yield generator
finally:
# We are exitting the with statement. Unset the callback.
self._client.on_message = None
def _cb_and_generator(self, *, log_context, queue_maxsize=0):
# Queue to hold the incoming messages
messages = asyncio.Queue(maxsize=queue_maxsize)
# Callback for the underlying API
def _put_in_queue(client, userdata, msg):
try:
messages.put_nowait(msg)
except asyncio.QueueFull:
MQTT_LOGGER.warning(f'[{log_context}] Message queue is full. Discarding message.')
# The generator that we give to the caller
async def _message_generator():
# Forward all messages from the queue
while True:
# Wait until we either:
# 1. Receive a message
# 2. Disconnect from the broker
get = self._loop.create_task(messages.get())
try:
done, _ = await asyncio.wait(
(get, self._disconnected),
return_when=asyncio.FIRST_COMPLETED
)
except asyncio.CancelledError:
# If the asyncio.wait is cancelled, we must make sure
# to also cancel the underlying tasks.
get.cancel()
raise
if get in done:
# We received a message. Return the result.
yield get.result()
else:
# We got disconnected from the broker. Cancel the "get" task.
get.cancel()
# Stop the generator with the following exception
raise MqttError("Disconnected during message iteration")
return _put_in_queue, _message_generator()
async def _wait_for(self, *args, **kwargs):
try:
return await asyncio.wait_for(*args, **kwargs)
except asyncio.TimeoutError:
raise MqttError('Operation timed out')
@contextmanager
def _pending_call(self, mid, value):
if mid in self._pending_calls:
raise RuntimeError(f'There already exists a pending call for message ID "{mid}"')
self._pending_calls[mid] = value # [1]
try:
# Log a warning if there is a concerning number of pending calls
pending = len(self._pending_calls)
if pending > self._pending_calls_threshold:
MQTT_LOGGER.warning(f'There are {pending} pending publish calls.')
# Back to the caller (run whatever is inside the with statement)
yield
finally:
# The normal procedure is:
# * We add the item at [1]
# * A callback will remove the item
#
# However, if the callback doesn't get called (e.g., due to a
# network error) we still need to remove the item from the dict.
self._pending_calls.pop(mid, None)
def _on_connect(self, client, userdata, flags, rc, properties=None):
# Return early if already connected. Sometimes, paho-mqtt calls _on_connect
# multiple times. Maybe because we receive multiple CONNACK messages
# from the server. In any case, we return early so that we don't set
# self._connected twice (as it raises an asyncio.InvalidStateError).
if self._connected.done():
return
if rc == mqtt.CONNACK_ACCEPTED:
self._connected.set_result(rc)
else:
self._connected.set_exception(MqttConnectError(rc))
def _on_disconnect(self, client, userdata, rc, properties=None):
# Return early if the disconnect is already acknowledged.
# Sometimes (e.g., due to timeouts), paho-mqtt calls _on_disconnect
# twice. We return early to avoid setting self._disconnected twice
# (as it raises an asyncio.InvalidStateError).
if self._disconnected.done():
return
# Return early if we are not connected yet. This avoids calling
# `_disconnected.set_exception` with an exception that will never
# be retrieved (since `__aexit__` won't get called if `__aenter__`
# fails). In turn, this avoids asyncio debug messages like the
# following:
#
# "[asyncio] Future exception was never retrieved"
#
# See also: https://docs.python.org/3/library/asyncio-dev.html#detect-never-retrieved-exceptions
if not self._connected.done() or self._connected.exception() is not None:
return
if rc == mqtt.MQTT_ERR_SUCCESS:
self._disconnected.set_result(rc)
else:
self._disconnected.set_exception(MqttCodeError(rc, 'Unexpected disconnect'))
def _on_subscribe(self, client, userdata, mid, granted_qos, properties=None):
try:
self._pending_calls.pop(mid).set_result(granted_qos)
except KeyError:
MQTT_LOGGER.error(f'Unexpected message ID "{mid}" in on_subscribe callback')
def _on_unsubscribe(self, client, userdata, mid, properties=None, reasonCodes=None):
try:
self._pending_calls.pop(mid).set()
except KeyError:
MQTT_LOGGER.error(f'Unexpected message ID "{mid}" in on_unsubscribe callback')
def _on_publish(self, client, userdata, mid):
try:
self._pending_calls.pop(mid).set()
except KeyError:
# Do nothing since [2] may call on_publish before it even returns.
# That is, the message may already be published before we even get a
# chance to set up the 'pending_call' logic.
pass
def _on_socket_open(self, client, userdata, sock):
def cb():
client.loop_read()
self._loop.add_reader(sock.fileno(), cb)
self._misc_task = self._loop.create_task(self._misc_loop())
def _on_socket_close(self, client, userdata, sock):
self._loop.remove_reader(sock.fileno())
with suppress(asyncio.CancelledError):
self._misc_task.cancel()
def _on_socket_register_write(self, client, userdata, sock):
def cb():
client.loop_write()
self._loop.add_writer(sock, cb)
def _on_socket_unregister_write(self, client, userdata, sock):
self._loop.remove_writer(sock)
async def _misc_loop(self):
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
await asyncio.sleep(1)
async def __aenter__(self):
"""Connect to the broker."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc, tb):
"""Disconnect from the broker."""
# Early out if already disconnected...
if self._disconnected.done():
disc_exc = self._disconnected.exception()
if disc_exc is not None:
# ...by raising the error that caused the disconnect
raise disc_exc
# ...by returning since the disconnect was intentional
return
# Try to gracefully disconnect from the broker
try:
await self.disconnect()
except MqttError as error:
# We tried to be graceful. Now there is no mercy.
MQTT_LOGGER.warning(f'Could not gracefully disconnect due to "{error}". Forcing disconnection.')
await self.force_disconnect()
# TODO: This should be a (frozen) dataclass (from Python 3.7)
# when we drop Python 3.6 support
class Will:
def __init__(self, topic, payload=None, qos=0, retain=False, properties=None):
self.topic = topic
self.payload = payload
self.qos = qos
self.retain = retain
self.properties = properties