mirror of https://github.com/skygpu/skynet.git
				
				
				
			Refactoring tui to be functional style
							parent
							
								
									12b32a7188
								
							
						
					
					
						commit
						5a3a43b3c6
					
				| 
						 | 
				
			
			@ -1,36 +1,17 @@
 | 
			
		|||
import logging
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import urwid
 | 
			
		||||
 | 
			
		||||
from hypercorn.config import Config
 | 
			
		||||
from hypercorn.trio import serve
 | 
			
		||||
from quart_trio import QuartTrio as Quart
 | 
			
		||||
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.dgpu.tui import init_tui
 | 
			
		||||
from skynet.dgpu.daemon import WorkerDaemon
 | 
			
		||||
from skynet.dgpu.network import NetConnector
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def setup_logging_for_tui(level):
 | 
			
		||||
    warnings.filterwarnings("ignore")
 | 
			
		||||
 | 
			
		||||
    logger = logging.getLogger()
 | 
			
		||||
    logger.setLevel(level)
 | 
			
		||||
 | 
			
		||||
    fh = logging.FileHandler('dgpu.log')
 | 
			
		||||
    fh.setLevel(level)
 | 
			
		||||
 | 
			
		||||
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
 | 
			
		||||
    fh.setFormatter(formatter)
 | 
			
		||||
 | 
			
		||||
    logger.addHandler(fh)
 | 
			
		||||
 | 
			
		||||
    for handler in logger.handlers:
 | 
			
		||||
        if isinstance(handler, logging.StreamHandler):
 | 
			
		||||
            logger.removeHandler(handler)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def open_dgpu_node(config: dict) -> None:
 | 
			
		||||
    '''
 | 
			
		||||
    Open a top level "GPU mgmt daemon", keep the
 | 
			
		||||
| 
						 | 
				
			
			@ -43,11 +24,10 @@ async def open_dgpu_node(config: dict) -> None:
 | 
			
		|||
 | 
			
		||||
    tui = None
 | 
			
		||||
    if config['tui']:
 | 
			
		||||
        setup_logging_for_tui(logging.INFO)
 | 
			
		||||
        tui = WorkerMonitor()
 | 
			
		||||
        tui = init_tui()
 | 
			
		||||
 | 
			
		||||
    conn = NetConnector(config, tui=tui)
 | 
			
		||||
    daemon = WorkerDaemon(conn, config, tui=tui)
 | 
			
		||||
    conn = NetConnector(config)
 | 
			
		||||
    daemon = WorkerDaemon(conn, config)
 | 
			
		||||
 | 
			
		||||
    api: Quart|None = None
 | 
			
		||||
    if 'api_bind' in config:
 | 
			
		||||
| 
						 | 
				
			
			@ -71,5 +51,5 @@ async def open_dgpu_node(config: dict) -> None:
 | 
			
		|||
            # block until cancelled
 | 
			
		||||
            await daemon.serve_forever()
 | 
			
		||||
 | 
			
		||||
        except *urwid.ExitMainLoop in ex_group:
 | 
			
		||||
        except *urwid.ExitMainLoop:
 | 
			
		||||
            ...
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,7 +12,7 @@ from contextlib import contextmanager as cm
 | 
			
		|||
import trio
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.dgpu.tui import maybe_update_tui
 | 
			
		||||
from skynet.dgpu.errors import (
 | 
			
		||||
    DGPUComputeError,
 | 
			
		||||
    DGPUInferenceCancelled,
 | 
			
		||||
| 
						 | 
				
			
			@ -108,8 +108,7 @@ def compute_one(
 | 
			
		|||
    method: str,
 | 
			
		||||
    params: dict,
 | 
			
		||||
    inputs: list[bytes] = [],
 | 
			
		||||
    should_cancel = None,
 | 
			
		||||
    tui: WorkerMonitor | None = None
 | 
			
		||||
    should_cancel = None
 | 
			
		||||
):
 | 
			
		||||
    if method == 'diffuse':
 | 
			
		||||
        method = 'txt2img'
 | 
			
		||||
| 
						 | 
				
			
			@ -130,8 +129,7 @@ def compute_one(
 | 
			
		|||
        if not isinstance(step, int):
 | 
			
		||||
            step = args[1]
 | 
			
		||||
 | 
			
		||||
        if tui:
 | 
			
		||||
            tui.set_progress(step, done=total_steps)
 | 
			
		||||
        maybe_update_tui(lambda tui: tui.set_progress(step, done=total_steps))
 | 
			
		||||
 | 
			
		||||
        if should_cancel:
 | 
			
		||||
            should_raise = trio.from_thread.run(should_cancel, request_id)
 | 
			
		||||
| 
						 | 
				
			
			@ -142,8 +140,7 @@ def compute_one(
 | 
			
		|||
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
    if tui:
 | 
			
		||||
        tui.set_status(f'Request #{request_id}')
 | 
			
		||||
    maybe_update_tui(lambda tui: tui.set_status(f'Request #{request_id}'))
 | 
			
		||||
 | 
			
		||||
    inference_step_wakeup(0)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -210,7 +207,6 @@ def compute_one(
 | 
			
		|||
    except BaseException as err:
 | 
			
		||||
        raise DGPUComputeError(str(err)) from err
 | 
			
		||||
 | 
			
		||||
    if tui:
 | 
			
		||||
        tui.set_status('')
 | 
			
		||||
    maybe_update_tui(lambda tui: tui.set_status(''))
 | 
			
		||||
 | 
			
		||||
    return output_hash, output
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,7 +17,7 @@ from skynet.constants import (
 | 
			
		|||
from skynet.dgpu.errors import (
 | 
			
		||||
    DGPUComputeError,
 | 
			
		||||
)
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
 | 
			
		||||
from skynet.dgpu.compute import maybe_load_model, compute_one
 | 
			
		||||
from skynet.dgpu.network import NetConnector
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -41,11 +41,9 @@ class WorkerDaemon:
 | 
			
		|||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        conn: NetConnector,
 | 
			
		||||
        config: dict,
 | 
			
		||||
        tui: WorkerMonitor | None = None
 | 
			
		||||
        config: dict
 | 
			
		||||
    ):
 | 
			
		||||
        self.conn: NetConnector = conn
 | 
			
		||||
        self._tui = tui
 | 
			
		||||
        self.auto_withdraw = (
 | 
			
		||||
            config['auto_withdraw']
 | 
			
		||||
            if 'auto_withdraw' in config else False
 | 
			
		||||
| 
						 | 
				
			
			@ -152,10 +150,12 @@ class WorkerDaemon:
 | 
			
		|||
        return app
 | 
			
		||||
 | 
			
		||||
    async def _update_balance(self):
 | 
			
		||||
        if self._tui:
 | 
			
		||||
        async def _fn(tui):
 | 
			
		||||
            # update balance
 | 
			
		||||
            balance = await self.conn.get_worker_balance()
 | 
			
		||||
            self._tui.set_header_text(new_balance=f'balance: {balance}')
 | 
			
		||||
            tui.set_header_text(new_balance=f'balance: {balance}')
 | 
			
		||||
 | 
			
		||||
        await maybe_update_tui_async(_fn)
 | 
			
		||||
 | 
			
		||||
    # TODO? this func is kinda big and maybe is better at module
 | 
			
		||||
    # level to reduce indentation?
 | 
			
		||||
| 
						 | 
				
			
			@ -258,8 +258,7 @@ class WorkerDaemon:
 | 
			
		|||
 | 
			
		||||
        with maybe_load_model(model, mode):
 | 
			
		||||
            try:
 | 
			
		||||
                if self._tui:
 | 
			
		||||
                    self._tui.set_progress(0, done=total_step)
 | 
			
		||||
                maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
 | 
			
		||||
 | 
			
		||||
                output_type = 'png'
 | 
			
		||||
                if 'output_type' in body['params']:
 | 
			
		||||
| 
						 | 
				
			
			@ -276,7 +275,6 @@ class WorkerDaemon:
 | 
			
		|||
                                mode, body['params'],
 | 
			
		||||
                                inputs=inputs,
 | 
			
		||||
                                should_cancel=self.should_cancel_work,
 | 
			
		||||
                                tui=self._tui
 | 
			
		||||
                            )
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -285,8 +283,7 @@ class WorkerDaemon:
 | 
			
		|||
                            f'Unsupported backend {self.backend}'
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                if self._tui:
 | 
			
		||||
                    self._tui.set_progress(total_step)
 | 
			
		||||
                maybe_update_tui(lambda tui: tui.set_progress(total_step))
 | 
			
		||||
 | 
			
		||||
                self._last_generation_ts: str = datetime.now().isoformat()
 | 
			
		||||
                self._last_benchmark: list[float] = self._benchmark
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,7 +13,7 @@ import outcome
 | 
			
		|||
from PIL import Image
 | 
			
		||||
from leap.cleos import CLEOS
 | 
			
		||||
from leap.protocol import Asset
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.dgpu.tui import maybe_update_tui
 | 
			
		||||
from skynet.constants import (
 | 
			
		||||
    DEFAULT_IPFS_DOMAIN,
 | 
			
		||||
    GPU_CONTRACT_ABI,
 | 
			
		||||
| 
						 | 
				
			
			@ -58,7 +58,7 @@ class NetConnector:
 | 
			
		|||
    - CLEOS client
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(self, config: dict, tui: WorkerMonitor | None = None):
 | 
			
		||||
    def __init__(self, config: dict):
 | 
			
		||||
        # TODO, why these extra instance vars for an (unsynced)
 | 
			
		||||
        # copy of the `config` state?
 | 
			
		||||
        self.account = config['account']
 | 
			
		||||
| 
						 | 
				
			
			@ -82,9 +82,8 @@ class NetConnector:
 | 
			
		|||
            self.ipfs_domain = config['ipfs_domain']
 | 
			
		||||
 | 
			
		||||
        self._wip_requests = {}
 | 
			
		||||
        self._tui = tui
 | 
			
		||||
        if self._tui:
 | 
			
		||||
            self._tui.set_header_text(new_worker_name=self.account)
 | 
			
		||||
 | 
			
		||||
        maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.account))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # blockchain helpers
 | 
			
		||||
| 
						 | 
				
			
			@ -173,8 +172,8 @@ class NetConnector:
 | 
			
		|||
                n.start_soon(
 | 
			
		||||
                    _run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
 | 
			
		||||
 | 
			
		||||
        if self._tui:
 | 
			
		||||
            self._tui.network_update(snap)
 | 
			
		||||
 | 
			
		||||
        maybe_update_tui(lambda tui: tui.network_update(snap))
 | 
			
		||||
 | 
			
		||||
        return snap
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,9 @@
 | 
			
		|||
import urwid
 | 
			
		||||
import trio
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import urwid
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WorkerMonitor:
 | 
			
		||||
| 
						 | 
				
			
			@ -163,86 +166,41 @@ class WorkerMonitor:
 | 
			
		|||
        self.update_requests(queue)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# # -----------------------------------------------------------------------------
 | 
			
		||||
# # Example usage
 | 
			
		||||
# # -----------------------------------------------------------------------------
 | 
			
		||||
# 
 | 
			
		||||
# async def main():
 | 
			
		||||
#     # Example data
 | 
			
		||||
#     example_requests = [
 | 
			
		||||
#         {
 | 
			
		||||
#             "id": 12,
 | 
			
		||||
#             "model": "black-forest-labs/FLUX.1-schnell",
 | 
			
		||||
#             "prompt": "Generate an answer about quantum entanglement.",
 | 
			
		||||
#             "user": "alice123",
 | 
			
		||||
#             "reward": "20.0000 GPU",
 | 
			
		||||
#             "workers": ["workerA", "workerB"],
 | 
			
		||||
#         },
 | 
			
		||||
#         {
 | 
			
		||||
#             "id": 5,
 | 
			
		||||
#             "model": "some-other-model/v2.0",
 | 
			
		||||
#             "prompt": "A story about dragons.",
 | 
			
		||||
#             "user": "bobthebuilder",
 | 
			
		||||
#             "reward": "15.0000 GPU",
 | 
			
		||||
#             "workers": ["workerX"],
 | 
			
		||||
#         },
 | 
			
		||||
#         {
 | 
			
		||||
#             "id": 99,
 | 
			
		||||
#             "model": "cool-model/turbo",
 | 
			
		||||
#             "prompt": "Classify sentiment in these tweets.",
 | 
			
		||||
#             "user": "charlie",
 | 
			
		||||
#             "reward": "25.5000 GPU",
 | 
			
		||||
#             "workers": ["workerOne", "workerTwo", "workerThree"],
 | 
			
		||||
#         },
 | 
			
		||||
#     ]
 | 
			
		||||
# 
 | 
			
		||||
#     ui = WorkerMonitor()
 | 
			
		||||
# 
 | 
			
		||||
#     async def progress_task():
 | 
			
		||||
#         # Fill from 0% to 100%
 | 
			
		||||
#         for pct in range(101):
 | 
			
		||||
#             ui.set_progress(
 | 
			
		||||
#                 current=pct,
 | 
			
		||||
#                 status_str=f"Request #1234 ({pct}%)"
 | 
			
		||||
#             )
 | 
			
		||||
#             await trio.sleep(0.05)
 | 
			
		||||
#         # Reset to 0
 | 
			
		||||
#         ui.set_progress(
 | 
			
		||||
#             current=0,
 | 
			
		||||
#             status_str="Starting again..."
 | 
			
		||||
#         )
 | 
			
		||||
# 
 | 
			
		||||
#     async def update_data_task():
 | 
			
		||||
#         await trio.sleep(3)  # Wait a bit, then update requests
 | 
			
		||||
#         new_data = [{
 | 
			
		||||
#             "id": 101,
 | 
			
		||||
#             "model": "new-model/v1.0",
 | 
			
		||||
#             "prompt": "Say hi to the world.",
 | 
			
		||||
#             "user": "eve",
 | 
			
		||||
#             "reward": "50.0000 GPU",
 | 
			
		||||
#             "workers": ["workerFresh", "workerPower"],
 | 
			
		||||
#         }]
 | 
			
		||||
#         ui.update_requests(new_data)
 | 
			
		||||
#         ui.set_header_text(new_worker_name="NewNodeName",
 | 
			
		||||
#                             new_balance="balance: 12345.6789 GPU")
 | 
			
		||||
# 
 | 
			
		||||
#     try:
 | 
			
		||||
#         async with trio.open_nursery() as nursery:
 | 
			
		||||
#             # Run the TUI
 | 
			
		||||
#             nursery.start_soon(ui.run_teadown_on_exit, nursery)
 | 
			
		||||
# 
 | 
			
		||||
#             ui.update_requests(example_requests)
 | 
			
		||||
#             ui.set_header_text(
 | 
			
		||||
#                 new_worker_name="worker1.scd",
 | 
			
		||||
#                 new_balance="balance: 12345.6789 GPU"
 | 
			
		||||
#             )
 | 
			
		||||
#             # Start background tasks
 | 
			
		||||
#             nursery.start_soon(progress_task)
 | 
			
		||||
#             nursery.start_soon(update_data_task)
 | 
			
		||||
# 
 | 
			
		||||
#     except *KeyboardInterrupt as ex_group:
 | 
			
		||||
#         ...
 | 
			
		||||
# 
 | 
			
		||||
# 
 | 
			
		||||
# if __name__ == "__main__":
 | 
			
		||||
#     trio.run(main)
 | 
			
		||||
def setup_logging_for_tui(level):
 | 
			
		||||
    warnings.filterwarnings("ignore")
 | 
			
		||||
 | 
			
		||||
    logger = logging.getLogger()
 | 
			
		||||
    logger.setLevel(level)
 | 
			
		||||
 | 
			
		||||
    fh = logging.FileHandler('dgpu.log')
 | 
			
		||||
    fh.setLevel(level)
 | 
			
		||||
 | 
			
		||||
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
 | 
			
		||||
    fh.setFormatter(formatter)
 | 
			
		||||
 | 
			
		||||
    logger.addHandler(fh)
 | 
			
		||||
 | 
			
		||||
    for handler in logger.handlers:
 | 
			
		||||
        if isinstance(handler, logging.StreamHandler):
 | 
			
		||||
            logger.removeHandler(handler)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_tui = None
 | 
			
		||||
def init_tui():
 | 
			
		||||
    global _tui
 | 
			
		||||
    assert not _tui
 | 
			
		||||
    setup_logging_for_tui(logging.INFO)
 | 
			
		||||
    _tui = WorkerMonitor()
 | 
			
		||||
    return _tui
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def maybe_update_tui(fn):
 | 
			
		||||
    global _tui
 | 
			
		||||
    if _tui:
 | 
			
		||||
        fn(_tui)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def maybe_update_tui_async(fn):
 | 
			
		||||
    global _tui
 | 
			
		||||
    if _tui:
 | 
			
		||||
        await fn(_tui)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue