mirror of https://github.com/skygpu/skynet.git
Remove rpc heartbeat service and add it to gpu bus
parent
9c7293f84f
commit
10e77655c6
|
@ -48,6 +48,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
nodes = OrderedDict()
|
nodes = OrderedDict()
|
||||||
wip_reqs = {}
|
wip_reqs = {}
|
||||||
fin_reqs = {}
|
fin_reqs = {}
|
||||||
|
heartbeats = {}
|
||||||
next_worker: Optional[int] = None
|
next_worker: Optional[int] = None
|
||||||
security = len(tls_whitelist) > 0
|
security = len(tls_whitelist) > 0
|
||||||
|
|
||||||
|
@ -116,8 +117,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
|
|
||||||
return nid
|
return nid
|
||||||
|
|
||||||
async def dgpu_image_streamer():
|
async def dgpu_heartbeat_service():
|
||||||
nonlocal wip_reqs, fin_reqs
|
nonlocal heartbeats
|
||||||
|
while True:
|
||||||
|
await trio.sleep(60)
|
||||||
|
rid = uuid.uuid4().hex
|
||||||
|
beat_msg = DGPUBusMessage(
|
||||||
|
rid=rid,
|
||||||
|
nid='',
|
||||||
|
method='heartbeat'
|
||||||
|
)
|
||||||
|
heartbeats.clear()
|
||||||
|
heartbeats[rid] = int(time.time() * 1000)
|
||||||
|
await dgpu_bus.asend(beat_msg.SerializeToString())
|
||||||
|
logging.info('sent heartbeat')
|
||||||
|
|
||||||
|
async def dgpu_bus_streamer():
|
||||||
|
nonlocal wip_reqs, fin_reqs, heartbeats
|
||||||
while True:
|
while True:
|
||||||
raw_msg = await dgpu_bus.arecv()
|
raw_msg = await dgpu_bus.arecv()
|
||||||
logging.info(f'streamer got {len(raw_msg)} bytes.')
|
logging.info(f'streamer got {len(raw_msg)} bytes.')
|
||||||
|
@ -129,6 +145,12 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
|
|
||||||
rid = msg.rid
|
rid = msg.rid
|
||||||
|
|
||||||
|
if msg.method == 'heartbeat':
|
||||||
|
sent_time = heartbeats[rid]
|
||||||
|
delta = msg.params['time'] - sent_time
|
||||||
|
logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}')
|
||||||
|
continue
|
||||||
|
|
||||||
if rid not in wip_reqs:
|
if rid not in wip_reqs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -372,7 +394,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
|
|
||||||
|
|
||||||
async with trio.open_nursery() as n:
|
async with trio.open_nursery() as n:
|
||||||
n.start_soon(dgpu_image_streamer)
|
n.start_soon(dgpu_bus_streamer)
|
||||||
|
n.start_soon(dgpu_heartbeat_service)
|
||||||
n.start_soon(request_service, n)
|
n.start_soon(request_service, n)
|
||||||
logging.info('starting rpc service')
|
logging.info('starting rpc service')
|
||||||
yield
|
yield
|
||||||
|
|
|
@ -2,9 +2,10 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import io
|
import io
|
||||||
import time
|
import trio
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
import time
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -13,7 +14,6 @@ from typing import List, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
|
||||||
import trio
|
|
||||||
import pynng
|
import pynng
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -141,16 +141,13 @@ async def open_dgpu_node(
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
async with (
|
async with open_skynet_rpc(
|
||||||
open_skynet_rpc(
|
unique_id,
|
||||||
unique_id,
|
rpc_address=rpc_address,
|
||||||
rpc_address=rpc_address,
|
security=security,
|
||||||
security=security,
|
cert_name=cert_name,
|
||||||
cert_name=cert_name,
|
key_name=key_name
|
||||||
key_name=key_name
|
) as rpc_call:
|
||||||
) as rpc_call,
|
|
||||||
trio.open_nursery() as n
|
|
||||||
):
|
|
||||||
|
|
||||||
tls_config = None
|
tls_config = None
|
||||||
if security:
|
if security:
|
||||||
|
@ -185,14 +182,6 @@ async def open_dgpu_node(
|
||||||
own_cert_string=tls_cert_data,
|
own_cert_string=tls_cert_data,
|
||||||
ca_string=skynet_cert_data)
|
ca_string=skynet_cert_data)
|
||||||
|
|
||||||
async def heartbeat_service():
|
|
||||||
while True:
|
|
||||||
await trio.sleep(60)
|
|
||||||
before = time.time()
|
|
||||||
res = await rpc_call('heartbeat')
|
|
||||||
now = res.result['ok']['time']
|
|
||||||
logging.info(f'heartbeat ping: {int((now - before) * 1000)}')
|
|
||||||
|
|
||||||
logging.info(f'connecting to {dgpu_address}')
|
logging.info(f'connecting to {dgpu_address}')
|
||||||
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
|
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
|
||||||
dgpu_sock.tls_config = tls_config
|
dgpu_sock.tls_config = tls_config
|
||||||
|
@ -201,13 +190,26 @@ async def open_dgpu_node(
|
||||||
res = await rpc_call('dgpu_online')
|
res = await rpc_call('dgpu_online')
|
||||||
assert 'ok' in res.result
|
assert 'ok' in res.result
|
||||||
|
|
||||||
n.start_soon(heartbeat_service)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
req = DGPUBusMessage()
|
req = DGPUBusMessage()
|
||||||
req.ParseFromString(await dgpu_sock.arecv())
|
req.ParseFromString(await dgpu_sock.arecv())
|
||||||
|
|
||||||
|
if req.method == 'heartbeat':
|
||||||
|
rep = DGPUBusMessage(
|
||||||
|
rid=req.rid,
|
||||||
|
nid=unique_id,
|
||||||
|
method=req.method
|
||||||
|
)
|
||||||
|
rep.params.update({'time': int(time.time() * 1000)})
|
||||||
|
|
||||||
|
if security:
|
||||||
|
rep.auth.cert = cert_name
|
||||||
|
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
|
||||||
|
|
||||||
|
await dgpu_sock.asend(rep.SerializeToString())
|
||||||
|
continue
|
||||||
|
|
||||||
if req.nid != unique_id:
|
if req.nid != unique_id:
|
||||||
logging.info(
|
logging.info(
|
||||||
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
f'witnessed msg {req.rid}, node involved: {req.nid}')
|
||||||
|
@ -216,6 +218,7 @@ async def open_dgpu_node(
|
||||||
if security:
|
if security:
|
||||||
verify_protobuf_msg(req, skynet_cert)
|
verify_protobuf_msg(req, skynet_cert)
|
||||||
|
|
||||||
|
|
||||||
ack_resp = DGPUBusMessage(
|
ack_resp = DGPUBusMessage(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
nid=req.nid
|
nid=req.nid
|
||||||
|
|
Loading…
Reference in New Issue