mirror of https://github.com/skygpu/skynet.git
Refactoring tui to be functional style
parent
cd028d15e7
commit
d8f243df9b
|
@ -1,36 +1,17 @@
|
|||
import logging
|
||||
import warnings
|
||||
|
||||
import trio
|
||||
import urwid
|
||||
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.trio import serve
|
||||
from quart_trio import QuartTrio as Quart
|
||||
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.dgpu.tui import init_tui
|
||||
from skynet.dgpu.daemon import WorkerDaemon
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
||||
def setup_logging_for_tui(level):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level)
|
||||
|
||||
fh = logging.FileHandler('dgpu.log')
|
||||
fh.setLevel(level)
|
||||
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
fh.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(fh)
|
||||
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, logging.StreamHandler):
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
async def open_dgpu_node(config: dict) -> None:
|
||||
'''
|
||||
Open a top level "GPU mgmt daemon", keep the
|
||||
|
@ -43,11 +24,10 @@ async def open_dgpu_node(config: dict) -> None:
|
|||
|
||||
tui = None
|
||||
if config['tui']:
|
||||
setup_logging_for_tui(logging.INFO)
|
||||
tui = WorkerMonitor()
|
||||
tui = init_tui()
|
||||
|
||||
conn = NetConnector(config, tui=tui)
|
||||
daemon = WorkerDaemon(conn, config, tui=tui)
|
||||
conn = NetConnector(config)
|
||||
daemon = WorkerDaemon(conn, config)
|
||||
|
||||
api: Quart|None = None
|
||||
if 'api_bind' in config:
|
||||
|
@ -71,5 +51,5 @@ async def open_dgpu_node(config: dict) -> None:
|
|||
# block until cancelled
|
||||
await daemon.serve_forever()
|
||||
|
||||
except *urwid.ExitMainLoop in ex_group:
|
||||
except *urwid.ExitMainLoop:
|
||||
...
|
||||
|
|
|
@ -12,7 +12,7 @@ from contextlib import contextmanager as cm
|
|||
import trio
|
||||
import torch
|
||||
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.dgpu.tui import maybe_update_tui
|
||||
from skynet.dgpu.errors import (
|
||||
DGPUComputeError,
|
||||
DGPUInferenceCancelled,
|
||||
|
@ -108,8 +108,7 @@ def compute_one(
|
|||
method: str,
|
||||
params: dict,
|
||||
inputs: list[bytes] = [],
|
||||
should_cancel = None,
|
||||
tui: WorkerMonitor | None = None
|
||||
should_cancel = None
|
||||
):
|
||||
if method == 'diffuse':
|
||||
method = 'txt2img'
|
||||
|
@ -130,8 +129,7 @@ def compute_one(
|
|||
if not isinstance(step, int):
|
||||
step = args[1]
|
||||
|
||||
if tui:
|
||||
tui.set_progress(step, done=total_steps)
|
||||
maybe_update_tui(lambda tui: tui.set_progress(step, done=total_steps))
|
||||
|
||||
if should_cancel:
|
||||
should_raise = trio.from_thread.run(should_cancel, request_id)
|
||||
|
@ -142,8 +140,7 @@ def compute_one(
|
|||
|
||||
return {}
|
||||
|
||||
if tui:
|
||||
tui.set_status(f'Request #{request_id}')
|
||||
maybe_update_tui(lambda tui: tui.set_status(f'Request #{request_id}'))
|
||||
|
||||
inference_step_wakeup(0)
|
||||
|
||||
|
@ -210,7 +207,6 @@ def compute_one(
|
|||
except BaseException as err:
|
||||
raise DGPUComputeError(str(err)) from err
|
||||
|
||||
if tui:
|
||||
tui.set_status('')
|
||||
maybe_update_tui(lambda tui: tui.set_status(''))
|
||||
|
||||
return output_hash, output
|
||||
|
|
|
@ -17,7 +17,7 @@ from skynet.constants import (
|
|||
from skynet.dgpu.errors import (
|
||||
DGPUComputeError,
|
||||
)
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
|
||||
from skynet.dgpu.compute import maybe_load_model, compute_one
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
@ -41,11 +41,9 @@ class WorkerDaemon:
|
|||
def __init__(
|
||||
self,
|
||||
conn: NetConnector,
|
||||
config: dict,
|
||||
tui: WorkerMonitor | None = None
|
||||
config: dict
|
||||
):
|
||||
self.conn: NetConnector = conn
|
||||
self._tui = tui
|
||||
self.auto_withdraw = (
|
||||
config['auto_withdraw']
|
||||
if 'auto_withdraw' in config else False
|
||||
|
@ -152,10 +150,12 @@ class WorkerDaemon:
|
|||
return app
|
||||
|
||||
async def _update_balance(self):
|
||||
if self._tui:
|
||||
async def _fn(tui):
|
||||
# update balance
|
||||
balance = await self.conn.get_worker_balance()
|
||||
self._tui.set_header_text(new_balance=f'balance: {balance}')
|
||||
tui.set_header_text(new_balance=f'balance: {balance}')
|
||||
|
||||
await maybe_update_tui_async(_fn)
|
||||
|
||||
# TODO? this func is kinda big and maybe is better at module
|
||||
# level to reduce indentation?
|
||||
|
@ -258,8 +258,7 @@ class WorkerDaemon:
|
|||
|
||||
with maybe_load_model(model, mode):
|
||||
try:
|
||||
if self._tui:
|
||||
self._tui.set_progress(0, done=total_step)
|
||||
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
|
||||
|
||||
output_type = 'png'
|
||||
if 'output_type' in body['params']:
|
||||
|
@ -276,7 +275,6 @@ class WorkerDaemon:
|
|||
mode, body['params'],
|
||||
inputs=inputs,
|
||||
should_cancel=self.should_cancel_work,
|
||||
tui=self._tui
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -285,8 +283,7 @@ class WorkerDaemon:
|
|||
f'Unsupported backend {self.backend}'
|
||||
)
|
||||
|
||||
if self._tui:
|
||||
self._tui.set_progress(total_step)
|
||||
maybe_update_tui(lambda tui: tui.set_progress(total_step))
|
||||
|
||||
self._last_generation_ts: str = datetime.now().isoformat()
|
||||
self._last_benchmark: list[float] = self._benchmark
|
||||
|
|
|
@ -13,7 +13,7 @@ import outcome
|
|||
from PIL import Image
|
||||
from leap.cleos import CLEOS
|
||||
from leap.protocol import Asset
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.dgpu.tui import maybe_update_tui
|
||||
from skynet.constants import (
|
||||
DEFAULT_IPFS_DOMAIN,
|
||||
GPU_CONTRACT_ABI,
|
||||
|
@ -58,7 +58,7 @@ class NetConnector:
|
|||
- CLEOS client
|
||||
|
||||
'''
|
||||
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
|
||||
def __init__(self, config: dict):
|
||||
# TODO, why these extra instance vars for an (unsynced)
|
||||
# copy of the `config` state?
|
||||
self.account = config['account']
|
||||
|
@ -82,9 +82,8 @@ class NetConnector:
|
|||
self.ipfs_domain = config['ipfs_domain']
|
||||
|
||||
self._wip_requests = {}
|
||||
self._tui = tui
|
||||
if self._tui:
|
||||
self._tui.set_header_text(new_worker_name=self.account)
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.account))
|
||||
|
||||
|
||||
# blockchain helpers
|
||||
|
@ -168,8 +167,8 @@ class NetConnector:
|
|||
n.start_soon(
|
||||
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
|
||||
|
||||
if self._tui:
|
||||
self._tui.network_update(snap)
|
||||
|
||||
maybe_update_tui(lambda tui: tui.network_update(snap))
|
||||
|
||||
return snap
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import urwid
|
||||
import trio
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import trio
|
||||
import urwid
|
||||
|
||||
|
||||
class WorkerMonitor:
|
||||
|
@ -163,86 +166,41 @@ class WorkerMonitor:
|
|||
self.update_requests(queue)
|
||||
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Example usage
|
||||
# # -----------------------------------------------------------------------------
|
||||
#
|
||||
# async def main():
|
||||
# # Example data
|
||||
# example_requests = [
|
||||
# {
|
||||
# "id": 12,
|
||||
# "model": "black-forest-labs/FLUX.1-schnell",
|
||||
# "prompt": "Generate an answer about quantum entanglement.",
|
||||
# "user": "alice123",
|
||||
# "reward": "20.0000 GPU",
|
||||
# "workers": ["workerA", "workerB"],
|
||||
# },
|
||||
# {
|
||||
# "id": 5,
|
||||
# "model": "some-other-model/v2.0",
|
||||
# "prompt": "A story about dragons.",
|
||||
# "user": "bobthebuilder",
|
||||
# "reward": "15.0000 GPU",
|
||||
# "workers": ["workerX"],
|
||||
# },
|
||||
# {
|
||||
# "id": 99,
|
||||
# "model": "cool-model/turbo",
|
||||
# "prompt": "Classify sentiment in these tweets.",
|
||||
# "user": "charlie",
|
||||
# "reward": "25.5000 GPU",
|
||||
# "workers": ["workerOne", "workerTwo", "workerThree"],
|
||||
# },
|
||||
# ]
|
||||
#
|
||||
# ui = WorkerMonitor()
|
||||
#
|
||||
# async def progress_task():
|
||||
# # Fill from 0% to 100%
|
||||
# for pct in range(101):
|
||||
# ui.set_progress(
|
||||
# current=pct,
|
||||
# status_str=f"Request #1234 ({pct}%)"
|
||||
# )
|
||||
# await trio.sleep(0.05)
|
||||
# # Reset to 0
|
||||
# ui.set_progress(
|
||||
# current=0,
|
||||
# status_str="Starting again..."
|
||||
# )
|
||||
#
|
||||
# async def update_data_task():
|
||||
# await trio.sleep(3) # Wait a bit, then update requests
|
||||
# new_data = [{
|
||||
# "id": 101,
|
||||
# "model": "new-model/v1.0",
|
||||
# "prompt": "Say hi to the world.",
|
||||
# "user": "eve",
|
||||
# "reward": "50.0000 GPU",
|
||||
# "workers": ["workerFresh", "workerPower"],
|
||||
# }]
|
||||
# ui.update_requests(new_data)
|
||||
# ui.set_header_text(new_worker_name="NewNodeName",
|
||||
# new_balance="balance: 12345.6789 GPU")
|
||||
#
|
||||
# try:
|
||||
# async with trio.open_nursery() as nursery:
|
||||
# # Run the TUI
|
||||
# nursery.start_soon(ui.run_teadown_on_exit, nursery)
|
||||
#
|
||||
# ui.update_requests(example_requests)
|
||||
# ui.set_header_text(
|
||||
# new_worker_name="worker1.scd",
|
||||
# new_balance="balance: 12345.6789 GPU"
|
||||
# )
|
||||
# # Start background tasks
|
||||
# nursery.start_soon(progress_task)
|
||||
# nursery.start_soon(update_data_task)
|
||||
#
|
||||
# except *KeyboardInterrupt as ex_group:
|
||||
# ...
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# trio.run(main)
|
||||
def setup_logging_for_tui(level):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level)
|
||||
|
||||
fh = logging.FileHandler('dgpu.log')
|
||||
fh.setLevel(level)
|
||||
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
fh.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(fh)
|
||||
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, logging.StreamHandler):
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
_tui = None
|
||||
def init_tui():
|
||||
global _tui
|
||||
assert not _tui
|
||||
setup_logging_for_tui(logging.INFO)
|
||||
_tui = WorkerMonitor()
|
||||
return _tui
|
||||
|
||||
|
||||
def maybe_update_tui(fn):
|
||||
global _tui
|
||||
if _tui:
|
||||
fn(_tui)
|
||||
|
||||
|
||||
async def maybe_update_tui_async(fn):
|
||||
global _tui
|
||||
if _tui:
|
||||
await fn(_tui)
|
||||
|
|
Loading…
Reference in New Issue