mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add reconnect mechanic to dgpu bus
							parent
							
								
									10e77655c6
								
							
						
					
					
						commit
						585d304f86
					
				| 
						 | 
				
			
			@ -12,7 +12,7 @@ import traceback
 | 
			
		|||
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from contextlib import AsyncExitStack
 | 
			
		||||
from contextlib import ExitStack
 | 
			
		||||
 | 
			
		||||
import pynng
 | 
			
		||||
import torch
 | 
			
		||||
| 
						 | 
				
			
			@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
 | 
			
		|||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReconnectingBus:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, address: str, tls_config: Optional[TLSConfig]):
 | 
			
		||||
        self.address = address
 | 
			
		||||
        self.tls_config = tls_config
 | 
			
		||||
 | 
			
		||||
        self._stack = ExitStack()
 | 
			
		||||
        self._sock = None
 | 
			
		||||
        self._closed = True
 | 
			
		||||
 | 
			
		||||
    def connect(self):
 | 
			
		||||
        self._sock = self._stack.enter_context(
 | 
			
		||||
            pynng.Bus0(recv_max_size=0))
 | 
			
		||||
        self._sock.tls_config = self.tls_config
 | 
			
		||||
        self._sock.dial(self.address)
 | 
			
		||||
        self._closed = False
 | 
			
		||||
 | 
			
		||||
    async def arecv(self):
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                return await self._sock.arecv()
 | 
			
		||||
 | 
			
		||||
            except pynng.exceptions.Closed:
 | 
			
		||||
                if self._closed:
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
    async def asend(self, msg):
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                return await self._sock.asend(msg)
 | 
			
		||||
 | 
			
		||||
            except pynng.exceptions.Closed:
 | 
			
		||||
                if self._closed:
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._stack.close()
 | 
			
		||||
        self._stack = ExitStack()
 | 
			
		||||
        self._closed = True
 | 
			
		||||
 | 
			
		||||
    def reconnect(self):
 | 
			
		||||
        self.close()
 | 
			
		||||
        self.connect()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def open_dgpu_node(
 | 
			
		||||
    cert_name: str,
 | 
			
		||||
    unique_id: str,
 | 
			
		||||
| 
						 | 
				
			
			@ -141,13 +186,16 @@ async def open_dgpu_node(
 | 
			
		|||
            torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
    async with (
 | 
			
		||||
        open_skynet_rpc(
 | 
			
		||||
            unique_id,
 | 
			
		||||
            rpc_address=rpc_address,
 | 
			
		||||
            security=security,
 | 
			
		||||
            cert_name=cert_name,
 | 
			
		||||
            key_name=key_name
 | 
			
		||||
    ) as rpc_call:
 | 
			
		||||
        ) as rpc_call,
 | 
			
		||||
        trio.open_nursery() as n
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        tls_config = None
 | 
			
		||||
        if security:
 | 
			
		||||
| 
						 | 
				
			
			@ -183,9 +231,25 @@ async def open_dgpu_node(
 | 
			
		|||
                ca_string=skynet_cert_data)
 | 
			
		||||
 | 
			
		||||
        logging.info(f'connecting to {dgpu_address}')
 | 
			
		||||
        with pynng.Bus0(recv_max_size=0) as dgpu_sock:
 | 
			
		||||
            dgpu_sock.tls_config = tls_config
 | 
			
		||||
            dgpu_sock.dial(dgpu_address)
 | 
			
		||||
 | 
			
		||||
        dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
 | 
			
		||||
        dgpu_bus.connect()
 | 
			
		||||
 | 
			
		||||
        last_msg = time.time()
 | 
			
		||||
        async def connection_refresher(refresh_time: int = 120):
 | 
			
		||||
            nonlocal last_msg
 | 
			
		||||
            while True:
 | 
			
		||||
                now = time.time()
 | 
			
		||||
                last_msg_time_delta = now - last_msg
 | 
			
		||||
                logging.info(f'time since last msg: {last_msg_time_delta}')
 | 
			
		||||
                if last_msg_time_delta > refresh_time:
 | 
			
		||||
                    dgpu_bus.reconnect()
 | 
			
		||||
                    logging.info('reconnected!')
 | 
			
		||||
                    last_msg = now
 | 
			
		||||
 | 
			
		||||
                await trio.sleep(refresh_time)
 | 
			
		||||
 | 
			
		||||
        n.start_soon(connection_refresher)
 | 
			
		||||
 | 
			
		||||
        res = await rpc_call('dgpu_online')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
| 
						 | 
				
			
			@ -193,7 +257,8 @@ async def open_dgpu_node(
 | 
			
		|||
        try:
 | 
			
		||||
            while True:
 | 
			
		||||
                req = DGPUBusMessage()
 | 
			
		||||
                    req.ParseFromString(await dgpu_sock.arecv())
 | 
			
		||||
                req.ParseFromString(await dgpu_bus.arecv())
 | 
			
		||||
                last_msg = time.time()
 | 
			
		||||
 | 
			
		||||
                if req.method == 'heartbeat':
 | 
			
		||||
                    rep = DGPUBusMessage(
 | 
			
		||||
| 
						 | 
				
			
			@ -207,7 +272,8 @@ async def open_dgpu_node(
 | 
			
		|||
                        rep.auth.cert = cert_name
 | 
			
		||||
                        rep.auth.sig = sign_protobuf_msg(rep, tls_key)
 | 
			
		||||
 | 
			
		||||
                        await dgpu_sock.asend(rep.SerializeToString())
 | 
			
		||||
                    await dgpu_bus.asend(rep.SerializeToString())
 | 
			
		||||
                    logging.info('heartbeat reply')
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                if req.nid != unique_id:
 | 
			
		||||
| 
						 | 
				
			
			@ -230,7 +296,7 @@ async def open_dgpu_node(
 | 
			
		|||
                    ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
 | 
			
		||||
 | 
			
		||||
                # send ack
 | 
			
		||||
                    await dgpu_sock.asend(ack_resp.SerializeToString())
 | 
			
		||||
                await dgpu_bus.asend(ack_resp.SerializeToString())
 | 
			
		||||
 | 
			
		||||
                logging.info(f'sent ack, processing {req.rid}...')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -266,14 +332,16 @@ async def open_dgpu_node(
 | 
			
		|||
                # send final image
 | 
			
		||||
                logging.info('sending img back...')
 | 
			
		||||
                raw_msg = img_resp.SerializeToString()
 | 
			
		||||
                    await dgpu_sock.asend(raw_msg)
 | 
			
		||||
                await dgpu_bus.asend(raw_msg)
 | 
			
		||||
                logging.info(f'sent {len(raw_msg)} bytes.')
 | 
			
		||||
                if img_resp.method == 'binary-reply':
 | 
			
		||||
                        await dgpu_sock.asend(img)
 | 
			
		||||
                    await dgpu_bus.asend(img)
 | 
			
		||||
                    logging.info(f'sent {len(img)} bytes.')
 | 
			
		||||
 | 
			
		||||
        except KeyboardInterrupt:
 | 
			
		||||
            logging.info('interrupt caught, stopping...')
 | 
			
		||||
            n.cancel_scope.cancel()
 | 
			
		||||
            dgpu_bus.close()
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
            res = await rpc_call('dgpu_offline')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue