Refactor ModelMngr to be a context manager + function combo

guilles_counter_review
Guillermo Rodriguez 2025-02-05 19:24:21 -03:00
parent b3dc7c1074
commit cd028d15e7
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
3 changed files with 124 additions and 139 deletions

View File

@ -8,7 +8,6 @@ from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr
from skynet.dgpu.daemon import WorkerDaemon
from skynet.dgpu.network import NetConnector
@ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None:
tui = WorkerMonitor()
conn = NetConnector(config, tui=tui)
mm = ModelMngr(config, tui=tui)
daemon = WorkerDaemon(mm, conn, config, tui=tui)
daemon = WorkerDaemon(conn, config, tui=tui)
api: Quart|None = None
if 'api_bind' in config:

View File

@ -7,6 +7,7 @@ import gc
import logging
from hashlib import sha256
from contextlib import contextmanager as cm
import trio
import torch
@ -66,166 +67,150 @@ def prepare_params_for_diffuse(
_params
)
_model_name: str = ''
_model_mode: str = ''
_model = None
class ModelMngr:
'''
(AI algo) Model manager for loading models, computing outputs,
checking load state, and unloading when no-longer-needed/finished.
@cm
def maybe_load_model(name: str, mode: str):
if mode == 'diffuse':
mode = 'txt2img'
'''
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
self._tui = tui
self.cache_dir = None
if 'hf_home' in config:
self.cache_dir = config['hf_home']
self._model_name: str = ''
self._model_mode: str = ''
def log_debug_info(self):
logging.debug('memory summary:')
logging.debug('\n' + torch.cuda.memory_summary())
def is_model_loaded(self, name: str, mode: str):
if (name == self._model_name and
mode == self._model_mode):
return True
return False
def unload_model(self) -> None:
if getattr(self, '_model', None):
del self._model
global _model_name, _model_mode, _model
if _model_name != name or _model_mode != mode:
# unload model
_model = None
gc.collect()
torch.cuda.empty_cache()
self._model_name = ''
self._model_mode = ''
_model_name = _model_mode = ''
def load_model(
self,
name: str,
mode: str
) -> None:
logging.info(f'loading model {name}...')
self.unload_model()
# load model
if mode == 'upscale':
_model = init_upscaler()
self._model = pipeline_for(
name, mode, cache_dir=self.cache_dir)
self._model_mode = mode
self._model_name = name
logging.info(f'{name} loaded!')
self.log_debug_info()
else:
_model = pipeline_for(
name, mode, cache_dir='hf_home')
def compute_one(
self,
request_id: int,
method: str,
params: dict,
inputs: list[bytes] = []
):
total_steps = params['step']
def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step,
we need to raise an exception here if we need to cancel work
'''
step = args[0]
# compat with callback_on_step_end
if not isinstance(step, int):
step = args[1]
_model_name = name
_model_mode = mode
if self._tui:
self._tui.set_progress(step, done=total_steps)
logging.debug('memory summary:')
logging.debug('\n' + torch.cuda.memory_summary())
should_raise = trio.from_thread.run(self._should_cancel, request_id)
if should_raise:
logging.warning(f'CANCELLING work at step {step}')
raise DGPUInferenceCancelled('network cancel')
yield
return {}
if self._tui:
self._tui.set_status(f'Request #{request_id}')
def compute_one(
request_id: int,
method: str,
params: dict,
inputs: list[bytes] = [],
should_cancel = None,
tui: WorkerMonitor | None = None
):
if method == 'diffuse':
method = 'txt2img'
inference_step_wakeup(0)
global _model, _model_name, _model_mode
output_type = 'png'
if 'output_type' in params:
output_type = params['output_type']
# validate correct model is loaded
assert params['model'] == _model_name
assert method == _model_mode
output = None
output_hash = None
try:
name = params['model']
total_steps = params['step']
def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step,
we need to raise an exception here if we need to cancel work
'''
step = args[0]
# compat with callback_on_step_end
if not isinstance(step, int):
step = args[1]
match method:
case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
if not self.is_model_loaded(name, method):
self.load_model(name, method)
if tui:
tui.set_progress(step, done=total_steps)
arguments = prepare_params_for_diffuse(
params, method, inputs)
prompt, guidance, step, seed, upscaler, extra_params = arguments
if should_cancel:
should_raise = trio.from_thread.run(should_cancel, request_id)
if 'flux' in name.lower():
extra_params['callback_on_step_end'] = inference_step_wakeup
if should_raise:
logging.warning(f'CANCELLING work at step {step}')
raise DGPUInferenceCancelled('network cancel')
else:
extra_params['callback'] = inference_step_wakeup
extra_params['callback_steps'] = 1
return {}
output = self._model(
prompt,
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
**extra_params
).images[0]
if tui:
tui.set_status(f'Request #{request_id}')
output_binary = b''
match output_type:
case 'png':
if upscaler == 'x4':
input_img = output.convert('RGB')
up_img, _ = init_upscaler().enhance(
convert_from_image_to_cv2(input_img), outscale=4)
inference_step_wakeup(0)
output = convert_from_cv2_to_image(up_img)
output_type = 'png'
if 'output_type' in params:
output_type = params['output_type']
output_binary = convert_from_img_to_bytes(output)
output = None
output_hash = None
try:
name = params['model']
case _:
raise DGPUComputeError(f'Unsupported output type: {output_type}')
match method:
case 'txt2img' | 'img2img' | 'inpaint':
arguments = prepare_params_for_diffuse(
params, method, inputs)
prompt, guidance, step, seed, upscaler, extra_params = arguments
output_hash = sha256(output_binary).hexdigest()
if 'flux' in name.lower():
extra_params['callback_on_step_end'] = inference_step_wakeup
case 'upscale':
if self._model_mode != 'upscale':
self.unload_model()
self._model = init_upscaler()
self._model_mode = 'upscale'
self._model_name = 'realesrgan'
else:
extra_params['callback'] = inference_step_wakeup
extra_params['callback_steps'] = 1
input_img = inputs[0].convert('RGB')
up_img, _ = self._model.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
output = _model(
prompt,
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
**extra_params
).images[0]
output = convert_from_cv2_to_image(up_img)
output_binary = b''
match output_type:
case 'png':
if upscaler == 'x4':
input_img = output.convert('RGB')
up_img, _ = init_upscaler().enhance(
convert_from_image_to_cv2(input_img), outscale=4)
output_binary = convert_from_img_to_bytes(output)
output_hash = sha256(output_binary).hexdigest()
output = convert_from_cv2_to_image(up_img)
case _:
raise DGPUComputeError('Unsupported compute method')
output_binary = convert_from_img_to_bytes(output)
except BaseException as err:
raise DGPUComputeError(str(err)) from err
case _:
raise DGPUComputeError(f'Unsupported output type: {output_type}')
finally:
torch.cuda.empty_cache()
output_hash = sha256(output_binary).hexdigest()
if self._tui:
self._tui.set_status('')
case 'upscale':
input_img = inputs[0].convert('RGB')
up_img, _ = _model.enhance(
convert_from_image_to_cv2(input_img), outscale=4)
return output_hash, output
output = convert_from_cv2_to_image(up_img)
output_binary = convert_from_img_to_bytes(output)
output_hash = sha256(output_binary).hexdigest()
case _:
raise DGPUComputeError('Unsupported compute method')
except BaseException as err:
raise DGPUComputeError(str(err)) from err
if tui:
tui.set_status('')
return output_hash, output

View File

@ -18,7 +18,7 @@ from skynet.dgpu.errors import (
DGPUComputeError,
)
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr
from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import NetConnector
@ -40,12 +40,10 @@ class WorkerDaemon:
'''
def __init__(
self,
mm: ModelMngr,
conn: NetConnector,
config: dict,
tui: WorkerMonitor | None = None
):
self.mm: ModelMngr = mm
self.conn: NetConnector = conn
self._tui = tui
self.auto_withdraw = (
@ -248,14 +246,17 @@ class WorkerDaemon:
logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step']
model = body['params']['model']
mode = body['method']
# TODO: validate request
resp = await self.conn.begin_work(rid)
if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.')
return False
else:
with maybe_load_model(model, mode):
try:
if self._tui:
self._tui.set_progress(0, done=total_step)
@ -268,13 +269,14 @@ class WorkerDaemon:
output_hash = None
match self.backend:
case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
compute_one,
rid,
body['method'], body['params'],
inputs=inputs
mode, body['params'],
inputs=inputs,
should_cancel=self.should_cancel_work,
tui=self._tui
)
)