mirror of https://github.com/skygpu/skynet.git
				
				
				
			Refactor ModelMngr to be a context manager + function combo
							parent
							
								
									b3dc7c1074
								
							
						
					
					
						commit
						cd028d15e7
					
				|  | @ -8,7 +8,6 @@ from hypercorn.trio import serve | |||
| from quart_trio import QuartTrio as Quart | ||||
| 
 | ||||
| from skynet.dgpu.tui import WorkerMonitor | ||||
| from skynet.dgpu.compute import ModelMngr | ||||
| from skynet.dgpu.daemon import WorkerDaemon | ||||
| from skynet.dgpu.network import NetConnector | ||||
| 
 | ||||
|  | @ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None: | |||
|         tui = WorkerMonitor() | ||||
| 
 | ||||
|     conn = NetConnector(config, tui=tui) | ||||
|     mm = ModelMngr(config, tui=tui) | ||||
|     daemon = WorkerDaemon(mm, conn, config, tui=tui) | ||||
|     daemon = WorkerDaemon(conn, config, tui=tui) | ||||
| 
 | ||||
|     api: Quart|None = None | ||||
|     if 'api_bind' in config: | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ import gc | |||
| import logging | ||||
| 
 | ||||
| from hashlib import sha256 | ||||
| from contextlib import contextmanager as cm | ||||
| 
 | ||||
| import trio | ||||
| import torch | ||||
|  | @ -66,65 +67,59 @@ def prepare_params_for_diffuse( | |||
|         _params | ||||
|     ) | ||||
| 
 | ||||
| _model_name: str = '' | ||||
| _model_mode: str = '' | ||||
| _model = None | ||||
| 
 | ||||
| class ModelMngr: | ||||
|     ''' | ||||
|     (AI algo) Model manager for loading models, computing outputs, | ||||
|     checking load state, and unloading when no-longer-needed/finished. | ||||
| @cm | ||||
| def maybe_load_model(name: str, mode: str): | ||||
|     if mode == 'diffuse': | ||||
|         mode = 'txt2img' | ||||
| 
 | ||||
|     ''' | ||||
|     def __init__(self, config: dict, tui: WorkerMonitor | None = None): | ||||
|         self._tui = tui | ||||
|         self.cache_dir = None | ||||
|         if 'hf_home' in config: | ||||
|             self.cache_dir = config['hf_home'] | ||||
| 
 | ||||
|         self._model_name: str = '' | ||||
|         self._model_mode: str = '' | ||||
| 
 | ||||
|     def log_debug_info(self): | ||||
|         logging.debug('memory summary:') | ||||
|         logging.debug('\n' + torch.cuda.memory_summary()) | ||||
| 
 | ||||
|     def is_model_loaded(self, name: str, mode: str): | ||||
|         if (name == self._model_name and | ||||
|             mode == self._model_mode): | ||||
|             return True | ||||
| 
 | ||||
|         return False | ||||
| 
 | ||||
|     def unload_model(self) -> None: | ||||
|         if getattr(self, '_model', None): | ||||
|             del self._model | ||||
|     global _model_name, _model_mode, _model | ||||
| 
 | ||||
|     if _model_name != name or _model_mode != mode: | ||||
|         # unload model | ||||
|         _model = None | ||||
|         gc.collect() | ||||
|         torch.cuda.empty_cache() | ||||
| 
 | ||||
|         self._model_name = '' | ||||
|         self._model_mode = '' | ||||
|         _model_name = _model_mode = '' | ||||
| 
 | ||||
|     def load_model( | ||||
|         self, | ||||
|         name: str, | ||||
|         mode: str | ||||
|     ) -> None: | ||||
|         logging.info(f'loading model {name}...') | ||||
|         self.unload_model() | ||||
|         # load model | ||||
|         if mode == 'upscale': | ||||
|             _model = init_upscaler() | ||||
| 
 | ||||
