diff --git a/skynet/brain.py b/skynet/brain.py index 9c41803..d649483 100644 --- a/skynet/brain.py +++ b/skynet/brain.py @@ -48,6 +48,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): nodes = OrderedDict() wip_reqs = {} fin_reqs = {} + heartbeats = {} next_worker: Optional[int] = None security = len(tls_whitelist) > 0 @@ -116,8 +117,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): return nid - async def dgpu_image_streamer(): - nonlocal wip_reqs, fin_reqs + async def dgpu_heartbeat_service(): + nonlocal heartbeats + while True: + await trio.sleep(60) + rid = uuid.uuid4().hex + beat_msg = DGPUBusMessage( + rid=rid, + nid='', + method='heartbeat' + ) + heartbeats.clear() + heartbeats[rid] = int(time.time() * 1000) + await dgpu_bus.asend(beat_msg.SerializeToString()) + logging.info('sent heartbeat') + + async def dgpu_bus_streamer(): + nonlocal wip_reqs, fin_reqs, heartbeats while True: raw_msg = await dgpu_bus.arecv() logging.info(f'streamer got {len(raw_msg)} bytes.') @@ -129,6 +145,12 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): rid = msg.rid + if msg.method == 'heartbeat': + sent_time = heartbeats[rid] + delta = msg.params['time'] - sent_time + logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}') + continue + if rid not in wip_reqs: continue @@ -372,7 +394,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): async with trio.open_nursery() as n: - n.start_soon(dgpu_image_streamer) + n.start_soon(dgpu_bus_streamer) + n.start_soon(dgpu_heartbeat_service) n.start_soon(request_service, n) logging.info('starting rpc service') yield diff --git a/skynet/dgpu.py b/skynet/dgpu.py index 6edcce5..3075ee2 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -2,9 +2,10 @@ import gc import io -import time +import trio import json import uuid +import time import random import logging import traceback @@ -13,7 +14,6 @@ from typing import List, Optional from pathlib import Path from contextlib import AsyncExitStack -import trio import pynng import torch @@ -141,16 +141,13 @@ async def open_dgpu_node( torch.cuda.empty_cache() - async with ( - open_skynet_rpc( - unique_id, - rpc_address=rpc_address, - security=security, - cert_name=cert_name, - key_name=key_name - ) as rpc_call, - trio.open_nursery() as n - ): + async with open_skynet_rpc( + unique_id, + rpc_address=rpc_address, + security=security, + cert_name=cert_name, + key_name=key_name + ) as rpc_call: tls_config = None if security: @@ -185,14 +182,6 @@ async def open_dgpu_node( own_cert_string=tls_cert_data, ca_string=skynet_cert_data) - async def heartbeat_service(): - while True: - await trio.sleep(60) - before = time.time() - res = await rpc_call('heartbeat') - now = res.result['ok']['time'] - logging.info(f'heartbeat ping: {int((now - before) * 1000)}') - logging.info(f'connecting to {dgpu_address}') with pynng.Bus0(recv_max_size=0) as dgpu_sock: dgpu_sock.tls_config = tls_config @@ -201,13 +190,26 @@ async def open_dgpu_node( res = await rpc_call('dgpu_online') assert 'ok' in res.result - n.start_soon(heartbeat_service) - try: while True: req = DGPUBusMessage() req.ParseFromString(await dgpu_sock.arecv()) + if req.method == 'heartbeat': + rep = DGPUBusMessage( + rid=req.rid, + nid=unique_id, + method=req.method + ) + rep.params.update({'time': int(time.time() * 1000)}) + + if security: + rep.auth.cert = cert_name + rep.auth.sig = sign_protobuf_msg(rep, tls_key) + + await dgpu_sock.asend(rep.SerializeToString()) + continue + if req.nid != unique_id: logging.info( f'witnessed msg {req.rid}, node involved: {req.nid}') @@ -216,6 +218,7 @@ async def open_dgpu_node( if security: verify_protobuf_msg(req, skynet_cert) + ack_resp = DGPUBusMessage( rid=req.rid, nid=req.nid