mirror of https://github.com/skygpu/skynet.git
commit
9c7293f84f
|
@ -1,5 +1,6 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import time
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import zlib
|
import zlib
|
||||||
|
@ -333,22 +334,27 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
if req.method == 'skynet_shutdown':
|
match req.method:
|
||||||
|
case 'skynet_shutdown':
|
||||||
raise SkynetShutdownRequested
|
raise SkynetShutdownRequested
|
||||||
|
|
||||||
elif req.method == 'dgpu_online':
|
case 'dgpu_online':
|
||||||
connect_node(req.uid)
|
connect_node(req.uid)
|
||||||
|
|
||||||
elif req.method == 'dgpu_offline':
|
case 'dgpu_offline':
|
||||||
disconnect_node(req.uid)
|
disconnect_node(req.uid)
|
||||||
|
|
||||||
elif req.method == 'dgpu_workers':
|
case 'dgpu_workers':
|
||||||
result = len(nodes)
|
result = len(nodes)
|
||||||
|
|
||||||
elif req.method == 'dgpu_next':
|
case 'dgpu_next':
|
||||||
result = next_worker
|
result = next_worker
|
||||||
|
|
||||||
else:
|
case 'heartbeat':
|
||||||
|
logging.info('beat')
|
||||||
|
result = {'time': time.time()}
|
||||||
|
|
||||||
|
case _:
|
||||||
n.start_soon(
|
n.start_soon(
|
||||||
handle_user_request, ctx, req)
|
handle_user_request, ctx, req)
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -2,10 +2,9 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import io
|
import io
|
||||||
import trio
|
import time
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import base64
|
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -14,6 +13,7 @@ from typing import List, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
|
||||||
|
import trio
|
||||||
import pynng
|
import pynng
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -141,13 +141,16 @@ async def open_dgpu_node(
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
async with open_skynet_rpc(
|
async with (
|
||||||
|
open_skynet_rpc(
|
||||||
unique_id,
|
unique_id,
|
||||||
rpc_address=rpc_address,
|
rpc_address=rpc_address,
|
||||||
security=security,
|
security=security,
|
||||||
cert_name=cert_name,
|
cert_name=cert_name,
|
||||||
key_name=key_name
|
key_name=key_name
|
||||||
) as rpc_call:
|
) as rpc_call,
|
||||||
|
trio.open_nursery() as n
|
||||||
|
):
|
||||||
|
|
||||||
tls_config = None
|
tls_config = None
|
||||||
if security:
|
if security:
|
||||||
|
@ -182,6 +185,14 @@ async def open_dgpu_node(
|
||||||
own_cert_string=tls_cert_data,
|
own_cert_string=tls_cert_data,
|
||||||
ca_string=skynet_cert_data)
|
ca_string=skynet_cert_data)
|
||||||
|
|
||||||
|
async def heartbeat_service():
|
||||||
|
while True:
|
||||||
|
await trio.sleep(60)
|
||||||
|
before = time.time()
|
||||||
|
res = await rpc_call('heartbeat')
|
||||||
|
now = res.result['ok']['time']
|
||||||
|
logging.info(f'heartbeat ping: {int((now - before) * 1000)}')
|
||||||
|
|
||||||
logging.info(f'connecting to {dgpu_address}')
|
logging.info(f'connecting to {dgpu_address}')
|
||||||
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
|
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
|
||||||
dgpu_sock.tls_config = tls_config
|
dgpu_sock.tls_config = tls_config
|
||||||
|
@ -190,6 +201,8 @@ async def open_dgpu_node(
|
||||||
res = await rpc_call('dgpu_online')
|
res = await rpc_call('dgpu_online')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
||||||
|
n.start_soon(heartbeat_service)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
req = DGPUBusMessage()
|
req = DGPUBusMessage()
|
||||||
|
|
|
@ -306,3 +306,18 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
|
||||||
ec, out = dgpu_workers[0].exec_run(
|
ec, out = dgpu_workers[0].exec_run(
|
||||||
['pkill', '-TERM', '-f', 'skynet'])
|
['pkill', '-TERM', '-f', 'skynet'])
|
||||||
assert ec == 0
|
assert ec == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_heartbeat(dgpu_workers):
|
||||||
|
'''
|
||||||
|
'''
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
'test-ctx',
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as test_rpc:
|
||||||
|
await wait_for_dgpus(test_rpc, 1)
|
||||||
|
await trio.sleep(120)
|
||||||
|
|
Loading…
Reference in New Issue