Suggest `skynet.dgpu` docs, typing, pythonisms

From the deep-ish dive drafting our first set of design/architecture
diagrams in https://github.com/skygpu/cyberdyne/pull/2, this adds
a buncha suggestions, typing, and styling adjustments.

Namely the code tweaks include,
- changing to multi-line import tuples where appropriate (since they're
  much handier to modify ;)
- adding typing in many spots where it wasn't clear to me the types
  being returned/operated-with in various (internal) methods.
- doc strings (in mostly random spots Xp ) where i had the need to
  remember the impl's purpose but didn't want to re-read the code in
  detail again.
- ALOT of TODOs surrounding various potential style changes,
  re-factorings, naming and in some cases "modernization" according to
  the latest python3.12 feats/spec/stdlib.
fomo_polish
Tyler Goodlet 2025-02-03 10:25:14 -05:00
parent c0ac6298a9
commit 7cb9f09d95
5 changed files with 201 additions and 61 deletions

View File

@ -14,8 +14,8 @@ from skynet.dgpu.network import SkynetGPUConnector
async def open_dgpu_node(config: dict) -> None: async def open_dgpu_node(config: dict) -> None:
''' '''
Open a top level "GPU mgmt daemon", keep the Open a top level "GPU mgmt daemon", keep the
`SkynetDGPUDaemon._snap: dict[str, list|dict]` table and *maybe* `SkynetDGPUDaemon._snap: dict[str, list|dict]` table
serve a `hypercorn` web API. and *maybe* serve a `hypercorn` web API.
''' '''
conn = SkynetGPUConnector(config) conn = SkynetGPUConnector(config)
@ -32,6 +32,8 @@ async def open_dgpu_node(config: dict) -> None:
async with trio.open_nursery() as tn: async with trio.open_nursery() as tn:
tn.start_soon(daemon.snap_updater_task) tn.start_soon(daemon.snap_updater_task)
# TODO, consider a more explicit `as hypercorn_serve`
# to clarify?
if api: if api:
tn.start_soon(serve, api, api_conf) tn.start_soon(serve, api, api_conf)

52
skynet/dgpu/compute.py 100644 → 100755
View File

