mirror of https://github.com/skygpu/skynet.git
Refactor ModelMngr to be a context manager + function combo
parent
b3dc7c1074
commit
cd028d15e7
|
@ -8,7 +8,6 @@ from hypercorn.trio import serve
|
||||||
from quart_trio import QuartTrio as Quart
|
from quart_trio import QuartTrio as Quart
|
||||||
|
|
||||||
from skynet.dgpu.tui import WorkerMonitor
|
from skynet.dgpu.tui import WorkerMonitor
|
||||||
from skynet.dgpu.compute import ModelMngr
|
|
||||||
from skynet.dgpu.daemon import WorkerDaemon
|
from skynet.dgpu.daemon import WorkerDaemon
|
||||||
from skynet.dgpu.network import NetConnector
|
from skynet.dgpu.network import NetConnector
|
||||||
|
|
||||||
|
@ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None:
|
||||||
tui = WorkerMonitor()
|
tui = WorkerMonitor()
|
||||||
|
|
||||||
conn = NetConnector(config, tui=tui)
|
conn = NetConnector(config, tui=tui)
|
||||||
mm = ModelMngr(config, tui=tui)
|
daemon = WorkerDaemon(conn, config, tui=tui)
|
||||||
daemon = WorkerDaemon(mm, conn, config, tui=tui)
|
|
||||||
|
|
||||||
api: Quart|None = None
|
api: Quart|None = None
|
||||||
if 'api_bind' in config:
|
if 'api_bind' in config:
|
||||||
|
|
|
@ -7,6 +7,7 @@ import gc
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from contextlib import contextmanager as cm
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import torch
|
import torch
|
||||||
|
@ -66,65 +67,59 @@ def prepare_params_for_diffuse(
|
||||||
_params
|
_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_model_name: str = ''
|
||||||
|
_model_mode: str = ''
|
||||||
|
_model = None
|
||||||
|
|
||||||
class ModelMngr:
|
@cm
|
||||||
'''
|
def maybe_load_model(name: str, mode: str):
|
||||||
(AI algo) Model manager for loading models, computing outputs,
|
if mode == 'diffuse':
|
||||||
checking load state, and unloading when no-longer-needed/finished.
|
mode = 'txt2img'
|
||||||
|
|
||||||
'''
|
global _model_name, _model_mode, _model
|
||||||
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
|
|
||||||
self._tui = tui
|
|
||||||
self.cache_dir = None
|
|
||||||
if 'hf_home' in config:
|
|
||||||
self.cache_dir = config['hf_home']
|
|
||||||
|
|
||||||
self._model_name: str = ''
|
|
||||||
self._model_mode: str = ''
|
|
||||||
|
|
||||||
def log_debug_info(self):
|
|
||||||
logging.debug('memory summary:')
|
|
||||||
logging.debug('\n' + torch.cuda.memory_summary())
|
|
||||||
|
|
||||||
def is_model_loaded(self, name: str, mode: str):
|
|
||||||
if (name == self._model_name and
|
|
||||||
mode == self._model_mode):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def unload_model(self) -> None:
|
|
||||||
if getattr(self, '_model', None):
|
|
||||||
del self._model
|
|
||||||
|
|
||||||
|
if _model_name != name or _model_mode != mode:
|
||||||
|
# unload model
|
||||||
|
_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
self._model_name = ''
|
_model_name = _model_mode = ''
|
||||||
self._model_mode = ''
|
|
||||||
|
|
||||||
def load_model(
|
# load model
|
||||||
self,
|
if mode == 'upscale':
|
||||||
name: str,
|
_model = init_upscaler()
|
||||||
mode: str
|
|
||||||
) -> None:
|
else:
|
||||||
logging.info(f'loading model {name}...')
|
_model = pipeline_for(
|
||||||
self.unload_model()
|
name, mode, cache_dir='hf_home')
|
||||||
|
|
||||||
|
_model_name = name
|
||||||
|
_model_mode = mode
|
||||||
|
|
||||||
|
logging.debug('memory summary:')
|
||||||
|
logging.debug('\n' + torch.cuda.memory_summary())
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
self._model = pipeline_for(
|
|
||||||
name, mode, cache_dir=self.cache_dir)
|
|
||||||
self._model_mode = mode
|
|
||||||
self._model_name = name
|
|
||||||
logging.info(f'{name} loaded!')
|
|
||||||
self.log_debug_info()
|
|
||||||
|
|
||||||
def compute_one(
|
def compute_one(
|
||||||
self,
|
|
||||||
request_id: int,
|
request_id: int,
|
||||||
method: str,
|
method: str,
|
||||||
params: dict,
|
params: dict,
|
||||||
inputs: list[bytes] = []
|
inputs: list[bytes] = [],
|
||||||
|
should_cancel = None,
|
||||||
|
tui: WorkerMonitor | None = None
|
||||||
):
|
):
|
||||||
|
if method == 'diffuse':
|
||||||
|
method = 'txt2img'
|
||||||
|
|
||||||
|
global _model, _model_name, _model_mode
|
||||||
|
|
||||||
|
# validate correct model is loaded
|
||||||
|
assert params['model'] == _model_name
|
||||||
|
assert method == _model_mode
|
||||||
|
|
||||||
total_steps = params['step']
|
total_steps = params['step']
|
||||||
def inference_step_wakeup(*args, **kwargs):
|
def inference_step_wakeup(*args, **kwargs):
|
||||||
'''This is a callback function that gets invoked every inference step,
|
'''This is a callback function that gets invoked every inference step,
|
||||||
|
@ -135,18 +130,20 @@ class ModelMngr:
|
||||||
if not isinstance(step, int):
|
if not isinstance(step, int):
|
||||||
step = args[1]
|
step = args[1]
|
||||||
|
|
||||||
if self._tui:
|
if tui:
|
||||||
self._tui.set_progress(step, done=total_steps)
|
tui.set_progress(step, done=total_steps)
|
||||||
|
|
||||||
|
if should_cancel:
|
||||||
|
should_raise = trio.from_thread.run(should_cancel, request_id)
|
||||||
|
|
||||||
should_raise = trio.from_thread.run(self._should_cancel, request_id)
|
|
||||||
if should_raise:
|
if should_raise:
|
||||||
logging.warning(f'CANCELLING work at step {step}')
|
logging.warning(f'CANCELLING work at step {step}')
|
||||||
raise DGPUInferenceCancelled('network cancel')
|
raise DGPUInferenceCancelled('network cancel')
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
if self._tui:
|
if tui:
|
||||||
self._tui.set_status(f'Request #{request_id}')
|
tui.set_status(f'Request #{request_id}')
|
||||||
|
|
||||||
inference_step_wakeup(0)
|
inference_step_wakeup(0)
|
||||||
|
|
||||||
|
@ -160,10 +157,7 @@ class ModelMngr:
|
||||||
name = params['model']
|
name = params['model']
|
||||||
|
|
||||||
match method:
|
match method:
|
||||||
case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
|
case 'txt2img' | 'img2img' | 'inpaint':
|
||||||
if not self.is_model_loaded(name, method):
|
|
||||||
self.load_model(name, method)
|
|
||||||
|
|
||||||
arguments = prepare_params_for_diffuse(
|
arguments = prepare_params_for_diffuse(
|
||||||
params, method, inputs)
|
params, method, inputs)
|
||||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||||
|
@ -175,7 +169,7 @@ class ModelMngr:
|
||||||
extra_params['callback'] = inference_step_wakeup
|
extra_params['callback'] = inference_step_wakeup
|
||||||
extra_params['callback_steps'] = 1
|
extra_params['callback_steps'] = 1
|
||||||
|
|
||||||
output = self._model(
|
output = _model(
|
||||||
prompt,
|
prompt,
|
||||||
guidance_scale=guidance,
|
guidance_scale=guidance,
|
||||||
num_inference_steps=step,
|
num_inference_steps=step,
|
||||||
|
@ -201,14 +195,8 @@ class ModelMngr:
|
||||||
output_hash = sha256(output_binary).hexdigest()
|
output_hash = sha256(output_binary).hexdigest()
|
||||||
|
|
||||||
case 'upscale':
|
case 'upscale':
|
||||||
if self._model_mode != 'upscale':
|
|
||||||
self.unload_model()
|
|
||||||
self._model = init_upscaler()
|
|
||||||
self._model_mode = 'upscale'
|
|
||||||
self._model_name = 'realesrgan'
|
|
||||||
|
|
||||||
input_img = inputs[0].convert('RGB')
|
input_img = inputs[0].convert('RGB')
|
||||||
up_img, _ = self._model.enhance(
|
up_img, _ = _model.enhance(
|
||||||
convert_from_image_to_cv2(input_img), outscale=4)
|
convert_from_image_to_cv2(input_img), outscale=4)
|
||||||
|
|
||||||
output = convert_from_cv2_to_image(up_img)
|
output = convert_from_cv2_to_image(up_img)
|
||||||
|
@ -222,10 +210,7 @@ class ModelMngr:
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
raise DGPUComputeError(str(err)) from err
|
raise DGPUComputeError(str(err)) from err
|
||||||
|
|
||||||
finally:
|
if tui:
|
||||||
torch.cuda.empty_cache()
|
tui.set_status('')
|
||||||
|
|
||||||
if self._tui:
|
|
||||||
self._tui.set_status('')
|
|
||||||
|
|
||||||
return output_hash, output
|
return output_hash, output
|
||||||
|
|
|
@ -18,7 +18,7 @@ from skynet.dgpu.errors import (
|
||||||
DGPUComputeError,
|
DGPUComputeError,
|
||||||
)
|
)
|
||||||
from skynet.dgpu.tui import WorkerMonitor
|
from skynet.dgpu.tui import WorkerMonitor
|
||||||
from skynet.dgpu.compute import ModelMngr
|
from skynet.dgpu.compute import maybe_load_model, compute_one
|
||||||
from skynet.dgpu.network import NetConnector
|
from skynet.dgpu.network import NetConnector
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,12 +40,10 @@ class WorkerDaemon:
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mm: ModelMngr,
|
|
||||||
conn: NetConnector,
|
conn: NetConnector,
|
||||||
config: dict,
|
config: dict,
|
||||||
tui: WorkerMonitor | None = None
|
tui: WorkerMonitor | None = None
|
||||||
):
|
):
|
||||||
self.mm: ModelMngr = mm
|
|
||||||
self.conn: NetConnector = conn
|
self.conn: NetConnector = conn
|
||||||
self._tui = tui
|
self._tui = tui
|
||||||
self.auto_withdraw = (
|
self.auto_withdraw = (
|
||||||
|
@ -248,14 +246,17 @@ class WorkerDaemon:
|
||||||
logging.info(f'calculated request hash: {request_hash}')
|
logging.info(f'calculated request hash: {request_hash}')
|
||||||
|
|
||||||
total_step = body['params']['step']
|
total_step = body['params']['step']
|
||||||
|
model = body['params']['model']
|
||||||
|
mode = body['method']
|
||||||
|
|
||||||
# TODO: validate request
|
# TODO: validate request
|
||||||
|
|
||||||
resp = await self.conn.begin_work(rid)
|
resp = await self.conn.begin_work(rid)
|
||||||
if not resp or 'code' in resp:
|
if not resp or 'code' in resp:
|
||||||
logging.info('begin_work error, probably being worked on already... skip.')
|
logging.info('begin_work error, probably being worked on already... skip.')
|
||||||
|
return False
|
||||||
|
|
||||||
else:
|
with maybe_load_model(model, mode):
|
||||||
try:
|
try:
|
||||||
if self._tui:
|
if self._tui:
|
||||||
self._tui.set_progress(0, done=total_step)
|
self._tui.set_progress(0, done=total_step)
|
||||||
|
@ -268,13 +269,14 @@ class WorkerDaemon:
|
||||||
output_hash = None
|
output_hash = None
|
||||||
match self.backend:
|
match self.backend:
|
||||||
case 'sync-on-thread':
|
case 'sync-on-thread':
|
||||||
self.mm._should_cancel = self.should_cancel_work
|
|
||||||
output_hash, output = await trio.to_thread.run_sync(
|
output_hash, output = await trio.to_thread.run_sync(
|
||||||
partial(
|
partial(
|
||||||
self.mm.compute_one,
|
compute_one,
|
||||||
rid,
|
rid,
|
||||||
body['method'], body['params'],
|
mode, body['params'],
|
||||||
inputs=inputs
|
inputs=inputs,
|
||||||
|
should_cancel=self.should_cancel_work,
|
||||||
|
tui=self._tui
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue