diff --git a/skynet/cli.py b/skynet/cli.py index e35b9f4..9da80de 100755 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -20,7 +20,7 @@ def skynet(*args, **kwargs): @click.command() -@click.option('--model', '-m', default='midj') +@click.option('--model', '-m', default=list(MODELS.keys())[-1]) @click.option( '--prompt', '-p', default='a red old tractor in a sunny wheat field') @click.option('--output', '-o', default='output.png') @@ -39,7 +39,7 @@ def txt2img(*args, **kwargs): utils.txt2img(hf_token, **kwargs) @click.command() -@click.option('--model', '-m', default=list(MODELS.keys())[0]) +@click.option('--model', '-m', default=list(MODELS.keys())[-2]) @click.option( '--prompt', '-p', default='a red old tractor in a sunny wheat field') @click.option('--input', '-i', default='input.png') @@ -68,7 +68,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed): @click.command() -@click.option('--model', '-m', default=list(MODELS.keys())[-1]) +@click.option('--model', '-m', default=list(MODELS.keys())[-3]) @click.option( '--prompt', '-p', default='a red old tractor in a sunny wheat field') @click.option('--input', '-i', default='input.png') diff --git a/skynet/constants.py b/skynet/constants.py index 28b5db3..a4a7bde 100755 --- a/skynet/constants.py +++ b/skynet/constants.py @@ -4,34 +4,108 @@ VERSION = '0.1a12' DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda' -MODELS = { - 'prompthero/openjourney': {'short': 'midj', 'mem': 6, 'size': {'w': 512, 'h': 512}}, - 'runwayml/stable-diffusion-v1-5': {'short': 'stable', 'mem': 6, 'size': {'w': 512, 'h': 512}}, - 'stabilityai/stable-diffusion-2-1-base': {'short': 'stable2', 'mem': 6, 'size': {'w': 512, 'h': 512}}, - 'snowkidy/stable-diffusion-xl-base-0.9': {'short': 'stablexl0.9', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}}, - 'Linaqruf/anything-v3.0': {'short': 'hdanime', 'mem': 6, 'size': {'w': 512, 'h': 512}}, - 'hakurei/waifu-diffusion': {'short': 'waifu', 'mem': 6, 'size': {'w': 512, 'h': 512}}, - 'nitrosocke/Ghibli-Diffusion': {'short': 'ghibli', 'mem': 6, 'size': {'w': 512, 'h': 512}}, - 'dallinmackay/Van-Gogh-diffusion': {'short': 'van-gogh', 'mem': 6, 'size': {'w': 512, 'h': 512}}, - 'lambdalabs/sd-pokemon-diffusers': {'short': 'pokemon', 'mem': 6, 'size': {'w': 512, 'h': 512}}, - 'Envvi/Inkpunk-Diffusion': {'short': 'ink', 'mem': 6, 'size': {'w': 512, 'h': 512}}, - 'nousr/robo-diffusion': {'short': 'robot', 'mem': 6, 'size': {'w': 512, 'h': 512}}, +import msgspec +from typing import Literal - # -1 is always inpaint default - 'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': {'short': 'stablexl-inpainting', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}}, +class Size(msgspec.Struct): + w: int + h: int - # default is always last - 'stabilityai/stable-diffusion-xl-base-1.0': {'short': 'stablexl', 'mem': 8.3, 'size': {'w': 1024, 'h': 1024}}, +class ModelDesc(msgspec.Struct): + short: str + mem: float + size: Size + tags: list[Literal['txt2img', 'img2img', 'inpaint']] + +MODELS: dict[str, ModelDesc] = { + 'runwayml/stable-diffusion-v1-5': ModelDesc( + short='stable', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img'] + ), + 'stabilityai/stable-diffusion-2-1-base': ModelDesc( + short='stable2', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img'] + ), + 'snowkidy/stable-diffusion-xl-base-0.9': ModelDesc( + short='stablexl0.9', + mem=8.3, + size=Size(w=1024, h=1024), + tags=['txt2img'] + ), + 'Linaqruf/anything-v3.0': ModelDesc( + short='hdanime', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img'] + ), + 'hakurei/waifu-diffusion': ModelDesc( + short='waifu', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img'] + ), + 'nitrosocke/Ghibli-Diffusion': ModelDesc( + short='ghibli', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img'] + ), + 'dallinmackay/Van-Gogh-diffusion': ModelDesc( + short='van-gogh', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img'] + ), + 'lambdalabs/sd-pokemon-diffusers': ModelDesc( + short='pokemon', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img'] + ), + 'Envvi/Inkpunk-Diffusion': ModelDesc( + short='ink', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img'] + ), + 'nousr/robo-diffusion': ModelDesc( + short='robot', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img'] + ), + 'diffusers/stable-diffusion-xl-1.0-inpainting-0.1': ModelDesc( + short='stablexl-inpainting', + mem=8.3, + size=Size(w=1024, h=1024), + tags=['inpaint'] + ), + 'prompthero/openjourney': ModelDesc( + short='midj', + mem=6, + size=Size(w=512, h=512), + tags=['txt2img', 'img2img'] + ), + 'stabilityai/stable-diffusion-xl-base-1.0': ModelDesc( + short='stablexl', + mem=8.3, + size=Size(w=1024, h=1024), + tags=['txt2img'] + ), } SHORT_NAMES = [ - model_info['short'] + model_info.short for model_info in MODELS.values() ] def get_model_by_shortname(short: str): for model, info in MODELS.items(): - if short == info['short']: + if short == info.short: return model N = '\n' @@ -169,9 +243,7 @@ DEFAULT_UPSCALER = None DEFAULT_CONFIG_PATH = 'skynet.toml' -DEFAULT_INITAL_MODELS = [ - 'stabilityai/stable-diffusion-xl-base-1.0' -] +DEFAULT_INITAL_MODEL = list(MODELS.keys())[-1] DATE_FORMAT = '%B the %dth %Y, %H:%M:%S' diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index 909ce7a..a0cbe62 100644 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -13,7 +13,7 @@ from diffusers import DiffusionPipeline import trio import torch -from skynet.constants import DEFAULT_INITAL_MODELS, MODELS +from skynet.constants import DEFAULT_INITAL_MODEL, MODELS from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for @@ -21,28 +21,36 @@ from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_ima def prepare_params_for_diffuse( params: dict, - input_type: str, - binary = None + mode: str, + inputs: list[bytes] ): _params = {} - if binary != None: - match input_type: - case 'png': - image = crop_image( - binary, params['width'], params['height']) + match mode: + case 'inpaint': + image = crop_image( + inputs[0], params['width'], params['height']) - _params['image'] = image - _params['strength'] = float(params['strength']) + mask = crop_image( + inputs[1], params['width'], params['height']) - case 'none': - ... + _params['image'] = image + _params['strength'] = float(params['strength']) - case _: - raise DGPUComputeError(f'Unknown input_type {input_type}') + case 'img2img': + image = crop_image( + inputs[0], params['width'], params['height']) - else: - _params['width'] = int(params['width']) - _params['height'] = int(params['height']) + _params['image'] = image + _params['strength'] = float(params['strength']) + + case 'txt2img': + ... + + case _: + raise DGPUComputeError(f'Unknown input_type {input_type}') + + _params['width'] = int(params['width']) + _params['height'] = int(params['height']) return ( params['prompt'], @@ -58,94 +66,52 @@ class SkynetMM: def __init__(self, config: dict): self.upscaler = init_upscaler() - self.initial_models = ( - config['initial_models'] - if 'initial_models' in config else DEFAULT_INITAL_MODELS - ) self.cache_dir = None if 'hf_home' in config: self.cache_dir = config['hf_home'] - self._models = {} - for model in self.initial_models: - self.load_model(model, False, force=True) + self.load_model(DEFAULT_INITAL_MODEL, 'txt2img') def log_debug_info(self): logging.info('memory summary:') logging.info('\n' + torch.cuda.memory_summary()) - def is_model_loaded(self, model_name: str, image: bool): - for model_key, model_data in self._models.items(): - if (model_key == model_name and - model_data['image'] == image): - return True + def is_model_loaded(self, name: str, mode: str): + if (name == self._model_name and + mode == self._model_mode): + return True return False def load_model( self, - model_name: str, - image: bool, - force=False + name: str, + mode: str ): logging.info(f'loading model {model_name}...') - if force or len(self._models.keys()) == 0: - pipe = pipeline_for( - model_name, image=image, cache_dir=self.cache_dir) + self._model_mode = mode + self._model_name = name - self._models[model_name] = { - 'pipe': pipe, - 'generated': 0, - 'image': image - } + gc.collect() + torch.cuda.empty_cache() - else: - least_used = list(self._models.keys())[0] + self._model = pipeline_for( + name, mode, cache_dir=self.cache_dir) - for model in self._models: - if self._models[ - least_used]['generated'] > self._models[model]['generated']: - least_used = model - - del self._models[least_used] - - logging.info(f'swapping model {least_used} for {model_name}...') - - gc.collect() - torch.cuda.empty_cache() - - pipe = pipeline_for( - model_name, image=image, cache_dir=self.cache_dir) - - self._models[model_name] = { - 'pipe': pipe, - 'generated': 0, - 'image': image - } - - logging.info(f'loaded model {model_name}') - return pipe - - def get_model(self, model_name: str, image: bool) -> DiffusionPipeline: - if model_name not in MODELS: + def get_model(self, name: str, mode: str) -> DiffusionPipeline: + if name not in MODELS: raise DGPUComputeError(f'Unknown model {model_name}') - if not self.is_model_loaded(model_name, image): - pipe = self.load_model(model_name, image=image) - - else: - pipe = self._models[model_name]['pipe'] - - return pipe + if not self.is_model_loaded(name, mode): + self.load_model(name, mode) def compute_one( self, request_id: int, method: str, params: dict, - input_type: str = 'png', - binary: bytes | None = None + inputs: list[bytes] = [] ): def maybe_cancel_work(step, *args, **kwargs): if self._should_cancel: @@ -164,17 +130,16 @@ class SkynetMM: output_hash = None try: match method: - case 'diffuse': + case 'txt2img' | 'img2img' | 'inpaint': arguments = prepare_params_for_diffuse( - params, input_type, binary=binary) + params, method, inputs) prompt, guidance, step, seed, upscaler, extra_params = arguments - model = self.get_model( + self.get_model( params['model'], - 'image' in extra_params, - 'mask_image' in extra_params + method ) - output = model( + output = self._model( prompt, guidance_scale=guidance, num_inference_steps=step, diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 7cd68d2..34b574c 100644 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -117,6 +117,96 @@ class SkynetDGPUDaemon: return app + async def maybe_serve_one(self, req): + rid = req['id'] + + # parse request + body = json.loads(req['body']) + model = body['params']['model'] + + # if model not known + if model not in MODELS: + logging.warning(f'Unknown model {model}') + return False + + # if whitelist enabled and model not in it continue + if (len(self.model_whitelist) > 0 and + not model in self.model_whitelist): + return False + + # if blacklist contains model skip + if model in self.model_blacklist: + return False + + my_results = [res['id'] for res in self._snap['my_results']] + if rid not in my_results and rid in self._snap['requests']: + statuses = self._snap['requests'][rid] + + if len(statuses) == 0: + inputs = [ + await self.conn.get_input_data(_input) + for _input in req['binary_data'].split(',') + ] + + hash_str = ( + str(req['nonce']) + + + req['body'] + + + req['binary_data'] + ) + logging.info(f'hashing: {hash_str}') + request_hash = sha256(hash_str.encode('utf-8')).hexdigest() + + # TODO: validate request + + # perform work + logging.info(f'working on {body}') + + resp = await self.conn.begin_work(rid) + if 'code' in resp: + logging.info(f'probably being worked on already... skip.') + + else: + try: + output_type = 'png' + if 'output_type' in body['params']: + output_type = body['params']['output_type'] + + output = None + output_hash = None + match self.backend: + case 'sync-on-thread': + self.mm._should_cancel = self.should_cancel_work + output_hash, output = await trio.to_thread.run_sync( + partial( + self.mm.compute_one, + rid, + body['method'], body['params'], + inputs=inputs + ) + ) + + case _: + raise DGPUComputeError(f'Unsupported backend {self.backend}') + self._last_generation_ts = datetime.now().isoformat() + self._last_benchmark = self._benchmark + self._benchmark = [] + + ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type) + + await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash) + + except BaseException as e: + traceback.print_exc() + await self.conn.cancel_work(rid, str(e)) + + finally: + return True + + else: + logging.info(f'request {rid} already beign worked on, skip...') + async def serve_forever(self): try: while True: @@ -133,92 +223,8 @@ class SkynetDGPUDaemon: ) for req in queue: - rid = req['id'] - - # parse request - body = json.loads(req['body']) - model = body['params']['model'] - - # if model not known - if model not in MODELS: - logging.warning(f'Unknown model {model}') - continue - - # if whitelist enabled and model not in it continue - if (len(self.model_whitelist) > 0 and - not model in self.model_whitelist): - continue - - # if blacklist contains model skip - if model in self.model_blacklist: - continue - - my_results = [res['id'] for res in self._snap['my_results']] - if rid not in my_results and rid in self._snap['requests']: - statuses = self._snap['requests'][rid] - - if len(statuses) == 0: - binary, input_type = await self.conn.get_input_data(req['binary_data']) - - hash_str = ( - str(req['nonce']) - + - req['body'] - + - req['binary_data'] - ) - logging.info(f'hashing: {hash_str}') - request_hash = sha256(hash_str.encode('utf-8')).hexdigest() - - # TODO: validate request - - # perform work - logging.info(f'working on {body}') - - resp = await self.conn.begin_work(rid) - if 'code' in resp: - logging.info(f'probably being worked on already... skip.') - - else: - try: - output_type = 'png' - if 'output_type' in body['params']: - output_type = body['params']['output_type'] - - output = None - output_hash = None - match self.backend: - case 'sync-on-thread': - self.mm._should_cancel = self.should_cancel_work - output_hash, output = await trio.to_thread.run_sync( - partial( - self.mm.compute_one, - rid, - body['method'], body['params'], - input_type=input_type, - binary=binary - ) - ) - - case _: - raise DGPUComputeError(f'Unsupported backend {self.backend}') - self._last_generation_ts = datetime.now().isoformat() - self._last_benchmark = self._benchmark - self._benchmark = [] - - ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type) - - await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash) - - except BaseException as e: - traceback.print_exc() - await self.conn.cancel_work(rid, str(e)) - - finally: - break - - else: - logging.info(f'request {rid} already beign worked on, skip...') + if (await self.maybe_serve_one(req)): + break await trio.sleep(1) diff --git a/skynet/dgpu/network.py b/skynet/dgpu/network.py index dfe4d67..46f1363 100644 --- a/skynet/dgpu/network.py +++ b/skynet/dgpu/network.py @@ -267,46 +267,15 @@ class SkynetGPUConnector: return file_cid - async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]: - input_type = 'none' - - if ipfs_hash == '': - return b'', input_type - - results = {} + async def get_input_data(self, ipfs_hash: str) -> Image: ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' - ipfs_link_legacy = ipfs_link + '/image.png' - async with trio.open_nursery() as n: - async def get_and_set_results(link: str): - res = await get_ipfs_file(link, timeout=1) - logging.info(f'got response from {link}') - if not res or res.status_code != 200: - logging.warning(f'couldn\'t get ipfs binary data at {link}!') + res = await get_ipfs_file(link, timeout=1) + logging.info(f'got response from {link}') + if not res or res.status_code != 200: + logging.warning(f'couldn\'t get ipfs binary data at {link}!') - else: - try: - # attempt to decode as image - results[link] = Image.open(io.BytesIO(res.raw)) - input_type = 'png' - n.cancel_scope.cancel() + # attempt to decode as image + input_data = Image.open(io.BytesIO(res.raw)) - except UnidentifiedImageError: - logging.warning(f'couldn\'t get ipfs binary data at {link}!') - - n.start_soon( - get_and_set_results, ipfs_link) - n.start_soon( - get_and_set_results, ipfs_link_legacy) - - input_data = None - if ipfs_link_legacy in results: - input_data = results[ipfs_link_legacy] - - if ipfs_link in results: - input_data = results[ipfs_link] - - if input_data == None: - raise DGPUComputeError('Couldn\'t gather input data from ipfs') - - return input_data, input_type + return input_data diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index 30d6fa1..bb5e9bc 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -39,7 +39,7 @@ def validate_user_config_request(req: str): case 'model' | 'algo': attr = 'model' val = params[2] - shorts = [model_info['short'] for model_info in MODELS.values()] + shorts = [model_info.short for model_info in MODELS.values()] if val not in shorts: raise ConfigUnknownAlgorithm(f'no model named {val}') @@ -112,20 +112,10 @@ def validate_user_config_request(req: str): def perform_auto_conf(config: dict) -> dict: - model = config['model'] - prefered_size_w = 512 - prefered_size_h = 512 - - if 'xl' in model: - prefered_size_w = 1024 - prefered_size_h = 1024 - - else: - prefered_size_w = 512 - prefered_size_h = 512 + model = MODELS[config['model']] config['step'] = random.randint(20, 35) - config['width'] = prefered_size_w - config['height'] = prefered_size_h + config['width'] = model.size.w + config['height'] = model.size.h return config diff --git a/skynet/frontend/telegram/__init__.py b/skynet/frontend/telegram/__init__.py index d4fd549..a6eadbb 100644 --- a/skynet/frontend/telegram/__init__.py +++ b/skynet/frontend/telegram/__init__.py @@ -116,7 +116,7 @@ class SkynetTelegramFrontend: method: str, params: dict, file_id: str | None = None, - binary_data: str = '' + inputs: list[str] = [] ) -> bool: if params['seed'] == None: params['seed'] = random.randint(0, 0xFFFFFFFF) @@ -148,7 +148,7 @@ class SkynetTelegramFrontend: { 'user': Name(self.account), 'request_body': body, - 'binary_data': binary_data, + 'binary_data': inputs.joint(','), 'reward': asset_from_str(reward), 'min_verification': 1 }, @@ -181,7 +181,7 @@ class SkynetTelegramFrontend: request_id, nonce = out.split(':') request_hash = sha256( - (nonce + body + binary_data).encode('utf-8')).hexdigest().upper() + (nonce + body + inputs.join(',')).encode('utf-8')).hexdigest().upper() request_id = int(request_id) @@ -241,47 +241,29 @@ class SkynetTelegramFrontend: user, params, tx_hash, worker, reward, self.explorer_domain) # attempt to get the image and send it - results = {} ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' - ipfs_link_legacy = ipfs_link + '/image.png' - async def get_and_set_results(link: str): - res = await get_ipfs_file(link) - logging.info(f'got response from {link}') - if not res or res.status_code != 200: + res = await get_ipfs_file(link) + logging.info(f'got response from {link}') + if not res or res.status_code != 200: + logging.warning(f'couldn\'t get ipfs binary data at {link}!') + + else: + try: + with Image.open(io.BytesIO(res.raw)) as image: + w, h = image.size + + if w > TG_MAX_WIDTH or h > TG_MAX_HEIGHT: + logging.warning(f'result is of size {image.size}') + image.thumbnail((TG_MAX_WIDTH, TG_MAX_HEIGHT)) + + tmp_buf = io.BytesIO() + image.save(tmp_buf, format='PNG') + png_img = tmp_buf.getvalue() + + except UnidentifiedImageError: logging.warning(f'couldn\'t get ipfs binary data at {link}!') - else: - try: - with Image.open(io.BytesIO(res.raw)) as image: - w, h = image.size - - if w > TG_MAX_WIDTH or h > TG_MAX_HEIGHT: - logging.warning(f'result is of size {image.size}') - image.thumbnail((TG_MAX_WIDTH, TG_MAX_HEIGHT)) - - tmp_buf = io.BytesIO() - image.save(tmp_buf, format='PNG') - png_img = tmp_buf.getvalue() - - results[link] = png_img - - except UnidentifiedImageError: - logging.warning(f'couldn\'t get ipfs binary data at {link}!') - - tasks = [ - get_and_set_results(ipfs_link), - get_and_set_results(ipfs_link_legacy) - ] - await asyncio.gather(*tasks) - - png_img = None - if ipfs_link_legacy in results: - png_img = results[ipfs_link_legacy] - - if ipfs_link in results: - png_img = results[ipfs_link] - if not png_img: await self.update_status_message( status_msg, diff --git a/skynet/utils.py b/skynet/utils.py index e2a8468..9ea2cf0 100755 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -66,8 +66,7 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image: def pipeline_for( model: str, mem_fraction: float = 1.0, - image: bool = False, - inpainting: bool = False, + mode: str = [], cache_dir: str | None = None ) -> DiffusionPipeline: @@ -85,14 +84,14 @@ def pipeline_for( model_info = MODELS[model] - req_mem = model_info['mem'] + req_mem = model_info.mem mem_gb = torch.cuda.mem_get_info()[1] / (10**9) mem_gb *= mem_fraction over_mem = mem_gb < req_mem if over_mem: logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..') - shortname = model_info['short'] + shortname = model_info.short params = { 'safety_checker': None, @@ -107,13 +106,14 @@ def pipeline_for( torch.cuda.set_per_process_memory_fraction(mem_fraction) - if inpainting: - pipe = AutoPipelineForInpainting.from_pretrained( - model, **params) + if 'inpaint' in mode: + pipe_class = AutoPipelineForInpainting else: - pipe = DiffusionPipeline.from_pretrained( - model, **params) + pipe_class = DiffusionPipeline + + pipe = AutoPipelineForInpainting.from_pretrained( + model, **params) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler.config) @@ -121,7 +121,7 @@ def pipeline_for( pipe.enable_xformers_memory_efficient_attention() if over_mem: - if not image: + if 'img2img' not in mode: pipe.enable_vae_slicing() pipe.enable_vae_tiling() @@ -140,7 +140,7 @@ def pipeline_for( def txt2img( hf_token: str, - model: str = 'prompthero/openjourney', + model: str = list(MODELS.keys())[-1], prompt: str = 'a red old tractor in a sunny wheat field', output: str = 'output.png', width: int = 512, height: int = 512, @@ -166,7 +166,7 @@ def txt2img( def img2img( hf_token: str, - model: str = 'prompthero/openjourney', + model: str = list(MODELS.keys())[-2], prompt: str = 'a red old tractor in a sunny wheat field', img_path: str = 'input.png', output: str = 'output.png', @@ -181,7 +181,7 @@ def img2img( model_info = MODELS[model] with open(img_path, 'rb') as img_file: - input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h']) + input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h) seed = seed if seed else random.randint(0, 2 ** 64) prompt = prompt @@ -198,7 +198,7 @@ def img2img( def inpaint( hf_token: str, - model: str = 'diffusers/stable-diffusion-xl-1.0-inpainting-0.1', + model: str = list(MODELS.keys())[-3], prompt: str = 'a red old tractor in a sunny wheat field', img_path: str = 'input.png', mask_path: str = 'mask.png', @@ -214,10 +214,10 @@ def inpaint( model_info = MODELS[model] with open(img_path, 'rb') as img_file: - input_img = convert_from_bytes_and_crop(img_file.read(), model_info['size']['w'], model_info['size']['h']) + input_img = convert_from_bytes_and_crop(img_file.read(), model_info.size.w, model_info.size.h) with open(mask_path, 'rb') as mask_file: - mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info['size']['w'], model_info['size']['h']) + mask_img = convert_from_bytes_and_crop(mask_file.read(), model_info.size.w, model_info.size.h) seed = seed if seed else random.randint(0, 2 ** 64) prompt = prompt diff --git a/tests/test_reqs.py b/tests/test_reqs.py new file mode 100644 index 0000000..cf97cee --- /dev/null +++ b/tests/test_reqs.py @@ -0,0 +1,98 @@ + +from skynet.config import * + +async def test_txt2img(): + req = { + 'id': 0, + 'body': json.dumps({ + "method": "txt2img", + "params": { + "prompt": "Kronos God Realistic 4k", + "model": list(MODELS.keys())[-1], + "step": 21, + "width": 1024, + "height": 1024, + "seed": 168402949, + "guidance": "7.5" + } + }), + 'inputs': [], + } + + config = load_skynet_toml(file_path=config_path) + hf_token = load_key(config, 'skynet.dgpu.hf_token') + hf_home = load_key(config, 'skynet.dgpu.hf_home') + set_hf_vars(hf_token, hf_home) + + assert 'skynet' in config + assert 'dgpu' in config['skynet'] + + mm = SkynetMM(config['skynet']['dgpu']) + + mm.maybe_serve_one(req) + + +async def test_img2img(): + req = { + 'id': 0, + 'body': json.dumps({ + "method": "img2img", + "params": { + "prompt": "Kronos God Realistic 4k", + "model": list(MODELS.keys())[-2], + "step": 21, + "width": 1024, + "height": 1024, + "seed": 168402949, + "guidance": "7.5", + "strength": "0.5" + } + }), + 'inputs': ['QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi'], + } + + config = load_skynet_toml(file_path=config_path) + hf_token = load_key(config, 'skynet.dgpu.hf_token') + hf_home = load_key(config, 'skynet.dgpu.hf_home') + set_hf_vars(hf_token, hf_home) + + assert 'skynet' in config + assert 'dgpu' in config['skynet'] + + mm = SkynetMM(config['skynet']['dgpu']) + + mm.maybe_serve_one(req) + +async def test_inpaint(): + req = { + 'id': 0, + 'body': json.dumps({ + "method": "inpaint", + "params": { + "prompt": "a black panther on a sunny roof", + "model": list(MODELS.keys())[-3], + "step": 21, + "width": 1024, + "height": 1024, + "seed": 168402949, + "guidance": "7.5", + "strength": "0.5" + } + }), + 'inputs': [ + 'QmZcGdXXVQfpco1G3tr2CGFBtv8xVsCwcwuq9gnJBWDymi', + 'Qmccx1aXNmq5mZDS3YviUhgGHXWhQeHvca3AgA7MDjj2hR' + ], + } + + config = load_skynet_toml(file_path=config_path) + hf_token = load_key(config, 'skynet.dgpu.hf_token') + hf_home = load_key(config, 'skynet.dgpu.hf_home') + set_hf_vars(hf_token, hf_home) + + assert 'skynet' in config + assert 'dgpu' in config['skynet'] + + mm = SkynetMM(config['skynet']['dgpu']) + + mm.maybe_serve_one(req)