mirror of https://github.com/skygpu/skynet.git
Refactor ModelMngr to be a context manager + function combo
parent
b3dc7c1074
commit
cd028d15e7
|
@ -8,7 +8,6 @@ from hypercorn.trio import serve
|
|||
from quart_trio import QuartTrio as Quart
|
||||
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.dgpu.compute import ModelMngr
|
||||
from skynet.dgpu.daemon import WorkerDaemon
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
@ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None:
|
|||
tui = WorkerMonitor()
|
||||
|
||||
conn = NetConnector(config, tui=tui)
|
||||
mm = ModelMngr(config, tui=tui)
|
||||
daemon = WorkerDaemon(mm, conn, config, tui=tui)
|
||||
daemon = WorkerDaemon(conn, config, tui=tui)
|
||||
|
||||
api: Quart|None = None
|
||||
if 'api_bind' in config:
|
||||
|
|
|
@ -7,6 +7,7 @@ import gc
|
|||
import logging
|
||||
|
||||
from hashlib import sha256
|
||||
from contextlib import contextmanager as cm
|
||||
|
||||
import trio
|
||||
import torch
|
||||
|
@ -66,65 +67,59 @@ def prepare_params_for_diffuse(
|
|||
_params
|
||||
)
|
||||
|
||||
_model_name: str = ''
|
||||
_model_mode: str = ''
|
||||
_model = None
|
||||
|
||||
class ModelMngr:
|
||||
'''
|
||||
(AI algo) Model manager for loading models, computing outputs,
|
||||
checking load state, and unloading when no-longer-needed/finished.
|
||||
@cm
|
||||
def maybe_load_model(name: str, mode: str):
|
||||
if mode == 'diffuse':
|
||||
mode = 'txt2img'
|
||||
|
||||
'''
|
||||
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
|
||||
self._tui = tui
|
||||
self.cache_dir = None
|
||||
if 'hf_home' in config:
|
||||
self.cache_dir = config['hf_home']
|
||||
|
||||
self._model_name: str = ''
|
||||
self._model_mode: str = ''
|
||||
|
||||
def log_debug_info(self):
|
||||
logging.debug('memory summary:')
|
||||
logging.debug('\n' + torch.cuda.memory_summary())
|
||||
|
||||
def is_model_loaded(self, name: str, mode: str):
|
||||
if (name == self._model_name and
|
||||
mode == self._model_mode):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def unload_model(self) -> None:
|
||||
if getattr(self, '_model', None):
|
||||
del self._model
|
||||
global _model_name, _model_mode, _model
|
||||
|
||||
if _model_name != name or _model_mode != mode:
|
||||
# unload model
|
||||
_model = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self._model_name = ''
|
||||
self._model_mode = ''
|
||||
_model_name = _model_mode = ''
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
name: str,
|
||||
mode: str
|
||||
) -> None:
|
||||
logging.info(f'loading model {name}...')
|
||||
self.unload_model()
|
||||
# load model
|
||||
if mode == 'upscale':
|
||||
_model = init_upscaler()
|
||||
|
||||
self._model = pipeline_for(
|
||||
name, mode, cache_dir=self.cache_dir)
|
||||
self._model_mode = mode
|
||||
self._model_name = name
|
||||
logging.info(f'{name} loaded!')
|
||||
self.log_debug_info()
|
||||
else:
|
||||
_model = pipeline_for(
|
||||
name, mode, cache_dir='hf_home')
|
||||
|
||||
def compute_one(
|
||||
self,
|
||||
_model_name = name
|
||||
_model_mode = mode
|
||||
|
||||
logging.debug('memory summary:')
|
||||
logging.debug('\n' + torch.cuda.memory_summary())
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def compute_one(
|
||||
request_id: int,
|
||||
method: str,
|
||||
params: dict,
|
||||
inputs: list[bytes] = []
|
||||
):
|
||||
inputs: list[bytes] = [],
|
||||
should_cancel = None,
|
||||
tui: WorkerMonitor | None = None
|
||||
):
|
||||
if method == 'diffuse':
|
||||
method = 'txt2img'
|
||||
|
||||
global _model, _model_name, _model_mode
|
||||
|
||||
# validate correct model is loaded
|
||||
assert params['model'] == _model_name
|
||||
assert method == _model_mode
|
||||
|
||||
total_steps = params['step']
|
||||
def inference_step_wakeup(*args, **kwargs):
|
||||
'''This is a callback function that gets invoked every inference step,
|
||||
|
@ -135,18 +130,20 @@ class ModelMngr:
|
|||
if not isinstance(step, int):
|
||||
step = args[1]
|
||||
|
||||
if self._tui:
|
||||
self._tui.set_progress(step, done=total_steps)
|
||||
if tui:
|
||||
tui.set_progress(step, done=total_steps)
|
||||
|
||||
if should_cancel:
|
||||
should_raise = trio.from_thread.run(should_cancel, request_id)
|
||||
|
||||
should_raise = trio.from_thread.run(self._should_cancel, request_id)
|
||||
if should_raise:
|
||||
logging.warning(f'CANCELLING work at step {step}')
|
||||
raise DGPUInferenceCancelled('network cancel')
|
||||
|
||||
return {}
|
||||
|
||||
if self._tui:
|
||||
self._tui.set_status(f'Request #{request_id}')
|
||||
if tui:
|
||||
tui.set_status(f'Request #{request_id}')
|
||||
|
||||
inference_step_wakeup(0)
|
||||
|
||||
|
@ -160,10 +157,7 @@ class ModelMngr:
|
|||
name = params['model']
|
||||
|
||||
match method:
|
||||
case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
|
||||
if not self.is_model_loaded(name, method):
|
||||
self.load_model(name, method)
|
||||
|
||||
case 'txt2img' | 'img2img' | 'inpaint':
|
||||
arguments = prepare_params_for_diffuse(
|
||||
params, method, inputs)
|
||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||
|
@ -175,7 +169,7 @@ class ModelMngr:
|
|||
extra_params['callback'] = inference_step_wakeup
|
||||
extra_params['callback_steps'] = 1
|
||||
|
||||
output = self._model(
|
||||
output = _model(
|
||||
prompt,
|
||||
guidance_scale=guidance,
|
||||
num_inference_steps=step,
|
||||
|
@ -201,14 +195,8 @@ class ModelMngr:
|
|||
output_hash = sha256(output_binary).hexdigest()
|
||||
|
||||
case 'upscale':
|
||||
if self._model_mode != 'upscale':
|
||||
self.unload_model()
|
||||
self._model = init_upscaler()
|
||||
self._model_mode = 'upscale'
|
||||
self._model_name = 'realesrgan'
|
||||
|
||||
input_img = inputs[0].convert('RGB')
|
||||
up_img, _ = self._model.enhance(
|
||||
up_img, _ = _model.enhance(
|
||||
convert_from_image_to_cv2(input_img), outscale=4)
|
||||
|
||||
output = convert_from_cv2_to_image(up_img)
|
||||
|
@ -222,10 +210,7 @@ class ModelMngr:
|
|||
except BaseException as err:
|
||||
raise DGPUComputeError(str(err)) from err
|
||||
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self._tui:
|
||||
self._tui.set_status('')
|
||||
if tui:
|
||||
tui.set_status('')
|
||||
|
||||
return output_hash, output
|
||||
|
|
|
@ -18,7 +18,7 @@ from skynet.dgpu.errors import (
|
|||
DGPUComputeError,
|
||||
)
|
||||
from skynet.dgpu.tui import WorkerMonitor
|
||||
from skynet.dgpu.compute import ModelMngr
|
||||
from skynet.dgpu.compute import maybe_load_model, compute_one
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
||||
|
@ -40,12 +40,10 @@ class WorkerDaemon:
|
|||
'''
|
||||
def __init__(
|
||||
self,
|
||||
mm: ModelMngr,
|
||||
conn: NetConnector,
|
||||
config: dict,
|
||||
tui: WorkerMonitor | None = None
|
||||
):
|
||||
self.mm: ModelMngr = mm
|
||||
self.conn: NetConnector = conn
|
||||
self._tui = tui
|
||||
self.auto_withdraw = (
|
||||
|
@ -248,14 +246,17 @@ class WorkerDaemon:
|
|||
logging.info(f'calculated request hash: {request_hash}')
|
||||
|
||||
total_step = body['params']['step']
|
||||
model = body['params']['model']
|
||||
mode = body['method']
|
||||
|
||||
# TODO: validate request
|
||||
|
||||
resp = await self.conn.begin_work(rid)
|
||||
if not resp or 'code' in resp:
|
||||
logging.info('begin_work error, probably being worked on already... skip.')
|
||||
return False
|
||||
|
||||
else:
|
||||
with maybe_load_model(model, mode):
|
||||
try:
|
||||
if self._tui:
|
||||
self._tui.set_progress(0, done=total_step)
|
||||
|
@ -268,13 +269,14 @@ class WorkerDaemon:
|
|||
output_hash = None
|
||||
match self.backend:
|
||||
case 'sync-on-thread':
|
||||
self.mm._should_cancel = self.should_cancel_work
|
||||
output_hash, output = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
self.mm.compute_one,
|
||||
compute_one,
|
||||
rid,
|
||||
body['method'], body['params'],
|
||||
inputs=inputs
|
||||
mode, body['params'],
|
||||
inputs=inputs,
|
||||
should_cancel=self.should_cancel_work,
|
||||
tui=self._tui
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue