Add img2img mode

pull/4/head
Guillermo Rodriguez 2023-01-15 23:42:45 -03:00
parent 585d304f86
commit 97f7d51782
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
8 changed files with 250 additions and 23 deletions

View File

@ -13,6 +13,7 @@ setup(
'console_scripts': [
'skynet = skynet.cli:skynet',
'txt2img = skynet.cli:txt2img',
'img2img = skynet.cli:img2img',
'upscale = skynet.cli:upscale'
]
},

View File

@ -164,7 +164,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
event.set()
del wip_reqs[rid]
async def dgpu_stream_one_img(req: Text2ImageParameters):
async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None):
nonlocal wip_reqs, fin_reqs, next_worker
nid = get_next_worker()
idx = list(nodes.keys()).index(nid)
@ -186,7 +186,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
dgpu_req.auth.cert = 'skynet'
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
await dgpu_bus.asend(dgpu_req.SerializeToString())
msg = dgpu_req.SerializeToString()
if img_buf:
logging.info(f'sending img of size {len(img_buf)} as attachment')
logging.info(img_buf[:10])
msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf
await dgpu_bus.asend(msg)
with trio.move_on_after(4):
await ack_event.wait()
@ -237,12 +243,38 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
del user_config['id']
user_config.update(MessageToDict(req.params))
req = Text2ImageParameters(**user_config)
req = DiffusionParameters(**user_config, image=False)
rid, img, meta = await dgpu_stream_one_img(req)
logging.info(f'done streaming {rid}')
result = {
'id': rid,
'img': zlib.compress(img).hex(),
'img': img.hex(),
'meta': meta
}
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
logging.info('updated user stats.')
case 'img2img':
logging.info('img2img')
user_config = {**(await get_user_config(conn, user))}
del user_config['id']
params = MessageToDict(req.params)
img_buf = bytes.fromhex(params['img'])
del params['img']
user_config.update(params)
req = DiffusionParameters(**user_config, image=True)
if not req.image:
raise AssertionError('Didn\'t enable image flag for img2img?')
rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf)
logging.info(f'done streaming {rid}')
result = {
'id': rid,
'img': img.hex(),
'meta': meta
}
@ -256,14 +288,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
prompt = await get_last_prompt_of(conn, user)
if prompt:
req = Text2ImageParameters(
req = DiffusionParameters(
prompt=prompt,
**user_config
)
rid, img, meta = await dgpu_stream_one_img(req)
result = {
'id': rid,
'img': zlib.compress(img).hex(),
'img': img.hex(),
'meta': meta
}
await update_user_stats(conn, user)

View File

@ -41,6 +41,28 @@ def txt2img(*args, **kwargs):
assert 'HF_TOKEN' in os.environ
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
@click.command()
@click.option('--model', '-m', default='midj')
@click.option(
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--input', '-i', default='input.png')
@click.option('--output', '-o', default='output.png')
@click.option('--guidance', '-g', default=10.0)
@click.option('--steps', '-s', default=26)
@click.option('--seed', '-S', default=None)
def img2img(model, prompt, input, output, guidance, steps, seed):
assert 'HF_TOKEN' in os.environ
utils.img2img(
os.environ['HF_TOKEN'],
model=model,
prompt=prompt,
img_path=input,
output=output,
guidance=guidance,
steps=steps,
seed=seed
)
@click.command()
@click.option('--input', '-i', default='input.png')
@click.option('--output', '-o', default='output.png')

View File

@ -6,10 +6,12 @@ import trio
import json
import uuid
import time
import zlib
import random
import logging
import traceback
from PIL import Image
from typing import List, Optional
from pathlib import Path
from contextlib import ExitStack
@ -25,6 +27,7 @@ from OpenSSL.crypto import (
)
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
EulerAncestralDiscreteScheduler
)
from realesrgan import RealESRGANer
@ -138,8 +141,9 @@ async def open_dgpu_node(
logging.info('memory summary:')
logging.info('\n' + torch.cuda.memory_summary())
async def gpu_compute_one(ireq: Text2ImageParameters):
if ireq.algo not in models:
async def gpu_compute_one(ireq: DiffusionParameters, image=None):
algo = ireq.algo + 'img' if image else ireq.algo
if algo not in models:
least_used = list(models.keys())[0]
for model in models:
if models[least_used]['generated'] > models[model]['generated']:
@ -148,16 +152,23 @@ async def open_dgpu_node(
del models[least_used]
gc.collect()
models[ireq.algo] = {
'pipe': pipeline_for(ireq.algo),
models[algo] = {
'pipe': pipeline_for(ireq.algo, image=True if image else False),
'generated': 0
}
_params = {}
if ireq.image:
_params['image'] = image
else:
_params['width'] = int(ireq.width)
_params['height'] = int(ireq.height)
try:
image = models[ireq.algo]['pipe'](
image = models[algo]['pipe'](
ireq.prompt,
width=int(ireq.width),
height=int(ireq.height),
**_params,
guidance_scale=ireq.guidance,
num_inference_steps=int(ireq.step),
generator=torch.Generator("cuda").manual_seed(ireq.seed)
@ -173,7 +184,9 @@ async def open_dgpu_node(
image = convert_from_cv2_to_image(up_img)
logging.info('done')
raw_img = image.tobytes()
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
raw_img = img_byte_arr.getvalue()
logging.info(f'final img size {len(raw_img)} bytes.')
return raw_img
@ -256,8 +269,19 @@ async def open_dgpu_node(
try:
while True:
msg = await dgpu_bus.arecv()
img = None
if b'BINEXT' in msg:
header, msg, img_raw = msg.split(b'%$%$')
logging.info(f'got img attachment of size {len(img_raw)}')
logging.info(img_raw[:10])
raw_img = zlib.decompress(img_raw)
logging.info(raw_img[:10])
img = Image.open(io.BytesIO(raw_img))
req = DGPUBusMessage()
req.ParseFromString(await dgpu_bus.arecv())
req.ParseFromString(msg)
last_msg = time.time()
if req.method == 'heartbeat':
@ -301,11 +325,12 @@ async def open_dgpu_node(
logging.info(f'sent ack, processing {req.rid}...')
try:
img_req = Text2ImageParameters(**req.params)
img_req = DiffusionParameters(**req.params)
if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req)
img = await gpu_compute_one(img_req, image=img)
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid,
@ -335,7 +360,7 @@ async def open_dgpu_node(
await dgpu_bus.asend(raw_msg)
logging.info(f'sent {len(raw_msg)} bytes.')
if img_resp.method == 'binary-reply':
await dgpu_bus.asend(img)
await dgpu_bus.asend(zlib.compress(img))
logging.info(f'sent {len(img)} bytes.')
except KeyboardInterrupt:

View File

@ -130,6 +130,57 @@ async def run_skynet_telegram(
await bot.reply_to(message, resp_txt)
@bot.message_handler(commands=['img2img'], content_types=['photo'])
async def send_img2img(message):
chat = message.chat
if chat.type != 'group' and chat.id != GROUP_ID:
return
prompt = ' '.join(message.caption.split(' ')[1:])
if len(prompt) == 0:
await bot.reply_to(message, 'Empty text prompt ignored.')
return
file_id = message.photo[-1].file_id
file_path = bot.get_file(file_id).file_path
file_raw = bot.download_file(file_path)
img = zlib.compress(file_raw)
logging.info(f'mid: {message.id}')
resp = await _rpc_call(
message.from_user.id,
'img2img',
{'prompt': prompt, 'img': img.hex()}
)
logging.info(f'resp to {message.id} arrived')
resp_txt = ''
result = MessageToDict(resp.result)
if 'error' in resp.result:
resp_txt = resp.result['message']
else:
logging.info(result['id'])
img_raw = zlib.decompress(bytes.fromhex(result['img']))
logging.info(f'got image of size: {len(img_raw)}')
meta = result['meta']['meta']
size = (int(meta['width']), int(meta['height']))
if meta['upscaler'] == 'x4':
size = (size[0] * 4, size[1] * 4)
img = Image.frombytes('RGB', size, img_raw)
await bot.send_photo(
message.chat.id,
caption=prepare_metainfo_caption(meta),
photo=img,
reply_to_message_id=message.id
)
return
await bot.reply_to(message, resp_txt)
@bot.message_handler(commands=['redo'])
async def redo_txt2img(message):
chat = message.chat

View File

@ -16,7 +16,7 @@ class Struct:
@dataclass
class Text2ImageParameters(Struct):
class DiffusionParameters(Struct):
algo: str
prompt: str
step: int
@ -24,4 +24,5 @@ class Text2ImageParameters(Struct):
height: int
guidance: float
seed: Optional[int]
image: bool # if true indicates a bytestream is next msg
upscaler: Optional[str]

View File

@ -12,7 +12,7 @@ from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import (
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionImg2ImgPipeline,
EulerAncestralDiscreteScheduler
)
from realesrgan import RealESRGANer
@ -31,7 +31,7 @@ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
return np.asarray(img)
def pipeline_for(algo: str, mem_fraction: float = 1.0):
def pipeline_for(algo: str, mem_fraction: float = 1.0, image=False):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction)
@ -46,13 +46,19 @@ def pipeline_for(algo: str, mem_fraction: float = 1.0):
if algo == 'stable':
params['revision'] = 'fp16'
pipe = StableDiffusionPipeline.from_pretrained(
if image:
pipe_class = StableDiffusionImg2ImgPipeline
else:
pipe_class = StableDiffusionPipeline
pipe = pipe_class.from_pretrained(
ALGOS[algo], **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config)
pipe.enable_vae_slicing()
if not image:
pipe.enable_vae_slicing()
return pipe.to('cuda')
@ -89,6 +95,39 @@ def txt2img(
image.save(output)
def img2img(
hf_token: str,
model: str = 'midj',
prompt: str = 'a red old tractor in a sunny wheat field',
img_path: str = 'input.png',
output: str = 'output.png',
guidance: float = 10,
steps: int = 28,
seed: Optional[int] = None
):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(1.0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token)
pipe = pipeline_for(model, image=True)
input_img = Image.open(img_path).convert('RGB')
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
image = pipe(
prompt,
image=input_img,
guidance_scale=guidance, num_inference_steps=steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(output)
def upscale(
img_path: str = 'input.png',
output: str = 'output.png',

View File

@ -321,3 +321,59 @@ async def test_dgpu_heartbeat(dgpu_workers):
) as test_rpc:
await wait_for_dgpus(test_rpc, 1)
await trio.sleep(120)
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_img2img(dgpu_workers):
async with open_skynet_rpc(
'1',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as rpc_call:
await wait_for_dgpus(rpc_call, 1)
res = await rpc_call(
'txt2img', {
'prompt': 'red old tractor in a sunny wheat field',
'step': 28,
'width': 512, 'height': 512,
'guidance': 7.5,
'seed': None,
'algo': list(ALGOS.keys())[0],
'upscaler': None
})
if 'error' in res.result:
raise SkynetDGPUComputeError(MessageToDict(res.result))
img_raw = res.result['img']
img = zlib.decompress(bytes.fromhex(img_raw))
logging.info(img[:10])
img = Image.open(io.BytesIO(img))
img.save('txt2img.png')
res = await rpc_call(
'img2img', {
'prompt': 'red sports car in a sunny wheat field',
'step': 28,
'img': img_raw,
'guidance': 12,
'seed': None,
'algo': list(ALGOS.keys())[0],
'upscaler': 'x4'
})
if 'error' in res.result:
raise SkynetDGPUComputeError(MessageToDict(res.result))
img_raw = res.result['img']
img = zlib.decompress(bytes.fromhex(img_raw))
logging.info(img[:10])
img = Image.open(io.BytesIO(img))
img.save('img2img.png')