@ -1,20 +1,36 @@
#!/usr/bin/python #!/usr/bin/python
# ^TODO? again, why..
#
# Do we expect this mod
# to be invoked? if so why is there no
# `if __name__ == '__main__'` guard?
#
# if anything this should contain a license header ;)
# Skynet Memory Manager '''
Skynet Memory Manager
'''
import gc import gc
import logging import logging
from hashlib import sha256 from hashlib import sha256
import zipfile # import zipfile
from PIL import Image # from PIL import Image
from diffusers import DiffusionPipeline # from diffusers import DiffusionPipeline
import trio import trio
import torch import torch
from skynet.constants import DEFAULT_INITAL_MODEL, MODELS # from skynet.constants import (
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled # DEFAULT_INITAL_MODEL,
# MODELS,
# )
from skynet.dgpu.errors import (
DGPUComputeError,
DGPUInferenceCancelled,
)
from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
@ -66,15 +82,20 @@ def prepare_params_for_diffuse(
) )
# TODO, yet again - drop the redundant prefix ;)
class SkynetMM: class SkynetMM:
'''
(AI algo) Model manager for loading models, computing outputs,
checking load state, and unloading when no-longer-needed/finished.
'''
def __init__(self, config: dict): def __init__(self, config: dict):
self.cache_dir = None self.cache_dir = None
if 'hf_home' in config: if 'hf_home' in config:
self.cache_dir = config['hf_home'] self.cache_dir = config['hf_home']
self._model_name = '' self._model_name: str = ''
self._model_mode = '' self._model_mode: str = ''
# self.load_model(DEFAULT_INITAL_MODEL, 'txt2img') # self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
@ -89,7 +110,7 @@ class SkynetMM:
return False return False
def unload_model(self): def unload_model(self) -> None:
if getattr(self, '_model', None): if getattr(self, '_model', None):
del self._model del self._model
@ -103,7 +124,7 @@ class SkynetMM:
self, self,
name: str, name: str,
mode: str mode: str
): ) -> None:
logging.info(f'loading model {name}...') logging.info(f'loading model {name}...')
self.unload_model() self.unload_model()
self._model = pipeline_for( self._model = pipeline_for(
@ -111,7 +132,6 @@ class SkynetMM:
self._model_mode = mode self._model_mode = mode
self._model_name = name self._model_name = name
def compute_one( def compute_one(
self, self,
request_id: int, request_id: int,
@ -124,6 +144,9 @@ class SkynetMM:
should_raise = trio.from_thread.run(self._should_cancel, request_id) should_raise = trio.from_thread.run(self._should_cancel, request_id)
if should_raise: if should_raise:
logging.warn(f'cancelling work at step {step}') logging.warn(f'cancelling work at step {step}')
# ?TODO, this is never caught, so why is it
# raised specially?
raise DGPUInferenceCancelled() raise DGPUInferenceCancelled()
return {} return {}
@ -199,9 +222,10 @@ class SkynetMM:
case _: case _:
raise DGPUComputeError('Unsupported compute method') raise DGPUComputeError('Unsupported compute method')
except BaseException as e: except BaseException as err:
logging.error(e) logging.error(err)
raise DGPUComputeError(str(e)) # to see the src exc in tb
raise DGPUComputeError(str(err)) from err
finally: finally:
torch.cuda.empty_cache() torch.cuda.empty_cache()

138
skynet/dgpu/daemon.py 100644 → 100755
View File

@ -1,23 +1,25 @@
#!/usr/bin/python #!/usr/bin/python
import json
import random
import logging
import time
import traceback
from hashlib import sha256
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from hashlib import sha256
import json
import logging
import random
# import traceback
import time
import trio import trio
from quart import jsonify from quart import jsonify
from quart_trio import QuartTrio as Quart from quart_trio import QuartTrio as Quart
from skynet.constants import MODELS, VERSION from skynet.constants import (
MODELS,
from skynet.dgpu.errors import * VERSION,
)
from skynet.dgpu.errors import (
DGPUComputeError,
)
from skynet.dgpu.compute import SkynetMM from skynet.dgpu.compute import SkynetMM
from skynet.dgpu.network import SkynetGPUConnector from skynet.dgpu.network import SkynetGPUConnector
@ -30,22 +32,29 @@ def convert_reward_to_int(reward_str):
return int(int_part + decimal_part) return int(int_part + decimal_part)
# prolly don't need the `Skynet` prefix since that's kinda implied ;p
class SkynetDGPUDaemon: class SkynetDGPUDaemon:
'''
The root "GPU daemon".
Contains/manages underlying susystems:
- a GPU connecto
'''
def __init__( def __init__(
self, self,
mm: SkynetMM, mm: SkynetMM,
conn: SkynetGPUConnector, conn: SkynetGPUConnector,
config: dict config: dict
): ):
self.mm = mm self.mm: SkynetMM = mm
self.conn = conn self.conn: SkynetGPUConnector = conn
self.auto_withdraw = ( self.auto_withdraw = (
config['auto_withdraw'] config['auto_withdraw']
if 'auto_withdraw' in config else False if 'auto_withdraw' in config else False
) )
self.account = config['account'] self.account: str = config['account']
self.non_compete = set() self.non_compete = set()
if 'non_compete' in config: if 'non_compete' in config:
@ -67,13 +76,20 @@ class SkynetDGPUDaemon:
'queue': [], 'queue': [],
'requests': {}, 'requests': {},
'my_results': [] 'my_results': []
# ^and here i thot they were **my** results..
# :sadcat:
} }
self._benchmark = [] self._benchmark: list[float] = []
self._last_benchmark = None self._last_benchmark: list[float]|None = None
self._last_generation_ts = None self._last_generation_ts: str|None = None
def _get_benchmark_speed(self) -> float: def _get_benchmark_speed(self) -> float:
'''
Return the (arithmetic) average work-iterations-per-second
fconducted by this compute worker.
'''
if not self._last_benchmark: if not self._last_benchmark:
return 0 return 0
@ -99,11 +115,26 @@ class SkynetDGPUDaemon:
async def snap_updater_task(self): async def snap_updater_task(self):
'''
Busy loop update the local `._snap: dict` table from
'''
while True: while True:
self._snap = await self.conn.get_full_queue_snapshot() self._snap = await self.conn.get_full_queue_snapshot()
await trio.sleep(1) await trio.sleep(1)
async def generate_api(self): # TODO, design suggestion, just make this a lazily accessed
# `@class_property` if we're 3.12+
# |_ https://docs.python.org/3/library/functools.html#functools.cached_property
async def generate_api(self) -> Quart:
'''
Gen a `Quart`-compat web API spec which (for now) simply
serves a small monitoring ep that reports,
- iso-time-stamp of the last served model-output
- the worker's average "compute-iterations-per-second"
'''
app = Quart(__name__) app = Quart(__name__)
@app.route('/') @app.route('/')
@ -117,21 +148,34 @@ class SkynetDGPUDaemon:
return app return app
async def maybe_serve_one(self, req): # TODO? this func is kinda big and maybe is better at module
# level to reduce indentation?
# -[ ] just pass `daemon: SkynetDGPUDaemon` vs. `self`
async def maybe_serve_one(
self,
req: dict,
):
rid = req['id'] rid = req['id']
# parse request # parse request
body = json.loads(req['body']) body = json.loads(req['body'])
model = body['params']['model'] model = body['params']['model']
# if model not known # if model not known, ignore.
if model != 'RealESRGAN_x4plus' and model not in MODELS: if (
model != 'RealESRGAN_x4plus'
and
model not in MODELS
):
logging.warning(f'Unknown model {model}') logging.warning(f'Unknown model {model}')
return False return False
# if whitelist enabled and model not in it continue # only handle whitelisted models
if (len(self.model_whitelist) > 0 and if (
not model in self.model_whitelist): len(self.model_whitelist) > 0
and
model not in self.model_whitelist
):
return False return False
# if blacklist contains model skip # if blacklist contains model skip
@ -139,21 +183,29 @@ class SkynetDGPUDaemon:
return False return False
my_results = [res['id'] for res in self._snap['my_results']] my_results = [res['id'] for res in self._snap['my_results']]
if rid not in my_results and rid in self._snap['requests']: if (
rid not in my_results
and
rid in self._snap['requests']
):
statuses = self._snap['requests'][rid] statuses = self._snap['requests'][rid]
if len(statuses) == 0: if len(statuses) == 0:
inputs = [] inputs = []
for _input in req['binary_data'].split(','): for _input in req['binary_data'].split(','):
if _input: if _input:
for _ in range(3): for _ in range(3):
try: try:
# user `GPUConnector` to IO with
# storage layer to seed the compute
# task.
img = await self.conn.get_input_data(_input) img = await self.conn.get_input_data(_input)
inputs.append(img) inputs.append(img)
break break
except: except BaseException:
... logging.exception(
'Model input error !?!\n'
)
hash_str = ( hash_str = (
str(req['nonce']) str(req['nonce'])
@ -172,7 +224,7 @@ class SkynetDGPUDaemon:
resp = await self.conn.begin_work(rid) resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp: if not resp or 'code' in resp:
logging.info(f'probably being worked on already... skip.') logging.info('probably being worked on already... skip.')
else: else:
try: try:
@ -195,25 +247,37 @@ class SkynetDGPUDaemon:
) )
case _: case _:
raise DGPUComputeError(f'Unsupported backend {self.backend}') raise DGPUComputeError(
self._last_generation_ts = datetime.now().isoformat() f'Unsupported backend {self.backend}'
self._last_benchmark = self._benchmark )
self._benchmark = []
self._last_generation_ts: str = datetime.now().isoformat()
self._last_benchmark: list[float] = self._benchmark
self._benchmark: list[float] = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type) ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash) await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
except BaseException as e: except BaseException as err:
traceback.print_exc() logging.exception('Failed to serve model request !?\n')
await self.conn.cancel_work(rid, str(e)) # traceback.print_exc() # TODO? <- replaced by above ya?
await self.conn.cancel_work(rid, str(err))
finally: finally:
return True return True
# TODO, i would inverse this case logic to avoid an indent
# level in above block ;)
else: else:
logging.info(f'request {rid} already beign worked on, skip...') logging.info(f'request {rid} already beign worked on, skip...')
# TODO, as per above on `.maybe_serve_one()`, it's likely a bit
# more *trionic* to define this all as a module level task-func
# which operates on a `daemon: SkynetDGPUDaemon`?
#
# -[ ] keeps tasks-as-funcs style prominent
# -[ ] avoids so much indentation due to methods
async def serve_forever(self): async def serve_forever(self):
try: try:
while True: while True:
@ -230,6 +294,8 @@ class SkynetDGPUDaemon:
) )
for req in queue: for req in queue:
# TODO, as mentioned above just inline this once
# converted to a mod level func.
if (await self.maybe_serve_one(req)): if (await self.maybe_serve_one(req)):
break break

1
skynet/dgpu/errors.py 100644 → 100755
View File

@ -1,4 +1,5 @@
#!/usr/bin/python #!/usr/bin/python
# ^TODO, why..
class DGPUComputeError(BaseException): class DGPUComputeError(BaseException):

65
skynet/dgpu/network.py 100644 → 100755
View File

@ -13,40 +13,72 @@ import leap
import anyio import anyio
import httpx import httpx
from PIL import Image, UnidentifiedImageError from PIL import (
Image,
# UnidentifiedImageError, # TODO, remove?
)
from leap.cleos import CLEOS from leap.cleos import CLEOS
from leap.protocol import Asset from leap.protocol import Asset
from skynet.constants import DEFAULT_IPFS_DOMAIN, GPU_CONTRACT_ABI from skynet.constants import (
DEFAULT_IPFS_DOMAIN,
GPU_CONTRACT_ABI,
)
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file from skynet.ipfs import (
from skynet.dgpu.errors import DGPUComputeError AsyncIPFSHTTP,
get_ipfs_file,
)
# TODO, remove?
# from skynet.dgpu.errors import DGPUComputeError
REQUEST_UPDATE_TIME = 3 REQUEST_UPDATE_TIME: int = 3
async def failable(fn: partial, ret_fail=None): # TODO, consider using the `outcome` lib instead?
# - it's already purpose built for exactly this, boxing (async)
# function invocations..
# |_ https://outcome.readthedocs.io/en/latest/api.html#outcome.capture
async def failable(
fn: partial,
ret_fail=None,
):
try: try:
return await fn() return await fn()
except ( except (
OSError, OSError,
json.JSONDecodeError, json.JSONDecodeError,
anyio.BrokenResourceError, anyio.BrokenResourceError,
httpx.ReadError, httpx.ReadError,
leap.errors.TransactionPushError leap.errors.TransactionPushError
) as e: ):
return ret_fail return ret_fail
# TODO, again the prefix XD
# -[ ] better name then `GPUConnector` ??
# |_ `Compute[Net]IO[Mngr]`
class SkynetGPUConnector: class SkynetGPUConnector:
'''
An API for connecting to and conducting various "high level"
network-service operations in the skynet.
- skynet user account creds
- hyperion API
- IPFs client
- CLEOS client
'''
def __init__(self, config: dict): def __init__(self, config: dict):
# TODO, why these extra instance vars for an (unsynced)
# copy of the `config` state?
self.account = config['account'] self.account = config['account']
self.permission = config['permission'] self.permission = config['permission']
self.key = config['key'] self.key = config['key']
# TODO, neither of these instance vars are used anywhere in
# methods? so why are they set on this type?
self.node_url = config['node_url'] self.node_url = config['node_url']
self.hyperion_url = config['hyperion_url'] self.hyperion_url = config['hyperion_url']
@ -125,7 +157,9 @@ class SkynetGPUConnector:
logging.info(f'competitors: {competitors}') logging.info(f'competitors: {competitors}')
return set(competitors) return set(competitors)
# TODO, considery making this a NON-method and instead
# handing in the `snap['queue']` output beforehand?
# -> since that call is the only usage of `self`?
async def get_full_queue_snapshot(self): async def get_full_queue_snapshot(self):
snap = { snap = {
'requests': {}, 'requests': {},
@ -146,6 +180,11 @@ class SkynetGPUConnector:
return snap return snap
async def begin_work(self, request_id: int): async def begin_work(self, request_id: int):
'''
Publish to the bc that the worker is beginning a model-computation
step.
'''
logging.info('begin_work') logging.info('begin_work')
return await failable( return await failable(
partial( partial(
@ -269,6 +308,14 @@ class SkynetGPUConnector:
return file_cid return file_cid
async def get_input_data(self, ipfs_hash: str) -> Image: async def get_input_data(self, ipfs_hash: str) -> Image:
'''
Retrieve an input (image) from the IPFs layer.
Normally used to retreive seed (visual) content previously
generated/validated by the network to be fed to some
consuming AI model.
'''
link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
res = await get_ipfs_file(link, timeout=1) res = await get_ipfs_file(link, timeout=1)