Refactor ModelMngr to be a context manager + function combo

guilles_counter_review
Guillermo Rodriguez 2025-02-05 19:24:21 -03:00
parent b3dc7c1074
commit cd028d15e7
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
3 changed files with 124 additions and 139 deletions

View File

@ -8,7 +8,6 @@ from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart from quart_trio import QuartTrio as Quart
from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr
from skynet.dgpu.daemon import WorkerDaemon from skynet.dgpu.daemon import WorkerDaemon
from skynet.dgpu.network import NetConnector from skynet.dgpu.network import NetConnector
@ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None:
tui = WorkerMonitor() tui = WorkerMonitor()
conn = NetConnector(config, tui=tui) conn = NetConnector(config, tui=tui)
mm = ModelMngr(config, tui=tui) daemon = WorkerDaemon(conn, config, tui=tui)
daemon = WorkerDaemon(mm, conn, config, tui=tui)
api: Quart|None = None api: Quart|None = None
if 'api_bind' in config: if 'api_bind' in config:

View File

@ -7,6 +7,7 @@ import gc
import logging import logging
from hashlib import sha256 from hashlib import sha256
from contextlib import contextmanager as cm
import trio import trio
import torch import torch
@ -66,166 +67,150 @@ def prepare_params_for_diffuse(
_params _params
) )
_model_name: str = ''
_model_mode: str = ''
_model = None
class ModelMngr: @cm
''' def maybe_load_model(name: str, mode: str):
(AI algo) Model manager for loading models, computing outputs, if mode == 'diffuse':
checking load state, and unloading when no-longer-needed/finished. mode = 'txt2img'
''' global _model_name, _model_mode, _model
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
self._tui = tui
self.cache_dir = None
if 'hf_home' in config:
self.cache_dir = config['hf_home']
self._model_name: str = ''
self._model_mode: str = ''
def log_debug_info(self):
logging.debug('memory summary:')
logging.debug('\n' + torch.cuda.memory_summary())
def is_model_loaded(self, name: str, mode: str):
if (name == self._model_name and
mode == self._model_mode):
return True
return False
def unload_model(self) -> None:
if getattr(self, '_model', None):
del self._model
if _model_name != name or _model_mode != mode:
# unload model
_model = None
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self._model_name = '' _model_name = _model_mode = ''
self._model_mode = ''
def load_model( # load model
self, if mode == 'upscale':
name: str, _model = init_upscaler()
mode: str
) -> None:
logging.info(f'loading model {name}...')
self.unload_model()
self._model = pipeline_for( else:
name, mode, cache_dir=self.cache_dir) _model = pipeline_for(
self._model_mode = mode name, mode, cache_dir='hf_home')
self._model_name = name
logging.info(f'{name} loaded!')
self.log_debug_info()
def compute_one( _model_name = name
self, _model_mode = mode
request_id: int,
method: str,
params: dict,
inputs: list[bytes] = []
):
total_steps = params['step']
def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step,
we need to raise an exception here if we need to cancel work
'''
step = args[0]
# compat with callback_on_step_end
if not isinstance(step, int):
step = args[1]
if self._tui: logging.debug('memory summary:')
self._tui.set_progress(step, done=total_steps) logging.debug('\n' + torch.cuda.memory_summary())
should_raise = trio.from_thread.run(self._should_cancel, request_id) yield
if should_raise:
logging.warning(f'CANCELLING work at step {step}')
raise DGPUInferenceCancelled('network cancel')
return {}
if self._tui: def compute_one(
self._tui.set_status(f'Request #{request_id}') request_id: int,
method: str,
params: dict,
inputs: list[bytes] = [],
should_cancel = None,
tui: WorkerMonitor | None = None
):
if method == 'diffuse':
method = 'txt2img'
inference_step_wakeup(0) global _model, _model_name, _model_mode
output_type = 'png' # validate correct model is loaded
if 'output_type' in params: assert params['model'] == _model_name
output_type = params['output_type'] assert method == _model_mode
output = None total_steps = params['step']
output_hash = None def inference_step_wakeup(*args, **kwargs):
try: '''This is a callback function that gets invoked every inference step,
name = params['model'] we need to raise an exception here if we need to cancel work
'''
step = args[0]
# compat with callback_on_step_end
if not isinstance(step, int):
step = args[1]
match method: if tui:
case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint': tui.set_progress(step, done=total_steps)
if not self.is_model_loaded(name, method):
self.load_model(name, method)
arguments = prepare_params_for_diffuse( if should_cancel:
params, method, inputs) should_raise = trio.from_thread.run(should_cancel, request_id)
prompt, guidance, step, seed, upscaler, extra_params = arguments
if 'flux' in name.lower(): if should_raise:
extra_params['callback_on_step_end'] = inference_step_wakeup logging.warning(f'CANCELLING work at step {step}')
raise DGPUInferenceCancelled('network cancel')
else: return {}
extra_params['callback'] = inference_step_wakeup
extra_params['callback_steps'] = 1
output = self._model( if tui:
prompt, tui.set_status(f'Request #{request_id}')
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
**extra_params
).images[0]
output_binary = b'' inference_step_wakeup(0)
match output_type:
case 'png':
if upscaler == 'x4':
input_img = output.convert('RGB')
up_img, _ = init_upscaler().enhance(
convert_from_image_to_cv2(input_img), outscale=4)
output = convert_from_cv2_to_image(up_img) output_type = 'png'
if 'output_type' in params:
output_type = params['output_type']
output_binary = convert_from_img_to_bytes(output) output = None
output_hash = None
try:
name = params['model']
case _: match method:
raise DGPUComputeError(f'Unsupported output type: {output_type}') case 'txt2img' | 'img2img' | 'inpaint':
arguments = prepare_params_for_diffuse(
params, method, inputs)
prompt, guidance, step, seed, upscaler, extra_params = arguments
output_hash = sha256(output_binary).hexdigest() if 'flux' in name.lower():
extra_params['callback_on_step_end'] = inference_step_wakeup
case 'upscale': else:
if self._model_mode != 'upscale': extra_params['callback'] = inference_step_wakeup
self.unload_model() extra_params['callback_steps'] = 1
self._model = init_upscaler()
self._model_mode = 'upscale'
self._model_name = 'realesrgan'
input_img = inputs[0].convert('RGB') output = _model(
up_img, _ = self._model.enhance( prompt,
convert_from_image_to_cv2(input_img), outscale=4) guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
**extra_params
).images[0]
output = convert_from_cv2_to_image(up_img) output_binary = b''
match output_type:
case 'png':
if upscaler == 'x4':
input_img = output.convert('RGB')
up_img, _ = init_upscaler().enhance(
convert_from_image_to_cv2(input_img), outscale=4)
output_binary = convert_from_img_to_bytes(output) output = convert_from_cv2_to_image(up_img)
output_hash = sha256(output_binary).hexdigest()
case _: output_binary = convert_from_img_to_bytes(output)
raise DGPUComputeError('Unsupported compute method')
except BaseException as err: case _:
raise DGPUComputeError(str(err)) from err raise DGPUComputeError(f'Unsupported output type: {output_type}')
finally: output_hash = sha256(output_binary).hexdigest()
torch.cuda.empty_cache()
if self._tui: case 'upscale':
self._tui.set_status('') input_img = inputs[0].convert('RGB')
up_img, _ = _model.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
return output_hash, output output = convert_from_cv2_to_image(up_img)
output_binary = convert_from_img_to_bytes(output)
output_hash = sha256(output_binary).hexdigest()
case _:
raise DGPUComputeError('Unsupported compute method')
except BaseException as err:
raise DGPUComputeError(str(err)) from err
if tui:
tui.set_status('')
return output_hash, output

View File

@ -18,7 +18,7 @@ from skynet.dgpu.errors import (
DGPUComputeError, DGPUComputeError,
) )
from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import NetConnector from skynet.dgpu.network import NetConnector
@ -40,12 +40,10 @@ class WorkerDaemon:
''' '''
def __init__( def __init__(
self, self,
mm: ModelMngr,
conn: NetConnector, conn: NetConnector,
config: dict, config: dict,
tui: WorkerMonitor | None = None tui: WorkerMonitor | None = None
): ):
self.mm: ModelMngr = mm
self.conn: NetConnector = conn self.conn: NetConnector = conn
self._tui = tui self._tui = tui
self.auto_withdraw = ( self.auto_withdraw = (
@ -248,14 +246,17 @@ class WorkerDaemon:
logging.info(f'calculated request hash: {request_hash}') logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step'] total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
# TODO: validate request # TODO: validate request
resp = await self.conn.begin_work(rid) resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp: if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.') logging.info('begin_work error, probably being worked on already... skip.')
return False
else: with maybe_load_model(model, mode):
try: try:
if self._tui: if self._tui:
self._tui.set_progress(0, done=total_step) self._tui.set_progress(0, done=total_step)
@ -268,13 +269,14 @@ class WorkerDaemon:
output_hash = None output_hash = None
match self.backend: match self.backend:
case 'sync-on-thread': case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync( output_hash, output = await trio.to_thread.run_sync(
partial( partial(
self.mm.compute_one, compute_one,
rid, rid,
body['method'], body['params'], mode, body['params'],
inputs=inputs inputs=inputs,
should_cancel=self.should_cancel_work,
tui=self._tui
) )
) )