mirror of https://github.com/skygpu/skynet.git
Mostly minor typing and comment changes remaining from fomos re-review, only big change is removed BaseException catch inside compute_one
parent
1dd2a8ed89
commit
6a991561de
|
@ -1,16 +1,8 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from leap.protocol import (
|
|
||||||
Name,
|
|
||||||
Asset,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
load_skynet_toml,
|
load_skynet_toml,
|
||||||
set_hf_vars,
|
set_hf_vars,
|
||||||
|
@ -49,7 +41,7 @@ def txt2img(*args, **kwargs):
|
||||||
|
|
||||||
config = load_skynet_toml()
|
config = load_skynet_toml()
|
||||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||||
utils.txt2img(hf_token, **kwargs)
|
utils.txt2img(config.dgpu.hf_token, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
|
@ -74,7 +66,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
||||||
config = load_skynet_toml()
|
config = load_skynet_toml()
|
||||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||||
utils.img2img(
|
utils.img2img(
|
||||||
hf_token,
|
config.dgpu.hf_token,
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
img_path=input,
|
img_path=input,
|
||||||
|
@ -102,7 +94,7 @@ def inpaint(model, prompt, input, mask, output, strength, guidance, steps, seed)
|
||||||
config = load_skynet_toml()
|
config = load_skynet_toml()
|
||||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||||
utils.inpaint(
|
utils.inpaint(
|
||||||
hf_token,
|
config.dgpu.hf_token,
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
img_path=input,
|
img_path=input,
|
||||||
|
|
|
@ -5,7 +5,7 @@ import trio
|
||||||
import urwid
|
import urwid
|
||||||
|
|
||||||
from skynet.config import Config
|
from skynet.config import Config
|
||||||
from skynet.dgpu.tui import init_tui
|
from skynet.dgpu.tui import init_tui, WorkerMonitor
|
||||||
from skynet.dgpu.daemon import dgpu_serve_forever
|
from skynet.dgpu.daemon import dgpu_serve_forever
|
||||||
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr
|
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ async def open_worker(config: Config):
|
||||||
# suppress logs from httpx (logs url + status after every query)
|
# suppress logs from httpx (logs url + status after every query)
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
tui = None
|
tui: WorkerMonitor | None = None
|
||||||
if config.tui:
|
if config.tui:
|
||||||
tui = init_tui(config)
|
tui = init_tui(config)
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import gc
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from typing import Callable, Generator
|
||||||
from contextlib import contextmanager as cm
|
from contextlib import contextmanager as cm
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
@ -20,7 +21,14 @@ from skynet.dgpu.errors import (
|
||||||
DGPUInferenceCancelled,
|
DGPUInferenceCancelled,
|
||||||
)
|
)
|
||||||
|
|
||||||
from skynet.dgpu.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
from skynet.dgpu.utils import (
|
||||||
|
Pipeline,
|
||||||
|
crop_image,
|
||||||
|
convert_from_cv2_to_image,
|
||||||
|
convert_from_image_to_cv2,
|
||||||
|
convert_from_img_to_bytes,
|
||||||
|
pipeline_for
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare_params_for_diffuse(
|
def prepare_params_for_diffuse(
|
||||||
|
@ -68,17 +76,21 @@ def prepare_params_for_diffuse(
|
||||||
|
|
||||||
_model_name: str = ''
|
_model_name: str = ''
|
||||||
_model_mode: str = ''
|
_model_mode: str = ''
|
||||||
_model = None
|
_model: Pipeline | None = None
|
||||||
|
|
||||||
@cm
|
@cm
|
||||||
def maybe_load_model(name: str, mode: ModelMode):
|
def maybe_load_model(name: str, mode: ModelMode) -> Generator[Pipeline, None, None]:
|
||||||
if mode == ModelMode.DIFFUSE:
|
if mode == ModelMode.DIFFUSE:
|
||||||
mode = ModelMode.TXT2IMG
|
mode = ModelMode.TXT2IMG
|
||||||
|
|
||||||
global _model_name, _model_mode, _model
|
global _model_name, _model_mode, _model
|
||||||
config = load_skynet_toml().dgpu
|
config = load_skynet_toml().dgpu
|
||||||
|
|
||||||
if _model_name != name or _model_mode != mode:
|
if (
|
||||||
|
_model_name != name
|
||||||
|
or
|
||||||
|
_model_mode != mode
|
||||||
|
):
|
||||||
# unload model
|
# unload model
|
||||||
_model = None
|
_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
@ -94,24 +106,26 @@ def maybe_load_model(name: str, mode: ModelMode):
|
||||||
_model_mode = mode
|
_model_mode = mode
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
logging.debug('memory summary:')
|
logging.debug(
|
||||||
logging.debug('\n' + torch.cuda.memory_summary())
|
'memory summary:\n'
|
||||||
|
f'{torch.cuda.memory_summary()}'
|
||||||
|
)
|
||||||
|
|
||||||
yield _model
|
yield _model
|
||||||
|
|
||||||
|
|
||||||
def compute_one(
|
def compute_one(
|
||||||
model,
|
model: Pipeline,
|
||||||
request_id: int,
|
request_id: int,
|
||||||
method: ModelMode,
|
method: ModelMode,
|
||||||
params: BodyV0Params,
|
params: BodyV0Params,
|
||||||
inputs: list[bytes] = [],
|
inputs: list[bytes] = [],
|
||||||
should_cancel = None
|
should_cancel: Callable[[int, ...], dict] = None
|
||||||
):
|
):
|
||||||
total_steps = params.step
|
total_steps = params.step
|
||||||
def inference_step_wakeup(*args, **kwargs):
|
def inference_step_wakeup(*args, **kwargs):
|
||||||
'''This is a callback function that gets invoked every inference step,
|
'''This is a callback function that gets invoked every inference step,
|
||||||
we need to raise an exception here if we need to cancel work
|
we must raise DGPUInferenceCancelled here if we need to cancel work
|
||||||
'''
|
'''
|
||||||
step = args[0]
|
step = args[0]
|
||||||
# compat with callback_on_step_end
|
# compat with callback_on_step_end
|
||||||
|
@ -122,6 +136,9 @@ def compute_one(
|
||||||
|
|
||||||
should_raise = False
|
should_raise = False
|
||||||
if should_cancel:
|
if should_cancel:
|
||||||
|
'''Pump main thread event loop, evaluate if we should keep working
|
||||||
|
on this request, based on latest network info like competitors...
|
||||||
|
'''
|
||||||
should_raise = trio.from_thread.run(should_cancel, request_id)
|
should_raise = trio.from_thread.run(should_cancel, request_id)
|
||||||
|
|
||||||
if should_raise:
|
if should_raise:
|
||||||
|
@ -137,7 +154,6 @@ def compute_one(
|
||||||
output_type = params.output_type
|
output_type = params.output_type
|
||||||
output = None
|
output = None
|
||||||
output_hash = None
|
output_hash = None
|
||||||
try:
|
|
||||||
name = params.model
|
name = params.model
|
||||||
|
|
||||||
match method:
|
match method:
|
||||||
|
@ -189,9 +205,6 @@ def compute_one(
|
||||||
case _:
|
case _:
|
||||||
raise DGPUComputeError('Unsupported compute method')
|
raise DGPUComputeError('Unsupported compute method')
|
||||||
|
|
||||||
except BaseException as err:
|
|
||||||
raise DGPUComputeError(str(err)) from err
|
|
||||||
|
|
||||||
maybe_update_tui(lambda tui: tui.set_status(''))
|
maybe_update_tui(lambda tui: tui.set_status(''))
|
||||||
|
|
||||||
return output_hash, output
|
return output_hash, output
|
||||||
|
|
|
@ -132,6 +132,11 @@ async def maybe_serve_one(
|
||||||
output_hash = None
|
output_hash = None
|
||||||
match config.backend:
|
match config.backend:
|
||||||
case 'sync-on-thread':
|
case 'sync-on-thread':
|
||||||
|
'''Block this task until inference completes, pass
|
||||||
|
state_mngr.should_cancel_work predicate as the inference_step_wakeup cb
|
||||||
|
used by torch each step of the inference, it will use a
|
||||||
|
trio.from_thread to unblock the main thread and pump the event loop
|
||||||
|
'''
|
||||||
output_hash, output = await trio.to_thread.run_sync(
|
output_hash, output = await trio.to_thread.run_sync(
|
||||||
partial(
|
partial(
|
||||||
compute_one,
|
compute_one,
|
||||||
|
|
|
@ -141,10 +141,6 @@ class WorkerMonitor:
|
||||||
|
|
||||||
self.progress_bar.current = current
|
self.progress_bar.current = current
|
||||||
|
|
||||||
pct = 0
|
|
||||||
if self.progress_bar.done != 0:
|
|
||||||
pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
|
|
||||||
|
|
||||||
def update_requests(self, new_requests):
|
def update_requests(self, new_requests):
|
||||||
"""
|
"""
|
||||||
Replace the data in the existing ListBox with new request widgets.
|
Replace the data in the existing ListBox with new request widgets.
|
||||||
|
|
|
@ -75,6 +75,9 @@ class DummyPB:
|
||||||
def update(self):
|
def update(self):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
type Pipeline = DiffusionPipeline | RealESRGANer
|
||||||
|
|
||||||
@torch.compiler.disable
|
@torch.compiler.disable
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def dummy_progress_bar(*args, **kwargs):
|
def dummy_progress_bar(*args, **kwargs):
|
||||||
|
@ -90,7 +93,7 @@ def pipeline_for(
|
||||||
mode: str,
|
mode: str,
|
||||||
mem_fraction: float = 1.0,
|
mem_fraction: float = 1.0,
|
||||||
cache_dir: str | None = None
|
cache_dir: str | None = None
|
||||||
) -> DiffusionPipeline:
|
) -> Pipeline:
|
||||||
diffusers.utils.logging.disable_progress_bar()
|
diffusers.utils.logging.disable_progress_bar()
|
||||||
|
|
||||||
logging.info(f'pipeline_for {model} {mode}')
|
logging.info(f'pipeline_for {model} {mode}')
|
||||||
|
|
|
@ -13,7 +13,7 @@ class ModelMode(StrEnum):
|
||||||
|
|
||||||
class ModelDesc(Struct):
|
class ModelDesc(Struct):
|
||||||
short: str # short unique name
|
short: str # short unique name
|
||||||
mem: float # recomended mem
|
mem: float # recomended mem in gb
|
||||||
attrs: dict # additional mode specific attrs
|
attrs: dict # additional mode specific attrs
|
||||||
tags: list[ModelMode]
|
tags: list[ModelMode]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue