mirror of https://github.com/skygpu/skynet.git
Add reconnect mechanic to dgpu bus
parent
10e77655c6
commit
585d304f86
|
@ -12,7 +12,7 @@ import traceback
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import ExitStack
|
||||||
|
|
||||||
import pynng
|
import pynng
|
||||||
import torch
|
import torch
|
||||||
|
@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectingBus:
|
||||||
|
|
||||||
|
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
|
||||||
|
self.address = address
|
||||||
|
self.tls_config = tls_config
|
||||||
|
|
||||||
|
self._stack = ExitStack()
|
||||||
|
self._sock = None
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self._sock = self._stack.enter_context(
|
||||||
|
pynng.Bus0(recv_max_size=0))
|
||||||
|
self._sock.tls_config = self.tls_config
|
||||||
|
self._sock.dial(self.address)
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
async def arecv(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await self._sock.arecv()
|
||||||
|
|
||||||
|
except pynng.exceptions.Closed:
|
||||||
|
if self._closed:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asend(self, msg):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await self._sock.asend(msg)
|
||||||
|
|
||||||
|
except pynng.exceptions.Closed:
|
||||||
|
if self._closed:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._stack.close()
|
||||||
|
self._stack = ExitStack()
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
def reconnect(self):
|
||||||
|
self.close()
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
|
||||||
async def open_dgpu_node(
|
async def open_dgpu_node(
|
||||||
cert_name: str,
|
cert_name: str,
|
||||||
unique_id: str,
|
unique_id: str,
|
||||||
|
@ -141,13 +186,16 @@ async def open_dgpu_node(
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
async with (
|
||||||
|
open_skynet_rpc(
|
||||||
unique_id,
|
unique_id,
|
||||||
rpc_address=rpc_address,
|
rpc_address=rpc_address,
|
||||||
security=security,
|
security=security,
|
||||||
cert_name=cert_name,
|
cert_name=cert_name,
|
||||||
key_name=key_name
|
key_name=key_name
|
||||||
) as rpc_call:
|
) as rpc_call,
|
||||||
|
trio.open_nursery() as n
|
||||||
|
):
|
||||||
|
|
||||||
tls_config = None
|
tls_config = None
|
||||||
if security:
|
if security:
|
||||||
|
@ -183,9 +231,25 @@ async def open_dgpu_node(
|
||||||
ca_string=skynet_cert_data)
|
ca_string=skynet_cert_data)
|
||||||
|
|
||||||
logging.info(f'connecting to {dgpu_address}')
|
logging.info(f'connecting to {dgpu_address}')
|
||||||
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
|
|
||||||
dgpu_sock.tls_config = tls_config
|
dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
|
||||||
dgpu_sock.dial(dgpu_address)
|
dgpu_bus.connect()
|
||||||
|
|
||||||
|
last_msg = time.time()
|
||||||
|
async def connection_refresher(refresh_time: int = 120):
|
||||||
|
nonlocal last_msg
|
||||||
|
while True:
|
||||||
|
now = time.time()
|
||||||
|
last_msg_time_delta = now - last_msg
|
||||||
|
logging.info(f'time since last msg: {last_msg_time_delta}')
|
||||||
|
if last_msg_time_delta > refresh_time:
|
||||||
|
dgpu_bus.reconnect()
|
||||||
|
logging.info('reconnected!')
|
||||||
|
last_msg = now
|
||||||
|
|
||||||
|
await trio.sleep(refresh_time)
|
||||||
|
|
||||||
|
n.start_soon(connection_refresher)
|
||||||
|
|
||||||
res = await rpc_call('dgpu_online')
|
res = await rpc_call('dgpu_online')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
@ -193,7 +257,8 @@ async def open_dgpu_node(
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
req = DGPUBusMessage()
|
req = DGPUBusMessage()
|
||||||
req.ParseFromString(await dgpu_sock.arecv())
|
req.ParseFromString(await dgpu_bus.arecv())
|
||||||
|
last_msg = time.time()
|
||||||
|
|
||||||
if req.method == 'heartbeat':
|
if req.method == 'heartbeat':
|
||||||
rep = DGPUBusMessage(
|
rep = DGPUBusMessage(
|
||||||
|
@ -207,7 +272,8 @@ async def open_dgpu_node(
|
||||||
rep.auth.cert = cert_name
|
rep.auth.cert = cert_name
|
||||||
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
||||||
|
|
||||||
await dgpu_sock.asend(rep.SerializeToString())
|
await dgpu_bus.asend(rep.SerializeToString())
|
||||||
|
logging.info('heartbeat reply')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if req.nid != unique_id:
|
if req.nid != unique_id:
|
||||||
|
@ -230,7 +296,7 @@ async def open_dgpu_node(
|
||||||
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
||||||
|
|
||||||
# send ack
|
# send ack
|
||||||
await dgpu_sock.asend(ack_resp.SerializeToString())
|
await dgpu_bus.asend(ack_resp.SerializeToString())
|
||||||
|
|
||||||
logging.info(f'sent ack, processing {req.rid}...')
|
logging.info(f'sent ack, processing {req.rid}...')
|
||||||
|
|
||||||
|
@ -266,14 +332,16 @@ async def open_dgpu_node(
|
||||||
# send final image
|
# send final image
|
||||||
logging.info('sending img back...')
|
logging.info('sending img back...')
|
||||||
raw_msg = img_resp.SerializeToString()
|
raw_msg = img_resp.SerializeToString()
|
||||||
await dgpu_sock.asend(raw_msg)
|
await dgpu_bus.asend(raw_msg)
|
||||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||||
if img_resp.method == 'binary-reply':
|
if img_resp.method == 'binary-reply':
|
||||||
await dgpu_sock.asend(img)
|
await dgpu_bus.asend(img)
|
||||||
logging.info(f'sent {len(img)} bytes.')
|
logging.info(f'sent {len(img)} bytes.')
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info('interrupt caught, stopping...')
|
logging.info('interrupt caught, stopping...')
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
dgpu_bus.close()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
res = await rpc_call('dgpu_offline')
|
res = await rpc_call('dgpu_offline')
|
||||||
|
|
Loading…
Reference in New Issue