skynet/skynet/dgpu/daemon.py

344 lines
11 KiB
Python
Raw Normal View History

2023-06-04 20:51:43 +00:00
import json
import logging
import random
2023-10-08 22:36:52 +00:00
import time
2025-02-03 20:00:44 +00:00
from datetime import datetime
from functools import partial
from hashlib import sha256
2023-06-04 20:51:43 +00:00
import trio
2023-10-08 22:36:52 +00:00
from quart import jsonify
from quart_trio import QuartTrio as Quart
from skynet.constants import (
MODELS,
VERSION,
)
from skynet.dgpu.errors import (
DGPUComputeError,
)
2025-02-05 18:35:40 +00:00
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import NetConnector
2023-06-04 20:51:43 +00:00
def convert_reward_to_int(reward_str):
int_part, decimal_part = (
reward_str.split('.')[0],
reward_str.split('.')[1].split(' ')[0]
)
return int(int_part + decimal_part)
class WorkerDaemon:
'''
The root "GPU daemon".
2023-06-04 20:51:43 +00:00
Contains/manages underlying susystems:
- a GPU connecto
'''
2023-06-04 20:51:43 +00:00
def __init__(
self,
conn: NetConnector,
2025-02-05 18:35:40 +00:00
config: dict,
tui: WorkerMonitor | None = None
2023-06-04 20:51:43 +00:00
):
self.conn: NetConnector = conn
2025-02-05 18:35:40 +00:00
self._tui = tui
2023-06-04 20:51:43 +00:00
self.auto_withdraw = (
config['auto_withdraw']
if 'auto_withdraw' in config else False
)
self.account: str = config['account']
2023-10-08 22:36:52 +00:00
self.non_compete = set()
if 'non_compete' in config:
self.non_compete = set(config['non_compete'])
2023-10-07 13:31:36 +00:00
self.model_whitelist = set()
if 'model_whitelist' in config:
self.model_whitelist = set(config['model_whitelist'])
self.model_blacklist = set()
if 'model_blacklist' in config:
self.model_blacklist = set(config['model_blacklist'])
self.backend = 'sync-on-thread'
if 'backend' in config:
self.backend = config['backend']
self._snap = {
'queue': [],
'requests': {},
2025-02-03 21:43:42 +00:00
'results': []
}
self._benchmark: list[float] = []
self._last_benchmark: list[float]|None = None
self._last_generation_ts: str|None = None
2023-10-08 22:36:52 +00:00
def _get_benchmark_speed(self) -> float:
'''
Return the (arithmetic) average work-iterations-per-second
fconducted by this compute worker.
'''
2023-10-08 22:36:52 +00:00
if not self._last_benchmark:
return 0
start = self._last_benchmark[0]
end = self._last_benchmark[-1]
elapsed = end - start
its = len(self._last_benchmark)
speed = its / elapsed
logging.info(f'{elapsed} s total its: {its}, at {speed} it/s ')
return speed
async def should_cancel_work(self, request_id: int):
2023-10-08 22:36:52 +00:00
self._benchmark.append(time.time())
logging.info('should cancel work?')
if request_id not in self._snap['requests']:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
competitors = set([
status['worker']
for status in self._snap['requests'][request_id]
2023-10-08 22:36:52 +00:00
if status['worker'] != self.account
])
logging.info(f'competitors: {competitors}')
should_cancel = bool(self.non_compete & competitors)
logging.info(f'cancel: {should_cancel}')
return should_cancel
2023-06-04 20:51:43 +00:00
async def snap_updater_task(self):
'''
Busy loop update the local `._snap: dict` table from
'''
while True:
self._snap = await self.conn.get_full_queue_snapshot()
await trio.sleep(1)
# TODO, design suggestion, just make this a lazily accessed
# `@class_property` if we're 3.12+
# |_ https://docs.python.org/3/library/functools.html#functools.cached_property
async def generate_api(self) -> Quart:
'''
Gen a `Quart`-compat web API spec which (for now) simply
serves a small monitoring ep that reports,
- iso-time-stamp of the last served model-output
- the worker's average "compute-iterations-per-second"
'''
2023-10-08 22:36:52 +00:00
app = Quart(__name__)
@app.route('/')
async def health():
return jsonify(
account=self.account,
version=VERSION,
last_generation_ts=self._last_generation_ts,
last_generation_speed=self._get_benchmark_speed()
)
return app
2025-02-05 18:35:40 +00:00
async def _update_balance(self):
if self._tui:
# update balance
balance = await self.conn.get_worker_balance()
self._tui.set_header_text(new_balance=f'balance: {balance}')
# TODO? this func is kinda big and maybe is better at module
# level to reduce indentation?
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
async def maybe_serve_one(
self,
req: dict,
):
2025-01-10 00:10:07 +00:00
rid = req['id']
logging.info(f'maybe serve request #{rid}')
2025-01-10 00:10:07 +00:00
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known, ignore.
if (
model != 'RealESRGAN_x4plus'
and
model not in MODELS
):
logging.warning(f'unknown model {model}!, skip...')
2025-01-10 00:10:07 +00:00
return False
# only handle whitelisted models
if (
len(self.model_whitelist) > 0
and
model not in self.model_whitelist
):
logging.warning('model not whitelisted!, skip...')
2025-01-10 00:10:07 +00:00
return False
# if blacklist contains model skip
if model in self.model_blacklist:
logging.warning('model not blacklisted!, skip...')
2025-01-10 00:10:07 +00:00
return False
results = [res['request_id'] for res in self._snap['results']]
2025-01-10 00:10:07 +00:00
# if worker already produced a result for this request
if rid in results:
logging.info(f'worker already submitted a result for request #{rid}, skip...')
return False
statuses = self._snap['requests'][rid]
# skip if workers in non_compete already on it
competitors = set((status['worker'] for status in statuses))
if bool(self.non_compete & competitors):
logging.info('worker in configured non_compete list already working on request, skip...')
return False
# resolve the ipfs hashes into the actual data behind them
inputs = []
raw_inputs = req['binary_data'].split(',')
if raw_inputs:
logging.info(f'fetching IPFS inputs: {raw_inputs}')
retry = 3
for _input in req['binary_data'].split(','):
if _input:
for r in range(retry):
try:
# user `GPUConnector` to IO with
# storage layer to seed the compute
# task.
img = await self.conn.get_input_data(_input)
inputs.append(img)
logging.info(f'retrieved {_input}!')
break
except BaseException:
logging.exception(
f'IPFS fetch input error !?! retries left {retry - r - 1}\n'
)
# compute unique request hash used on submit
hash_str = (
str(req['nonce'])
+
req['body']
+
req['binary_data']
)
logging.debug(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'calculated request hash: {request_hash}')
2025-02-05 18:35:40 +00:00
total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
2025-02-05 18:35:40 +00:00
# TODO: validate request
resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.')
return False
with maybe_load_model(model, mode):
try:
2025-02-05 18:35:40 +00:00
if self._tui:
self._tui.set_progress(0, done=total_step)
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.backend:
case 'sync-on-thread':
output_hash, output = await trio.to_thread.run_sync(
partial(
compute_one,
rid,
mode, body['params'],
inputs=inputs,
should_cancel=self.should_cancel_work,
tui=self._tui
)
)
case _:
raise DGPUComputeError(
f'Unsupported backend {self.backend}'
)
2025-01-10 00:10:07 +00:00
2025-02-05 18:35:40 +00:00
if self._tui:
self._tui.set_progress(total_step)
self._last_generation_ts: str = datetime.now().isoformat()
self._last_benchmark: list[float] = self._benchmark
self._benchmark: list[float] = []
2025-01-10 00:10:07 +00:00
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
2025-01-10 00:10:07 +00:00
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
2025-01-10 00:10:07 +00:00
2025-02-05 18:35:40 +00:00
await self._update_balance()
except BaseException as err:
if 'network cancel' not in str(err):
logging.exception('Failed to serve model request !?\n')
if rid in self._snap['requests']:
await self.conn.cancel_work(rid, 'reason not provided')
2025-01-10 00:10:07 +00:00
finally:
return True
2025-01-10 00:10:07 +00:00
# TODO, as per above on `.maybe_serve_one()`, it's likely a bit
# more *trionic* to define this all as a module level task-func
# which operates on a `daemon: WorkerDaemon`?
#
# -[ ] keeps tasks-as-funcs style prominent
# -[ ] avoids so much indentation due to methods
2023-06-04 20:51:43 +00:00
async def serve_forever(self):
2025-02-05 18:35:40 +00:00
await self._update_balance()
2023-06-04 20:51:43 +00:00
try:
while True:
if self.auto_withdraw:
await self.conn.maybe_withdraw_all()
queue = self._snap['queue']
random.shuffle(queue)
queue = sorted(
queue,
key=lambda req: convert_reward_to_int(req['reward']),
reverse=True
)
2023-06-04 20:51:43 +00:00
for req in queue:
# TODO, as mentioned above just inline this once
# converted to a mod level func.
2025-01-10 00:10:07 +00:00
if (await self.maybe_serve_one(req)):
break
2023-06-04 20:51:43 +00:00
await trio.sleep(1)
except KeyboardInterrupt:
...