Ensure connections exist when needed

This commit is contained in:
Richard Mitchell
2019-04-18 13:14:08 +01:00
parent 8e97b42616
commit e57935ba7c

View File

@@ -413,6 +413,14 @@ def _call_async(fn, *args):
loop.call_soon(wrapper, fn, *args) loop.call_soon(wrapper, fn, *args)
def async_connection_needed(fn):
def wrapper(self, *args, **kwargs):
await self.async_connect()
return await fn(*args, **kwargs)
return wrapper
class TuyaDevice: class TuyaDevice:
"""Represents a generic Tuya device.""" """Represents a generic Tuya device."""
@@ -442,7 +450,7 @@ class TuyaDevice:
Message.PING_COMMAND: [self._async_pong_received], Message.PING_COMMAND: [self._async_pong_received],
} }
self._dps = {} self._dps = {}
self._disconnected = True self._connected = False
def __repr__(self): def __repr__(self):
return "{}({}, {}, {}, {})".format( return "{}({}, {}, {}, {})".format(
@@ -464,7 +472,7 @@ class TuyaDevice:
sock.connect((self.host, self.port)) sock.connect((self.host, self.port))
except socket.timeout as e: except socket.timeout as e:
raise ConnectionTimeoutException("Connection timed out") from e raise ConnectionTimeoutException("Connection timed out") from e
self._disconnected = False self._connected = True
self.reader, self.writer = await asyncio.open_connection(sock=sock) self.reader, self.writer = await asyncio.open_connection(sock=sock)
asyncio.ensure_future(self._async_handle_message()) asyncio.ensure_future(self._async_handle_message())
@@ -473,7 +481,7 @@ class TuyaDevice:
async def async_disconnect(self): async def async_disconnect(self):
_LOGGER.debug("Disconnected from {}".format(self)) _LOGGER.debug("Disconnected from {}".format(self))
self._disconnected = True self._connected = False
self.last_pong = 0 self.last_pong = 0
if self.writer is not None: if self.writer is not None:
self.writer.close() self.writer.close()
@@ -547,6 +555,7 @@ class TuyaDevice:
asyncio.ensure_future(self._async_handle_message()) asyncio.ensure_future(self._async_handle_message())
@async_connection_needed
async def _async_send(self, message, retries=4): async def _async_send(self, message, retries=4):
_LOGGER.debug("Sending to {}: {}".format(self, message)) _LOGGER.debug("Sending to {}: {}".format(self, message))
try: try: