Factor out WorkerDaemon, split into functions, made poller into an async gen and moved it to NetConnector as well as should_cancel

pull/47/head
Guillermo Rodriguez 2025-02-07 16:41:50 -03:00
parent ea3b35904c
commit 149d9f9f33
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
6 changed files with 199 additions and 310 deletions

View File

@ -193,14 +193,14 @@ def dgpu(
config_path: str
):
import trio
from .dgpu import open_dgpu_node
from .dgpu import _dgpu_main
logging.basicConfig(level=loglevel)
config = load_skynet_toml(file_path=config_path)
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
trio.run(open_dgpu_node, config.dgpu)
trio.run(_dgpu_main, config.dgpu)
@run.command()

View File

@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct):
backend: str = 'sync-on-thread'
api_bind: str = False
tui: bool = False
poll_time: float = 0.5
class TelegramConfig(msgspec.Struct):
account: str

View File

@ -3,23 +3,13 @@ import logging
import trio
import urwid
from hypercorn.config import Config as HCConfig
from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart
from skynet.config import Config
from skynet.dgpu.tui import init_tui
from skynet.dgpu.daemon import WorkerDaemon
from skynet.dgpu.daemon import serve_forever
from skynet.dgpu.network import NetConnector
async def open_dgpu_node(config: Config) -> None:
'''
Open a top level "GPU mgmt daemon", keep the
`WorkerDaemon._snap: dict[str, list|dict]` table
and *maybe* serve a `hypercorn` web API.
'''
async def _dgpu_main(config: Config) -> None:
# suppress logs from httpx (logs url + status after every query)
logging.getLogger("httpx").setLevel(logging.WARNING)
@ -28,29 +18,14 @@ async def open_dgpu_node(config: Config) -> None:
tui = init_tui()
conn = NetConnector(config)
daemon = WorkerDaemon(conn, config)
api: Quart|None = None
if config.api_bind:
api_conf = HCConfig()
api_conf.bind = [config.api_bind]
api: Quart = await daemon.generate_api()
try:
n: trio.Nursery
async with trio.open_nursery() as n:
if tui:
n.start_soon(tui.run)
tn: trio.Nursery
async with trio.open_nursery() as tn:
tn.start_soon(daemon.snap_updater_task)
if tui:
tn.start_soon(tui.run)
await serve_forever(config, conn)
# TODO, consider a more explicit `as hypercorn_serve`
# to clarify?
if api:
logging.info(f'serving api @ {config["api_bind"]}')
tn.start_soon(serve, api, api_conf)
try:
# block until cancelled
await daemon.serve_forever()
except *urwid.ExitMainLoop:
...
except *urwid.ExitMainLoop:
...

View File

