mirror of https://github.com/skygpu/skynet.git
Add reconnect mechanic to dgpu bus
parent
10e77655c6
commit
585d304f86
238
skynet/dgpu.py
238
skynet/dgpu.py
|
@ -12,7 +12,7 @@ import traceback
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import ExitStack
|
||||||
|
|
||||||
import pynng
|
import pynng
|
||||||
import torch
|
import torch
|
||||||
|
@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectingBus:
|
||||||
|
|
||||||
|
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
|
||||||
|
self.address = address
|
||||||
|
self.tls_config = tls_config
|
||||||
|
|
||||||
|
self._stack = ExitStack()
|
||||||
|
self._sock = None
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self._sock = self._stack.enter_context(
|
||||||
|
pynng.Bus0(recv_max_size=0))
|
||||||
|
self._sock.tls_config = self.tls_config
|
||||||
|
self._sock.dial(self.address)
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
async def arecv(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await self._sock.arecv()
|
||||||
|
|
||||||
|
except pynng.exceptions.Closed:
|
||||||
|
if self._closed:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asend(self, msg):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await self._sock.asend(msg)
|
||||||
|
|
||||||
|
except pynng.exceptions.Closed:
|
||||||
|
if self._closed:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._stack.close()
|
||||||
|
self._stack = ExitStack()
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
def reconnect(self):
|
||||||
|
self.close()
|
||||||
|
self.connect()
|
||||||
|
|
||||||
|
|
||||||
async def open_dgpu_node(
|
async def open_dgpu_node(
|
||||||
cert_name: str,
|
cert_name: str,
|
||||||
unique_id: str,
|
unique_id: str,
|
||||||
|
@ -141,13 +186,16 @@ async def open_dgpu_node(
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
async with (
|
||||||
unique_id,
|
open_skynet_rpc(
|
||||||
rpc_address=rpc_address,
|
unique_id,
|
||||||
security=security,
|
rpc_address=rpc_address,
|
||||||
cert_name=cert_name,
|
security=security,
|
||||||
key_name=key_name
|
cert_name=cert_name,
|
||||||
) as rpc_call:
|
key_name=key_name
|
||||||
|
) as rpc_call,
|
||||||
|
trio.open_nursery() as n
|
||||||
|
):
|
||||||
|
|
||||||
tls_config = None
|
tls_config = None
|
||||||
if security:
|
if security:
|
||||||
|
@ -183,98 +231,118 @@ async def open_dgpu_node(
|
||||||
ca_string=skynet_cert_data)
|
ca_string=skynet_cert_data)
|
||||||
|
|
||||||
logging.info(f'connecting to {dgpu_address}')
|
logging.info(f'connecting to {dgpu_address}')
|
||||||
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
|
|
||||||
dgpu_sock.tls_config = tls_config
|
|
||||||
dgpu_sock.dial(dgpu_address)
|
|
||||||
|
|
||||||
res = await rpc_call('dgpu_online')
|
dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
|
||||||
assert 'ok' in res.result
|
dgpu_bus.connect()
|
||||||
|
|
||||||
try:
|
last_msg = time.time()
|
||||||
while True:
|
async def connection_refresher(refresh_time: int = 120):
|
||||||
req = DGPUBusMessage()
|
nonlocal last_msg
|
||||||
req.ParseFromString(await dgpu_sock.arecv())
|
while True:
|
||||||
|
now = time.time()
|
||||||
|
last_msg_time_delta = now - last_msg
|
||||||
|
logging.info(f'time since last msg: {last_msg_time_delta}')
|
||||||
|
if last_msg_time_delta > refresh_time:
|
||||||
|
dgpu_bus.reconnect()
|
||||||
|
logging.info('reconnected!')
|
||||||
|
last_msg = now
|
||||||
|
|
||||||
if req.method == 'heartbeat':
|
await trio.sleep(refresh_time)
|
||||||
rep = DGPUBusMessage(
|
|
||||||
rid=req.rid,
|
|
||||||
nid=unique_id,
|
|
||||||
method=req.method
|
|
||||||
)
|
|
||||||
rep.params.update({'time': int(time.time() * 1000)})
|
|
||||||
|
|
||||||
if security:
|
n.start_soon(connection_refresher)
|
||||||
rep.auth.cert = cert_name
|
|
||||||
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
|
||||||
|
|
||||||
await dgpu_sock.asend(rep.SerializeToString())
|
res = await rpc_call('dgpu_online')
|
||||||
continue
|
assert 'ok' in res.result
|
||||||
|
|
||||||
if req.nid != unique_id:
|
try:
|
||||||
logging.info(
|
while True:
|
||||||
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
req = DGPUBusMessage()
|
||||||
continue
|
req.ParseFromString(await dgpu_bus.arecv())
|
||||||
|
last_msg = time.time()
|
||||||
|
|
||||||
|
if req.method == 'heartbeat':
|
||||||
|
rep = DGPUBusMessage(
|
||||||
|
rid=req.rid,
|
||||||
|
nid=unique_id,
|
||||||
|
method=req.method
|
||||||
|
)
|
||||||
|
rep.params.update({'time': int(time.time() * 1000)})
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
verify_protobuf_msg(req, skynet_cert)
|
rep.auth.cert = cert_name
|
||||||
|
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
||||||
|
|
||||||
|
await dgpu_bus.asend(rep.SerializeToString())
|
||||||
|
logging.info('heartbeat reply')
|
||||||
|
continue
|
||||||
|
|
||||||
|
if req.nid != unique_id:
|
||||||
|
logging.info(
|
||||||
|
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
if security:
|
||||||
|
verify_protobuf_msg(req, skynet_cert)
|
||||||
|
|
||||||
|
|
||||||
ack_resp = DGPUBusMessage(
|
ack_resp = DGPUBusMessage(
|
||||||
|
rid=req.rid,
|
||||||
|
nid=req.nid
|
||||||
|
)
|
||||||
|
ack_resp.params.update({'ack': {}})
|
||||||
|
|
||||||
|
if security:
|
||||||
|
ack_resp.auth.cert = cert_name
|
||||||
|
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
||||||
|
|
||||||
|
# send ack
|
||||||
|
await dgpu_bus.asend(ack_resp.SerializeToString())
|
||||||
|
|
||||||
|
logging.info(f'sent ack, processing {req.rid}...')
|
||||||
|
|
||||||
|
try:
|
||||||
|
img_req = Text2ImageParameters(**req.params)
|
||||||
|
if not img_req.seed:
|
||||||
|
img_req.seed = random.randint(0, 2 ** 64)
|
||||||
|
|
||||||
|
img = await gpu_compute_one(img_req)
|
||||||
|
img_resp = DGPUBusMessage(
|
||||||
|
rid=req.rid,
|
||||||
|
nid=req.nid,
|
||||||
|
method='binary-reply'
|
||||||
|
)
|
||||||
|
img_resp.params.update({
|
||||||
|
'len': len(img),
|
||||||
|
'meta': img_req.to_dict()
|
||||||
|
})
|
||||||
|
|
||||||
|
except DGPUComputeError as e:
|
||||||
|
traceback.print_exception(type(e), e, e.__traceback__)
|
||||||
|
img_resp = DGPUBusMessage(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
nid=req.nid
|
nid=req.nid
|
||||||
)
|
)
|
||||||
ack_resp.params.update({'ack': {}})
|
img_resp.params.update({'error': str(e)})
|
||||||
|
|
||||||
if security:
|
|
||||||
ack_resp.auth.cert = cert_name
|
|
||||||
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
|
||||||
|
|
||||||
# send ack
|
|
||||||
await dgpu_sock.asend(ack_resp.SerializeToString())
|
|
||||||
|
|
||||||
logging.info(f'sent ack, processing {req.rid}...')
|
|
||||||
|
|
||||||
try:
|
|
||||||
img_req = Text2ImageParameters(**req.params)
|
|
||||||
if not img_req.seed:
|
|
||||||
img_req.seed = random.randint(0, 2 ** 64)
|
|
||||||
|
|
||||||
img = await gpu_compute_one(img_req)
|
|
||||||
img_resp = DGPUBusMessage(
|
|
||||||
rid=req.rid,
|
|
||||||
nid=req.nid,
|
|
||||||
method='binary-reply'
|
|
||||||
)
|
|
||||||
img_resp.params.update({
|
|
||||||
'len': len(img),
|
|
||||||
'meta': img_req.to_dict()
|
|
||||||
})
|
|
||||||
|
|
||||||
except DGPUComputeError as e:
|
|
||||||
traceback.print_exception(type(e), e, e.__traceback__)
|
|
||||||
img_resp = DGPUBusMessage(
|
|
||||||
rid=req.rid,
|
|
||||||
nid=req.nid
|
|
||||||
)
|
|
||||||
img_resp.params.update({'error': str(e)})
|
|
||||||
|
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
img_resp.auth.cert = cert_name
|
img_resp.auth.cert = cert_name
|
||||||
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
|
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
|
||||||
|
|
||||||
# send final image
|
# send final image
|
||||||
logging.info('sending img back...')
|
logging.info('sending img back...')
|
||||||
raw_msg = img_resp.SerializeToString()
|
raw_msg = img_resp.SerializeToString()
|
||||||
await dgpu_sock.asend(raw_msg)
|
await dgpu_bus.asend(raw_msg)
|
||||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||||
if img_resp.method == 'binary-reply':
|
if img_resp.method == 'binary-reply':
|
||||||
await dgpu_sock.asend(img)
|
await dgpu_bus.asend(img)
|
||||||
logging.info(f'sent {len(img)} bytes.')
|
logging.info(f'sent {len(img)} bytes.')
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info('interrupt caught, stopping...')
|
logging.info('interrupt caught, stopping...')
|
||||||
|
n.cancel_scope.cancel()
|
||||||
|
dgpu_bus.close()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
res = await rpc_call('dgpu_offline')
|
res = await rpc_call('dgpu_offline')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
Loading…
Reference in New Issue