|         self._model = pipeline_for( | ||||
|             name, mode, cache_dir=self.cache_dir) | ||||
|         self._model_mode = mode | ||||
|         self._model_name = name | ||||
|         logging.info(f'{name} loaded!') | ||||
|         self.log_debug_info() | ||||
|         else: | ||||
|             _model = pipeline_for( | ||||
|                 name, mode, cache_dir='hf_home') | ||||
| 
 | ||||
|     def compute_one( | ||||
|         self, | ||||
|         _model_name = name | ||||
|         _model_mode = mode | ||||
| 
 | ||||
|         logging.debug('memory summary:') | ||||
|         logging.debug('\n' + torch.cuda.memory_summary()) | ||||
| 
 | ||||
|     yield | ||||
| 
 | ||||
| 
 | ||||
| def compute_one( | ||||
|     request_id: int, | ||||
|     method: str, | ||||
|     params: dict, | ||||
|         inputs: list[bytes] = [] | ||||
|     ): | ||||
|     inputs: list[bytes] = [], | ||||
|     should_cancel = None, | ||||
|     tui: WorkerMonitor | None = None | ||||
| ): | ||||
|     if method == 'diffuse': | ||||
|         method = 'txt2img' | ||||
| 
 | ||||
|     global _model, _model_name, _model_mode | ||||
| 
 | ||||
|     # validate correct model is loaded | ||||
|     assert params['model'] == _model_name | ||||
|     assert method == _model_mode | ||||
| 
 | ||||
