Remove rpc heartbeat service and add it to gpu bus

pull/4/head
Guillermo Rodriguez 2023-01-07 13:01:03 -03:00
parent 9c7293f84f
commit 10e77655c6
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
2 changed files with 51 additions and 25 deletions

View File

@ -48,6 +48,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nodes = OrderedDict()
wip_reqs = {}
fin_reqs = {}
heartbeats = {}
next_worker: Optional[int] = None
security = len(tls_whitelist) > 0
@ -116,8 +117,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
return nid
async def dgpu_image_streamer():
nonlocal wip_reqs, fin_reqs
async def dgpu_heartbeat_service():
nonlocal heartbeats
while True:
await trio.sleep(60)
rid = uuid.uuid4().hex
beat_msg = DGPUBusMessage(
rid=rid,
nid='',
method='heartbeat'
)
heartbeats.clear()
heartbeats[rid] = int(time.time() * 1000)
await dgpu_bus.asend(beat_msg.SerializeToString())
logging.info('sent heartbeat')
async def dgpu_bus_streamer():
nonlocal wip_reqs, fin_reqs, heartbeats
while True:
raw_msg = await dgpu_bus.arecv()
logging.info(f'streamer got {len(raw_msg)} bytes.')
@ -129,6 +145,12 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
rid = msg.rid
if msg.method == 'heartbeat':
sent_time = heartbeats[rid]
delta = msg.params['time'] - sent_time
logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}')
continue
if rid not in wip_reqs:
continue
@ -372,7 +394,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
async with trio.open_nursery() as n:
n.start_soon(dgpu_image_streamer)
n.start_soon(dgpu_bus_streamer)
n.start_soon(dgpu_heartbeat_service)
n.start_soon(request_service, n)
logging.info('starting rpc service')
yield

View File

@ -2,9 +2,10 @@
import gc
import io
import time
import trio
import json
import uuid
import time
import random
import logging
import traceback
@ -13,7 +14,6 @@ from typing import List, Optional
from pathlib import Path
from contextlib import AsyncExitStack
import trio
import pynng
import torch
@ -141,16 +141,13 @@ async def open_dgpu_node(
torch.cuda.empty_cache()
async with (
open_skynet_rpc(
async with open_skynet_rpc(
unique_id,
rpc_address=rpc_address,
security=security,
cert_name=cert_name,
key_name=key_name
) as rpc_call,
trio.open_nursery() as n
):
) as rpc_call:
tls_config = None
if security:
@ -185,14 +182,6 @@ async def open_dgpu_node(
own_cert_string=tls_cert_data,
ca_string=skynet_cert_data)
async def heartbeat_service():
while True:
await trio.sleep(60)
before = time.time()
res = await rpc_call('heartbeat')
now = res.result['ok']['time']
logging.info(f'heartbeat ping: {int((now - before) * 1000)}')
logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0(recv_max_size=0) as dgpu_sock:
dgpu_sock.tls_config = tls_config
@ -201,13 +190,26 @@ async def open_dgpu_node(
res = await rpc_call('dgpu_online')
assert 'ok' in res.result
n.start_soon(heartbeat_service)
try:
while True:
req = DGPUBusMessage()
req.ParseFromString(await dgpu_sock.arecv())
if req.method == 'heartbeat':
rep = DGPUBusMessage(
rid=req.rid,
nid=unique_id,
method=req.method
)
rep.params.update({'time': int(time.time() * 1000)})
if security:
rep.auth.cert = cert_name
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
await dgpu_sock.asend(rep.SerializeToString())
continue
if req.nid != unique_id:
logging.info(
f'witnessed msg {req.rid}, node involved: {req.nid}')
@ -216,6 +218,7 @@ async def open_dgpu_node(
if security:
verify_protobuf_msg(req, skynet_cert)
ack_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid