diff --git a/skynet/dgpu/__init__.py b/skynet/dgpu/__init__.py index 59af61c..f5fcc9e 100755 --- a/skynet/dgpu/__init__.py +++ b/skynet/dgpu/__init__.py @@ -8,7 +8,6 @@ from hypercorn.trio import serve from quart_trio import QuartTrio as Quart from skynet.dgpu.tui import WorkerMonitor -from skynet.dgpu.compute import ModelMngr from skynet.dgpu.daemon import WorkerDaemon from skynet.dgpu.network import NetConnector @@ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None: tui = WorkerMonitor() conn = NetConnector(config, tui=tui) - mm = ModelMngr(config, tui=tui) - daemon = WorkerDaemon(mm, conn, config, tui=tui) + daemon = WorkerDaemon(conn, config, tui=tui) api: Quart|None = None if 'api_bind' in config: diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index d0e8689..ec50054 100755 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -7,6 +7,7 @@ import gc import logging from hashlib import sha256 +from contextlib import contextmanager as cm import trio import torch @@ -66,166 +67,150 @@ def prepare_params_for_diffuse( _params ) +_model_name: str = '' +_model_mode: str = '' +_model = None -class ModelMngr: - ''' - (AI algo) Model manager for loading models, computing outputs, - checking load state, and unloading when no-longer-needed/finished. +@cm +def maybe_load_model(name: str, mode: str): + if mode == 'diffuse': + mode = 'txt2img' - ''' - def __init__(self, config: dict, tui: WorkerMonitor | None = None): - self._tui = tui - self.cache_dir = None - if 'hf_home' in config: - self.cache_dir = config['hf_home'] - - self._model_name: str = '' - self._model_mode: str = '' - - def log_debug_info(self): - logging.debug('memory summary:') - logging.debug('\n' + torch.cuda.memory_summary()) - - def is_model_loaded(self, name: str, mode: str): - if (name == self._model_name and - mode == self._model_mode): - return True - - return False - - def unload_model(self) -> None: - if getattr(self, '_model', None): - del self._model + global _model_name, _model_mode, _model + if _model_name != name or _model_mode != mode: + # unload model + _model = None gc.collect() torch.cuda.empty_cache() - self._model_name = '' - self._model_mode = '' + _model_name = _model_mode = '' - def load_model( - self, - name: str, - mode: str - ) -> None: - logging.info(f'loading model {name}...') - self.unload_model() + # load model + if mode == 'upscale': + _model = init_upscaler() - self._model = pipeline_for( - name, mode, cache_dir=self.cache_dir) - self._model_mode = mode - self._model_name = name - logging.info(f'{name} loaded!') - self.log_debug_info() + else: + _model = pipeline_for( + name, mode, cache_dir='hf_home') - def compute_one( - self, - request_id: int, - method: str, - params: dict, - inputs: list[bytes] = [] - ): - total_steps = params['step'] - def inference_step_wakeup(*args, **kwargs): - '''This is a callback function that gets invoked every inference step, - we need to raise an exception here if we need to cancel work - ''' - step = args[0] - # compat with callback_on_step_end - if not isinstance(step, int): - step = args[1] + _model_name = name + _model_mode = mode - if self._tui: - self._tui.set_progress(step, done=total_steps) + logging.debug('memory summary:') + logging.debug('\n' + torch.cuda.memory_summary()) - should_raise = trio.from_thread.run(self._should_cancel, request_id) - if should_raise: - logging.warning(f'CANCELLING work at step {step}') - raise DGPUInferenceCancelled('network cancel') + yield - return {} - if self._tui: - self._tui.set_status(f'Request #{request_id}') +def compute_one( + request_id: int, + method: str, + params: dict, + inputs: list[bytes] = [], + should_cancel = None, + tui: WorkerMonitor | None = None +): + if method == 'diffuse': + method = 'txt2img' - inference_step_wakeup(0) + global _model, _model_name, _model_mode - output_type = 'png' - if 'output_type' in params: - output_type = params['output_type'] + # validate correct model is loaded + assert params['model'] == _model_name + assert method == _model_mode - output = None - output_hash = None - try: - name = params['model'] + total_steps = params['step'] + def inference_step_wakeup(*args, **kwargs): + '''This is a callback function that gets invoked every inference step, + we need to raise an exception here if we need to cancel work + ''' + step = args[0] + # compat with callback_on_step_end + if not isinstance(step, int): + step = args[1] - match method: - case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint': - if not self.is_model_loaded(name, method): - self.load_model(name, method) + if tui: + tui.set_progress(step, done=total_steps) - arguments = prepare_params_for_diffuse( - params, method, inputs) - prompt, guidance, step, seed, upscaler, extra_params = arguments + if should_cancel: + should_raise = trio.from_thread.run(should_cancel, request_id) - if 'flux' in name.lower(): - extra_params['callback_on_step_end'] = inference_step_wakeup + if should_raise: + logging.warning(f'CANCELLING work at step {step}') + raise DGPUInferenceCancelled('network cancel') - else: - extra_params['callback'] = inference_step_wakeup - extra_params['callback_steps'] = 1 + return {} - output = self._model( - prompt, - guidance_scale=guidance, - num_inference_steps=step, - generator=seed, - **extra_params - ).images[0] + if tui: + tui.set_status(f'Request #{request_id}') - output_binary = b'' - match output_type: - case 'png': - if upscaler == 'x4': - input_img = output.convert('RGB') - up_img, _ = init_upscaler().enhance( - convert_from_image_to_cv2(input_img), outscale=4) + inference_step_wakeup(0) - output = convert_from_cv2_to_image(up_img) + output_type = 'png' + if 'output_type' in params: + output_type = params['output_type'] - output_binary = convert_from_img_to_bytes(output) + output = None + output_hash = None + try: + name = params['model'] - case _: - raise DGPUComputeError(f'Unsupported output type: {output_type}') + match method: + case 'txt2img' | 'img2img' | 'inpaint': + arguments = prepare_params_for_diffuse( + params, method, inputs) + prompt, guidance, step, seed, upscaler, extra_params = arguments - output_hash = sha256(output_binary).hexdigest() + if 'flux' in name.lower(): + extra_params['callback_on_step_end'] = inference_step_wakeup - case 'upscale': - if self._model_mode != 'upscale': - self.unload_model() - self._model = init_upscaler() - self._model_mode = 'upscale' - self._model_name = 'realesrgan' + else: + extra_params['callback'] = inference_step_wakeup + extra_params['callback_steps'] = 1 - input_img = inputs[0].convert('RGB') - up_img, _ = self._model.enhance( - convert_from_image_to_cv2(input_img), outscale=4) + output = _model( + prompt, + guidance_scale=guidance, + num_inference_steps=step, + generator=seed, + **extra_params + ).images[0] - output = convert_from_cv2_to_image(up_img) + output_binary = b'' + match output_type: + case 'png': + if upscaler == 'x4': + input_img = output.convert('RGB') + up_img, _ = init_upscaler().enhance( + convert_from_image_to_cv2(input_img), outscale=4) - output_binary = convert_from_img_to_bytes(output) - output_hash = sha256(output_binary).hexdigest() + output = convert_from_cv2_to_image(up_img) - case _: - raise DGPUComputeError('Unsupported compute method') + output_binary = convert_from_img_to_bytes(output) - except BaseException as err: - raise DGPUComputeError(str(err)) from err + case _: + raise DGPUComputeError(f'Unsupported output type: {output_type}') - finally: - torch.cuda.empty_cache() + output_hash = sha256(output_binary).hexdigest() - if self._tui: - self._tui.set_status('') + case 'upscale': + input_img = inputs[0].convert('RGB') + up_img, _ = _model.enhance( + convert_from_image_to_cv2(input_img), outscale=4) - return output_hash, output + output = convert_from_cv2_to_image(up_img) + + output_binary = convert_from_img_to_bytes(output) + output_hash = sha256(output_binary).hexdigest() + + case _: + raise DGPUComputeError('Unsupported compute method') + + except BaseException as err: + raise DGPUComputeError(str(err)) from err + + if tui: + tui.set_status('') + + return output_hash, output diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 98d3eda..4c0bdce 100755 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -18,7 +18,7 @@ from skynet.dgpu.errors import ( DGPUComputeError, ) from skynet.dgpu.tui import WorkerMonitor -from skynet.dgpu.compute import ModelMngr +from skynet.dgpu.compute import maybe_load_model, compute_one from skynet.dgpu.network import NetConnector @@ -40,12 +40,10 @@ class WorkerDaemon: ''' def __init__( self, - mm: ModelMngr, conn: NetConnector, config: dict, tui: WorkerMonitor | None = None ): - self.mm: ModelMngr = mm self.conn: NetConnector = conn self._tui = tui self.auto_withdraw = ( @@ -248,14 +246,17 @@ class WorkerDaemon: logging.info(f'calculated request hash: {request_hash}') total_step = body['params']['step'] + model = body['params']['model'] + mode = body['method'] # TODO: validate request resp = await self.conn.begin_work(rid) if not resp or 'code' in resp: logging.info('begin_work error, probably being worked on already... skip.') + return False - else: + with maybe_load_model(model, mode): try: if self._tui: self._tui.set_progress(0, done=total_step) @@ -268,13 +269,14 @@ class WorkerDaemon: output_hash = None match self.backend: case 'sync-on-thread': - self.mm._should_cancel = self.should_cancel_work output_hash, output = await trio.to_thread.run_sync( partial( - self.mm.compute_one, + compute_one, rid, - body['method'], body['params'], - inputs=inputs + mode, body['params'], + inputs=inputs, + should_cancel=self.should_cancel_work, + tui=self._tui ) )