mirror of https://github.com/skygpu/skynet.git
				
				
				
			Refactoring tui to be functional style
							parent
							
								
									12b32a7188
								
							
						
					
					
						commit
						5a3a43b3c6
					
				| 
						 | 
					@ -1,36 +1,17 @@
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
 | 
					import urwid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hypercorn.config import Config
 | 
					from hypercorn.config import Config
 | 
				
			||||||
from hypercorn.trio import serve
 | 
					from hypercorn.trio import serve
 | 
				
			||||||
from quart_trio import QuartTrio as Quart
 | 
					from quart_trio import QuartTrio as Quart
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.dgpu.tui import WorkerMonitor
 | 
					from skynet.dgpu.tui import init_tui
 | 
				
			||||||
from skynet.dgpu.daemon import WorkerDaemon
 | 
					from skynet.dgpu.daemon import WorkerDaemon
 | 
				
			||||||
from skynet.dgpu.network import NetConnector
 | 
					from skynet.dgpu.network import NetConnector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def setup_logging_for_tui(level):
 | 
					 | 
				
			||||||
    warnings.filterwarnings("ignore")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    logger = logging.getLogger()
 | 
					 | 
				
			||||||
    logger.setLevel(level)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    fh = logging.FileHandler('dgpu.log')
 | 
					 | 
				
			||||||
    fh.setLevel(level)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
 | 
					 | 
				
			||||||
    fh.setFormatter(formatter)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    logger.addHandler(fh)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for handler in logger.handlers:
 | 
					 | 
				
			||||||
        if isinstance(handler, logging.StreamHandler):
 | 
					 | 
				
			||||||
            logger.removeHandler(handler)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def open_dgpu_node(config: dict) -> None:
 | 
					async def open_dgpu_node(config: dict) -> None:
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    Open a top level "GPU mgmt daemon", keep the
 | 
					    Open a top level "GPU mgmt daemon", keep the
 | 
				
			||||||
