mirror of https://github.com/skygpu/skynet.git
				
				
				
			Begin adding TUI
							parent
							
								
									e66f8d74fd
								
							
						
					
					
						commit
						8b45fb5979
					
				| 
						 | 
				
			
			@ -61,6 +61,7 @@ cuda = [
 | 
			
		|||
    "basicsr>=1.4.2,<2",
 | 
			
		||||
    "realesrgan>=0.3.0,<0.4",
 | 
			
		||||
    "sentencepiece>=0.2.0",
 | 
			
		||||
    "urwid>=2.6.16",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[tool.uv]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
import logging
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -6,11 +7,31 @@ from hypercorn.config import Config
 | 
			
		|||
from hypercorn.trio import serve
 | 
			
		||||
from quart_trio import QuartTrio as Quart
 | 
			
		||||
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.dgpu.compute import ModelMngr
 | 
			
		||||
from skynet.dgpu.daemon import WorkerDaemon
 | 
			
		||||
from skynet.dgpu.network import NetConnector
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def setup_logging_for_tui(level):
 | 
			
		||||
    warnings.filterwarnings("ignore")
 | 
			
		||||
 | 
			
		||||
    logger = logging.getLogger()
 | 
			
		||||
    logger.setLevel(level)
 | 
			
		||||
 | 
			
		||||
    fh = logging.FileHandler('dgpu.log')
 | 
			
		||||
    fh.setLevel(level)
 | 
			
		||||
 | 
			
		||||
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
 | 
			
		||||
    fh.setFormatter(formatter)
 | 
			
		||||
 | 
			
		||||
    logger.addHandler(fh)
 | 
			
		||||
 | 
			
		||||
    for handler in logger.handlers:
 | 
			
		||||
        if isinstance(handler, logging.StreamHandler):
 | 
			
		||||
            logger.removeHandler(handler)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def open_dgpu_node(config: dict) -> None:
 | 
			
		||||
    '''
 | 
			
		||||
    Open a top level "GPU mgmt daemon", keep the
 | 
			
		||||
| 
						 | 
				
			
			@ -18,13 +39,17 @@ async def open_dgpu_node(config: dict) -> None:
 | 
			
		|||
    and *maybe* serve a `hypercorn` web API.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    # suppress logs from httpx (logs url + status after every query)
 | 
			
		||||
    logging.getLogger("httpx").setLevel(logging.WARNING)
 | 
			
		||||
 | 
			
		||||
    conn = NetConnector(config)
 | 
			
		||||
    mm = ModelMngr(config)
 | 
			
		||||
    daemon = WorkerDaemon(mm, conn, config)
 | 
			
		||||
    tui = None
 | 
			
		||||
    if config['tui']:
 | 
			
		||||
        setup_logging_for_tui(logging.INFO)
 | 
			
		||||
        tui = WorkerMonitor()
 | 
			
		||||
 | 
			
		||||
    conn = NetConnector(config, tui=tui)
 | 
			
		||||
    mm = ModelMngr(config, tui=tui)
 | 
			
		||||
    daemon = WorkerDaemon(mm, conn, config, tui=tui)
 | 
			
		||||
 | 
			
		||||
    api: Quart|None = None
 | 
			
		||||
    if 'api_bind' in config:
 | 
			
		||||
| 
						 | 
				
			
			@ -35,6 +60,8 @@ async def open_dgpu_node(config: dict) -> None:
 | 
			
		|||
    tn: trio.Nursery
 | 
			
		||||
    async with trio.open_nursery() as tn:
 | 
			
		||||
        tn.start_soon(daemon.snap_updater_task)
 | 
			
		||||
        if tui:
 | 
			
		||||
            tn.start_soon(tui.run)
 | 
			
		||||
 | 
			
		||||
        # TODO, consider a more explicit `as hypercorn_serve`
 | 
			
		||||
        # to clarify?
 | 
			
		||||
| 
						 | 
				
			
			@ -42,5 +69,9 @@ async def open_dgpu_node(config: dict) -> None:
 | 
			
		|||
            logging.info(f'serving api @ {config["api_bind"]}')
 | 
			
		||||
            tn.start_soon(serve, api, api_conf)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # block until cancelled
 | 
			
		||||
            await daemon.serve_forever()
 | 
			
		||||
 | 
			
		||||
        except *urwid.ExitMainLoop in ex_group:
 | 
			
		||||
            ...
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,6 +11,7 @@ from hashlib import sha256
 | 
			
		|||
import trio
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.dgpu.errors import (
 | 
			
		||||
    DGPUComputeError,
 | 
			
		||||
    DGPUInferenceCancelled,
 | 
			
		||||
| 
						 | 
				
			
			@ -72,7 +73,8 @@ class ModelMngr:
 | 
			
		|||
    checking load state, and unloading when no-longer-needed/finished.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(self, config: dict):
 | 
			
		||||
    def __init__(self, config: dict, tui: WorkerMonitor | None = None):
 | 
			
		||||
        self._tui = tui
 | 
			
		||||
        self.cache_dir = None
 | 
			
		||||
        if 'hf_home' in config:
 | 
			
		||||
            self.cache_dir = config['hf_home']
 | 
			
		||||
| 
						 | 
				
			
			@ -80,8 +82,6 @@ class ModelMngr:
 | 
			
		|||
        self._model_name: str = ''
 | 
			
		||||
        self._model_mode: str = ''
 | 
			
		||||
 | 
			
		||||
        # self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
 | 
			
		||||
 | 
			
		||||
    def log_debug_info(self):
 | 
			
		||||
        logging.debug('memory summary:')
 | 
			
		||||
        logging.debug('\n' + torch.cuda.memory_summary())
 | 
			
		||||
| 
						 | 
				
			
			@ -110,6 +110,7 @@ class ModelMngr:
 | 
			
		|||
    ) -> None:
 | 
			
		||||
        logging.info(f'loading model {name}...')
 | 
			
		||||
        self.unload_model()
 | 
			
		||||
 | 
			
		||||
        self._model = pipeline_for(
 | 
			
		||||
            name, mode, cache_dir=self.cache_dir)
 | 
			
		||||
        self._model_mode = mode
 | 
			
		||||
| 
						 | 
				
			
			@ -124,11 +125,19 @@ class ModelMngr:
 | 
			
		|||
        params: dict,
 | 
			
		||||
        inputs: list[bytes] = []
 | 
			
		||||
    ):
 | 
			
		||||
        def maybe_cancel_work(step, *args, **kwargs):
 | 
			
		||||
        total_steps = params['step']
 | 
			
		||||
        def inference_step_wakeup(*args, **kwargs):
 | 
			
		||||
            '''This is a callback function that gets invoked every inference step,
 | 
			
		||||
            we need to raise an exception here if we need to cancel work
 | 
			
		||||
            '''
 | 
			
		||||
            if self._should_cancel:
 | 
			
		||||
            step = args[0]
 | 
			
		||||
            # compat with callback_on_step_end
 | 
			
		||||
            if not isinstance(step, int):
 | 
			
		||||
                step = args[1]
 | 
			
		||||
 | 
			
		||||
            if self._tui:
 | 
			
		||||
                self._tui.set_progress(step, done=total_steps)
 | 
			
		||||
 | 
			
		||||
            should_raise = trio.from_thread.run(self._should_cancel, request_id)
 | 
			
		||||
            if should_raise:
 | 
			
		||||
                logging.warning(f'CANCELLING work at step {step}')
 | 
			
		||||
| 
						 | 
				
			
			@ -136,7 +145,10 @@ class ModelMngr:
 | 
			
		|||
 | 
			
		||||
            return {}
 | 
			
		||||
 | 
			
		||||
        maybe_cancel_work(0)
 | 
			
		||||
        if self._tui:
 | 
			
		||||
            self._tui.set_status(f'Request #{request_id}')
 | 
			
		||||
 | 
			
		||||
        inference_step_wakeup(0)
 | 
			
		||||
 | 
			
		||||
        output_type = 'png'
 | 
			
		||||
        if 'output_type' in params:
 | 
			
		||||
| 
						 | 
				
			
			@ -157,10 +169,10 @@ class ModelMngr:
 | 
			
		|||
                    prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
			
		||||
 | 
			
		||||
                    if 'flux' in name.lower():
 | 
			
		||||
                        extra_params['callback_on_step_end'] = maybe_cancel_work
 | 
			
		||||
                        extra_params['callback_on_step_end'] = inference_step_wakeup
 | 
			
		||||
 | 
			
		||||
                    else:
 | 
			
		||||
                        extra_params['callback'] = maybe_cancel_work
 | 
			
		||||
                        extra_params['callback'] = inference_step_wakeup
 | 
			
		||||
                        extra_params['callback_steps'] = 1
 | 
			
		||||
 | 
			
		||||
                    output = self._model(
 | 
			
		||||
| 
						 | 
				
			
			@ -213,4 +225,7 @@ class ModelMngr:
 | 
			
		|||
        finally:
 | 
			
		||||
            torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
        if self._tui:
 | 
			
		||||
            self._tui.set_status('')
 | 
			
		||||
 | 
			
		||||
        return output_hash, output
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,6 +17,7 @@ from skynet.constants import (
 | 
			
		|||
from skynet.dgpu.errors import (
 | 
			
		||||
    DGPUComputeError,
 | 
			
		||||
)
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.dgpu.compute import ModelMngr
 | 
			
		||||
from skynet.dgpu.network import NetConnector
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -41,10 +42,12 @@ class WorkerDaemon:
 | 
			
		|||
        self,
 | 
			
		||||
        mm: ModelMngr,
 | 
			
		||||
        conn: NetConnector,
 | 
			
		||||
        config: dict
 | 
			
		||||
        config: dict,
 | 
			
		||||
        tui: WorkerMonitor | None = None
 | 
			
		||||
    ):
 | 
			
		||||
        self.mm: ModelMngr = mm
 | 
			
		||||
        self.conn: NetConnector = conn
 | 
			
		||||
        self._tui = tui
 | 
			
		||||
        self.auto_withdraw = (
 | 
			
		||||
            config['auto_withdraw']
 | 
			
		||||
            if 'auto_withdraw' in config else False
 | 
			
		||||
| 
						 | 
				
			
			@ -150,6 +153,12 @@ class WorkerDaemon:
 | 
			
		|||
 | 
			
		||||
        return app
 | 
			
		||||
 | 
			
		||||
    async def _update_balance(self):
 | 
			
		||||
        if self._tui:
 | 
			
		||||
            # update balance
 | 
			
		||||
            balance = await self.conn.get_worker_balance()
 | 
			
		||||
            self._tui.set_header_text(new_balance=f'balance: {balance}')
 | 
			
		||||
 | 
			
		||||
    # TODO? this func is kinda big and maybe is better at module
 | 
			
		||||
    # level to reduce indentation?
 | 
			
		||||
    # -[ ] just pass `daemon: WorkerDaemon` vs. `self`
 | 
			
		||||
| 
						 | 
				
			
			@ -238,6 +247,8 @@ class WorkerDaemon:
 | 
			
		|||
        request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
 | 
			
		||||
        logging.info(f'calculated request hash: {request_hash}')
 | 
			
		||||
 | 
			
		||||
        total_step = body['params']['step']
 | 
			
		||||
 | 
			
		||||
        # TODO: validate request
 | 
			
		||||
 | 
			
		||||
        resp = await self.conn.begin_work(rid)
 | 
			
		||||
| 
						 | 
				
			
			@ -246,6 +257,9 @@ class WorkerDaemon:
 | 
			
		|||
 | 
			
		||||
        else:
 | 
			
		||||
            try:
 | 
			
		||||
                if self._tui:
 | 
			
		||||
                    self._tui.set_progress(0, done=total_step)
 | 
			
		||||
 | 
			
		||||
                output_type = 'png'
 | 
			
		||||
                if 'output_type' in body['params']:
 | 
			
		||||
                    output_type = body['params']['output_type']
 | 
			
		||||
| 
						 | 
				
			
			@ -269,6 +283,9 @@ class WorkerDaemon:
 | 
			
		|||
                            f'Unsupported backend {self.backend}'
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                if self._tui:
 | 
			
		||||
                    self._tui.set_progress(total_step)
 | 
			
		||||
 | 
			
		||||
                self._last_generation_ts: str = datetime.now().isoformat()
 | 
			
		||||
                self._last_benchmark: list[float] = self._benchmark
 | 
			
		||||
                self._benchmark: list[float] = []
 | 
			
		||||
| 
						 | 
				
			
			@ -277,6 +294,9 @@ class WorkerDaemon:
 | 
			
		|||
 | 
			
		||||
                await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
 | 
			
		||||
 | 
			
		||||
                await self._update_balance()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            except BaseException as err:
 | 
			
		||||
                if 'network cancel' not in str(err):
 | 
			
		||||
                    logging.exception('Failed to serve model request !?\n')
 | 
			
		||||
| 
						 | 
				
			
			@ -294,6 +314,7 @@ class WorkerDaemon:
 | 
			
		|||
    # -[ ] keeps tasks-as-funcs style prominent
 | 
			
		||||
    # -[ ] avoids so much indentation due to methods
 | 
			
		||||
    async def serve_forever(self):
 | 
			
		||||
        await self._update_balance()
 | 
			
		||||
        try:
 | 
			
		||||
            while True:
 | 
			
		||||
                if self.auto_withdraw:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,7 @@ import outcome
 | 
			
		|||
from PIL import Image
 | 
			
		||||
from leap.cleos import CLEOS
 | 
			
		||||
from leap.protocol import Asset
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.constants import (
 | 
			
		||||
    DEFAULT_IPFS_DOMAIN,
 | 
			
		||||
    GPU_CONTRACT_ABI,
 | 
			
		||||
| 
						 | 
				
			
			@ -57,7 +58,7 @@ class NetConnector:
 | 
			
		|||
    - CLEOS client
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(self, config: dict):
 | 
			
		||||
    def __init__(self, config: dict, tui: WorkerMonitor | None = None):
 | 
			
		||||
        # TODO, why these extra instance vars for an (unsynced)
 | 
			
		||||
        # copy of the `config` state?
 | 
			
		||||
        self.account = config['account']
 | 
			
		||||
| 
						 | 
				
			
			@ -81,6 +82,10 @@ class NetConnector:
 | 
			
		|||
            self.ipfs_domain = config['ipfs_domain']
 | 
			
		||||
 | 
			
		||||
        self._wip_requests = {}
 | 
			
		||||
        self._tui = tui
 | 
			
		||||
        if self._tui:
 | 
			
		||||
            self._tui.set_header_text(new_worker_name=self.account)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # blockchain helpers
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -168,6 +173,9 @@ class NetConnector:
 | 
			
		|||
                n.start_soon(
 | 
			
		||||
                    _run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
 | 
			
		||||
 | 
			
		||||
        if self._tui:
 | 
			
		||||
            self._tui.network_update(snap)
 | 
			
		||||
 | 
			
		||||
        return snap
 | 
			
		||||
 | 
			
		||||
    async def begin_work(self, request_id: int):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,248 @@
 | 
			
		|||
import urwid
 | 
			
		||||
import trio
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WorkerMonitor:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.requests = []
 | 
			
		||||
        self.header_info = {}
 | 
			
		||||
 | 
			
		||||
        self.palette = [
 | 
			
		||||
            ('headerbar',         'white',      'dark blue'),
 | 
			
		||||
            ('request_row',       'white',      'dark gray'),
 | 
			
		||||
            ('worker_row',        'light gray', 'black'),
 | 
			
		||||
            ('progress_normal',   'black',      'light gray'),
 | 
			
		||||
            ('progress_complete', 'black',      'dark green'),
 | 
			
		||||
            ('body',              'white',      'black'),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        # --- Top bar (header) ---
 | 
			
		||||
        worker_name = self.header_info.get('left', "unknown")
 | 
			
		||||
        balance     = self.header_info.get('right', "balance: unknown")
 | 
			
		||||
 | 
			
		||||
        self.worker_name_widget = urwid.Text(worker_name)
 | 
			
		||||
        self.balance_widget     = urwid.Text(balance, align='right')
 | 
			
		||||
 | 
			
		||||
        header = urwid.Columns([self.worker_name_widget, self.balance_widget])
 | 
			
		||||
        header_attr = urwid.AttrMap(header, 'headerbar')
 | 
			
		||||
 | 
			
		||||
        # --- Body (List of requests) ---
 | 
			
		||||
        self.body_listbox = self._create_listbox_body(self.requests)
 | 
			
		||||
 | 
			
		||||
        # --- Bottom bar (progress) ---
 | 
			
		||||
        self.status_text  = urwid.Text("Request: none", align='left')
 | 
			
		||||
        self.progress_bar = urwid.ProgressBar(
 | 
			
		||||
            'progress_normal',
 | 
			
		||||
            'progress_complete',
 | 
			
		||||
            current=0,
 | 
			
		||||
            done=100
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        footer_cols = urwid.Columns([
 | 
			
		||||
            ('fixed', 20, self.status_text),
 | 
			
		||||
            self.progress_bar,
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
        # Build the main frame
 | 
			
		||||
        frame = urwid.Frame(
 | 
			
		||||
            self.body_listbox,
 | 
			
		||||
            header=header_attr,
 | 
			
		||||
            footer=footer_cols
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Set up the main loop with Trio
 | 
			
		||||
        self.event_loop = urwid.TrioEventLoop()
 | 
			
		||||
        self.main_loop = urwid.MainLoop(
 | 
			
		||||
            frame,
 | 
			
		||||
            palette=self.palette,
 | 
			
		||||
            event_loop=self.event_loop,
 | 
			
		||||
            unhandled_input=self._exit_on_q
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _create_listbox_body(self, requests):
 | 
			
		||||
        """
 | 
			
		||||
        Build a ListBox (vertical list) of requests & workers using SimpleFocusListWalker.
 | 
			
		||||
        """
 | 
			
		||||
        widgets = self._build_request_widgets(requests)
 | 
			
		||||
        walker = urwid.SimpleFocusListWalker(widgets)
 | 
			
		||||
        return urwid.ListBox(walker)
 | 
			
		||||
 | 
			
		||||
    def _build_request_widgets(self, requests):
 | 
			
		||||
        """
 | 
			
		||||
        Build a list of Urwid widgets (one row per request + per worker).
 | 
			
		||||
        """
 | 
			
		||||
        row_widgets = []
 | 
			
		||||
 | 
			
		||||
        for req in requests:
 | 
			
		||||
            # Build a columns widget for the request row
 | 
			
		||||
            columns = urwid.Columns([
 | 
			
		||||
                ('fixed', 5,  urwid.Text(f"#{req['id']}")),   # e.g. "#12"
 | 
			
		||||
                ('weight', 3, urwid.Text(req['model'])),
 | 
			
		||||
                ('weight', 3, urwid.Text(req['prompt'])),
 | 
			
		||||
                ('fixed', 13, urwid.Text(req['user'])),
 | 
			
		||||
                ('fixed', 13, urwid.Text(req['reward'])),
 | 
			
		||||
            ], dividechars=1)
 | 
			
		||||
 | 
			
		||||
            # Wrap the columns with an attribute map for coloring
 | 
			
		||||
            request_row = urwid.AttrMap(columns, 'request_row')
 | 
			
		||||
            row_widgets.append(request_row)
 | 
			
		||||
 | 
			
		||||
            # Then add each worker in its own line below
 | 
			
		||||
            for w in req["workers"]:
 | 
			
		||||
                worker_line = urwid.Text(f"  {w}")
 | 
			
		||||
                worker_row  = urwid.AttrMap(worker_line, 'worker_row')
 | 
			
		||||
                row_widgets.append(worker_row)
 | 
			
		||||
 | 
			
		||||
            # Optional blank line after each request
 | 
			
		||||
            row_widgets.append(urwid.Text(""))
 | 
			
		||||
 | 
			
		||||
        return row_widgets
 | 
			
		||||
 | 
			
		||||
    def _exit_on_q(self, key):
 | 
			
		||||
        """Exit the TUI on 'q' or 'Q'."""
 | 
			
		||||
        if key in ('q', 'Q'):
 | 
			
		||||
            raise urwid.ExitMainLoop()
 | 
			
		||||
 | 
			
		||||
    async def run(self):
 | 
			
		||||
        """
 | 
			
		||||
        Run the TUI in an async context (Trio).
 | 
			
		||||
        This method blocks until the user quits (pressing q/Q).
 | 
			
		||||
        """
 | 
			
		||||
        with self.main_loop.start():
 | 
			
		||||
            await self.event_loop.run_async()
 | 
			
		||||
 | 
			
		||||
        raise urwid.ExitMainLoop()
 | 
			
		||||
 | 
			
		||||
    # -------------------------------------------------------------------------
 | 
			
		||||
    # Public Methods to Update Various Parts of the UI
 | 
			
		||||
    # -------------------------------------------------------------------------
 | 
			
		||||
    def set_status(self, status: str):
 | 
			
		||||
        self.status_text.set_text(status)
 | 
			
		||||
 | 
			
		||||
    def set_progress(self, current, done=None):
 | 
			
		||||
        """
 | 
			
		||||
        Update the bottom progress bar.
 | 
			
		||||
          - `current`: new current progress value (int).
 | 
			
		||||
          - `done`: max progress value (int). If None, we don’t change it.
 | 
			
		||||
        """
 | 
			
		||||
        if done is not None:
 | 
			
		||||
            self.progress_bar.done = done
 | 
			
		||||
 | 
			
		||||
        self.progress_bar.current = current
 | 
			
		||||
 | 
			
		||||
        pct = 0
 | 
			
		||||
        if self.progress_bar.done != 0:
 | 
			
		||||
            pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
 | 
			
		||||
 | 
			
		||||
    def update_requests(self, new_requests):
 | 
			
		||||
        """
 | 
			
		||||
        Replace the data in the existing ListBox with new request widgets.
 | 
			
		||||
        """
 | 
			
		||||
        new_widgets = self._build_request_widgets(new_requests)
 | 
			
		||||
        self.body_listbox.body[:] = new_widgets  # replace content of the list walker
 | 
			
		||||
 | 
			
		||||
    def set_header_text(self, new_worker_name=None, new_balance=None):
 | 
			
		||||
        """
 | 
			
		||||
        Update the text in the header bar for worker name and/or balance.
 | 
			
		||||
        """
 | 
			
		||||
        if new_worker_name is not None:
 | 
			
		||||
            self.worker_name_widget.set_text(new_worker_name)
 | 
			
		||||
        if new_balance is not None:
 | 
			
		||||
            self.balance_widget.set_text(new_balance)
 | 
			
		||||
 | 
			
		||||
    def network_update(self, snapshot: dict):
 | 
			
		||||
        queue = [
 | 
			
		||||
            {
 | 
			
		||||
                **r,
 | 
			
		||||
                **(json.loads(r['body'])['params']),
 | 
			
		||||
                'workers': [s['worker'] for s in snapshot['requests'][r['id']]]
 | 
			
		||||
            }
 | 
			
		||||
            for r in snapshot['queue']
 | 
			
		||||
        ]
 | 
			
		||||
        self.update_requests(queue)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# # -----------------------------------------------------------------------------
 | 
			
		||||
# # Example usage
 | 
			
		||||
# # -----------------------------------------------------------------------------
 | 
			
		||||
# 
 | 
			
		||||
# async def main():
 | 
			
		||||
#     # Example data
 | 
			
		||||
#     example_requests = [
 | 
			
		||||
#         {
 | 
			
		||||
#             "id": 12,
 | 
			
		||||
#             "model": "black-forest-labs/FLUX.1-schnell",
 | 
			
		||||
#             "prompt": "Generate an answer about quantum entanglement.",
 | 
			
		||||
#             "user": "alice123",
 | 
			
		||||
#             "reward": "20.0000 GPU",
 | 
			
		||||
#             "workers": ["workerA", "workerB"],
 | 
			
		||||
#         },
 | 
			
		||||
#         {
 | 
			
		||||
#             "id": 5,
 | 
			
		||||
#             "model": "some-other-model/v2.0",
 | 
			
		||||
#             "prompt": "A story about dragons.",
 | 
			
		||||
#             "user": "bobthebuilder",
 | 
			
		||||
#             "reward": "15.0000 GPU",
 | 
			
		||||
#             "workers": ["workerX"],
 | 
			
		||||
#         },
 | 
			
		||||
#         {
 | 
			
		||||
#             "id": 99,
 | 
			
		||||
#             "model": "cool-model/turbo",
 | 
			
		||||
#             "prompt": "Classify sentiment in these tweets.",
 | 
			
		||||
#             "user": "charlie",
 | 
			
		||||
#             "reward": "25.5000 GPU",
 | 
			
		||||
#             "workers": ["workerOne", "workerTwo", "workerThree"],
 | 
			
		||||
#         },
 | 
			
		||||
#     ]
 | 
			
		||||
# 
 | 
			
		||||
#     ui = WorkerMonitor()
 | 
			
		||||
# 
 | 
			
		||||
#     async def progress_task():
 | 
			
		||||
#         # Fill from 0% to 100%
 | 
			
		||||
#         for pct in range(101):
 | 
			
		||||
#             ui.set_progress(
 | 
			
		||||
#                 current=pct,
 | 
			
		||||
#                 status_str=f"Request #1234 ({pct}%)"
 | 
			
		||||
#             )
 | 
			
		||||
#             await trio.sleep(0.05)
 | 
			
		||||
#         # Reset to 0
 | 
			
		||||
#         ui.set_progress(
 | 
			
		||||
#             current=0,
 | 
			
		||||
#             status_str="Starting again..."
 | 
			
		||||
#         )
 | 
			
		||||
# 
 | 
			
		||||
#     async def update_data_task():
 | 
			
		||||
#         await trio.sleep(3)  # Wait a bit, then update requests
 | 
			
		||||
#         new_data = [{
 | 
			
		||||
#             "id": 101,
 | 
			
		||||
#             "model": "new-model/v1.0",
 | 
			
		||||
#             "prompt": "Say hi to the world.",
 | 
			
		||||
#             "user": "eve",
 | 
			
		||||
#             "reward": "50.0000 GPU",
 | 
			
		||||
#             "workers": ["workerFresh", "workerPower"],
 | 
			
		||||
#         }]
 | 
			
		||||
#         ui.update_requests(new_data)
 | 
			
		||||
#         ui.set_header_text(new_worker_name="NewNodeName",
 | 
			
		||||
#                             new_balance="balance: 12345.6789 GPU")
 | 
			
		||||
# 
 | 
			
		||||
#     try:
 | 
			
		||||
#         async with trio.open_nursery() as nursery:
 | 
			
		||||
#             # Run the TUI
 | 
			
		||||
#             nursery.start_soon(ui.run_teadown_on_exit, nursery)
 | 
			
		||||
# 
 | 
			
		||||
#             ui.update_requests(example_requests)
 | 
			
		||||
#             ui.set_header_text(
 | 
			
		||||
#                 new_worker_name="worker1.scd",
 | 
			
		||||
#                 new_balance="balance: 12345.6789 GPU"
 | 
			
		||||
#             )
 | 
			
		||||
#             # Start background tasks
 | 
			
		||||
#             nursery.start_soon(progress_task)
 | 
			
		||||
#             nursery.start_soon(update_data_task)
 | 
			
		||||
# 
 | 
			
		||||
#     except *KeyboardInterrupt as ex_group:
 | 
			
		||||
#         ...
 | 
			
		||||
# 
 | 
			
		||||
# 
 | 
			
		||||
# if __name__ == "__main__":
 | 
			
		||||
#     trio.run(main)
 | 
			
		||||
| 
						 | 
				
			
			@ -7,8 +7,10 @@ import logging
 | 
			
		|||
import importlib
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import diffusers
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from PIL import Image
 | 
			
		||||
| 
						 | 
				
			
			@ -74,12 +76,27 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
 | 
			
		|||
    return crop_image(convert_from_bytes_to_img(raw), max_w, max_h)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DummyPB:
 | 
			
		||||
    def update(self):
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
@torch.compiler.disable
 | 
			
		||||
@contextmanager
 | 
			
		||||
def dummy_progress_bar(*args, **kwargs):
 | 
			
		||||
    yield DummyPB()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def monkey_patch_pipeline_disable_progress_bar(pipe):
 | 
			
		||||
    pipe.progress_bar = dummy_progress_bar
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pipeline_for(
 | 
			
		||||
    model: str,
 | 
			
		||||
    mode: str,
 | 
			
		||||
    mem_fraction: float = 1.0,
 | 
			
		||||
    cache_dir: str | None = None
 | 
			
		||||
) -> DiffusionPipeline:
 | 
			
		||||
    diffusers.utils.logging.disable_progress_bar()
 | 
			
		||||
 | 
			
		||||
    logging.info(f'pipeline_for {model} {mode}')
 | 
			
		||||
    assert torch.cuda.is_available()
 | 
			
		||||
| 
						 | 
				
			
			@ -105,7 +122,9 @@ def pipeline_for(
 | 
			
		|||
        normalized_shortname = shortname.replace('-', '_')
 | 
			
		||||
        custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}')
 | 
			
		||||
        assert custom_pipeline.__model['name'] == model
 | 
			
		||||
        return custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
 | 
			
		||||
        pipe = custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
 | 
			
		||||
        monkey_patch_pipeline_disable_progress_bar(pipe)
 | 
			
		||||
        return pipe
 | 
			
		||||
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        # TODO, uhh why not warn/error log this?
 | 
			
		||||
| 
						 | 
				
			
			@ -121,7 +140,6 @@ def pipeline_for(
 | 
			
		|||
        logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
 | 
			
		||||
 | 
			
		||||
    params = {
 | 
			
		||||
        'safety_checker': None,
 | 
			
		||||
        'torch_dtype': torch.float16,
 | 
			
		||||
        'cache_dir': cache_dir,
 | 
			
		||||
        'variant': 'fp16',
 | 
			
		||||
| 
						 | 
				
			
			@ -130,6 +148,7 @@ def pipeline_for(
 | 
			
		|||
    match shortname:
 | 
			
		||||
        case 'stable':
 | 
			
		||||
            params['revision'] = 'fp16'
 | 
			
		||||
            params['safety_checker'] = None
 | 
			
		||||
 | 
			
		||||
    torch.cuda.set_per_process_memory_fraction(mem_fraction)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -167,6 +186,8 @@ def pipeline_for(
 | 
			
		|||
 | 
			
		||||
        pipe = pipe.to('cuda')
 | 
			
		||||
 | 
			
		||||
    monkey_patch_pipeline_disable_progress_bar(pipe)
 | 
			
		||||
 | 
			
		||||
    return pipe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										24
									
								
								uv.lock
								
								
								
								
							
							
						
						
									
										24
									
								
								uv.lock
								
								
								
								
							| 
						 | 
				
			
			@ -2262,6 +2262,7 @@ cuda = [
 | 
			
		|||
    { name = "torchvision" },
 | 
			
		||||
    { name = "transformers" },
 | 
			
		||||
    { name = "triton" },
 | 
			
		||||
    { name = "urwid" },
 | 
			
		||||
    { name = "xformers" },
 | 
			
		||||
]
 | 
			
		||||
dev = [
 | 
			
		||||
| 
						 | 
				
			
			@ -2313,6 +2314,7 @@ cuda = [
 | 
			
		|||
    { name = "torchvision", specifier = "==0.20.1+cu121", index = "https://download.pytorch.org/whl/cu121" },
 | 
			
		||||
    { name = "transformers", specifier = "==4.48.0" },
 | 
			
		||||
    { name = "triton", specifier = "==3.1.0", index = "https://download.pytorch.org/whl/cu121" },
 | 
			
		||||
    { name = "urwid", specifier = ">=2.6.16" },
 | 
			
		||||
    { name = "xformers", specifier = ">=0.0.29,<0.0.30" },
 | 
			
		||||
]
 | 
			
		||||
dev = [
 | 
			
		||||
| 
						 | 
				
			
			@ -2627,6 +2629,28 @@ wheels = [
 | 
			
		|||
    { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "urwid"
 | 
			
		||||
version = "2.6.16"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
dependencies = [
 | 
			
		||||
    { name = "typing-extensions" },
 | 
			
		||||
    { name = "wcwidth" },
 | 
			
		||||
]
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/98/21/ad23c9e961b2d36d57c63686a6f86768dd945d406323fb58c84f09478530/urwid-2.6.16.tar.gz", hash = "sha256:93ad239939e44c385e64aa00027878b9e5c486d59e855ec8ab5b1e1adcdb32a2", size = 848179 }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/54/cb/271a4f5a1bf4208dbdc96d85b9eae744cf4e5e11ac73eda76dc98c8fd2d7/urwid-2.6.16-py3-none-any.whl", hash = "sha256:de14896c6df9eb759ed1fd93e0384a5279e51e0dde8f621e4083f7a8368c0797", size = 297196 },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "wcwidth"
 | 
			
		||||
version = "0.2.13"
 | 
			
		||||
source = { registry = "https://pypi.org/simple" }
 | 
			
		||||
sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 }
 | 
			
		||||
wheels = [
 | 
			
		||||
    { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 },
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "websocket-client"
 | 
			
		||||
version = "1.8.0"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue