From 1b42f288bca5a36461de4645fed9208e3b11708e Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sat, 7 Jan 2023 06:59:50 -0300 Subject: [PATCH] Add simple heartbeat mechanic --- skynet/brain.py | 34 ++++++++++++++++++++-------------- skynet/dgpu.py | 31 ++++++++++++++++++++++--------- tests/test_dgpu.py | 15 +++++++++++++++ 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/skynet/brain.py b/skynet/brain.py index ab28df0..9c41803 100644 --- a/skynet/brain.py +++ b/skynet/brain.py @@ -1,5 +1,6 @@ #!/usr/bin/python +import time import json import uuid import zlib @@ -333,25 +334,30 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): result = {} - if req.method == 'skynet_shutdown': - raise SkynetShutdownRequested + match req.method: + case 'skynet_shutdown': + raise SkynetShutdownRequested - elif req.method == 'dgpu_online': - connect_node(req.uid) + case 'dgpu_online': + connect_node(req.uid) - elif req.method == 'dgpu_offline': - disconnect_node(req.uid) + case 'dgpu_offline': + disconnect_node(req.uid) - elif req.method == 'dgpu_workers': - result = len(nodes) + case 'dgpu_workers': + result = len(nodes) - elif req.method == 'dgpu_next': - result = next_worker + case 'dgpu_next': + result = next_worker - else: - n.start_soon( - handle_user_request, ctx, req) - continue + case 'heartbeat': + logging.info('beat') + result = {'time': time.time()} + + case _: + n.start_soon( + handle_user_request, ctx, req) + continue resp = SkynetRPCResponse() resp.result.update({'ok': result}) diff --git a/skynet/dgpu.py b/skynet/dgpu.py index 975c78b..6edcce5 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -2,10 +2,9 @@ import gc import io -import trio +import time import json import uuid -import base64 import random import logging import traceback @@ -14,6 +13,7 @@ from typing import List, Optional from pathlib import Path from contextlib import AsyncExitStack +import trio import pynng import torch @@ -141,13 +141,16 @@ async def open_dgpu_node( torch.cuda.empty_cache() - async with open_skynet_rpc( - unique_id, - rpc_address=rpc_address, - security=security, - cert_name=cert_name, - key_name=key_name - ) as rpc_call: + async with ( + open_skynet_rpc( + unique_id, + rpc_address=rpc_address, + security=security, + cert_name=cert_name, + key_name=key_name + ) as rpc_call, + trio.open_nursery() as n + ): tls_config = None if security: @@ -182,6 +185,14 @@ async def open_dgpu_node( own_cert_string=tls_cert_data, ca_string=skynet_cert_data) + async def heartbeat_service(): + while True: + await trio.sleep(60) + before = time.time() + res = await rpc_call('heartbeat') + now = res.result['ok']['time'] + logging.info(f'heartbeat ping: {int((now - before) * 1000)}') + logging.info(f'connecting to {dgpu_address}') with pynng.Bus0(recv_max_size=0) as dgpu_sock: dgpu_sock.tls_config = tls_config @@ -190,6 +201,8 @@ async def open_dgpu_node( res = await rpc_call('dgpu_online') assert 'ok' in res.result + n.start_soon(heartbeat_service) + try: while True: req = DGPUBusMessage() diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py index 4699156..7d6ef07 100644 --- a/tests/test_dgpu.py +++ b/tests/test_dgpu.py @@ -306,3 +306,18 @@ async def test_dgpu_timeout_while_processing(dgpu_workers): ec, out = dgpu_workers[0].exec_run( ['pkill', '-TERM', '-f', 'skynet']) assert ec == 0 + + +@pytest.mark.parametrize( + 'dgpu_workers', [(1, ['midj'])], indirect=True) +async def test_dgpu_heartbeat(dgpu_workers): + ''' + ''' + async with open_skynet_rpc( + 'test-ctx', + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as test_rpc: + await wait_for_dgpus(test_rpc, 1) + await trio.sleep(120)