Add reconnect mechanic to dgpu bus

pull/4/head
Guillermo Rodriguez 2023-01-08 07:16:43 -03:00
parent 10e77655c6
commit 585d304f86
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
1 changed files with 153 additions and 85 deletions

View File

@ -12,7 +12,7 @@ import traceback
from typing import List, Optional
from pathlib import Path
from contextlib import AsyncExitStack
from contextlib import ExitStack
import pynng
import torch
@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
...
class ReconnectingBus:
def __init__(self, address: str, tls_config: Optional[TLSConfig]):
self.address = address
self.tls_config = tls_config
self._stack = ExitStack()
self._sock = None
self._closed = True
def connect(self):
self._sock = self._stack.enter_context(
pynng.Bus0(recv_max_size=0))
self._sock.tls_config = self.tls_config
self._sock.dial(self.address)
self._closed = False
async def arecv(self):
while True:
try:
return await self._sock.arecv()
except pynng.exceptions.Closed:
if self._closed:
raise
async def asend(self, msg):
while True:
try:
return await self._sock.asend(msg)
except pynng.exceptions.Closed:
if self._closed:
raise
def close(self):
self._stack.close()
self._stack = ExitStack()
self._closed = True
def reconnect(self):
self.close()
self.connect()
async def open_dgpu_node(
cert_name: str,
unique_id: str,
@ -141,13 +186,16 @@ async def open_dgpu_node(
torch.cuda.empty_cache()
async with open_skynet_rpc(
unique_id,
rpc_address=rpc_address,
security=security,
cert_name=cert_name,
key_name=key_name
) as rpc_call:
async with (
open_skynet_rpc(
unique_id,
rpc_address=rpc_address,
security=security,
cert_name=cert_name,
key_name=key_name
) as rpc_call,
trio.open_nursery() as n
):
tls_config = None
if security:
@ -183,98 +231,118 @@ async def open_dgpu_node(
ca_string=skynet_cert_data)
logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
dgpu_sock.tls_config = tls_config
dgpu_sock.dial(dgpu_address)
res = await rpc_call('dgpu_online')
assert 'ok' in res.result
dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
dgpu_bus.connect()
try:
while True:
req = DGPUBusMessage()
req.ParseFromString(await dgpu_sock.arecv())
last_msg = time.time()
async def connection_refresher(refresh_time: int = 120):
nonlocal last_msg
while True:
now = time.time()
last_msg_time_delta = now - last_msg
logging.info(f'time since last msg: {last_msg_time_delta}')
if last_msg_time_delta > refresh_time:
dgpu_bus.reconnect()
logging.info('reconnected!')
last_msg = now
if req.method == 'heartbeat':
rep = DGPUBusMessage(
rid=req.rid,
nid=unique_id,
method=req.method
)
rep.params.update({'time': int(time.time() * 1000)})
await trio.sleep(refresh_time)
if security:
rep.auth.cert = cert_name
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
n.start_soon(connection_refresher)
await dgpu_sock.asend(rep.SerializeToString())
continue
res = await rpc_call('dgpu_online')
assert 'ok' in res.result
if req.nid != unique_id:
logging.info(
f'witnessed msg {req.rid}, node involved: {req.nid}')
continue
try:
while True:
req = DGPUBusMessage()
req.ParseFromString(await dgpu_bus.arecv())
last_msg = time.time()
if req.method == 'heartbeat':
rep = DGPUBusMessage(
rid=req.rid,
nid=unique_id,
method=req.method
)
rep.params.update({'time': int(time.time() * 1000)})
if security:
verify_protobuf_msg(req, skynet_cert)
rep.auth.cert = cert_name
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
await dgpu_bus.asend(rep.SerializeToString())
logging.info('heartbeat reply')
continue
if req.nid != unique_id:
logging.info(
f'witnessed msg {req.rid}, node involved: {req.nid}')
continue
if security:
verify_protobuf_msg(req, skynet_cert)
ack_resp = DGPUBusMessage(
ack_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid
)
ack_resp.params.update({'ack': {}})
if security:
ack_resp.auth.cert = cert_name
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
# send ack
await dgpu_bus.asend(ack_resp.SerializeToString())
logging.info(f'sent ack, processing {req.rid}...')
try:
img_req = Text2ImageParameters(**req.params)
if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req)
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid,
method='binary-reply'
)
img_resp.params.update({
'len': len(img),
'meta': img_req.to_dict()
})
except DGPUComputeError as e:
traceback.print_exception(type(e), e, e.__traceback__)
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid
)
ack_resp.params.update({'ack': {}})
if security:
ack_resp.auth.cert = cert_name
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
# send ack
await dgpu_sock.asend(ack_resp.SerializeToString())
logging.info(f'sent ack, processing {req.rid}...')
try:
img_req = Text2ImageParameters(**req.params)
if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req)
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid,
method='binary-reply'
)
img_resp.params.update({
'len': len(img),
'meta': img_req.to_dict()
})
except DGPUComputeError as e:
traceback.print_exception(type(e), e, e.__traceback__)
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid
)
img_resp.params.update({'error': str(e)})
img_resp.params.update({'error': str(e)})
if security:
img_resp.auth.cert = cert_name
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
if security:
img_resp.auth.cert = cert_name
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
# send final image
logging.info('sending img back...')
raw_msg = img_resp.SerializeToString()
await dgpu_sock.asend(raw_msg)
logging.info(f'sent {len(raw_msg)} bytes.')
if img_resp.method == 'binary-reply':
await dgpu_sock.asend(img)
logging.info(f'sent {len(img)} bytes.')
# send final image
logging.info('sending img back...')
raw_msg = img_resp.SerializeToString()
await dgpu_bus.asend(raw_msg)
logging.info(f'sent {len(raw_msg)} bytes.')
if img_resp.method == 'binary-reply':
await dgpu_bus.asend(img)
logging.info(f'sent {len(img)} bytes.')
except KeyboardInterrupt:
logging.info('interrupt caught, stopping...')
except KeyboardInterrupt:
logging.info('interrupt caught, stopping...')
n.cancel_scope.cancel()
dgpu_bus.close()
finally:
res = await rpc_call('dgpu_offline')
assert 'ok' in res.result
finally:
res = await rpc_call('dgpu_offline')
assert 'ok' in res.result