@ -20,6 +20,7 @@ from skynet.dgpu.errors import (
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
def prepare_params_for_diffuse(
params: dict,
mode: str,

View File

@ -7,8 +7,6 @@ from functools import partial
from hashlib import sha256
import trio
from quart import jsonify
from quart_trio import QuartTrio as Quart
from skynet.config import DgpuConfig as Config
from skynet.constants import (
@ -31,291 +29,175 @@ def convert_reward_to_int(reward_str):
return int(int_part + decimal_part)
class WorkerDaemon:
'''
The root "GPU daemon".
async def maybe_update_tui_balance(conn: NetConnector):
async def _fn(tui):
# update balance
balance = await conn.get_worker_balance()
tui.set_header_text(new_balance=f'balance: {balance}')
Contains/manages underlying susystems:
- a GPU connecto
await maybe_update_tui_async(_fn)
'''
def __init__(
self,
conn: NetConnector,
config: Config
async def maybe_serve_one(
config: Config,
conn: NetConnector,
req: dict,
):
rid = req['id']
logging.info(f'maybe serve request #{rid}')
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known, ignore.
if (
model != 'RealESRGAN_x4plus'
and
model not in MODELS
):
self.config = config
self.conn: NetConnector = conn
logging.warning(f'unknown model {model}!, skip...')
return
self._snap = {
'queue': [],
'requests': {},
'results': []
}
# only handle whitelisted models
if (
len(config.model_whitelist) > 0
and
model not in config.model_whitelist
):
logging.warning('model not whitelisted!, skip...')
return
self._benchmark: list[float] = []
self._last_benchmark: list[float]|None = None
self._last_generation_ts: str|None = None
# if blacklist contains model skip
if (
len(config.model_blacklist) > 0
and
model in config.model_blacklist
):
logging.warning('model not blacklisted!, skip...')
return
def _get_benchmark_speed(self) -> float:
'''
Return the (arithmetic) average work-iterations-per-second
fconducted by this compute worker.
results = [res['request_id'] for res in conn._tables['results']]
'''
if not self._last_benchmark:
return 0
# if worker already produced a result for this request
if rid in results:
logging.info(f'worker already submitted a result for request #{rid}, skip...')
return
start = self._last_benchmark[0]
end = self._last_benchmark[-1]
statuses = conn._tables['requests'][rid]
elapsed = end - start
its = len(self._last_benchmark)
speed = its / elapsed
# skip if workers in non_compete already on it
competitors = set((status['worker'] for status in statuses))
if bool(config.non_compete & competitors):
logging.info('worker in configured non_compete list already working on request, skip...')
return
logging.info(f'{elapsed} s total its: {its}, at {speed} it/s ')
# resolve the ipfs hashes into the actual data behind them
inputs = []
raw_inputs = req['binary_data'].split(',')
if raw_inputs:
logging.info(f'fetching IPFS inputs: {raw_inputs}')
return speed
retry = 3
for _input in req['binary_data'].split(','):
if _input:
for r in range(retry):
try:
# user `GPUConnector` to IO with
# storage layer to seed the compute
# task.
img = await conn.get_input_data(_input)
inputs.append(img)
logging.info(f'retrieved {_input}!')
break
async def should_cancel_work(self, request_id: int):
self._benchmark.append(time.time())
logging.info('should cancel work?')
if request_id not in self._snap['requests']:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
except BaseException:
logging.exception(
f'IPFS fetch input error !?! retries left {retry - r - 1}\n'
)
competitors = set([
status['worker']
for status in self._snap['requests'][request_id]
if status['worker'] != self.config.account
])
logging.info(f'competitors: {competitors}')
should_cancel = bool(self.config.non_compete & competitors)
logging.info(f'cancel: {should_cancel}')
return should_cancel
# compute unique request hash used on submit
hash_str = (
str(req['nonce'])
+
req['body']
+
req['binary_data']
)
logging.debug(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
# TODO: validate request
resp = await conn.begin_work(rid)
if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.')
return
with maybe_load_model(model, mode):
try:
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match config.backend:
case 'sync-on-thread':
output_hash, output = await trio.to_thread.run_sync(
partial(
compute_one,
rid,
mode, body['params'],
inputs=inputs,
should_cancel=conn.should_cancel_work,
)
)
case _:
raise DGPUComputeError(
f'Unsupported backend {config.backend}'
)
maybe_update_tui(lambda tui: tui.set_progress(total_step))
ipfs_hash = await conn.publish_on_ipfs(output, typ=output_type)
await conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
await maybe_update_tui_balance(conn)
async def snap_updater_task(self):
'''
Busy loop update the local `._snap: dict` table from
except BaseException as err:
if 'network cancel' not in str(err):
logging.exception('Failed to serve model request !?\n')
'''
while True:
self._snap = await self.conn.get_full_queue_snapshot()
await trio.sleep(1)
if rid in conn._tables['requests']:
await conn.cancel_work(rid, 'reason not provided')
# TODO, design suggestion, just make this a lazily accessed
# `@class_property` if we're 3.12+
# |_ https://docs.python.org/3/library/functools.html#functools.cached_property
async def generate_api(self) -> Quart:
'''
Gen a `Quart`-compat web API spec which (for now) simply
serves a small monitoring ep that reports,
- iso-time-stamp of the last served model-output
- the worker's average "compute-iterations-per-second"
async def serve_forever(config: Config, conn: NetConnector):
await maybe_update_tui_balance(conn)
try:
async for tables in conn.iter_poll_update(config.poll_time):
queue = tables['queue']
'''
app = Quart(__name__)
@app.route('/')
async def health():
return jsonify(
account=self.config.account,
version=VERSION,
last_generation_ts=self._last_generation_ts,
last_generation_speed=self._get_benchmark_speed()
random.shuffle(queue)
queue = sorted(
queue,
key=lambda req: convert_reward_to_int(req['reward']),
reverse=True
)
return app
if len(queue) > 0:
await maybe_serve_one(config, conn, queue[0])
async def _update_balance(self):
async def _fn(tui):
# update balance
balance = await self.conn.get_worker_balance()
tui.set_header_text(new_balance=f'balance: {balance}')
await maybe_update_tui_async(_fn)
# TODO? this func is kinda big and maybe is better at module
# level to reduce indentation?
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
async def maybe_serve_one(
self,
req: dict,
):
rid = req['id']
logging.info(f'maybe serve request #{rid}')
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known, ignore.
if (
model != 'RealESRGAN_x4plus'
and
model not in MODELS
):
logging.warning(f'unknown model {model}!, skip...')
return False
# only handle whitelisted models
if (
len(self.config.model_whitelist) > 0
and
model not in self.config.model_whitelist
):
logging.warning('model not whitelisted!, skip...')
return False
# if blacklist contains model skip
if (
len(self.config.model_blacklist) > 0
and
model in self.config.model_blacklist
):
logging.warning('model not blacklisted!, skip...')
return False
results = [res['request_id'] for res in self._snap['results']]
# if worker already produced a result for this request
if rid in results:
logging.info(f'worker already submitted a result for request #{rid}, skip...')
return False
statuses = self._snap['requests'][rid]
# skip if workers in non_compete already on it
competitors = set((status['worker'] for status in statuses))
if bool(self.config.non_compete & competitors):
logging.info('worker in configured non_compete list already working on request, skip...')
return False
# resolve the ipfs hashes into the actual data behind them
inputs = []
raw_inputs = req['binary_data'].split(',')
if raw_inputs:
logging.info(f'fetching IPFS inputs: {raw_inputs}')
retry = 3
for _input in req['binary_data'].split(','):
if _input:
for r in range(retry):
try:
# user `GPUConnector` to IO with
# storage layer to seed the compute
# task.
img = await self.conn.get_input_data(_input)
inputs.append(img)
logging.info(f'retrieved {_input}!')
break
except BaseException:
logging.exception(
f'IPFS fetch input error !?! retries left {retry - r - 1}\n'
)
# compute unique request hash used on submit
hash_str = (
str(req['nonce'])
+
req['body']
+
req['binary_data']
)
logging.debug(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
# TODO: validate request
resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.')
return False
with maybe_load_model(model, mode):
try:
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.config.backend:
case 'sync-on-thread':
output_hash, output = await trio.to_thread.run_sync(
partial(
compute_one,
rid,
mode, body['params'],
inputs=inputs,
should_cancel=self.should_cancel_work,
)
)
case _:
raise DGPUComputeError(
f'Unsupported backend {self.config.backend}'
)
maybe_update_tui(lambda tui: tui.set_progress(total_step))
self._last_generation_ts: str = datetime.now().isoformat()
self._last_benchmark: list[float] = self._benchmark
self._benchmark: list[float] = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
await self._update_balance()
except BaseException as err:
if 'network cancel' not in str(err):
logging.exception('Failed to serve model request !?\n')
if rid in self._snap['requests']:
await self.conn.cancel_work(rid, 'reason not provided')
finally:
return True
# TODO, as per above on `.maybe_serve_one()`, it's likely a bit
# more *trionic* to define this all as a module level task-func
# which operates on a `daemon: WorkerDaemon`?
#
# -[ ] keeps tasks-as-funcs style prominent
# -[ ] avoids so much indentation due to methods
async def serve_forever(self):
await self._update_balance()
try:
while True:
queue = self._snap['queue']
random.shuffle(queue)
queue = sorted(
queue,
key=lambda req: convert_reward_to_int(req['reward']),
reverse=True
)
for req in queue:
# TODO, as mentioned above just inline this once
# converted to a mod level func.
if (await self.maybe_serve_one(req)):
break
await trio.sleep(1)
except KeyboardInterrupt:
...
except KeyboardInterrupt:
...

View File

@ -3,6 +3,7 @@ import json
import time
import logging
from pathlib import Path
from typing import AsyncGenerator
from functools import partial
import trio
@ -66,7 +67,11 @@ class NetConnector:
self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
self._wip_requests = {}
self._tables = {
'queue': [],
'requests': {},
'results': []
}
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
@ -132,9 +137,6 @@ class NetConnector:
logging.info('no balance info found')
return None
# TODO, considery making this a NON-method and instead
# handing in the `snap['queue']` output beforehand?
# -> since that call is the only usage of `self`?
async def get_full_queue_snapshot(self):
'''
Keep in-sync with latest (telos chain's smart-contract) table
@ -162,6 +164,34 @@ class NetConnector:
return snap
async def iter_poll_update(self, poll_time: float) -> AsyncGenerator[dict, None]:
'''
Long running task, olls gpu contract tables yields latest table rows
'''
while True:
start_time = time.time()
self._tables = await self.get_full_queue_snapshot()
elapsed = time.time() - start_time
yield self._tables
await trio.sleep(max(poll_time - elapsed, 0.1))
async def should_cancel_work(self, request_id: int) -> bool:
logging.info('should cancel work?')
if request_id not in self._tables['requests']:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
competitors = set([
status['worker']
for status in self._tables['requests'][request_id]
if status['worker'] != self.config.account
])
logging.info(f'competitors: {competitors}')
should_cancel = bool(self.config.non_compete & competitors)
logging.info(f'cancel: {should_cancel}')
return should_cancel
async def begin_work(self, request_id: int):
'''
Publish to the bc that the worker is beginning a model-computation
@ -244,7 +274,7 @@ class NetConnector:
result_hash: str,
ipfs_hash: str
):
logging.info('submit_work #{request_id}')
logging.info(f'submit_work #{request_id}')
return await failable(
partial(
self.cleos.a_push_action,