mirror of https://github.com/skygpu/skynet.git
				
				
				
			Refactor ModelMngr to be a context manager + function combo
							parent
							
								
									8b45fb5979
								
							
						
					
					
						commit
						12b32a7188
					
				| 
						 | 
				
			
			@ -8,7 +8,6 @@ from hypercorn.trio import serve
 | 
			
		|||
from quart_trio import QuartTrio as Quart
 | 
			
		||||
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.dgpu.compute import ModelMngr
 | 
			
		||||
from skynet.dgpu.daemon import WorkerDaemon
 | 
			
		||||
from skynet.dgpu.network import NetConnector
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None:
 | 
			
		|||
        tui = WorkerMonitor()
 | 
			
		||||
 | 
			
		||||
    conn = NetConnector(config, tui=tui)
 | 
			
		||||
    mm = ModelMngr(config, tui=tui)
 | 
			
		||||
    daemon = WorkerDaemon(mm, conn, config, tui=tui)
 | 
			
		||||
    daemon = WorkerDaemon(conn, config, tui=tui)
 | 
			
		||||
 | 
			
		||||
    api: Quart|None = None
 | 
			
		||||
    if 'api_bind' in config:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,6 +7,7 @@ import gc
 | 
			
		|||
import logging
 | 
			
		||||
 | 
			
		||||
from hashlib import sha256
 | 
			
		||||
from contextlib import contextmanager as cm
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import torch
 | 
			
		||||
