Factor out WorkerDaemon, split into functions, made poller into an async gen and moved it to NetConnector as well as should_cancel

pull/47/head
Guillermo Rodriguez 2025-02-07 16:41:50 -03:00
parent ea3b35904c
commit 149d9f9f33
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
6 changed files with 199 additions and 310 deletions

View File

@ -193,14 +193,14 @@ def dgpu(
config_path: str config_path: str
): ):
import trio import trio
from .dgpu import open_dgpu_node from .dgpu import _dgpu_main
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
config = load_skynet_toml(file_path=config_path) config = load_skynet_toml(file_path=config_path)
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home) set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
trio.run(open_dgpu_node, config.dgpu) trio.run(_dgpu_main, config.dgpu)
@run.command() @run.command()

View File

@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct):
backend: str = 'sync-on-thread' backend: str = 'sync-on-thread'
api_bind: str = False api_bind: str = False
tui: bool = False tui: bool = False
poll_time: float = 0.5
class TelegramConfig(msgspec.Struct): class TelegramConfig(msgspec.Struct):
account: str account: str

View File

@ -3,23 +3,13 @@ import logging
import trio import trio
import urwid import urwid
from hypercorn.config import Config as HCConfig
from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart
from skynet.config import Config from skynet.config import Config
from skynet.dgpu.tui import init_tui from skynet.dgpu.tui import init_tui
from skynet.dgpu.daemon import WorkerDaemon from skynet.dgpu.daemon import serve_forever
from skynet.dgpu.network import NetConnector from skynet.dgpu.network import NetConnector
async def open_dgpu_node(config: Config) -> None: async def _dgpu_main(config: Config) -> None:
'''
Open a top level "GPU mgmt daemon", keep the
`WorkerDaemon._snap: dict[str, list|dict]` table
and *maybe* serve a `hypercorn` web API.
'''
# suppress logs from httpx (logs url + status after every query) # suppress logs from httpx (logs url + status after every query)
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
@ -28,29 +18,14 @@ async def open_dgpu_node(config: Config) -> None:
tui = init_tui() tui = init_tui()
conn = NetConnector(config) conn = NetConnector(config)
daemon = WorkerDaemon(conn, config)
api: Quart|None = None
if config.api_bind:
api_conf = HCConfig()
api_conf.bind = [config.api_bind]
api: Quart = await daemon.generate_api()
tn: trio.Nursery
async with trio.open_nursery() as tn:
tn.start_soon(daemon.snap_updater_task)
if tui:
tn.start_soon(tui.run)
# TODO, consider a more explicit `as hypercorn_serve`
# to clarify?
if api:
logging.info(f'serving api @ {config["api_bind"]}')
tn.start_soon(serve, api, api_conf)
try: try:
# block until cancelled n: trio.Nursery
await daemon.serve_forever() async with trio.open_nursery() as n:
if tui:
n.start_soon(tui.run)
await serve_forever(config, conn)
except *urwid.ExitMainLoop: except *urwid.ExitMainLoop:
... ...

View File

