Add simple heartbeat mechanic

pull/3/head
Guillermo Rodriguez 2023-01-07 06:59:50 -03:00
parent 83465aadaf
commit 1b42f288bc
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
3 changed files with 57 additions and 23 deletions

View File

@ -1,5 +1,6 @@
#!/usr/bin/python
import time
import json
import uuid
import zlib
@ -333,25 +334,30 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
result = {}
if req.method == 'skynet_shutdown':
raise SkynetShutdownRequested
match req.method:
case 'skynet_shutdown':
raise SkynetShutdownRequested
elif req.method == 'dgpu_online':
connect_node(req.uid)
case 'dgpu_online':
connect_node(req.uid)
elif req.method == 'dgpu_offline':
disconnect_node(req.uid)
case 'dgpu_offline':
disconnect_node(req.uid)
elif req.method == 'dgpu_workers':
result = len(nodes)
case 'dgpu_workers':
result = len(nodes)
elif req.method == 'dgpu_next':
result = next_worker
case 'dgpu_next':
result = next_worker
else:
n.start_soon(
handle_user_request, ctx, req)
continue
case 'heartbeat':
logging.info('beat')
result = {'time': time.time()}
case _:
n.start_soon(
handle_user_request, ctx, req)
continue
resp = SkynetRPCResponse()
resp.result.update({'ok': result})

View File

@ -2,10 +2,9 @@
import gc
import io
import trio
import time
import json
import uuid
import base64
import random
import logging
import traceback
@ -14,6 +13,7 @@ from typing import List, Optional
from pathlib import Path
from contextlib import AsyncExitStack
import trio
import pynng
import torch
@ -141,13 +141,16 @@ async def open_dgpu_node(
torch.cuda.empty_cache()
async with open_skynet_rpc(
unique_id,
rpc_address=rpc_address,
security=security,
cert_name=cert_name,
key_name=key_name
) as rpc_call:
async with (
open_skynet_rpc(
unique_id,
rpc_address=rpc_address,
security=security,
cert_name=cert_name,
key_name=key_name
) as rpc_call,
trio.open_nursery() as n
):
tls_config = None
if security:
@ -182,6 +185,14 @@ async def open_dgpu_node(
own_cert_string=tls_cert_data,
ca_string=skynet_cert_data)
async def heartbeat_service():
while True:
await trio.sleep(60)
before = time.time()
res = await rpc_call('heartbeat')
now = res.result['ok']['time']
logging.info(f'heartbeat ping: {int((now - before) * 1000)}')
logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
dgpu_sock.tls_config = tls_config
@ -190,6 +201,8 @@ async def open_dgpu_node(
res = await rpc_call('dgpu_online')
assert 'ok' in res.result
n.start_soon(heartbeat_service)
try:
while True:
req = DGPUBusMessage()

View File

@ -306,3 +306,18 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
ec, out = dgpu_workers[0].exec_run(
['pkill', '-TERM', '-f', 'skynet'])
assert ec == 0
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_heartbeat(dgpu_workers):
'''
'''
async with open_skynet_rpc(
'test-ctx',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
await trio.sleep(120)