Factor out WorkerDaemon, split into functions, made poller into an async gen and moved it to NetConnector as well as should_cancel

pull/47/head
Guillermo Rodriguez 2025-02-07 16:41:50 -03:00
parent ea3b35904c
commit 149d9f9f33
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
6 changed files with 199 additions and 310 deletions

View File

@ -193,14 +193,14 @@ def dgpu(
config_path: str config_path: str
): ):
import trio import trio
from .dgpu import open_dgpu_node from .dgpu import _dgpu_main
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
config = load_skynet_toml(file_path=config_path) config = load_skynet_toml(file_path=config_path)
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home) set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
trio.run(open_dgpu_node, config.dgpu) trio.run(_dgpu_main, config.dgpu)
@run.command() @run.command()

View File

@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct):
backend: str = 'sync-on-thread' backend: str = 'sync-on-thread'
api_bind: str = False api_bind: str = False
tui: bool = False tui: bool = False
poll_time: float = 0.5
class TelegramConfig(msgspec.Struct): class TelegramConfig(msgspec.Struct):
account: str account: str

View File

@ -3,23 +3,13 @@ import logging
import trio import trio
import urwid import urwid
from hypercorn.config import Config as HCConfig
from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart
from skynet.config import Config from skynet.config import Config
from skynet.dgpu.tui import init_tui from skynet.dgpu.tui import init_tui
from skynet.dgpu.daemon import WorkerDaemon from skynet.dgpu.daemon import serve_forever
from skynet.dgpu.network import NetConnector from skynet.dgpu.network import NetConnector
async def open_dgpu_node(config: Config) -> None: async def _dgpu_main(config: Config) -> None:
'''
Open a top level "GPU mgmt daemon", keep the
`WorkerDaemon._snap: dict[str, list|dict]` table
and *maybe* serve a `hypercorn` web API.
'''
# suppress logs from httpx (logs url + status after every query) # suppress logs from httpx (logs url + status after every query)
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
@ -28,29 +18,14 @@ async def open_dgpu_node(config: Config) -> None:
tui = init_tui() tui = init_tui()
conn = NetConnector(config) conn = NetConnector(config)
daemon = WorkerDaemon(conn, config)
api: Quart|None = None try:
if config.api_bind: n: trio.Nursery
api_conf = HCConfig() async with trio.open_nursery() as n:
api_conf.bind = [config.api_bind] if tui:
api: Quart = await daemon.generate_api() n.start_soon(tui.run)
tn: trio.Nursery await serve_forever(config, conn)
async with trio.open_nursery() as tn:
tn.start_soon(daemon.snap_updater_task)
if tui:
tn.start_soon(tui.run)
# TODO, consider a more explicit `as hypercorn_serve` except *urwid.ExitMainLoop:
# to clarify? ...
if api:
logging.info(f'serving api @ {config["api_bind"]}')
tn.start_soon(serve, api, api_conf)
try:
# block until cancelled
await daemon.serve_forever()
except *urwid.ExitMainLoop:
...

View File

