Add simple heartbeat mechanic

pull/3/head
Guillermo Rodriguez 2023-01-07 06:59:50 -03:00
parent 83465aadaf
commit 1b42f288bc
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
3 changed files with 57 additions and 23 deletions

View File

@ -1,5 +1,6 @@
#!/usr/bin/python #!/usr/bin/python
import time
import json import json
import uuid import uuid
import zlib import zlib
@ -333,22 +334,27 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
result = {} result = {}
if req.method == 'skynet_shutdown': match req.method:
case 'skynet_shutdown':
raise SkynetShutdownRequested raise SkynetShutdownRequested
elif req.method == 'dgpu_online': case 'dgpu_online':
connect_node(req.uid) connect_node(req.uid)
elif req.method == 'dgpu_offline': case 'dgpu_offline':
disconnect_node(req.uid) disconnect_node(req.uid)
elif req.method == 'dgpu_workers': case 'dgpu_workers':
result = len(nodes) result = len(nodes)
elif req.method == 'dgpu_next': case 'dgpu_next':
result = next_worker result = next_worker
else: case 'heartbeat':
logging.info('beat')
result = {'time': time.time()}
case _:
n.start_soon( n.start_soon(
handle_user_request, ctx, req) handle_user_request, ctx, req)
continue continue

View File

@ -2,10 +2,9 @@
import gc import gc
import io import io
import trio import time
import json import json
import uuid import uuid
import base64
import random import random
import logging import logging
import traceback import traceback
@ -14,6 +13,7 @@ from typing import List, Optional
from pathlib import Path from pathlib import Path
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
import trio
import pynng import pynng
import torch import torch
@ -141,13 +141,16 @@ async def open_dgpu_node(
torch.cuda.empty_cache() torch.cuda.empty_cache()
async with open_skynet_rpc( async with (
open_skynet_rpc(
unique_id, unique_id,
rpc_address=rpc_address, rpc_address=rpc_address,
security=security, security=security,
cert_name=cert_name, cert_name=cert_name,
key_name=key_name key_name=key_name
) as rpc_call: ) as rpc_call,
trio.open_nursery() as n
):
tls_config = None tls_config = None
if security: if security:
@ -182,6 +185,14 @@ async def open_dgpu_node(
own_cert_string=tls_cert_data, own_cert_string=tls_cert_data,
ca_string=skynet_cert_data) ca_string=skynet_cert_data)
async def heartbeat_service():
while True:
await trio.sleep(60)
before = time.time()
res = await rpc_call('heartbeat')
now = res.result['ok']['time']
logging.info(f'heartbeat ping: {int((now - before) * 1000)}')
logging.info(f'connecting to {dgpu_address}') logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0(recv_max_size=0) as dgpu_sock: with pynng.Bus0(recv_max_size=0) as dgpu_sock:
dgpu_sock.tls_config = tls_config dgpu_sock.tls_config = tls_config
@ -190,6 +201,8 @@ async def open_dgpu_node(
res = await rpc_call('dgpu_online') res = await rpc_call('dgpu_online')
assert 'ok' in res.result assert 'ok' in res.result
n.start_soon(heartbeat_service)
try: try:
while True: while True:
req = DGPUBusMessage() req = DGPUBusMessage()

View File

@ -306,3 +306,18 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
ec, out = dgpu_workers[0].exec_run( ec, out = dgpu_workers[0].exec_run(
['pkill', '-TERM', '-f', 'skynet']) ['pkill', '-TERM', '-f', 'skynet'])
assert ec == 0 assert ec == 0
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_heartbeat(dgpu_workers):
'''
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
await trio.sleep(120)