Add reconnect mechanic to dgpu bus

pull/4/head
Guillermo Rodriguez 2023-01-08 07:16:43 -03:00
parent 10e77655c6
commit 585d304f86
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
1 changed files with 153 additions and 85 deletions

View File

@ -12,7 +12,7 @@ import traceback
from typing import List, Optional from typing import List, Optional
from pathlib import Path from pathlib import Path
from contextlib import AsyncExitStack from contextlib import ExitStack
import pynng import pynng
import torch import torch
@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
... ...
class ReconnectingBus:
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
self.address = address
self.tls_config = tls_config
self._stack = ExitStack()
self._sock = None
self._closed = True
def connect(self):
self._sock = self._stack.enter_context(
pynng.Bus0(recv_max_size=0))
self._sock.tls_config = self.tls_config
self._sock.dial(self.address)
self._closed = False
async def arecv(self):
while True:
try:
return await self._sock.arecv()
except pynng.exceptions.Closed:
if self._closed:
raise
async def asend(self, msg):
while True:
try:
return await self._sock.asend(msg)
except pynng.exceptions.Closed:
if self._closed:
raise
def close(self):
self._stack.close()
self._stack = ExitStack()
self._closed = True
def reconnect(self):
self.close()
self.connect()
async def open_dgpu_node( async def open_dgpu_node(
cert_name: str, cert_name: str,
unique_id: str, unique_id: str,
@ -141,13 +186,16 @@ async def open_dgpu_node(
torch.cuda.empty_cache() torch.cuda.empty_cache()
async with open_skynet_rpc( async with (
unique_id, open_skynet_rpc(
rpc_address=rpc_address, unique_id,
security=security, rpc_address=rpc_address,
cert_name=cert_name, security=security,
key_name=key_name cert_name=cert_name,
) as rpc_call: key_name=key_name
) as rpc_call,
trio.open_nursery() as n
):
tls_config = None tls_config = None
if security: if security:
@ -183,98 +231,118 @@ async def open_dgpu_node(
ca_string=skynet_cert_data) ca_string=skynet_cert_data)
logging.info(f'connecting to {dgpu_address}') logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
dgpu_sock.tls_config = tls_config
dgpu_sock.dial(dgpu_address)
res = await rpc_call('dgpu_online') dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
assert 'ok' in res.result dgpu_bus.connect()
try: last_msg = time.time()
while True: async def connection_refresher(refresh_time: int = 120):
req = DGPUBusMessage() nonlocal last_msg
req.ParseFromString(await dgpu_sock.arecv()) while True:
now = time.time()
last_msg_time_delta = now - last_msg
logging.info(f'time since last msg: {last_msg_time_delta}')
if last_msg_time_delta > refresh_time:
dgpu_bus.reconnect()
logging.info('reconnected!')
last_msg = now
if req.method == 'heartbeat': await trio.sleep(refresh_time)
rep = DGPUBusMessage(
rid=req.rid,
nid=unique_id,
method=req.method
)
rep.params.update({'time': int(time.time() * 1000)})
if security: n.start_soon(connection_refresher)
rep.auth.cert = cert_name
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
await dgpu_sock.asend(rep.SerializeToString()) res = await rpc_call('dgpu_online')
continue assert 'ok' in res.result
if req.nid != unique_id: try:
logging.info( while True:
f'witnessed msg {req.rid}, node involved: {req.nid}') req = DGPUBusMessage()
continue req.ParseFromString(await dgpu_bus.arecv())
last_msg = time.time()
if req.method == 'heartbeat':
rep = DGPUBusMessage(
rid=req.rid,
nid=unique_id,
method=req.method
)
rep.params.update({'time': int(time.time() * 1000)})
if security: if security:
verify_protobuf_msg(req, skynet_cert) rep.auth.cert = cert_name
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
await dgpu_bus.asend(rep.SerializeToString())
logging.info('heartbeat reply')
continue
if req.nid != unique_id:
logging.info(
f'witnessed msg {req.rid}, node involved: {req.nid}')
continue
if security:
verify_protobuf_msg(req, skynet_cert)
ack_resp = DGPUBusMessage( ack_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid
)
ack_resp.params.update({'ack': {}})
if security:
ack_resp.auth.cert = cert_name
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
# send ack
await dgpu_bus.asend(ack_resp.SerializeToString())
logging.info(f'sent ack, processing {req.rid}...')
try:
img_req = Text2ImageParameters(**req.params)
if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req)
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid,
method='binary-reply'
)
img_resp.params.update({
'len': len(img),
'meta': img_req.to_dict()
})
except DGPUComputeError as e:
traceback.print_exception(type(e), e, e.__traceback__)
img_resp = DGPUBusMessage(
rid=req.rid, rid=req.rid,
nid=req.nid nid=req.nid
) )
ack_resp.params.update({'ack': {}}) img_resp.params.update({'error': str(e)})
if security:
ack_resp.auth.cert = cert_name
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
# send ack
await dgpu_sock.asend(ack_resp.SerializeToString())
logging.info(f'sent ack, processing {req.rid}...')
try:
img_req = Text2ImageParameters(**req.params)
if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req)
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid,
method='binary-reply'
)
img_resp.params.update({
'len': len(img),
'meta': img_req.to_dict()
})
except DGPUComputeError as e:
traceback.print_exception(type(e), e, e.__traceback__)
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid
)
img_resp.params.update({'error': str(e)})
if security: if security:
img_resp.auth.cert = cert_name img_resp.auth.cert = cert_name
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key) img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
# send final image # send final image
logging.info('sending img back...') logging.info('sending img back...')
raw_msg = img_resp.SerializeToString() raw_msg = img_resp.SerializeToString()
await dgpu_sock.asend(raw_msg) await dgpu_bus.asend(raw_msg)
logging.info(f'sent {len(raw_msg)} bytes.') logging.info(f'sent {len(raw_msg)} bytes.')
if img_resp.method == 'binary-reply': if img_resp.method == 'binary-reply':
await dgpu_sock.asend(img) await dgpu_bus.asend(img)
logging.info(f'sent {len(img)} bytes.') logging.info(f'sent {len(img)} bytes.')
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info('interrupt caught, stopping...') logging.info('interrupt caught, stopping...')
n.cancel_scope.cancel()
dgpu_bus.close()
finally: finally:
res = await rpc_call('dgpu_offline') res = await rpc_call('dgpu_offline')
assert 'ok' in res.result assert 'ok' in res.result