diff --git a/skynet/cli.py b/skynet/cli.py index 4a5850e..58297ca 100755 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -1,16 +1,8 @@ import json import logging -import random - -from functools import partial import click -from leap.protocol import ( - Name, - Asset, -) - from .config import ( load_skynet_toml, set_hf_vars, @@ -49,7 +41,7 @@ def txt2img(*args, **kwargs): config = load_skynet_toml() set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home) - utils.txt2img(hf_token, **kwargs) + utils.txt2img(config.dgpu.hf_token, **kwargs) @click.command() @@ -74,7 +66,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed): config = load_skynet_toml() set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home) utils.img2img( - hf_token, + config.dgpu.hf_token, model=model, prompt=prompt, img_path=input, @@ -102,7 +94,7 @@ def inpaint(model, prompt, input, mask, output, strength, guidance, steps, seed) config = load_skynet_toml() set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home) utils.inpaint( - hf_token, + config.dgpu.hf_token, model=model, prompt=prompt, img_path=input, diff --git a/skynet/dgpu/__init__.py b/skynet/dgpu/__init__.py index a22655d..cd460ae 100755 --- a/skynet/dgpu/__init__.py +++ b/skynet/dgpu/__init__.py @@ -5,7 +5,7 @@ import trio import urwid from skynet.config import Config -from skynet.dgpu.tui import init_tui +from skynet.dgpu.tui import init_tui, WorkerMonitor from skynet.dgpu.daemon import dgpu_serve_forever from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr @@ -15,7 +15,7 @@ async def open_worker(config: Config): # suppress logs from httpx (logs url + status after every query) logging.getLogger("httpx").setLevel(logging.WARNING) - tui = None + tui: WorkerMonitor | None = None if config.tui: tui = init_tui(config) diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index c309d4b..82e68e2 100755 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -7,6 +7,7 @@ import gc import logging from hashlib import sha256 +from typing import Callable, Generator from contextlib import contextmanager as cm import trio @@ -20,7 +21,14 @@ from skynet.dgpu.errors import ( DGPUInferenceCancelled, ) -from skynet.dgpu.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for +from skynet.dgpu.utils import ( + Pipeline, + crop_image, + convert_from_cv2_to_image, + convert_from_image_to_cv2, + convert_from_img_to_bytes, + pipeline_for +) def prepare_params_for_diffuse( @@ -68,17 +76,21 @@ def prepare_params_for_diffuse( _model_name: str = '' _model_mode: str = '' -_model = None +_model: Pipeline | None = None @cm -def maybe_load_model(name: str, mode: ModelMode): +def maybe_load_model(name: str, mode: ModelMode) -> Generator[Pipeline, None, None]: if mode == ModelMode.DIFFUSE: mode = ModelMode.TXT2IMG global _model_name, _model_mode, _model config = load_skynet_toml().dgpu - if _model_name != name or _model_mode != mode: + if ( + _model_name != name + or + _model_mode != mode + ): # unload model _model = None gc.collect() @@ -94,24 +106,26 @@ def maybe_load_model(name: str, mode: ModelMode): _model_mode = mode if torch.cuda.is_available(): - logging.debug('memory summary:') - logging.debug('\n' + torch.cuda.memory_summary()) + logging.debug( + 'memory summary:\n' + f'{torch.cuda.memory_summary()}' + ) yield _model def compute_one( - model, + model: Pipeline, request_id: int, method: ModelMode, params: BodyV0Params, inputs: list[bytes] = [], - should_cancel = None + should_cancel: Callable[[int, ...], dict] = None ): total_steps = params.step def inference_step_wakeup(*args, **kwargs): '''This is a callback function that gets invoked every inference step, - we need to raise an exception here if we need to cancel work + we must raise DGPUInferenceCancelled here if we need to cancel work ''' step = args[0] # compat with callback_on_step_end @@ -122,6 +136,9 @@ def compute_one( should_raise = False if should_cancel: + '''Pump main thread event loop, evaluate if we should keep working + on this request, based on latest network info like competitors... + ''' should_raise = trio.from_thread.run(should_cancel, request_id) if should_raise: @@ -137,60 +154,56 @@ def compute_one( output_type = params.output_type output = None output_hash = None - try: - name = params.model + name = params.model - match method: - case ( - ModelMode.DIFFUSE | - ModelMode.TXT2IMG | - ModelMode.IMG2IMG | - ModelMode.INPAINT - ): - arguments = prepare_params_for_diffuse( - params, method, inputs) - prompt, guidance, step, seed, extra_params = arguments + match method: + case ( + ModelMode.DIFFUSE | + ModelMode.TXT2IMG | + ModelMode.IMG2IMG | + ModelMode.INPAINT + ): + arguments = prepare_params_for_diffuse( + params, method, inputs) + prompt, guidance, step, seed, extra_params = arguments - if 'flux' in name.lower(): - extra_params['callback_on_step_end'] = inference_step_wakeup + if 'flux' in name.lower(): + extra_params['callback_on_step_end'] = inference_step_wakeup - else: - extra_params['callback'] = inference_step_wakeup - extra_params['callback_steps'] = 1 + else: + extra_params['callback'] = inference_step_wakeup + extra_params['callback_steps'] = 1 - output = model( - prompt, - guidance_scale=guidance, - num_inference_steps=step, - generator=seed, - **extra_params - ).images[0] + output = model( + prompt, + guidance_scale=guidance, + num_inference_steps=step, + generator=seed, + **extra_params + ).images[0] - output_binary = b'' - match output_type: - case 'png': - output_binary = convert_from_img_to_bytes(output) + output_binary = b'' + match output_type: + case 'png': + output_binary = convert_from_img_to_bytes(output) - case _: - raise DGPUComputeError(f'Unsupported output type: {output_type}') + case _: + raise DGPUComputeError(f'Unsupported output type: {output_type}') - output_hash = sha256(output_binary).hexdigest() + output_hash = sha256(output_binary).hexdigest() - case 'upscale': - input_img = inputs[0].convert('RGB') - up_img, _ = model.enhance( - convert_from_image_to_cv2(input_img), outscale=4) + case 'upscale': + input_img = inputs[0].convert('RGB') + up_img, _ = model.enhance( + convert_from_image_to_cv2(input_img), outscale=4) - output = convert_from_cv2_to_image(up_img) + output = convert_from_cv2_to_image(up_img) - output_binary = convert_from_img_to_bytes(output) - output_hash = sha256(output_binary).hexdigest() + output_binary = convert_from_img_to_bytes(output) + output_hash = sha256(output_binary).hexdigest() - case _: - raise DGPUComputeError('Unsupported compute method') - - except BaseException as err: - raise DGPUComputeError(str(err)) from err + case _: + raise DGPUComputeError('Unsupported compute method') maybe_update_tui(lambda tui: tui.set_status('')) diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index dd7022b..8016e64 100755 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -132,6 +132,11 @@ async def maybe_serve_one( output_hash = None match config.backend: case 'sync-on-thread': + '''Block this task until inference completes, pass + state_mngr.should_cancel_work predicate as the inference_step_wakeup cb + used by torch each step of the inference, it will use a + trio.from_thread to unblock the main thread and pump the event loop + ''' output_hash, output = await trio.to_thread.run_sync( partial( compute_one, diff --git a/skynet/dgpu/tui.py b/skynet/dgpu/tui.py index e1c572a..ce8db5c 100644 --- a/skynet/dgpu/tui.py +++ b/skynet/dgpu/tui.py @@ -141,10 +141,6 @@ class WorkerMonitor: self.progress_bar.current = current - pct = 0 - if self.progress_bar.done != 0: - pct = int((self.progress_bar.current / self.progress_bar.done) * 100) - def update_requests(self, new_requests): """ Replace the data in the existing ListBox with new request widgets. diff --git a/skynet/dgpu/utils.py b/skynet/dgpu/utils.py index 6a17c93..daf955f 100755 --- a/skynet/dgpu/utils.py +++ b/skynet/dgpu/utils.py @@ -75,6 +75,9 @@ class DummyPB: def update(self): ... + +type Pipeline = DiffusionPipeline | RealESRGANer + @torch.compiler.disable @contextmanager def dummy_progress_bar(*args, **kwargs): @@ -90,7 +93,7 @@ def pipeline_for( mode: str, mem_fraction: float = 1.0, cache_dir: str | None = None -) -> DiffusionPipeline: +) -> Pipeline: diffusers.utils.logging.disable_progress_bar() logging.info(f'pipeline_for {model} {mode}') diff --git a/skynet/types.py b/skynet/types.py index 45450aa..44e1bce 100644 --- a/skynet/types.py +++ b/skynet/types.py @@ -13,7 +13,7 @@ class ModelMode(StrEnum): class ModelDesc(Struct): short: str # short unique name - mem: float # recomended mem + mem: float # recomended mem in gb attrs: dict # additional mode specific attrs tags: list[ModelMode]