Refactor ModelMngr to be a context manager + function combo

guilles_counter_review
Guillermo Rodriguez 2025-02-05 19:24:21 -03:00
parent b3dc7c1074
commit cd028d15e7
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
3 changed files with 124 additions and 139 deletions

View File

@ -8,7 +8,6 @@ from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart from quart_trio import QuartTrio as Quart
from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr
from skynet.dgpu.daemon import WorkerDaemon from skynet.dgpu.daemon import WorkerDaemon
from skynet.dgpu.network import NetConnector from skynet.dgpu.network import NetConnector
@ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None:
tui = WorkerMonitor() tui = WorkerMonitor()
conn = NetConnector(config, tui=tui) conn = NetConnector(config, tui=tui)
mm = ModelMngr(config, tui=tui) daemon = WorkerDaemon(conn, config, tui=tui)
daemon = WorkerDaemon(mm, conn, config, tui=tui)
api: Quart|None = None api: Quart|None = None
if 'api_bind' in config: if 'api_bind' in config:

View File

@ -7,6 +7,7 @@ import gc
import logging import logging
from hashlib import sha256 from hashlib import sha256
from contextlib import contextmanager as cm
import trio import trio
import torch import torch
@ -66,65 +67,59 @@ def prepare_params_for_diffuse(
_params _params
) )
_model_name: str = ''
_model_mode: str = ''
_model = None
class ModelMngr: @cm
''' def maybe_load_model(name: str, mode: str):
(AI algo) Model manager for loading models, computing outputs, if mode == 'diffuse':
checking load state, and unloading when no-longer-needed/finished. mode = 'txt2img'
''' global _model_name, _model_mode, _model
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
self._tui = tui
self.cache_dir = None
if 'hf_home' in config:
self.cache_dir = config['hf_home']
self._model_name: str = ''
self._model_mode: str = ''
def log_debug_info(self):
logging.debug('memory summary:')
logging.debug('\n' + torch.cuda.memory_summary())
def is_model_loaded(self, name: str, mode: str):
if (name == self._model_name and
mode == self._model_mode):
return True
return False
def unload_model(self) -> None:
if getattr(self, '_model', None):
del self._model
if _model_name != name or _model_mode != mode:
# unload model
_model = None
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self._model_name = '' _model_name = _model_mode = ''
self._model_mode = ''
def load_model( # load model
self, if mode == 'upscale':
name: str, _model = init_upscaler()
mode: str
) -> None:
logging.info(f'loading model {name}...')
self.unload_model()
self._model = pipeline_for( else:
name, mode, cache_dir=self.cache_dir) _model = pipeline_for(
self._model_mode = mode name, mode, cache_dir='hf_home')
self._model_name = name
logging.info(f'{name} loaded!')
self.log_debug_info()
def compute_one( _model_name = name
self, _model_mode = mode
logging.debug('memory summary:')
logging.debug('\n' + torch.cuda.memory_summary())
yield
def compute_one(
request_id: int, request_id: int,
method: str, method: str,
params: dict, params: dict,
inputs: list[bytes] = [] inputs: list[bytes] = [],
): should_cancel = None,
tui: WorkerMonitor | None = None
):
if method == 'diffuse':
method = 'txt2img'
global _model, _model_name, _model_mode
# validate correct model is loaded
assert params['model'] == _model_name
assert method == _model_mode
total_steps = params['step'] total_steps = params['step']
def inference_step_wakeup(*args, **kwargs): def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step, '''This is a callback function that gets invoked every inference step,
@ -135,18 +130,20 @@ class ModelMngr:
if not isinstance(step, int): if not isinstance(step, int):
step = args[1] step = args[1]
if self._tui: if tui:
self._tui.set_progress(step, done=total_steps) tui.set_progress(step, done=total_steps)
if should_cancel:
should_raise = trio.from_thread.run(should_cancel, request_id)
should_raise = trio.from_thread.run(self._should_cancel, request_id)
if should_raise: if should_raise:
logging.warning(f'CANCELLING work at step {step}') logging.warning(f'CANCELLING work at step {step}')
raise DGPUInferenceCancelled('network cancel') raise DGPUInferenceCancelled('network cancel')
return {} return {}
if self._tui: if tui:
self._tui.set_status(f'Request #{request_id}') tui.set_status(f'Request #{request_id}')
inference_step_wakeup(0) inference_step_wakeup(0)
@ -160,10 +157,7 @@ class ModelMngr:
name = params['model'] name = params['model']
match method: match method:
case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint': case 'txt2img' | 'img2img' | 'inpaint':
if not self.is_model_loaded(name, method):
self.load_model(name, method)
arguments = prepare_params_for_diffuse( arguments = prepare_params_for_diffuse(
params, method, inputs) params, method, inputs)
prompt, guidance, step, seed, upscaler, extra_params = arguments prompt, guidance, step, seed, upscaler, extra_params = arguments
@ -175,7 +169,7 @@ class ModelMngr:
extra_params['callback'] = inference_step_wakeup extra_params['callback'] = inference_step_wakeup
extra_params['callback_steps'] = 1 extra_params['callback_steps'] = 1
output = self._model( output = _model(
prompt, prompt,
guidance_scale=guidance, guidance_scale=guidance,
num_inference_steps=step, num_inference_steps=step,
@ -201,14 +195,8 @@ class ModelMngr:
output_hash = sha256(output_binary).hexdigest() output_hash = sha256(output_binary).hexdigest()
case 'upscale': case 'upscale':
if self._model_mode != 'upscale':
self.unload_model()
self._model = init_upscaler()
self._model_mode = 'upscale'
self._model_name = 'realesrgan'
input_img = inputs[0].convert('RGB') input_img = inputs[0].convert('RGB')
up_img, _ = self._model.enhance( up_img, _ = _model.enhance(
convert_from_image_to_cv2(input_img), outscale=4) convert_from_image_to_cv2(input_img), outscale=4)
output = convert_from_cv2_to_image(up_img) output = convert_from_cv2_to_image(up_img)
@ -222,10 +210,7 @@ class ModelMngr:
except BaseException as err: except BaseException as err:
raise DGPUComputeError(str(err)) from err raise DGPUComputeError(str(err)) from err
finally: if tui:
torch.cuda.empty_cache() tui.set_status('')
if self._tui:
self._tui.set_status('')
return output_hash, output return output_hash, output

View File

@ -18,7 +18,7 @@ from skynet.dgpu.errors import (
DGPUComputeError, DGPUComputeError,
) )
from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import NetConnector from skynet.dgpu.network import NetConnector
@ -40,12 +40,10 @@ class WorkerDaemon:
''' '''
def __init__( def __init__(
self, self,
mm: ModelMngr,
conn: NetConnector, conn: NetConnector,
config: dict, config: dict,
tui: WorkerMonitor | None = None tui: WorkerMonitor | None = None
): ):
self.mm: ModelMngr = mm
self.conn: NetConnector = conn self.conn: NetConnector = conn
self._tui = tui self._tui = tui
self.auto_withdraw = ( self.auto_withdraw = (
@ -248,14 +246,17 @@ class WorkerDaemon:
logging.info(f'calculated request hash: {request_hash}') logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step'] total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
# TODO: validate request # TODO: validate request
resp = await self.conn.begin_work(rid) resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp: if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.') logging.info('begin_work error, probably being worked on already... skip.')
return False
else: with maybe_load_model(model, mode):
try: try:
if self._tui: if self._tui:
self._tui.set_progress(0, done=total_step) self._tui.set_progress(0, done=total_step)
@ -268,13 +269,14 @@ class WorkerDaemon:
output_hash = None output_hash = None
match self.backend: match self.backend:
case 'sync-on-thread': case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync( output_hash, output = await trio.to_thread.run_sync(
partial( partial(
self.mm.compute_one, compute_one,
rid, rid,
body['method'], body['params'], mode, body['params'],
inputs=inputs inputs=inputs,
should_cancel=self.should_cancel_work,
tui=self._tui
) )
) )