Add reconnect mechanic to dgpu bus

pull/4/head
Guillermo Rodriguez 2023-01-08 07:16:43 -03:00
parent 10e77655c6
commit 585d304f86
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
1 changed files with 153 additions and 85 deletions

View File

@ -12,7 +12,7 @@ import traceback
from typing import List, Optional
from pathlib import Path
from contextlib import AsyncExitStack
from contextlib import ExitStack
import pynng
import torch
@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
...
class ReconnectingBus:
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
self.address = address
self.tls_config = tls_config
self._stack = ExitStack()
self._sock = None
self._closed = True
def connect(self):
self._sock = self._stack.enter_context(
pynng.Bus0(recv_max_size=0))
self._sock.tls_config = self.tls_config
self._sock.dial(self.address)
self._closed = False
async def arecv(self):
while True:
try:
return await self._sock.arecv()
except pynng.exceptions.Closed:
if self._closed:
raise
async def asend(self, msg):
while True:
try:
return await self._sock.asend(msg)
except pynng.exceptions.Closed:
if self._closed:
raise
def close(self):
self._stack.close()
self._stack = ExitStack()
self._closed = True
def reconnect(self):
self.close()
self.connect()
async def open_dgpu_node(
cert_name: str,
unique_id: str,
@ -141,13 +186,16 @@ async def open_dgpu_node(
torch.cuda.empty_cache()
async with open_skynet_rpc(
async with (
open_skynet_rpc(
unique_id,
rpc_address=rpc_address,
security=security,
cert_name=cert_name,
key_name=key_name
) as rpc_call:
) as rpc_call,
trio.open_nursery() as n
):
tls_config = None
if security:
@ -183,9 +231,25 @@ async def open_dgpu_node(
ca_string=skynet_cert_data)
logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
dgpu_sock.tls_config = tls_config
dgpu_sock.dial(dgpu_address)
dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
dgpu_bus.connect()
last_msg = time.time()
async def connection_refresher(refresh_time: int = 120):
nonlocal last_msg
while True:
now = time.time()
last_msg_time_delta = now - last_msg
logging.info(f'time since last msg: {last_msg_time_delta}')
if last_msg_time_delta > refresh_time:
dgpu_bus.reconnect()
logging.info('reconnected!')
last_msg = now
await trio.sleep(refresh_time)
n.start_soon(connection_refresher)
res = await rpc_call('dgpu_online')
assert 'ok' in res.result
@ -193,7 +257,8 @@ async def open_dgpu_node(
try:
while True:
req = DGPUBusMessage()
req.ParseFromString(await dgpu_sock.arecv())
req.ParseFromString(await dgpu_bus.arecv())
last_msg = time.time()
if req.method == 'heartbeat':
rep = DGPUBusMessage(
@ -207,7 +272,8 @@ async def open_dgpu_node(
rep.auth.cert = cert_name
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
await dgpu_sock.asend(rep.SerializeToString())
await dgpu_bus.asend(rep.SerializeToString())
logging.info('heartbeat reply')
continue
if req.nid != unique_id:
@ -230,7 +296,7 @@ async def open_dgpu_node(
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
# send ack
await dgpu_sock.asend(ack_resp.SerializeToString())
await dgpu_bus.asend(ack_resp.SerializeToString())
logging.info(f'sent ack, processing {req.rid}...')
@ -266,14 +332,16 @@ async def open_dgpu_node(
# send final image
logging.info('sending img back...')
raw_msg = img_resp.SerializeToString()
await dgpu_sock.asend(raw_msg)
await dgpu_bus.asend(raw_msg)
logging.info(f'sent {len(raw_msg)} bytes.')
if img_resp.method == 'binary-reply':
await dgpu_sock.asend(img)
await dgpu_bus.asend(img)
logging.info(f'sent {len(img)} bytes.')
except KeyboardInterrupt:
logging.info('interrupt caught, stopping...')
n.cancel_scope.cancel()
dgpu_bus.close()
finally:
res = await rpc_call('dgpu_offline')