mirror of https://github.com/skygpu/skynet.git
Suggest `skynet.dgpu` docs, typing, pythonisms
From the deep-ish dive drafting our first set of design/architecture diagrams in https://github.com/skygpu/cyberdyne/pull/2, this adds a buncha suggestions, typing, and styling adjustments. Namely the code tweaks include, - changing to multi-line import tuples where appropriate (since they're much handier to modify ;) - adding typing in many spots where it wasn't clear to me the types being returned/operated-with in various (internal) methods. - doc strings (in mostly random spots Xp ) where i had the need to remember the impl's purpose but didn't want to re-read the code in detail again. - ALOT of TODOs surrounding various potential style changes, re-factorings, naming and in some cases "modernization" according to the latest python3.12 feats/spec/stdlib.pull/48/merge
parent
c0ac6298a9
commit
7cb9f09d95
|
@ -14,8 +14,8 @@ from skynet.dgpu.network import SkynetGPUConnector
|
|||
async def open_dgpu_node(config: dict) -> None:
|
||||
'''
|
||||
Open a top level "GPU mgmt daemon", keep the
|
||||
`SkynetDGPUDaemon._snap: dict[str, list|dict]` table and *maybe*
|
||||
serve a `hypercorn` web API.
|
||||
`SkynetDGPUDaemon._snap: dict[str, list|dict]` table
|
||||
and *maybe* serve a `hypercorn` web API.
|
||||
|
||||
'''
|
||||
conn = SkynetGPUConnector(config)
|
||||
|
@ -32,6 +32,8 @@ async def open_dgpu_node(config: dict) -> None:
|
|||
async with trio.open_nursery() as tn:
|
||||
tn.start_soon(daemon.snap_updater_task)
|
||||
|
||||
# TODO, consider a more explicit `as hypercorn_serve`
|
||||
# to clarify?
|
||||
if api:
|
||||
tn.start_soon(serve, api, api_conf)
|
||||
|
||||
|
|
|
@ -1,20 +1,36 @@
|
|||
#!/usr/bin/python
|
||||
# ^TODO? again, why..
|
||||
#
|
||||
# Do we expect this mod
|
||||
# to be invoked? if so why is there no
|
||||
# `if __name__ == '__main__'` guard?
|
||||
#
|
||||
# if anything this should contain a license header ;)
|
||||
|
||||
# Skynet Memory Manager
|
||||
'''
|
||||
Skynet Memory Manager
|
||||
|
||||
'''
|
||||
|
||||
import gc
|
||||
import logging
|
||||
|
||||
from hashlib import sha256
|
||||
import zipfile
|
||||
from PIL import Image
|
||||
from diffusers import DiffusionPipeline
|
||||
# import zipfile
|
||||
# from PIL import Image
|
||||
# from diffusers import DiffusionPipeline
|
||||
|
||||
import trio
|
||||
import torch
|
||||
|
||||
from skynet.constants import DEFAULT_INITAL_MODEL, MODELS
|
||||
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
|
||||
# from skynet.constants import (
|
||||
# DEFAULT_INITAL_MODEL,
|
||||
# MODELS,
|
||||
# )
|
||||
from skynet.dgpu.errors import (
|
||||
DGPUComputeError,
|
||||
DGPUInferenceCancelled,
|
||||
)
|
||||
|
||||
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
||||
|
||||
|
@ -66,15 +82,20 @@ def prepare_params_for_diffuse(
|
|||
)
|
||||
|
||||
|
||||
# TODO, yet again - drop the redundant prefix ;)
|
||||
class SkynetMM:
|
||||
'''
|
||||
(AI algo) Model manager for loading models, computing outputs,
|
||||
checking load state, and unloading when no-longer-needed/finished.
|
||||
|
||||
'''
|
||||
def __init__(self, config: dict):
|
||||
self.cache_dir = None
|
||||
if 'hf_home' in config:
|
||||
self.cache_dir = config['hf_home']
|
||||
|
||||
self._model_name = ''
|
||||
self._model_mode = ''
|
||||
self._model_name: str = ''
|
||||
self._model_mode: str = ''
|
||||
|
||||
# self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
|
||||
|
||||
|
@ -89,7 +110,7 @@ class SkynetMM:
|
|||
|
||||
return False
|
||||
|
||||
def unload_model(self):
|
||||
def unload_model(self) -> None:
|
||||
if getattr(self, '_model', None):
|
||||
del self._model
|
||||
|
||||
|
@ -103,7 +124,7 @@ class SkynetMM:
|
|||
self,
|
||||
name: str,
|
||||
mode: str
|
||||
):
|
||||
) -> None:
|
||||
logging.info(f'loading model {name}...')
|
||||
self.unload_model()
|
||||
self._model = pipeline_for(
|
||||
|
@ -111,7 +132,6 @@ class SkynetMM:
|
|||
self._model_mode = mode
|
||||
self._model_name = name
|
||||
|
||||
|
||||
def compute_one(
|
||||
self,
|
||||
request_id: int,
|
||||
|
@ -124,6 +144,9 @@ class SkynetMM:
|
|||
should_raise = trio.from_thread.run(self._should_cancel, request_id)
|
||||
if should_raise:
|
||||
logging.warn(f'cancelling work at step {step}')
|
||||
|
||||
# ?TODO, this is never caught, so why is it
|
||||
# raised specially?
|
||||
raise DGPUInferenceCancelled()
|
||||
|
||||
return {}
|
||||
|
@ -199,9 +222,10 @@ class SkynetMM:
|
|||
case _:
|
||||
raise DGPUComputeError('Unsupported compute method')
|
||||
|
||||
except BaseException as e:
|
||||
logging.error(e)
|
||||
raise DGPUComputeError(str(e))
|
||||
except BaseException as err:
|
||||
logging.error(err)
|
||||
# to see the src exc in tb
|
||||
raise DGPUComputeError(str(err)) from err
|
||||
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -1,23 +1,25 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import json
|
||||
import random
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from hashlib import sha256
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from hashlib import sha256
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
# import traceback
|
||||
import time
|
||||
|
||||
import trio
|
||||
|
||||
from quart import jsonify
|
||||
from quart_trio import QuartTrio as Quart
|
||||
|
||||
from skynet.constants import MODELS, VERSION
|
||||
|
||||
from skynet.dgpu.errors import *
|
||||
from skynet.constants import (
|
||||
MODELS,
|
||||
VERSION,
|
||||
)
|
||||
from skynet.dgpu.errors import (
|
||||
DGPUComputeError,
|
||||
)
|
||||
from skynet.dgpu.compute import SkynetMM
|
||||
from skynet.dgpu.network import SkynetGPUConnector
|
||||
|
||||
|
@ -30,22 +32,29 @@ def convert_reward_to_int(reward_str):
|
|||
return int(int_part + decimal_part)
|
||||
|
||||
|
||||
# prolly don't need the `Skynet` prefix since that's kinda implied ;p
|
||||
class SkynetDGPUDaemon:
|
||||
'''
|
||||
The root "GPU daemon".
|
||||
|
||||
Contains/manages underlying susystems:
|
||||
- a GPU connecto
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
mm: SkynetMM,
|
||||
conn: SkynetGPUConnector,
|
||||
config: dict
|
||||
):
|
||||
self.mm = mm
|
||||
self.conn = conn
|
||||
self.mm: SkynetMM = mm
|
||||
self.conn: SkynetGPUConnector = conn
|
||||
self.auto_withdraw = (
|
||||
config['auto_withdraw']
|
||||
if 'auto_withdraw' in config else False
|
||||
)
|
||||
|
||||
self.account = config['account']
|
||||
self.account: str = config['account']
|
||||
|
||||
self.non_compete = set()
|
||||
if 'non_compete' in config:
|
||||
|
@ -67,13 +76,20 @@ class SkynetDGPUDaemon:
|
|||
'queue': [],
|
||||
'requests': {},
|
||||
'my_results': []
|
||||
# ^and here i thot they were **my** results..
|
||||
# :sadcat:
|
||||
}
|
||||
|
||||
self._benchmark = []
|
||||
self._last_benchmark = None
|
||||
self._last_generation_ts = None
|
||||
self._benchmark: list[float] = []
|
||||
self._last_benchmark: list[float]|None = None
|
||||
self._last_generation_ts: str|None = None
|
||||
|
||||
def _get_benchmark_speed(self) -> float:
|
||||
'''
|
||||
Return the (arithmetic) average work-iterations-per-second
|
||||
fconducted by this compute worker.
|
||||
|
||||
'''
|
||||
if not self._last_benchmark:
|
||||
return 0
|
||||
|
||||
|
@ -99,11 +115,26 @@ class SkynetDGPUDaemon:
|
|||
|
||||
|
||||
async def snap_updater_task(self):
|
||||
'''
|
||||
Busy loop update the local `._snap: dict` table from
|
||||
|
||||
'''
|
||||
while True:
|
||||
self._snap = await self.conn.get_full_queue_snapshot()
|
||||
await trio.sleep(1)
|
||||
|
||||
async def generate_api(self):
|
||||
# TODO, design suggestion, just make this a lazily accessed
|
||||
# `@class_property` if we're 3.12+
|
||||
# |_ https://docs.python.org/3/library/functools.html#functools.cached_property
|
||||
async def generate_api(self) -> Quart:
|
||||
'''
|
||||
Gen a `Quart`-compat web API spec which (for now) simply
|
||||
serves a small monitoring ep that reports,
|
||||
|
||||
- iso-time-stamp of the last served model-output
|
||||
- the worker's average "compute-iterations-per-second"
|
||||
|
||||
'''
|
||||
app = Quart(__name__)
|
||||
|
||||
@app.route('/')
|
||||
|
@ -117,21 +148,34 @@ class SkynetDGPUDaemon:
|
|||
|
||||
return app
|
||||
|
||||
async def maybe_serve_one(self, req):
|
||||
# TODO? this func is kinda big and maybe is better at module
|
||||
# level to reduce indentation?
|
||||
# -[ ] just pass `daemon: SkynetDGPUDaemon` vs. `self`
|
||||
async def maybe_serve_one(
|
||||
self,
|
||||
req: dict,
|
||||
):
|
||||
rid = req['id']
|
||||
|
||||
# parse request
|
||||
body = json.loads(req['body'])
|
||||
model = body['params']['model']
|
||||
|
||||
# if model not known
|
||||
if model != 'RealESRGAN_x4plus' and model not in MODELS:
|
||||
# if model not known, ignore.
|
||||
if (
|
||||
model != 'RealESRGAN_x4plus'
|
||||
and
|
||||
model not in MODELS
|
||||
):
|
||||
logging.warning(f'Unknown model {model}')
|
||||
return False
|
||||
|
||||
# if whitelist enabled and model not in it continue
|
||||
if (len(self.model_whitelist) > 0 and
|
||||
not model in self.model_whitelist):
|
||||
# only handle whitelisted models
|
||||
if (
|
||||
len(self.model_whitelist) > 0
|
||||
and
|
||||
model not in self.model_whitelist
|
||||
):
|
||||
return False
|
||||
|
||||
# if blacklist contains model skip
|
||||
|
@ -139,21 +183,29 @@ class SkynetDGPUDaemon:
|
|||
return False
|
||||
|
||||
my_results = [res['id'] for res in self._snap['my_results']]
|
||||
if rid not in my_results and rid in self._snap['requests']:
|
||||
if (
|
||||
rid not in my_results
|
||||
and
|
||||
rid in self._snap['requests']
|
||||
):
|
||||
statuses = self._snap['requests'][rid]
|
||||
|
||||
if len(statuses) == 0:
|
||||
inputs = []
|
||||
for _input in req['binary_data'].split(','):
|
||||
if _input:
|
||||
for _ in range(3):
|
||||
try:
|
||||
# user `GPUConnector` to IO with
|
||||
# storage layer to seed the compute
|
||||
# task.
|
||||
img = await self.conn.get_input_data(_input)
|
||||
inputs.append(img)
|
||||
break
|
||||
|
||||
except:
|
||||
...
|
||||
except BaseException:
|
||||
logging.exception(
|
||||
'Model input error !?!\n'
|
||||
)
|
||||
|
||||
hash_str = (
|
||||
str(req['nonce'])
|
||||
|
@ -172,7 +224,7 @@ class SkynetDGPUDaemon:
|
|||
|
||||
resp = await self.conn.begin_work(rid)
|
||||
if not resp or 'code' in resp:
|
||||
logging.info(f'probably being worked on already... skip.')
|
||||
logging.info('probably being worked on already... skip.')
|
||||
|
||||
else:
|
||||
try:
|
||||
|
@ -195,25 +247,37 @@ class SkynetDGPUDaemon:
|
|||
)
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unsupported backend {self.backend}')
|
||||
self._last_generation_ts = datetime.now().isoformat()
|
||||
self._last_benchmark = self._benchmark
|
||||
self._benchmark = []
|
||||
raise DGPUComputeError(
|
||||
f'Unsupported backend {self.backend}'
|
||||
)
|
||||
|
||||
self._last_generation_ts: str = datetime.now().isoformat()
|
||||
self._last_benchmark: list[float] = self._benchmark
|
||||
self._benchmark: list[float] = []
|
||||
|
||||
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
|
||||
|
||||
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
||||
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
await self.conn.cancel_work(rid, str(e))
|
||||
except BaseException as err:
|
||||
logging.exception('Failed to serve model request !?\n')
|
||||
# traceback.print_exc() # TODO? <- replaced by above ya?
|
||||
await self.conn.cancel_work(rid, str(err))
|
||||
|
||||
finally:
|
||||
return True
|
||||
|
||||
# TODO, i would inverse this case logic to avoid an indent
|
||||
# level in above block ;)
|
||||
else:
|
||||
logging.info(f'request {rid} already beign worked on, skip...')
|
||||
|
||||
# TODO, as per above on `.maybe_serve_one()`, it's likely a bit
|
||||
# more *trionic* to define this all as a module level task-func
|
||||
# which operates on a `daemon: SkynetDGPUDaemon`?
|
||||
#
|
||||
# -[ ] keeps tasks-as-funcs style prominent
|
||||
# -[ ] avoids so much indentation due to methods
|
||||
async def serve_forever(self):
|
||||
try:
|
||||
while True:
|
||||
|
@ -230,6 +294,8 @@ class SkynetDGPUDaemon:
|
|||
)
|
||||
|
||||
for req in queue:
|
||||
# TODO, as mentioned above just inline this once
|
||||
# converted to a mod level func.
|
||||
if (await self.maybe_serve_one(req)):
|
||||
break
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#!/usr/bin/python
|
||||
# ^TODO, why..
|
||||
|
||||
|
||||
class DGPUComputeError(BaseException):
|
||||
|
|
|
@ -13,40 +13,72 @@ import leap
|
|||
import anyio
|
||||
import httpx
|
||||
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from PIL import (
|
||||
Image,
|
||||
# UnidentifiedImageError, # TODO, remove?
|
||||
)
|
||||
|
||||
from leap.cleos import CLEOS
|
||||
from leap.protocol import Asset
|
||||
from skynet.constants import DEFAULT_IPFS_DOMAIN, GPU_CONTRACT_ABI
|
||||
from skynet.constants import (
|
||||
DEFAULT_IPFS_DOMAIN,
|
||||
GPU_CONTRACT_ABI,
|
||||
)
|
||||
|
||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
|
||||
from skynet.dgpu.errors import DGPUComputeError
|
||||
from skynet.ipfs import (
|
||||
AsyncIPFSHTTP,
|
||||
get_ipfs_file,
|
||||
)
|
||||
# TODO, remove?
|
||||
# from skynet.dgpu.errors import DGPUComputeError
|
||||
|
||||
|
||||
REQUEST_UPDATE_TIME = 3
|
||||
REQUEST_UPDATE_TIME: int = 3
|
||||
|
||||
|
||||
async def failable(fn: partial, ret_fail=None):
|
||||
# TODO, consider using the `outcome` lib instead?
|
||||
# - it's already purpose built for exactly this, boxing (async)
|
||||
# function invocations..
|
||||
# |_ https://outcome.readthedocs.io/en/latest/api.html#outcome.capture
|
||||
async def failable(
|
||||
fn: partial,
|
||||
ret_fail=None,
|
||||
):
|
||||
try:
|
||||
return await fn()
|
||||
|
||||
except (
|
||||
OSError,
|
||||
json.JSONDecodeError,
|
||||
anyio.BrokenResourceError,
|
||||
httpx.ReadError,
|
||||
leap.errors.TransactionPushError
|
||||
) as e:
|
||||
):
|
||||
return ret_fail
|
||||
|
||||
|
||||
# TODO, again the prefix XD
|
||||
# -[ ] better name then `GPUConnector` ??
|
||||
# |_ `Compute[Net]IO[Mngr]`
|
||||
class SkynetGPUConnector:
|
||||
'''
|
||||
An API for connecting to and conducting various "high level"
|
||||
network-service operations in the skynet.
|
||||
|
||||
- skynet user account creds
|
||||
- hyperion API
|
||||
- IPFs client
|
||||
- CLEOS client
|
||||
|
||||
'''
|
||||
def __init__(self, config: dict):
|
||||
# TODO, why these extra instance vars for an (unsynced)
|
||||
# copy of the `config` state?
|
||||
self.account = config['account']
|
||||
self.permission = config['permission']
|
||||
self.key = config['key']
|
||||
|
||||
# TODO, neither of these instance vars are used anywhere in
|
||||
# methods? so why are they set on this type?
|
||||
self.node_url = config['node_url']
|
||||
self.hyperion_url = config['hyperion_url']
|
||||
|
||||
|
@ -125,7 +157,9 @@ class SkynetGPUConnector:
|
|||
logging.info(f'competitors: {competitors}')
|
||||
return set(competitors)
|
||||
|
||||
|
||||
# TODO, considery making this a NON-method and instead
|
||||
# handing in the `snap['queue']` output beforehand?
|
||||
# -> since that call is the only usage of `self`?
|
||||
async def get_full_queue_snapshot(self):
|
||||
snap = {
|
||||
'requests': {},
|
||||
|
@ -146,6 +180,11 @@ class SkynetGPUConnector:
|
|||
return snap
|
||||
|
||||
async def begin_work(self, request_id: int):
|
||||
'''
|
||||
Publish to the bc that the worker is beginning a model-computation
|
||||
step.
|
||||
|
||||
'''
|
||||
logging.info('begin_work')
|
||||
return await failable(
|
||||
partial(
|
||||
|
@ -269,6 +308,14 @@ class SkynetGPUConnector:
|
|||
return file_cid
|
||||
|
||||
async def get_input_data(self, ipfs_hash: str) -> Image:
|
||||
'''
|
||||
Retrieve an input (image) from the IPFs layer.
|
||||
|
||||
Normally used to retreive seed (visual) content previously
|
||||
generated/validated by the network to be fed to some
|
||||
consuming AI model.
|
||||
|
||||
'''
|
||||
link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
|
||||
res = await get_ipfs_file(link, timeout=1)
|
||||
|
|
Loading…
Reference in New Issue