skynet/skynet/dgpu/daemon.py

300 lines
9.3 KiB
Python
Raw Normal View History

2023-06-04 20:51:43 +00:00
import json
import logging
import random
2023-10-08 22:36:52 +00:00
import time
2025-02-03 20:00:44 +00:00
from datetime import datetime
from functools import partial
from hashlib import sha256
2023-06-04 20:51:43 +00:00
import trio
2023-10-08 22:36:52 +00:00
from quart import jsonify
from quart_trio import QuartTrio as Quart
from skynet.constants import (
MODELS,
VERSION,
)
from skynet.dgpu.errors import (
DGPUComputeError,
)
from skynet.dgpu.compute import ModelMngr
from skynet.dgpu.network import NetConnector
2023-06-04 20:51:43 +00:00
def convert_reward_to_int(reward_str):
int_part, decimal_part = (
reward_str.split('.')[0],
reward_str.split('.')[1].split(' ')[0]
)
return int(int_part + decimal_part)
class WorkerDaemon:
'''
The root "GPU daemon".
2023-06-04 20:51:43 +00:00
Contains/manages underlying susystems:
- a GPU connecto
'''
2023-06-04 20:51:43 +00:00
def __init__(
self,
mm: ModelMngr,
conn: NetConnector,
2023-06-04 20:51:43 +00:00
config: dict
):
self.mm: ModelMngr = mm
self.conn: NetConnector = conn
2023-06-04 20:51:43 +00:00
self.auto_withdraw = (
config['auto_withdraw']
if 'auto_withdraw' in config else False
)
self.account: str = config['account']
2023-10-08 22:36:52 +00:00
self.non_compete = set()
if 'non_compete' in config:
self.non_compete = set(config['non_compete'])
2023-10-07 13:31:36 +00:00
self.model_whitelist = set()
if 'model_whitelist' in config:
self.model_whitelist = set(config['model_whitelist'])
self.model_blacklist = set()
if 'model_blacklist' in config:
self.model_blacklist = set(config['model_blacklist'])
self.backend = 'sync-on-thread'
if 'backend' in config:
self.backend = config['backend']
self._snap = {
'queue': [],
'requests': {},
2025-02-03 21:43:42 +00:00
'results': []
}
self._benchmark: list[float] = []
self._last_benchmark: list[float]|None = None
self._last_generation_ts: str|None = None
2023-10-08 22:36:52 +00:00
def _get_benchmark_speed(self) -> float:
'''
Return the (arithmetic) average work-iterations-per-second
fconducted by this compute worker.
'''
2023-10-08 22:36:52 +00:00
if not self._last_benchmark:
return 0
start = self._last_benchmark[0]
end = self._last_benchmark[-1]
elapsed = end - start
its = len(self._last_benchmark)
speed = its / elapsed
logging.info(f'{elapsed} s total its: {its}, at {speed} it/s ')
return speed
async def should_cancel_work(self, request_id: int):
2023-10-08 22:36:52 +00:00
self._benchmark.append(time.time())
competitors = set([
status['worker']
for status in self._snap['requests'][request_id]
2023-10-08 22:36:52 +00:00
if status['worker'] != self.account
])
return bool(self.non_compete & competitors)
2023-06-04 20:51:43 +00:00
async def snap_updater_task(self):
'''
Busy loop update the local `._snap: dict` table from
'''
while True:
self._snap = await self.conn.get_full_queue_snapshot()
await trio.sleep(1)
# TODO, design suggestion, just make this a lazily accessed
# `@class_property` if we're 3.12+
# |_ https://docs.python.org/3/library/functools.html#functools.cached_property
async def generate_api(self) -> Quart:
'''
Gen a `Quart`-compat web API spec which (for now) simply
serves a small monitoring ep that reports,
- iso-time-stamp of the last served model-output
- the worker's average "compute-iterations-per-second"
'''
2023-10-08 22:36:52 +00:00
app = Quart(__name__)
@app.route('/')
async def health():
return jsonify(
account=self.account,
version=VERSION,
last_generation_ts=self._last_generation_ts,
last_generation_speed=self._get_benchmark_speed()
)
return app
# TODO? this func is kinda big and maybe is better at module
# level to reduce indentation?
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
async def maybe_serve_one(
self,
req: dict,
):
2025-01-10 00:10:07 +00:00
rid = req['id']
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known, ignore.
if (
model != 'RealESRGAN_x4plus'
and
model not in MODELS
):
2025-01-10 00:10:07 +00:00
logging.warning(f'Unknown model {model}')
return False
# only handle whitelisted models
if (
len(self.model_whitelist) > 0
and
model not in self.model_whitelist
):
2025-01-10 00:10:07 +00:00
return False
# if blacklist contains model skip
if model in self.model_blacklist:
return False
2025-02-03 21:43:42 +00:00
results = [res['id'] for res in self._snap['results']]
if (
2025-02-03 21:43:42 +00:00
rid not in results
and
rid in self._snap['requests']
):
2025-01-10 00:10:07 +00:00
statuses = self._snap['requests'][rid]
if len(statuses) == 0:
inputs = []
for _input in req['binary_data'].split(','):
if _input:
for _ in range(3):
try:
# user `GPUConnector` to IO with
# storage layer to seed the compute
# task.
img = await self.conn.get_input_data(_input)
inputs.append(img)
break
except BaseException:
logging.exception(
'Model input error !?!\n'
)
2025-01-10 00:10:07 +00:00
hash_str = (
str(req['nonce'])
+
req['body']
+
req['binary_data']
)
logging.info(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
# TODO: validate request
# perform work
logging.info(f'working on {body}')
resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp:
logging.info('probably being worked on already... skip.')
2025-01-10 00:10:07 +00:00
else:
try:
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.backend:
case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
rid,
body['method'], body['params'],
inputs=inputs
)
)
case _:
raise DGPUComputeError(
f'Unsupported backend {self.backend}'
)
self._last_generation_ts: str = datetime.now().isoformat()
self._last_benchmark: list[float] = self._benchmark
self._benchmark: list[float] = []
2025-01-10 00:10:07 +00:00
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
except BaseException as err:
logging.exception('Failed to serve model request !?\n')
# traceback.print_exc() # TODO? <- replaced by above ya?
await self.conn.cancel_work(rid, str(err))
2025-01-10 00:10:07 +00:00
finally:
return True
# TODO, i would inverse this case logic to avoid an indent
# level in above block ;)
2025-01-10 00:10:07 +00:00
else:
logging.info(f'request {rid} already beign worked on, skip...')
# TODO, as per above on `.maybe_serve_one()`, it's likely a bit
# more *trionic* to define this all as a module level task-func
# which operates on a `daemon: WorkerDaemon`?
#
# -[ ] keeps tasks-as-funcs style prominent
# -[ ] avoids so much indentation due to methods
2023-06-04 20:51:43 +00:00
async def serve_forever(self):
try:
while True:
if self.auto_withdraw:
await self.conn.maybe_withdraw_all()
queue = self._snap['queue']
random.shuffle(queue)
queue = sorted(
queue,
key=lambda req: convert_reward_to_int(req['reward']),
reverse=True
)
2023-06-04 20:51:43 +00:00
for req in queue:
# TODO, as mentioned above just inline this once
# converted to a mod level func.
2025-01-10 00:10:07 +00:00
if (await self.maybe_serve_one(req)):
break
2023-06-04 20:51:43 +00:00
await trio.sleep(1)
except KeyboardInterrupt:
...