mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add simple heartbeat mechanic
							parent
							
								
									83465aadaf
								
							
						
					
					
						commit
						1b42f288bc
					
				| 
						 | 
					@ -1,5 +1,6 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import uuid
 | 
					import uuid
 | 
				
			||||||
import zlib
 | 
					import zlib
 | 
				
			||||||
| 
						 | 
					@ -333,22 +334,27 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            result = {}
 | 
					            result = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if req.method == 'skynet_shutdown':
 | 
					            match req.method:
 | 
				
			||||||
 | 
					                case 'skynet_shutdown':
 | 
				
			||||||
                    raise SkynetShutdownRequested
 | 
					                    raise SkynetShutdownRequested
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            elif req.method == 'dgpu_online':
 | 
					                case 'dgpu_online':
 | 
				
			||||||
                    connect_node(req.uid)
 | 
					                    connect_node(req.uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            elif req.method == 'dgpu_offline':
 | 
					                case 'dgpu_offline':
 | 
				
			||||||
                    disconnect_node(req.uid)
 | 
					                    disconnect_node(req.uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            elif req.method == 'dgpu_workers':
 | 
					                case 'dgpu_workers':
 | 
				
			||||||
                    result = len(nodes)
 | 
					                    result = len(nodes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            elif req.method == 'dgpu_next':
 | 
					                case 'dgpu_next':
 | 
				
			||||||
                    result = next_worker
 | 
					                    result = next_worker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            else:
 | 
					                case 'heartbeat':
 | 
				
			||||||
 | 
					                    logging.info('beat')
 | 
				
			||||||
 | 
					                    result = {'time': time.time()}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                case _:
 | 
				
			||||||
                    n.start_soon(
 | 
					                    n.start_soon(
 | 
				
			||||||
                        handle_user_request, ctx, req)
 | 
					                        handle_user_request, ctx, req)
 | 
				
			||||||
                    continue
 | 
					                    continue
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,10 +2,9 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import gc
 | 
					import gc
 | 
				
			||||||
import io
 | 
					import io
 | 
				
			||||||
import trio
 | 
					import time
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import uuid
 | 
					import uuid
 | 
				
			||||||
import base64
 | 
					 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
| 
						 | 
					@ -14,6 +13,7 @@ from typing import List, Optional
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from contextlib import AsyncExitStack
 | 
					from contextlib import AsyncExitStack
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import trio
 | 
				
			||||||
import pynng
 | 
					import pynng
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -141,13 +141,16 @@ async def open_dgpu_node(
 | 
				
			||||||
            torch.cuda.empty_cache()
 | 
					            torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async with open_skynet_rpc(
 | 
					    async with (
 | 
				
			||||||
 | 
					        open_skynet_rpc(
 | 
				
			||||||
            unique_id,
 | 
					            unique_id,
 | 
				
			||||||
            rpc_address=rpc_address,
 | 
					            rpc_address=rpc_address,
 | 
				
			||||||
            security=security,
 | 
					            security=security,
 | 
				
			||||||
            cert_name=cert_name,
 | 
					            cert_name=cert_name,
 | 
				
			||||||
            key_name=key_name
 | 
					            key_name=key_name
 | 
				
			||||||
    ) as rpc_call:
 | 
					        ) as rpc_call,
 | 
				
			||||||
 | 
					        trio.open_nursery() as n
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        tls_config = None
 | 
					        tls_config = None
 | 
				
			||||||
        if security:
 | 
					        if security:
 | 
				
			||||||
| 
						 | 
					@ -182,6 +185,14 @@ async def open_dgpu_node(
 | 
				
			||||||
                own_cert_string=tls_cert_data,
 | 
					                own_cert_string=tls_cert_data,
 | 
				
			||||||
                ca_string=skynet_cert_data)
 | 
					                ca_string=skynet_cert_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        async def heartbeat_service():
 | 
				
			||||||
 | 
					            while True:
 | 
				
			||||||
 | 
					                await trio.sleep(60)
 | 
				
			||||||
 | 
					                before = time.time()
 | 
				
			||||||
 | 
					                res = await rpc_call('heartbeat')
 | 
				
			||||||
 | 
					                now = res.result['ok']['time']
 | 
				
			||||||
 | 
					                logging.info(f'heartbeat ping: {int((now - before) * 1000)}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info(f'connecting to {dgpu_address}')
 | 
					        logging.info(f'connecting to {dgpu_address}')
 | 
				
			||||||
        with pynng.Bus0(recv_max_size=0) as dgpu_sock:
 | 
					        with pynng.Bus0(recv_max_size=0) as dgpu_sock:
 | 
				
			||||||
            dgpu_sock.tls_config = tls_config
 | 
					            dgpu_sock.tls_config = tls_config
 | 
				
			||||||
| 
						 | 
					@ -190,6 +201,8 @@ async def open_dgpu_node(
 | 
				
			||||||
            res = await rpc_call('dgpu_online')
 | 
					            res = await rpc_call('dgpu_online')
 | 
				
			||||||
            assert 'ok' in res.result
 | 
					            assert 'ok' in res.result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            n.start_soon(heartbeat_service)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                while True:
 | 
					                while True:
 | 
				
			||||||
                    req = DGPUBusMessage()
 | 
					                    req = DGPUBusMessage()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -306,3 +306,18 @@ async def test_dgpu_timeout_while_processing(dgpu_workers):
 | 
				
			||||||
            ec, out = dgpu_workers[0].exec_run(
 | 
					            ec, out = dgpu_workers[0].exec_run(
 | 
				
			||||||
                ['pkill', '-TERM', '-f', 'skynet'])
 | 
					                ['pkill', '-TERM', '-f', 'skynet'])
 | 
				
			||||||
            assert ec == 0
 | 
					            assert ec == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    'dgpu_workers', [(1, ['midj'])], indirect=True)
 | 
				
			||||||
 | 
					async def test_dgpu_heartbeat(dgpu_workers):
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
 | 
					        'test-ctx',
 | 
				
			||||||
 | 
					        security=True,
 | 
				
			||||||
 | 
					        cert_name='whitelist/testing',
 | 
				
			||||||
 | 
					        key_name='testing'
 | 
				
			||||||
 | 
					    ) as test_rpc:
 | 
				
			||||||
 | 
					        await wait_for_dgpus(test_rpc, 1)
 | 
				
			||||||
 | 
					        await trio.sleep(120)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue