mirror of https://github.com/skygpu/skynet.git
Mostly minor typing and comment changes remaining from fomos re-review, only big change is removed BaseException catch inside compute_one
parent
1dd2a8ed89
commit
6a991561de
|
@ -1,16 +1,8 @@
|
|||
import json
|
||||
import logging
|
||||
import random
|
||||
|
||||
from functools import partial
|
||||
|
||||
import click
|
||||
|
||||
from leap.protocol import (
|
||||
Name,
|
||||
Asset,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
load_skynet_toml,
|
||||
set_hf_vars,
|
||||
|
@ -49,7 +41,7 @@ def txt2img(*args, **kwargs):
|
|||
|
||||
config = load_skynet_toml()
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.txt2img(hf_token, **kwargs)
|
||||
utils.txt2img(config.dgpu.hf_token, **kwargs)
|
||||
|
||||
|
||||
@click.command()
|
||||
|
@ -74,7 +66,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
|||
config = load_skynet_toml()
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.img2img(
|
||||
hf_token,
|
||||
config.dgpu.hf_token,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
img_path=input,
|
||||
|
@ -102,7 +94,7 @@ def inpaint(model, prompt, input, mask, output, strength, guidance, steps, seed)
|
|||
config = load_skynet_toml()
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.inpaint(
|
||||
hf_token,
|
||||
config.dgpu.hf_token,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
img_path=input,
|
||||
|
|
|
@ -5,7 +5,7 @@ import trio
|
|||
import urwid
|
||||
|
||||
from skynet.config import Config
|
||||
from skynet.dgpu.tui import init_tui
|
||||
from skynet.dgpu.tui import init_tui, WorkerMonitor
|
||||
from skynet.dgpu.daemon import dgpu_serve_forever
|
||||
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr
|
||||
|
||||
|
@ -15,7 +15,7 @@ async def open_worker(config: Config):
|
|||
# suppress logs from httpx (logs url + status after every query)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
tui = None
|
||||
tui: WorkerMonitor | None = None
|
||||
if config.tui:
|
||||
tui = init_tui(config)
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import gc
|
|||
import logging
|
||||
|
||||
from hashlib import sha256
|
||||
from typing import Callable, Generator
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
import trio
|
||||
|
@ -20,7 +21,14 @@ from skynet.dgpu.errors import (
|
|||
DGPUInferenceCancelled,
|
||||
)
|
||||
|
||||
from skynet.dgpu.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
||||
from skynet.dgpu.utils import (
|
||||
Pipeline,
|
||||
crop_image,
|
||||
convert_from_cv2_to_image,
|
||||
convert_from_image_to_cv2,
|
||||
convert_from_img_to_bytes,
|
||||
pipeline_for
|
||||
)
|
||||
|
||||
|
||||
def prepare_params_for_diffuse(
|
||||
|
@ -68,17 +76,21 @@ def prepare_params_for_diffuse(
|
|||
|
||||
_model_name: str = ''
|
||||
_model_mode: str = ''
|
||||
_model = None
|
||||
_model: Pipeline | None = None
|
||||
|
||||
@cm
|
||||
def maybe_load_model(name: str, mode: ModelMode):
|
||||
def maybe_load_model(name: str, mode: ModelMode) -> Generator[Pipeline, None, None]:
|
||||
if mode == ModelMode.DIFFUSE:
|
||||
mode = ModelMode.TXT2IMG
|
||||
|
||||
global _model_name, _model_mode, _model
|
||||
config = load_skynet_toml().dgpu
|
||||
|
||||
if _model_name != name or _model_mode != mode:
|
||||
if (
|
||||
_model_name != name
|
||||
or
|
||||
_model_mode != mode
|
||||
):
|
||||
# unload model
|
||||
_model = None
|
||||
gc.collect()
|
||||
|
@ -94,24 +106,26 @@ def maybe_load_model(name: str, mode: ModelMode):
|
|||
_model_mode = mode
|
||||
|
||||
if torch.cuda.is_available():
|
||||
logging.debug('memory summary:')
|
||||
logging.debug('\n' + torch.cuda.memory_summary())
|
||||
logging.debug(
|
||||
'memory summary:\n'
|
||||
f'{torch.cuda.memory_summary()}'
|
||||
)
|
||||
|
||||
yield _model
|
||||
|
||||
|
||||
def compute_one(
|
||||
model,
|
||||
model: Pipeline,
|
||||
request_id: int,
|
||||
method: ModelMode,
|
||||
params: BodyV0Params,
|
||||
inputs: list[bytes] = [],
|
||||
should_cancel = None
|
||||
should_cancel: Callable[[int, ...], dict] = None
|
||||
):
|
||||
total_steps = params.step
|
||||
def inference_step_wakeup(*args, **kwargs):
|
||||
'''This is a callback function that gets invoked every inference step,
|
||||
we need to raise an exception here if we need to cancel work
|
||||
we must raise DGPUInferenceCancelled here if we need to cancel work
|
||||
'''
|
||||
step = args[0]
|
||||
# compat with callback_on_step_end
|
||||
|
@ -122,6 +136,9 @@ def compute_one(
|
|||
|
||||
should_raise = False
|
||||
if should_cancel:
|
||||
'''Pump main thread event loop, evaluate if we should keep working
|
||||
on this request, based on latest network info like competitors...
|
||||
'''
|
||||
should_raise = trio.from_thread.run(should_cancel, request_id)
|
||||
|
||||
if should_raise:
|
||||
|
@ -137,60 +154,56 @@ def compute_one(
|
|||
output_type = params.output_type
|
||||
output = None
|
||||
output_hash = None
|
||||
try:
|
||||
name = params.model
|
||||
name = params.model
|
||||
|
||||
match method:
|
||||
case (
|
||||
ModelMode.DIFFUSE |
|
||||
ModelMode.TXT2IMG |
|
||||
ModelMode.IMG2IMG |
|
||||
ModelMode.INPAINT
|
||||
):
|
||||
arguments = prepare_params_for_diffuse(
|
||||
params, method, inputs)
|
||||
prompt, guidance, step, seed, extra_params = arguments
|
||||
match method:
|
||||
case (
|
||||
ModelMode.DIFFUSE |
|
||||
ModelMode.TXT2IMG |
|
||||
ModelMode.IMG2IMG |
|
||||
ModelMode.INPAINT
|
||||
):
|
||||
arguments = prepare_params_for_diffuse(
|
||||
params, method, inputs)
|
||||
prompt, guidance, step, seed, extra_params = arguments
|
||||
|
||||
if 'flux' in name.lower():
|
||||
extra_params['callback_on_step_end'] = inference_step_wakeup
|
||||
if 'flux' in name.lower():
|
||||
extra_params['callback_on_step_end'] = inference_step_wakeup
|
||||
|
||||
else:
|
||||
extra_params['callback'] = inference_step_wakeup
|
||||
extra_params['callback_steps'] = 1
|
||||
else:
|
||||
extra_params['callback'] = inference_step_wakeup
|
||||
extra_params['callback_steps'] = 1
|
||||
|
||||
output = model(
|
||||
prompt,
|
||||
guidance_scale=guidance,
|
||||
num_inference_steps=step,
|
||||
generator=seed,
|
||||
**extra_params
|
||||
).images[0]
|
||||
output = model(
|
||||
prompt,
|
||||
guidance_scale=guidance,
|
||||
num_inference_steps=step,
|
||||
generator=seed,
|
||||
**extra_params
|
||||
).images[0]
|
||||
|
||||
output_binary = b''
|
||||
match output_type:
|
||||
case 'png':
|
||||
output_binary = convert_from_img_to_bytes(output)
|
||||
output_binary = b''
|
||||
match output_type:
|
||||
case 'png':
|
||||
output_binary = convert_from_img_to_bytes(output)
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unsupported output type: {output_type}')
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unsupported output type: {output_type}')
|
||||
|
||||
output_hash = sha256(output_binary).hexdigest()
|
||||
output_hash = sha256(output_binary).hexdigest()
|
||||
|
||||
case 'upscale':
|
||||
input_img = inputs[0].convert('RGB')
|
||||
up_img, _ = model.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
case 'upscale':
|
||||
input_img = inputs[0].convert('RGB')
|
||||
up_img, _ = model.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
|
||||
output = convert_from_cv2_to_image(up_img)
|
||||
output = convert_from_cv2_to_image(up_img)
|
||||
|
||||
output_binary = convert_from_img_to_bytes(output)
|
||||
output_hash = sha256(output_binary).hexdigest()
|
||||
output_binary = convert_from_img_to_bytes(output)
|
||||
output_hash = sha256(output_binary).hexdigest()
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError('Unsupported compute method')
|
||||
|
||||
except BaseException as err:
|
||||
raise DGPUComputeError(str(err)) from err
|
||||
case _:
|
||||
raise DGPUComputeError('Unsupported compute method')
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_status(''))
|
||||
|
||||
|
|
|
@ -132,6 +132,11 @@ async def maybe_serve_one(
|
|||
output_hash = None
|
||||
match config.backend:
|
||||
case 'sync-on-thread':
|
||||
'''Block this task until inference completes, pass
|
||||
state_mngr.should_cancel_work predicate as the inference_step_wakeup cb
|
||||
used by torch each step of the inference, it will use a
|
||||
trio.from_thread to unblock the main thread and pump the event loop
|
||||
'''
|
||||
output_hash, output = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
compute_one,
|
||||
|
|
|
@ -141,10 +141,6 @@ class WorkerMonitor:
|
|||
|
||||
self.progress_bar.current = current
|
||||
|
||||
pct = 0
|
||||
if self.progress_bar.done != 0:
|
||||
pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
|
||||
|
||||
def update_requests(self, new_requests):
|
||||
"""
|
||||
Replace the data in the existing ListBox with new request widgets.
|
||||
|
|
|
@ -75,6 +75,9 @@ class DummyPB:
|
|||
def update(self):
|
||||
...
|
||||
|
||||
|
||||
type Pipeline = DiffusionPipeline | RealESRGANer
|
||||
|
||||
@torch.compiler.disable
|
||||
@contextmanager
|
||||
def dummy_progress_bar(*args, **kwargs):
|
||||
|
@ -90,7 +93,7 @@ def pipeline_for(
|
|||
mode: str,
|
||||
mem_fraction: float = 1.0,
|
||||
cache_dir: str | None = None
|
||||
) -> DiffusionPipeline:
|
||||
) -> Pipeline:
|
||||
diffusers.utils.logging.disable_progress_bar()
|
||||
|
||||
logging.info(f'pipeline_for {model} {mode}')
|
||||
|
|
|
@ -13,7 +13,7 @@ class ModelMode(StrEnum):
|
|||
|
||||
class ModelDesc(Struct):
|
||||
short: str # short unique name
|
||||
mem: float # recomended mem
|
||||
mem: float # recomended mem in gb
|
||||
attrs: dict # additional mode specific attrs
|
||||
tags: list[ModelMode]
|
||||
|
||||
|
|
Loading…
Reference in New Issue