mirror of https://github.com/skygpu/skynet.git
				
				
				
			Start testing inpainting mode
							parent
							
								
									1e40c05da6
								
							
						
					
					
						commit
						8d35e5ed9a
					
				|  | @ -20,7 +20,7 @@ def skynet(*args, **kwargs): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @click.command() | @click.command() | ||||||
| @click.option('--model', '-m', default='midj') | @click.option('--model', '-m', default=list(MODELS.keys())[-1]) | ||||||
| @click.option( | @click.option( | ||||||
|     '--prompt', '-p', default='a red old tractor in a sunny wheat field') |     '--prompt', '-p', default='a red old tractor in a sunny wheat field') | ||||||
| @click.option('--output', '-o', default='output.png') | @click.option('--output', '-o', default='output.png') | ||||||
|  | @ -39,7 +39,7 @@ def txt2img(*args, **kwargs): | ||||||
|     utils.txt2img(hf_token, **kwargs) |     utils.txt2img(hf_token, **kwargs) | ||||||
| 
 | 
 | ||||||
| @click.command() | @click.command() | ||||||
| @click.option('--model', '-m', default=list(MODELS.keys())[0]) | @click.option('--model', '-m', default=list(MODELS.keys())[-2]) | ||||||
| @click.option( | @click.option( | ||||||
|     '--prompt', '-p', default='a red old tractor in a sunny wheat field') |     '--prompt', '-p', default='a red old tractor in a sunny wheat field') | ||||||
| @click.option('--input', '-i', default='input.png') | @click.option('--input', '-i', default='input.png') | ||||||
|  | @ -68,7 +68,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @click.command() | @click.command() | ||||||
| @click.option('--model', '-m', default=list(MODELS.keys())[-1]) | @click.option('--model', '-m', default=list(MODELS.keys())[-3]) | ||||||
| @click.option( | @click.option( | ||||||
|     '--prompt', '-p', default='a red old tractor in a sunny wheat field') |     '--prompt', '-p', default='a red old tractor in a sunny wheat field') | ||||||
| @click.option('--input', '-i', default='input.png') | @click.option('--input', '-i', default='input.png') | ||||||
|  |  | ||||||
|  | @ -4,34 +4,108 @@ VERSION = '0.1a12' | ||||||
| 
 | 
 | ||||||
| DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' | DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' | ||||||
| 
 | 
 | ||||||
| MODELS = { | import msgspec | ||||||
|     'prompthero/openjourney':                           {'short': 'midj',                'mem': 6,   'size': {'w': 512,  'h': 512}}, | from typing import Literal | ||||||
|     'runwayml/stable-diffusion-v1-5':                   {'short': 'stable',              'mem': 6,   'size': {'w': 512,  'h': 512}}, |  | ||||||
|     'stabilityai/stable-diffusion-2-1-base':            {'short': 'stable2',             'mem': 6,   'size': {'w': 512,  'h': 512}}, |  | ||||||
|     'snowkidy/stable-diffusion-xl-base-0.9':            {'short': 'stablexl0.9',         'mem': 8.3, 'size': {'w': 1024, 'h': 1024}}, |  | ||||||
|     'Linaqruf/anything-v3.0':                           {'short': 'hdanime',             'mem': 6,   'size': {'w': 512,  'h': 512}}, |  | ||||||
|     'hakurei/waifu-diffusion':                          {'short': 'waifu',               'mem': 6,   'size': {'w': 512,  'h': 512}}, |  | ||||||
|     'nitrosocke/Ghibli-Diffusion':                      {'short': 'ghibli',              'mem': 6,   'size': {'w': 512,  'h': 512}}, |  | ||||||
|     'dallinmackay/Van-Gogh-diffusion':                  {'short': 'van-gogh',            'mem': 6,   'size': {'w': 512,  'h': 512}}, |  | ||||||
|     'lambdalabs/sd-pokemon-diffusers':                  {'short': 'pokemon',             'mem': 6,   'size': {'w': 512,  'h': 512}}, |  | ||||||
|     'Envvi/Inkpunk-Diffusion':                          {'short': 'ink',                 'mem': 6,   'size': {'w': 512,  'h': 512}}, |  | ||||||
|     'nousr/robo-diffusion':                             {'short': 'robot',               'mem': 6,   'size': {'w': 512,  'h': 512}}, |  | ||||||
| 
 | 
 | ||||||
|     # -1 is always inpaint default | class Size(msgspec.Struct): | ||||||
|     'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': {'short': 'stablexl-inpainting', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}}, |     w: int | ||||||
|  |     h: int | ||||||
| 
 | 
 | ||||||
|     # default is always last | class ModelDesc(msgspec.Struct): | ||||||
|     'stabilityai/stable-diffusion-xl-base-1.0':         {'short': 'stablexl',            'mem': 8.3, 'size': {'w': 1024, 'h': 1024}}, |     short: str | ||||||
|  |     mem: float | ||||||
|  |     size: Size | ||||||
|  |     tags: list[Literal['txt2img', 'img2img', 'inpaint']] | ||||||
|  | 
 | ||||||
