mirror of https://github.com/skygpu/skynet.git
Refactoring tui to be functional style
parent
cd028d15e7
commit
d8f243df9b
|
@ -1,36 +1,17 @@
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
import urwid
|
||||||
|
|
||||||
from hypercorn.config import Config
|
from hypercorn.config import Config
|
||||||
from hypercorn.trio import serve
|
from hypercorn.trio import serve
|
||||||
from quart_trio import QuartTrio as Quart
|
from quart_trio import QuartTrio as Quart
|
||||||
|
|
||||||
from skynet.dgpu.tui import WorkerMonitor
|
from skynet.dgpu.tui import init_tui
|
||||||
from skynet.dgpu.daemon import WorkerDaemon
|
from skynet.dgpu.daemon import WorkerDaemon
|
||||||
from skynet.dgpu.network import NetConnector
|
from skynet.dgpu.network import NetConnector
|
||||||
|
|
||||||
|
|
||||||
def setup_logging_for_tui(level):
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(level)
|
|
||||||
|
|
||||||
fh = logging.FileHandler('dgpu.log')
|
|
||||||
fh.setLevel(level)
|
|
||||||
|
|
||||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
||||||
fh.setFormatter(formatter)
|
|
||||||
|
|
||||||
logger.addHandler(fh)
|
|
||||||
|
|
||||||
for handler in logger.handlers:
|
|
||||||
if isinstance(handler, logging.StreamHandler):
|
|
||||||
logger.removeHandler(handler)
|
|
||||||
|
|
||||||
|
|
||||||
async def open_dgpu_node(config: dict) -> None:
|
async def open_dgpu_node(config: dict) -> None:
|
||||||
'''
|
'''
|
||||||
Open a top level "GPU mgmt daemon", keep the
|
Open a top level "GPU mgmt daemon", keep the
|
||||||
|
@ -43,11 +24,10 @@ async def open_dgpu_node(config: dict) -> None:
|
||||||
|
|
||||||
tui = None
|
tui = None
|
||||||
if config['tui']:
|
if config['tui']:
|
||||||
setup_logging_for_tui(logging.INFO)
|
tui = init_tui()
|
||||||
tui = WorkerMonitor()
|
|
||||||
|
|
||||||
conn = NetConnector(config, tui=tui)
|
conn = NetConnector(config)
|
||||||
daemon = WorkerDaemon(conn, config, tui=tui)
|
daemon = WorkerDaemon(conn, config)
|
||||||
|
|
||||||
api: Quart|None = None
|
api: Quart|None = None
|
||||||
if 'api_bind' in config:
|
if 'api_bind' in config:
|
||||||
|
@ -71,5 +51,5 @@ async def open_dgpu_node(config: dict) -> None:
|
||||||
# block until cancelled
|
# block until cancelled
|
||||||
await daemon.serve_forever()
|
await daemon.serve_forever()
|
||||||
|
|
||||||
except *urwid.ExitMainLoop in ex_group:
|
except *urwid.ExitMainLoop:
|
||||||
...
|
...
|
||||||
|
|
|
@ -12,7 +12,7 @@ from contextlib import contextmanager as cm
|
||||||
import trio
|
import trio
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from skynet.dgpu.tui import WorkerMonitor
|
from skynet.dgpu.tui import maybe_update_tui
|
||||||
from skynet.dgpu.errors import (
|
from skynet.dgpu.errors import (
|
||||||
DGPUComputeError,
|
DGPUComputeError,
|
||||||
DGPUInferenceCancelled,
|
DGPUInferenceCancelled,
|
||||||
|
@ -108,8 +108,7 @@ def compute_one(
|
||||||
method: str,
|
method: str,
|
||||||
params: dict,
|
params: dict,
|
||||||
inputs: list[bytes] = [],
|
inputs: list[bytes] = [],
|
||||||
should_cancel = None,
|
should_cancel = None
|
||||||
tui: WorkerMonitor | None = None
|
|
||||||
):
|
):
|
||||||
if method == 'diffuse':
|
if method == 'diffuse':
|
||||||
method = 'txt2img'
|
method = 'txt2img'
|
||||||
|
@ -130,8 +129,7 @@ def compute_one(
|
||||||
if not isinstance(step, int):
|
if not isinstance(step, int):
|
||||||
step = args[1]
|
step = args[1]
|
||||||
|
|
||||||
if tui:
|
maybe_update_tui(lambda tui: tui.set_progress(step, done=total_steps))
|
||||||
tui.set_progress(step, done=total_steps)
|
|
||||||
|
|
||||||
if should_cancel:
|
if should_cancel:
|
||||||
should_raise = trio.from_thread.run(should_cancel, request_id)
|
should_raise = trio.from_thread.run(should_cancel, request_id)
|
||||||
|
@ -142,8 +140,7 @@ def compute_one(
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
if tui:
|
maybe_update_tui(lambda tui: tui.set_status(f'Request #{request_id}'))
|
||||||
tui.set_status(f'Request #{request_id}')
|
|
||||||
|
|
||||||
inference_step_wakeup(0)
|
inference_step_wakeup(0)
|
||||||
|
|
||||||
|
@ -210,7 +207,6 @@ def compute_one(
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
raise DGPUComputeError(str(err)) from err
|
raise DGPUComputeError(str(err)) from err
|
||||||
|
|
||||||
if tui:
|
maybe_update_tui(lambda tui: tui.set_status(''))
|
||||||
tui.set_status('')
|
|
||||||
|
|
||||||
return output_hash, output
|
return output_hash, output
|
||||||
|
|
|
@ -17,7 +17,7 @@ from skynet.constants import (
|
||||||
from skynet.dgpu.errors import (
|
from skynet.dgpu.errors import (
|
||||||
DGPUComputeError,
|
DGPUComputeError,
|
||||||
)
|
)
|
||||||
from skynet.dgpu.tui import WorkerMonitor
|
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
|
||||||
from skynet.dgpu.compute import maybe_load_model, compute_one
|
from skynet.dgpu.compute import maybe_load_model, compute_one
|
||||||
from skynet.dgpu.network import NetConnector
|
from skynet.dgpu.network import NetConnector
|
||||||
|
|
||||||
|
@ -41,11 +41,9 @@ class WorkerDaemon:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
conn: NetConnector,
|
conn: NetConnector,
|
||||||
config: dict,
|
config: dict
|
||||||
tui: WorkerMonitor | None = None
|
|
||||||
):
|
):
|
||||||
self.conn: NetConnector = conn
|
self.conn: NetConnector = conn
|
||||||
self._tui = tui
|
|
||||||
self.auto_withdraw = (
|
self.auto_withdraw = (
|
||||||
config['auto_withdraw']
|
config['auto_withdraw']
|
||||||
if 'auto_withdraw' in config else False
|
if 'auto_withdraw' in config else False
|
||||||
|
@ -152,10 +150,12 @@ class WorkerDaemon:
|
||||||
return app
|
return app
|
||||||
|
|
||||||
async def _update_balance(self):
|
async def _update_balance(self):
|
||||||
if self._tui:
|
async def _fn(tui):
|
||||||
# update balance
|
# update balance
|
||||||
balance = await self.conn.get_worker_balance()
|
balance = await self.conn.get_worker_balance()
|
||||||
self._tui.set_header_text(new_balance=f'balance: {balance}')
|
tui.set_header_text(new_balance=f'balance: {balance}')
|
||||||
|
|
||||||
|
await maybe_update_tui_async(_fn)
|
||||||
|
|
||||||
# TODO? this func is kinda big and maybe is better at module
|
# TODO? this func is kinda big and maybe is better at module
|
||||||
# level to reduce indentation?
|
# level to reduce indentation?
|
||||||
|
@ -258,8 +258,7 @@ class WorkerDaemon:
|
||||||
|
|
||||||
with maybe_load_model(model, mode):
|
with maybe_load_model(model, mode):
|
||||||
try:
|
try:
|
||||||
if self._tui:
|
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
|
||||||
self._tui.set_progress(0, done=total_step)
|
|
||||||
|
|
||||||
output_type = 'png'
|
output_type = 'png'
|
||||||
if 'output_type' in body['params']:
|
if 'output_type' in body['params']:
|
||||||
|
@ -276,7 +275,6 @@ class WorkerDaemon:
|
||||||
mode, body['params'],
|
mode, body['params'],
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
should_cancel=self.should_cancel_work,
|
should_cancel=self.should_cancel_work,
|
||||||
tui=self._tui
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -285,8 +283,7 @@ class WorkerDaemon:
|
||||||
f'Unsupported backend {self.backend}'
|
f'Unsupported backend {self.backend}'
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._tui:
|
maybe_update_tui(lambda tui: tui.set_progress(total_step))
|
||||||
self._tui.set_progress(total_step)
|
|
||||||
|
|
||||||
self._last_generation_ts: str = datetime.now().isoformat()
|
self._last_generation_ts: str = datetime.now().isoformat()
|
||||||
self._last_benchmark: list[float] = self._benchmark
|
self._last_benchmark: list[float] = self._benchmark
|
||||||
|
|
|
@ -13,7 +13,7 @@ import outcome
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from leap.cleos import CLEOS
|
from leap.cleos import CLEOS
|
||||||
from leap.protocol import Asset
|
from leap.protocol import Asset
|
||||||
from skynet.dgpu.tui import WorkerMonitor
|
from skynet.dgpu.tui import maybe_update_tui
|
||||||
from skynet.constants import (
|
from skynet.constants import (
|
||||||
DEFAULT_IPFS_DOMAIN,
|
DEFAULT_IPFS_DOMAIN,
|
||||||
GPU_CONTRACT_ABI,
|
GPU_CONTRACT_ABI,
|
||||||
|
@ -58,7 +58,7 @@ class NetConnector:
|
||||||
- CLEOS client
|
- CLEOS client
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
|
def __init__(self, config: dict):
|
||||||
# TODO, why these extra instance vars for an (unsynced)
|
# TODO, why these extra instance vars for an (unsynced)
|
||||||
# copy of the `config` state?
|
# copy of the `config` state?
|
||||||
self.account = config['account']
|
self.account = config['account']
|
||||||
|
@ -82,9 +82,8 @@ class NetConnector:
|
||||||
self.ipfs_domain = config['ipfs_domain']
|
self.ipfs_domain = config['ipfs_domain']
|
||||||
|
|
||||||
self._wip_requests = {}
|
self._wip_requests = {}
|
||||||
self._tui = tui
|
|
||||||
if self._tui:
|
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.account))
|
||||||
self._tui.set_header_text(new_worker_name=self.account)
|
|
||||||
|
|
||||||
|
|
||||||
# blockchain helpers
|
# blockchain helpers
|
||||||
|
@ -168,8 +167,8 @@ class NetConnector:
|
||||||
n.start_soon(
|
n.start_soon(
|
||||||
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
|
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
|
||||||
|
|
||||||
if self._tui:
|
|
||||||
self._tui.network_update(snap)
|
maybe_update_tui(lambda tui: tui.network_update(snap))
|
||||||
|
|
||||||
return snap
|
return snap
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
import urwid
|
|
||||||
import trio
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import trio
|
||||||
|
import urwid
|
||||||
|
|
||||||
|
|
||||||
class WorkerMonitor:
|
class WorkerMonitor:
|
||||||
|
@ -163,86 +166,41 @@ class WorkerMonitor:
|
||||||
self.update_requests(queue)
|
self.update_requests(queue)
|
||||||
|
|
||||||
|
|
||||||
# # -----------------------------------------------------------------------------
|
def setup_logging_for_tui(level):
|
||||||
# # Example usage
|
warnings.filterwarnings("ignore")
|
||||||
# # -----------------------------------------------------------------------------
|
|
||||||
#
|
logger = logging.getLogger()
|
||||||
# async def main():
|
logger.setLevel(level)
|
||||||
# # Example data
|
|
||||||
# example_requests = [
|
fh = logging.FileHandler('dgpu.log')
|
||||||
# {
|
fh.setLevel(level)
|
||||||
# "id": 12,
|
|
||||||
# "model": "black-forest-labs/FLUX.1-schnell",
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
# "prompt": "Generate an answer about quantum entanglement.",
|
fh.setFormatter(formatter)
|
||||||
# "user": "alice123",
|
|
||||||
# "reward": "20.0000 GPU",
|
logger.addHandler(fh)
|
||||||
# "workers": ["workerA", "workerB"],
|
|
||||||
# },
|
for handler in logger.handlers:
|
||||||
# {
|
if isinstance(handler, logging.StreamHandler):
|
||||||
# "id": 5,
|
logger.removeHandler(handler)
|
||||||
# "model": "some-other-model/v2.0",
|
|
||||||
# "prompt": "A story about dragons.",
|
|
||||||
# "user": "bobthebuilder",
|
_tui = None
|
||||||
# "reward": "15.0000 GPU",
|
def init_tui():
|
||||||
# "workers": ["workerX"],
|
global _tui
|
||||||
# },
|
assert not _tui
|
||||||
# {
|
setup_logging_for_tui(logging.INFO)
|
||||||
# "id": 99,
|
_tui = WorkerMonitor()
|
||||||
# "model": "cool-model/turbo",
|
return _tui
|
||||||
# "prompt": "Classify sentiment in these tweets.",
|
|
||||||
# "user": "charlie",
|
|
||||||
# "reward": "25.5000 GPU",
|
def maybe_update_tui(fn):
|
||||||
# "workers": ["workerOne", "workerTwo", "workerThree"],
|
global _tui
|
||||||
# },
|
if _tui:
|
||||||
# ]
|
fn(_tui)
|
||||||
#
|
|
||||||
# ui = WorkerMonitor()
|
|
||||||
#
|
async def maybe_update_tui_async(fn):
|
||||||
# async def progress_task():
|
global _tui
|
||||||
# # Fill from 0% to 100%
|
if _tui:
|
||||||
# for pct in range(101):
|
await fn(_tui)
|
||||||
# ui.set_progress(
|
|
||||||
# current=pct,
|
|
||||||
# status_str=f"Request #1234 ({pct}%)"
|
|
||||||
# )
|
|
||||||
# await trio.sleep(0.05)
|
|
||||||
# # Reset to 0
|
|
||||||
# ui.set_progress(
|
|
||||||
# current=0,
|
|
||||||
# status_str="Starting again..."
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# async def update_data_task():
|
|
||||||
# await trio.sleep(3) # Wait a bit, then update requests
|
|
||||||
# new_data = [{
|
|
||||||
# "id": 101,
|
|
||||||
# "model": "new-model/v1.0",
|
|
||||||
# "prompt": "Say hi to the world.",
|
|
||||||
# "user": "eve",
|
|
||||||
# "reward": "50.0000 GPU",
|
|
||||||
# "workers": ["workerFresh", "workerPower"],
|
|
||||||
# }]
|
|
||||||
# ui.update_requests(new_data)
|
|
||||||
# ui.set_header_text(new_worker_name="NewNodeName",
|
|
||||||
# new_balance="balance: 12345.6789 GPU")
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# async with trio.open_nursery() as nursery:
|
|
||||||
# # Run the TUI
|
|
||||||
# nursery.start_soon(ui.run_teadown_on_exit, nursery)
|
|
||||||
#
|
|
||||||
# ui.update_requests(example_requests)
|
|
||||||
# ui.set_header_text(
|
|
||||||
# new_worker_name="worker1.scd",
|
|
||||||
# new_balance="balance: 12345.6789 GPU"
|
|
||||||
# )
|
|
||||||
# # Start background tasks
|
|
||||||
# nursery.start_soon(progress_task)
|
|
||||||
# nursery.start_soon(update_data_task)
|
|
||||||
#
|
|
||||||
# except *KeyboardInterrupt as ex_group:
|
|
||||||
# ...
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# trio.run(main)
|
|
||||||
|
|
Loading…
Reference in New Issue