| 
						 | 
				
			
			@ -66,65 +67,59 @@ def prepare_params_for_diffuse(
 | 
			
		|||
        _params
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
_model_name: str = ''
 | 
			
		||||
_model_mode: str = ''
 | 
			
		||||
_model = None
 | 
			
		||||
 | 
			
		||||
class ModelMngr:
 | 
			
		||||
    '''
 | 
			
		||||
    (AI algo) Model manager for loading models, computing outputs,
 | 
			
		||||
    checking load state, and unloading when no-longer-needed/finished.
 | 
			
		||||
@cm
 | 
			
		||||
def maybe_load_model(name: str, mode: str):
 | 
			
		||||
    if mode == 'diffuse':
 | 
			
		||||
        mode = 'txt2img'
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(self, config: dict, tui: WorkerMonitor | None = None):
 | 
			
		||||
        self._tui = tui
 | 
			
		||||
        self.cache_dir = None
 | 
			
		||||
        if 'hf_home' in config:
 | 
			
		||||
            self.cache_dir = config['hf_home']
 | 
			
		||||
 | 
			
		||||
        self._model_name: str = ''
 | 
			
		||||
        self._model_mode: str = ''
 | 
			
		||||
 | 
			
		||||
    def log_debug_info(self):
 | 
			
		||||
        logging.debug('memory summary:')
 | 
			
		||||
        logging.debug('\n' + torch.cuda.memory_summary())
 | 
			
		||||
 | 
			
		||||
    def is_model_loaded(self, name: str, mode: str):
 | 
			
		||||
        if (name == self._model_name and
 | 
			
		||||
            mode == self._model_mode):
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def unload_model(self) -> None:
 | 
			
		||||
        if getattr(self, '_model', None):
 | 
			
		||||
            del self._model
 | 
			
		||||
    global _model_name, _model_mode, _model
 | 
			
		||||
 | 
			
		||||
    if _model_name != name or _model_mode != mode:
 | 
			
		||||
        # unload model
 | 
			
		||||
        _model = None
 | 
			
		||||
        gc.collect()
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
        self._model_name = ''
 | 
			
		||||
        self._model_mode = ''
 | 
			
		||||
        _model_name = _model_mode = ''
 | 
			
		||||
 | 
			
		||||
    def load_model(
 | 
			
		||||
        self,
 | 
			
		||||
        name: str,
 | 
			
		||||
        mode: str
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        logging.info(f'loading model {name}...')
 | 
			
		||||
        self.unload_model()
 | 
			
		||||
        # load model
 | 
			
		||||
        if mode == 'upscale':
 | 
			
		||||
            _model = init_upscaler()
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            _model = pipeline_for(
 | 
			
		||||
                name, mode, cache_dir='hf_home')
 | 
			
		||||
 | 
			
		||||
        _model_name = name
 | 
			
		||||
        _model_mode = mode
 | 
			
		||||
 | 
			
		||||
        logging.debug('memory summary:')
 | 
			
		||||
        logging.debug('\n' + torch.cuda.memory_summary())
 | 
			
		||||
 | 
			
		||||
    yield
 | 
			
		||||
 | 
			
		||||
        self._model = pipeline_for(
 | 
			
		||||
            name, mode, cache_dir=self.cache_dir)
 | 
			
		||||
        self._model_mode = mode
 | 
			
		||||
        self._model_name = name
 | 
			
		||||
        logging.info(f'{name} loaded!')
 | 
			
		||||
        self.log_debug_info()
 | 
			
		||||
 | 
			
		||||
def compute_one(
 | 
			
		||||
        self,
 | 
			
		||||
    request_id: int,
 | 
			
		||||
    method: str,
 | 
			
		||||
    params: dict,
 | 
			
		||||
        inputs: list[bytes] = []
 | 
			
		||||
    inputs: list[bytes] = [],
 | 
			
		||||
    should_cancel = None,
 | 
			
		||||
    tui: WorkerMonitor | None = None
 | 
			
		||||
):
 | 
			
		||||
    if method == 'diffuse':
 | 
			
		||||
        method = 'txt2img'
 | 
			
		||||
 | 
			
		||||
    global _model, _model_name, _model_mode
 | 
			
		||||
 | 
			
		||||
    # validate correct model is loaded
 | 
			
		||||
    assert params['model'] == _model_name
 | 
			
		||||
    assert method == _model_mode
 | 
			
		||||
 | 
			
		||||
    total_steps = params['step']
 | 
			
		||||
    def inference_step_wakeup(*args, **kwargs):
 | 
			
		||||
        '''This is a callback function that gets invoked every inference step,
 | 
			
		||||
| 
						 | 
				
			
			@ -135,18 +130,20 @@ class ModelMngr:
 | 
			
		|||
        if not isinstance(step, int):
 | 
			
		||||
            step = args[1]
 | 
			
		||||
 | 
			
		||||
            if self._tui:
 | 
			
		||||
                self._tui.set_progress(step, done=total_steps)
 | 
			
		||||
        if tui:
 | 
			
		||||
            tui.set_progress(step, done=total_steps)
 | 
			
		||||
 | 
			
		||||
        if should_cancel:
 | 
			
		||||
            should_raise = trio.from_thread.run(should_cancel, request_id)
 | 
			
		||||
 | 
			
		||||
            should_raise = trio.from_thread.run(self._should_cancel, request_id)
 | 
			
		||||
        if should_raise:
 | 
			
		||||
            logging.warning(f'CANCELLING work at step {step}')
 | 
			
		||||
            raise DGPUInferenceCancelled('network cancel')
 | 
			
		||||
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
        if self._tui:
 | 
			
		||||
            self._tui.set_status(f'Request #{request_id}')
 | 
			
		||||
    if tui:
 | 
			
		||||
        tui.set_status(f'Request #{request_id}')
 | 
			
		||||
 | 
			
		||||
    inference_step_wakeup(0)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -160,10 +157,7 @@ class ModelMngr:
 | 
			
		|||
        name = params['model']
 | 
			
		||||
 | 
			
		||||
        match method:
 | 
			
		||||
                case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
 | 
			
		||||
                    if not self.is_model_loaded(name, method):
 | 
			
		||||
                        self.load_model(name, method)
 | 
			
		||||
 | 
			
		||||
            case 'txt2img' | 'img2img' | 'inpaint':
 | 
			
		||||
                arguments = prepare_params_for_diffuse(
 | 
			
		||||
                    params, method, inputs)
 | 
			
		||||
                prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
			
		||||
| 
						 | 
				
			
			@ -175,7 +169,7 @@ class ModelMngr:
 | 
			
		|||
                    extra_params['callback'] = inference_step_wakeup
 | 
			
		||||
                    extra_params['callback_steps'] = 1
 | 
			
		||||
 | 
			
		||||
                    output = self._model(
 | 
			
		||||
                output = _model(
 | 
			
		||||
                    prompt,
 | 
			
		||||
                    guidance_scale=guidance,
 | 
			
		||||
                    num_inference_steps=step,
 | 
			
		||||
| 
						 | 
				
			
			@ -201,14 +195,8 @@ class ModelMngr:
 | 
			
		|||
                output_hash = sha256(output_binary).hexdigest()
 | 
			
		||||
 | 
			
		||||
            case 'upscale':
 | 
			
		||||
                    if self._model_mode != 'upscale':
 | 
			
		||||
                        self.unload_model()
 | 
			
		||||
                        self._model = init_upscaler()
 | 
			
		||||
                        self._model_mode = 'upscale'
 | 
			
		||||
                        self._model_name = 'realesrgan'
 | 
			
		||||
 | 
			
		||||
                input_img = inputs[0].convert('RGB')
 | 
			
		||||
                    up_img, _ = self._model.enhance(
 | 
			
		||||
                up_img, _ = _model.enhance(
 | 
			
		||||
                    convert_from_image_to_cv2(input_img), outscale=4)
 | 
			
		||||
 | 
			
		||||
                output = convert_from_cv2_to_image(up_img)
 | 
			
		||||
| 
						 | 
				
			
			@ -222,10 +210,7 @@ class ModelMngr:
 | 
			
		|||
    except BaseException as err:
 | 
			
		||||
        raise DGPUComputeError(str(err)) from err
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
            torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
        if self._tui:
 | 
			
		||||
            self._tui.set_status('')
 | 
			
		||||
    if tui:
 | 
			
		||||
        tui.set_status('')
 | 
			
		||||
 | 
			
		||||
    return output_hash, output
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,7 @@ from skynet.dgpu.errors import (
 | 
			
		|||
    DGPUComputeError,
 | 
			
		||||
)
 | 
			
		||||
from skynet.dgpu.tui import WorkerMonitor
 | 
			
		||||
from skynet.dgpu.compute import ModelMngr
 | 
			
		||||
from skynet.dgpu.compute import maybe_load_model, compute_one
 | 
			
		||||
from skynet.dgpu.network import NetConnector
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -40,12 +40,10 @@ class WorkerDaemon:
 | 
			
		|||
    '''
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        mm: ModelMngr,
 | 
			
		||||
        conn: NetConnector,
 | 
			
		||||
        config: dict,
 | 
			
		||||
        tui: WorkerMonitor | None = None
 | 
			
		||||
    ):
 | 
			
		||||
        self.mm: ModelMngr = mm
 | 
			
		||||
        self.conn: NetConnector = conn
 | 
			
		||||
        self._tui = tui
 | 
			
		||||
        self.auto_withdraw = (
 | 
			
		||||
| 
						 | 
				
			
			@ -248,14 +246,17 @@ class WorkerDaemon:
 | 
			
		|||
        logging.info(f'calculated request hash: {request_hash}')
 | 
			
		||||
 | 
			
		||||
        total_step = body['params']['step']
 | 
			
		||||
        model = body['params']['model']
 | 
			
		||||
        mode = body['method']
 | 
			
		||||
 | 
			
		||||
        # TODO: validate request
 | 
			
		||||
 | 
			
		||||
        resp = await self.conn.begin_work(rid)
 | 
			
		||||
        if not resp or 'code' in resp:
 | 
			
		||||
            logging.info('begin_work error, probably being worked on already... skip.')
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
        with maybe_load_model(model, mode):
 | 
			
		||||
            try:
 | 
			
		||||
                if self._tui:
 | 
			
		||||
                    self._tui.set_progress(0, done=total_step)
 | 
			
		||||
| 
						 | 
				
			
			@ -268,13 +269,14 @@ class WorkerDaemon:
 | 
			
		|||
                output_hash = None
 | 
			
		||||
                match self.backend:
 | 
			
		||||
                    case 'sync-on-thread':
 | 
			
		||||
                        self.mm._should_cancel = self.should_cancel_work
 | 
			
		||||
                        output_hash, output = await trio.to_thread.run_sync(
 | 
			
		||||
                            partial(
 | 
			
		||||
                                self.mm.compute_one,
 | 
			
		||||
                                compute_one,
 | 
			
		||||
                                rid,
 | 
			
		||||
                                body['method'], body['params'],
 | 
			
		||||
                                inputs=inputs
 | 
			
		||||
                                mode, body['params'],
 | 
			
		||||
                                inputs=inputs,
 | 
			
		||||
                                should_cancel=self.should_cancel_work,
 | 
			
		||||
                                tui=self._tui
 | 
			
		||||
                            )
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue