mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add reconnect mechanic to dgpu bus
							parent
							
								
									10e77655c6
								
							
						
					
					
						commit
						585d304f86
					
				| 
						 | 
					@ -12,7 +12,7 @@ import traceback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import List, Optional
 | 
					from typing import List, Optional
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from contextlib import AsyncExitStack
 | 
					from contextlib import ExitStack
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pynng
 | 
					import pynng
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
| 
						 | 
					@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
 | 
				
			||||||
    ...
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ReconnectingBus:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, address: str, tls_config: Optional[TLSConfig]):
 | 
				
			||||||
 | 
					        self.address = address
 | 
				
			||||||
 | 
					        self.tls_config = tls_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._stack = ExitStack()
 | 
				
			||||||
 | 
					        self._sock = None
 | 
				
			||||||
 | 
					        self._closed = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def connect(self):
 | 
				
			||||||
 | 
					        self._sock = self._stack.enter_context(
 | 
				
			||||||
 | 
					            pynng.Bus0(recv_max_size=0))
 | 
				
			||||||
 | 
					        self._sock.tls_config = self.tls_config
 | 
				
			||||||
 | 
					        self._sock.dial(self.address)
 | 
				
			||||||
 | 
					        self._closed = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def arecv(self):
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                return await self._sock.arecv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            except pynng.exceptions.Closed:
 | 
				
			||||||
 | 
					                if self._closed:
 | 
				
			||||||
 | 
					                    raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def asend(self, msg):
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                return await self._sock.asend(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            except pynng.exceptions.Closed:
 | 
				
			||||||
 | 
					                if self._closed:
 | 
				
			||||||
 | 
					                    raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close(self):
 | 
				
			||||||
 | 
					        self._stack.close()
 | 
				
			||||||
 | 
					        self._stack = ExitStack()
 | 
				
			||||||
 | 
					        self._closed = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def reconnect(self):
 | 
				
			||||||
 | 
					        self.close()
 | 
				
			||||||
 | 
					        self.connect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def open_dgpu_node(
 | 
					async def open_dgpu_node(
 | 
				
			||||||
    cert_name: str,
 | 
					    cert_name: str,
 | 
				
			||||||
    unique_id: str,
 | 
					    unique_id: str,
 | 
				
			||||||
| 
						 | 
					@ -141,13 +186,16 @@ async def open_dgpu_node(
 | 
				
			||||||
            torch.cuda.empty_cache()
 | 
					            torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async with open_skynet_rpc(
 | 
					    async with (
 | 
				
			||||||
 | 
					        open_skynet_rpc(
 | 
				
			||||||
            unique_id,
 | 
					            unique_id,
 | 
				
			||||||
            rpc_address=rpc_address,
 | 
					            rpc_address=rpc_address,
 | 
				
			||||||
            security=security,
 | 
					            security=security,
 | 
				
			||||||
            cert_name=cert_name,
 | 
					            cert_name=cert_name,
 | 
				
			||||||
            key_name=key_name
 | 
					            key_name=key_name
 | 
				
			||||||
    ) as rpc_call:
 | 
					        ) as rpc_call,
 | 
				
			||||||
 | 
					        trio.open_nursery() as n
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        tls_config = None
 | 
					        tls_config = None
 | 
				
			||||||
        if security:
 | 
					        if security:
 | 
				
			||||||
| 
						 | 
					@ -183,9 +231,25 @@ async def open_dgpu_node(
 | 
				
			||||||
                ca_string=skynet_cert_data)
 | 
					                ca_string=skynet_cert_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info(f'connecting to {dgpu_address}')
 | 
					        logging.info(f'connecting to {dgpu_address}')
 | 
				
			||||||
        with pynng.Bus0(recv_max_size=0) as dgpu_sock:
 | 
					
 | 
				
			||||||
            dgpu_sock.tls_config = tls_config
 | 
					        dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
 | 
				
			||||||
            dgpu_sock.dial(dgpu_address)
 | 
					        dgpu_bus.connect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        last_msg = time.time()
 | 
				
			||||||
 | 
					        async def connection_refresher(refresh_time: int = 120):
 | 
				
			||||||
 | 
					            nonlocal last_msg
 | 
				
			||||||
 | 
					            while True:
 | 
				
			||||||
 | 
					                now = time.time()
 | 
				
			||||||
 | 
					                last_msg_time_delta = now - last_msg
 | 
				
			||||||
 | 
					                logging.info(f'time since last msg: {last_msg_time_delta}')
 | 
				
			||||||
 | 
					                if last_msg_time_delta > refresh_time:
 | 
				
			||||||
 | 
					                    dgpu_bus.reconnect()
 | 
				
			||||||
 | 
					                    logging.info('reconnected!')
 | 
				
			||||||
 | 
					                    last_msg = now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                await trio.sleep(refresh_time)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        n.start_soon(connection_refresher)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        res = await rpc_call('dgpu_online')
 | 
					        res = await rpc_call('dgpu_online')
 | 
				
			||||||
        assert 'ok' in res.result
 | 
					        assert 'ok' in res.result
 | 
				
			||||||
| 
						 | 
					@ -193,7 +257,8 @@ async def open_dgpu_node(
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            while True:
 | 
					            while True:
 | 
				
			||||||
                req = DGPUBusMessage()
 | 
					                req = DGPUBusMessage()
 | 
				
			||||||
                    req.ParseFromString(await dgpu_sock.arecv())
 | 
					                req.ParseFromString(await dgpu_bus.arecv())
 | 
				
			||||||
 | 
					                last_msg = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if req.method == 'heartbeat':
 | 
					                if req.method == 'heartbeat':
 | 
				
			||||||
                    rep = DGPUBusMessage(
 | 
					                    rep = DGPUBusMessage(
 | 
				
			||||||
| 
						 | 
					@ -207,7 +272,8 @@ async def open_dgpu_node(
 | 
				
			||||||
                        rep.auth.cert = cert_name
 | 
					                        rep.auth.cert = cert_name
 | 
				
			||||||
                        rep.auth.sig = sign_protobuf_msg(rep, tls_key)
 | 
					                        rep.auth.sig = sign_protobuf_msg(rep, tls_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        await dgpu_sock.asend(rep.SerializeToString())
 | 
					                    await dgpu_bus.asend(rep.SerializeToString())
 | 
				
			||||||
 | 
					                    logging.info('heartbeat reply')
 | 
				
			||||||
                    continue
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if req.nid != unique_id:
 | 
					                if req.nid != unique_id:
 | 
				
			||||||
| 
						 | 
					@ -230,7 +296,7 @@ async def open_dgpu_node(
 | 
				
			||||||
                    ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
 | 
					                    ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # send ack
 | 
					                # send ack
 | 
				
			||||||
                    await dgpu_sock.asend(ack_resp.SerializeToString())
 | 
					                await dgpu_bus.asend(ack_resp.SerializeToString())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                logging.info(f'sent ack, processing {req.rid}...')
 | 
					                logging.info(f'sent ack, processing {req.rid}...')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -266,14 +332,16 @@ async def open_dgpu_node(
 | 
				
			||||||
                # send final image
 | 
					                # send final image
 | 
				
			||||||
                logging.info('sending img back...')
 | 
					                logging.info('sending img back...')
 | 
				
			||||||
                raw_msg = img_resp.SerializeToString()
 | 
					                raw_msg = img_resp.SerializeToString()
 | 
				
			||||||
                    await dgpu_sock.asend(raw_msg)
 | 
					                await dgpu_bus.asend(raw_msg)
 | 
				
			||||||
                logging.info(f'sent {len(raw_msg)} bytes.')
 | 
					                logging.info(f'sent {len(raw_msg)} bytes.')
 | 
				
			||||||
                if img_resp.method == 'binary-reply':
 | 
					                if img_resp.method == 'binary-reply':
 | 
				
			||||||
                        await dgpu_sock.asend(img)
 | 
					                    await dgpu_bus.asend(img)
 | 
				
			||||||
                    logging.info(f'sent {len(img)} bytes.')
 | 
					                    logging.info(f'sent {len(img)} bytes.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except KeyboardInterrupt:
 | 
					        except KeyboardInterrupt:
 | 
				
			||||||
            logging.info('interrupt caught, stopping...')
 | 
					            logging.info('interrupt caught, stopping...')
 | 
				
			||||||
 | 
					            n.cancel_scope.cancel()
 | 
				
			||||||
 | 
					            dgpu_bus.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
            res = await rpc_call('dgpu_offline')
 | 
					            res = await rpc_call('dgpu_offline')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue