Mostly minor typing and comment changes remaining from fomos re-review, only big change is removed BaseException catch inside compute_one

structify
Guillermo Rodriguez 2025-02-18 15:47:05 -03:00
parent 1dd2a8ed89
commit 6a991561de
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
7 changed files with 80 additions and 71 deletions

View File

@ -1,16 +1,8 @@
import json
import logging
import random
from functools import partial
import click
from leap.protocol import (
Name,
Asset,
)
from .config import (
load_skynet_toml,
set_hf_vars,
@ -49,7 +41,7 @@ def txt2img(*args, **kwargs):
config = load_skynet_toml()
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
utils.txt2img(hf_token, **kwargs)
utils.txt2img(config.dgpu.hf_token, **kwargs)
@click.command()
@ -74,7 +66,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
config = load_skynet_toml()
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
utils.img2img(
hf_token,
config.dgpu.hf_token,
model=model,
prompt=prompt,
img_path=input,
@ -102,7 +94,7 @@ def inpaint(model, prompt, input, mask, output, strength, guidance, steps, seed)
config = load_skynet_toml()
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
utils.inpaint(
hf_token,
config.dgpu.hf_token,
model=model,
prompt=prompt,
img_path=input,

View File

@ -5,7 +5,7 @@ import trio
import urwid
from skynet.config import Config
from skynet.dgpu.tui import init_tui
from skynet.dgpu.tui import init_tui, WorkerMonitor
from skynet.dgpu.daemon import dgpu_serve_forever
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr
@ -15,7 +15,7 @@ async def open_worker(config: Config):
# suppress logs from httpx (logs url + status after every query)
logging.getLogger("httpx").setLevel(logging.WARNING)
tui = None
tui: WorkerMonitor | None = None
if config.tui:
tui = init_tui(config)

View File

@ -7,6 +7,7 @@ import gc
import logging
from hashlib import sha256
from typing import Callable, Generator
from contextlib import contextmanager as cm
import trio
@ -20,7 +21,14 @@ from skynet.dgpu.errors import (
DGPUInferenceCancelled,
)
from skynet.dgpu.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
from skynet.dgpu.utils import (
Pipeline,
crop_image,
convert_from_cv2_to_image,
convert_from_image_to_cv2,
convert_from_img_to_bytes,
pipeline_for
)
def prepare_params_for_diffuse(
@ -68,17 +76,21 @@ def prepare_params_for_diffuse(
_model_name: str = ''
_model_mode: str = ''
_model = None
_model: Pipeline | None = None
@cm
def maybe_load_model(name: str, mode: ModelMode):
def maybe_load_model(name: str, mode: ModelMode) -> Generator[Pipeline, None, None]:
if mode == ModelMode.DIFFUSE:
mode = ModelMode.TXT2IMG
global _model_name, _model_mode, _model
config = load_skynet_toml().dgpu
if _model_name != name or _model_mode != mode:
if (
_model_name != name
or
_model_mode != mode
):
# unload model
_model = None
gc.collect()
@ -94,24 +106,26 @@ def maybe_load_model(name: str, mode: ModelMode):
_model_mode = mode
if torch.cuda.is_available():
logging.debug('memory summary:')
logging.debug('\n' + torch.cuda.memory_summary())
logging.debug(
'memory summary:\n'
f'{torch.cuda.memory_summary()}'
)
yield _model
def compute_one(
model,
model: Pipeline,
request_id: int,
method: ModelMode,
params: BodyV0Params,
inputs: list[bytes] = [],
should_cancel = None
should_cancel: Callable[[int, ...], dict] = None
):
total_steps = params.step
def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step,
we need to raise an exception here if we need to cancel work
we must raise DGPUInferenceCancelled here if we need to cancel work
'''
step = args[0]
# compat with callback_on_step_end
@ -122,6 +136,9 @@ def compute_one(
should_raise = False
if should_cancel:
'''Pump main thread event loop, evaluate if we should keep working
on this request, based on latest network info like competitors...
'''
should_raise = trio.from_thread.run(should_cancel, request_id)
if should_raise:
@ -137,60 +154,56 @@ def compute_one(
output_type = params.output_type
output = None
output_hash = None
try:
name = params.model
name = params.model
match method:
case (
ModelMode.DIFFUSE |
ModelMode.TXT2IMG |
ModelMode.IMG2IMG |
ModelMode.INPAINT
):
arguments = prepare_params_for_diffuse(
params, method, inputs)
prompt, guidance, step, seed, extra_params = arguments
match method:
case (
ModelMode.DIFFUSE |
ModelMode.TXT2IMG |
ModelMode.IMG2IMG |
ModelMode.INPAINT
):
arguments = prepare_params_for_diffuse(
params, method, inputs)
prompt, guidance, step, seed, extra_params = arguments
if 'flux' in name.lower():
extra_params['callback_on_step_end'] = inference_step_wakeup
if 'flux' in name.lower():
extra_params['callback_on_step_end'] = inference_step_wakeup
else:
extra_params['callback'] = inference_step_wakeup
extra_params['callback_steps'] = 1
else:
extra_params['callback'] = inference_step_wakeup
extra_params['callback_steps'] = 1
output = model(
prompt,
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
**extra_params
).images[0]
output = model(
prompt,
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
**extra_params
).images[0]
output_binary = b''
match output_type:
case 'png':
output_binary = convert_from_img_to_bytes(output)
output_binary = b''
match output_type:
case 'png':
output_binary = convert_from_img_to_bytes(output)
case _:
raise DGPUComputeError(f'Unsupported output type: {output_type}')
case _:
raise DGPUComputeError(f'Unsupported output type: {output_type}')
output_hash = sha256(output_binary).hexdigest()
output_hash = sha256(output_binary).hexdigest()
case 'upscale':
input_img = inputs[0].convert('RGB')
up_img, _ = model.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
case 'upscale':
input_img = inputs[0].convert('RGB')
up_img, _ = model.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
output = convert_from_cv2_to_image(up_img)
output = convert_from_cv2_to_image(up_img)
output_binary = convert_from_img_to_bytes(output)
output_hash = sha256(output_binary).hexdigest()
output_binary = convert_from_img_to_bytes(output)
output_hash = sha256(output_binary).hexdigest()
case _:
raise DGPUComputeError('Unsupported compute method')
except BaseException as err:
raise DGPUComputeError(str(err)) from err
case _:
raise DGPUComputeError('Unsupported compute method')
maybe_update_tui(lambda tui: tui.set_status(''))

View File

@ -132,6 +132,11 @@ async def maybe_serve_one(
output_hash = None
match config.backend:
case 'sync-on-thread':
'''Block this task until inference completes, pass
state_mngr.should_cancel_work predicate as the inference_step_wakeup cb
used by torch each step of the inference, it will use a
trio.from_thread to unblock the main thread and pump the event loop
'''
output_hash, output = await trio.to_thread.run_sync(
partial(
compute_one,

View File

@ -141,10 +141,6 @@ class WorkerMonitor:
self.progress_bar.current = current
pct = 0
if self.progress_bar.done != 0:
pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
def update_requests(self, new_requests):
"""
Replace the data in the existing ListBox with new request widgets.

View File

@ -75,6 +75,9 @@ class DummyPB:
def update(self):
...
type Pipeline = DiffusionPipeline | RealESRGANer
@torch.compiler.disable
@contextmanager
def dummy_progress_bar(*args, **kwargs):
@ -90,7 +93,7 @@ def pipeline_for(
mode: str,
mem_fraction: float = 1.0,
cache_dir: str | None = None
) -> DiffusionPipeline:
) -> Pipeline:
diffusers.utils.logging.disable_progress_bar()
logging.info(f'pipeline_for {model} {mode}')

View File

@ -13,7 +13,7 @@ class ModelMode(StrEnum):
class ModelDesc(Struct):
short: str # short unique name
mem: float # recomended mem
mem: float # recomended mem in gb
attrs: dict # additional mode specific attrs
tags: list[ModelMode]