Add reconnect mechanic to dgpu bus

pull/4/head
Guillermo Rodriguez 2023-01-08 07:16:43 -03:00
parent 10e77655c6
commit 585d304f86
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
1 changed files with 153 additions and 85 deletions

View File

@ -12,7 +12,7 @@ import traceback
from typing import List, Optional from typing import List, Optional
from pathlib import Path from pathlib import Path
from contextlib import AsyncExitStack from contextlib import ExitStack
import pynng import pynng
import torch import torch
@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
... ...
class ReconnectingBus:
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
self.address = address
self.tls_config = tls_config
self._stack = ExitStack()
self._sock = None
self._closed = True
def connect(self):
self._sock = self._stack.enter_context(
pynng.Bus0(recv_max_size=0))
self._sock.tls_config = self.tls_config
self._sock.dial(self.address)
self._closed = False
async def arecv(self):
while True:
try:
return await self._sock.arecv()
except pynng.exceptions.Closed:
if self._closed:
raise
async def asend(self, msg):
while True:
try:
return await self._sock.asend(msg)
except pynng.exceptions.Closed:
if self._closed:
raise
def close(self):
self._stack.close()
self._stack = ExitStack()
self._closed = True
def reconnect(self):
self.close()
self.connect()
async def open_dgpu_node( async def open_dgpu_node(
cert_name: str, cert_name: str,
unique_id: str, unique_id: str,
@ -141,13 +186,16 @@ async def open_dgpu_node(
torch.cuda.empty_cache() torch.cuda.empty_cache()
async with open_skynet_rpc( async with (
open_skynet_rpc(
unique_id, unique_id,
rpc_address=rpc_address, rpc_address=rpc_address,
security=security, security=security,
cert_name=cert_name, cert_name=cert_name,
key_name=key_name key_name=key_name
) as rpc_call: ) as rpc_call,
trio.open_nursery() as n
):
tls_config = None tls_config = None
if security: if security:
@ -183,9 +231,25 @@ async def open_dgpu_node(
ca_string=skynet_cert_data) ca_string=skynet_cert_data)
logging.info(f'connecting to {dgpu_address}') logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
dgpu_sock.tls_config = tls_config dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
dgpu_sock.dial(dgpu_address) dgpu_bus.connect()
last_msg = time.time()
async def connection_refresher(refresh_time: int = 120):
nonlocal last_msg
while True:
now = time.time()
last_msg_time_delta = now - last_msg
logging.info(f'time since last msg: {last_msg_time_delta}')
if last_msg_time_delta > refresh_time:
dgpu_bus.reconnect()
logging.info('reconnected!')
last_msg = now
await trio.sleep(refresh_time)
n.start_soon(connection_refresher)
res = await rpc_call('dgpu_online') res = await rpc_call('dgpu_online')
assert 'ok' in res.result assert 'ok' in res.result
@ -193,7 +257,8 @@ async def open_dgpu_node(
try: try:
while True: while True:
req = DGPUBusMessage() req = DGPUBusMessage()
req.ParseFromString(await dgpu_sock.arecv()) req.ParseFromString(await dgpu_bus.arecv())
last_msg = time.time()
if req.method == 'heartbeat': if req.method == 'heartbeat':
rep = DGPUBusMessage( rep = DGPUBusMessage(
@ -207,7 +272,8 @@ async def open_dgpu_node(
rep.auth.cert = cert_name rep.auth.cert = cert_name
rep.auth.sig = sign_protobuf_msg(rep, tls_key) rep.auth.sig = sign_protobuf_msg(rep, tls_key)
await dgpu_sock.asend(rep.SerializeToString()) await dgpu_bus.asend(rep.SerializeToString())
logging.info('heartbeat reply')
continue continue
if req.nid != unique_id: if req.nid != unique_id:
@ -230,7 +296,7 @@ async def open_dgpu_node(
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key) ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
# send ack # send ack
await dgpu_sock.asend(ack_resp.SerializeToString()) await dgpu_bus.asend(ack_resp.SerializeToString())
logging.info(f'sent ack, processing {req.rid}...') logging.info(f'sent ack, processing {req.rid}...')
@ -266,14 +332,16 @@ async def open_dgpu_node(
# send final image # send final image
logging.info('sending img back...') logging.info('sending img back...')
raw_msg = img_resp.SerializeToString() raw_msg = img_resp.SerializeToString()
await dgpu_sock.asend(raw_msg) await dgpu_bus.asend(raw_msg)
logging.info(f'sent {len(raw_msg)} bytes.') logging.info(f'sent {len(raw_msg)} bytes.')
if img_resp.method == 'binary-reply': if img_resp.method == 'binary-reply':
await dgpu_sock.asend(img) await dgpu_bus.asend(img)
logging.info(f'sent {len(img)} bytes.') logging.info(f'sent {len(img)} bytes.')
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info('interrupt caught, stopping...') logging.info('interrupt caught, stopping...')
n.cancel_scope.cancel()
dgpu_bus.close()
finally: finally:
res = await rpc_call('dgpu_offline') res = await rpc_call('dgpu_offline')