skynet/skynet/dgpu/daemon.py

344 lines
11 KiB
Python
Executable File

import json
import logging
import random
import time
from datetime import datetime
from functools import partial
from hashlib import sha256
import trio
from quart import jsonify
from quart_trio import QuartTrio as Quart
from skynet.constants import (
MODELS,
VERSION,
)
from skynet.dgpu.errors import (
DGPUComputeError,
)
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import NetConnector
def convert_reward_to_int(reward_str):
int_part, decimal_part = (
reward_str.split('.')[0],
reward_str.split('.')[1].split(' ')[0]
)
return int(int_part + decimal_part)
class WorkerDaemon:
'''
The root "GPU daemon".
Contains/manages underlying susystems:
- a GPU connecto
'''
def __init__(
self,
conn: NetConnector,
config: dict,
tui: WorkerMonitor | None = None
):
self.conn: NetConnector = conn
self._tui = tui
self.auto_withdraw = (
config['auto_withdraw']
if 'auto_withdraw' in config else False
)
self.account: str = config['account']
self.non_compete = set()
if 'non_compete' in config:
self.non_compete = set(config['non_compete'])
self.model_whitelist = set()
if 'model_whitelist' in config:
self.model_whitelist = set(config['model_whitelist'])
self.model_blacklist = set()
if 'model_blacklist' in config:
self.model_blacklist = set(config['model_blacklist'])
self.backend = 'sync-on-thread'
if 'backend' in config:
self.backend = config['backend']
self._snap = {
'queue': [],
'requests': {},
'results': []
}
self._benchmark: list[float] = []
self._last_benchmark: list[float]|None = None
self._last_generation_ts: str|None = None
def _get_benchmark_speed(self) -> float:
'''
Return the (arithmetic) average work-iterations-per-second
fconducted by this compute worker.
'''
if not self._last_benchmark:
return 0
start = self._last_benchmark[0]
end = self._last_benchmark[-1]
elapsed = end - start
its = len(self._last_benchmark)
speed = its / elapsed
logging.info(f'{elapsed} s total its: {its}, at {speed} it/s ')
return speed
async def should_cancel_work(self, request_id: int):
self._benchmark.append(time.time())
logging.info('should cancel work?')
if request_id not in self._snap['requests']:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
competitors = set([
status['worker']
for status in self._snap['requests'][request_id]
if status['worker'] != self.account
])
logging.info(f'competitors: {competitors}')
should_cancel = bool(self.non_compete & competitors)
logging.info(f'cancel: {should_cancel}')
return should_cancel
async def snap_updater_task(self):
'''
Busy loop update the local `._snap: dict` table from
'''
while True:
self._snap = await self.conn.get_full_queue_snapshot()
await trio.sleep(1)
# TODO, design suggestion, just make this a lazily accessed
# `@class_property` if we're 3.12+
# |_ https://docs.python.org/3/library/functools.html#functools.cached_property
async def generate_api(self) -> Quart:
'''
Gen a `Quart`-compat web API spec which (for now) simply
serves a small monitoring ep that reports,
- iso-time-stamp of the last served model-output
- the worker's average "compute-iterations-per-second"
'''
app = Quart(__name__)
@app.route('/')
async def health():
return jsonify(
account=self.account,
version=VERSION,
last_generation_ts=self._last_generation_ts,
last_generation_speed=self._get_benchmark_speed()
)
return app
async def _update_balance(self):
if self._tui:
# update balance
balance = await self.conn.get_worker_balance()
self._tui.set_header_text(new_balance=f'balance: {balance}')
# TODO? this func is kinda big and maybe is better at module
# level to reduce indentation?
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
async def maybe_serve_one(
self,
req: dict,
):
rid = req['id']
logging.info(f'maybe serve request #{rid}')
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known, ignore.
if (
model != 'RealESRGAN_x4plus'
and
model not in MODELS
):
logging.warning(f'unknown model {model}!, skip...')
return False
# only handle whitelisted models
if (
len(self.model_whitelist) > 0
and
model not in self.model_whitelist
):
logging.warning('model not whitelisted!, skip...')
return False
# if blacklist contains model skip
if model in self.model_blacklist:
logging.warning('model not blacklisted!, skip...')
return False
results = [res['request_id'] for res in self._snap['results']]
# if worker already produced a result for this request
if rid in results:
logging.info(f'worker already submitted a result for request #{rid}, skip...')
return False
statuses = self._snap['requests'][rid]
# skip if workers in non_compete already on it
competitors = set((status['worker'] for status in statuses))
if bool(self.non_compete & competitors):
logging.info('worker in configured non_compete list already working on request, skip...')
return False
# resolve the ipfs hashes into the actual data behind them
inputs = []
raw_inputs = req['binary_data'].split(',')
if raw_inputs:
logging.info(f'fetching IPFS inputs: {raw_inputs}')
retry = 3
for _input in req['binary_data'].split(','):
if _input:
for r in range(retry):
try:
# user `GPUConnector` to IO with
# storage layer to seed the compute
# task.
img = await self.conn.get_input_data(_input)
inputs.append(img)
logging.info(f'retrieved {_input}!')
break
except BaseException:
logging.exception(
f'IPFS fetch input error !?! retries left {retry - r - 1}\n'
)
# compute unique request hash used on submit
hash_str = (
str(req['nonce'])
+
req['body']
+
req['binary_data']
)
logging.debug(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
# TODO: validate request
resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.')
return False
with maybe_load_model(model, mode):
try:
if self._tui:
self._tui.set_progress(0, done=total_step)
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.backend:
case 'sync-on-thread':
output_hash, output = await trio.to_thread.run_sync(
partial(
compute_one,
rid,
mode, body['params'],
inputs=inputs,
should_cancel=self.should_cancel_work,
tui=self._tui
)
)
case _:
raise DGPUComputeError(
f'Unsupported backend {self.backend}'
)
if self._tui:
self._tui.set_progress(total_step)
self._last_generation_ts: str = datetime.now().isoformat()
self._last_benchmark: list[float] = self._benchmark
self._benchmark: list[float] = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
await self._update_balance()
except BaseException as err:
if 'network cancel' not in str(err):
logging.exception('Failed to serve model request !?\n')
if rid in self._snap['requests']:
await self.conn.cancel_work(rid, 'reason not provided')
finally:
return True
# TODO, as per above on `.maybe_serve_one()`, it's likely a bit
# more *trionic* to define this all as a module level task-func
# which operates on a `daemon: WorkerDaemon`?
#
# -[ ] keeps tasks-as-funcs style prominent
# -[ ] avoids so much indentation due to methods
async def serve_forever(self):
await self._update_balance()
try:
while True:
if self.auto_withdraw:
await self.conn.maybe_withdraw_all()
queue = self._snap['queue']
random.shuffle(queue)
queue = sorted(
queue,
key=lambda req: convert_reward_to_int(req['reward']),
reverse=True
)
for req in queue:
# TODO, as mentioned above just inline this once
# converted to a mod level func.
if (await self.maybe_serve_one(req)):
break
await trio.sleep(1)
except KeyboardInterrupt:
...