From 97f7d517829b14a5b262e20021852f932a81e049 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sun, 15 Jan 2023 23:42:45 -0300 Subject: [PATCH] Add img2img mode --- setup.py | 1 + skynet/brain.py | 44 +++++++++++++++++++++++++---- skynet/cli.py | 22 +++++++++++++++ skynet/dgpu.py | 49 ++++++++++++++++++++++++-------- skynet/frontend/telegram.py | 51 +++++++++++++++++++++++++++++++++ skynet/protobuf/__init__.py | 3 +- skynet/utils.py | 47 ++++++++++++++++++++++++++++--- tests/test_dgpu.py | 56 +++++++++++++++++++++++++++++++++++++ 8 files changed, 250 insertions(+), 23 deletions(-) diff --git a/setup.py b/setup.py index 0a822be..b1866d9 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ setup( 'console_scripts': [ 'skynet = skynet.cli:skynet', 'txt2img = skynet.cli:txt2img', + 'img2img = skynet.cli:img2img', 'upscale = skynet.cli:upscale' ] }, diff --git a/skynet/brain.py b/skynet/brain.py index d649483..e5fe3fe 100644 --- a/skynet/brain.py +++ b/skynet/brain.py @@ -164,7 +164,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): event.set() del wip_reqs[rid] - async def dgpu_stream_one_img(req: Text2ImageParameters): + async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None): nonlocal wip_reqs, fin_reqs, next_worker nid = get_next_worker() idx = list(nodes.keys()).index(nid) @@ -186,7 +186,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): dgpu_req.auth.cert = 'skynet' dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key) - await dgpu_bus.asend(dgpu_req.SerializeToString()) + msg = dgpu_req.SerializeToString() + if img_buf: + logging.info(f'sending img of size {len(img_buf)} as attachment') + logging.info(img_buf[:10]) + msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf + + await dgpu_bus.asend(msg) with trio.move_on_after(4): await ack_event.wait() @@ -237,12 +243,38 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): del user_config['id'] user_config.update(MessageToDict(req.params)) - req = Text2ImageParameters(**user_config) + req = DiffusionParameters(**user_config, image=False) rid, img, meta = await dgpu_stream_one_img(req) logging.info(f'done streaming {rid}') result = { 'id': rid, - 'img': zlib.compress(img).hex(), + 'img': img.hex(), + 'meta': meta + } + + await update_user_stats(conn, user, last_prompt=user_config['prompt']) + logging.info('updated user stats.') + + case 'img2img': + logging.info('img2img') + user_config = {**(await get_user_config(conn, user))} + del user_config['id'] + + params = MessageToDict(req.params) + img_buf = bytes.fromhex(params['img']) + del params['img'] + user_config.update(params) + + req = DiffusionParameters(**user_config, image=True) + + if not req.image: + raise AssertionError('Didn\'t enable image flag for img2img?') + + rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf) + logging.info(f'done streaming {rid}') + result = { + 'id': rid, + 'img': img.hex(), 'meta': meta } @@ -256,14 +288,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): prompt = await get_last_prompt_of(conn, user) if prompt: - req = Text2ImageParameters( + req = DiffusionParameters( prompt=prompt, **user_config ) rid, img, meta = await dgpu_stream_one_img(req) result = { 'id': rid, - 'img': zlib.compress(img).hex(), + 'img': img.hex(), 'meta': meta } await update_user_stats(conn, user) diff --git a/skynet/cli.py b/skynet/cli.py index 1d7b501..de7a37d 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -41,6 +41,28 @@ def txt2img(*args, **kwargs): assert 'HF_TOKEN' in os.environ utils.txt2img(os.environ['HF_TOKEN'], **kwargs) +@click.command() +@click.option('--model', '-m', default='midj') +@click.option( + '--prompt', '-p', default='a red old tractor in a sunny wheat field') +@click.option('--input', '-i', default='input.png') +@click.option('--output', '-o', default='output.png') +@click.option('--guidance', '-g', default=10.0) +@click.option('--steps', '-s', default=26) +@click.option('--seed', '-S', default=None) +def img2img(model, prompt, input, output, guidance, steps, seed): + assert 'HF_TOKEN' in os.environ + utils.img2img( + os.environ['HF_TOKEN'], + model=model, + prompt=prompt, + img_path=input, + output=output, + guidance=guidance, + steps=steps, + seed=seed + ) + @click.command() @click.option('--input', '-i', default='input.png') @click.option('--output', '-o', default='output.png') diff --git a/skynet/dgpu.py b/skynet/dgpu.py index 4bb4f10..2a8118e 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -6,10 +6,12 @@ import trio import json import uuid import time +import zlib import random import logging import traceback +from PIL import Image from typing import List, Optional from pathlib import Path from contextlib import ExitStack @@ -25,6 +27,7 @@ from OpenSSL.crypto import ( ) from diffusers import ( StableDiffusionPipeline, + StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler ) from realesrgan import RealESRGANer @@ -138,8 +141,9 @@ async def open_dgpu_node( logging.info('memory summary:') logging.info('\n' + torch.cuda.memory_summary()) - async def gpu_compute_one(ireq: Text2ImageParameters): - if ireq.algo not in models: + async def gpu_compute_one(ireq: DiffusionParameters, image=None): + algo = ireq.algo + 'img' if image else ireq.algo + if algo not in models: least_used = list(models.keys())[0] for model in models: if models[least_used]['generated'] > models[model]['generated']: @@ -148,16 +152,23 @@ async def open_dgpu_node( del models[least_used] gc.collect() - models[ireq.algo] = { - 'pipe': pipeline_for(ireq.algo), + models[algo] = { + 'pipe': pipeline_for(ireq.algo, image=True if image else False), 'generated': 0 } + _params = {} + if ireq.image: + _params['image'] = image + + else: + _params['width'] = int(ireq.width) + _params['height'] = int(ireq.height) + try: - image = models[ireq.algo]['pipe']( + image = models[algo]['pipe']( ireq.prompt, - width=int(ireq.width), - height=int(ireq.height), + **_params, guidance_scale=ireq.guidance, num_inference_steps=int(ireq.step), generator=torch.Generator("cuda").manual_seed(ireq.seed) @@ -173,7 +184,9 @@ async def open_dgpu_node( image = convert_from_cv2_to_image(up_img) logging.info('done') - raw_img = image.tobytes() + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format='PNG') + raw_img = img_byte_arr.getvalue() logging.info(f'final img size {len(raw_img)} bytes.') return raw_img @@ -256,8 +269,19 @@ async def open_dgpu_node( try: while True: + msg = await dgpu_bus.arecv() + + img = None + if b'BINEXT' in msg: + header, msg, img_raw = msg.split(b'%$%$') + logging.info(f'got img attachment of size {len(img_raw)}') + logging.info(img_raw[:10]) + raw_img = zlib.decompress(img_raw) + logging.info(raw_img[:10]) + img = Image.open(io.BytesIO(raw_img)) + req = DGPUBusMessage() - req.ParseFromString(await dgpu_bus.arecv()) + req.ParseFromString(msg) last_msg = time.time() if req.method == 'heartbeat': @@ -301,11 +325,12 @@ async def open_dgpu_node( logging.info(f'sent ack, processing {req.rid}...') try: - img_req = Text2ImageParameters(**req.params) + img_req = DiffusionParameters(**req.params) + if not img_req.seed: img_req.seed = random.randint(0, 2 ** 64) - img = await gpu_compute_one(img_req) + img = await gpu_compute_one(img_req, image=img) img_resp = DGPUBusMessage( rid=req.rid, nid=req.nid, @@ -335,7 +360,7 @@ async def open_dgpu_node( await dgpu_bus.asend(raw_msg) logging.info(f'sent {len(raw_msg)} bytes.') if img_resp.method == 'binary-reply': - await dgpu_bus.asend(img) + await dgpu_bus.asend(zlib.compress(img)) logging.info(f'sent {len(img)} bytes.') except KeyboardInterrupt: diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py index f53a116..7152b32 100644 --- a/skynet/frontend/telegram.py +++ b/skynet/frontend/telegram.py @@ -130,6 +130,57 @@ async def run_skynet_telegram( await bot.reply_to(message, resp_txt) + @bot.message_handler(commands=['img2img'], content_types=['photo']) + async def send_img2img(message): + chat = message.chat + if chat.type != 'group' and chat.id != GROUP_ID: + return + + prompt = ' '.join(message.caption.split(' ')[1:]) + + if len(prompt) == 0: + await bot.reply_to(message, 'Empty text prompt ignored.') + return + + file_id = message.photo[-1].file_id + file_path = bot.get_file(file_id).file_path + file_raw = bot.download_file(file_path) + img = zlib.compress(file_raw) + + logging.info(f'mid: {message.id}') + resp = await _rpc_call( + message.from_user.id, + 'img2img', + {'prompt': prompt, 'img': img.hex()} + ) + logging.info(f'resp to {message.id} arrived') + + resp_txt = '' + result = MessageToDict(resp.result) + if 'error' in resp.result: + resp_txt = resp.result['message'] + + else: + logging.info(result['id']) + img_raw = zlib.decompress(bytes.fromhex(result['img'])) + logging.info(f'got image of size: {len(img_raw)}') + meta = result['meta']['meta'] + size = (int(meta['width']), int(meta['height'])) + if meta['upscaler'] == 'x4': + size = (size[0] * 4, size[1] * 4) + + img = Image.frombytes('RGB', size, img_raw) + + await bot.send_photo( + message.chat.id, + caption=prepare_metainfo_caption(meta), + photo=img, + reply_to_message_id=message.id + ) + return + + await bot.reply_to(message, resp_txt) + @bot.message_handler(commands=['redo']) async def redo_txt2img(message): chat = message.chat diff --git a/skynet/protobuf/__init__.py b/skynet/protobuf/__init__.py index eb99c3f..15af051 100644 --- a/skynet/protobuf/__init__.py +++ b/skynet/protobuf/__init__.py @@ -16,7 +16,7 @@ class Struct: @dataclass -class Text2ImageParameters(Struct): +class DiffusionParameters(Struct): algo: str prompt: str step: int @@ -24,4 +24,5 @@ class Text2ImageParameters(Struct): height: int guidance: float seed: Optional[int] + image: bool # if true indicates a bytestream is next msg upscaler: Optional[str] diff --git a/skynet/utils.py b/skynet/utils.py index c511453..64a4583 100644 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -12,7 +12,7 @@ from PIL import Image from basicsr.archs.rrdbnet_arch import RRDBNet from diffusers import ( StableDiffusionPipeline, - StableDiffusionUpscalePipeline, + StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler ) from realesrgan import RealESRGANer @@ -31,7 +31,7 @@ def convert_from_image_to_cv2(img: Image) -> np.ndarray: return np.asarray(img) -def pipeline_for(algo: str, mem_fraction: float = 1.0): +def pipeline_for(algo: str, mem_fraction: float = 1.0, image=False): assert torch.cuda.is_available() torch.cuda.empty_cache() torch.cuda.set_per_process_memory_fraction(mem_fraction) @@ -46,13 +46,19 @@ def pipeline_for(algo: str, mem_fraction: float = 1.0): if algo == 'stable': params['revision'] = 'fp16' - pipe = StableDiffusionPipeline.from_pretrained( + if image: + pipe_class = StableDiffusionImg2ImgPipeline + else: + pipe_class = StableDiffusionPipeline + + pipe = pipe_class.from_pretrained( ALGOS[algo], **params) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler.config) - pipe.enable_vae_slicing() + if not image: + pipe.enable_vae_slicing() return pipe.to('cuda') @@ -89,6 +95,39 @@ def txt2img( image.save(output) +def img2img( + hf_token: str, + model: str = 'midj', + prompt: str = 'a red old tractor in a sunny wheat field', + img_path: str = 'input.png', + output: str = 'output.png', + guidance: float = 10, + steps: int = 28, + seed: Optional[int] = None +): + assert torch.cuda.is_available() + torch.cuda.empty_cache() + torch.cuda.set_per_process_memory_fraction(1.0) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + login(token=hf_token) + pipe = pipeline_for(model, image=True) + + input_img = Image.open(img_path).convert('RGB') + + seed = seed if seed else random.randint(0, 2 ** 64) + prompt = prompt + image = pipe( + prompt, + image=input_img, + guidance_scale=guidance, num_inference_steps=steps, + generator=torch.Generator("cuda").manual_seed(seed) + ).images[0] + + image.save(output) + + def upscale( img_path: str = 'input.png', output: str = 'output.png', diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py index 7d6ef07..4ce93bf 100644 --- a/tests/test_dgpu.py +++ b/tests/test_dgpu.py @@ -321,3 +321,59 @@ async def test_dgpu_heartbeat(dgpu_workers): ) as test_rpc: await wait_for_dgpus(test_rpc, 1) await trio.sleep(120) + + +@pytest.mark.parametrize( + 'dgpu_workers', [(1, ['midj'])], indirect=True) +async def test_dgpu_img2img(dgpu_workers): + + async with open_skynet_rpc( + '1', + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as rpc_call: + await wait_for_dgpus(rpc_call, 1) + + + res = await rpc_call( + 'txt2img', { + 'prompt': 'red old tractor in a sunny wheat field', + 'step': 28, + 'width': 512, 'height': 512, + 'guidance': 7.5, + 'seed': None, + 'algo': list(ALGOS.keys())[0], + 'upscaler': None + }) + + if 'error' in res.result: + raise SkynetDGPUComputeError(MessageToDict(res.result)) + + img_raw = res.result['img'] + img = zlib.decompress(bytes.fromhex(img_raw)) + logging.info(img[:10]) + img = Image.open(io.BytesIO(img)) + + img.save('txt2img.png') + + res = await rpc_call( + 'img2img', { + 'prompt': 'red sports car in a sunny wheat field', + 'step': 28, + 'img': img_raw, + 'guidance': 12, + 'seed': None, + 'algo': list(ALGOS.keys())[0], + 'upscaler': 'x4' + }) + + if 'error' in res.result: + raise SkynetDGPUComputeError(MessageToDict(res.result)) + + img_raw = res.result['img'] + img = zlib.decompress(bytes.fromhex(img_raw)) + logging.info(img[:10]) + img = Image.open(io.BytesIO(img)) + + img.save('img2img.png')