diff --git a/build_docker.sh b/build_docker.sh index acec498..8ca9905 100755 --- a/build_docker.sh +++ b/build_docker.sh @@ -5,3 +5,7 @@ docker build \ docker build \ -t guilledk/skynet:runtime-cuda-py311 \ -f docker/Dockerfile.runtime+cuda-py311 . + +docker build \ + -t guilledk/skynet:runtime-cuda \ + -f docker/Dockerfile.runtime+cuda-py311 . diff --git a/skynet/cli.py b/skynet/cli.py index ae56adb..370cda2 100755 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -33,8 +33,8 @@ def txt2img(*args, **kwargs): from . import utils config = load_skynet_toml() - hf_token = load_key(config, 'skynet.dgpu', 'hf_token') - hf_home = load_key(config, 'skynet.dgpu', 'hf_home') + hf_token = load_key(config, 'skynet.dgpu.hf_token') + hf_home = load_key(config, 'skynet.dgpu.hf_home') set_hf_vars(hf_token, hf_home) utils.txt2img(hf_token, **kwargs) @@ -51,8 +51,8 @@ def txt2img(*args, **kwargs): def img2img(model, prompt, input, output, strength, guidance, steps, seed): from . import utils config = load_skynet_toml() - hf_token = load_key(config, 'skynet.dgpu', 'hf_token') - hf_home = load_key(config, 'skynet.dgpu', 'hf_home') + hf_token = load_key(config, 'skynet.dgpu.hf_token') + hf_home = load_key(config, 'skynet.dgpu.hf_home') set_hf_vars(hf_token, hf_home) utils.img2img( hf_token, @@ -82,8 +82,8 @@ def upscale(input, output, model): def download(): from . import utils config = load_skynet_toml() - hf_token = load_key(config, 'skynet.dgpu', 'hf_token') - hf_home = load_key(config, 'skynet.dgpu', 'hf_home') + hf_token = load_key(config, 'skynet.dgpu.hf_token') + hf_home = load_key(config, 'skynet.dgpu.hf_home') set_hf_vars(hf_token, hf_home) utils.download_all_models(hf_token) @@ -112,10 +112,10 @@ def enqueue( config = load_skynet_toml() - key = load_key(config, 'skynet.user', 'key') - account = load_key(config, 'skynet.user', 'account') - permission = load_key(config, 'skynet.user', 'permission') - node_url = load_key(config, 'skynet.user', 'node_url') + key = load_key(config, 'skynet.user.key') + account = load_key(config, 'skynet.user.account') + permission = load_key(config, 'skynet.user.permission') + node_url = load_key(config, 'skynet.user.node_url') cleos = CLEOS(None, None, url=node_url, remote=node_url) @@ -156,10 +156,10 @@ def clean( from leap.cleos import CLEOS config = load_skynet_toml() - key = load_key(config, 'skynet.user', 'key') - account = load_key(config, 'skynet.user', 'account') - permission = load_key(config, 'skynet.user', 'permission') - node_url = load_key(config, 'skynet.user', 'node_url') + key = load_key(config, 'skynet.user.key') + account = load_key(config, 'skynet.user.account') + permission = load_key(config, 'skynet.user.permission') + node_url = load_key(config, 'skynet.user.node_url') logging.basicConfig(level=loglevel) cleos = CLEOS(None, None, url=node_url, remote=node_url) @@ -177,7 +177,7 @@ def clean( def queue(): import requests config = load_skynet_toml() - node_url = load_key(config, 'skynet.user', 'node_url') + node_url = load_key(config, 'skynet.user.node_url') resp = requests.post( f'{node_url}/v1/chain/get_table_rows', json={ @@ -194,7 +194,7 @@ def queue(): def status(request_id: int): import requests config = load_skynet_toml() - node_url = load_key(config, 'skynet.user', 'node_url') + node_url = load_key(config, 'skynet.user.node_url') resp = requests.post( f'{node_url}/v1/chain/get_table_rows', json={ @@ -213,10 +213,10 @@ def dequeue(request_id: int): from leap.cleos import CLEOS config = load_skynet_toml() - key = load_key(config, 'skynet.user', 'key') - account = load_key(config, 'skynet.user', 'account') - permission = load_key(config, 'skynet.user', 'permission') - node_url = load_key(config, 'skynet.user', 'node_url') + key = load_key(config, 'skynet.user.key') + account = load_key(config, 'skynet.user.account') + permission = load_key(config, 'skynet.user.permission') + node_url = load_key(config, 'skynet.user.node_url') cleos = CLEOS(None, None, url=node_url, remote=node_url) res = trio.run( @@ -248,10 +248,10 @@ def config( config = load_skynet_toml() - key = load_key(config, 'skynet.user', 'key') - account = load_key(config, 'skynet.user', 'account') - permission = load_key(config, 'skynet.user', 'permission') - node_url = load_key(config, 'skynet.user', 'node_url') + key = load_key(config, 'skynet.user.key') + account = load_key(config, 'skynet.user.account') + permission = load_key(config, 'skynet.user.permission') + node_url = load_key(config, 'skynet.user.node_url') cleos = CLEOS(None, None, url=node_url, remote=node_url) res = trio.run( @@ -277,10 +277,10 @@ def deposit(quantity: str): config = load_skynet_toml() - key = load_key(config, 'skynet.user', 'key') - account = load_key(config, 'skynet.user', 'account') - permission = load_key(config, 'skynet.user', 'permission') - node_url = load_key(config, 'skynet.user', 'node_url') + key = load_key(config, 'skynet.user.key') + account = load_key(config, 'skynet.user.account') + permission = load_key(config, 'skynet.user.permission') + node_url = load_key(config, 'skynet.user.node_url') cleos = CLEOS(None, None, url=node_url, remote=node_url) res = trio.run( @@ -365,21 +365,21 @@ def telegram( logging.basicConfig(level=loglevel) config = load_skynet_toml() - tg_token = load_key(config, 'skynet.telegram', 'tg_token') + tg_token = load_key(config, 'skynet.telegram.tg_token') - key = load_key(config, 'skynet.telegram', 'key') - account = load_key(config, 'skynet.telegram', 'account') - permission = load_key(config, 'skynet.telegram', 'permission') - node_url = load_key(config, 'skynet.telegram', 'node_url') - hyperion_url = load_key(config, 'skynet.telegram', 'hyperion_url') + key = load_key(config, 'skynet.telegram.key') + account = load_key(config, 'skynet.telegram.account') + permission = load_key(config, 'skynet.telegram.permission') + node_url = load_key(config, 'skynet.telegram.node_url') + hyperion_url = load_key(config, 'skynet.telegram.hyperion_url') try: - ipfs_gateway_url = load_key(config, 'skynet.telegram', 'ipfs_gateway_url') + ipfs_gateway_url = load_key(config, 'skynet.telegram.ipfs_gateway_url') except ConfigParsingError: ipfs_gateway_url = None - ipfs_url = load_key(config, 'skynet.telegram', 'ipfs_url') + ipfs_url = load_key(config, 'skynet.telegram.ipfs_url') async def _async_main(): frontend = SkynetTelegramFrontend( @@ -421,16 +421,16 @@ def discord( logging.basicConfig(level=loglevel) config = load_skynet_toml() - dc_token = load_key(config, 'skynet.discord', 'dc_token') + dc_token = load_key(config, 'skynet.discord.dc_token') - key = load_key(config, 'skynet.discord', 'key') - account = load_key(config, 'skynet.discord', 'account') - permission = load_key(config, 'skynet.discord', 'permission') - node_url = load_key(config, 'skynet.discord', 'node_url') - hyperion_url = load_key(config, 'skynet.discord', 'hyperion_url') + key = load_key(config, 'skynet.discord.key') + account = load_key(config, 'skynet.discord.account') + permission = load_key(config, 'skynet.discord.permission') + node_url = load_key(config, 'skynet.discord.node_url') + hyperion_url = load_key(config, 'skynet.discord.hyperion_url') - ipfs_gateway_url = load_key(config, 'skynet.discord', 'ipfs_gateway_url') - ipfs_url = load_key(config, 'skynet.discord', 'ipfs_url') + ipfs_gateway_url = load_key(config, 'skynet.discord.ipfs_gateway_url') + ipfs_url = load_key(config, 'skynet.discord.ipfs_url') async def _async_main(): frontend = SkynetDiscordFrontend( @@ -471,8 +471,8 @@ def pinner(loglevel): from .ipfs.pinner import SkynetPinner config = load_skynet_toml() - hyperion_url = load_key(config, 'skynet.pinner', 'hyperion_url') - ipfs_url = load_key(config, 'skynet.pinner', 'ipfs_url') + hyperion_url = load_key(config, 'skynet.pinner.hyperion_url') + ipfs_url = load_key(config, 'skynet.pinner.ipfs_url') logging.basicConfig(level=loglevel) ipfs_node = AsyncIPFSHTTP(ipfs_url) diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index c343ae7..7df8d95 100644 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -13,7 +13,7 @@ import trio import torch from skynet.constants import DEFAULT_INITAL_MODELS, MODELS -from skynet.dgpu.errors import DGPUComputeError +from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for @@ -132,16 +132,19 @@ class SkynetMM: def compute_one( self, + request_id: int, should_cancel_work, method: str, params: dict, binary: bytes | None = None ): - def callback_fn(step: int, timestep: int, latents: torch.FloatTensor): - should_raise = trio.from_thread.run(should_cancel_work) + def maybe_cancel_work(step, *args, **kwargs): + should_raise = trio.from_thread.run(should_cancel_work, request_id) if should_raise: logging.warn(f'cancelling work at step {step}') - raise DGPUComputeError('Inference cancelled') + raise DGPUInferenceCancelled() + + maybe_cancel_work(0) try: match method: @@ -157,7 +160,7 @@ class SkynetMM: guidance_scale=guidance, num_inference_steps=step, generator=seed, - callback=callback_fn, + callback=maybe_cancel_work, callback_steps=2, **extra_params ).images[0] diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 6fb1c39..5d6971d 100644 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -40,15 +40,9 @@ class SkynetDGPUDaemon: if 'model_blacklist' in config: self.model_blacklist = set(config['model_blacklist']) - self.current_request = None - - async def should_cancel_work(self): - competitors = set(( - status['worker'] - for status in - (await self.conn.get_status_by_request_id(self.current_request)) - )) - return self.non_compete & competitors + async def should_cancel_work(self, request_id: int): + competitors = await self.conn.get_competitors_for_req(request_id) + return bool(self.non_compete & competitors) async def serve_forever(self): try: @@ -79,7 +73,7 @@ class SkynetDGPUDaemon: statuses = await self.conn.get_status_by_request_id(rid) if len(statuses) == 0: - self.current_request = rid + self.conn.monitor_request(rid) binary = await self.conn.get_input_data(req['binary_data']) @@ -107,6 +101,7 @@ class SkynetDGPUDaemon: img_sha, img_raw = await trio.to_thread.run_sync( partial( self.mm.compute_one, + rid, self.should_cancel_work, body['method'], body['params'], binary=binary ) @@ -115,11 +110,13 @@ class SkynetDGPUDaemon: ipfs_hash = await self.conn.publish_on_ipfs(img_raw) await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash) - break except BaseException as e: traceback.print_exc() await self.conn.cancel_work(rid, str(e)) + + finally: + self.conn.forget_request(rid) break else: diff --git a/skynet/dgpu/errors.py b/skynet/dgpu/errors.py index 1f08624..91db585 100644 --- a/skynet/dgpu/errors.py +++ b/skynet/dgpu/errors.py @@ -3,3 +3,6 @@ class DGPUComputeError(BaseException): ... + +class DGPUInferenceCancelled(BaseException): + ... diff --git a/skynet/dgpu/network.py b/skynet/dgpu/network.py index 83e896d..ad082b3 100644 --- a/skynet/dgpu/network.py +++ b/skynet/dgpu/network.py @@ -1,13 +1,15 @@ #!/usr/bin/python -from functools import partial import io import json -from pathlib import Path import time import logging +from pathlib import Path +from functools import partial + import asks +import trio import anyio from PIL import Image @@ -20,6 +22,9 @@ from skynet.dgpu.errors import DGPUComputeError from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file +REQUEST_UPDATE_TIME = 3 + + async def failable(fn: partial, ret_fail=None): try: return await fn() @@ -54,6 +59,8 @@ class SkynetGPUConnector: self.ipfs_client = AsyncIPFSHTTP(self.ipfs_url) + self._wip_requests = {} + # blockchain helpers async def get_work_requests_last_hour(self): @@ -103,6 +110,36 @@ class SkynetGPUConnector: else: return None + def monitor_request(self, request_id: int): + logging.info(f'begin monitoring request: {request_id}') + self._wip_requests[request_id] = { + 'last_update': None, + 'competitors': set() + } + + async def maybe_update_request(self, request_id: int): + now = time.time() + stats = self._wip_requests[request_id] + if (not stats['last_update'] or + (now - stats['last_update']) > REQUEST_UPDATE_TIME): + stats['competitors'] = [ + status['worker'] + for status in + (await self.get_status_by_request_id(request_id)) + if status['worker'] != self.account + ] + stats['last_update'] = now + + async def get_competitors_for_req(self, request_id: int) -> set: + await self.maybe_update_request(request_id) + competitors = set(self._wip_requests[request_id]['competitors']) + logging.info(f'competitors: {competitors}') + return competitors + + def forget_request(self, request_id: int): + logging.info(f'end monitoring request: {request_id}') + del self._wip_requests[request_id] + async def begin_work(self, request_id: int): logging.info('begin_work') return await failable(