| 
						 | 
					@ -43,11 +24,10 @@ async def open_dgpu_node(config: dict) -> None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tui = None
 | 
					    tui = None
 | 
				
			||||||
    if config['tui']:
 | 
					    if config['tui']:
 | 
				
			||||||
        setup_logging_for_tui(logging.INFO)
 | 
					        tui = init_tui()
 | 
				
			||||||
        tui = WorkerMonitor()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    conn = NetConnector(config, tui=tui)
 | 
					    conn = NetConnector(config)
 | 
				
			||||||
    daemon = WorkerDaemon(conn, config, tui=tui)
 | 
					    daemon = WorkerDaemon(conn, config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    api: Quart|None = None
 | 
					    api: Quart|None = None
 | 
				
			||||||
    if 'api_bind' in config:
 | 
					    if 'api_bind' in config:
 | 
				
			||||||
| 
						 | 
					@ -71,5 +51,5 @@ async def open_dgpu_node(config: dict) -> None:
 | 
				
			||||||
            # block until cancelled
 | 
					            # block until cancelled
 | 
				
			||||||
            await daemon.serve_forever()
 | 
					            await daemon.serve_forever()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except *urwid.ExitMainLoop in ex_group:
 | 
					        except *urwid.ExitMainLoop:
 | 
				
			||||||
            ...
 | 
					            ...
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,7 +12,7 @@ from contextlib import contextmanager as cm
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.dgpu.tui import WorkerMonitor
 | 
					from skynet.dgpu.tui import maybe_update_tui
 | 
				
			||||||
from skynet.dgpu.errors import (
 | 
					from skynet.dgpu.errors import (
 | 
				
			||||||
    DGPUComputeError,
 | 
					    DGPUComputeError,
 | 
				
			||||||
    DGPUInferenceCancelled,
 | 
					    DGPUInferenceCancelled,
 | 
				
			||||||
| 
						 | 
					@ -108,8 +108,7 @@ def compute_one(
 | 
				
			||||||
    method: str,
 | 
					    method: str,
 | 
				
			||||||
    params: dict,
 | 
					    params: dict,
 | 
				
			||||||
    inputs: list[bytes] = [],
 | 
					    inputs: list[bytes] = [],
 | 
				
			||||||
    should_cancel = None,
 | 
					    should_cancel = None
 | 
				
			||||||
    tui: WorkerMonitor | None = None
 | 
					 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    if method == 'diffuse':
 | 
					    if method == 'diffuse':
 | 
				
			||||||
        method = 'txt2img'
 | 
					        method = 'txt2img'
 | 
				
			||||||
| 
						 | 
					@ -130,8 +129,7 @@ def compute_one(
 | 
				
			||||||
        if not isinstance(step, int):
 | 
					        if not isinstance(step, int):
 | 
				
			||||||
            step = args[1]
 | 
					            step = args[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if tui:
 | 
					        maybe_update_tui(lambda tui: tui.set_progress(step, done=total_steps))
 | 
				
			||||||
            tui.set_progress(step, done=total_steps)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if should_cancel:
 | 
					        if should_cancel:
 | 
				
			||||||
            should_raise = trio.from_thread.run(should_cancel, request_id)
 | 
					            should_raise = trio.from_thread.run(should_cancel, request_id)
 | 
				
			||||||
| 
						 | 
					@ -142,8 +140,7 @@ def compute_one(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return {}
 | 
					        return {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if tui:
 | 
					    maybe_update_tui(lambda tui: tui.set_status(f'Request #{request_id}'))
 | 
				
			||||||
        tui.set_status(f'Request #{request_id}')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    inference_step_wakeup(0)
 | 
					    inference_step_wakeup(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -210,7 +207,6 @@ def compute_one(
 | 
				
			||||||
    except BaseException as err:
 | 
					    except BaseException as err:
 | 
				
			||||||
        raise DGPUComputeError(str(err)) from err
 | 
					        raise DGPUComputeError(str(err)) from err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if tui:
 | 
					    maybe_update_tui(lambda tui: tui.set_status(''))
 | 
				
			||||||
        tui.set_status('')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return output_hash, output
 | 
					    return output_hash, output
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@ from skynet.constants import (
 | 
				
			||||||
from skynet.dgpu.errors import (
 | 
					from skynet.dgpu.errors import (
 | 
				
			||||||
    DGPUComputeError,
 | 
					    DGPUComputeError,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from skynet.dgpu.tui import WorkerMonitor
 | 
					from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
 | 
				
			||||||
from skynet.dgpu.compute import maybe_load_model, compute_one
 | 
					from skynet.dgpu.compute import maybe_load_model, compute_one
 | 
				
			||||||
from skynet.dgpu.network import NetConnector
 | 
					from skynet.dgpu.network import NetConnector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,11 +41,9 @@ class WorkerDaemon:
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        conn: NetConnector,
 | 
					        conn: NetConnector,
 | 
				
			||||||
        config: dict,
 | 
					        config: dict
 | 
				
			||||||
        tui: WorkerMonitor | None = None
 | 
					 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        self.conn: NetConnector = conn
 | 
					        self.conn: NetConnector = conn
 | 
				
			||||||
        self._tui = tui
 | 
					 | 
				
			||||||
        self.auto_withdraw = (
 | 
					        self.auto_withdraw = (
 | 
				
			||||||
            config['auto_withdraw']
 | 
					            config['auto_withdraw']
 | 
				
			||||||
            if 'auto_withdraw' in config else False
 | 
					            if 'auto_withdraw' in config else False
 | 
				
			||||||
| 
						 | 
					@ -152,10 +150,12 @@ class WorkerDaemon:
 | 
				
			||||||
        return app
 | 
					        return app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _update_balance(self):
 | 
					    async def _update_balance(self):
 | 
				
			||||||
        if self._tui:
 | 
					        async def _fn(tui):
 | 
				
			||||||
            # update balance
 | 
					            # update balance
 | 
				
			||||||
            balance = await self.conn.get_worker_balance()
 | 
					            balance = await self.conn.get_worker_balance()
 | 
				
			||||||
            self._tui.set_header_text(new_balance=f'balance: {balance}')
 | 
					            tui.set_header_text(new_balance=f'balance: {balance}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await maybe_update_tui_async(_fn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO? this func is kinda big and maybe is better at module
 | 
					    # TODO? this func is kinda big and maybe is better at module
 | 
				
			||||||
    # level to reduce indentation?
 | 
					    # level to reduce indentation?
 | 
				
			||||||
| 
						 | 
					@ -258,8 +258,7 @@ class WorkerDaemon:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with maybe_load_model(model, mode):
 | 
					        with maybe_load_model(model, mode):
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                if self._tui:
 | 
					                maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
 | 
				
			||||||
                    self._tui.set_progress(0, done=total_step)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                output_type = 'png'
 | 
					                output_type = 'png'
 | 
				
			||||||
                if 'output_type' in body['params']:
 | 
					                if 'output_type' in body['params']:
 | 
				
			||||||
| 
						 | 
					@ -276,7 +275,6 @@ class WorkerDaemon:
 | 
				
			||||||
                                mode, body['params'],
 | 
					                                mode, body['params'],
 | 
				
			||||||
                                inputs=inputs,
 | 
					                                inputs=inputs,
 | 
				
			||||||
                                should_cancel=self.should_cancel_work,
 | 
					                                should_cancel=self.should_cancel_work,
 | 
				
			||||||
                                tui=self._tui
 | 
					 | 
				
			||||||
                            )
 | 
					                            )
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -285,8 +283,7 @@ class WorkerDaemon:
 | 
				
			||||||
                            f'Unsupported backend {self.backend}'
 | 
					                            f'Unsupported backend {self.backend}'
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if self._tui:
 | 
					                maybe_update_tui(lambda tui: tui.set_progress(total_step))
 | 
				
			||||||
                    self._tui.set_progress(total_step)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                self._last_generation_ts: str = datetime.now().isoformat()
 | 
					                self._last_generation_ts: str = datetime.now().isoformat()
 | 
				
			||||||
                self._last_benchmark: list[float] = self._benchmark
 | 
					                self._last_benchmark: list[float] = self._benchmark
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,7 @@ import outcome
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
from leap.cleos import CLEOS
 | 
					from leap.cleos import CLEOS
 | 
				
			||||||
from leap.protocol import Asset
 | 
					from leap.protocol import Asset
 | 
				
			||||||
from skynet.dgpu.tui import WorkerMonitor
 | 
					from skynet.dgpu.tui import maybe_update_tui
 | 
				
			||||||
from skynet.constants import (
 | 
					from skynet.constants import (
 | 
				
			||||||
    DEFAULT_IPFS_DOMAIN,
 | 
					    DEFAULT_IPFS_DOMAIN,
 | 
				
			||||||
    GPU_CONTRACT_ABI,
 | 
					    GPU_CONTRACT_ABI,
 | 
				
			||||||
| 
						 | 
					@ -58,7 +58,7 @@ class NetConnector:
 | 
				
			||||||
    - CLEOS client
 | 
					    - CLEOS client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    def __init__(self, config: dict, tui: WorkerMonitor | None = None):
 | 
					    def __init__(self, config: dict):
 | 
				
			||||||
        # TODO, why these extra instance vars for an (unsynced)
 | 
					        # TODO, why these extra instance vars for an (unsynced)
 | 
				
			||||||
        # copy of the `config` state?
 | 
					        # copy of the `config` state?
 | 
				
			||||||
        self.account = config['account']
 | 
					        self.account = config['account']
 | 
				
			||||||
| 
						 | 
					@ -82,9 +82,8 @@ class NetConnector:
 | 
				
			||||||
            self.ipfs_domain = config['ipfs_domain']
 | 
					            self.ipfs_domain = config['ipfs_domain']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._wip_requests = {}
 | 
					        self._wip_requests = {}
 | 
				
			||||||
        self._tui = tui
 | 
					
 | 
				
			||||||
        if self._tui:
 | 
					        maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.account))
 | 
				
			||||||
            self._tui.set_header_text(new_worker_name=self.account)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # blockchain helpers
 | 
					    # blockchain helpers
 | 
				
			||||||
| 
						 | 
					@ -173,8 +172,8 @@ class NetConnector:
 | 
				
			||||||
                n.start_soon(
 | 
					                n.start_soon(
 | 
				
			||||||
                    _run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
 | 
					                    _run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self._tui:
 | 
					
 | 
				
			||||||
            self._tui.network_update(snap)
 | 
					        maybe_update_tui(lambda tui: tui.network_update(snap))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return snap
 | 
					        return snap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,9 @@
 | 
				
			||||||
import urwid
 | 
					 | 
				
			||||||
import trio
 | 
					 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import trio
 | 
				
			||||||
 | 
					import urwid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class WorkerMonitor:
 | 
					class WorkerMonitor:
 | 
				
			||||||
| 
						 | 
					@ -163,86 +166,41 @@ class WorkerMonitor:
 | 
				
			||||||
        self.update_requests(queue)
 | 
					        self.update_requests(queue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# # -----------------------------------------------------------------------------
 | 
					def setup_logging_for_tui(level):
 | 
				
			||||||
# # Example usage
 | 
					    warnings.filterwarnings("ignore")
 | 
				
			||||||
# # -----------------------------------------------------------------------------
 | 
					
 | 
				
			||||||
# 
 | 
					    logger = logging.getLogger()
 | 
				
			||||||
# async def main():
 | 
					    logger.setLevel(level)
 | 
				
			||||||
#     # Example data
 | 
					
 | 
				
			||||||
#     example_requests = [
 | 
					    fh = logging.FileHandler('dgpu.log')
 | 
				
			||||||
#         {
 | 
					    fh.setLevel(level)
 | 
				
			||||||
#             "id": 12,
 | 
					
 | 
				
			||||||
#             "model": "black-forest-labs/FLUX.1-schnell",
 | 
					    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
 | 
				
			||||||
#             "prompt": "Generate an answer about quantum entanglement.",
 | 
					    fh.setFormatter(formatter)
 | 
				
			||||||
#             "user": "alice123",
 | 
					
 | 
				
			||||||
#             "reward": "20.0000 GPU",
 | 
					    logger.addHandler(fh)
 | 
				
			||||||
#             "workers": ["workerA", "workerB"],
 | 
					
 | 
				
			||||||
#         },
 | 
					    for handler in logger.handlers:
 | 
				
			||||||
#         {
 | 
					        if isinstance(handler, logging.StreamHandler):
 | 
				
			||||||
#             "id": 5,
 | 
					            logger.removeHandler(handler)
 | 
				
			||||||
#             "model": "some-other-model/v2.0",
 | 
					
 | 
				
			||||||
#             "prompt": "A story about dragons.",
 | 
					
 | 
				
			||||||
#             "user": "bobthebuilder",
 | 
					_tui = None
 | 
				
			||||||
#             "reward": "15.0000 GPU",
 | 
					def init_tui():
 | 
				
			||||||
#             "workers": ["workerX"],
 | 
					    global _tui
 | 
				
			||||||
#         },
 | 
					    assert not _tui
 | 
				
			||||||
#         {
 | 
					    setup_logging_for_tui(logging.INFO)
 | 
				
			||||||
#             "id": 99,
 | 
					    _tui = WorkerMonitor()
 | 
				
			||||||
#             "model": "cool-model/turbo",
 | 
					    return _tui
 | 
				
			||||||
#             "prompt": "Classify sentiment in these tweets.",
 | 
					
 | 
				
			||||||
#             "user": "charlie",
 | 
					
 | 
				
			||||||
#             "reward": "25.5000 GPU",
 | 
					def maybe_update_tui(fn):
 | 
				
			||||||
#             "workers": ["workerOne", "workerTwo", "workerThree"],
 | 
					    global _tui
 | 
				
			||||||
#         },
 | 
					    if _tui:
 | 
				
			||||||
#     ]
 | 
					        fn(_tui)
 | 
				
			||||||
# 
 | 
					
 | 
				
			||||||
#     ui = WorkerMonitor()
 | 
					
 | 
				
			||||||
# 
 | 
					async def maybe_update_tui_async(fn):
 | 
				
			||||||
#     async def progress_task():
 | 
					    global _tui
 | 
				
			||||||
#         # Fill from 0% to 100%
 | 
					    if _tui:
 | 
				
			||||||
#         for pct in range(101):
 | 
					        await fn(_tui)
 | 
				
			||||||
#             ui.set_progress(
 | 
					 | 
				
			||||||
#                 current=pct,
 | 
					 | 
				
			||||||
#                 status_str=f"Request #1234 ({pct}%)"
 | 
					 | 
				
			||||||
#             )
 | 
					 | 
				
			||||||
#             await trio.sleep(0.05)
 | 
					 | 
				
			||||||
#         # Reset to 0
 | 
					 | 
				
			||||||
#         ui.set_progress(
 | 
					 | 
				
			||||||
#             current=0,
 | 
					 | 
				
			||||||
#             status_str="Starting again..."
 | 
					 | 
				
			||||||
#         )
 | 
					 | 
				
			||||||
# 
 | 
					 | 
				
			||||||
#     async def update_data_task():
 | 
					 | 
				
			||||||
#         await trio.sleep(3)  # Wait a bit, then update requests
 | 
					 | 
				
			||||||
#         new_data = [{
 | 
					 | 
				
			||||||
#             "id": 101,
 | 
					 | 
				
			||||||
#             "model": "new-model/v1.0",
 | 
					 | 
				
			||||||
#             "prompt": "Say hi to the world.",
 | 
					 | 
				
			||||||
#             "user": "eve",
 | 
					 | 
				
			||||||
#             "reward": "50.0000 GPU",
 | 
					 | 
				
			||||||
#             "workers": ["workerFresh", "workerPower"],
 | 
					 | 
				
			||||||
#         }]
 | 
					 | 
				
			||||||
#         ui.update_requests(new_data)
 | 
					 | 
				
			||||||
#         ui.set_header_text(new_worker_name="NewNodeName",
 | 
					 | 
				
			||||||
#                             new_balance="balance: 12345.6789 GPU")
 | 
					 | 
				
			||||||
# 
 | 
					 | 
				
			||||||
#     try:
 | 
					 | 
				
			||||||
#         async with trio.open_nursery() as nursery:
 | 
					 | 
				
			||||||
#             # Run the TUI
 | 
					 | 
				
			||||||
#             nursery.start_soon(ui.run_teadown_on_exit, nursery)
 | 
					 | 
				
			||||||
# 
 | 
					 | 
				
			||||||
#             ui.update_requests(example_requests)
 | 
					 | 
				
			||||||
#             ui.set_header_text(
 | 
					 | 
				
			||||||
#                 new_worker_name="worker1.scd",
 | 
					 | 
				
			||||||
#                 new_balance="balance: 12345.6789 GPU"
 | 
					 | 
				
			||||||
#             )
 | 
					 | 
				
			||||||
#             # Start background tasks
 | 
					 | 
				
			||||||
#             nursery.start_soon(progress_task)
 | 
					 | 
				
			||||||
#             nursery.start_soon(update_data_task)
 | 
					 | 
				
			||||||
# 
 | 
					 | 
				
			||||||
#     except *KeyboardInterrupt as ex_group:
 | 
					 | 
				
			||||||
#         ...
 | 
					 | 
				
			||||||
# 
 | 
					 | 
				
			||||||
# 
 | 
					 | 
				
			||||||
# if __name__ == "__main__":
 | 
					 | 
				
			||||||
#     trio.run(main)
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue