import tornado.concurrent
from tornado.iostream import StreamClosedError
from tornado.tcpserver import TCPServer
from tornado.log import access_log
from tornado import gen
from logging import getLogger
import toro
from broker import MQTTConstants
from broker.access_control import NoAuthentication, Authorization
from broker.client import MQTTClient
from broker.exceptions import ConnectError
from broker.messages import Publish, Connect, Connack
from broker.connection import MQTTConnection
from broker.factory import MQTTMessageFactory
from broker.persistence import InMemoryPersistence
client_logger = getLogger('activity.clients')
[docs]class MQTTServer(TCPServer):
"""
This is the highest abstraction of the package and represents the whole MQTT
Broker. It's main roles are handling incoming connections, keeping tabs for
the known client sessions and dispatching messages based on subscription
matching.
"""
def __init__(self, authentication=None, persistence=None, clients=None,
ssl_options=None):
super().__init__(ssl_options=ssl_options)
self.clients = clients if clients is not None else dict()
assert isinstance(self.clients, dict)
self.persistence = persistence or InMemoryPersistence()
self.authentication = authentication or NoAuthentication()
self.recreate_sessions(self.persistence.get_client_uids())
self._retained_messages = RetainedMessages(self.persistence.get_retained_messages())
assert isinstance(self._retained_messages, RetainedMessages)
def recreate_sessions(self, uids):
access_log.info("recreating %s sessions" % len(uids))
for uid in uids:
if uid not in self.clients:
self.add_client(self.recreate_client(str(uid)))
[docs] def get_known_client(self, connect_msg):
"""
Returns a known MQTTClient instance that has the same uid defined on
the Connect message.
.. caution::
If the connect message defines the usage of a clean session, this
method will clear any previous session matching this client ID and
automatically return None
:param connect_msg Connect: A connect message that specifies the client.
"""
assert isinstance(connect_msg, Connect)
client = self.clients.get(connect_msg.client_uid)
if client is not None:
assert isinstance(client, MQTTClient)
if connect_msg.clean_session:
# Force server to remove the client, regardless of its
# previous clean_sessions configuration
self.remove_client(client)
client = None
return client
def get_or_create_client(self, connection, msg, authorization):
if not authorization.is_connection_allowed():
raise ConnectError("Authentication failed uid:%s user:%s"
% (msg.client_uid, msg.username))
client = self.get_known_client(msg)
if client is None:
client = self.create_client(connection, msg, authorization)
else:
self.update_client(connection, msg, authorization, client)
self.configure_last_will(client, msg)
return client
def create_client(self, connection, msg, authorization):
client_persistence = self.persistence.get_for_client(msg.client_uid)
client = MQTTClient(
server=self,
connection=connection,
authorization=authorization,
uid=msg.client_uid,
clean_session=msg.clean_session,
keep_alive=msg.keep_alive,
persistence=client_persistence,
)
access_log.info("[uid: %s] new session created"
% client.uid)
return client
def recreate_client(self, client_uid):
return MQTTClient(
server=self,
connection=None,
uid=client_uid,
clean_session=False,
persistence=self.persistence.get_for_client(client_uid)
)
def update_client(self, connection, msg, authorization, client):
client.update_configuration(
clean_session=msg.clean_session,
keep_alive=msg.keep_alive
)
client.update_connection(connection)
client.update_authorization(authorization)
access_log.info("[uid: %s] Reconfigured client upon "
"reconnection." % client.uid)
@gen.coroutine
[docs] def handle_stream(self, stream, address):
"""
This coroutine is called by the Tornado loop whenever it receives a
incoming connection. The server resolves the first message sent, checks
if it's a CONNECT frame and configures the client accordingly.
:param IOStream stream: A :class:`tornado.iostream.IOStream` instance;
:param tuple address: A tuple containing the ip and port of the
connected client, ie ('127.0.0.1', 12345).
"""
with stream_handle_context(stream) as context:
connection = MQTTConnection(stream, address)
msg = yield self.read_connect_message(connection)
context.client_uid = msg.client_uid
authorization = yield self.authenticate(msg)
yield self.write_connack_message(connection, msg, authorization)
context.client = client = self.get_or_create_client(
connection, msg, authorization)
client.start()
self.add_client(client)
@gen.coroutine
def read_connect_message(self, connection):
bytes_ = yield connection.read_message()
msg = MQTTMessageFactory.make(bytes_)
if not isinstance(msg, Connect):
raise ConnectError('The first message is expected to be CONNECT')
client_logger.debug("[B << C] [uid: %s] %s" %
(msg.client_uid, msg.log_info()))
return msg
@gen.coroutine
def authenticate(self, msg):
if not msg.client_uid and not msg.clean_session:
raise ConnectError('Client must provide an id to connect '
'without clean session')
authorization = yield self.authentication.authenticate(
msg.client_uid,
msg.username, msg.passwd)
assert isinstance(authorization, Authorization)
if authorization.is_fully_authorized():
client_logger.debug('[uid: %s] user:%s fully authorized' %
(msg.client_uid, msg.username))
return authorization
@gen.coroutine
def write_connack_message(self, connection, msg, authorization):
if not authorization.is_connection_allowed():
ack = Connack.from_return_code(0x04)
else:
sp = self.is_session_present(msg)
ack = Connack.from_return_code(0x00, session_present=sp)
client_logger.debug("[B >> C] [uid: %s] %s" %
(msg.client_uid, ack.log_info()))
yield connection.write_message(ack)
def is_session_present(self, msg):
return not msg.clean_session and msg.client_uid in self.clients
[docs] def add_client(self, client):
"""
Register a client to the Broker.
:param MQTTClient client: A :class:`broker.client.MQTTClient` instance.
"""
assert isinstance(client, MQTTClient)
self.clients[client.uid] = client
[docs] def remove_client(self, client):
"""
Removes a client from the know clients list. It's safe to call this
method without checking if the client is already known.
:param MQTTClient client: A :class:`broker.client.MQTTClient` instance;
.. caution::
It won't force client disconnection during the process, which can
result in a lingering client in the Tornado loop.
"""
assert isinstance(client, MQTTClient)
self.persistence.remove_client_data(client.uid)
if client.uid in self.clients:
del self.clients[client.uid]
access_log.info("[uid: %s] session cleaned" % client.uid)
[docs] def dispatch_message(self, client, msg, cache=None):
"""
Dispatches a message to a client based on its subscriptions. It is safe
to call this method without checking if the client has matching
subscriptions.
:param MQTTClient client: The client which will possibly receive the
message;
:param Publish msg: The message to be delivered.
:param dict cache: A dict that will be used for raw data caching.
Defaults to a empty dictionary if None.
"""
assert isinstance(msg, Publish)
assert isinstance(client, MQTTClient)
assert client.uid in self.clients
cache = cache if cache is not None else {}
qos_list = client.get_list_of_delivery_qos(msg)
for qos in qos_list:
# If the client is not connected, drop QoS 0 messages
if client.is_connected() or \
qos > MQTTConstants.AT_MOST_ONCE:
if qos not in cache:
msg_copy = msg.copy()
msg_copy.qos = qos
cache[qos] = msg_copy
client.publish(cache[qos])
[docs] def broadcast_message(self, msg):
"""
Broadcasts a message to all clients with matching subscriptions,
respecting the subscription QoS.
:param Publish msg: A :class:`broker.messages.Publish` instance.
"""
assert isinstance(msg, Publish)
cache = {}
for client in self.clients.values():
self.dispatch_message(client, msg, cache)
[docs] def disconnect_client(self, client):
"""
Disconnects a MQTT client. Can be safely called without checking if the
client is connected.
:param MQTTClient client: The MQTTClient to be disconnect
"""
assert isinstance(client, MQTTClient)
client.disconnect()
[docs] def disconnect_all_clients(self):
""" Disconnect all known clients. """
# The tuple() is needed because the dictionary could change during the
# iteration
for client in tuple(self.clients.values()):
self.disconnect_client(client)
[docs] def handle_incoming_publish(self, msg):
"""
Handles an incoming publish. This method is normally called by the
clients a mechanism of notifying the server that there is a new message
to be processed. The processing itself consists of retaining the message
according with the `msg.retain` flag and broadcasting it to the
subscribers.
:param Publish msg: The Publish message to be processed.
"""
if msg.retain is True:
self._retained_messages.save(msg)
# Broadcasted messages must always be delivered with the retain flag
# set to false. The flag should only be used when the message is sent
# cold.
msg.retain = False
self.broadcast_message(msg)
[docs] def enqueue_retained_message(self, client, subscription_mask):
"""
Enqueues all retained messages matching the `subscription_mask` to be
sent to the `client`.
:param MQTTClient client: A known MQTTClient.
:param str subscription_mask: The subscription mask to match the
messages against.
"""
assert isinstance(client, MQTTClient)
for topic, message in self._retained_messages.items():
if message is not None:
msg_obj = Publish.from_bytes(message)
qos = client.get_matching_qos(msg_obj, subscription_mask)
if qos is not None:
# creates a copy of the object to avoid reference errors
msg_copy = msg_obj.copy()
msg_copy.qos = qos
client.publish(msg_copy)
class stream_handle_context():
def __init__(self, stream):
self.stream = stream
self.client = None
self.client_uid = '?'
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type == toro.Timeout:
access_log.debug("[uid: %s] connection timeout"
% self.client_uid)
elif exc_type == StreamClosedError:
access_log.warning('[uid: %s] stream closed unexpectedly'
% self.client_uid)
elif exc_type == ConnectError:
self.stream.close()
access_log.info('[uid: %s] connection refused: %s'
% (self.client_uid, exc_val.message))
elif exc_type == Exception:
access_log.exception('[uid: %s] error handling stream'
% self.client_uid, exc_info=True)
if exc_val is not None:
if self.client is not None:
self.client.disconnect()
return True # suppress the raised exception
class RetainedMessages():
def __init__(self, retained_messages):
self._messages = retained_messages
def save(self, msg):
assert isinstance(msg, Publish)
if len(msg.payload) == 0:
if msg.topic in self._messages:
del self._messages[msg.topic]
else:
self._messages[msg.topic] = msg.raw_data
def items(self):
return self._messages.items()