diff --git a/skynet/dgpu.py b/skynet/dgpu.py index 3075ee2..4bb4f10 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -12,7 +12,7 @@ import traceback from typing import List, Optional from pathlib import Path -from contextlib import AsyncExitStack +from contextlib import ExitStack import pynng import torch @@ -61,6 +61,51 @@ class DGPUComputeError(BaseException): ... +class ReconnectingBus: + + def __init__(self, address: str, tls_config: Optional[TLSConfig]): + self.address = address + self.tls_config = tls_config + + self._stack = ExitStack() + self._sock = None + self._closed = True + + def connect(self): + self._sock = self._stack.enter_context( + pynng.Bus0(recv_max_size=0)) + self._sock.tls_config = self.tls_config + self._sock.dial(self.address) + self._closed = False + + async def arecv(self): + while True: + try: + return await self._sock.arecv() + + except pynng.exceptions.Closed: + if self._closed: + raise + + async def asend(self, msg): + while True: + try: + return await self._sock.asend(msg) + + except pynng.exceptions.Closed: + if self._closed: + raise + + def close(self): + self._stack.close() + self._stack = ExitStack() + self._closed = True + + def reconnect(self): + self.close() + self.connect() + + async def open_dgpu_node( cert_name: str, unique_id: str, @@ -141,13 +186,16 @@ async def open_dgpu_node( torch.cuda.empty_cache() - async with open_skynet_rpc( - unique_id, - rpc_address=rpc_address, - security=security, - cert_name=cert_name, - key_name=key_name - ) as rpc_call: + async with ( + open_skynet_rpc( + unique_id, + rpc_address=rpc_address, + security=security, + cert_name=cert_name, + key_name=key_name + ) as rpc_call, + trio.open_nursery() as n + ): tls_config = None if security: @@ -183,98 +231,118 @@ async def open_dgpu_node( ca_string=skynet_cert_data) logging.info(f'connecting to {dgpu_address}') - with pynng.Bus0(recv_max_size=0) as dgpu_sock: - dgpu_sock.tls_config = tls_config - dgpu_sock.dial(dgpu_address) - res = await rpc_call('dgpu_online') - assert 'ok' in res.result + dgpu_bus = ReconnectingBus(dgpu_address, tls_config) + dgpu_bus.connect() - try: - while True: - req = DGPUBusMessage() - req.ParseFromString(await dgpu_sock.arecv()) + last_msg = time.time() + async def connection_refresher(refresh_time: int = 120): + nonlocal last_msg + while True: + now = time.time() + last_msg_time_delta = now - last_msg + logging.info(f'time since last msg: {last_msg_time_delta}') + if last_msg_time_delta > refresh_time: + dgpu_bus.reconnect() + logging.info('reconnected!') + last_msg = now - if req.method == 'heartbeat': - rep = DGPUBusMessage( - rid=req.rid, - nid=unique_id, - method=req.method - ) - rep.params.update({'time': int(time.time() * 1000)}) + await trio.sleep(refresh_time) - if security: - rep.auth.cert = cert_name - rep.auth.sig = sign_protobuf_msg(rep, tls_key) + n.start_soon(connection_refresher) - await dgpu_sock.asend(rep.SerializeToString()) - continue + res = await rpc_call('dgpu_online') + assert 'ok' in res.result - if req.nid != unique_id: - logging.info( - f'witnessed msg {req.rid}, node involved: {req.nid}') - continue + try: + while True: + req = DGPUBusMessage() + req.ParseFromString(await dgpu_bus.arecv()) + last_msg = time.time() + + if req.method == 'heartbeat': + rep = DGPUBusMessage( + rid=req.rid, + nid=unique_id, + method=req.method + ) + rep.params.update({'time': int(time.time() * 1000)}) if security: - verify_protobuf_msg(req, skynet_cert) + rep.auth.cert = cert_name + rep.auth.sig = sign_protobuf_msg(rep, tls_key) + + await dgpu_bus.asend(rep.SerializeToString()) + logging.info('heartbeat reply') + continue + + if req.nid != unique_id: + logging.info( + f'witnessed msg {req.rid}, node involved: {req.nid}') + continue + + if security: + verify_protobuf_msg(req, skynet_cert) - ack_resp = DGPUBusMessage( + ack_resp = DGPUBusMessage( + rid=req.rid, + nid=req.nid + ) + ack_resp.params.update({'ack': {}}) + + if security: + ack_resp.auth.cert = cert_name + ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key) + + # send ack + await dgpu_bus.asend(ack_resp.SerializeToString()) + + logging.info(f'sent ack, processing {req.rid}...') + + try: + img_req = Text2ImageParameters(**req.params) + if not img_req.seed: + img_req.seed = random.randint(0, 2 ** 64) + + img = await gpu_compute_one(img_req) + img_resp = DGPUBusMessage( + rid=req.rid, + nid=req.nid, + method='binary-reply' + ) + img_resp.params.update({ + 'len': len(img), + 'meta': img_req.to_dict() + }) + + except DGPUComputeError as e: + traceback.print_exception(type(e), e, e.__traceback__) + img_resp = DGPUBusMessage( rid=req.rid, nid=req.nid ) - ack_resp.params.update({'ack': {}}) - - if security: - ack_resp.auth.cert = cert_name - ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key) - - # send ack - await dgpu_sock.asend(ack_resp.SerializeToString()) - - logging.info(f'sent ack, processing {req.rid}...') - - try: - img_req = Text2ImageParameters(**req.params) - if not img_req.seed: - img_req.seed = random.randint(0, 2 ** 64) - - img = await gpu_compute_one(img_req) - img_resp = DGPUBusMessage( - rid=req.rid, - nid=req.nid, - method='binary-reply' - ) - img_resp.params.update({ - 'len': len(img), - 'meta': img_req.to_dict() - }) - - except DGPUComputeError as e: - traceback.print_exception(type(e), e, e.__traceback__) - img_resp = DGPUBusMessage( - rid=req.rid, - nid=req.nid - ) - img_resp.params.update({'error': str(e)}) + img_resp.params.update({'error': str(e)}) - if security: - img_resp.auth.cert = cert_name - img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key) + if security: + img_resp.auth.cert = cert_name + img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key) - # send final image - logging.info('sending img back...') - raw_msg = img_resp.SerializeToString() - await dgpu_sock.asend(raw_msg) - logging.info(f'sent {len(raw_msg)} bytes.') - if img_resp.method == 'binary-reply': - await dgpu_sock.asend(img) - logging.info(f'sent {len(img)} bytes.') + # send final image + logging.info('sending img back...') + raw_msg = img_resp.SerializeToString() + await dgpu_bus.asend(raw_msg) + logging.info(f'sent {len(raw_msg)} bytes.') + if img_resp.method == 'binary-reply': + await dgpu_bus.asend(img) + logging.info(f'sent {len(img)} bytes.') - except KeyboardInterrupt: - logging.info('interrupt caught, stopping...') + except KeyboardInterrupt: + logging.info('interrupt caught, stopping...') + n.cancel_scope.cancel() + dgpu_bus.close() - finally: - res = await rpc_call('dgpu_offline') - assert 'ok' in res.result + finally: + res = await rpc_call('dgpu_offline') + assert 'ok' in res.result