mirror of https://github.com/skygpu/skynet.git
Factor out WorkerDaemon, split into functions, made poller into an async gen and moved it to NetConnector as well as should_cancel
parent
ea3b35904c
commit
149d9f9f33
|
@ -193,14 +193,14 @@ def dgpu(
|
|||
config_path: str
|
||||
):
|
||||
import trio
|
||||
from .dgpu import open_dgpu_node
|
||||
from .dgpu import _dgpu_main
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml(file_path=config_path)
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
|
||||
trio.run(open_dgpu_node, config.dgpu)
|
||||
trio.run(_dgpu_main, config.dgpu)
|
||||
|
||||
|
||||
@run.command()
|
||||
|
|
|
@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct):
|
|||
backend: str = 'sync-on-thread'
|
||||
api_bind: str = False
|
||||
tui: bool = False
|
||||
poll_time: float = 0.5
|
||||
|
||||
class TelegramConfig(msgspec.Struct):
|
||||
account: str
|
||||
|
|
|
@ -3,23 +3,13 @@ import logging
|
|||
import trio
|
||||
import urwid
|
||||
|
||||
from hypercorn.config import Config as HCConfig
|
||||
from hypercorn.trio import serve
|
||||
from quart_trio import QuartTrio as Quart
|
||||
|
||||
from skynet.config import Config
|
||||
from skynet.dgpu.tui import init_tui
|
||||
from skynet.dgpu.daemon import WorkerDaemon
|
||||
from skynet.dgpu.daemon import serve_forever
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
||||
async def open_dgpu_node(config: Config) -> None:
|
||||
'''
|
||||
Open a top level "GPU mgmt daemon", keep the
|
||||
`WorkerDaemon._snap: dict[str, list|dict]` table
|
||||
and *maybe* serve a `hypercorn` web API.
|
||||
|
||||
'''
|
||||
async def _dgpu_main(config: Config) -> None:
|
||||
# suppress logs from httpx (logs url + status after every query)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
@ -28,29 +18,14 @@ async def open_dgpu_node(config: Config) -> None:
|
|||
tui = init_tui()
|
||||
|
||||
conn = NetConnector(config)
|
||||
daemon = WorkerDaemon(conn, config)
|
||||
|
||||
api: Quart|None = None
|
||||
if config.api_bind:
|
||||
api_conf = HCConfig()
|
||||
api_conf.bind = [config.api_bind]
|
||||
api: Quart = await daemon.generate_api()
|
||||
try:
|
||||
n: trio.Nursery
|
||||
async with trio.open_nursery() as n:
|
||||
if tui:
|
||||
n.start_soon(tui.run)
|
||||
|
||||
tn: trio.Nursery
|
||||
async with trio.open_nursery() as tn:
|
||||
tn.start_soon(daemon.snap_updater_task)
|
||||
if tui:
|
||||
tn.start_soon(tui.run)
|
||||
await serve_forever(config, conn)
|
||||
|
||||
# TODO, consider a more explicit `as hypercorn_serve`
|
||||
# to clarify?
|
||||
if api:
|
||||
logging.info(f'serving api @ {config["api_bind"]}')
|
||||
tn.start_soon(serve, api, api_conf)
|
||||
|
||||
try:
|
||||
# block until cancelled
|
||||
await daemon.serve_forever()
|
||||
|
||||
except *urwid.ExitMainLoop:
|
||||
...
|
||||
except *urwid.ExitMainLoop:
|
||||
...
|
||||
|
|
|
@ -20,6 +20,7 @@ from skynet.dgpu.errors import (
|
|||
|
||||
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
||||
|
||||
|
||||
def prepare_params_for_diffuse(
|
||||
params: dict,
|
||||
mode: str,
|
||||
|
|
|
@ -7,8 +7,6 @@ from functools import partial
|
|||
from hashlib import sha256
|
||||
|
||||
import trio
|
||||
from quart import jsonify
|
||||
from quart_trio import QuartTrio as Quart
|
||||
|
||||
from skynet.config import DgpuConfig as Config
|
||||
from skynet.constants import (
|
||||
|
@ -31,291 +29,175 @@ def convert_reward_to_int(reward_str):
|
|||
return int(int_part + decimal_part)
|
||||
|
||||
|
||||
class WorkerDaemon:
|
||||
'''
|
||||
The root "GPU daemon".
|
||||
async def maybe_update_tui_balance(conn: NetConnector):
|
||||
async def _fn(tui):
|
||||
# update balance
|
||||
balance = await conn.get_worker_balance()
|
||||
tui.set_header_text(new_balance=f'balance: {balance}')
|
||||
|
||||
Contains/manages underlying susystems:
|
||||
- a GPU connecto
|
||||
await maybe_update_tui_async(_fn)
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
conn: NetConnector,
|
||||
config: Config
|
||||
|
||||
async def maybe_serve_one(
|
||||
config: Config,
|
||||
conn: NetConnector,
|
||||
req: dict,
|
||||
):
|
||||
rid = req['id']
|
||||
logging.info(f'maybe serve request #{rid}')
|
||||
|
||||
# parse request
|
||||
body = json.loads(req['body'])
|
||||
model = body['params']['model']
|
||||
|
||||
# if model not known, ignore.
|
||||
if (
|
||||
model != 'RealESRGAN_x4plus'
|
||||
and
|
||||
model not in MODELS
|
||||
):
|
||||
self.config = config
|
||||
self.conn: NetConnector = conn
|
||||
logging.warning(f'unknown model {model}!, skip...')
|
||||
return
|
||||
|
||||
self._snap = {
|
||||
'queue': [],
|
||||
'requests': {},
|
||||
'results': []
|
||||
}
|
||||
# only handle whitelisted models
|
||||
if (
|
||||
len(config.model_whitelist) > 0
|
||||
and
|
||||
model not in config.model_whitelist
|
||||
):
|
||||
logging.warning('model not whitelisted!, skip...')
|
||||
return
|
||||
|
||||
self._benchmark: list[float] = []
|
||||
self._last_benchmark: list[float]|None = None
|
||||
self._last_generation_ts: str|None = None
|
||||
# if blacklist contains model skip
|
||||
if (
|
||||
len(config.model_blacklist) > 0
|
||||
and
|
||||
model in config.model_blacklist
|
||||
):
|
||||
logging.warning('model not blacklisted!, skip...')
|
||||
return
|
||||
|
||||
def _get_benchmark_speed(self) -> float:
|
||||
'''
|
||||
Return the (arithmetic) average work-iterations-per-second
|
||||
fconducted by this compute worker.
|
||||
results = [res['request_id'] for res in conn._tables['results']]
|
||||
|
||||
'''
|
||||
if not self._last_benchmark:
|
||||
return 0
|
||||
# if worker already produced a result for this request
|
||||
if rid in results:
|
||||
logging.info(f'worker already submitted a result for request #{rid}, skip...')
|
||||
return
|
||||
|
||||
start = self._last_benchmark[0]
|
||||
end = self._last_benchmark[-1]
|
||||
statuses = conn._tables['requests'][rid]
|
||||
|
||||
elapsed = end - start
|
||||
its = len(self._last_benchmark)
|
||||
speed = its / elapsed
|
||||
# skip if workers in non_compete already on it
|
||||
competitors = set((status['worker'] for status in statuses))
|
||||
if bool(config.non_compete & competitors):
|
||||
logging.info('worker in configured non_compete list already working on request, skip...')
|
||||
return
|
||||
|
||||
logging.info(f'{elapsed} s total its: {its}, at {speed} it/s ')
|
||||
# resolve the ipfs hashes into the actual data behind them
|
||||
inputs = []
|
||||
raw_inputs = req['binary_data'].split(',')
|
||||
if raw_inputs:
|
||||
logging.info(f'fetching IPFS inputs: {raw_inputs}')
|
||||
|
||||
return speed
|
||||
retry = 3
|
||||
for _input in req['binary_data'].split(','):
|
||||
if _input:
|
||||
for r in range(retry):
|
||||
try:
|
||||
# user `GPUConnector` to IO with
|
||||
# storage layer to seed the compute
|
||||
# task.
|
||||
img = await conn.get_input_data(_input)
|
||||
inputs.append(img)
|
||||
logging.info(f'retrieved {_input}!')
|
||||
break
|
||||
|
||||
async def should_cancel_work(self, request_id: int):
|
||||
self._benchmark.append(time.time())
|
||||
logging.info('should cancel work?')
|
||||
if request_id not in self._snap['requests']:
|
||||
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
|
||||
return True
|
||||
except BaseException:
|
||||
logging.exception(
|
||||
f'IPFS fetch input error !?! retries left {retry - r - 1}\n'
|
||||
)
|
||||
|
||||
competitors = set([
|
||||
status['worker']
|
||||
for status in self._snap['requests'][request_id]
|
||||
if status['worker'] != self.config.account
|
||||
])
|
||||
logging.info(f'competitors: {competitors}')
|
||||
should_cancel = bool(self.config.non_compete & competitors)
|
||||
logging.info(f'cancel: {should_cancel}')
|
||||
return should_cancel
|
||||
# compute unique request hash used on submit
|
||||
hash_str = (
|
||||
str(req['nonce'])
|
||||
+
|
||||
req['body']
|
||||
+
|
||||
req['binary_data']
|
||||
)
|
||||
logging.debug(f'hashing: {hash_str}')
|
||||
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
||||
logging.info(f'calculated request hash: {request_hash}')
|
||||
|
||||
total_step = body['params']['step']
|
||||
model = body['params']['model']
|
||||
mode = body['method']
|
||||
|
||||
# TODO: validate request
|
||||
|
||||
resp = await conn.begin_work(rid)
|
||||
if not resp or 'code' in resp:
|
||||
logging.info('begin_work error, probably being worked on already... skip.')
|
||||
return
|
||||
|
||||
with maybe_load_model(model, mode):
|
||||
try:
|
||||
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
|
||||
|
||||
output_type = 'png'
|
||||
if 'output_type' in body['params']:
|
||||
output_type = body['params']['output_type']
|
||||
|
||||
output = None
|
||||
output_hash = None
|
||||
match config.backend:
|
||||
case 'sync-on-thread':
|
||||
output_hash, output = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
compute_one,
|
||||
rid,
|
||||
mode, body['params'],
|
||||
inputs=inputs,
|
||||
should_cancel=conn.should_cancel_work,
|
||||
)
|
||||
)
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(
|
||||
f'Unsupported backend {config.backend}'
|
||||
)
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_progress(total_step))
|
||||
|
||||
ipfs_hash = await conn.publish_on_ipfs(output, typ=output_type)
|
||||
|
||||
await conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
||||
|
||||
await maybe_update_tui_balance(conn)
|
||||
|
||||
|
||||
async def snap_updater_task(self):
|
||||
'''
|
||||
Busy loop update the local `._snap: dict` table from
|
||||
except BaseException as err:
|
||||
if 'network cancel' not in str(err):
|
||||
logging.exception('Failed to serve model request !?\n')
|
||||
|
||||
'''
|
||||
while True:
|
||||
self._snap = await self.conn.get_full_queue_snapshot()
|
||||
await trio.sleep(1)
|
||||
if rid in conn._tables['requests']:
|
||||
await conn.cancel_work(rid, 'reason not provided')
|
||||
|
||||
# TODO, design suggestion, just make this a lazily accessed
|
||||
# `@class_property` if we're 3.12+
|
||||
# |_ https://docs.python.org/3/library/functools.html#functools.cached_property
|
||||
async def generate_api(self) -> Quart:
|
||||
'''
|
||||
Gen a `Quart`-compat web API spec which (for now) simply
|
||||
serves a small monitoring ep that reports,
|
||||
|
||||
- iso-time-stamp of the last served model-output
|
||||
- the worker's average "compute-iterations-per-second"
|
||||
async def serve_forever(config: Config, conn: NetConnector):
|
||||
await maybe_update_tui_balance(conn)
|
||||
try:
|
||||
async for tables in conn.iter_poll_update(config.poll_time):
|
||||
queue = tables['queue']
|
||||
|
||||
'''
|
||||
app = Quart(__name__)
|
||||
|
||||
@app.route('/')
|
||||
async def health():
|
||||
return jsonify(
|
||||
account=self.config.account,
|
||||
version=VERSION,
|
||||
last_generation_ts=self._last_generation_ts,
|
||||
last_generation_speed=self._get_benchmark_speed()
|
||||
random.shuffle(queue)
|
||||
queue = sorted(
|
||||
queue,
|
||||
key=lambda req: convert_reward_to_int(req['reward']),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return app
|
||||
if len(queue) > 0:
|
||||
await maybe_serve_one(config, conn, queue[0])
|
||||
|
||||
async def _update_balance(self):
|
||||
async def _fn(tui):
|
||||
# update balance
|
||||
balance = await self.conn.get_worker_balance()
|
||||
tui.set_header_text(new_balance=f'balance: {balance}')
|
||||
|
||||
await maybe_update_tui_async(_fn)
|
||||
|
||||
# TODO? this func is kinda big and maybe is better at module
|
||||
# level to reduce indentation?
|
||||
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
|
||||
async def maybe_serve_one(
|
||||
self,
|
||||
req: dict,
|
||||
):
|
||||
rid = req['id']
|
||||
logging.info(f'maybe serve request #{rid}')
|
||||
|
||||
# parse request
|
||||
body = json.loads(req['body'])
|
||||
model = body['params']['model']
|
||||
|
||||
# if model not known, ignore.
|
||||
if (
|
||||
model != 'RealESRGAN_x4plus'
|
||||
and
|
||||
model not in MODELS
|
||||
):
|
||||
logging.warning(f'unknown model {model}!, skip...')
|
||||
return False
|
||||
|
||||
# only handle whitelisted models
|
||||
if (
|
||||
len(self.config.model_whitelist) > 0
|
||||
and
|
||||
model not in self.config.model_whitelist
|
||||
):
|
||||
logging.warning('model not whitelisted!, skip...')
|
||||
return False
|
||||
|
||||
# if blacklist contains model skip
|
||||
if (
|
||||
len(self.config.model_blacklist) > 0
|
||||
and
|
||||
model in self.config.model_blacklist
|
||||
):
|
||||
logging.warning('model not blacklisted!, skip...')
|
||||
return False
|
||||
|
||||
results = [res['request_id'] for res in self._snap['results']]
|
||||
|
||||
# if worker already produced a result for this request
|
||||
if rid in results:
|
||||
logging.info(f'worker already submitted a result for request #{rid}, skip...')
|
||||
return False
|
||||
|
||||
statuses = self._snap['requests'][rid]
|
||||
|
||||
# skip if workers in non_compete already on it
|
||||
competitors = set((status['worker'] for status in statuses))
|
||||
if bool(self.config.non_compete & competitors):
|
||||
logging.info('worker in configured non_compete list already working on request, skip...')
|
||||
return False
|
||||
|
||||
# resolve the ipfs hashes into the actual data behind them
|
||||
inputs = []
|
||||
raw_inputs = req['binary_data'].split(',')
|
||||
if raw_inputs:
|
||||
logging.info(f'fetching IPFS inputs: {raw_inputs}')
|
||||
|
||||
retry = 3
|
||||
for _input in req['binary_data'].split(','):
|
||||
if _input:
|
||||
for r in range(retry):
|
||||
try:
|
||||
# user `GPUConnector` to IO with
|
||||
# storage layer to seed the compute
|
||||
# task.
|
||||
img = await self.conn.get_input_data(_input)
|
||||
inputs.append(img)
|
||||
logging.info(f'retrieved {_input}!')
|
||||
break
|
||||
|
||||
except BaseException:
|
||||
logging.exception(
|
||||
f'IPFS fetch input error !?! retries left {retry - r - 1}\n'
|
||||
)
|
||||
|
||||
# compute unique request hash used on submit
|
||||
hash_str = (
|
||||
str(req['nonce'])
|
||||
+
|
||||
req['body']
|
||||
+
|
||||
req['binary_data']
|
||||
)
|
||||
logging.debug(f'hashing: {hash_str}')
|
||||
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
||||
logging.info(f'calculated request hash: {request_hash}')
|
||||
|
||||
total_step = body['params']['step']
|
||||
model = body['params']['model']
|
||||
mode = body['method']
|
||||
|
||||
# TODO: validate request
|
||||
|
||||
resp = await self.conn.begin_work(rid)
|
||||
if not resp or 'code' in resp:
|
||||
logging.info('begin_work error, probably being worked on already... skip.')
|
||||
return False
|
||||
|
||||
with maybe_load_model(model, mode):
|
||||
try:
|
||||
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
|
||||
|
||||
output_type = 'png'
|
||||
if 'output_type' in body['params']:
|
||||
output_type = body['params']['output_type']
|
||||
|
||||
output = None
|
||||
output_hash = None
|
||||
match self.config.backend:
|
||||
case 'sync-on-thread':
|
||||
output_hash, output = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
compute_one,
|
||||
rid,
|
||||
mode, body['params'],
|
||||
inputs=inputs,
|
||||
should_cancel=self.should_cancel_work,
|
||||
)
|
||||
)
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(
|
||||
f'Unsupported backend {self.config.backend}'
|
||||
)
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_progress(total_step))
|
||||
|
||||
self._last_generation_ts: str = datetime.now().isoformat()
|
||||
self._last_benchmark: list[float] = self._benchmark
|
||||
self._benchmark: list[float] = []
|
||||
|
||||
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
|
||||
|
||||
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
||||
|
||||
await self._update_balance()
|
||||
|
||||
|
||||
except BaseException as err:
|
||||
if 'network cancel' not in str(err):
|
||||
logging.exception('Failed to serve model request !?\n')
|
||||
|
||||
if rid in self._snap['requests']:
|
||||
await self.conn.cancel_work(rid, 'reason not provided')
|
||||
|
||||
finally:
|
||||
return True
|
||||
|
||||
# TODO, as per above on `.maybe_serve_one()`, it's likely a bit
|
||||
# more *trionic* to define this all as a module level task-func
|
||||
# which operates on a `daemon: WorkerDaemon`?
|
||||
#
|
||||
# -[ ] keeps tasks-as-funcs style prominent
|
||||
# -[ ] avoids so much indentation due to methods
|
||||
async def serve_forever(self):
|
||||
await self._update_balance()
|
||||
try:
|
||||
while True:
|
||||
queue = self._snap['queue']
|
||||
|
||||
random.shuffle(queue)
|
||||
queue = sorted(
|
||||
queue,
|
||||
key=lambda req: convert_reward_to_int(req['reward']),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
for req in queue:
|
||||
# TODO, as mentioned above just inline this once
|
||||
# converted to a mod level func.
|
||||
if (await self.maybe_serve_one(req)):
|
||||
break
|
||||
|
||||
await trio.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
|
|
|
@ -3,6 +3,7 @@ import json
|
|||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
|
@ -66,7 +67,11 @@ class NetConnector:
|
|||
|
||||
self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
|
||||
|
||||
self._wip_requests = {}
|
||||
self._tables = {
|
||||
'queue': [],
|
||||
'requests': {},
|
||||
'results': []
|
||||
}
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
|
||||
|
||||
|
@ -132,9 +137,6 @@ class NetConnector:
|
|||
logging.info('no balance info found')
|
||||
return None
|
||||
|
||||
# TODO, considery making this a NON-method and instead
|
||||
# handing in the `snap['queue']` output beforehand?
|
||||
# -> since that call is the only usage of `self`?
|
||||
async def get_full_queue_snapshot(self):
|
||||
'''
|
||||
Keep in-sync with latest (telos chain's smart-contract) table
|
||||
|
@ -162,6 +164,34 @@ class NetConnector:
|
|||
|
||||
return snap
|
||||
|
||||
async def iter_poll_update(self, poll_time: float) -> AsyncGenerator[dict, None]:
|
||||
'''
|
||||
Long running task, olls gpu contract tables yields latest table rows
|
||||
|
||||
'''
|
||||
while True:
|
||||
start_time = time.time()
|
||||
self._tables = await self.get_full_queue_snapshot()
|
||||
elapsed = time.time() - start_time
|
||||
yield self._tables
|
||||
await trio.sleep(max(poll_time - elapsed, 0.1))
|
||||
|
||||
async def should_cancel_work(self, request_id: int) -> bool:
|
||||
logging.info('should cancel work?')
|
||||
if request_id not in self._tables['requests']:
|
||||
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
|
||||
return True
|
||||
|
||||
competitors = set([
|
||||
status['worker']
|
||||
for status in self._tables['requests'][request_id]
|
||||
if status['worker'] != self.config.account
|
||||
])
|
||||
logging.info(f'competitors: {competitors}')
|
||||
should_cancel = bool(self.config.non_compete & competitors)
|
||||
logging.info(f'cancel: {should_cancel}')
|
||||
return should_cancel
|
||||
|
||||
async def begin_work(self, request_id: int):
|
||||
'''
|
||||
Publish to the bc that the worker is beginning a model-computation
|
||||
|
@ -244,7 +274,7 @@ class NetConnector:
|
|||
result_hash: str,
|
||||
ipfs_hash: str
|
||||
):
|
||||
logging.info('submit_work #{request_id}')
|
||||
logging.info(f'submit_work #{request_id}')
|
||||
return await failable(
|
||||
partial(
|
||||
self.cleos.a_push_action,
|
||||
|
|
Loading…
Reference in New Issue