diff --git a/skynet/cli.py b/skynet/cli.py index 0ead5d3..f835d98 100755 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -193,14 +193,14 @@ def dgpu( config_path: str ): import trio - from .dgpu import open_dgpu_node + from .dgpu import _dgpu_main logging.basicConfig(level=loglevel) config = load_skynet_toml(file_path=config_path) set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home) - trio.run(open_dgpu_node, config.dgpu) + trio.run(_dgpu_main, config.dgpu) @run.command() diff --git a/skynet/config.py b/skynet/config.py index ca7b745..7625dab 100755 --- a/skynet/config.py +++ b/skynet/config.py @@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct): backend: str = 'sync-on-thread' api_bind: str = False tui: bool = False + poll_time: float = 0.5 class TelegramConfig(msgspec.Struct): account: str diff --git a/skynet/dgpu/__init__.py b/skynet/dgpu/__init__.py index 6f7c6f7..1c1f40c 100755 --- a/skynet/dgpu/__init__.py +++ b/skynet/dgpu/__init__.py @@ -3,23 +3,13 @@ import logging import trio import urwid -from hypercorn.config import Config as HCConfig -from hypercorn.trio import serve -from quart_trio import QuartTrio as Quart - from skynet.config import Config from skynet.dgpu.tui import init_tui -from skynet.dgpu.daemon import WorkerDaemon +from skynet.dgpu.daemon import serve_forever from skynet.dgpu.network import NetConnector -async def open_dgpu_node(config: Config) -> None: - ''' - Open a top level "GPU mgmt daemon", keep the - `WorkerDaemon._snap: dict[str, list|dict]` table - and *maybe* serve a `hypercorn` web API. - - ''' +async def _dgpu_main(config: Config) -> None: # suppress logs from httpx (logs url + status after every query) logging.getLogger("httpx").setLevel(logging.WARNING) @@ -28,29 +18,14 @@ async def open_dgpu_node(config: Config) -> None: tui = init_tui() conn = NetConnector(config) - daemon = WorkerDaemon(conn, config) - api: Quart|None = None - if config.api_bind: - api_conf = HCConfig() - api_conf.bind = [config.api_bind] - api: Quart = await daemon.generate_api() + try: + n: trio.Nursery + async with trio.open_nursery() as n: + if tui: + n.start_soon(tui.run) - tn: trio.Nursery - async with trio.open_nursery() as tn: - tn.start_soon(daemon.snap_updater_task) - if tui: - tn.start_soon(tui.run) + await serve_forever(config, conn) - # TODO, consider a more explicit `as hypercorn_serve` - # to clarify? - if api: - logging.info(f'serving api @ {config["api_bind"]}') - tn.start_soon(serve, api, api_conf) - - try: - # block until cancelled - await daemon.serve_forever() - - except *urwid.ExitMainLoop: - ... + except *urwid.ExitMainLoop: + ... diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index a027dc8..5e90791 100755 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -20,6 +20,7 @@ from skynet.dgpu.errors import ( from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for + def prepare_params_for_diffuse( params: dict, mode: str, diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 31c3d79..f3f8ef3 100755 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -7,8 +7,6 @@ from functools import partial from hashlib import sha256 import trio -from quart import jsonify -from quart_trio import QuartTrio as Quart from skynet.config import DgpuConfig as Config from skynet.constants import ( @@ -31,291 +29,175 @@ def convert_reward_to_int(reward_str): return int(int_part + decimal_part) -class WorkerDaemon: - ''' - The root "GPU daemon". +async def maybe_update_tui_balance(conn: NetConnector): + async def _fn(tui): + # update balance + balance = await conn.get_worker_balance() + tui.set_header_text(new_balance=f'balance: {balance}') - Contains/manages underlying susystems: - - a GPU connecto + await maybe_update_tui_async(_fn) - ''' - def __init__( - self, - conn: NetConnector, - config: Config + +async def maybe_serve_one( + config: Config, + conn: NetConnector, + req: dict, +): + rid = req['id'] + logging.info(f'maybe serve request #{rid}') + + # parse request + body = json.loads(req['body']) + model = body['params']['model'] + + # if model not known, ignore. + if ( + model != 'RealESRGAN_x4plus' + and + model not in MODELS ): - self.config = config - self.conn: NetConnector = conn + logging.warning(f'unknown model {model}!, skip...') + return - self._snap = { - 'queue': [], - 'requests': {}, - 'results': [] - } + # only handle whitelisted models + if ( + len(config.model_whitelist) > 0 + and + model not in config.model_whitelist + ): + logging.warning('model not whitelisted!, skip...') + return - self._benchmark: list[float] = [] - self._last_benchmark: list[float]|None = None - self._last_generation_ts: str|None = None + # if blacklist contains model skip + if ( + len(config.model_blacklist) > 0 + and + model in config.model_blacklist + ): + logging.warning('model not blacklisted!, skip...') + return - def _get_benchmark_speed(self) -> float: - ''' - Return the (arithmetic) average work-iterations-per-second - fconducted by this compute worker. + results = [res['request_id'] for res in conn._tables['results']] - ''' - if not self._last_benchmark: - return 0 + # if worker already produced a result for this request + if rid in results: + logging.info(f'worker already submitted a result for request #{rid}, skip...') + return - start = self._last_benchmark[0] - end = self._last_benchmark[-1] + statuses = conn._tables['requests'][rid] - elapsed = end - start - its = len(self._last_benchmark) - speed = its / elapsed + # skip if workers in non_compete already on it + competitors = set((status['worker'] for status in statuses)) + if bool(config.non_compete & competitors): + logging.info('worker in configured non_compete list already working on request, skip...') + return - logging.info(f'{elapsed} s total its: {its}, at {speed} it/s ') + # resolve the ipfs hashes into the actual data behind them + inputs = [] + raw_inputs = req['binary_data'].split(',') + if raw_inputs: + logging.info(f'fetching IPFS inputs: {raw_inputs}') - return speed + retry = 3 + for _input in req['binary_data'].split(','): + if _input: + for r in range(retry): + try: + # user `GPUConnector` to IO with + # storage layer to seed the compute + # task. + img = await conn.get_input_data(_input) + inputs.append(img) + logging.info(f'retrieved {_input}!') + break - async def should_cancel_work(self, request_id: int): - self._benchmark.append(time.time()) - logging.info('should cancel work?') - if request_id not in self._snap['requests']: - logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...') - return True + except BaseException: + logging.exception( + f'IPFS fetch input error !?! retries left {retry - r - 1}\n' + ) - competitors = set([ - status['worker'] - for status in self._snap['requests'][request_id] - if status['worker'] != self.config.account - ]) - logging.info(f'competitors: {competitors}') - should_cancel = bool(self.config.non_compete & competitors) - logging.info(f'cancel: {should_cancel}') - return should_cancel + # compute unique request hash used on submit + hash_str = ( + str(req['nonce']) + + + req['body'] + + + req['binary_data'] + ) + logging.debug(f'hashing: {hash_str}') + request_hash = sha256(hash_str.encode('utf-8')).hexdigest() + logging.info(f'calculated request hash: {request_hash}') + + total_step = body['params']['step'] + model = body['params']['model'] + mode = body['method'] + + # TODO: validate request + + resp = await conn.begin_work(rid) + if not resp or 'code' in resp: + logging.info('begin_work error, probably being worked on already... skip.') + return + + with maybe_load_model(model, mode): + try: + maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step)) + + output_type = 'png' + if 'output_type' in body['params']: + output_type = body['params']['output_type'] + + output = None + output_hash = None + match config.backend: + case 'sync-on-thread': + output_hash, output = await trio.to_thread.run_sync( + partial( + compute_one, + rid, + mode, body['params'], + inputs=inputs, + should_cancel=conn.should_cancel_work, + ) + ) + + case _: + raise DGPUComputeError( + f'Unsupported backend {config.backend}' + ) + + maybe_update_tui(lambda tui: tui.set_progress(total_step)) + + ipfs_hash = await conn.publish_on_ipfs(output, typ=output_type) + + await conn.submit_work(rid, request_hash, output_hash, ipfs_hash) + + await maybe_update_tui_balance(conn) - async def snap_updater_task(self): - ''' - Busy loop update the local `._snap: dict` table from + except BaseException as err: + if 'network cancel' not in str(err): + logging.exception('Failed to serve model request !?\n') - ''' - while True: - self._snap = await self.conn.get_full_queue_snapshot() - await trio.sleep(1) + if rid in conn._tables['requests']: + await conn.cancel_work(rid, 'reason not provided') - # TODO, design suggestion, just make this a lazily accessed - # `@class_property` if we're 3.12+ - # |_ https://docs.python.org/3/library/functools.html#functools.cached_property - async def generate_api(self) -> Quart: - ''' - Gen a `Quart`-compat web API spec which (for now) simply - serves a small monitoring ep that reports, - - iso-time-stamp of the last served model-output - - the worker's average "compute-iterations-per-second" +async def serve_forever(config: Config, conn: NetConnector): + await maybe_update_tui_balance(conn) + try: + async for tables in conn.iter_poll_update(config.poll_time): + queue = tables['queue'] - ''' - app = Quart(__name__) - - @app.route('/') - async def health(): - return jsonify( - account=self.config.account, - version=VERSION, - last_generation_ts=self._last_generation_ts, - last_generation_speed=self._get_benchmark_speed() + random.shuffle(queue) + queue = sorted( + queue, + key=lambda req: convert_reward_to_int(req['reward']), + reverse=True ) - return app + if len(queue) > 0: + await maybe_serve_one(config, conn, queue[0]) - async def _update_balance(self): - async def _fn(tui): - # update balance - balance = await self.conn.get_worker_balance() - tui.set_header_text(new_balance=f'balance: {balance}') - - await maybe_update_tui_async(_fn) - - # TODO? this func is kinda big and maybe is better at module - # level to reduce indentation? - # -[ ] just pass `daemon: WorkerDaemon` vs. `self` - async def maybe_serve_one( - self, - req: dict, - ): - rid = req['id'] - logging.info(f'maybe serve request #{rid}') - - # parse request - body = json.loads(req['body']) - model = body['params']['model'] - - # if model not known, ignore. - if ( - model != 'RealESRGAN_x4plus' - and - model not in MODELS - ): - logging.warning(f'unknown model {model}!, skip...') - return False - - # only handle whitelisted models - if ( - len(self.config.model_whitelist) > 0 - and - model not in self.config.model_whitelist - ): - logging.warning('model not whitelisted!, skip...') - return False - - # if blacklist contains model skip - if ( - len(self.config.model_blacklist) > 0 - and - model in self.config.model_blacklist - ): - logging.warning('model not blacklisted!, skip...') - return False - - results = [res['request_id'] for res in self._snap['results']] - - # if worker already produced a result for this request - if rid in results: - logging.info(f'worker already submitted a result for request #{rid}, skip...') - return False - - statuses = self._snap['requests'][rid] - - # skip if workers in non_compete already on it - competitors = set((status['worker'] for status in statuses)) - if bool(self.config.non_compete & competitors): - logging.info('worker in configured non_compete list already working on request, skip...') - return False - - # resolve the ipfs hashes into the actual data behind them - inputs = [] - raw_inputs = req['binary_data'].split(',') - if raw_inputs: - logging.info(f'fetching IPFS inputs: {raw_inputs}') - - retry = 3 - for _input in req['binary_data'].split(','): - if _input: - for r in range(retry): - try: - # user `GPUConnector` to IO with - # storage layer to seed the compute - # task. - img = await self.conn.get_input_data(_input) - inputs.append(img) - logging.info(f'retrieved {_input}!') - break - - except BaseException: - logging.exception( - f'IPFS fetch input error !?! retries left {retry - r - 1}\n' - ) - - # compute unique request hash used on submit - hash_str = ( - str(req['nonce']) - + - req['body'] - + - req['binary_data'] - ) - logging.debug(f'hashing: {hash_str}') - request_hash = sha256(hash_str.encode('utf-8')).hexdigest() - logging.info(f'calculated request hash: {request_hash}') - - total_step = body['params']['step'] - model = body['params']['model'] - mode = body['method'] - - # TODO: validate request - - resp = await self.conn.begin_work(rid) - if not resp or 'code' in resp: - logging.info('begin_work error, probably being worked on already... skip.') - return False - - with maybe_load_model(model, mode): - try: - maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step)) - - output_type = 'png' - if 'output_type' in body['params']: - output_type = body['params']['output_type'] - - output = None - output_hash = None - match self.config.backend: - case 'sync-on-thread': - output_hash, output = await trio.to_thread.run_sync( - partial( - compute_one, - rid, - mode, body['params'], - inputs=inputs, - should_cancel=self.should_cancel_work, - ) - ) - - case _: - raise DGPUComputeError( - f'Unsupported backend {self.config.backend}' - ) - - maybe_update_tui(lambda tui: tui.set_progress(total_step)) - - self._last_generation_ts: str = datetime.now().isoformat() - self._last_benchmark: list[float] = self._benchmark - self._benchmark: list[float] = [] - - ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type) - - await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash) - - await self._update_balance() - - - except BaseException as err: - if 'network cancel' not in str(err): - logging.exception('Failed to serve model request !?\n') - - if rid in self._snap['requests']: - await self.conn.cancel_work(rid, 'reason not provided') - - finally: - return True - - # TODO, as per above on `.maybe_serve_one()`, it's likely a bit - # more *trionic* to define this all as a module level task-func - # which operates on a `daemon: WorkerDaemon`? - # - # -[ ] keeps tasks-as-funcs style prominent - # -[ ] avoids so much indentation due to methods - async def serve_forever(self): - await self._update_balance() - try: - while True: - queue = self._snap['queue'] - - random.shuffle(queue) - queue = sorted( - queue, - key=lambda req: convert_reward_to_int(req['reward']), - reverse=True - ) - - for req in queue: - # TODO, as mentioned above just inline this once - # converted to a mod level func. - if (await self.maybe_serve_one(req)): - break - - await trio.sleep(1) - - except KeyboardInterrupt: - ... + except KeyboardInterrupt: + ... diff --git a/skynet/dgpu/network.py b/skynet/dgpu/network.py index 9076ee2..8b45197 100755 --- a/skynet/dgpu/network.py +++ b/skynet/dgpu/network.py @@ -3,6 +3,7 @@ import json import time import logging from pathlib import Path +from typing import AsyncGenerator from functools import partial import trio @@ -66,7 +67,11 @@ class NetConnector: self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url) - self._wip_requests = {} + self._tables = { + 'queue': [], + 'requests': {}, + 'results': [] + } maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account)) @@ -132,9 +137,6 @@ class NetConnector: logging.info('no balance info found') return None - # TODO, considery making this a NON-method and instead - # handing in the `snap['queue']` output beforehand? - # -> since that call is the only usage of `self`? async def get_full_queue_snapshot(self): ''' Keep in-sync with latest (telos chain's smart-contract) table @@ -162,6 +164,34 @@ class NetConnector: return snap + async def iter_poll_update(self, poll_time: float) -> AsyncGenerator[dict, None]: + ''' + Long running task, olls gpu contract tables yields latest table rows + + ''' + while True: + start_time = time.time() + self._tables = await self.get_full_queue_snapshot() + elapsed = time.time() - start_time + yield self._tables + await trio.sleep(max(poll_time - elapsed, 0.1)) + + async def should_cancel_work(self, request_id: int) -> bool: + logging.info('should cancel work?') + if request_id not in self._tables['requests']: + logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...') + return True + + competitors = set([ + status['worker'] + for status in self._tables['requests'][request_id] + if status['worker'] != self.config.account + ]) + logging.info(f'competitors: {competitors}') + should_cancel = bool(self.config.non_compete & competitors) + logging.info(f'cancel: {should_cancel}') + return should_cancel + async def begin_work(self, request_id: int): ''' Publish to the bc that the worker is beginning a model-computation @@ -244,7 +274,7 @@ class NetConnector: result_hash: str, ipfs_hash: str ): - logging.info('submit_work #{request_id}') + logging.info(f'submit_work #{request_id}') return await failable( partial( self.cleos.a_push_action,