diff --git a/pyproject.toml b/pyproject.toml index 73884a7..94a9740 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ cuda = [ "basicsr>=1.4.2,<2", "realesrgan>=0.3.0,<0.4", "sentencepiece>=0.2.0", + "urwid>=2.6.16", ] [tool.uv] diff --git a/skynet/dgpu/__init__.py b/skynet/dgpu/__init__.py index 4371f83..59af61c 100755 --- a/skynet/dgpu/__init__.py +++ b/skynet/dgpu/__init__.py @@ -1,4 +1,5 @@ import logging +import warnings import trio @@ -6,11 +7,31 @@ from hypercorn.config import Config from hypercorn.trio import serve from quart_trio import QuartTrio as Quart +from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.compute import ModelMngr from skynet.dgpu.daemon import WorkerDaemon from skynet.dgpu.network import NetConnector +def setup_logging_for_tui(level): + warnings.filterwarnings("ignore") + + logger = logging.getLogger() + logger.setLevel(level) + + fh = logging.FileHandler('dgpu.log') + fh.setLevel(level) + + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + fh.setFormatter(formatter) + + logger.addHandler(fh) + + for handler in logger.handlers: + if isinstance(handler, logging.StreamHandler): + logger.removeHandler(handler) + + async def open_dgpu_node(config: dict) -> None: ''' Open a top level "GPU mgmt daemon", keep the @@ -18,13 +39,17 @@ async def open_dgpu_node(config: dict) -> None: and *maybe* serve a `hypercorn` web API. ''' - # suppress logs from httpx (logs url + status after every query) logging.getLogger("httpx").setLevel(logging.WARNING) - conn = NetConnector(config) - mm = ModelMngr(config) - daemon = WorkerDaemon(mm, conn, config) + tui = None + if config['tui']: + setup_logging_for_tui(logging.INFO) + tui = WorkerMonitor() + + conn = NetConnector(config, tui=tui) + mm = ModelMngr(config, tui=tui) + daemon = WorkerDaemon(mm, conn, config, tui=tui) api: Quart|None = None if 'api_bind' in config: @@ -35,6 +60,8 @@ async def open_dgpu_node(config: dict) -> None: tn: trio.Nursery async with trio.open_nursery() as tn: tn.start_soon(daemon.snap_updater_task) + if tui: + tn.start_soon(tui.run) # TODO, consider a more explicit `as hypercorn_serve` # to clarify? @@ -42,5 +69,9 @@ async def open_dgpu_node(config: dict) -> None: logging.info(f'serving api @ {config["api_bind"]}') tn.start_soon(serve, api, api_conf) - # block until cancelled - await daemon.serve_forever() + try: + # block until cancelled + await daemon.serve_forever() + + except *urwid.ExitMainLoop in ex_group: + ... diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index 56403a1..d0e8689 100755 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -11,6 +11,7 @@ from hashlib import sha256 import trio import torch +from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.errors import ( DGPUComputeError, DGPUInferenceCancelled, @@ -72,7 +73,8 @@ class ModelMngr: checking load state, and unloading when no-longer-needed/finished. ''' - def __init__(self, config: dict): + def __init__(self, config: dict, tui: WorkerMonitor | None = None): + self._tui = tui self.cache_dir = None if 'hf_home' in config: self.cache_dir = config['hf_home'] @@ -80,8 +82,6 @@ class ModelMngr: self._model_name: str = '' self._model_mode: str = '' - # self.load_model(DEFAULT_INITAL_MODEL, 'txt2img') - def log_debug_info(self): logging.debug('memory summary:') logging.debug('\n' + torch.cuda.memory_summary()) @@ -110,6 +110,7 @@ class ModelMngr: ) -> None: logging.info(f'loading model {name}...') self.unload_model() + self._model = pipeline_for( name, mode, cache_dir=self.cache_dir) self._model_mode = mode @@ -124,19 +125,30 @@ class ModelMngr: params: dict, inputs: list[bytes] = [] ): - def maybe_cancel_work(step, *args, **kwargs): + total_steps = params['step'] + def inference_step_wakeup(*args, **kwargs): '''This is a callback function that gets invoked every inference step, we need to raise an exception here if we need to cancel work ''' - if self._should_cancel: - should_raise = trio.from_thread.run(self._should_cancel, request_id) - if should_raise: - logging.warning(f'CANCELLING work at step {step}') - raise DGPUInferenceCancelled('network cancel') + step = args[0] + # compat with callback_on_step_end + if not isinstance(step, int): + step = args[1] + + if self._tui: + self._tui.set_progress(step, done=total_steps) + + should_raise = trio.from_thread.run(self._should_cancel, request_id) + if should_raise: + logging.warning(f'CANCELLING work at step {step}') + raise DGPUInferenceCancelled('network cancel') return {} - maybe_cancel_work(0) + if self._tui: + self._tui.set_status(f'Request #{request_id}') + + inference_step_wakeup(0) output_type = 'png' if 'output_type' in params: @@ -157,10 +169,10 @@ class ModelMngr: prompt, guidance, step, seed, upscaler, extra_params = arguments if 'flux' in name.lower(): - extra_params['callback_on_step_end'] = maybe_cancel_work + extra_params['callback_on_step_end'] = inference_step_wakeup else: - extra_params['callback'] = maybe_cancel_work + extra_params['callback'] = inference_step_wakeup extra_params['callback_steps'] = 1 output = self._model( @@ -213,4 +225,7 @@ class ModelMngr: finally: torch.cuda.empty_cache() + if self._tui: + self._tui.set_status('') + return output_hash, output diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index bfcab79..98d3eda 100755 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -17,6 +17,7 @@ from skynet.constants import ( from skynet.dgpu.errors import ( DGPUComputeError, ) +from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.compute import ModelMngr from skynet.dgpu.network import NetConnector @@ -41,10 +42,12 @@ class WorkerDaemon: self, mm: ModelMngr, conn: NetConnector, - config: dict + config: dict, + tui: WorkerMonitor | None = None ): self.mm: ModelMngr = mm self.conn: NetConnector = conn + self._tui = tui self.auto_withdraw = ( config['auto_withdraw'] if 'auto_withdraw' in config else False @@ -150,6 +153,12 @@ class WorkerDaemon: return app + async def _update_balance(self): + if self._tui: + # update balance + balance = await self.conn.get_worker_balance() + self._tui.set_header_text(new_balance=f'balance: {balance}') + # TODO? this func is kinda big and maybe is better at module # level to reduce indentation? # -[ ] just pass `daemon: WorkerDaemon` vs. `self` @@ -238,6 +247,8 @@ class WorkerDaemon: request_hash = sha256(hash_str.encode('utf-8')).hexdigest() logging.info(f'calculated request hash: {request_hash}') + total_step = body['params']['step'] + # TODO: validate request resp = await self.conn.begin_work(rid) @@ -246,6 +257,9 @@ class WorkerDaemon: else: try: + if self._tui: + self._tui.set_progress(0, done=total_step) + output_type = 'png' if 'output_type' in body['params']: output_type = body['params']['output_type'] @@ -269,6 +283,9 @@ class WorkerDaemon: f'Unsupported backend {self.backend}' ) + if self._tui: + self._tui.set_progress(total_step) + self._last_generation_ts: str = datetime.now().isoformat() self._last_benchmark: list[float] = self._benchmark self._benchmark: list[float] = [] @@ -277,6 +294,9 @@ class WorkerDaemon: await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash) + await self._update_balance() + + except BaseException as err: if 'network cancel' not in str(err): logging.exception('Failed to serve model request !?\n') @@ -294,6 +314,7 @@ class WorkerDaemon: # -[ ] keeps tasks-as-funcs style prominent # -[ ] avoids so much indentation due to methods async def serve_forever(self): + await self._update_balance() try: while True: if self.auto_withdraw: diff --git a/skynet/dgpu/network.py b/skynet/dgpu/network.py index d3c573a..6efd871 100755 --- a/skynet/dgpu/network.py +++ b/skynet/dgpu/network.py @@ -13,6 +13,7 @@ import outcome from PIL import Image from leap.cleos import CLEOS from leap.protocol import Asset +from skynet.dgpu.tui import WorkerMonitor from skynet.constants import ( DEFAULT_IPFS_DOMAIN, GPU_CONTRACT_ABI, @@ -57,7 +58,7 @@ class NetConnector: - CLEOS client ''' - def __init__(self, config: dict): + def __init__(self, config: dict, tui: WorkerMonitor | None = None): # TODO, why these extra instance vars for an (unsynced) # copy of the `config` state? self.account = config['account'] @@ -81,6 +82,10 @@ class NetConnector: self.ipfs_domain = config['ipfs_domain'] self._wip_requests = {} + self._tui = tui + if self._tui: + self._tui.set_header_text(new_worker_name=self.account) + # blockchain helpers @@ -163,6 +168,9 @@ class NetConnector: n.start_soon( _run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id']) + if self._tui: + self._tui.network_update(snap) + return snap async def begin_work(self, request_id: int): diff --git a/skynet/dgpu/tui.py b/skynet/dgpu/tui.py new file mode 100644 index 0000000..6530b58 --- /dev/null +++ b/skynet/dgpu/tui.py @@ -0,0 +1,248 @@ +import urwid +import trio +import json + + +class WorkerMonitor: + def __init__(self): + self.requests = [] + self.header_info = {} + + self.palette = [ + ('headerbar', 'white', 'dark blue'), + ('request_row', 'white', 'dark gray'), + ('worker_row', 'light gray', 'black'), + ('progress_normal', 'black', 'light gray'), + ('progress_complete', 'black', 'dark green'), + ('body', 'white', 'black'), + ] + + # --- Top bar (header) --- + worker_name = self.header_info.get('left', "unknown") + balance = self.header_info.get('right', "balance: unknown") + + self.worker_name_widget = urwid.Text(worker_name) + self.balance_widget = urwid.Text(balance, align='right') + + header = urwid.Columns([self.worker_name_widget, self.balance_widget]) + header_attr = urwid.AttrMap(header, 'headerbar') + + # --- Body (List of requests) --- + self.body_listbox = self._create_listbox_body(self.requests) + + # --- Bottom bar (progress) --- + self.status_text = urwid.Text("Request: none", align='left') + self.progress_bar = urwid.ProgressBar( + 'progress_normal', + 'progress_complete', + current=0, + done=100 + ) + + footer_cols = urwid.Columns([ + ('fixed', 20, self.status_text), + self.progress_bar, + ]) + + # Build the main frame + frame = urwid.Frame( + self.body_listbox, + header=header_attr, + footer=footer_cols + ) + + # Set up the main loop with Trio + self.event_loop = urwid.TrioEventLoop() + self.main_loop = urwid.MainLoop( + frame, + palette=self.palette, + event_loop=self.event_loop, + unhandled_input=self._exit_on_q + ) + + def _create_listbox_body(self, requests): + """ + Build a ListBox (vertical list) of requests & workers using SimpleFocusListWalker. + """ + widgets = self._build_request_widgets(requests) + walker = urwid.SimpleFocusListWalker(widgets) + return urwid.ListBox(walker) + + def _build_request_widgets(self, requests): + """ + Build a list of Urwid widgets (one row per request + per worker). + """ + row_widgets = [] + + for req in requests: + # Build a columns widget for the request row + columns = urwid.Columns([ + ('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12" + ('weight', 3, urwid.Text(req['model'])), + ('weight', 3, urwid.Text(req['prompt'])), + ('fixed', 13, urwid.Text(req['user'])), + ('fixed', 13, urwid.Text(req['reward'])), + ], dividechars=1) + + # Wrap the columns with an attribute map for coloring + request_row = urwid.AttrMap(columns, 'request_row') + row_widgets.append(request_row) + + # Then add each worker in its own line below + for w in req["workers"]: + worker_line = urwid.Text(f" {w}") + worker_row = urwid.AttrMap(worker_line, 'worker_row') + row_widgets.append(worker_row) + + # Optional blank line after each request + row_widgets.append(urwid.Text("")) + + return row_widgets + + def _exit_on_q(self, key): + """Exit the TUI on 'q' or 'Q'.""" + if key in ('q', 'Q'): + raise urwid.ExitMainLoop() + + async def run(self): + """ + Run the TUI in an async context (Trio). + This method blocks until the user quits (pressing q/Q). + """ + with self.main_loop.start(): + await self.event_loop.run_async() + + raise urwid.ExitMainLoop() + + # ------------------------------------------------------------------------- + # Public Methods to Update Various Parts of the UI + # ------------------------------------------------------------------------- + def set_status(self, status: str): + self.status_text.set_text(status) + + def set_progress(self, current, done=None): + """ + Update the bottom progress bar. + - `current`: new current progress value (int). + - `done`: max progress value (int). If None, we don’t change it. + """ + if done is not None: + self.progress_bar.done = done + + self.progress_bar.current = current + + pct = 0 + if self.progress_bar.done != 0: + pct = int((self.progress_bar.current / self.progress_bar.done) * 100) + + def update_requests(self, new_requests): + """ + Replace the data in the existing ListBox with new request widgets. + """ + new_widgets = self._build_request_widgets(new_requests) + self.body_listbox.body[:] = new_widgets # replace content of the list walker + + def set_header_text(self, new_worker_name=None, new_balance=None): + """ + Update the text in the header bar for worker name and/or balance. + """ + if new_worker_name is not None: + self.worker_name_widget.set_text(new_worker_name) + if new_balance is not None: + self.balance_widget.set_text(new_balance) + + def network_update(self, snapshot: dict): + queue = [ + { + **r, + **(json.loads(r['body'])['params']), + 'workers': [s['worker'] for s in snapshot['requests'][r['id']]] + } + for r in snapshot['queue'] + ] + self.update_requests(queue) + + +# # ----------------------------------------------------------------------------- +# # Example usage +# # ----------------------------------------------------------------------------- +# +# async def main(): +# # Example data +# example_requests = [ +# { +# "id": 12, +# "model": "black-forest-labs/FLUX.1-schnell", +# "prompt": "Generate an answer about quantum entanglement.", +# "user": "alice123", +# "reward": "20.0000 GPU", +# "workers": ["workerA", "workerB"], +# }, +# { +# "id": 5, +# "model": "some-other-model/v2.0", +# "prompt": "A story about dragons.", +# "user": "bobthebuilder", +# "reward": "15.0000 GPU", +# "workers": ["workerX"], +# }, +# { +# "id": 99, +# "model": "cool-model/turbo", +# "prompt": "Classify sentiment in these tweets.", +# "user": "charlie", +# "reward": "25.5000 GPU", +# "workers": ["workerOne", "workerTwo", "workerThree"], +# }, +# ] +# +# ui = WorkerMonitor() +# +# async def progress_task(): +# # Fill from 0% to 100% +# for pct in range(101): +# ui.set_progress( +# current=pct, +# status_str=f"Request #1234 ({pct}%)" +# ) +# await trio.sleep(0.05) +# # Reset to 0 +# ui.set_progress( +# current=0, +# status_str="Starting again..." +# ) +# +# async def update_data_task(): +# await trio.sleep(3) # Wait a bit, then update requests +# new_data = [{ +# "id": 101, +# "model": "new-model/v1.0", +# "prompt": "Say hi to the world.", +# "user": "eve", +# "reward": "50.0000 GPU", +# "workers": ["workerFresh", "workerPower"], +# }] +# ui.update_requests(new_data) +# ui.set_header_text(new_worker_name="NewNodeName", +# new_balance="balance: 12345.6789 GPU") +# +# try: +# async with trio.open_nursery() as nursery: +# # Run the TUI +# nursery.start_soon(ui.run_teadown_on_exit, nursery) +# +# ui.update_requests(example_requests) +# ui.set_header_text( +# new_worker_name="worker1.scd", +# new_balance="balance: 12345.6789 GPU" +# ) +# # Start background tasks +# nursery.start_soon(progress_task) +# nursery.start_soon(update_data_task) +# +# except *KeyboardInterrupt as ex_group: +# ... +# +# +# if __name__ == "__main__": +# trio.run(main) diff --git a/skynet/utils.py b/skynet/utils.py index f29bea2..ce029bd 100755 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -7,8 +7,10 @@ import logging import importlib from typing import Optional +from contextlib import contextmanager import torch +import diffusers import numpy as np from PIL import Image @@ -74,12 +76,27 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image: return crop_image(convert_from_bytes_to_img(raw), max_w, max_h) +class DummyPB: + def update(self): + ... + +@torch.compiler.disable +@contextmanager +def dummy_progress_bar(*args, **kwargs): + yield DummyPB() + + +def monkey_patch_pipeline_disable_progress_bar(pipe): + pipe.progress_bar = dummy_progress_bar + + def pipeline_for( model: str, mode: str, mem_fraction: float = 1.0, cache_dir: str | None = None ) -> DiffusionPipeline: + diffusers.utils.logging.disable_progress_bar() logging.info(f'pipeline_for {model} {mode}') assert torch.cuda.is_available() @@ -105,7 +122,9 @@ def pipeline_for( normalized_shortname = shortname.replace('-', '_') custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}') assert custom_pipeline.__model['name'] == model - return custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir) + pipe = custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir) + monkey_patch_pipeline_disable_progress_bar(pipe) + return pipe except ImportError: # TODO, uhh why not warn/error log this? @@ -121,7 +140,6 @@ def pipeline_for( logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..') params = { - 'safety_checker': None, 'torch_dtype': torch.float16, 'cache_dir': cache_dir, 'variant': 'fp16', @@ -130,6 +148,7 @@ def pipeline_for( match shortname: case 'stable': params['revision'] = 'fp16' + params['safety_checker'] = None torch.cuda.set_per_process_memory_fraction(mem_fraction) @@ -167,6 +186,8 @@ def pipeline_for( pipe = pipe.to('cuda') + monkey_patch_pipeline_disable_progress_bar(pipe) + return pipe diff --git a/uv.lock b/uv.lock index 932568c..17a748a 100644 --- a/uv.lock +++ b/uv.lock @@ -2261,6 +2261,7 @@ cuda = [ { name = "torchvision" }, { name = "transformers" }, { name = "triton" }, + { name = "urwid" }, { name = "xformers" }, ] dev = [ @@ -2312,6 +2313,7 @@ cuda = [ { name = "torchvision", specifier = "==0.20.1+cu121", index = "https://download.pytorch.org/whl/cu121" }, { name = "transformers", specifier = "==4.48.0" }, { name = "triton", specifier = "==3.1.0", index = "https://download.pytorch.org/whl/cu121" }, + { name = "urwid", specifier = ">=2.6.16" }, { name = "xformers", specifier = ">=0.0.29,<0.0.30" }, ] dev = [ @@ -2626,6 +2628,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, ] +[[package]] +name = "urwid" +version = "2.6.16" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/21/ad23c9e961b2d36d57c63686a6f86768dd945d406323fb58c84f09478530/urwid-2.6.16.tar.gz", hash = "sha256:93ad239939e44c385e64aa00027878b9e5c486d59e855ec8ab5b1e1adcdb32a2", size = 848179 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/cb/271a4f5a1bf4208dbdc96d85b9eae744cf4e5e11ac73eda76dc98c8fd2d7/urwid-2.6.16-py3-none-any.whl", hash = "sha256:de14896c6df9eb759ed1fd93e0384a5279e51e0dde8f621e4083f7a8368c0797", size = 297196 }, +] + +[[package]] +name = "wcwidth" +version = "0.2.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, +] + [[package]] name = "websocket-client" version = "1.8.0"