mirror of https://github.com/skygpu/skynet.git
Begin adding TUI
parent
93ee65087f
commit
b3dc7c1074
|
@ -61,6 +61,7 @@ cuda = [
|
||||||
"basicsr>=1.4.2,<2",
|
"basicsr>=1.4.2,<2",
|
||||||
"realesrgan>=0.3.0,<0.4",
|
"realesrgan>=0.3.0,<0.4",
|
||||||
"sentencepiece>=0.2.0",
|
"sentencepiece>=0.2.0",
|
||||||
|
"urwid>=2.6.16",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
@ -6,11 +7,31 @@ from hypercorn.config import Config
|
||||||
from hypercorn.trio import serve
|
from hypercorn.trio import serve
|
||||||
from quart_trio import QuartTrio as Quart
|
from quart_trio import QuartTrio as Quart
|
||||||
|
|
||||||
|
from skynet.dgpu.tui import WorkerMonitor
|
||||||
from skynet.dgpu.compute import ModelMngr
|
from skynet.dgpu.compute import ModelMngr
|
||||||
from skynet.dgpu.daemon import WorkerDaemon
|
from skynet.dgpu.daemon import WorkerDaemon
|
||||||
from skynet.dgpu.network import NetConnector
|
from skynet.dgpu.network import NetConnector
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging_for_tui(level):
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
fh = logging.FileHandler('dgpu.log')
|
||||||
|
fh.setLevel(level)
|
||||||
|
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(fh)
|
||||||
|
|
||||||
|
for handler in logger.handlers:
|
||||||
|
if isinstance(handler, logging.StreamHandler):
|
||||||
|
logger.removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
async def open_dgpu_node(config: dict) -> None:
|
async def open_dgpu_node(config: dict) -> None:
|
||||||
'''
|
'''
|
||||||
Open a top level "GPU mgmt daemon", keep the
|
Open a top level "GPU mgmt daemon", keep the
|
||||||
|
@ -18,13 +39,17 @@ async def open_dgpu_node(config: dict) -> None:
|
||||||
and *maybe* serve a `hypercorn` web API.
|
and *maybe* serve a `hypercorn` web API.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# suppress logs from httpx (logs url + status after every query)
|
# suppress logs from httpx (logs url + status after every query)
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
conn = NetConnector(config)
|
tui = None
|
||||||
mm = ModelMngr(config)
|
if config['tui']:
|
||||||
daemon = WorkerDaemon(mm, conn, config)
|
setup_logging_for_tui(logging.INFO)
|
||||||
|
tui = WorkerMonitor()
|
||||||
|
|
||||||
|
conn = NetConnector(config, tui=tui)
|
||||||
|
mm = ModelMngr(config, tui=tui)
|
||||||
|
daemon = WorkerDaemon(mm, conn, config, tui=tui)
|
||||||
|
|
||||||
api: Quart|None = None
|
api: Quart|None = None
|
||||||
if 'api_bind' in config:
|
if 'api_bind' in config:
|
||||||
|
@ -35,6 +60,8 @@ async def open_dgpu_node(config: dict) -> None:
|
||||||
tn: trio.Nursery
|
tn: trio.Nursery
|
||||||
async with trio.open_nursery() as tn:
|
async with trio.open_nursery() as tn:
|
||||||
tn.start_soon(daemon.snap_updater_task)
|
tn.start_soon(daemon.snap_updater_task)
|
||||||
|
if tui:
|
||||||
|
tn.start_soon(tui.run)
|
||||||
|
|
||||||
# TODO, consider a more explicit `as hypercorn_serve`
|
# TODO, consider a more explicit `as hypercorn_serve`
|
||||||
# to clarify?
|
# to clarify?
|
||||||
|
@ -42,5 +69,9 @@ async def open_dgpu_node(config: dict) -> None:
|
||||||
logging.info(f'serving api @ {config["api_bind"]}')
|
logging.info(f'serving api @ {config["api_bind"]}')
|
||||||
tn.start_soon(serve, api, api_conf)
|
tn.start_soon(serve, api, api_conf)
|
||||||
|
|
||||||
# block until cancelled
|
try:
|
||||||
await daemon.serve_forever()
|
# block until cancelled
|
||||||
|
await daemon.serve_forever()
|
||||||
|
|
||||||
|
except *urwid.ExitMainLoop in ex_group:
|
||||||
|
...
|
||||||
|
|
|
@ -11,6 +11,7 @@ from hashlib import sha256
|
||||||
import trio
|
import trio
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from skynet.dgpu.tui import WorkerMonitor
|
||||||
from skynet.dgpu.errors import (
|
from skynet.dgpu.errors import (
|
||||||
DGPUComputeError,
|
DGPUComputeError,
|
||||||
DGPUInferenceCancelled,
|
DGPUInferenceCancelled,
|
||||||
|
@ -72,7 +73,8 @@ class ModelMngr:
|
||||||
checking load state, and unloading when no-longer-needed/finished.
|
checking load state, and unloading when no-longer-needed/finished.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
|
||||||
|
self._tui = tui
|
||||||
self.cache_dir = None
|
self.cache_dir = None
|
||||||
if 'hf_home' in config:
|
if 'hf_home' in config:
|
||||||
self.cache_dir = config['hf_home']
|
self.cache_dir = config['hf_home']
|
||||||
|
@ -80,8 +82,6 @@ class ModelMngr:
|
||||||
self._model_name: str = ''
|
self._model_name: str = ''
|
||||||
self._model_mode: str = ''
|
self._model_mode: str = ''
|
||||||
|
|
||||||
# self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
|
|
||||||
|
|
||||||
def log_debug_info(self):
|
def log_debug_info(self):
|
||||||
logging.debug('memory summary:')
|
logging.debug('memory summary:')
|
||||||
logging.debug('\n' + torch.cuda.memory_summary())
|
logging.debug('\n' + torch.cuda.memory_summary())
|
||||||
|
@ -110,6 +110,7 @@ class ModelMngr:
|
||||||
) -> None:
|
) -> None:
|
||||||
logging.info(f'loading model {name}...')
|
logging.info(f'loading model {name}...')
|
||||||
self.unload_model()
|
self.unload_model()
|
||||||
|
|
||||||
self._model = pipeline_for(
|
self._model = pipeline_for(
|
||||||
name, mode, cache_dir=self.cache_dir)
|
name, mode, cache_dir=self.cache_dir)
|
||||||
self._model_mode = mode
|
self._model_mode = mode
|
||||||
|
@ -124,19 +125,30 @@ class ModelMngr:
|
||||||
params: dict,
|
params: dict,
|
||||||
inputs: list[bytes] = []
|
inputs: list[bytes] = []
|
||||||
):
|
):
|
||||||
def maybe_cancel_work(step, *args, **kwargs):
|
total_steps = params['step']
|
||||||
|
def inference_step_wakeup(*args, **kwargs):
|
||||||
'''This is a callback function that gets invoked every inference step,
|
'''This is a callback function that gets invoked every inference step,
|
||||||
we need to raise an exception here if we need to cancel work
|
we need to raise an exception here if we need to cancel work
|
||||||
'''
|
'''
|
||||||
if self._should_cancel:
|
step = args[0]
|
||||||
should_raise = trio.from_thread.run(self._should_cancel, request_id)
|
# compat with callback_on_step_end
|
||||||
if should_raise:
|
if not isinstance(step, int):
|
||||||
logging.warning(f'CANCELLING work at step {step}')
|
step = args[1]
|
||||||
raise DGPUInferenceCancelled('network cancel')
|
|
||||||
|
if self._tui:
|
||||||
|
self._tui.set_progress(step, done=total_steps)
|
||||||
|
|
||||||
|
should_raise = trio.from_thread.run(self._should_cancel, request_id)
|
||||||
|
if should_raise:
|
||||||
|
logging.warning(f'CANCELLING work at step {step}')
|
||||||
|
raise DGPUInferenceCancelled('network cancel')
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
maybe_cancel_work(0)
|
if self._tui:
|
||||||
|
self._tui.set_status(f'Request #{request_id}')
|
||||||
|
|
||||||
|
inference_step_wakeup(0)
|
||||||
|
|
||||||
output_type = 'png'
|
output_type = 'png'
|
||||||
if 'output_type' in params:
|
if 'output_type' in params:
|
||||||
|
@ -157,10 +169,10 @@ class ModelMngr:
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||||
|
|
||||||
if 'flux' in name.lower():
|
if 'flux' in name.lower():
|
||||||
extra_params['callback_on_step_end'] = maybe_cancel_work
|
extra_params['callback_on_step_end'] = inference_step_wakeup
|
||||||
|
|
||||||
else:
|
else:
|
||||||
extra_params['callback'] = maybe_cancel_work
|
extra_params['callback'] = inference_step_wakeup
|
||||||
extra_params['callback_steps'] = 1
|
extra_params['callback_steps'] = 1
|
||||||
|
|
||||||
output = self._model(
|
output = self._model(
|
||||||
|
@ -213,4 +225,7 @@ class ModelMngr:
|
||||||
finally:
|
finally:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if self._tui:
|
||||||
|
self._tui.set_status('')
|
||||||
|
|
||||||
return output_hash, output
|
return output_hash, output
|
||||||
|
|
|
@ -17,6 +17,7 @@ from skynet.constants import (
|
||||||
from skynet.dgpu.errors import (
|
from skynet.dgpu.errors import (
|
||||||
DGPUComputeError,
|
DGPUComputeError,
|
||||||
)
|
)
|
||||||
|
from skynet.dgpu.tui import WorkerMonitor
|
||||||
from skynet.dgpu.compute import ModelMngr
|
from skynet.dgpu.compute import ModelMngr
|
||||||
from skynet.dgpu.network import NetConnector
|
from skynet.dgpu.network import NetConnector
|
||||||
|
|
||||||
|
@ -41,10 +42,12 @@ class WorkerDaemon:
|
||||||
self,
|
self,
|
||||||
mm: ModelMngr,
|
mm: ModelMngr,
|
||||||
conn: NetConnector,
|
conn: NetConnector,
|
||||||
config: dict
|
config: dict,
|
||||||
|
tui: WorkerMonitor | None = None
|
||||||
):
|
):
|
||||||
self.mm: ModelMngr = mm
|
self.mm: ModelMngr = mm
|
||||||
self.conn: NetConnector = conn
|
self.conn: NetConnector = conn
|
||||||
|
self._tui = tui
|
||||||
self.auto_withdraw = (
|
self.auto_withdraw = (
|
||||||
config['auto_withdraw']
|
config['auto_withdraw']
|
||||||
if 'auto_withdraw' in config else False
|
if 'auto_withdraw' in config else False
|
||||||
|
@ -150,6 +153,12 @@ class WorkerDaemon:
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
async def _update_balance(self):
|
||||||
|
if self._tui:
|
||||||
|
# update balance
|
||||||
|
balance = await self.conn.get_worker_balance()
|
||||||
|
self._tui.set_header_text(new_balance=f'balance: {balance}')
|
||||||
|
|
||||||
# TODO? this func is kinda big and maybe is better at module
|
# TODO? this func is kinda big and maybe is better at module
|
||||||
# level to reduce indentation?
|
# level to reduce indentation?
|
||||||
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
|
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
|
||||||
|
@ -238,6 +247,8 @@ class WorkerDaemon:
|
||||||
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
||||||
logging.info(f'calculated request hash: {request_hash}')
|
logging.info(f'calculated request hash: {request_hash}')
|
||||||
|
|
||||||
|
total_step = body['params']['step']
|
||||||
|
|
||||||
# TODO: validate request
|
# TODO: validate request
|
||||||
|
|
||||||
resp = await self.conn.begin_work(rid)
|
resp = await self.conn.begin_work(rid)
|
||||||
|
@ -246,6 +257,9 @@ class WorkerDaemon:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
if self._tui:
|
||||||
|
self._tui.set_progress(0, done=total_step)
|
||||||
|
|
||||||
output_type = 'png'
|
output_type = 'png'
|
||||||
if 'output_type' in body['params']:
|
if 'output_type' in body['params']:
|
||||||
output_type = body['params']['output_type']
|
output_type = body['params']['output_type']
|
||||||
|
@ -269,6 +283,9 @@ class WorkerDaemon:
|
||||||
f'Unsupported backend {self.backend}'
|
f'Unsupported backend {self.backend}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._tui:
|
||||||
|
self._tui.set_progress(total_step)
|
||||||
|
|
||||||
self._last_generation_ts: str = datetime.now().isoformat()
|
self._last_generation_ts: str = datetime.now().isoformat()
|
||||||
self._last_benchmark: list[float] = self._benchmark
|
self._last_benchmark: list[float] = self._benchmark
|
||||||
self._benchmark: list[float] = []
|
self._benchmark: list[float] = []
|
||||||
|
@ -277,6 +294,9 @@ class WorkerDaemon:
|
||||||
|
|
||||||
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
||||||
|
|
||||||
|
await self._update_balance()
|
||||||
|
|
||||||
|
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
if 'network cancel' not in str(err):
|
if 'network cancel' not in str(err):
|
||||||
logging.exception('Failed to serve model request !?\n')
|
logging.exception('Failed to serve model request !?\n')
|
||||||
|
@ -294,6 +314,7 @@ class WorkerDaemon:
|
||||||
# -[ ] keeps tasks-as-funcs style prominent
|
# -[ ] keeps tasks-as-funcs style prominent
|
||||||
# -[ ] avoids so much indentation due to methods
|
# -[ ] avoids so much indentation due to methods
|
||||||
async def serve_forever(self):
|
async def serve_forever(self):
|
||||||
|
await self._update_balance()
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if self.auto_withdraw:
|
if self.auto_withdraw:
|
||||||
|
|
|
@ -13,6 +13,7 @@ import outcome
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from leap.cleos import CLEOS
|
from leap.cleos import CLEOS
|
||||||
from leap.protocol import Asset
|
from leap.protocol import Asset
|
||||||
|
from skynet.dgpu.tui import WorkerMonitor
|
||||||
from skynet.constants import (
|
from skynet.constants import (
|
||||||
DEFAULT_IPFS_DOMAIN,
|
DEFAULT_IPFS_DOMAIN,
|
||||||
GPU_CONTRACT_ABI,
|
GPU_CONTRACT_ABI,
|
||||||
|
@ -57,7 +58,7 @@ class NetConnector:
|
||||||
- CLEOS client
|
- CLEOS client
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
|
||||||
# TODO, why these extra instance vars for an (unsynced)
|
# TODO, why these extra instance vars for an (unsynced)
|
||||||
# copy of the `config` state?
|
# copy of the `config` state?
|
||||||
self.account = config['account']
|
self.account = config['account']
|
||||||
|
@ -81,6 +82,10 @@ class NetConnector:
|
||||||
self.ipfs_domain = config['ipfs_domain']
|
self.ipfs_domain = config['ipfs_domain']
|
||||||
|
|
||||||
self._wip_requests = {}
|
self._wip_requests = {}
|
||||||
|
self._tui = tui
|
||||||
|
if self._tui:
|
||||||
|
self._tui.set_header_text(new_worker_name=self.account)
|
||||||
|
|
||||||
|
|
||||||
# blockchain helpers
|
# blockchain helpers
|
||||||
|
|
||||||
|
@ -163,6 +168,9 @@ class NetConnector:
|
||||||
n.start_soon(
|
n.start_soon(
|
||||||
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
|
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
|
||||||
|
|
||||||
|
if self._tui:
|
||||||
|
self._tui.network_update(snap)
|
||||||
|
|
||||||
return snap
|
return snap
|
||||||
|
|
||||||
async def begin_work(self, request_id: int):
|
async def begin_work(self, request_id: int):
|
||||||
|
|
|
@ -0,0 +1,248 @@
|
||||||
|
import urwid
|
||||||
|
import trio
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerMonitor:
|
||||||
|
def __init__(self):
|
||||||
|
self.requests = []
|
||||||
|
self.header_info = {}
|
||||||
|
|
||||||
|
self.palette = [
|
||||||
|
('headerbar', 'white', 'dark blue'),
|
||||||
|
('request_row', 'white', 'dark gray'),
|
||||||
|
('worker_row', 'light gray', 'black'),
|
||||||
|
('progress_normal', 'black', 'light gray'),
|
||||||
|
('progress_complete', 'black', 'dark green'),
|
||||||
|
('body', 'white', 'black'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Top bar (header) ---
|
||||||
|
worker_name = self.header_info.get('left', "unknown")
|
||||||
|
balance = self.header_info.get('right', "balance: unknown")
|
||||||
|
|
||||||
|
self.worker_name_widget = urwid.Text(worker_name)
|
||||||
|
self.balance_widget = urwid.Text(balance, align='right')
|
||||||
|
|
||||||
|
header = urwid.Columns([self.worker_name_widget, self.balance_widget])
|
||||||
|
header_attr = urwid.AttrMap(header, 'headerbar')
|
||||||
|
|
||||||
|
# --- Body (List of requests) ---
|
||||||
|
self.body_listbox = self._create_listbox_body(self.requests)
|
||||||
|
|
||||||
|
# --- Bottom bar (progress) ---
|
||||||
|
self.status_text = urwid.Text("Request: none", align='left')
|
||||||
|
self.progress_bar = urwid.ProgressBar(
|
||||||
|
'progress_normal',
|
||||||
|
'progress_complete',
|
||||||
|
current=0,
|
||||||
|
done=100
|
||||||
|
)
|
||||||
|
|
||||||
|
footer_cols = urwid.Columns([
|
||||||
|
('fixed', 20, self.status_text),
|
||||||
|
self.progress_bar,
|
||||||
|
])
|
||||||
|
|
||||||
|
# Build the main frame
|
||||||
|
frame = urwid.Frame(
|
||||||
|
self.body_listbox,
|
||||||
|
header=header_attr,
|
||||||
|
footer=footer_cols
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up the main loop with Trio
|
||||||
|
self.event_loop = urwid.TrioEventLoop()
|
||||||
|
self.main_loop = urwid.MainLoop(
|
||||||
|
frame,
|
||||||
|
palette=self.palette,
|
||||||
|
event_loop=self.event_loop,
|
||||||
|
unhandled_input=self._exit_on_q
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_listbox_body(self, requests):
|
||||||
|
"""
|
||||||
|
Build a ListBox (vertical list) of requests & workers using SimpleFocusListWalker.
|
||||||
|
"""
|
||||||
|
widgets = self._build_request_widgets(requests)
|
||||||
|
walker = urwid.SimpleFocusListWalker(widgets)
|
||||||
|
return urwid.ListBox(walker)
|
||||||
|
|
||||||
|
def _build_request_widgets(self, requests):
|
||||||
|
"""
|
||||||
|
Build a list of Urwid widgets (one row per request + per worker).
|
||||||
|
"""
|
||||||
|
row_widgets = []
|
||||||
|
|
||||||
|
for req in requests:
|
||||||
|
# Build a columns widget for the request row
|
||||||
|
columns = urwid.Columns([
|
||||||
|
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
|
||||||
|
('weight', 3, urwid.Text(req['model'])),
|
||||||
|
('weight', 3, urwid.Text(req['prompt'])),
|
||||||
|
('fixed', 13, urwid.Text(req['user'])),
|
||||||
|
('fixed', 13, urwid.Text(req['reward'])),
|
||||||
|
], dividechars=1)
|
||||||
|
|
||||||
|
# Wrap the columns with an attribute map for coloring
|
||||||
|
request_row = urwid.AttrMap(columns, 'request_row')
|
||||||
|
row_widgets.append(request_row)
|
||||||
|
|
||||||
|
# Then add each worker in its own line below
|
||||||
|
for w in req["workers"]:
|
||||||
|
worker_line = urwid.Text(f" {w}")
|
||||||
|
worker_row = urwid.AttrMap(worker_line, 'worker_row')
|
||||||
|
row_widgets.append(worker_row)
|
||||||
|
|
||||||
|
# Optional blank line after each request
|
||||||
|
row_widgets.append(urwid.Text(""))
|
||||||
|
|
||||||
|
return row_widgets
|
||||||
|
|
||||||
|
def _exit_on_q(self, key):
|
||||||
|
"""Exit the TUI on 'q' or 'Q'."""
|
||||||
|
if key in ('q', 'Q'):
|
||||||
|
raise urwid.ExitMainLoop()
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""
|
||||||
|
Run the TUI in an async context (Trio).
|
||||||
|
This method blocks until the user quits (pressing q/Q).
|
||||||
|
"""
|
||||||
|
with self.main_loop.start():
|
||||||
|
await self.event_loop.run_async()
|
||||||
|
|
||||||
|
raise urwid.ExitMainLoop()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Public Methods to Update Various Parts of the UI
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def set_status(self, status: str):
|
||||||
|
self.status_text.set_text(status)
|
||||||
|
|
||||||
|
def set_progress(self, current, done=None):
|
||||||
|
"""
|
||||||
|
Update the bottom progress bar.
|
||||||
|
- `current`: new current progress value (int).
|
||||||
|
- `done`: max progress value (int). If None, we don’t change it.
|
||||||
|
"""
|
||||||
|
if done is not None:
|
||||||
|
self.progress_bar.done = done
|
||||||
|
|
||||||
|
self.progress_bar.current = current
|
||||||
|
|
||||||
|
pct = 0
|
||||||
|
if self.progress_bar.done != 0:
|
||||||
|
pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
|
||||||
|
|
||||||
|
def update_requests(self, new_requests):
|
||||||
|
"""
|
||||||
|
Replace the data in the existing ListBox with new request widgets.
|
||||||
|
"""
|
||||||
|
new_widgets = self._build_request_widgets(new_requests)
|
||||||
|
self.body_listbox.body[:] = new_widgets # replace content of the list walker
|
||||||
|
|
||||||
|
def set_header_text(self, new_worker_name=None, new_balance=None):
|
||||||
|
"""
|
||||||
|
Update the text in the header bar for worker name and/or balance.
|
||||||
|
"""
|
||||||
|
if new_worker_name is not None:
|
||||||
|
self.worker_name_widget.set_text(new_worker_name)
|
||||||
|
if new_balance is not None:
|
||||||
|
self.balance_widget.set_text(new_balance)
|
||||||
|
|
||||||
|
def network_update(self, snapshot: dict):
|
||||||
|
queue = [
|
||||||
|
{
|
||||||
|
**r,
|
||||||
|
**(json.loads(r['body'])['params']),
|
||||||
|
'workers': [s['worker'] for s in snapshot['requests'][r['id']]]
|
||||||
|
}
|
||||||
|
for r in snapshot['queue']
|
||||||
|
]
|
||||||
|
self.update_requests(queue)
|
||||||
|
|
||||||
|
|
||||||
|
# # -----------------------------------------------------------------------------
|
||||||
|
# # Example usage
|
||||||
|
# # -----------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# async def main():
|
||||||
|
# # Example data
|
||||||
|
# example_requests = [
|
||||||
|
# {
|
||||||
|
# "id": 12,
|
||||||
|
# "model": "black-forest-labs/FLUX.1-schnell",
|
||||||
|
# "prompt": "Generate an answer about quantum entanglement.",
|
||||||
|
# "user": "alice123",
|
||||||
|
# "reward": "20.0000 GPU",
|
||||||
|
# "workers": ["workerA", "workerB"],
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "id": 5,
|
||||||
|
# "model": "some-other-model/v2.0",
|
||||||
|
# "prompt": "A story about dragons.",
|
||||||
|
# "user": "bobthebuilder",
|
||||||
|
# "reward": "15.0000 GPU",
|
||||||
|
# "workers": ["workerX"],
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "id": 99,
|
||||||
|
# "model": "cool-model/turbo",
|
||||||
|
# "prompt": "Classify sentiment in these tweets.",
|
||||||
|
# "user": "charlie",
|
||||||
|
# "reward": "25.5000 GPU",
|
||||||
|
# "workers": ["workerOne", "workerTwo", "workerThree"],
|
||||||
|
# },
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# ui = WorkerMonitor()
|
||||||
|
#
|
||||||
|
# async def progress_task():
|
||||||
|
# # Fill from 0% to 100%
|
||||||
|
# for pct in range(101):
|
||||||
|
# ui.set_progress(
|
||||||
|
# current=pct,
|
||||||
|
# status_str=f"Request #1234 ({pct}%)"
|
||||||
|
# )
|
||||||
|
# await trio.sleep(0.05)
|
||||||
|
# # Reset to 0
|
||||||
|
# ui.set_progress(
|
||||||
|
# current=0,
|
||||||
|
# status_str="Starting again..."
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# async def update_data_task():
|
||||||
|
# await trio.sleep(3) # Wait a bit, then update requests
|
||||||
|
# new_data = [{
|
||||||
|
# "id": 101,
|
||||||
|
# "model": "new-model/v1.0",
|
||||||
|
# "prompt": "Say hi to the world.",
|
||||||
|
# "user": "eve",
|
||||||
|
# "reward": "50.0000 GPU",
|
||||||
|
# "workers": ["workerFresh", "workerPower"],
|
||||||
|
# }]
|
||||||
|
# ui.update_requests(new_data)
|
||||||
|
# ui.set_header_text(new_worker_name="NewNodeName",
|
||||||
|
# new_balance="balance: 12345.6789 GPU")
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# async with trio.open_nursery() as nursery:
|
||||||
|
# # Run the TUI
|
||||||
|
# nursery.start_soon(ui.run_teadown_on_exit, nursery)
|
||||||
|
#
|
||||||
|
# ui.update_requests(example_requests)
|
||||||
|
# ui.set_header_text(
|
||||||
|
# new_worker_name="worker1.scd",
|
||||||
|
# new_balance="balance: 12345.6789 GPU"
|
||||||
|
# )
|
||||||
|
# # Start background tasks
|
||||||
|
# nursery.start_soon(progress_task)
|
||||||
|
# nursery.start_soon(update_data_task)
|
||||||
|
#
|
||||||
|
# except *KeyboardInterrupt as ex_group:
|
||||||
|
# ...
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# trio.run(main)
|
|
@ -7,8 +7,10 @@ import logging
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import diffusers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -74,12 +76,27 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
||||||
return crop_image(convert_from_bytes_to_img(raw), max_w, max_h)
|
return crop_image(convert_from_bytes_to_img(raw), max_w, max_h)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPB:
|
||||||
|
def update(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@torch.compiler.disable
|
||||||
|
@contextmanager
|
||||||
|
def dummy_progress_bar(*args, **kwargs):
|
||||||
|
yield DummyPB()
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_patch_pipeline_disable_progress_bar(pipe):
|
||||||
|
pipe.progress_bar = dummy_progress_bar
|
||||||
|
|
||||||
|
|
||||||
def pipeline_for(
|
def pipeline_for(
|
||||||
model: str,
|
model: str,
|
||||||
mode: str,
|
mode: str,
|
||||||
mem_fraction: float = 1.0,
|
mem_fraction: float = 1.0,
|
||||||
cache_dir: str | None = None
|
cache_dir: str | None = None
|
||||||
) -> DiffusionPipeline:
|
) -> DiffusionPipeline:
|
||||||
|
diffusers.utils.logging.disable_progress_bar()
|
||||||
|
|
||||||
logging.info(f'pipeline_for {model} {mode}')
|
logging.info(f'pipeline_for {model} {mode}')
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
|
@ -105,7 +122,9 @@ def pipeline_for(
|
||||||
normalized_shortname = shortname.replace('-', '_')
|
normalized_shortname = shortname.replace('-', '_')
|
||||||
custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}')
|
custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}')
|
||||||
assert custom_pipeline.__model['name'] == model
|
assert custom_pipeline.__model['name'] == model
|
||||||
return custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
|
pipe = custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
|
||||||
|
monkey_patch_pipeline_disable_progress_bar(pipe)
|
||||||
|
return pipe
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# TODO, uhh why not warn/error log this?
|
# TODO, uhh why not warn/error log this?
|
||||||
|
@ -121,7 +140,6 @@ def pipeline_for(
|
||||||
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'safety_checker': None,
|
|
||||||
'torch_dtype': torch.float16,
|
'torch_dtype': torch.float16,
|
||||||
'cache_dir': cache_dir,
|
'cache_dir': cache_dir,
|
||||||
'variant': 'fp16',
|
'variant': 'fp16',
|
||||||
|
@ -130,6 +148,7 @@ def pipeline_for(
|
||||||
match shortname:
|
match shortname:
|
||||||
case 'stable':
|
case 'stable':
|
||||||
params['revision'] = 'fp16'
|
params['revision'] = 'fp16'
|
||||||
|
params['safety_checker'] = None
|
||||||
|
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
|
||||||
|
@ -167,6 +186,8 @@ def pipeline_for(
|
||||||
|
|
||||||
pipe = pipe.to('cuda')
|
pipe = pipe.to('cuda')
|
||||||
|
|
||||||
|
monkey_patch_pipeline_disable_progress_bar(pipe)
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
|
24
uv.lock
24
uv.lock
|
@ -2261,6 +2261,7 @@ cuda = [
|
||||||
{ name = "torchvision" },
|
{ name = "torchvision" },
|
||||||
{ name = "transformers" },
|
{ name = "transformers" },
|
||||||
{ name = "triton" },
|
{ name = "triton" },
|
||||||
|
{ name = "urwid" },
|
||||||
{ name = "xformers" },
|
{ name = "xformers" },
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
|
@ -2312,6 +2313,7 @@ cuda = [
|
||||||
{ name = "torchvision", specifier = "==0.20.1+cu121", index = "https://download.pytorch.org/whl/cu121" },
|
{ name = "torchvision", specifier = "==0.20.1+cu121", index = "https://download.pytorch.org/whl/cu121" },
|
||||||
{ name = "transformers", specifier = "==4.48.0" },
|
{ name = "transformers", specifier = "==4.48.0" },
|
||||||
{ name = "triton", specifier = "==3.1.0", index = "https://download.pytorch.org/whl/cu121" },
|
{ name = "triton", specifier = "==3.1.0", index = "https://download.pytorch.org/whl/cu121" },
|
||||||
|
{ name = "urwid", specifier = ">=2.6.16" },
|
||||||
{ name = "xformers", specifier = ">=0.0.29,<0.0.30" },
|
{ name = "xformers", specifier = ">=0.0.29,<0.0.30" },
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
|
@ -2626,6 +2628,28 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
|
{ url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "urwid"
|
||||||
|
version = "2.6.16"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
{ name = "wcwidth" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/98/21/ad23c9e961b2d36d57c63686a6f86768dd945d406323fb58c84f09478530/urwid-2.6.16.tar.gz", hash = "sha256:93ad239939e44c385e64aa00027878b9e5c486d59e855ec8ab5b1e1adcdb32a2", size = 848179 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/54/cb/271a4f5a1bf4208dbdc96d85b9eae744cf4e5e11ac73eda76dc98c8fd2d7/urwid-2.6.16-py3-none-any.whl", hash = "sha256:de14896c6df9eb759ed1fd93e0384a5279e51e0dde8f621e4083f7a8368c0797", size = 297196 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wcwidth"
|
||||||
|
version = "0.2.13"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "websocket-client"
|
name = "websocket-client"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
|
|
Loading…
Reference in New Issue