@ -20,6 +20,7 @@ from skynet.dgpu.errors import (
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
def prepare_params_for_diffuse( def prepare_params_for_diffuse(
params: dict, params: dict,
mode: str, mode: str,

View File

@ -7,8 +7,6 @@ from functools import partial
from hashlib import sha256 from hashlib import sha256
import trio import trio
from quart import jsonify
from quart_trio import QuartTrio as Quart
from skynet.config import DgpuConfig as Config from skynet.config import DgpuConfig as Config
from skynet.constants import ( from skynet.constants import (
@ -31,117 +29,18 @@ def convert_reward_to_int(reward_str):
return int(int_part + decimal_part) return int(int_part + decimal_part)
class WorkerDaemon: async def maybe_update_tui_balance(conn: NetConnector):
'''
The root "GPU daemon".
Contains/manages underlying susystems:
- a GPU connecto
'''
def __init__(
self,
conn: NetConnector,
config: Config
):
self.config = config
self.conn: NetConnector = conn
self._snap = {
'queue': [],
'requests': {},
'results': []
}
self._benchmark: list[float] = []
self._last_benchmark: list[float]|None = None
self._last_generation_ts: str|None = None
def _get_benchmark_speed(self) -> float:
'''
Return the (arithmetic) average work-iterations-per-second
fconducted by this compute worker.
'''
if not self._last_benchmark:
return 0
start = self._last_benchmark[0]
end = self._last_benchmark[-1]
elapsed = end - start
its = len(self._last_benchmark)
speed = its / elapsed
logging.info(f'{elapsed} s total its: {its}, at {speed} it/s ')
return speed
async def should_cancel_work(self, request_id: int):
self._benchmark.append(time.time())
logging.info('should cancel work?')
if request_id not in self._snap['requests']:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
competitors = set([
status['worker']
for status in self._snap['requests'][request_id]
if status['worker'] != self.config.account
])
logging.info(f'competitors: {competitors}')
should_cancel = bool(self.config.non_compete & competitors)
logging.info(f'cancel: {should_cancel}')
return should_cancel
async def snap_updater_task(self):
'''
Busy loop update the local `._snap: dict` table from
'''
while True:
self._snap = await self.conn.get_full_queue_snapshot()
await trio.sleep(1)
# TODO, design suggestion, just make this a lazily accessed
# `@class_property` if we're 3.12+
# |_ https://docs.python.org/3/library/functools.html#functools.cached_property
async def generate_api(self) -> Quart:
'''
Gen a `Quart`-compat web API spec which (for now) simply
serves a small monitoring ep that reports,
- iso-time-stamp of the last served model-output
- the worker's average "compute-iterations-per-second"
'''
app = Quart(__name__)
@app.route('/')
async def health():
return jsonify(
account=self.config.account,
version=VERSION,
last_generation_ts=self._last_generation_ts,
last_generation_speed=self._get_benchmark_speed()
)
return app
async def _update_balance(self):
async def _fn(tui): async def _fn(tui):
# update balance # update balance
balance = await self.conn.get_worker_balance() balance = await conn.get_worker_balance()
tui.set_header_text(new_balance=f'balance: {balance}') tui.set_header_text(new_balance=f'balance: {balance}')
await maybe_update_tui_async(_fn) await maybe_update_tui_async(_fn)
# TODO? this func is kinda big and maybe is better at module
# level to reduce indentation?
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
async def maybe_serve_one( async def maybe_serve_one(
self, config: Config,
conn: NetConnector,
req: dict, req: dict,
): ):
rid = req['id'] rid = req['id']
@ -158,40 +57,40 @@ class WorkerDaemon:
model not in MODELS model not in MODELS
): ):
logging.warning(f'unknown model {model}!, skip...') logging.warning(f'unknown model {model}!, skip...')
return False return
# only handle whitelisted models # only handle whitelisted models
if ( if (
len(self.config.model_whitelist) > 0 len(config.model_whitelist) > 0
and and
model not in self.config.model_whitelist model not in config.model_whitelist
): ):
logging.warning('model not whitelisted!, skip...') logging.warning('model not whitelisted!, skip...')
return False return
# if blacklist contains model skip # if blacklist contains model skip
if ( if (
len(self.config.model_blacklist) > 0 len(config.model_blacklist) > 0
and and
model in self.config.model_blacklist model in config.model_blacklist
): ):
logging.warning('model not blacklisted!, skip...') logging.warning('model not blacklisted!, skip...')
return False return
results = [res['request_id'] for res in self._snap['results']] results = [res['request_id'] for res in conn._tables['results']]
# if worker already produced a result for this request # if worker already produced a result for this request
if rid in results: if rid in results:
logging.info(f'worker already submitted a result for request #{rid}, skip...') logging.info(f'worker already submitted a result for request #{rid}, skip...')
return False return
statuses = self._snap['requests'][rid] statuses = conn._tables['requests'][rid]
# skip if workers in non_compete already on it # skip if workers in non_compete already on it
competitors = set((status['worker'] for status in statuses)) competitors = set((status['worker'] for status in statuses))
if bool(self.config.non_compete & competitors): if bool(config.non_compete & competitors):
logging.info('worker in configured non_compete list already working on request, skip...') logging.info('worker in configured non_compete list already working on request, skip...')
return False return
# resolve the ipfs hashes into the actual data behind them # resolve the ipfs hashes into the actual data behind them
inputs = [] inputs = []
@ -207,7 +106,7 @@ class WorkerDaemon:
# user `GPUConnector` to IO with # user `GPUConnector` to IO with
# storage layer to seed the compute # storage layer to seed the compute
# task. # task.
img = await self.conn.get_input_data(_input) img = await conn.get_input_data(_input)
inputs.append(img) inputs.append(img)
logging.info(f'retrieved {_input}!') logging.info(f'retrieved {_input}!')
break break
@ -235,10 +134,10 @@ class WorkerDaemon:
# TODO: validate request # TODO: validate request
resp = await self.conn.begin_work(rid) resp = await conn.begin_work(rid)
if not resp or 'code' in resp: if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.') logging.info('begin_work error, probably being worked on already... skip.')
return False return
with maybe_load_model(model, mode): with maybe_load_model(model, mode):
try: try:
@ -250,7 +149,7 @@ class WorkerDaemon:
output = None output = None
output_hash = None output_hash = None
match self.config.backend: match config.backend:
case 'sync-on-thread': case 'sync-on-thread':
output_hash, output = await trio.to_thread.run_sync( output_hash, output = await trio.to_thread.run_sync(
partial( partial(
@ -258,49 +157,37 @@ class WorkerDaemon:
rid, rid,
mode, body['params'], mode, body['params'],
inputs=inputs, inputs=inputs,
should_cancel=self.should_cancel_work, should_cancel=conn.should_cancel_work,
) )
) )
case _: case _:
raise DGPUComputeError( raise DGPUComputeError(
f'Unsupported backend {self.config.backend}' f'Unsupported backend {config.backend}'
) )
maybe_update_tui(lambda tui: tui.set_progress(total_step)) maybe_update_tui(lambda tui: tui.set_progress(total_step))
self._last_generation_ts: str = datetime.now().isoformat() ipfs_hash = await conn.publish_on_ipfs(output, typ=output_type)
self._last_benchmark: list[float] = self._benchmark
self._benchmark: list[float] = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type) await conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash) await maybe_update_tui_balance(conn)
await self._update_balance()
except BaseException as err: except BaseException as err:
if 'network cancel' not in str(err): if 'network cancel' not in str(err):
logging.exception('Failed to serve model request !?\n') logging.exception('Failed to serve model request !?\n')
if rid in self._snap['requests']: if rid in conn._tables['requests']:
await self.conn.cancel_work(rid, 'reason not provided') await conn.cancel_work(rid, 'reason not provided')
finally:
return True
# TODO, as per above on `.maybe_serve_one()`, it's likely a bit async def serve_forever(config: Config, conn: NetConnector):
# more *trionic* to define this all as a module level task-func await maybe_update_tui_balance(conn)
# which operates on a `daemon: WorkerDaemon`?
#
# -[ ] keeps tasks-as-funcs style prominent
# -[ ] avoids so much indentation due to methods
async def serve_forever(self):
await self._update_balance()
try: try:
while True: async for tables in conn.iter_poll_update(config.poll_time):
queue = self._snap['queue'] queue = tables['queue']
random.shuffle(queue) random.shuffle(queue)
queue = sorted( queue = sorted(
@ -309,13 +196,8 @@ class WorkerDaemon:
reverse=True reverse=True
) )
for req in queue: if len(queue) > 0:
# TODO, as mentioned above just inline this once await maybe_serve_one(config, conn, queue[0])
# converted to a mod level func.
if (await self.maybe_serve_one(req)):
break
await trio.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
... ...

View File

@ -3,6 +3,7 @@ import json
import time import time
import logging import logging
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator
from functools import partial from functools import partial
import trio import trio
@ -66,7 +67,11 @@ class NetConnector:
self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url) self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
self._wip_requests = {} self._tables = {
'queue': [],
'requests': {},
'results': []
}
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account)) maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
@ -132,9 +137,6 @@ class NetConnector:
logging.info('no balance info found') logging.info('no balance info found')
return None return None
# TODO, considery making this a NON-method and instead
# handing in the `snap['queue']` output beforehand?
# -> since that call is the only usage of `self`?
async def get_full_queue_snapshot(self): async def get_full_queue_snapshot(self):
''' '''
Keep in-sync with latest (telos chain's smart-contract) table Keep in-sync with latest (telos chain's smart-contract) table
@ -162,6 +164,34 @@ class NetConnector:
return snap return snap
async def iter_poll_update(self, poll_time: float) -> AsyncGenerator[dict, None]:
'''
Long running task, olls gpu contract tables yields latest table rows
'''
while True:
start_time = time.time()
self._tables = await self.get_full_queue_snapshot()
elapsed = time.time() - start_time
yield self._tables
await trio.sleep(max(poll_time - elapsed, 0.1))
async def should_cancel_work(self, request_id: int) -> bool:
logging.info('should cancel work?')
if request_id not in self._tables['requests']:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
competitors = set([
status['worker']
for status in self._tables['requests'][request_id]
if status['worker'] != self.config.account
])
logging.info(f'competitors: {competitors}')
should_cancel = bool(self.config.non_compete & competitors)
logging.info(f'cancel: {should_cancel}')
return should_cancel
async def begin_work(self, request_id: int): async def begin_work(self, request_id: int):
''' '''
Publish to the bc that the worker is beginning a model-computation Publish to the bc that the worker is beginning a model-computation
@ -244,7 +274,7 @@ class NetConnector:
result_hash: str, result_hash: str,
ipfs_hash: str ipfs_hash: str
): ):
logging.info('submit_work #{request_id}') logging.info(f'submit_work #{request_id}')
return await failable( return await failable(
partial( partial(
self.cleos.a_push_action, self.cleos.a_push_action,