from collections import deque
from contextlib import ContextDecorator
from datetime import timedelta
import re
from logging import getLogger
from tornado import gen
from tornado import ioloop
from tornado.concurrent import Future
from tornado.iostream import StreamClosedError
from toro import Event
import toro
from broker.access_control import Authorization
from broker.actions import IncomingAction, OutgoingAction
from broker.concurency import CancelledException, DummyFuture
from broker.persistence import InMemoryClientPersistence, OutgoingPublishesBase, PacketIdsDepletedError
from broker.util import MQTTUtils
from broker.factory import MQTTMessageFactory, MQTTMessageMakeError, IncomingActionFactory, TypeFactoryError, \
OutgoingActionFactory
from broker.messages import Publish, BaseMQTTMessage
from broker.connection import MQTTConnection, MQTTConnectionClosed
[docs]class MQTTClient():
"""
Objects of this class encapsulate and abstract all aspects of a given
client. A MQTTClient object may refer to a live, connected, client or a
known client which albeit disconnected had the clean_sessions flag set to
false and, thus, is kept by the server as an end point for routed messages.
One may call :meth:`self.is_connected` to check whether is there a connected
client or not.
:param MQTTServer server: The server which the client is bound to;
:param MQTTConnection connection: The connection to be used;
:param str uid: A string used as the client's id;
:param bool clean_session: The clean session flag, as per MQTT Protocol;
:param int keep_alive: The keep alive interval, in seconds.
:param ClientPersistenceBase persistence: An object that provides persistence
"""
def __init__(self, server, connection, authorization=None,
uid=None, clean_session=False,
keep_alive=60, persistence=None):
self.uid = uid
self.logger = getLogger('activity.clients')
self.persistence = persistence or InMemoryClientPersistence(uid)
self.subscriptions = ClientSubscriptions(persistence.subscriptions)
self._connected = Event()
self.last_will = None
self.connection = None
self.clean_session = None
self.keep_alive = None
self.server = server
self.authorization = authorization or Authorization.no_restrictions()
# Queue of the packets ready to be delivered
self.outgoing_queue = OutgoingQueue(self.persistence.outgoing_publishes)
self.update_configuration(clean_session, keep_alive)
self.update_connection(connection)
@property
def incoming_packet_ids(self):
return self.persistence.incoming_packet_ids
@property
def redelivery_deadline(self):
s = self.keep_alive if self.keep_alive > 0 else 60
return timedelta(seconds=s)
[docs] def update_connection(self, connection):
"""
Updates the client's connection by disconnecting the previous one and
configuring the new one's keep alive time according to ``keep_alive``.
:param MQTTConnection connection: The new connection to be used;
"""
if self.is_connected():
self.disconnect()
if connection is not None:
assert isinstance(connection, MQTTConnection)
self.connection = connection
self.connection.set_timeout(self.keep_alive)
self.connection.set_close_callback(self._on_stream_close)
self.connection.set_timeout_callback(self._on_connection_timeout)
if not self.connection.closed():
self._connected.set()
[docs] def update_configuration(self, clean_session=False, keep_alive=60):
"""
Updates the internal attributes.
:param bool clean_session: A flag indicating whether this session should
be brand new or attempt to reuse the last known session for a client
with the same :attr:`self.uid` as this.
:param int keep_alive: Connection's keep alive setting, in seconds.
"""
self.clean_session = clean_session
self.keep_alive = keep_alive
def update_authorization(self, authorization):
self.authorization = authorization
if not self.clean_session:
self.unsubscribe_denied_topics()
[docs] def start(self):
"""
Starts the client fetching, processing and dispatching routines. Should
be called after object instantiation or a :meth:`self.update_connection`
call.
The following coroutines are started:
* :meth:`self._process_incoming_messages`
* :meth:`self._process_outgoing_messages`
"""
self.logger.debug("[uid: %s] starting co-routines" % self.uid)
# put connection object in `_process_incoming_messages` scope
# so corner cases can be handled appropriately
self._process_incoming_packets(self.connection)
self._process_outgoing_packets(self.connection)
@property
[docs] def connected(self):
"""
An :class:`toro.Event` instance that is set whenever the client is
connected and clear on disconnection. It's safe to wait on this property
before stream related operations.
"""
return self._connected
[docs] def is_connected(self):
"""
A shorthand for :meth:connected.is_set().
"""
return self._connected.is_set()
[docs] def handle_last_will(self):
"""
Checks if a client has a pending last will message and dispatches it for
server processing.
"""
if self.last_will is not None:
self.dispatch_to_server(self.last_will)
self.logger.debug("[uid: %s] Dispatched last will message %s" %
(self.uid, self.last_will.log_info()))
def _get_action(self, msg, factory):
try:
action = factory.make(msg)
action.bind_client(self)
return action
except KeyError:
self.logger.exception("[uid: %s] Wrong message type %s" %
(self.uid, type(msg)))
self.logger.debug(
"[uid: %s] Will be disconnect due to invalid message" %
self.uid
)
self.disconnect()
@gen.coroutine
[docs] def _process_incoming_packets(self, connection):
"""
This coroutinte fetches the message raw data from
:attr:`self.incoming_queue`, parses it into the corresponding message
object (an instance of one of the
:class:`broker.messages.BaseMQTTMessage` subclasses) and passes it to
the :attr:`self.incoming_transaction_manager` to be processed.
It is started by calling :meth:`self.start()` and stops upon client
disconnection.
"""
while connection.is_readable:
with client_process_context(self, connection):
msg = yield connection.read_message()
msg_obj = MQTTMessageFactory.make(msg)
self.logger.debug("[B << C] [uid: %s] %s" %
(self.uid, msg_obj.log_info()))
action = self._get_action(msg_obj, IncomingActionFactory)
assert isinstance(action, IncomingAction)
action.run()
self.logger.debug("[uid: %s] stopping _process_incoming_messages"
% self.uid)
@gen.coroutine
[docs] def _process_outgoing_packets(self, connection):
"""
This coroutinte fetches the message raw data from
:attr:`self.outgoing_queue`, parses it into the corresponding message
object (an instance of one of the
:class:`broker.messages.BaseMQTTMessage` subclasses) and passes
it to the :attr:`self.outgoing_transaction_manager` to be processed.
It is started by calling :meth:`self.start()` and stops upon client
disconnection.
"""
self.outgoing_queue.clear()
self.outgoing_queue.retry_pending()
while not connection.closed():
with client_process_context(self, connection):
msg = yield self.outgoing_queue.get()
assert isinstance(msg, BaseMQTTMessage)
action = self._get_action(msg, OutgoingActionFactory)
assert isinstance(action, OutgoingAction)
self.logger.debug("[B >> C] [uid: %s] %s" %
(self.uid, msg.log_info()))
yield self.write(action.get_data())
action.post_write()
self.logger.debug("[uid: %s] stopping _process_outgoing_messages" % self.uid)
[docs] def publish(self, msg):
"""
Puts a publish packet on the :attr:`self.outgoing_queue` to be sent
to the client.
:param msg: The message to be set or a iterable of its bytes.
:type msg: Publish
"""
try:
self.outgoing_queue.put_publish(msg)
except PacketIdsDepletedError:
self.logger.error('[uid: %s] Packet IDs depleted' % self.uid)
[docs] def send_packet(self, packet):
"""
Puts a packet on the :attr:`self.outgoing_queue` to be sent to the
client.
"""
self.outgoing_queue.put(packet)
@gen.coroutine
[docs] def write(self, msg):
"""
Writes a MQTT Message to the client. If the client isn't connected,
waits for the :attr:`self.connected` event to be set.
:param MQTT Message msg: The message to be send. It must be a instance
of :class:`broker.messages.BaseMQTTMessage` or it's subclasses.
"""
yield self.connected.wait()
yield self.connection.write_message(msg)
[docs] def dispatch_to_server(self, pub_msg):
"""
Dispatches a Publish message to the server for further processing, ie.
delivering it to the appropriate subscribers.
:param Publish pub_msg: A :class:`broker.messages.Publish` instance.
"""
assert isinstance(pub_msg, Publish)
if self.authorization.is_publish_allowed(pub_msg.topic):
self.server.handle_incoming_publish(pub_msg)
else:
self.logger.warn("[uid: %s] is not allowed to publish on %s" %
(self.uid, pub_msg.topic))
[docs] def subscribe(self, subscription_mask, qos):
"""
Subscribes the client to a topic or wildcarded mask at the informed QoS
level. Calling this method also signalizes the server to enqueue the
matching retained messages.
When called for a (`subscripition_mask`, `qos`) pair for which the
client has already a subscription it will silently ignore the command
and return a suback.
:param string subscription_mask: A MQTT valid topic or wildcarded mask;
:param int qos: A valid QoS level (0, 1 or 2).
:rtype: int
:return: The granted QoS level (0, 1 or 2) or 0x80 for failed
subscriptions.
"""
if qos not in [0, 1, 2]:
self.logger.warn('client tried to subscribe with invalid qos %s' % qos)
return 0x80
new_subscription = subscription_mask not in self.subscriptions or \
self.subscriptions.qos(subscription_mask) != qos
if not self.authorization.is_subscription_allowed(subscription_mask):
self.logger.warn("[uid: %s] is not allowed to subscribe on %s" %
(self.uid, subscription_mask))
del self.subscriptions[subscription_mask]
qos = 0x80
elif new_subscription:
ereg = MQTTUtils.convert_to_ereg(subscription_mask)
if ereg is not None:
self.subscriptions.add(subscription_mask, qos, re.compile(ereg))
self.server.enqueue_retained_message(self, subscription_mask)
else:
qos = 0x80
return qos
[docs] def unsubscribe(self, topics):
"""
Unsubscribes the client from each topic in ``topics``. Safely
ignores topics which the client is not subscribed to.
:param iterable topics: An iterable of MQTT valid topic strings.
"""
for topic in topics:
del self.subscriptions[topic]
def unsubscribe_denied_topics(self):
for topic in self.subscriptions.masks:
if not self.authorization.is_subscription_allowed(topic):
self.unsubscribe(topic)
[docs] def disconnect(self):
"""
Closes the socket and disconnects the client. If
:attr:`self.clean_session` is set, ensures that the incoming and
outgoing queues are cleared and calls the server client removing
routine.
.. hint::
It's safe to call this function without checking whether the
connection is open or not.
"""
if self.is_connected():
self.logger.debug("[uid: %s] disconnecting client" % self.uid)
self.connection.close()
self._connected.clear()
self.outgoing_queue.clear()
if self.clean_session:
self.logger.debug("[uid: %s] cleaning session" % self.uid)
self.server.remove_client(self)
[docs] def get_list_of_delivery_qos(self, msg):
"""
Matches the ``msg.topic`` against all the current subscriptions and
returns a list containing the QoS level for each matched subscription.
:param Publish msg: A MQTT valid message.
:rtype: tuple
:return: A list of QoS levels, ie [0, 0, 1, 2, 0, 2]
"""
# TODO CACHING THESE RESULTS PER CLIENT
qos_list = []
for mask in self.subscriptions.masks:
qos = self.get_matching_qos(msg, mask)
if qos is not None:
qos_list.append(qos)
return qos_list
[docs] def get_matching_qos(self, msg, subscriptions_mask):
"""
Matches the ``msg.topic`` against a single subscription defined by the
subscription mask and returns the QoS level on which the message should
be delivered.
:param Publish msg: Message to be analysed;
:param subscriptions_mask: A subscription mask that identifies one of
the client's subscriptions.
:return: QoS Level or None, in case it doesn't match.
"""
assert isinstance(msg, Publish)
qos, pattern = self.subscriptions[subscriptions_mask]
if pattern.match(msg.topic) is not None:
self.logger.debug("[uid: %s] %s matched by %s, MAXQOS: %d"
% (self.uid, msg.topic, subscriptions_mask, qos))
return min(qos, msg.qos)
return None
[docs] def _on_connection_timeout(self, connection):
"""
Callback called when the connection times out. Ensures clearing the
:attr:`self.connected` event and processing the :meth:`self.disconnect`
method.
"""
self.logger.debug("[uid: %s] Connection timeout" % self.uid)
self.handle_last_will()
self.disconnect()
connection.set_close_callback(None)
connection.set_timeout_callback(None)
[docs] def _on_stream_close(self, connection):
"""
Callback called when the stream closes. Ensures clearing the
:attr:`self.connected` event and processing the :meth:`self.disconnect`
method.
"""
self.logger.debug("[uid: %s] Stream closed %s" %
(self.uid, self.connection._address))
try:
if self.connection.closed_due_error():
self.handle_last_will()
self.disconnect()
except MQTTConnectionClosed:
# already closed / disconnected
pass
connection.set_close_callback(None)
connection.set_timeout_callback(None)
class client_process_context(ContextDecorator):
def __init__(self, client, connection):
assert isinstance(client, MQTTClient)
self.client = client
self.connection = connection
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type == StreamClosedError:
self.client.disconnect()
elif exc_type == toro.Timeout:
self.client.logger.debug("[uid: %s] _process_incoming_messages: timed out" %
(self.client.uid, ))
self.client._on_connection_timeout(self.connection)
elif exc_type == TypeFactoryError:
self.client.logger.debug('[uid: %s] %s' % (self.client.uid, exc_val.message))
self.client.disconnect()
elif exc_type in [CancelledException, MQTTConnectionClosed, MQTTMessageMakeError]:
pass
elif isinstance(exc_val, Exception):
self.client.logger.exception(
"[uid: %s] Unhandled exception on client_process_context: %s" %
(self.client.uid, str(exc_val))
)
return True # suppress the raised exception
class ClientSubscriptions():
"""
Encapsulates subscription persistence access and mask regex caching.
"""
def __init__(self, subscriptions):
self._subscriptions = subscriptions
self._re_cache = dict()
def add(self, mask, qos, pattern=None):
self._re_cache[mask] = pattern or re.compile(MQTTUtils.convert_to_ereg(mask))
self._subscriptions[mask] = qos
def __contains__(self, item):
return item in self._subscriptions
def __getitem__(self, mask):
assert self.__contains__(mask)
return self.qos(mask), self._get_regex(mask)
def qos(self, mask):
return self._subscriptions[mask]
def _get_regex(self, mask):
compiled = self._re_cache.get(mask, None)
if not compiled:
ereg = MQTTUtils.convert_to_ereg(mask)
compiled = re.compile(ereg)
self._re_cache[mask] = compiled
return compiled
@property
def masks(self):
return self._subscriptions.keys()
def __delitem__(self, mask):
if mask in self._subscriptions:
del self._subscriptions[mask]
if mask in self._re_cache:
del self._re_cache[mask]
class OutgoingQueue():
"""
This class controls packets to be delivered to the remote client.
It encapsulates the logic to send packets and start new publish flows.
"""
def __init__(self, outgoing_publishes):
self.max_inflight = 1
assert isinstance(outgoing_publishes, OutgoingPublishesBase)
self.publishes = outgoing_publishes
self.packets = deque()
self.retrial_handles = dict()
self.future = DummyFuture()
def retry_pending(self):
for packet in self.publishes.get_all_inflight():
self.retrial_handles[packet.id] = None
self.put(packet)
def put(self, packet):
"""
Puts a packet to the outgoing queue.
No checks are done on the provided packets.
"""
assert isinstance(packet, BaseMQTTMessage)
if not self.future.done():
self.future.set_result(packet)
else:
self.packets.append(packet)
def put_publish(self, packet):
"""
Puts a publish packet to the outgoing queue.
If the QoS level is 0 it is only placed on the outgoing queue.
Otherwise the packet is persisted and scheduled for publishing.
"""
assert isinstance(packet, Publish)
if packet.qos == 0:
self.put(packet)
else:
self.publishes.insert(packet)
self._start_next_flow()
def set_sent(self, packet_id):
if self.publishes.is_inflight(packet_id):
self.publishes.set_sent(packet_id)
def is_sent(self, packet_id):
return self.publishes.is_inflight(packet_id) and \
self.publishes.is_sent(packet_id)
def is_pubconf(self, packet_id):
return self.publishes.is_inflight(packet_id) and \
self.publishes.is_pubconf(packet_id)
def set_pubconf(self, packet_id):
if self.publishes.is_inflight(packet_id):
self.publishes.set_pubconf(packet_id)
def get(self):
"""
Gets the next packet to be sent to the remote client.
:return: a `Future`
"""
if not self.future.done():
self.future.set_exception(CancelledException('Only one future supported'))
self.future = Future()
# try to start next publish flow first,
# otherwise the outgoing packets would have to deplete before
# any publish flows could start
started = self._start_next_flow()
if not started and self.packets:
self.future.set_result(self.packets.popleft())
return self.future
def clear(self):
"""
Clears the in-memory state of the outgoing queue. The data in
persistence is left as is. If there is any pending `Future`
waiting for result, it is cancelled.
"""
for packet_id in self.retrial_handles.keys():
self.cancel_retrial(packet_id)
self.packets.clear()
self.retrial_handles.clear()
if not self.future.done():
msg = 'Outgoing queue was cleansed'
self.future.set_exception(CancelledException(msg))
self.future = DummyFuture()
def flow_completed(self, packet_id):
"""
Removes the publish corresponding to the `packet-id` from the inflight
set and from persistence.
"""
if packet_id:
self.publishes.remove(packet_id)
if packet_id in self.retrial_handles:
self.cancel_retrial(packet_id)
del self.retrial_handles[packet_id]
self._start_next_flow()
def _start_next_flow(self):
if self.publishes.inflight_len < self.max_inflight:
packet = self.publishes.get_next()
if packet:
self.retrial_handles[packet.id] = None
self.put(packet)
return True
return False
def _retry_flow(self, packet_id):
if self.publishes.is_inflight(packet_id):
self.retrial_handles[packet_id] = None
self.put(self.publishes.get_inflight(packet_id))
def set_retrial(self, packet_id, deadline):
handle = self.retrial_handles.get(packet_id)
if handle:
self.cancel_retrial(packet_id)
callback = lambda: self._retry(packet_id)
handle = ioloop.IOLoop.current().add_timeout(deadline, callback)
self.retrial_handles[packet_id] = handle
def cancel_retrial(self, packet_id):
handler = self.retrial_handles.get(packet_id)
if handler is not None:
ioloop.IOLoop.current().remove_timeout(handler)
self.retrial_handles[packet_id] = None
def _retry(self, packet_id):
if packet_id in self.retrial_handles:
self._retry_flow(packet_id)