Refactor ModelMngr to be a context manager + function combo

guilles_counter_review
Guillermo Rodriguez 2025-02-05 19:24:21 -03:00
parent b3dc7c1074
commit cd028d15e7
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
3 changed files with 124 additions and 139 deletions

View File

@ -8,7 +8,6 @@ from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr
from skynet.dgpu.daemon import WorkerDaemon
from skynet.dgpu.network import NetConnector
@ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None:
tui = WorkerMonitor()
conn = NetConnector(config, tui=tui)
mm = ModelMngr(config, tui=tui)
daemon = WorkerDaemon(mm, conn, config, tui=tui)
daemon = WorkerDaemon(conn, config, tui=tui)
api: Quart|None = None
if 'api_bind' in config:

View File

@ -7,6 +7,7 @@ import gc
import logging
from hashlib import sha256
from contextlib import contextmanager as cm
import trio
import torch
@ -66,65 +67,59 @@ def prepare_params_for_diffuse(
_params
)
_model_name: str = ''
_model_mode: str = ''
_model = None
class ModelMngr:
'''
(AI algo) Model manager for loading models, computing outputs,
checking load state, and unloading when no-longer-needed/finished.
@cm
def maybe_load_model(name: str, mode: str):
if mode == 'diffuse':
mode = 'txt2img'
'''
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
self._tui = tui
self.cache_dir = None
if 'hf_home' in config:
self.cache_dir = config['hf_home']
self._model_name: str = ''
self._model_mode: str = ''
def log_debug_info(self):
logging.debug('memory summary:')
logging.debug('\n' + torch.cuda.memory_summary())
def is_model_loaded(self, name: str, mode: str):
if (name == self._model_name and
mode == self._model_mode):
return True
return False
def unload_model(self) -> None:
if getattr(self, '_model', None):
del self._model
global _model_name, _model_mode, _model
if _model_name != name or _model_mode != mode:
# unload model
_model = None
gc.collect()
torch.cuda.empty_cache()
self._model_name = ''
self._model_mode = ''
_model_name = _model_mode = ''
def load_model(
self,
name: str,
mode: str
) -> None:
logging.info(f'loading model {name}...')
self.unload_model()
# load model
if mode == 'upscale':
_model = init_upscaler()
self._model = pipeline_for(
name, mode, cache_dir=self.cache_dir)
self._model_mode = mode
self._model_name = name
logging.info(f'{name} loaded!')
self.log_debug_info()
else:
_model = pipeline_for(
name, mode, cache_dir='hf_home')
def compute_one(
self,
_model_name = name
_model_mode = mode
logging.debug('memory summary:')
logging.debug('\n' + torch.cuda.memory_summary())
yield
def compute_one(
request_id: int,
method: str,
params: dict,
inputs: list[bytes] = []
):
inputs: list[bytes] = [],
should_cancel = None,
tui: WorkerMonitor | None = None
):
if method == 'diffuse':
method = 'txt2img'
global _model, _model_name, _model_mode
# validate correct model is loaded
assert params['model'] == _model_name
assert method == _model_mode
total_steps = params['step']
def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step,
@ -135,18 +130,20 @@ class ModelMngr:
if not isinstance(step, int):
step = args[1]
if self._tui:
self._tui.set_progress(step, done=total_steps)
if tui:
tui.set_progress(step, done=total_steps)
if should_cancel:
should_raise = trio.from_thread.run(should_cancel, request_id)
should_raise = trio.from_thread.run(self._should_cancel, request_id)
if should_raise:
logging.warning(f'CANCELLING work at step {step}')
raise DGPUInferenceCancelled('network cancel')
return {}
if self._tui:
self._tui.set_status(f'Request #{request_id}')
if tui:
tui.set_status(f'Request #{request_id}')
inference_step_wakeup(0)
@ -160,10 +157,7 @@ class ModelMngr:
name = params['model']
match method:
case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
if not self.is_model_loaded(name, method):
self.load_model(name, method)
case 'txt2img' | 'img2img' | 'inpaint':
arguments = prepare_params_for_diffuse(
params, method, inputs)
prompt, guidance, step, seed, upscaler, extra_params = arguments
@ -175,7 +169,7 @@ class ModelMngr:
extra_params['callback'] = inference_step_wakeup
extra_params['callback_steps'] = 1
output = self._model(
output = _model(
prompt,
guidance_scale=guidance,
num_inference_steps=step,
@ -201,14 +195,8 @@ class ModelMngr:
output_hash = sha256(output_binary).hexdigest()
case 'upscale':
if self._model_mode != 'upscale':
self.unload_model()
self._model = init_upscaler()
self._model_mode = 'upscale'
self._model_name = 'realesrgan'
input_img = inputs[0].convert('RGB')
up_img, _ = self._model.enhance(
up_img, _ = _model.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
output = convert_from_cv2_to_image(up_img)
@ -222,10 +210,7 @@ class ModelMngr:
except BaseException as err:
raise DGPUComputeError(str(err)) from err
finally:
torch.cuda.empty_cache()
if self._tui:
self._tui.set_status('')
if tui:
tui.set_status('')
return output_hash, output

View File

@ -18,7 +18,7 @@ from skynet.dgpu.errors import (
DGPUComputeError,
)
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr
from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import NetConnector
@ -40,12 +40,10 @@ class WorkerDaemon:
'''
def __init__(
self,
mm: ModelMngr,
conn: NetConnector,
config: dict,
tui: WorkerMonitor | None = None
):
self.mm: ModelMngr = mm
self.conn: NetConnector = conn
self._tui = tui
self.auto_withdraw = (
@ -248,14 +246,17 @@ class WorkerDaemon:
logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
# TODO: validate request
resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.')
return False
else:
with maybe_load_model(model, mode):
try:
if self._tui:
self._tui.set_progress(0, done=total_step)
@ -268,13 +269,14 @@ class WorkerDaemon:
output_hash = None
match self.backend:
case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
compute_one,
rid,
body['method'], body['params'],
inputs=inputs
mode, body['params'],
inputs=inputs,
should_cancel=self.should_cancel_work,
tui=self._tui
)
)