mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add reconnect mechanic to dgpu bus
							parent
							
								
									10e77655c6
								
							
						
					
					
						commit
						585d304f86
					
				
							
								
								
									
										238
									
								
								skynet/dgpu.py
								
								
								
								
							
							
						
						
									
										238
									
								
								skynet/dgpu.py
								
								
								
								
							| 
						 | 
				
			
			@ -12,7 +12,7 @@ import traceback
 | 
			
		|||
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from contextlib import AsyncExitStack
 | 
			
		||||
from contextlib import ExitStack
 | 
			
		||||
 | 
			
		||||
import pynng
 | 
			
		||||
import torch
 | 
			
		||||
| 
						 | 
				
			
			@ -61,6 +61,51 @@ class DGPUComputeError(BaseException):
 | 
			
		|||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReconnectingBus:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, address: str, tls_config: Optional[TLSConfig]):
 | 
			
		||||
        self.address = address
 | 
			
		||||
        self.tls_config = tls_config
 | 
			
		||||
 | 
			
		||||
        self._stack = ExitStack()
 | 
			
		||||
        self._sock = None
 | 
			
		||||
        self._closed = True
 | 
			
		||||
 | 
			
		||||
    def connect(self):
 | 
			
		||||
        self._sock = self._stack.enter_context(
 | 
			
		||||
            pynng.Bus0(recv_max_size=0))
 | 
			
		||||
        self._sock.tls_config = self.tls_config
 | 
			
		||||
        self._sock.dial(self.address)
 | 
			
		||||
        self._closed = False
 | 
			
		||||
 | 
			
		||||
    async def arecv(self):
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                return await self._sock.arecv()
 | 
			
		||||
 | 
			
		||||
            except pynng.exceptions.Closed:
 | 
			
		||||
                if self._closed:
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
    async def asend(self, msg):
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                return await self._sock.asend(msg)
 | 
			
		||||
 | 
			
		||||
            except pynng.exceptions.Closed:
 | 
			
		||||
                if self._closed:
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._stack.close()
 | 
			
		||||
        self._stack = ExitStack()
 | 
			
		||||
        self._closed = True
 | 
			
		||||
 | 
			
		||||
    def reconnect(self):
 | 
			
		||||
        self.close()
 | 
			
		||||
        self.connect()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def open_dgpu_node(
 | 
			
		||||
    cert_name: str,
 | 
			
		||||
    unique_id: str,
 | 
			
		||||
| 
						 | 
				
			
			@ -141,13 +186,16 @@ async def open_dgpu_node(
 | 
			
		|||
            torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async with open_skynet_rpc(
 | 
			
		||||
        unique_id,
 | 
			
		||||
        rpc_address=rpc_address,
 | 
			
		||||
        security=security,
 | 
			
		||||
        cert_name=cert_name,
 | 
			
		||||
        key_name=key_name
 | 
			
		||||
    ) as rpc_call:
 | 
			
		||||
    async with (
 | 
			
		||||
        open_skynet_rpc(
 | 
			
		||||
            unique_id,
 | 
			
		||||
            rpc_address=rpc_address,
 | 
			
		||||
            security=security,
 | 
			
		||||
            cert_name=cert_name,
 | 
			
		||||
            key_name=key_name
 | 
			
		||||
        ) as rpc_call,
 | 
			
		||||
        trio.open_nursery() as n
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        tls_config = None
 | 
			
		||||
        if security:
 | 
			
		||||
| 
						 | 
				
			
			@ -183,98 +231,118 @@ async def open_dgpu_node(
 | 
			
		|||
                ca_string=skynet_cert_data)
 | 
			
		||||
 | 
			
		||||
        logging.info(f'connecting to {dgpu_address}')
 | 
			
		||||
        with pynng.Bus0(recv_max_size=0) as dgpu_sock:
 | 
			
		||||
            dgpu_sock.tls_config = tls_config
 | 
			
		||||
            dgpu_sock.dial(dgpu_address)
 | 
			
		||||
 | 
			
		||||
            res = await rpc_call('dgpu_online')
 | 
			
		||||
            assert 'ok' in res.result
 | 
			
		||||
        dgpu_bus = ReconnectingBus(dgpu_address, tls_config)
 | 
			
		||||
        dgpu_bus.connect()
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                while True:
 | 
			
		||||
                    req = DGPUBusMessage()
 | 
			
		||||
                    req.ParseFromString(await dgpu_sock.arecv())
 | 
			
		||||
        last_msg = time.time()
 | 
			
		||||
        async def connection_refresher(refresh_time: int = 120):
 | 
			
		||||
            nonlocal last_msg
 | 
			
		||||
            while True:
 | 
			
		||||
                now = time.time()
 | 
			
		||||
                last_msg_time_delta = now - last_msg
 | 
			
		||||
                logging.info(f'time since last msg: {last_msg_time_delta}')
 | 
			
		||||
                if last_msg_time_delta > refresh_time:
 | 
			
		||||
                    dgpu_bus.reconnect()
 | 
			
		||||
                    logging.info('reconnected!')
 | 
			
		||||
                    last_msg = now
 | 
			
		||||
 | 
			
		||||
                    if req.method == 'heartbeat':
 | 
			
		||||
                        rep = DGPUBusMessage(
 | 
			
		||||
                            rid=req.rid,
 | 
			
		||||
                            nid=unique_id,
 | 
			
		||||
                            method=req.method
 | 
			
		||||
                        )
 | 
			
		||||
                        rep.params.update({'time': int(time.time() * 1000)})
 | 
			
		||||
                await trio.sleep(refresh_time)
 | 
			
		||||
 | 
			
		||||
                        if security:
 | 
			
		||||
                            rep.auth.cert = cert_name
 | 
			
		||||
                            rep.auth.sig = sign_protobuf_msg(rep, tls_key)
 | 
			
		||||
        n.start_soon(connection_refresher)
 | 
			
		||||
 | 
			
		||||
                        await dgpu_sock.asend(rep.SerializeToString())
 | 
			
		||||
                        continue
 | 
			
		||||
        res = await rpc_call('dgpu_online')
 | 
			
		||||
        assert 'ok' in res.result
 | 
			
		||||
 | 
			
		||||
                    if req.nid != unique_id:
 | 
			
		||||
                        logging.info(
 | 
			
		||||
                            f'witnessed msg {req.rid}, node involved: {req.nid}')
 | 
			
		||||
                        continue
 | 
			
		||||
        try:
 | 
			
		||||
            while True:
 | 
			
		||||
                req = DGPUBusMessage()
 | 
			
		||||
                req.ParseFromString(await dgpu_bus.arecv())
 | 
			
		||||
                last_msg = time.time()
 | 
			
		||||
 | 
			
		||||
                if req.method == 'heartbeat':
 | 
			
		||||
                    rep = DGPUBusMessage(
 | 
			
		||||
                        rid=req.rid,
 | 
			
		||||
                        nid=unique_id,
 | 
			
		||||
                        method=req.method
 | 
			
		||||
                    )
 | 
			
		||||
                    rep.params.update({'time': int(time.time() * 1000)})
 | 
			
		||||
 | 
			
		||||
                    if security:
 | 
			
		||||
                        verify_protobuf_msg(req, skynet_cert)
 | 
			
		||||
                        rep.auth.cert = cert_name
 | 
			
		||||
                        rep.auth.sig = sign_protobuf_msg(rep, tls_key)
 | 
			
		||||
 | 
			
		||||
                    await dgpu_bus.asend(rep.SerializeToString())
 | 
			
		||||
                    logging.info('heartbeat reply')
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                if req.nid != unique_id:
 | 
			
		||||
                    logging.info(
 | 
			
		||||
                        f'witnessed msg {req.rid}, node involved: {req.nid}')
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                if security:
 | 
			
		||||
                    verify_protobuf_msg(req, skynet_cert)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                    ack_resp = DGPUBusMessage(
 | 
			
		||||
                ack_resp = DGPUBusMessage(
 | 
			
		||||
                    rid=req.rid,
 | 
			
		||||
                    nid=req.nid
 | 
			
		||||
                )
 | 
			
		||||
                ack_resp.params.update({'ack': {}})
 | 
			
		||||
 | 
			
		||||
                if security:
 | 
			
		||||
                    ack_resp.auth.cert = cert_name
 | 
			
		||||
                    ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
 | 
			
		||||
 | 
			
		||||
                # send ack
 | 
			
		||||
                await dgpu_bus.asend(ack_resp.SerializeToString())
 | 
			
		||||
 | 
			
		||||
                logging.info(f'sent ack, processing {req.rid}...')
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    img_req = Text2ImageParameters(**req.params)
 | 
			
		||||
                    if not img_req.seed:
 | 
			
		||||
                        img_req.seed = random.randint(0, 2 ** 64)
 | 
			
		||||
 | 
			
		||||
                    img = await gpu_compute_one(img_req)
 | 
			
		||||
                    img_resp = DGPUBusMessage(
 | 
			
		||||
                        rid=req.rid,
 | 
			
		||||
                        nid=req.nid,
 | 
			
		||||
                        method='binary-reply'
 | 
			
		||||
                    )
 | 
			
		||||
                    img_resp.params.update({
 | 
			
		||||
                        'len': len(img),
 | 
			
		||||
                        'meta': img_req.to_dict()
 | 
			
		||||
                    })
 | 
			
		||||
 | 
			
		||||
                except DGPUComputeError as e:
 | 
			
		||||
                    traceback.print_exception(type(e), e, e.__traceback__)
 | 
			
		||||
                    img_resp = DGPUBusMessage(
 | 
			
		||||
                        rid=req.rid,
 | 
			
		||||
                        nid=req.nid
 | 
			
		||||
                    )
 | 
			
		||||
                    ack_resp.params.update({'ack': {}})
 | 
			
		||||
 | 
			
		||||
                    if security:
 | 
			
		||||
                        ack_resp.auth.cert = cert_name
 | 
			
		||||
                        ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
 | 
			
		||||
 | 
			
		||||
                    # send ack
 | 
			
		||||
                    await dgpu_sock.asend(ack_resp.SerializeToString())
 | 
			
		||||
 | 
			
		||||
                    logging.info(f'sent ack, processing {req.rid}...')
 | 
			
		||||
 | 
			
		||||
                    try:
 | 
			
		||||
                        img_req = Text2ImageParameters(**req.params)
 | 
			
		||||
                        if not img_req.seed:
 | 
			
		||||
                            img_req.seed = random.randint(0, 2 ** 64)
 | 
			
		||||
 | 
			
		||||
                        img = await gpu_compute_one(img_req)
 | 
			
		||||
                        img_resp = DGPUBusMessage(
 | 
			
		||||
                            rid=req.rid,
 | 
			
		||||
                            nid=req.nid,
 | 
			
		||||
                            method='binary-reply'
 | 
			
		||||
                        )
 | 
			
		||||
                        img_resp.params.update({
 | 
			
		||||
                            'len': len(img),
 | 
			
		||||
                            'meta': img_req.to_dict()
 | 
			
		||||
                        })
 | 
			
		||||
 | 
			
		||||
                    except DGPUComputeError as e:
 | 
			
		||||
                        traceback.print_exception(type(e), e, e.__traceback__)
 | 
			
		||||
                        img_resp = DGPUBusMessage(
 | 
			
		||||
                            rid=req.rid,
 | 
			
		||||
                            nid=req.nid
 | 
			
		||||
                        )
 | 
			
		||||
                        img_resp.params.update({'error': str(e)})
 | 
			
		||||
                    img_resp.params.update({'error': str(e)})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                    if security:
 | 
			
		||||
                        img_resp.auth.cert = cert_name
 | 
			
		||||
                        img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
 | 
			
		||||
                if security:
 | 
			
		||||
                    img_resp.auth.cert = cert_name
 | 
			
		||||
                    img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
 | 
			
		||||
 | 
			
		||||
                    # send final image
 | 
			
		||||
                    logging.info('sending img back...')
 | 
			
		||||
                    raw_msg = img_resp.SerializeToString()
 | 
			
		||||
                    await dgpu_sock.asend(raw_msg)
 | 
			
		||||
                    logging.info(f'sent {len(raw_msg)} bytes.')
 | 
			
		||||
                    if img_resp.method == 'binary-reply':
 | 
			
		||||
                        await dgpu_sock.asend(img)
 | 
			
		||||
                        logging.info(f'sent {len(img)} bytes.')
 | 
			
		||||
                # send final image
 | 
			
		||||
                logging.info('sending img back...')
 | 
			
		||||
                raw_msg = img_resp.SerializeToString()
 | 
			
		||||
                await dgpu_bus.asend(raw_msg)
 | 
			
		||||
                logging.info(f'sent {len(raw_msg)} bytes.')
 | 
			
		||||
                if img_resp.method == 'binary-reply':
 | 
			
		||||
                    await dgpu_bus.asend(img)
 | 
			
		||||
                    logging.info(f'sent {len(img)} bytes.')
 | 
			
		||||
 | 
			
		||||
            except KeyboardInterrupt:
 | 
			
		||||
                logging.info('interrupt caught, stopping...')
 | 
			
		||||
        except KeyboardInterrupt:
 | 
			
		||||
            logging.info('interrupt caught, stopping...')
 | 
			
		||||
            n.cancel_scope.cancel()
 | 
			
		||||
            dgpu_bus.close()
 | 
			
		||||
 | 
			
		||||
            finally:
 | 
			
		||||
                res = await rpc_call('dgpu_offline')
 | 
			
		||||
                assert 'ok' in res.result
 | 
			
		||||
        finally:
 | 
			
		||||
            res = await rpc_call('dgpu_offline')
 | 
			
		||||
            assert 'ok' in res.result
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue