Merge pull request #3 from guilledk/heartbeat

Add simple heartbeat mechanic
pull/4/head
Guillermo Rodriguez 2023-01-07 09:09:03 -03:00 committed by GitHub
commit 9c7293f84f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 23 deletions

View File

@ -1,5 +1,6 @@
#!/usr/bin/python #!/usr/bin/python
import time
import json import json
import uuid import uuid
import zlib import zlib
@ -333,25 +334,30 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
result = {} result = {}
if req.method == 'skynet_shutdown': match req.method:
raise SkynetShutdownRequested case 'skynet_shutdown':
raise SkynetShutdownRequested
elif req.method == 'dgpu_online': case 'dgpu_online':
connect_node(req.uid) connect_node(req.uid)
elif req.method == 'dgpu_offline': case 'dgpu_offline':
disconnect_node(req.uid) disconnect_node(req.uid)
elif req.method == 'dgpu_workers': case 'dgpu_workers':
result = len(nodes) result = len(nodes)
elif req.method == 'dgpu_next': case 'dgpu_next':
result = next_worker result = next_worker
else: case 'heartbeat':
n.start_soon( logging.info('beat')
handle_user_request, ctx, req) result = {'time': time.time()}
continue
case _:
n.start_soon(
handle_user_request, ctx, req)
continue
resp = SkynetRPCResponse() resp = SkynetRPCResponse()
resp.result.update({'ok': result}) resp.result.update({'ok': result})

View File

@ -2,10 +2,9 @@
import gc import gc
import io import io
import trio import time
import json import json
import uuid import uuid
import base64
import random import random
import logging import logging
import traceback import traceback
@ -14,6 +13,7 @@ from typing import List, Optional
from pathlib import Path from pathlib import Path
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
import trio
import pynng import pynng
import torch import torch
@ -141,13 +141,16 @@ async def open_dgpu_node(
torch.cuda.empty_cache() torch.cuda.empty_cache()
async with open_skynet_rpc( async with (
unique_id, open_skynet_rpc(
rpc_address=rpc_address, unique_id,
security=security, rpc_address=rpc_address,
cert_name=cert_name, security=security,
key_name=key_name cert_name=cert_name,
) as rpc_call: key_name=key_name
) as rpc_call,
trio.open_nursery() as n
):
tls_config = None tls_config = None
if security: if security:
@ -182,6 +185,14 @@ async def open_dgpu_node(
own_cert_string=tls_cert_data, own_cert_string=tls_cert_data,
ca_string=skynet_cert_data) ca_string=skynet_cert_data)
async def heartbeat_service():
while True:
await trio.sleep(60)
before = time.time()
res = await rpc_call('heartbeat')
now = res.result['ok']['time']
logging.info(f'heartbeat ping: {int((now - before) * 1000)}')
logging.info(f'connecting to {dgpu_address}') logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0(recv_max_size=0) as dgpu_sock: with pynng.Bus0(recv_max_size=0) as dgpu_sock:
dgpu_sock.tls_config = tls_config dgpu_sock.tls_config = tls_config
@ -190,6 +201,8 @@ async def open_dgpu_node(
res = await rpc_call('dgpu_online') res = await rpc_call('dgpu_online')
assert 'ok' in res.result assert 'ok' in res.result
n.start_soon(heartbeat_service)
try: try:
while True: while True:
req = DGPUBusMessage() req = DGPUBusMessage()

View File

@ -306,3 +306,18 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
ec, out = dgpu_workers[0].exec_run( ec, out = dgpu_workers[0].exec_run(
['pkill', '-TERM', '-f', 'skynet']) ['pkill', '-TERM', '-f', 'skynet'])
assert ec == 0 assert ec == 0
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_heartbeat(dgpu_workers):
'''
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
await trio.sleep(120)