mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add simple heartbeat mechanic
							parent
							
								
									83465aadaf
								
							
						
					
					
						commit
						1b42f288bc
					
				| 
						 | 
				
			
			@ -1,5 +1,6 @@
 | 
			
		|||
#!/usr/bin/python
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
import json
 | 
			
		||||
import uuid
 | 
			
		||||
import zlib
 | 
			
		||||
| 
						 | 
				
			
			@ -333,22 +334,27 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
			
		|||
 | 
			
		||||
            result = {}
 | 
			
		||||
 | 
			
		||||
            if req.method == 'skynet_shutdown':
 | 
			
		||||
            match req.method:
 | 
			
		||||
                case 'skynet_shutdown':
 | 
			
		||||
                    raise SkynetShutdownRequested
 | 
			
		||||
 | 
			
		||||
            elif req.method == 'dgpu_online':
 | 
			
		||||
                case 'dgpu_online':
 | 
			
		||||
                    connect_node(req.uid)
 | 
			
		||||
 | 
			
		||||
            elif req.method == 'dgpu_offline':
 | 
			
		||||
                case 'dgpu_offline':
 | 
			
		||||
                    disconnect_node(req.uid)
 | 
			
		||||
 | 
			
		||||
            elif req.method == 'dgpu_workers':
 | 
			
		||||
                case 'dgpu_workers':
 | 
			
		||||
                    result = len(nodes)
 | 
			
		||||
 | 
			
		||||
            elif req.method == 'dgpu_next':
 | 
			
		||||
                case 'dgpu_next':
 | 
			
		||||
                    result = next_worker
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                case 'heartbeat':
 | 
			
		||||
                    logging.info('beat')
 | 
			
		||||
                    result = {'time': time.time()}
 | 
			
		||||
 | 
			
		||||
                case _:
 | 
			
		||||
                    n.start_soon(
 | 
			
		||||
                        handle_user_request, ctx, req)
 | 
			
		||||
                    continue
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,10 +2,9 @@
 | 
			
		|||
 | 
			
		||||
import gc
 | 
			
		||||
import io
 | 
			
		||||
import trio
 | 
			
		||||
import time
 | 
			
		||||
import json
 | 
			
		||||
import uuid
 | 
			
		||||
import base64
 | 
			
		||||
import random
 | 
			
		||||
import logging
 | 
			
		||||
import traceback
 | 
			
		||||
| 
						 | 
				
			
			@ -14,6 +13,7 @@ from typing import List, Optional
 | 
			
		|||
from pathlib import Path
 | 
			
		||||
from contextlib import AsyncExitStack
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import pynng
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -141,13 +141,16 @@ async def open_dgpu_node(
 | 
			
		|||
            torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
    async with (
 | 
			
		||||
        open_skynet_rpc(
 | 
			
		||||
            unique_id,
 | 
			
		||||
            rpc_address=rpc_address,
 | 
			
		||||
            security=security,
 | 
			
		||||
            cert_name=cert_name,
 | 
			
		||||
            key_name=key_name
 | 
			
		||||
    ) as rpc_call:
 | 
			
		||||
        ) as rpc_call,
 | 
			
		||||
        trio.open_nursery() as n
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        tls_config = None
 | 
			
		||||
        if security:
 | 
			
		||||
| 
						 | 
				
			
			@ -182,6 +185,14 @@ async def open_dgpu_node(
 | 
			
		|||
                own_cert_string=tls_cert_data,
 | 
			
		||||
                ca_string=skynet_cert_data)
 | 
			
		||||
 | 
			
		||||
        async def heartbeat_service():
 | 
			
		||||
            while True:
 | 
			
		||||
                await trio.sleep(60)
 | 
			
		||||
                before = time.time()
 | 
			
		||||
                res = await rpc_call('heartbeat')
 | 
			
		||||
                now = res.result['ok']['time']
 | 
			
		||||
                logging.info(f'heartbeat ping: {int((now - before) * 1000)}')
 | 
			
		||||
 | 
			
		||||
        logging.info(f'connecting to {dgpu_address}')
 | 
			
		||||
        with pynng.Bus0(recv_max_size=0) as dgpu_sock:
 | 
			
		||||
            dgpu_sock.tls_config = tls_config
 | 
			
		||||
| 
						 | 
				
			
			@ -190,6 +201,8 @@ async def open_dgpu_node(
 | 
			
		|||
            res = await rpc_call('dgpu_online')
 | 
			
		||||
            assert 'ok' in res.result
 | 
			
		||||
 | 
			
		||||
            n.start_soon(heartbeat_service)
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                while True:
 | 
			
		||||
                    req = DGPUBusMessage()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -306,3 +306,18 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
 | 
			
		|||
            ec, out = dgpu_workers[0].exec_run(
 | 
			
		||||
                ['pkill', '-TERM', '-f', 'skynet'])
 | 
			
		||||
            assert ec == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    'dgpu_workers', [(1, ['midj'])], indirect=True)
 | 
			
		||||
async def test_dgpu_heartbeat(dgpu_workers):
 | 
			
		||||
    '''
 | 
			
		||||
    '''
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        'test-ctx',
 | 
			
		||||
        security=True,
 | 
			
		||||
        cert_name='whitelist/testing',
 | 
			
		||||
        key_name='testing'
 | 
			
		||||
    ) as test_rpc:
 | 
			
		||||
        await wait_for_dgpus(test_rpc, 1)
 | 
			
		||||
        await trio.sleep(120)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue