mirror of https://github.com/skygpu/skynet.git
				
				
				
			Refactor ModelMngr to be a context manager + function combo
							parent
							
								
									8b45fb5979
								
							
						
					
					
						commit
						12b32a7188
					
				| 
						 | 
					@ -8,7 +8,6 @@ from hypercorn.trio import serve
 | 
				
			||||||
from quart_trio import QuartTrio as Quart
 | 
					from quart_trio import QuartTrio as Quart
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.dgpu.tui import WorkerMonitor
 | 
					from skynet.dgpu.tui import WorkerMonitor
 | 
				
			||||||
from skynet.dgpu.compute import ModelMngr
 | 
					 | 
				
			||||||
from skynet.dgpu.daemon import WorkerDaemon
 | 
					from skynet.dgpu.daemon import WorkerDaemon
 | 
				
			||||||
from skynet.dgpu.network import NetConnector
 | 
					from skynet.dgpu.network import NetConnector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,8 +47,7 @@ async def open_dgpu_node(config: dict) -> None:
 | 
				
			||||||
        tui = WorkerMonitor()
 | 
					        tui = WorkerMonitor()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    conn = NetConnector(config, tui=tui)
 | 
					    conn = NetConnector(config, tui=tui)
 | 
				
			||||||
    mm = ModelMngr(config, tui=tui)
 | 
					    daemon = WorkerDaemon(conn, config, tui=tui)
 | 
				
			||||||
    daemon = WorkerDaemon(mm, conn, config, tui=tui)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    api: Quart|None = None
 | 
					    api: Quart|None = None
 | 
				
			||||||
    if 'api_bind' in config:
 | 
					    if 'api_bind' in config:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,6 +7,7 @@ import gc
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hashlib import sha256
 | 
					from hashlib import sha256
 | 
				
			||||||
 | 
					from contextlib import contextmanager as cm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
| 
						 | 
					@ -66,166 +67,150 @@ def prepare_params_for_diffuse(
 | 
				
			||||||
        _params
 | 
					        _params
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_model_name: str = ''
 | 
				
			||||||
 | 
					_model_mode: str = ''
 | 
				
			||||||
 | 
					_model = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelMngr:
 | 
					@cm
 | 
				
			||||||
    '''
 | 
					def maybe_load_model(name: str, mode: str):
 | 
				
			||||||
    (AI algo) Model manager for loading models, computing outputs,
 | 
					    if mode == 'diffuse':
 | 
				
			||||||
    checking load state, and unloading when no-longer-needed/finished.
 | 
					        mode = 'txt2img'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    '''
 | 
					    global _model_name, _model_mode, _model
 | 
				
			||||||
    def __init__(self, config: dict, tui: WorkerMonitor | None = None):
 | 
					 | 
				
			||||||
        self._tui = tui
 | 
					 | 
				
			||||||
        self.cache_dir = None
 | 
					 | 
				
			||||||
        if 'hf_home' in config:
 | 
					 | 
				
			||||||
            self.cache_dir = config['hf_home']
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self._model_name: str = ''
 | 
					 | 
				
			||||||
        self._model_mode: str = ''
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def log_debug_info(self):
 | 
					 | 
				
			||||||
        logging.debug('memory summary:')
 | 
					 | 
				
			||||||
        logging.debug('\n' + torch.cuda.memory_summary())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def is_model_loaded(self, name: str, mode: str):
 | 
					 | 
				
			||||||
        if (name == self._model_name and
 | 
					 | 
				
			||||||
            mode == self._model_mode):
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def unload_model(self) -> None:
 | 
					 | 
				
			||||||
        if getattr(self, '_model', None):
 | 
					 | 
				
			||||||
            del self._model
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if _model_name != name or _model_mode != mode:
 | 
				
			||||||
 | 
					        # unload model
 | 
				
			||||||
 | 
					        _model = None
 | 
				
			||||||
        gc.collect()
 | 
					        gc.collect()
 | 
				
			||||||
        torch.cuda.empty_cache()
 | 
					        torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._model_name = ''
 | 
					        _model_name = _model_mode = ''
 | 
				
			||||||
        self._model_mode = ''
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load_model(
 | 
					        # load model
 | 
				
			||||||
        self,
 | 
					        if mode == 'upscale':
 | 
				
			||||||
        name: str,
 | 
					            _model = init_upscaler()
 | 
				
			||||||
        mode: str
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        logging.info(f'loading model {name}...')
 | 
					 | 
				
			||||||
        self.unload_model()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._model = pipeline_for(
 | 
					        else:
 | 
				
			||||||
            name, mode, cache_dir=self.cache_dir)
 | 
					            _model = pipeline_for(
 | 
				
			||||||
        self._model_mode = mode
 | 
					                name, mode, cache_dir='hf_home')
 | 
				
			||||||
        self._model_name = name
 | 
					 | 
				
			||||||
        logging.info(f'{name} loaded!')
 | 
					 | 
				
			||||||
        self.log_debug_info()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compute_one(
 | 
					        _model_name = name
 | 
				
			||||||
        self,
 | 
					        _model_mode = mode
 | 
				
			||||||
        request_id: int,
 | 
					 | 
				
			||||||
        method: str,
 | 
					 | 
				
			||||||
        params: dict,
 | 
					 | 
				
			||||||
        inputs: list[bytes] = []
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        total_steps = params['step']
 | 
					 | 
				
			||||||
        def inference_step_wakeup(*args, **kwargs):
 | 
					 | 
				
			||||||
            '''This is a callback function that gets invoked every inference step,
 | 
					 | 
				
			||||||
            we need to raise an exception here if we need to cancel work
 | 
					 | 
				
			||||||
            '''
 | 
					 | 
				
			||||||
            step = args[0]
 | 
					 | 
				
			||||||
            # compat with callback_on_step_end
 | 
					 | 
				
			||||||
            if not isinstance(step, int):
 | 
					 | 
				
			||||||
                step = args[1]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if self._tui:
 | 
					        logging.debug('memory summary:')
 | 
				
			||||||
                self._tui.set_progress(step, done=total_steps)
 | 
					        logging.debug('\n' + torch.cuda.memory_summary())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            should_raise = trio.from_thread.run(self._should_cancel, request_id)
 | 
					    yield
 | 
				
			||||||
            if should_raise:
 | 
					 | 
				
			||||||
                logging.warning(f'CANCELLING work at step {step}')
 | 
					 | 
				
			||||||
                raise DGPUInferenceCancelled('network cancel')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return {}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self._tui:
 | 
					def compute_one(
 | 
				
			||||||
            self._tui.set_status(f'Request #{request_id}')
 | 
					    request_id: int,
 | 
				
			||||||
 | 
					    method: str,
 | 
				
			||||||
 | 
					    params: dict,
 | 
				
			||||||
 | 
					    inputs: list[bytes] = [],
 | 
				
			||||||
 | 
					    should_cancel = None,
 | 
				
			||||||
 | 
					    tui: WorkerMonitor | None = None
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    if method == 'diffuse':
 | 
				
			||||||
 | 
					        method = 'txt2img'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        inference_step_wakeup(0)
 | 
					    global _model, _model_name, _model_mode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        output_type = 'png'
 | 
					    # validate correct model is loaded
 | 
				
			||||||
        if 'output_type' in params:
 | 
					    assert params['model'] == _model_name
 | 
				
			||||||
            output_type = params['output_type']
 | 
					    assert method == _model_mode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        output = None
 | 
					    total_steps = params['step']
 | 
				
			||||||
        output_hash = None
 | 
					    def inference_step_wakeup(*args, **kwargs):
 | 
				
			||||||
        try:
 | 
					        '''This is a callback function that gets invoked every inference step,
 | 
				
			||||||
            name = params['model']
 | 
					        we need to raise an exception here if we need to cancel work
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        step = args[0]
 | 
				
			||||||
 | 
					        # compat with callback_on_step_end
 | 
				
			||||||
 | 
					        if not isinstance(step, int):
 | 
				
			||||||
 | 
					            step = args[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            match method:
 | 
					        if tui:
 | 
				
			||||||
                case 'diffuse' | 'txt2img' | 'img2img' | 'inpaint':
 | 
					            tui.set_progress(step, done=total_steps)
 | 
				
			||||||
                    if not self.is_model_loaded(name, method):
 | 
					 | 
				
			||||||
                        self.load_model(name, method)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    arguments = prepare_params_for_diffuse(
 | 
					        if should_cancel:
 | 
				
			||||||
                        params, method, inputs)
 | 
					            should_raise = trio.from_thread.run(should_cancel, request_id)
 | 
				
			||||||
                    prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    if 'flux' in name.lower():
 | 
					        if should_raise:
 | 
				
			||||||
                        extra_params['callback_on_step_end'] = inference_step_wakeup
 | 
					            logging.warning(f'CANCELLING work at step {step}')
 | 
				
			||||||
 | 
					            raise DGPUInferenceCancelled('network cancel')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    else:
 | 
					        return {}
 | 
				
			||||||
                        extra_params['callback'] = inference_step_wakeup
 | 
					 | 
				
			||||||
                        extra_params['callback_steps'] = 1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    output = self._model(
 | 
					    if tui:
 | 
				
			||||||
                        prompt,
 | 
					        tui.set_status(f'Request #{request_id}')
 | 
				
			||||||
                        guidance_scale=guidance,
 | 
					 | 
				
			||||||
                        num_inference_steps=step,
 | 
					 | 
				
			||||||
                        generator=seed,
 | 
					 | 
				
			||||||
                        **extra_params
 | 
					 | 
				
			||||||
                    ).images[0]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    output_binary = b''
 | 
					    inference_step_wakeup(0)
 | 
				
			||||||
                    match output_type:
 | 
					 | 
				
			||||||
                        case 'png':
 | 
					 | 
				
			||||||
                            if upscaler == 'x4':
 | 
					 | 
				
			||||||
                                input_img = output.convert('RGB')
 | 
					 | 
				
			||||||
                                up_img, _ = init_upscaler().enhance(
 | 
					 | 
				
			||||||
                                    convert_from_image_to_cv2(input_img), outscale=4)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                output = convert_from_cv2_to_image(up_img)
 | 
					    output_type = 'png'
 | 
				
			||||||
 | 
					    if 'output_type' in params:
 | 
				
			||||||
 | 
					        output_type = params['output_type']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            output_binary = convert_from_img_to_bytes(output)
 | 
					    output = None
 | 
				
			||||||
 | 
					    output_hash = None
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        name = params['model']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        case _:
 | 
					        match method:
 | 
				
			||||||
                            raise DGPUComputeError(f'Unsupported output type: {output_type}')
 | 
					            case 'txt2img' | 'img2img' | 'inpaint':
 | 
				
			||||||
 | 
					                arguments = prepare_params_for_diffuse(
 | 
				
			||||||
 | 
					                    params, method, inputs)
 | 
				
			||||||
 | 
					                prompt, guidance, step, seed, upscaler, extra_params = arguments
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    output_hash = sha256(output_binary).hexdigest()
 | 
					                if 'flux' in name.lower():
 | 
				
			||||||
 | 
					                    extra_params['callback_on_step_end'] = inference_step_wakeup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                case 'upscale':
 | 
					                else:
 | 
				
			||||||
                    if self._model_mode != 'upscale':
 | 
					                    extra_params['callback'] = inference_step_wakeup
 | 
				
			||||||
                        self.unload_model()
 | 
					                    extra_params['callback_steps'] = 1
 | 
				
			||||||
                        self._model = init_upscaler()
 | 
					 | 
				
			||||||
                        self._model_mode = 'upscale'
 | 
					 | 
				
			||||||
                        self._model_name = 'realesrgan'
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    input_img = inputs[0].convert('RGB')
 | 
					                output = _model(
 | 
				
			||||||
                    up_img, _ = self._model.enhance(
 | 
					                    prompt,
 | 
				
			||||||
                        convert_from_image_to_cv2(input_img), outscale=4)
 | 
					                    guidance_scale=guidance,
 | 
				
			||||||
 | 
					                    num_inference_steps=step,
 | 
				
			||||||
 | 
					                    generator=seed,
 | 
				
			||||||
 | 
					                    **extra_params
 | 
				
			||||||
 | 
					                ).images[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    output = convert_from_cv2_to_image(up_img)
 | 
					                output_binary = b''
 | 
				
			||||||
 | 
					                match output_type:
 | 
				
			||||||
 | 
					                    case 'png':
 | 
				
			||||||
 | 
					                        if upscaler == 'x4':
 | 
				
			||||||
 | 
					                            input_img = output.convert('RGB')
 | 
				
			||||||
 | 
					                            up_img, _ = init_upscaler().enhance(
 | 
				
			||||||
 | 
					                                convert_from_image_to_cv2(input_img), outscale=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    output_binary = convert_from_img_to_bytes(output)
 | 
					                            output = convert_from_cv2_to_image(up_img)
 | 
				
			||||||
                    output_hash = sha256(output_binary).hexdigest()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                case _:
 | 
					                        output_binary = convert_from_img_to_bytes(output)
 | 
				
			||||||
                    raise DGPUComputeError('Unsupported compute method')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except BaseException as err:
 | 
					                    case _:
 | 
				
			||||||
            raise DGPUComputeError(str(err)) from err
 | 
					                        raise DGPUComputeError(f'Unsupported output type: {output_type}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        finally:
 | 
					                output_hash = sha256(output_binary).hexdigest()
 | 
				
			||||||
            torch.cuda.empty_cache()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self._tui:
 | 
					            case 'upscale':
 | 
				
			||||||
            self._tui.set_status('')
 | 
					                input_img = inputs[0].convert('RGB')
 | 
				
			||||||
 | 
					                up_img, _ = _model.enhance(
 | 
				
			||||||
 | 
					                    convert_from_image_to_cv2(input_img), outscale=4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return output_hash, output
 | 
					                output = convert_from_cv2_to_image(up_img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                output_binary = convert_from_img_to_bytes(output)
 | 
				
			||||||
 | 
					                output_hash = sha256(output_binary).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            case _:
 | 
				
			||||||
 | 
					                raise DGPUComputeError('Unsupported compute method')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    except BaseException as err:
 | 
				
			||||||
 | 
					        raise DGPUComputeError(str(err)) from err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if tui:
 | 
				
			||||||
 | 
					        tui.set_status('')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return output_hash, output
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,7 +18,7 @@ from skynet.dgpu.errors import (
 | 
				
			||||||
    DGPUComputeError,
 | 
					    DGPUComputeError,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from skynet.dgpu.tui import WorkerMonitor
 | 
					from skynet.dgpu.tui import WorkerMonitor
 | 
				
			||||||
from skynet.dgpu.compute import ModelMngr
 | 
					from skynet.dgpu.compute import maybe_load_model, compute_one
 | 
				
			||||||
from skynet.dgpu.network import NetConnector
 | 
					from skynet.dgpu.network import NetConnector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,12 +40,10 @@ class WorkerDaemon:
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        mm: ModelMngr,
 | 
					 | 
				
			||||||
        conn: NetConnector,
 | 
					        conn: NetConnector,
 | 
				
			||||||
        config: dict,
 | 
					        config: dict,
 | 
				
			||||||
        tui: WorkerMonitor | None = None
 | 
					        tui: WorkerMonitor | None = None
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        self.mm: ModelMngr = mm
 | 
					 | 
				
			||||||
        self.conn: NetConnector = conn
 | 
					        self.conn: NetConnector = conn
 | 
				
			||||||
        self._tui = tui
 | 
					        self._tui = tui
 | 
				
			||||||
        self.auto_withdraw = (
 | 
					        self.auto_withdraw = (
 | 
				
			||||||
| 
						 | 
					@ -248,14 +246,17 @@ class WorkerDaemon:
 | 
				
			||||||
        logging.info(f'calculated request hash: {request_hash}')
 | 
					        logging.info(f'calculated request hash: {request_hash}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        total_step = body['params']['step']
 | 
					        total_step = body['params']['step']
 | 
				
			||||||
 | 
					        model = body['params']['model']
 | 
				
			||||||
 | 
					        mode = body['method']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: validate request
 | 
					        # TODO: validate request
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        resp = await self.conn.begin_work(rid)
 | 
					        resp = await self.conn.begin_work(rid)
 | 
				
			||||||
        if not resp or 'code' in resp:
 | 
					        if not resp or 'code' in resp:
 | 
				
			||||||
            logging.info('begin_work error, probably being worked on already... skip.')
 | 
					            logging.info('begin_work error, probably being worked on already... skip.')
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        with maybe_load_model(model, mode):
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                if self._tui:
 | 
					                if self._tui:
 | 
				
			||||||
                    self._tui.set_progress(0, done=total_step)
 | 
					                    self._tui.set_progress(0, done=total_step)
 | 
				
			||||||
| 
						 | 
					@ -268,13 +269,14 @@ class WorkerDaemon:
 | 
				
			||||||
                output_hash = None
 | 
					                output_hash = None
 | 
				
			||||||
                match self.backend:
 | 
					                match self.backend:
 | 
				
			||||||
                    case 'sync-on-thread':
 | 
					                    case 'sync-on-thread':
 | 
				
			||||||
                        self.mm._should_cancel = self.should_cancel_work
 | 
					 | 
				
			||||||
                        output_hash, output = await trio.to_thread.run_sync(
 | 
					                        output_hash, output = await trio.to_thread.run_sync(
 | 
				
			||||||
                            partial(
 | 
					                            partial(
 | 
				
			||||||
                                self.mm.compute_one,
 | 
					                                compute_one,
 | 
				
			||||||
                                rid,
 | 
					                                rid,
 | 
				
			||||||
                                body['method'], body['params'],
 | 
					                                mode, body['params'],
 | 
				
			||||||
                                inputs=inputs
 | 
					                                inputs=inputs,
 | 
				
			||||||
 | 
					                                should_cancel=self.should_cancel_work,
 | 
				
			||||||
 | 
					                                tui=self._tui
 | 
				
			||||||
                            )
 | 
					                            )
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue