mirror of https://github.com/skygpu/skynet.git
Add reconnect mechanic to dgpu bus
parent
10e77655c6
commit
585d304f86
238
skynet/dgpu.py
238
skynet/dgpu.py
|
@ -12,7 +12,7 @@ import traceback
|
|||
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from contextlib import AsyncExitStack
|
||||
from contextlib import ExitStack
|
||||
|
||||
import pynng
|
||||
import torch
|
||||
|
@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
|
|||
...
|
||||
|
||||
|
||||
class ReconnectingBus:
|
||||
|
||||
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
|
||||
self.address = address
|
||||
self.tls_config = tls_config
|
||||
|
||||
self._stack = ExitStack()
|
||||
self._sock = None
|
||||
self._closed = True
|
||||
|
||||
def connect(self):
|
||||
self._sock = self._stack.enter_context(
|
||||
pynng.Bus0(recv_max_size=0))
|
||||
self._sock.tls_config = self.tls_config
|
||||
self._sock.dial(self.address)
|
||||
self._closed = False
|
||||
|
||||
async def arecv(self):
|
||||
while True:
|
||||
try:
|
||||
return await self._sock.arecv()
|
||||
|
||||
except pynng.exceptions.Closed:
|
||||
if self._closed:
|
||||
raise
|
||||
|
||||
async def asend(self, msg):
|
||||
while True:
|
||||
try:
|
||||
return await self._sock.asend(msg)
|
||||
|
||||
except pynng.exceptions.Closed:
|
||||
if self._closed:
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
self._stack.close()
|
||||
self._stack = ExitStack()
|
||||
self._closed = True
|
||||
|
||||
def reconnect(self):
|
||||
self.close()
|
||||
self.connect()
|
||||
|
||||
|
||||
async def open_dgpu_node(
|
||||
cert_name: str,
|
||||
unique_id: str,
|
||||
|
@ -141,13 +186,16 @@ async def open_dgpu_node(
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
async with open_skynet_rpc(
|
||||
unique_id,
|
||||
rpc_address=rpc_address,
|
||||
security=security,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
) as rpc_call:
|
||||
async with (
|
||||
open_skynet_rpc(
|
||||
unique_id,
|
||||
rpc_address=rpc_address,
|
||||
security=security,
|
||||
cert_name=cert_name,
|
||||
key_name=key_name
|
||||
) as rpc_call,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
|
||||
tls_config = None
|
||||
if security:
|
||||
|
@ -183,98 +231,118 @@ async def open_dgpu_node(
|
|||
ca_string=skynet_cert_data)
|
||||
|
||||
logging.info(f'connecting to {dgpu_address}')
|
||||
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
|
||||
dgpu_sock.tls_config = tls_config
|
||||
dgpu_sock.dial(dgpu_address)
|
||||
|
||||
res = await rpc_call('dgpu_online')
|
||||
assert 'ok' in res.result
|
||||
dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
|
||||
dgpu_bus.connect()
|
||||
|
||||
try:
|
||||
while True:
|
||||
req = DGPUBusMessage()
|
||||
req.ParseFromString(await dgpu_sock.arecv())
|
||||
last_msg = time.time()
|
||||
async def connection_refresher(refresh_time: int = 120):
|
||||
nonlocal last_msg
|
||||
while True:
|
||||
now = time.time()
|
||||
last_msg_time_delta = now - last_msg
|
||||
logging.info(f'time since last msg: {last_msg_time_delta}')
|
||||
if last_msg_time_delta > refresh_time:
|
||||
dgpu_bus.reconnect()
|
||||
logging.info('reconnected!')
|
||||
last_msg = now
|
||||
|
||||
if req.method == 'heartbeat':
|
||||
rep = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=unique_id,
|
||||
method=req.method
|
||||
)
|
||||
rep.params.update({'time': int(time.time() * 1000)})
|
||||
await trio.sleep(refresh_time)
|
||||
|
||||
if security:
|
||||
rep.auth.cert = cert_name
|
||||
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
||||
n.start_soon(connection_refresher)
|
||||
|
||||
await dgpu_sock.asend(rep.SerializeToString())
|
||||
continue
|
||||
res = await rpc_call('dgpu_online')
|
||||
assert 'ok' in res.result
|
||||
|
||||
if req.nid != unique_id:
|
||||
logging.info(
|
||||
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
||||
continue
|
||||
try:
|
||||
while True:
|
||||
req = DGPUBusMessage()
|
||||
req.ParseFromString(await dgpu_bus.arecv())
|
||||
last_msg = time.time()
|
||||
|
||||
if req.method == 'heartbeat':
|
||||
rep = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=unique_id,
|
||||
method=req.method
|
||||
)
|
||||
rep.params.update({'time': int(time.time() * 1000)})
|
||||
|
||||
if security:
|
||||
verify_protobuf_msg(req, skynet_cert)
|
||||
rep.auth.cert = cert_name
|
||||
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
||||
|
||||
await dgpu_bus.asend(rep.SerializeToString())
|
||||
logging.info('heartbeat reply')
|
||||
continue
|
||||
|
||||
if req.nid != unique_id:
|
||||
logging.info(
|
||||
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
||||
continue
|
||||
|
||||
if security:
|
||||
verify_protobuf_msg(req, skynet_cert)
|
||||
|
||||
|
||||
ack_resp = DGPUBusMessage(
|
||||
ack_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid
|
||||
)
|
||||
ack_resp.params.update({'ack': {}})
|
||||
|
||||
if security:
|
||||
ack_resp.auth.cert = cert_name
|
||||
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
||||
|
||||
# send ack
|
||||
await dgpu_bus.asend(ack_resp.SerializeToString())
|
||||
|
||||
logging.info(f'sent ack, processing {req.rid}...')
|
||||
|
||||
try:
|
||||
img_req = Text2ImageParameters(**req.params)
|
||||
if not img_req.seed:
|
||||
img_req.seed = random.randint(0, 2 ** 64)
|
||||
|
||||
img = await gpu_compute_one(img_req)
|
||||
img_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
method='binary-reply'
|
||||
)
|
||||
img_resp.params.update({
|
||||
'len': len(img),
|
||||
'meta': img_req.to_dict()
|
||||
})
|
||||
|
||||
except DGPUComputeError as e:
|
||||
traceback.print_exception(type(e), e, e.__traceback__)
|
||||
img_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid
|
||||
)
|
||||
ack_resp.params.update({'ack': {}})
|
||||
|
||||
if security:
|
||||
ack_resp.auth.cert = cert_name
|
||||
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
||||
|
||||
# send ack
|
||||
await dgpu_sock.asend(ack_resp.SerializeToString())
|
||||
|
||||
logging.info(f'sent ack, processing {req.rid}...')
|
||||
|
||||
try:
|
||||
img_req = Text2ImageParameters(**req.params)
|
||||
if not img_req.seed:
|
||||
img_req.seed = random.randint(0, 2 ** 64)
|
||||
|
||||
img = await gpu_compute_one(img_req)
|
||||
img_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
method='binary-reply'
|
||||
)
|
||||
img_resp.params.update({
|
||||
'len': len(img),
|
||||
'meta': img_req.to_dict()
|
||||
})
|
||||
|
||||
except DGPUComputeError as e:
|
||||
traceback.print_exception(type(e), e, e.__traceback__)
|
||||
img_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid
|
||||
)
|
||||
img_resp.params.update({'error': str(e)})
|
||||
img_resp.params.update({'error': str(e)})
|
||||
|
||||
|
||||
if security:
|
||||
img_resp.auth.cert = cert_name
|
||||
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
|
||||
if security:
|
||||
img_resp.auth.cert = cert_name
|
||||
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
|
||||
|
||||
# send final image
|
||||
logging.info('sending img back...')
|
||||
raw_msg = img_resp.SerializeToString()
|
||||
await dgpu_sock.asend(raw_msg)
|
||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||
if img_resp.method == 'binary-reply':
|
||||
await dgpu_sock.asend(img)
|
||||
logging.info(f'sent {len(img)} bytes.')
|
||||
# send final image
|
||||
logging.info('sending img back...')
|
||||
raw_msg = img_resp.SerializeToString()
|
||||
await dgpu_bus.asend(raw_msg)
|
||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||
if img_resp.method == 'binary-reply':
|
||||
await dgpu_bus.asend(img)
|
||||
logging.info(f'sent {len(img)} bytes.')
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logging.info('interrupt caught, stopping...')
|
||||
except KeyboardInterrupt:
|
||||
logging.info('interrupt caught, stopping...')
|
||||
n.cancel_scope.cancel()
|
||||
dgpu_bus.close()
|
||||
|
||||
finally:
|
||||
res = await rpc_call('dgpu_offline')
|
||||
assert 'ok' in res.result
|
||||
finally:
|
||||
res = await rpc_call('dgpu_offline')
|
||||
assert 'ok' in res.result
|
||||
|
|
Loading…
Reference in New Issue