Mostly minor typing and comment changes remaining from fomos re-review, only big change is removed BaseException catch inside compute_one

structify
Guillermo Rodriguez 2025-02-18 15:47:05 -03:00
parent 1dd2a8ed89
commit 6a991561de
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
7 changed files with 80 additions and 71 deletions

View File

@ -1,16 +1,8 @@
import json import json
import logging import logging
import random
from functools import partial
import click import click
from leap.protocol import (
Name,
Asset,
)
from .config import ( from .config import (
load_skynet_toml, load_skynet_toml,
set_hf_vars, set_hf_vars,
@ -49,7 +41,7 @@ def txt2img(*args, **kwargs):
config = load_skynet_toml() config = load_skynet_toml()
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home) set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
utils.txt2img(hf_token, **kwargs) utils.txt2img(config.dgpu.hf_token, **kwargs)
@click.command() @click.command()
@ -74,7 +66,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
config = load_skynet_toml() config = load_skynet_toml()
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home) set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
utils.img2img( utils.img2img(
hf_token, config.dgpu.hf_token,
model=model, model=model,
prompt=prompt, prompt=prompt,
img_path=input, img_path=input,
@ -102,7 +94,7 @@ def inpaint(model, prompt, input, mask, output, strength, guidance, steps, seed)
config = load_skynet_toml() config = load_skynet_toml()
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home) set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
utils.inpaint( utils.inpaint(
hf_token, config.dgpu.hf_token,
model=model, model=model,
prompt=prompt, prompt=prompt,
img_path=input, img_path=input,

View File

@ -5,7 +5,7 @@ import trio
import urwid import urwid
from skynet.config import Config from skynet.config import Config
from skynet.dgpu.tui import init_tui from skynet.dgpu.tui import init_tui, WorkerMonitor
from skynet.dgpu.daemon import dgpu_serve_forever from skynet.dgpu.daemon import dgpu_serve_forever
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr
@ -15,7 +15,7 @@ async def open_worker(config: Config):
# suppress logs from httpx (logs url + status after every query) # suppress logs from httpx (logs url + status after every query)
logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING)
tui = None tui: WorkerMonitor | None = None
if config.tui: if config.tui:
tui = init_tui(config) tui = init_tui(config)

View File

@ -7,6 +7,7 @@ import gc
import logging import logging
from hashlib import sha256 from hashlib import sha256
from typing import Callable, Generator
from contextlib import contextmanager as cm from contextlib import contextmanager as cm
import trio import trio
@ -20,7 +21,14 @@ from skynet.dgpu.errors import (
DGPUInferenceCancelled, DGPUInferenceCancelled,
) )
from skynet.dgpu.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for from skynet.dgpu.utils import (
Pipeline,
crop_image,
convert_from_cv2_to_image,
convert_from_image_to_cv2,
convert_from_img_to_bytes,
pipeline_for
)
def prepare_params_for_diffuse( def prepare_params_for_diffuse(
@ -68,17 +76,21 @@ def prepare_params_for_diffuse(
_model_name: str = '' _model_name: str = ''
_model_mode: str = '' _model_mode: str = ''
_model = None _model: Pipeline | None = None
@cm @cm
def maybe_load_model(name: str, mode: ModelMode): def maybe_load_model(name: str, mode: ModelMode) -> Generator[Pipeline, None, None]:
if mode == ModelMode.DIFFUSE: if mode == ModelMode.DIFFUSE:
mode = ModelMode.TXT2IMG mode = ModelMode.TXT2IMG
global _model_name, _model_mode, _model global _model_name, _model_mode, _model
config = load_skynet_toml().dgpu config = load_skynet_toml().dgpu
if _model_name != name or _model_mode != mode: if (
_model_name != name
or
_model_mode != mode
):
# unload model # unload model
_model = None _model = None
gc.collect() gc.collect()
@ -94,24 +106,26 @@ def maybe_load_model(name: str, mode: ModelMode):
_model_mode = mode _model_mode = mode
if torch.cuda.is_available(): if torch.cuda.is_available():
logging.debug('memory summary:') logging.debug(
logging.debug('\n' + torch.cuda.memory_summary()) 'memory summary:\n'
f'{torch.cuda.memory_summary()}'
)
yield _model yield _model
def compute_one( def compute_one(
model, model: Pipeline,
request_id: int, request_id: int,
method: ModelMode, method: ModelMode,
params: BodyV0Params, params: BodyV0Params,
inputs: list[bytes] = [], inputs: list[bytes] = [],
should_cancel = None should_cancel: Callable[[int, ...], dict] = None
): ):
total_steps = params.step total_steps = params.step
def inference_step_wakeup(*args, **kwargs): def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step, '''This is a callback function that gets invoked every inference step,
we need to raise an exception here if we need to cancel work we must raise DGPUInferenceCancelled here if we need to cancel work
''' '''
step = args[0] step = args[0]
# compat with callback_on_step_end # compat with callback_on_step_end
@ -122,6 +136,9 @@ def compute_one(
should_raise = False should_raise = False
if should_cancel: if should_cancel:
'''Pump main thread event loop, evaluate if we should keep working
on this request, based on latest network info like competitors...
'''
should_raise = trio.from_thread.run(should_cancel, request_id) should_raise = trio.from_thread.run(should_cancel, request_id)
if should_raise: if should_raise:
@ -137,7 +154,6 @@ def compute_one(
output_type = params.output_type output_type = params.output_type
output = None output = None
output_hash = None output_hash = None
try:
name = params.model name = params.model
match method: match method:
@ -189,9 +205,6 @@ def compute_one(
case _: case _:
raise DGPUComputeError('Unsupported compute method') raise DGPUComputeError('Unsupported compute method')
except BaseException as err:
raise DGPUComputeError(str(err)) from err
maybe_update_tui(lambda tui: tui.set_status('')) maybe_update_tui(lambda tui: tui.set_status(''))
return output_hash, output return output_hash, output

View File

@ -132,6 +132,11 @@ async def maybe_serve_one(
output_hash = None output_hash = None
match config.backend: match config.backend:
case 'sync-on-thread': case 'sync-on-thread':
'''Block this task until inference completes, pass
state_mngr.should_cancel_work predicate as the inference_step_wakeup cb
used by torch each step of the inference, it will use a
trio.from_thread to unblock the main thread and pump the event loop
'''
output_hash, output = await trio.to_thread.run_sync( output_hash, output = await trio.to_thread.run_sync(
partial( partial(
compute_one, compute_one,

View File

@ -141,10 +141,6 @@ class WorkerMonitor:
self.progress_bar.current = current self.progress_bar.current = current
pct = 0
if self.progress_bar.done != 0:
pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
def update_requests(self, new_requests): def update_requests(self, new_requests):
""" """
Replace the data in the existing ListBox with new request widgets. Replace the data in the existing ListBox with new request widgets.

View File

@ -75,6 +75,9 @@ class DummyPB:
def update(self): def update(self):
... ...
type Pipeline = DiffusionPipeline | RealESRGANer
@torch.compiler.disable @torch.compiler.disable
@contextmanager @contextmanager
def dummy_progress_bar(*args, **kwargs): def dummy_progress_bar(*args, **kwargs):
@ -90,7 +93,7 @@ def pipeline_for(
mode: str, mode: str,
mem_fraction: float = 1.0, mem_fraction: float = 1.0,
cache_dir: str | None = None cache_dir: str | None = None
) -> DiffusionPipeline: ) -> Pipeline:
diffusers.utils.logging.disable_progress_bar() diffusers.utils.logging.disable_progress_bar()
logging.info(f'pipeline_for {model} {mode}') logging.info(f'pipeline_for {model} {mode}')

View File

@ -13,7 +13,7 @@ class ModelMode(StrEnum):
class ModelDesc(Struct): class ModelDesc(Struct):
short: str # short unique name short: str # short unique name
mem: float # recomended mem mem: float # recomended mem in gb
attrs: dict # additional mode specific attrs attrs: dict # additional mode specific attrs
tags: list[ModelMode] tags: list[ModelMode]