|  | MODELS: dict[str, ModelDesc] = { | ||||||
|  |     'runwayml/stable-diffusion-v1-5': ModelDesc( | ||||||
|  |         short='stable', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'stabilityai/stable-diffusion-2-1-base': ModelDesc( | ||||||
|  |         short='stable2', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'snowkidy/stable-diffusion-xl-base-0.9': ModelDesc( | ||||||
|  |         short='stablexl0.9', | ||||||
|  |         mem=8.3, | ||||||
|  |         size=Size(w=1024, h=1024), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'Linaqruf/anything-v3.0': ModelDesc( | ||||||
|  |         short='hdanime', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'hakurei/waifu-diffusion': ModelDesc( | ||||||
|  |         short='waifu', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'nitrosocke/Ghibli-Diffusion': ModelDesc( | ||||||
|  |         short='ghibli', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'dallinmackay/Van-Gogh-diffusion': ModelDesc( | ||||||
|  |         short='van-gogh', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'lambdalabs/sd-pokemon-diffusers': ModelDesc( | ||||||
|  |         short='pokemon', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'Envvi/Inkpunk-Diffusion': ModelDesc( | ||||||
|  |         short='ink', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'nousr/robo-diffusion': ModelDesc( | ||||||
|  |         short='robot', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
|  |     'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc( | ||||||
|  |         short='stablexl-inpainting', | ||||||
|  |         mem=8.3, | ||||||
|  |         size=Size(w=1024, h=1024), | ||||||
|  |         tags=['inpaint'] | ||||||
|  |     ), | ||||||
|  |     'prompthero/openjourney': ModelDesc( | ||||||
|  |         short='midj', | ||||||
|  |         mem=6, | ||||||
|  |         size=Size(w=512, h=512), | ||||||
|  |         tags=['txt2img', 'img2img'] | ||||||
|  |     ), | ||||||
|  |     'stabilityai/stable-diffusion-xl-base-1.0': ModelDesc( | ||||||
|  |         short='stablexl', | ||||||
|  |         mem=8.3, | ||||||
|  |         size=Size(w=1024, h=1024), | ||||||
|  |         tags=['txt2img'] | ||||||
|  |     ), | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| SHORT_NAMES = [ | SHORT_NAMES = [ | ||||||
|     model_info['short'] |     model_info.short | ||||||
|     for model_info in MODELS.values() |     for model_info in MODELS.values() | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| def get_model_by_shortname(short: str): | def get_model_by_shortname(short: str): | ||||||
|     for model, info in MODELS.items(): |     for model, info in MODELS.items(): | ||||||
|         if short == info['short']: |         if short == info.short: | ||||||
|             return model |             return model | ||||||
| 
 | 
 | ||||||
| N = '\n' | N = '\n' | ||||||
|  | @ -169,9 +243,7 @@ DEFAULT_UPSCALER = None | ||||||
| 
 | 
 | ||||||
| DEFAULT_CONFIG_PATH = 'skynet.toml' | DEFAULT_CONFIG_PATH = 'skynet.toml' | ||||||
| 
 | 
 | ||||||
| DEFAULT_INITAL_MODELS = [ | DEFAULT_INITAL_MODEL = list(MODELS.keys())[-1] | ||||||
|     'stabilityai/stable-diffusion-xl-base-1.0' |  | ||||||
| ] |  | ||||||
| 
 | 
 | ||||||
| DATE_FORMAT = '%B the %dth %Y, %H:%M:%S' | DATE_FORMAT = '%B the %dth %Y, %H:%M:%S' | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -13,7 +13,7 @@ from diffusers import DiffusionPipeline | ||||||
| import trio | import trio | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| from skynet.constants import DEFAULT_INITAL_MODELS, MODELS | from skynet.constants import DEFAULT_INITAL_MODEL, MODELS | ||||||
| from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled | from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled | ||||||
| 
 | 
 | ||||||
| from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for | from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for | ||||||
|  | @ -21,26 +21,34 @@ from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_ima | ||||||
| 
 | 
 | ||||||
| def prepare_params_for_diffuse( | def prepare_params_for_diffuse( | ||||||
|     params: dict, |     params: dict, | ||||||
|     input_type: str, |     mode: str, | ||||||
|     binary = None |     inputs: list[bytes] | ||||||
| ): | ): | ||||||
|     _params = {} |     _params = {} | ||||||
|     if binary != None: |     match mode: | ||||||
|         match input_type: |         case 'inpaint': | ||||||
|             case 'png': |  | ||||||
|             image = crop_image( |             image = crop_image( | ||||||
|                     binary, params['width'], params['height']) |                 inputs[0], params['width'], params['height']) | ||||||
|  | 
 | ||||||
|  |             mask = crop_image( | ||||||
|  |                 inputs[1], params['width'], params['height']) | ||||||
| 
 | 
 | ||||||
|             _params['image'] = image |             _params['image'] = image | ||||||
|             _params['strength'] = float(params['strength']) |             _params['strength'] = float(params['strength']) | ||||||
| 
 | 
 | ||||||
|             case 'none': |         case 'img2img': | ||||||
|  |             image = crop_image( | ||||||
|  |                 inputs[0], params['width'], params['height']) | ||||||
|  | 
 | ||||||
|  |             _params['image'] = image | ||||||
|  |             _params['strength'] = float(params['strength']) | ||||||
|  | 
 | ||||||
|  |         case 'txt2img': | ||||||
|             ... |             ... | ||||||
| 
 | 
 | ||||||
|         case _: |         case _: | ||||||
|             raise DGPUComputeError(f'Unknown input_type {input_type}') |             raise DGPUComputeError(f'Unknown input_type {input_type}') | ||||||
| 
 | 
 | ||||||
|     else: |  | ||||||
|     _params['width'] = int(params['width']) |     _params['width'] = int(params['width']) | ||||||
|     _params['height'] = int(params['height']) |     _params['height'] = int(params['height']) | ||||||
| 
 | 
 | ||||||
|  | @ -58,94 +66,52 @@ class SkynetMM: | ||||||
| 
 | 
 | ||||||
|     def __init__(self, config: dict): |     def __init__(self, config: dict): | ||||||
|         self.upscaler = init_upscaler() |         self.upscaler = init_upscaler() | ||||||
|         self.initial_models = ( |  | ||||||
|             config['initial_models'] |  | ||||||
|             if 'initial_models' in config else DEFAULT_INITAL_MODELS |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         self.cache_dir = None |         self.cache_dir = None | ||||||
|         if 'hf_home' in config: |         if 'hf_home' in config: | ||||||
|             self.cache_dir = config['hf_home'] |             self.cache_dir = config['hf_home'] | ||||||
| 
 | 
 | ||||||
|         self._models = {} |         self.load_model(DEFAULT_INITAL_MODEL, 'txt2img') | ||||||
|         for model in self.initial_models: |  | ||||||
|             self.load_model(model, False, force=True) |  | ||||||
| 
 | 
 | ||||||
|     def log_debug_info(self): |     def log_debug_info(self): | ||||||
|         logging.info('memory summary:') |         logging.info('memory summary:') | ||||||
|         logging.info('\n' + torch.cuda.memory_summary()) |         logging.info('\n' + torch.cuda.memory_summary()) | ||||||
| 
 | 
 | ||||||
|     def is_model_loaded(self, model_name: str, image: bool): |     def is_model_loaded(self, name: str, mode: str): | ||||||
|         for model_key, model_data in self._models.items(): |         if (name == self._model_name and | ||||||
|             if (model_key == model_name and |             mode == self._model_mode): | ||||||
|                 model_data['image'] == image): |  | ||||||
|             return True |             return True | ||||||
| 
 | 
 | ||||||
|         return False |         return False | ||||||
| 
 | 
 | ||||||
|     def load_model( |     def load_model( | ||||||
|         self, |         self, | ||||||
|         model_name: str, |         name: str, | ||||||
|         image: bool, |         mode: str | ||||||
|         force=False |  | ||||||
|     ): |     ): | ||||||
|         logging.info(f'loading model {model_name}...') |         logging.info(f'loading model {model_name}...') | ||||||
|         if force or len(self._models.keys()) == 0: |         self._model_mode = mode | ||||||
|             pipe = pipeline_for( |         self._model_name = name | ||||||
|                 model_name, image=image, cache_dir=self.cache_dir) |  | ||||||
| 
 |  | ||||||
|             self._models[model_name] = { |  | ||||||
|                 'pipe': pipe, |  | ||||||
|                 'generated': 0, |  | ||||||
|                 'image': image |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|         else: |  | ||||||
|             least_used = list(self._models.keys())[0] |  | ||||||
| 
 |  | ||||||
|             for model in self._models: |  | ||||||
|                 if self._models[ |  | ||||||
|                     least_used]['generated'] > self._models[model]['generated']: |  | ||||||
|                     least_used = model |  | ||||||
| 
 |  | ||||||
|             del self._models[least_used] |  | ||||||
| 
 |  | ||||||
|             logging.info(f'swapping model {least_used} for {model_name}...') |  | ||||||
| 
 | 
 | ||||||
|         gc.collect() |         gc.collect() | ||||||
|         torch.cuda.empty_cache() |         torch.cuda.empty_cache() | ||||||
| 
 | 
 | ||||||
|             pipe = pipeline_for( |         self._model = pipeline_for( | ||||||
|                 model_name, image=image, cache_dir=self.cache_dir) |             name, mode, cache_dir=self.cache_dir) | ||||||
| 
 | 
 | ||||||
|             self._models[model_name] = { |     def get_model(self, name: str, mode: str) -> DiffusionPipeline: | ||||||
|                 'pipe': pipe, |         if name not in MODELS: | ||||||
|                 'generated': 0, |  | ||||||
|                 'image': image |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|         logging.info(f'loaded model {model_name}') |  | ||||||
|         return pipe |  | ||||||
| 
 |  | ||||||
|     def get_model(self, model_name: str, image: bool) -> DiffusionPipeline: |  | ||||||
|         if model_name not in MODELS: |  | ||||||
|             raise DGPUComputeError(f'Unknown model {model_name}') |             raise DGPUComputeError(f'Unknown model {model_name}') | ||||||
| 
 | 
 | ||||||
|         if not self.is_model_loaded(model_name, image): |         if not self.is_model_loaded(name, mode): | ||||||
|             pipe = self.load_model(model_name, image=image) |             self.load_model(name, mode) | ||||||
| 
 |  | ||||||
|         else: |  | ||||||
|             pipe = self._models[model_name]['pipe'] |  | ||||||
| 
 |  | ||||||
|         return pipe |  | ||||||
| 
 | 
 | ||||||
|     def compute_one( |     def compute_one( | ||||||
|         self, |         self, | ||||||
|         request_id: int, |         request_id: int, | ||||||
|         method: str, |         method: str, | ||||||
|         params: dict, |         params: dict, | ||||||
|         input_type: str = 'png', |         inputs: list[bytes] = [] | ||||||
|         binary: bytes | None = None |  | ||||||
|     ): |     ): | ||||||
|         def maybe_cancel_work(step, *args, **kwargs): |         def maybe_cancel_work(step, *args, **kwargs): | ||||||
|             if self._should_cancel: |             if self._should_cancel: | ||||||
|  | @ -164,17 +130,16 @@ class SkynetMM: | ||||||
|         output_hash = None |         output_hash = None | ||||||
|         try: |         try: | ||||||
|             match method: |             match method: | ||||||
|                 case 'diffuse': |                 case 'txt2img' | 'img2img' | 'inpaint': | ||||||
|                     arguments = prepare_params_for_diffuse( |                     arguments = prepare_params_for_diffuse( | ||||||
|                         params, input_type, binary=binary) |                         params, method, inputs) | ||||||
|                     prompt, guidance, step, seed, upscaler, extra_params = arguments |                     prompt, guidance, step, seed, upscaler, extra_params = arguments | ||||||
|                     model = self.get_model( |                     self.get_model( | ||||||
|                         params['model'], |                         params['model'], | ||||||
|                         'image' in extra_params, |                         method | ||||||
|                         'mask_image' in extra_params |  | ||||||
|                     ) |                     ) | ||||||
| 
 | 
 | ||||||
|                     output = model( |                     output = self._model( | ||||||
|                         prompt, |                         prompt, | ||||||
|                         guidance_scale=guidance, |                         guidance_scale=guidance, | ||||||
|                         num_inference_steps=step, |                         num_inference_steps=step, | ||||||
|  |  | ||||||
|  | @ -117,22 +117,7 @@ class SkynetDGPUDaemon: | ||||||
| 
 | 
 | ||||||
|         return app |         return app | ||||||
| 
 | 
 | ||||||
|     async def serve_forever(self): |     async def maybe_serve_one(self, req): | ||||||
|         try: |  | ||||||
|             while True: |  | ||||||
|                 if self.auto_withdraw: |  | ||||||
|                     await self.conn.maybe_withdraw_all() |  | ||||||
| 
 |  | ||||||
|                 queue = self._snap['queue'] |  | ||||||
| 
 |  | ||||||
|                 random.shuffle(queue) |  | ||||||
|                 queue = sorted( |  | ||||||
|                     queue, |  | ||||||
|                     key=lambda req: convert_reward_to_int(req['reward']), |  | ||||||
|                     reverse=True |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|                 for req in queue: |  | ||||||
|         rid = req['id'] |         rid = req['id'] | ||||||
| 
 | 
 | ||||||
|         # parse request |         # parse request | ||||||
|  | @ -142,23 +127,26 @@ class SkynetDGPUDaemon: | ||||||
|         # if model not known |         # if model not known | ||||||
|         if model not in MODELS: |         if model not in MODELS: | ||||||
|             logging.warning(f'Unknown model {model}') |             logging.warning(f'Unknown model {model}') | ||||||
|                         continue |             return False | ||||||
| 
 | 
 | ||||||
|         # if whitelist enabled and model not in it continue |         # if whitelist enabled and model not in it continue | ||||||
|         if (len(self.model_whitelist) > 0 and |         if (len(self.model_whitelist) > 0 and | ||||||
|             not model in self.model_whitelist): |             not model in self.model_whitelist): | ||||||
|                         continue |             return False | ||||||
| 
 | 
 | ||||||
|         # if blacklist contains model skip |         # if blacklist contains model skip | ||||||
|         if model in self.model_blacklist: |         if model in self.model_blacklist: | ||||||
|                         continue |             return False | ||||||
| 
 | 
 | ||||||
|         my_results = [res['id'] for res in self._snap['my_results']] |         my_results = [res['id'] for res in self._snap['my_results']] | ||||||
|         if rid not in my_results and rid in self._snap['requests']: |         if rid not in my_results and rid in self._snap['requests']: | ||||||
|             statuses = self._snap['requests'][rid] |             statuses = self._snap['requests'][rid] | ||||||
| 
 | 
 | ||||||
|             if len(statuses) == 0: |             if len(statuses) == 0: | ||||||
|                             binary, input_type = await self.conn.get_input_data(req['binary_data']) |                 inputs = [ | ||||||
|  |                     await self.conn.get_input_data(_input) | ||||||
|  |                     for _input in req['binary_data'].split(',') | ||||||
|  |                 ] | ||||||
| 
 | 
 | ||||||
|                 hash_str = ( |                 hash_str = ( | ||||||
|                     str(req['nonce']) |                     str(req['nonce']) | ||||||
|  | @ -195,8 +183,7 @@ class SkynetDGPUDaemon: | ||||||
|                                         self.mm.compute_one, |                                         self.mm.compute_one, | ||||||
|                                         rid, |                                         rid, | ||||||
|                                         body['method'], body['params'], |                                         body['method'], body['params'], | ||||||
|                                                     input_type=input_type, |                                         inputs=inputs | ||||||
|                                                     binary=binary |  | ||||||
|                                     ) |                                     ) | ||||||
|                                 ) |                                 ) | ||||||
| 
 | 
 | ||||||
|  | @ -215,11 +202,30 @@ class SkynetDGPUDaemon: | ||||||
|                         await self.conn.cancel_work(rid, str(e)) |                         await self.conn.cancel_work(rid, str(e)) | ||||||
| 
 | 
 | ||||||
|                     finally: |                     finally: | ||||||
|                                     break |                         return True | ||||||
| 
 | 
 | ||||||
|         else: |         else: | ||||||
|             logging.info(f'request {rid} already beign worked on, skip...') |             logging.info(f'request {rid} already beign worked on, skip...') | ||||||
| 
 | 
 | ||||||
|  |     async def serve_forever(self): | ||||||
|  |         try: | ||||||
|  |             while True: | ||||||
|  |                 if self.auto_withdraw: | ||||||
|  |                     await self.conn.maybe_withdraw_all() | ||||||
|  | 
 | ||||||
|  |                 queue = self._snap['queue'] | ||||||
|  | 
 | ||||||
|  |                 random.shuffle(queue) | ||||||
|  |                 queue = sorted( | ||||||
|  |                     queue, | ||||||
|  |                     key=lambda req: convert_reward_to_int(req['reward']), | ||||||
|  |                     reverse=True | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |                 for req in queue: | ||||||
|  |                     if (await self.maybe_serve_one(req)): | ||||||
|  |                         break | ||||||
|  | 
 | ||||||
|                 await trio.sleep(1) |                 await trio.sleep(1) | ||||||
| 
 | 
 | ||||||
|         except KeyboardInterrupt: |         except KeyboardInterrupt: | ||||||
|  |  | ||||||
|  | @ -267,46 +267,15 @@ class SkynetGPUConnector: | ||||||
| 
 | 
 | ||||||
|         return file_cid |         return file_cid | ||||||
| 
 | 
 | ||||||
|     async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]: |     async def get_input_data(self, ipfs_hash: str) -> Image: | ||||||
|         input_type = 'none' |  | ||||||
| 
 |  | ||||||
|         if ipfs_hash == '': |  | ||||||
|             return b'', input_type |  | ||||||
| 
 |  | ||||||
|         results = {} |  | ||||||
|         ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' |         ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' | ||||||
|         ipfs_link_legacy = ipfs_link + '/image.png' |  | ||||||
| 
 | 
 | ||||||
|         async with trio.open_nursery() as n: |  | ||||||
|             async def get_and_set_results(link: str): |  | ||||||
|         res = await get_ipfs_file(link, timeout=1) |         res = await get_ipfs_file(link, timeout=1) | ||||||
|         logging.info(f'got response from {link}') |         logging.info(f'got response from {link}') | ||||||
|         if not res or res.status_code != 200: |         if not res or res.status_code != 200: | ||||||
|             logging.warning(f'couldn\'t get ipfs binary data at {link}!') |             logging.warning(f'couldn\'t get ipfs binary data at {link}!') | ||||||
| 
 | 
 | ||||||
|                 else: |  | ||||||
|                     try: |  | ||||||
|         # attempt to decode as image |         # attempt to decode as image | ||||||
|                         results[link] = Image.open(io.BytesIO(res.raw)) |         input_data = Image.open(io.BytesIO(res.raw)) | ||||||
|                         input_type = 'png' |  | ||||||
|                         n.cancel_scope.cancel() |  | ||||||
| 
 | 
 | ||||||
|                     except UnidentifiedImageError: |         return input_data | ||||||
|                         logging.warning(f'couldn\'t get ipfs binary data at {link}!') |  | ||||||
| 
 |  | ||||||
|             n.start_soon( |  | ||||||
|                 get_and_set_results, ipfs_link) |  | ||||||
|             n.start_soon( |  | ||||||
|                 get_and_set_results, ipfs_link_legacy) |  | ||||||
| 
 |  | ||||||
|         input_data = None |  | ||||||
|         if ipfs_link_legacy in results: |  | ||||||
|             input_data = results[ipfs_link_legacy] |  | ||||||
| 
 |  | ||||||
|         if ipfs_link in results: |  | ||||||
|             input_data = results[ipfs_link] |  | ||||||
| 
 |  | ||||||
|         if input_data == None: |  | ||||||
|             raise DGPUComputeError('Couldn\'t gather input data from ipfs') |  | ||||||
| 
 |  | ||||||
|         return input_data, input_type |  | ||||||
|  |  | ||||||
|  | @ -39,7 +39,7 @@ def validate_user_config_request(req: str): | ||||||
|                 case 'model' | 'algo': |                 case 'model' | 'algo': | ||||||
|                     attr = 'model' |                     attr = 'model' | ||||||
|                     val = params[2] |                     val = params[2] | ||||||
|                     shorts = [model_info['short'] for model_info in MODELS.values()] |                     shorts = [model_info.short for model_info in MODELS.values()] | ||||||
|                     if val not in shorts: |                     if val not in shorts: | ||||||
|                         raise ConfigUnknownAlgorithm(f'no model named {val}') |                         raise ConfigUnknownAlgorithm(f'no model named {val}') | ||||||
| 
 | 
 | ||||||
|  | @ -112,20 +112,10 @@ def validate_user_config_request(req: str): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def perform_auto_conf(config: dict) -> dict: | def perform_auto_conf(config: dict) -> dict: | ||||||
|     model = config['model'] |     model = MODELS[config['model']] | ||||||
|     prefered_size_w = 512 |  | ||||||
|     prefered_size_h = 512 |  | ||||||
| 
 |  | ||||||
|     if 'xl' in model: |  | ||||||
|         prefered_size_w = 1024 |  | ||||||
|         prefered_size_h = 1024 |  | ||||||
| 
 |  | ||||||
|     else: |  | ||||||
|         prefered_size_w = 512 |  | ||||||
|         prefered_size_h = 512 |  | ||||||
| 
 | 
 | ||||||
|     config['step'] = random.randint(20, 35) |     config['step'] = random.randint(20, 35) | ||||||
|     config['width'] = prefered_size_w |     config['width'] = model.size.w | ||||||
|     config['height'] = prefered_size_h |     config['height'] = model.size.h | ||||||
| 
 | 
 | ||||||
|     return config |     return config | ||||||
|  |  | ||||||
|  | @ -116,7 +116,7 @@ class SkynetTelegramFrontend: | ||||||
|         method: str, |         method: str, | ||||||
|         params: dict, |         params: dict, | ||||||
|         file_id: str | None = None, |         file_id: str | None = None, | ||||||
|         binary_data: str = '' |         inputs: list[str] = [] | ||||||
|     ) -> bool: |     ) -> bool: | ||||||
|         if params['seed'] == None: |         if params['seed'] == None: | ||||||
|             params['seed'] = random.randint(0, 0xFFFFFFFF) |             params['seed'] = random.randint(0, 0xFFFFFFFF) | ||||||
|  | @ -148,7 +148,7 @@ class SkynetTelegramFrontend: | ||||||
|             { |             { | ||||||
|                 'user': Name(self.account), |                 'user': Name(self.account), | ||||||
|                 'request_body': body, |                 'request_body': body, | ||||||
|                 'binary_data': binary_data, |                 'binary_data': inputs.joint(','), | ||||||
|                 'reward': asset_from_str(reward), |                 'reward': asset_from_str(reward), | ||||||
|                 'min_verification': 1 |                 'min_verification': 1 | ||||||
|             }, |             }, | ||||||
|  | @ -181,7 +181,7 @@ class SkynetTelegramFrontend: | ||||||
|         request_id, nonce = out.split(':') |         request_id, nonce = out.split(':') | ||||||
| 
 | 
 | ||||||
|         request_hash = sha256( |         request_hash = sha256( | ||||||
|             (nonce + body + binary_data).encode('utf-8')).hexdigest().upper() |             (nonce + body + inputs.join(',')).encode('utf-8')).hexdigest().upper() | ||||||
| 
 | 
 | ||||||
|         request_id = int(request_id) |         request_id = int(request_id) | ||||||
| 
 | 
 | ||||||
|  | @ -241,11 +241,8 @@ class SkynetTelegramFrontend: | ||||||
|             user, params, tx_hash, worker, reward, self.explorer_domain) |             user, params, tx_hash, worker, reward, self.explorer_domain) | ||||||
| 
 | 
 | ||||||
|         # attempt to get the image and send it |         # attempt to get the image and send it | ||||||
|         results = {} |  | ||||||
|         ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' |         ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' | ||||||
|         ipfs_link_legacy = ipfs_link + '/image.png' |  | ||||||
| 
 | 
 | ||||||
|         async def get_and_set_results(link: str): |  | ||||||
|         res = await get_ipfs_file(link) |         res = await get_ipfs_file(link) | ||||||
|         logging.info(f'got response from {link}') |         logging.info(f'got response from {link}') | ||||||
|         if not res or res.status_code != 200: |         if not res or res.status_code != 200: | ||||||
|  | @ -264,24 +261,9 @@ class SkynetTelegramFrontend: | ||||||
|                     image.save(tmp_buf, format='PNG') |                     image.save(tmp_buf, format='PNG') | ||||||
|                     png_img = tmp_buf.getvalue() |                     png_img = tmp_buf.getvalue() | ||||||
| 
 | 
 | ||||||
|                         results[link] = png_img |  | ||||||
| 
 |  | ||||||
|             except UnidentifiedImageError: |             except UnidentifiedImageError: | ||||||
|                 logging.warning(f'couldn\'t get ipfs binary data at {link}!') |                 logging.warning(f'couldn\'t get ipfs binary data at {link}!') | ||||||
| 
 | 
 | ||||||
|         tasks = [ |  | ||||||
|             get_and_set_results(ipfs_link), |  | ||||||
|             get_and_set_results(ipfs_link_legacy) |  | ||||||
|         ] |  | ||||||
|         await asyncio.gather(*tasks) |  | ||||||
| 
 |  | ||||||
|         png_img = None |  | ||||||
|         if ipfs_link_legacy in results: |  | ||||||
|             png_img = results[ipfs_link_legacy] |  | ||||||
| 
 |  | ||||||
|         if ipfs_link in results: |  | ||||||
|             png_img = results[ipfs_link] |  | ||||||
| 
 |  | ||||||
|         if not png_img: |         if not png_img: | ||||||
|             await self.update_status_message( |             await self.update_status_message( | ||||||
|                 status_msg, |                 status_msg, | ||||||
|  |  | ||||||
|  | @ -66,8 +66,7 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image: | ||||||
| def pipeline_for( | def pipeline_for( | ||||||
|     model: str, |     model: str, | ||||||
|     mem_fraction: float = 1.0, |     mem_fraction: float = 1.0, | ||||||
|     image: bool = False, |     mode: str = [], | ||||||
|     inpainting: bool = False, |  | ||||||
|     cache_dir: str | None = None |     cache_dir: str | None = None | ||||||
| ) -> DiffusionPipeline: | ) -> DiffusionPipeline: | ||||||
| 
 | 
 | ||||||
|  | @ -85,14 +84,14 @@ def pipeline_for( | ||||||
| 
 | 
 | ||||||
|     model_info = MODELS[model] |     model_info = MODELS[model] | ||||||
| 
 | 
 | ||||||
|     req_mem = model_info['mem'] |     req_mem = model_info.mem | ||||||
|     mem_gb = torch.cuda.mem_get_info()[1] / (10**9) |     mem_gb = torch.cuda.mem_get_info()[1] / (10**9) | ||||||
|     mem_gb *= mem_fraction |     mem_gb *= mem_fraction | ||||||
|     over_mem = mem_gb < req_mem |     over_mem = mem_gb < req_mem | ||||||
|     if over_mem: |     if over_mem: | ||||||
|         logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..') |         logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..') | ||||||
| 
 | 
 | ||||||
|     shortname = model_info['short'] |     shortname = model_info.short | ||||||
| 
 | 
 | ||||||
|     params = { |     params = { | ||||||
|         'safety_checker': None, |         'safety_checker': None, | ||||||
|  | @ -107,12 +106,13 @@ def pipeline_for( | ||||||
| 
 | 
 | ||||||
|     torch.cuda.set_per_process_memory_fraction(mem_fraction) |     torch.cuda.set_per_process_memory_fraction(mem_fraction) | ||||||
| 
 | 
 | ||||||
|     if inpainting: |     if 'inpaint' in mode: | ||||||
|         pipe = AutoPipelineForInpainting.from_pretrained( |         pipe_class = AutoPipelineForInpainting | ||||||
|             model, **params) |  | ||||||
| 
 | 
 | ||||||
|     else: |     else: | ||||||
|         pipe = DiffusionPipeline.from_pretrained( |         pipe_class = DiffusionPipeline | ||||||
|  | 
 | ||||||
|  |     pipe = AutoPipelineForInpainting.from_pretrained( | ||||||
|         model, **params) |         model, **params) | ||||||
| 
 | 
 | ||||||
|     pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( |     pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( | ||||||
|  | @ -121,7 +121,7 @@ def pipeline_for( | ||||||
|     pipe.enable_xformers_memory_efficient_attention() |     pipe.enable_xformers_memory_efficient_attention() | ||||||
| 
 | 
 | ||||||
|     if over_mem: |     if over_mem: | ||||||
|         if not image: |         if 'img2img' not in mode: | ||||||
|             pipe.enable_vae_slicing() |             pipe.enable_vae_slicing() | ||||||
|             pipe.enable_vae_tiling() |             pipe.enable_vae_tiling() | ||||||
| 
 | 
 | ||||||
|  | @ -140,7 +140,7 @@ def pipeline_for( | ||||||
| 
 | 
 | ||||||
| def txt2img( | def txt2img( | ||||||
|     hf_token: str, |     hf_token: str, | ||||||
|     model: str = 'prompthero/openjourney', |     model: str = list(MODELS.keys())[-1], | ||||||
|     prompt: str = 'a red old tractor in a sunny wheat field', |     prompt: str = 'a red old tractor in a sunny wheat field', | ||||||
|     output: str = 'output.png', |     output: str = 'output.png', | ||||||
|     width: int = 512, height: int = 512, |     width: int = 512, height: int = 512, | ||||||
|  | @ -166,7 +166,7 @@ def txt2img( | ||||||
| 
 | 
 | ||||||
| def img2img( | def img2img( | ||||||
|     hf_token: str, |     hf_token: str, | ||||||
|     model: str = 'prompthero/openjourney', |     model: str = list(MODELS.keys())[-2], | ||||||
|     prompt: str = 'a red old tractor in a sunny wheat field', |     prompt: str = 'a red old tractor in a sunny wheat field', | ||||||
|     img_path: str = 'input.png', |     img_path: str = 'input.png', | ||||||
|     output: str = 'output.png', |     output: str = 'output.png', | ||||||
|  | @ -181,7 +181,7 @@ def img2img( | ||||||
|     model_info = MODELS[model] |     model_info = MODELS[model] | ||||||
| 
 | 
 | ||||||
|     with open(img_path, 'rb') as img_file: |     with open(img_path, 'rb') as img_file: | ||||||
|         input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h']) |         input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h) | ||||||
| 
 | 
 | ||||||
|     seed = seed if seed else random.randint(0, 2 ** 64) |     seed = seed if seed else random.randint(0, 2 ** 64) | ||||||
|     prompt = prompt |     prompt = prompt | ||||||
|  | @ -198,7 +198,7 @@ def img2img( | ||||||
| 
 | 
 | ||||||
| def inpaint( | def inpaint( | ||||||
|     hf_token: str, |     hf_token: str, | ||||||
|     model: str = 'diffusers/stable-diffusion-xl-1.0-inpainting-0.1', |     model: str = list(MODELS.keys())[-3], | ||||||
|     prompt: str = 'a red old tractor in a sunny wheat field', |     prompt: str = 'a red old tractor in a sunny wheat field', | ||||||
|     img_path: str = 'input.png', |     img_path: str = 'input.png', | ||||||
|     mask_path: str = 'mask.png', |     mask_path: str = 'mask.png', | ||||||
|  | @ -214,10 +214,10 @@ def inpaint( | ||||||
|     model_info = MODELS[model] |     model_info = MODELS[model] | ||||||
| 
 | 
 | ||||||
|     with open(img_path, 'rb') as img_file: |     with open(img_path, 'rb') as img_file: | ||||||
|         input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h']) |         input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h) | ||||||
| 
 | 
 | ||||||
|     with open(mask_path, 'rb') as mask_file: |     with open(mask_path, 'rb') as mask_file: | ||||||
|         mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info['size']['w'], model_info['size']['h']) |         mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info.size.w, model_info.size.h) | ||||||
| 
 | 
 | ||||||
|     seed = seed if seed else random.randint(0, 2 ** 64) |     seed = seed if seed else random.randint(0, 2 ** 64) | ||||||
|     prompt = prompt |     prompt = prompt | ||||||
|  |  | ||||||
|  | @ -0,0 +1,98 @@ | ||||||
|  | 
 | ||||||
|  | from skynet.config import * | ||||||
|  | 
 | ||||||
|  | async def test_txt2img(): | ||||||
|  |     req = { | ||||||
|  |         'id': 0, | ||||||
|  |         'body': json.dumps({ | ||||||
|  |             "method": "txt2img", | ||||||
|  |             "params": { | ||||||
|  |                 "prompt": "Kronos God Realistic 4k", | ||||||
|  |                 "model": list(MODELS.keys())[-1], | ||||||
|  |                 "step": 21, | ||||||
|  |                 "width": 1024, | ||||||
|  |                 "height": 1024, | ||||||
|  |                 "seed": 168402949, | ||||||
|  |                 "guidance": "7.5" | ||||||
|  |             } | ||||||
|  |         }), | ||||||
|  |         'inputs': [], | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     config = load_skynet_toml(file_path=config_path) | ||||||
|  |     hf_token = load_key(config, 'skynet.dgpu.hf_token') | ||||||
|  |     hf_home = load_key(config, 'skynet.dgpu.hf_home') | ||||||
|  |     set_hf_vars(hf_token, hf_home) | ||||||
|  | 
 | ||||||
|  |     assert 'skynet' in config | ||||||
|  |     assert 'dgpu' in config['skynet'] | ||||||
|  | 
 | ||||||
|  |     mm = SkynetMM(config['skynet']['dgpu']) | ||||||
|  | 
 | ||||||
|  |     mm.maybe_serve_one(req) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def test_img2img(): | ||||||
|  |     req = { | ||||||
|  |         'id': 0, | ||||||
|  |         'body': json.dumps({ | ||||||
|  |             "method": "img2img", | ||||||
|  |             "params": { | ||||||
|  |                 "prompt": "Kronos God Realistic 4k", | ||||||
|  |                 "model": list(MODELS.keys())[-2], | ||||||
|  |                 "step": 21, | ||||||
|  |                 "width": 1024, | ||||||
|  |                 "height": 1024, | ||||||
|  |                 "seed": 168402949, | ||||||
|  |                 "guidance": "7.5", | ||||||
|  |                 "strength": "0.5" | ||||||
|  |             } | ||||||
|  |         }), | ||||||
|  |         'inputs': ['QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi'], | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     config = load_skynet_toml(file_path=config_path) | ||||||
|  |     hf_token = load_key(config, 'skynet.dgpu.hf_token') | ||||||
|  |     hf_home = load_key(config, 'skynet.dgpu.hf_home') | ||||||
|  |     set_hf_vars(hf_token, hf_home) | ||||||
|  | 
 | ||||||
|  |     assert 'skynet' in config | ||||||
|  |     assert 'dgpu' in config['skynet'] | ||||||
|  | 
 | ||||||
|  |     mm = SkynetMM(config['skynet']['dgpu']) | ||||||
|  | 
 | ||||||
|  |     mm.maybe_serve_one(req) | ||||||
|  | 
 | ||||||
|  | async def test_inpaint(): | ||||||
|  |     req = { | ||||||
|  |         'id': 0, | ||||||
|  |         'body': json.dumps({ | ||||||
|  |             "method": "inpaint", | ||||||
|  |             "params": { | ||||||
|  |                 "prompt": "a black panther on a sunny roof", | ||||||
|  |                 "model": list(MODELS.keys())[-3], | ||||||
|  |                 "step": 21, | ||||||
|  |                 "width": 1024, | ||||||
|  |                 "height": 1024, | ||||||
|  |                 "seed": 168402949, | ||||||
|  |                 "guidance": "7.5", | ||||||
|  |                 "strength": "0.5" | ||||||
|  |             } | ||||||
|  |         }), | ||||||
|  |         'inputs': [ | ||||||
|  |             'QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi', | ||||||
|  |             'Qmccx1aXNmq5mZDS3YviUhgGHXWhQeHvca3AgA7MDjj2hR' | ||||||
|  |         ], | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     config = load_skynet_toml(file_path=config_path) | ||||||
|  |     hf_token = load_key(config, 'skynet.dgpu.hf_token') | ||||||
|  |     hf_home = load_key(config, 'skynet.dgpu.hf_home') | ||||||
|  |     set_hf_vars(hf_token, hf_home) | ||||||
|  | 
 | ||||||
|  |     assert 'skynet' in config | ||||||
|  |     assert 'dgpu' in config['skynet'] | ||||||
|  | 
 | ||||||
|  |     mm = SkynetMM(config['skynet']['dgpu']) | ||||||
|  | 
 | ||||||
|  |     mm.maybe_serve_one(req) | ||||||
		Loading…
	
		Reference in New Issue