Remove rpc heartbeat service and add it to gpu bus

pull/4/head
Guillermo Rodriguez 2023-01-07 13:01:03 -03:00
parent 9c7293f84f
commit 10e77655c6
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
2 changed files with 51 additions and 25 deletions

View File

@ -48,6 +48,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nodes = OrderedDict() nodes = OrderedDict()
wip_reqs = {} wip_reqs = {}
fin_reqs = {} fin_reqs = {}
heartbeats = {}
next_worker: Optional[int] = None next_worker: Optional[int] = None
security = len(tls_whitelist) > 0 security = len(tls_whitelist) > 0
@ -116,8 +117,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
return nid return nid
async def dgpu_image_streamer(): async def dgpu_heartbeat_service():
nonlocal wip_reqs, fin_reqs nonlocal heartbeats
while True:
await trio.sleep(60)
rid = uuid.uuid4().hex
beat_msg = DGPUBusMessage(
rid=rid,
nid='',
method='heartbeat'
)
heartbeats.clear()
heartbeats[rid] = int(time.time() * 1000)
await dgpu_bus.asend(beat_msg.SerializeToString())
logging.info('sent heartbeat')
async def dgpu_bus_streamer():
nonlocal wip_reqs, fin_reqs, heartbeats
while True: while True:
raw_msg = await dgpu_bus.arecv() raw_msg = await dgpu_bus.arecv()
logging.info(f'streamer got {len(raw_msg)} bytes.') logging.info(f'streamer got {len(raw_msg)} bytes.')
@ -129,6 +145,12 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
rid = msg.rid rid = msg.rid
if msg.method == 'heartbeat':
sent_time = heartbeats[rid]
delta = msg.params['time'] - sent_time
logging.info(f'got heartbeat reply from {msg.nid}, ping: {delta}')
continue
if rid not in wip_reqs: if rid not in wip_reqs:
continue continue
@ -372,7 +394,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
n.start_soon(dgpu_image_streamer) n.start_soon(dgpu_bus_streamer)
n.start_soon(dgpu_heartbeat_service)
n.start_soon(request_service, n) n.start_soon(request_service, n)
logging.info('starting rpc service') logging.info('starting rpc service')
yield yield

View File

@ -2,9 +2,10 @@
import gc import gc
import io import io
import time import trio
import json import json
import uuid import uuid
import time
import random import random
import logging import logging
import traceback import traceback
@ -13,7 +14,6 @@ from typing import List, Optional
from pathlib import Path from pathlib import Path
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
import trio
import pynng import pynng
import torch import torch
@ -141,16 +141,13 @@ async def open_dgpu_node(
torch.cuda.empty_cache() torch.cuda.empty_cache()
async with ( async with open_skynet_rpc(
open_skynet_rpc(
unique_id, unique_id,
rpc_address=rpc_address, rpc_address=rpc_address,
security=security, security=security,
cert_name=cert_name, cert_name=cert_name,
key_name=key_name key_name=key_name
) as rpc_call, ) as rpc_call:
trio.open_nursery() as n
):
tls_config = None tls_config = None
if security: if security:
@ -185,14 +182,6 @@ async def open_dgpu_node(
own_cert_string=tls_cert_data, own_cert_string=tls_cert_data,
ca_string=skynet_cert_data) ca_string=skynet_cert_data)
async def heartbeat_service():
while True:
await trio.sleep(60)
before = time.time()
res = await rpc_call('heartbeat')
now = res.result['ok']['time']
logging.info(f'heartbeat ping: {int((now - before) * 1000)}')
logging.info(f'connecting to {dgpu_address}') logging.info(f'connecting to {dgpu_address}')
with pynng.Bus0(recv_max_size=0) as dgpu_sock: with pynng.Bus0(recv_max_size=0) as dgpu_sock:
dgpu_sock.tls_config = tls_config dgpu_sock.tls_config = tls_config
@ -201,13 +190,26 @@ async def open_dgpu_node(
res = await rpc_call('dgpu_online') res = await rpc_call('dgpu_online')
assert 'ok' in res.result assert 'ok' in res.result
n.start_soon(heartbeat_service)
try: try:
while True: while True:
req = DGPUBusMessage() req = DGPUBusMessage()
req.ParseFromString(await dgpu_sock.arecv()) req.ParseFromString(await dgpu_sock.arecv())
if req.method == 'heartbeat':
rep = DGPUBusMessage(
rid=req.rid,
nid=unique_id,
method=req.method
)
rep.params.update({'time': int(time.time() * 1000)})
if security:
rep.auth.cert = cert_name
rep.auth.sig = sign_protobuf_msg(rep, tls_key)
await dgpu_sock.asend(rep.SerializeToString())
continue
if req.nid != unique_id: if req.nid != unique_id:
logging.info( logging.info(
f'witnessed msg {req.rid}, node involved: {req.nid}') f'witnessed msg {req.rid}, node involved: {req.nid}')
@ -216,6 +218,7 @@ async def open_dgpu_node(
if security: if security:
verify_protobuf_msg(req, skynet_cert) verify_protobuf_msg(req, skynet_cert)
ack_resp = DGPUBusMessage( ack_resp = DGPUBusMessage(
rid=req.rid, rid=req.rid,
nid=req.nid nid=req.nid