|     total_steps = params['step'] | ||||
|     def inference_step_wakeup(*args, **kwargs): | ||||
|         '''This is a callback function that gets invoked every inference step, | ||||
|  | @ -135,18 +130,20 @@ class ModelMngr: | |||
|         if not isinstance(step, int): | ||||
|             step = args[1] | ||||
| 
 | ||||
|             if self._tui: | ||||
|                 self._tui.set_progress(step, done=total_steps) | ||||
|         if tui: | ||||
|             tui.set_progress(step, done=total_steps) | ||||
| 
 | ||||
|         if should_cancel: | ||||
|             should_raise = trio.from_thread.run(should_cancel, request_id) | ||||
| 
 | ||||
|             should_raise = trio.from_thread.run(self._should_cancel, request_id) | ||||
|         if should_raise: | ||||
|             logging.warning(f'CANCELLING work at step {step}') | ||||
|             raise DGPUInferenceCancelled('network cancel') | ||||
| 
 | ||||
|         return {} | ||||
| 
 | ||||
|         if self._tui: | ||||
|             self._tui.set_status(f'Request #{request_id}') | ||||
|     if tui: | ||||
|         tui.set_status(f'Request #{request_id}') | ||||
| 
 | ||||
|     inference_step_wakeup(0) | ||||
| 
 | ||||
|  | @ -160,10 +157,7 @@ class ModelMngr: | |||
|         name = params['model'] | ||||
| 
 | ||||
|         match method: | ||||
|                 case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint': | ||||
|                     if not self.is_model_loaded(name, method): | ||||
|                         self.load_model(name, method) | ||||
| 
 | ||||
|             case 'txt2img' | 'img2img' | 'inpaint': | ||||
|                 arguments = prepare_params_for_diffuse( | ||||
|                     params, method, inputs) | ||||
|                 prompt, guidance, step, seed, upscaler, extra_params = arguments | ||||
|  | @ -175,7 +169,7 @@ class ModelMngr: | |||
|                     extra_params['callback'] = inference_step_wakeup | ||||
|                     extra_params['callback_steps'] = 1 | ||||
| 
 | ||||
|                     output = self._model( | ||||
|                 output = _model( | ||||
|                     prompt, | ||||
|                     guidance_scale=guidance, | ||||
|                     num_inference_steps=step, | ||||
|  | @ -201,14 +195,8 @@ class ModelMngr: | |||
|                 output_hash = sha256(output_binary).hexdigest() | ||||
| 
 | ||||
|             case 'upscale': | ||||
|                     if self._model_mode != 'upscale': | ||||
|                         self.unload_model() | ||||
|                         self._model = init_upscaler() | ||||
|                         self._model_mode = 'upscale' | ||||
|                         self._model_name = 'realesrgan' | ||||
| 
 | ||||
|                 input_img = inputs[0].convert('RGB') | ||||
|                     up_img, _ = self._model.enhance( | ||||
|                 up_img, _ = _model.enhance( | ||||
|                     convert_from_image_to_cv2(input_img), outscale=4) | ||||
| 
 | ||||
|                 output = convert_from_cv2_to_image(up_img) | ||||
|  | @ -222,10 +210,7 @@ class ModelMngr: | |||
|     except BaseException as err: | ||||
|         raise DGPUComputeError(str(err)) from err | ||||
| 
 | ||||
|         finally: | ||||
|             torch.cuda.empty_cache() | ||||
| 
 | ||||
|         if self._tui: | ||||
|             self._tui.set_status('') | ||||
|     if tui: | ||||
|         tui.set_status('') | ||||
| 
 | ||||
|     return output_hash, output | ||||
|  |  | |||
|  | @ -18,7 +18,7 @@ from skynet.dgpu.errors import ( | |||
|     DGPUComputeError, | ||||
| ) | ||||
| from skynet.dgpu.tui import WorkerMonitor | ||||
| from skynet.dgpu.compute import ModelMngr | ||||
| from skynet.dgpu.compute import maybe_load_model, compute_one | ||||
| from skynet.dgpu.network import NetConnector | ||||
| 
 | ||||
| 
 | ||||
|  | @ -40,12 +40,10 @@ class WorkerDaemon: | |||
|     ''' | ||||
|     def __init__( | ||||
|         self, | ||||
|         mm: ModelMngr, | ||||
|         conn: NetConnector, | ||||
|         config: dict, | ||||
|         tui: WorkerMonitor | None = None | ||||
|     ): | ||||
|         self.mm: ModelMngr = mm | ||||
|         self.conn: NetConnector = conn | ||||
|         self._tui = tui | ||||
|         self.auto_withdraw = ( | ||||
|  | @ -248,14 +246,17 @@ class WorkerDaemon: | |||
|         logging.info(f'calculated request hash: {request_hash}') | ||||
| 
 | ||||
|         total_step = body['params']['step'] | ||||
|         model = body['params']['model'] | ||||
|         mode = body['method'] | ||||
| 
 | ||||
|         # TODO: validate request | ||||
| 
 | ||||
|         resp = await self.conn.begin_work(rid) | ||||
|         if not resp or 'code' in resp: | ||||
|             logging.info('begin_work error, probably being worked on already... skip.') | ||||
|             return False | ||||
| 
 | ||||
|         else: | ||||
|         with maybe_load_model(model, mode): | ||||
|             try: | ||||
|                 if self._tui: | ||||
|                     self._tui.set_progress(0, done=total_step) | ||||
|  | @ -268,13 +269,14 @@ class WorkerDaemon: | |||
|                 output_hash = None | ||||
|                 match self.backend: | ||||
|                     case 'sync-on-thread': | ||||
|                         self.mm._should_cancel = self.should_cancel_work | ||||
|                         output_hash, output = await trio.to_thread.run_sync( | ||||
|                             partial( | ||||
|                                 self.mm.compute_one, | ||||
|                                 compute_one, | ||||
|                                 rid, | ||||
|                                 body['method'], body['params'], | ||||
|                                 inputs=inputs | ||||
|                                 mode, body['params'], | ||||
|                                 inputs=inputs, | ||||
|                                 should_cancel=self.should_cancel_work, | ||||
|                                 tui=self._tui | ||||
|                             ) | ||||
|                         ) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue