mirror of https://github.com/skygpu/skynet.git
Add reconnect mechanic to dgpu bus
parent
10e77655c6
commit
585d304f86
|
@ -12,7 +12,7 @@ import traceback
|
|||
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from contextlib import AsyncExitStack
|
||||
from contextlib import ExitStack
|
||||
|
||||
import pynng
|
||||
import torch
|
||||
|
@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
|
|||
...
|
||||
|
||||
|
||||
class ReconnectingBus:
|
||||
|
||||
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
|
||||
self.address = address
|
||||
self.tls_config = tls_config
|
||||
|
||||
self._stack = ExitStack()
|
||||
self._sock = None
|
||||
self._closed = True
|
||||
|
||||
def connect(self):
|
||||
self._sock = self._stack.enter_context(
|
||||
pynng.Bus0(recv_max_size=0))
|
||||
self._sock.tls_config = self.tls_config
|
||||
self._sock.dial(self.address)
|
||||
self._closed = False
|
||||
|
||||
async def arecv(self):
|
||||
while True:
|
||||
try:
|
||||
return await self._sock.arecv()
|
||||
|
||||
except pynng.exceptions.Closed:
|
||||
if self._closed:
|
||||
raise
|
||||
|
||||
async def asend(self, msg):
|
||||
while True:
|
||||
try:
|
||||
return await self._sock.asend(msg)
|
||||
|
||||
except pynng.exceptions.Closed:
|
||||
if self._closed:
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
self._stack.close()
|
||||
self._stack = ExitStack()
|
||||
self._closed = True
|
||||
|
||||
def reconnect(self):
|
||||
self.close()
|
||||
self.connect()
|
||||
|
||||
|
||||
async def open_dgpu_node(
|
||||
cert_name: str,
|
||||
unique_id: str,
|
||||
|
@ -141,13 +186,16 @@ async def open_dgpu_node(
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
async with open_skynet_rpc(
|
||||
async with (
|
||||
open_skynet_rpc(
|
||||
unique_id,
|
||||
rpc_address=rpc_address,
|
||||
security=security,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
) as rpc_call:
|
||||
) as rpc_call,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
|
||||
tls_config = None
|
||||
if security:
|
||||
|
@ -183,9 +231,25 @@ async def open_dgpu_node(
|
|||
ca_string=skynet_cert_data)
|
||||
|
||||
logging.info(f'connecting to {dgpu_address}')
|
||||
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
|
||||
dgpu_sock.tls_config = tls_config
|
||||
dgpu_sock.dial(dgpu_address)
|
||||
|
||||
dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
|
||||
dgpu_bus.connect()
|
||||
|
||||
last_msg = time.time()
|
||||
async def connection_refresher(refresh_time: int = 120):
|
||||
nonlocal last_msg
|
||||
while True:
|
||||
now = time.time()
|
||||
last_msg_time_delta = now - last_msg
|
||||
logging.info(f'time since last msg: {last_msg_time_delta}')
|
||||
if last_msg_time_delta > refresh_time:
|
||||
dgpu_bus.reconnect()
|
||||
logging.info('reconnected!')
|
||||
last_msg = now
|
||||
|
||||
await trio.sleep(refresh_time)
|
||||
|
||||
n.start_soon(connection_refresher)
|
||||
|
||||
res = await rpc_call('dgpu_online')
|
||||
assert 'ok' in res.result
|
||||
|
@ -193,7 +257,8 @@ async def open_dgpu_node(
|
|||
try:
|
||||
while True:
|
||||
req = DGPUBusMessage()
|
||||
req.ParseFromString(await dgpu_sock.arecv())
|
||||
req.ParseFromString(await dgpu_bus.arecv())
|
||||
last_msg = time.time()
|
||||
|
||||
if req.method == 'heartbeat':
|
||||
rep = DGPUBusMessage(
|
||||
|
@ -207,7 +272,8 @@ async def open_dgpu_node(
|
|||
rep.auth.cert = cert_name
|
||||
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
||||
|
||||
await dgpu_sock.asend(rep.SerializeToString())
|
||||
await dgpu_bus.asend(rep.SerializeToString())
|
||||
logging.info('heartbeat reply')
|
||||
continue
|
||||
|
||||
if req.nid != unique_id:
|
||||
|
@ -230,7 +296,7 @@ async def open_dgpu_node(
|
|||
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
||||
|
||||
# send ack
|
||||
await dgpu_sock.asend(ack_resp.SerializeToString())
|
||||
await dgpu_bus.asend(ack_resp.SerializeToString())
|
||||
|
||||
logging.info(f'sent ack, processing {req.rid}...')
|
||||
|
||||
|
@ -266,14 +332,16 @@ async def open_dgpu_node(
|
|||
# send final image
|
||||
logging.info('sending img back...')
|
||||
raw_msg = img_resp.SerializeToString()
|
||||
await dgpu_sock.asend(raw_msg)
|
||||
await dgpu_bus.asend(raw_msg)
|
||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||
if img_resp.method == 'binary-reply':
|
||||
await dgpu_sock.asend(img)
|
||||
await dgpu_bus.asend(img)
|
||||
logging.info(f'sent {len(img)} bytes.')
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logging.info('interrupt caught, stopping...')
|
||||
n.cancel_scope.cancel()
|
||||
dgpu_bus.close()
|
||||
|
||||
finally:
|
||||
res = await rpc_call('dgpu_offline')
|
||||
|
|
Loading…
Reference in New Issue