@ -20,6 +20,7 @@ from skynet.dgpu.errors import (
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
def prepare_params_for_diffuse( def prepare_params_for_diffuse(
params: dict, params: dict,
mode: str, mode: str,

View File

@ -7,8 +7,6 @@ from functools import partial
from hashlib import sha256 from hashlib import sha256
import trio import trio
from quart import jsonify
from quart_trio import QuartTrio as Quart
from skynet.config import DgpuConfig as Config from skynet.config import DgpuConfig as Config
from skynet.constants import ( from skynet.constants import (
@ -31,291 +29,175 @@ def convert_reward_to_int(reward_str):
return int(int_part + decimal_part) return int(int_part + decimal_part)
class WorkerDaemon: async def maybe_update_tui_balance(conn: NetConnector):
''' async def _fn(tui):
The root "GPU daemon". # update balance
balance = await conn.get_worker_balance()
tui.set_header_text(new_balance=f'balance: {balance}')
Contains/manages underlying susystems: await maybe_update_tui_async(_fn)
- a GPU connecto
'''
def __init__( async def maybe_serve_one(
self, config: Config,
conn: NetConnector, conn: NetConnector,
config: Config req: dict,
):
rid = req['id']
logging.info(f'maybe serve request #{rid}')
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known, ignore.
if (
model != 'RealESRGAN_x4plus'
and
model not in MODELS
): ):
self.config = config logging.warning(f'unknown model {model}!, skip...')
self.conn: NetConnector = conn return
self._snap = { # only handle whitelisted models
'queue': [], if (
'requests': {}, len(config.model_whitelist) > 0
'results': [] and
} model not in config.model_whitelist
):
logging.warning('model not whitelisted!, skip...')
return
self._benchmark: list[float] = [] # if blacklist contains model skip
self._last_benchmark: list[float]|None = None if (
self._last_generation_ts: str|None = None len(config.model_blacklist) > 0
and
model in config.model_blacklist
):
logging.warning('model not blacklisted!, skip...')
return
def _get_benchmark_speed(self) -> float: results = [res['request_id'] for res in conn._tables['results']]
'''
Return the (arithmetic) average work-iterations-per-second
fconducted by this compute worker.
''' # if worker already produced a result for this request
if not self._last_benchmark: if rid in results:
return 0 logging.info(f'worker already submitted a result for request #{rid}, skip...')
return
start = self._last_benchmark[0] statuses = conn._tables['requests'][rid]
end = self._last_benchmark[-1]
elapsed = end - start # skip if workers in non_compete already on it
its = len(self._last_benchmark) competitors = set((status['worker'] for status in statuses))
speed = its / elapsed if bool(config.non_compete & competitors):
logging.info('worker in configured non_compete list already working on request, skip...')
return
logging.info(f'{elapsed} s total its: {its}, at {speed} it/s ') # resolve the ipfs hashes into the actual data behind them
inputs = []
raw_inputs = req['binary_data'].split(',')
if raw_inputs:
logging.info(f'fetching IPFS inputs: {raw_inputs}')
return speed retry = 3
for _input in req['binary_data'].split(','):
if _input:
for r in range(retry):
try:
# user `GPUConnector` to IO with
# storage layer to seed the compute
# task.
img = await conn.get_input_data(_input)
inputs.append(img)
logging.info(f'retrieved {_input}!')
break
async def should_cancel_work(self, request_id: int): except BaseException:
self._benchmark.append(time.time()) logging.exception(
logging.info('should cancel work?') f'IPFS fetch input error !?! retries left {retry - r - 1}\n'
if request_id not in self._snap['requests']: )
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
competitors = set([ # compute unique request hash used on submit
status['worker'] hash_str = (
for status in self._snap['requests'][request_id] str(req['nonce'])
if status['worker'] != self.config.account +
]) req['body']
logging.info(f'competitors: {competitors}') +
should_cancel = bool(self.config.non_compete & competitors) req['binary_data']
logging.info(f'cancel: {should_cancel}') )
return should_cancel logging.debug(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
# TODO: validate request
resp = await conn.begin_work(rid)
if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.')
return
with maybe_load_model(model, mode):
try:
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match config.backend:
case 'sync-on-thread':
output_hash, output = await trio.to_thread.run_sync(
partial(
compute_one,
rid,
mode, body['params'],
inputs=inputs,
should_cancel=conn.should_cancel_work,
)
)
case _:
raise DGPUComputeError(
f'Unsupported backend {config.backend}'
)
maybe_update_tui(lambda tui: tui.set_progress(total_step))
ipfs_hash = await conn.publish_on_ipfs(output, typ=output_type)
await conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
await maybe_update_tui_balance(conn)
async def snap_updater_task(self): except BaseException as err:
''' if 'network cancel' not in str(err):
Busy loop update the local `._snap: dict` table from logging.exception('Failed to serve model request !?\n')
''' if rid in conn._tables['requests']:
while True: await conn.cancel_work(rid, 'reason not provided')
self._snap = await self.conn.get_full_queue_snapshot()
await trio.sleep(1)
# TODO, design suggestion, just make this a lazily accessed
# `@class_property` if we're 3.12+
# |_ https://docs.python.org/3/library/functools.html#functools.cached_property
async def generate_api(self) -> Quart:
'''
Gen a `Quart`-compat web API spec which (for now) simply
serves a small monitoring ep that reports,
- iso-time-stamp of the last served model-output async def serve_forever(config: Config, conn: NetConnector):
- the worker's average "compute-iterations-per-second" await maybe_update_tui_balance(conn)
try:
async for tables in conn.iter_poll_update(config.poll_time):
queue = tables['queue']
''' random.shuffle(queue)
app = Quart(__name__) queue = sorted(
queue,
@app.route('/') key=lambda req: convert_reward_to_int(req['reward']),
async def health(): reverse=True
return jsonify(
account=self.config.account,
version=VERSION,
last_generation_ts=self._last_generation_ts,
last_generation_speed=self._get_benchmark_speed()
) )
return app if len(queue) > 0:
await maybe_serve_one(config, conn, queue[0])
async def _update_balance(self): except KeyboardInterrupt:
async def _fn(tui): ...
# update balance
balance = await self.conn.get_worker_balance()
tui.set_header_text(new_balance=f'balance: {balance}')
await maybe_update_tui_async(_fn)
# TODO? this func is kinda big and maybe is better at module
# level to reduce indentation?
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
async def maybe_serve_one(
self,
req: dict,
):
rid = req['id']
logging.info(f'maybe serve request #{rid}')
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known, ignore.
if (
model != 'RealESRGAN_x4plus'
and
model not in MODELS
):
logging.warning(f'unknown model {model}!, skip...')
return False
# only handle whitelisted models
if (
len(self.config.model_whitelist) > 0
and
model not in self.config.model_whitelist
):
logging.warning('model not whitelisted!, skip...')
return False
# if blacklist contains model skip
if (
len(self.config.model_blacklist) > 0
and
model in self.config.model_blacklist
):
logging.warning('model not blacklisted!, skip...')
return False
results = [res['request_id'] for res in self._snap['results']]
# if worker already produced a result for this request
if rid in results:
logging.info(f'worker already submitted a result for request #{rid}, skip...')
return False
statuses = self._snap['requests'][rid]
# skip if workers in non_compete already on it
competitors = set((status['worker'] for status in statuses))
if bool(self.config.non_compete & competitors):
logging.info('worker in configured non_compete list already working on request, skip...')
return False
# resolve the ipfs hashes into the actual data behind them
inputs = []
raw_inputs = req['binary_data'].split(',')
if raw_inputs:
logging.info(f'fetching IPFS inputs: {raw_inputs}')
retry = 3
for _input in req['binary_data'].split(','):
if _input:
for r in range(retry):
try:
# user `GPUConnector` to IO with
# storage layer to seed the compute
# task.
img = await self.conn.get_input_data(_input)
inputs.append(img)
logging.info(f'retrieved {_input}!')
break
except BaseException:
logging.exception(
f'IPFS fetch input error !?! retries left {retry - r - 1}\n'
)
# compute unique request hash used on submit
hash_str = (
str(req['nonce'])
+
req['body']
+
req['binary_data']
)
logging.debug(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
# TODO: validate request
resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.')
return False
with maybe_load_model(model, mode):
try:
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.config.backend:
case 'sync-on-thread':
output_hash, output = await trio.to_thread.run_sync(
partial(
compute_one,
rid,
mode, body['params'],
inputs=inputs,
should_cancel=self.should_cancel_work,
)
)
case _:
raise DGPUComputeError(
f'Unsupported backend {self.config.backend}'
)
maybe_update_tui(lambda tui: tui.set_progress(total_step))
self._last_generation_ts: str = datetime.now().isoformat()
self._last_benchmark: list[float] = self._benchmark
self._benchmark: list[float] = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
await self._update_balance()
except BaseException as err:
if 'network cancel' not in str(err):
logging.exception('Failed to serve model request !?\n')
if rid in self._snap['requests']:
await self.conn.cancel_work(rid, 'reason not provided')
finally:
return True
# TODO, as per above on `.maybe_serve_one()`, it's likely a bit
# more *trionic* to define this all as a module level task-func
# which operates on a `daemon: WorkerDaemon`?
#
# -[ ] keeps tasks-as-funcs style prominent
# -[ ] avoids so much indentation due to methods
async def serve_forever(self):
await self._update_balance()
try:
while True:
queue = self._snap['queue']
random.shuffle(queue)
queue = sorted(
queue,
key=lambda req: convert_reward_to_int(req['reward']),
reverse=True
)
for req in queue:
# TODO, as mentioned above just inline this once
# converted to a mod level func.
if (await self.maybe_serve_one(req)):
break
await trio.sleep(1)
except KeyboardInterrupt:
...

View File

@ -3,6 +3,7 @@ import json
import time import time
import logging import logging
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator
from functools import partial from functools import partial
import trio import trio
@ -66,7 +67,11 @@ class NetConnector:
self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url) self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
self._wip_requests = {} self._tables = {
'queue': [],
'requests': {},
'results': []
}
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account)) maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
@ -132,9 +137,6 @@ class NetConnector:
logging.info('no balance info found') logging.info('no balance info found')
return None return None
# TODO, considery making this a NON-method and instead
# handing in the `snap['queue']` output beforehand?
# -> since that call is the only usage of `self`?
async def get_full_queue_snapshot(self): async def get_full_queue_snapshot(self):
''' '''
Keep in-sync with latest (telos chain's smart-contract) table Keep in-sync with latest (telos chain's smart-contract) table
@ -162,6 +164,34 @@ class NetConnector:
return snap return snap
async def iter_poll_update(self, poll_time: float) -> AsyncGenerator[dict, None]:
'''
Long running task, olls gpu contract tables yields latest table rows
'''
while True:
start_time = time.time()
self._tables = await self.get_full_queue_snapshot()
elapsed = time.time() - start_time
yield self._tables
await trio.sleep(max(poll_time - elapsed, 0.1))
async def should_cancel_work(self, request_id: int) -> bool:
logging.info('should cancel work?')
if request_id not in self._tables['requests']:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
competitors = set([
status['worker']
for status in self._tables['requests'][request_id]
if status['worker'] != self.config.account
])
logging.info(f'competitors: {competitors}')
should_cancel = bool(self.config.non_compete & competitors)
logging.info(f'cancel: {should_cancel}')
return should_cancel
async def begin_work(self, request_id: int): async def begin_work(self, request_id: int):
''' '''
Publish to the bc that the worker is beginning a model-computation Publish to the bc that the worker is beginning a model-computation
@ -244,7 +274,7 @@ class NetConnector:
result_hash: str, result_hash: str,
ipfs_hash: str ipfs_hash: str
): ):
logging.info('submit_work #{request_id}') logging.info(f'submit_work #{request_id}')
return await failable( return await failable(
partial( partial(
self.cleos.a_push_action, self.cleos.a_push_action,