Add img2img mode

pull/4/head
Guillermo Rodriguez 2023-01-15 23:42:45 -03:00
parent 585d304f86
commit 97f7d51782
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
8 changed files with 250 additions and 23 deletions

View File

@ -13,6 +13,7 @@ setup(
'console_scripts': [ 'console_scripts': [
'skynet = skynet.cli:skynet', 'skynet = skynet.cli:skynet',
'txt2img = skynet.cli:txt2img', 'txt2img = skynet.cli:txt2img',
'img2img = skynet.cli:img2img',
'upscale = skynet.cli:upscale' 'upscale = skynet.cli:upscale'
] ]
}, },

View File

@ -164,7 +164,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
event.set() event.set()
del wip_reqs[rid] del wip_reqs[rid]
async def dgpu_stream_one_img(req: Text2ImageParameters): async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None):
nonlocal wip_reqs, fin_reqs, next_worker nonlocal wip_reqs, fin_reqs, next_worker
nid = get_next_worker() nid = get_next_worker()
idx = list(nodes.keys()).index(nid) idx = list(nodes.keys()).index(nid)
@ -186,7 +186,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
dgpu_req.auth.cert = 'skynet' dgpu_req.auth.cert = 'skynet'
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key) dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
await dgpu_bus.asend(dgpu_req.SerializeToString()) msg = dgpu_req.SerializeToString()
if img_buf:
logging.info(f'sending img of size {len(img_buf)} as attachment')
logging.info(img_buf[:10])
msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf
await dgpu_bus.asend(msg)
with trio.move_on_after(4): with trio.move_on_after(4):
await ack_event.wait() await ack_event.wait()
@ -237,12 +243,38 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
del user_config['id'] del user_config['id']
user_config.update(MessageToDict(req.params)) user_config.update(MessageToDict(req.params))
req = Text2ImageParameters(**user_config) req = DiffusionParameters(**user_config, image=False)
rid, img, meta = await dgpu_stream_one_img(req) rid, img, meta = await dgpu_stream_one_img(req)
logging.info(f'done streaming {rid}') logging.info(f'done streaming {rid}')
result = { result = {
'id': rid, 'id': rid,
'img': zlib.compress(img).hex(), 'img': img.hex(),
'meta': meta
}
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
logging.info('updated user stats.')
case 'img2img':
logging.info('img2img')
user_config = {**(await get_user_config(conn, user))}
del user_config['id']
params = MessageToDict(req.params)
img_buf = bytes.fromhex(params['img'])
del params['img']
user_config.update(params)
req = DiffusionParameters(**user_config, image=True)
if not req.image:
raise AssertionError('Didn\'t enable image flag for img2img?')
rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf)
logging.info(f'done streaming {rid}')
result = {
'id': rid,
'img': img.hex(),
'meta': meta 'meta': meta
} }
@ -256,14 +288,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
prompt = await get_last_prompt_of(conn, user) prompt = await get_last_prompt_of(conn, user)
if prompt: if prompt:
req = Text2ImageParameters( req = DiffusionParameters(
prompt=prompt, prompt=prompt,
**user_config **user_config
) )
rid, img, meta = await dgpu_stream_one_img(req) rid, img, meta = await dgpu_stream_one_img(req)
result = { result = {
'id': rid, 'id': rid,
'img': zlib.compress(img).hex(), 'img': img.hex(),
'meta': meta 'meta': meta
} }
await update_user_stats(conn, user) await update_user_stats(conn, user)

View File

@ -41,6 +41,28 @@ def txt2img(*args, **kwargs):
assert 'HF_TOKEN' in os.environ assert 'HF_TOKEN' in os.environ
utils.txt2img(os.environ['HF_TOKEN'], **kwargs) utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
@click.command()
@click.option('--model', '-m', default='midj')
@click.option(
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
@click.option('--input', '-i', default='input.png')
@click.option('--output', '-o', default='output.png')
@click.option('--guidance', '-g', default=10.0)
@click.option('--steps', '-s', default=26)
@click.option('--seed', '-S', default=None)
def img2img(model, prompt, input, output, guidance, steps, seed):
assert 'HF_TOKEN' in os.environ
utils.img2img(
os.environ['HF_TOKEN'],
model=model,
prompt=prompt,
img_path=input,
output=output,
guidance=guidance,
steps=steps,
seed=seed
)
@click.command() @click.command()
@click.option('--input', '-i', default='input.png') @click.option('--input', '-i', default='input.png')
@click.option('--output', '-o', default='output.png') @click.option('--output', '-o', default='output.png')

View File

@ -6,10 +6,12 @@ import trio
import json import json
import uuid import uuid
import time import time
import zlib
import random import random
import logging import logging
import traceback import traceback
from PIL import Image
from typing import List, Optional from typing import List, Optional
from pathlib import Path from pathlib import Path
from contextlib import ExitStack from contextlib import ExitStack
@ -25,6 +27,7 @@ from OpenSSL.crypto import (
) )
from diffusers import ( from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
EulerAncestralDiscreteScheduler EulerAncestralDiscreteScheduler
) )
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
@ -138,8 +141,9 @@ async def open_dgpu_node(
logging.info('memory summary:') logging.info('memory summary:')
logging.info('\n' + torch.cuda.memory_summary()) logging.info('\n' + torch.cuda.memory_summary())
async def gpu_compute_one(ireq: Text2ImageParameters): async def gpu_compute_one(ireq: DiffusionParameters, image=None):
if ireq.algo not in models: algo = ireq.algo + 'img' if image else ireq.algo
if algo not in models:
least_used = list(models.keys())[0] least_used = list(models.keys())[0]
for model in models: for model in models:
if models[least_used]['generated'] > models[model]['generated']: if models[least_used]['generated'] > models[model]['generated']:
@ -148,16 +152,23 @@ async def open_dgpu_node(
del models[least_used] del models[least_used]
gc.collect() gc.collect()
models[ireq.algo] = { models[algo] = {
'pipe': pipeline_for(ireq.algo), 'pipe': pipeline_for(ireq.algo, image=True if image else False),
'generated': 0 'generated': 0
} }
_params = {}
if ireq.image:
_params['image'] = image
else:
_params['width'] = int(ireq.width)
_params['height'] = int(ireq.height)
try: try:
image = models[ireq.algo]['pipe']( image = models[algo]['pipe'](
ireq.prompt, ireq.prompt,
width=int(ireq.width), **_params,
height=int(ireq.height),
guidance_scale=ireq.guidance, guidance_scale=ireq.guidance,
num_inference_steps=int(ireq.step), num_inference_steps=int(ireq.step),
generator=torch.Generator("cuda").manual_seed(ireq.seed) generator=torch.Generator("cuda").manual_seed(ireq.seed)
@ -173,7 +184,9 @@ async def open_dgpu_node(
image = convert_from_cv2_to_image(up_img) image = convert_from_cv2_to_image(up_img)
logging.info('done') logging.info('done')
raw_img = image.tobytes() img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
raw_img = img_byte_arr.getvalue()
logging.info(f'final img size {len(raw_img)} bytes.') logging.info(f'final img size {len(raw_img)} bytes.')
return raw_img return raw_img
@ -256,8 +269,19 @@ async def open_dgpu_node(
try: try:
while True: while True:
msg = await dgpu_bus.arecv()
img = None
if b'BINEXT' in msg:
header, msg, img_raw = msg.split(b'%$%$')
logging.info(f'got img attachment of size {len(img_raw)}')
logging.info(img_raw[:10])
raw_img = zlib.decompress(img_raw)
logging.info(raw_img[:10])
img = Image.open(io.BytesIO(raw_img))
req = DGPUBusMessage() req = DGPUBusMessage()
req.ParseFromString(await dgpu_bus.arecv()) req.ParseFromString(msg)
last_msg = time.time() last_msg = time.time()
if req.method == 'heartbeat': if req.method == 'heartbeat':
@ -301,11 +325,12 @@ async def open_dgpu_node(
logging.info(f'sent ack, processing {req.rid}...') logging.info(f'sent ack, processing {req.rid}...')
try: try:
img_req = Text2ImageParameters(**req.params) img_req = DiffusionParameters(**req.params)
if not img_req.seed: if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64) img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req) img = await gpu_compute_one(img_req, image=img)
img_resp = DGPUBusMessage( img_resp = DGPUBusMessage(
rid=req.rid, rid=req.rid,
nid=req.nid, nid=req.nid,
@ -335,7 +360,7 @@ async def open_dgpu_node(
await dgpu_bus.asend(raw_msg) await dgpu_bus.asend(raw_msg)
logging.info(f'sent {len(raw_msg)} bytes.') logging.info(f'sent {len(raw_msg)} bytes.')
if img_resp.method == 'binary-reply': if img_resp.method == 'binary-reply':
await dgpu_bus.asend(img) await dgpu_bus.asend(zlib.compress(img))
logging.info(f'sent {len(img)} bytes.') logging.info(f'sent {len(img)} bytes.')
except KeyboardInterrupt: except KeyboardInterrupt:

View File

@ -130,6 +130,57 @@ async def run_skynet_telegram(
await bot.reply_to(message, resp_txt) await bot.reply_to(message, resp_txt)
@bot.message_handler(commands=['img2img'], content_types=['photo'])
async def send_img2img(message):
chat = message.chat
if chat.type != 'group' and chat.id != GROUP_ID:
return
prompt = ' '.join(message.caption.split(' ')[1:])
if len(prompt) == 0:
await bot.reply_to(message, 'Empty text prompt ignored.')
return
file_id = message.photo[-1].file_id
file_path = bot.get_file(file_id).file_path
file_raw = bot.download_file(file_path)
img = zlib.compress(file_raw)
logging.info(f'mid: {message.id}')
resp = await _rpc_call(
message.from_user.id,
'img2img',
{'prompt': prompt, 'img': img.hex()}
)
logging.info(f'resp to {message.id} arrived')
resp_txt = ''
result = MessageToDict(resp.result)
if 'error' in resp.result:
resp_txt = resp.result['message']
else:
logging.info(result['id'])
img_raw = zlib.decompress(bytes.fromhex(result['img']))
logging.info(f'got image of size: {len(img_raw)}')
meta = result['meta']['meta']
size = (int(meta['width']), int(meta['height']))
if meta['upscaler'] == 'x4':
size = (size[0] * 4, size[1] * 4)
img = Image.frombytes('RGB', size, img_raw)
await bot.send_photo(
message.chat.id,
caption=prepare_metainfo_caption(meta),
photo=img,
reply_to_message_id=message.id
)
return
await bot.reply_to(message, resp_txt)
@bot.message_handler(commands=['redo']) @bot.message_handler(commands=['redo'])
async def redo_txt2img(message): async def redo_txt2img(message):
chat = message.chat chat = message.chat

View File

@ -16,7 +16,7 @@ class Struct:
@dataclass @dataclass
class Text2ImageParameters(Struct): class DiffusionParameters(Struct):
algo: str algo: str
prompt: str prompt: str
step: int step: int
@ -24,4 +24,5 @@ class Text2ImageParameters(Struct):
height: int height: int
guidance: float guidance: float
seed: Optional[int] seed: Optional[int]
image: bool # if true indicates a bytestream is next msg
upscaler: Optional[str] upscaler: Optional[str]

View File

@ -12,7 +12,7 @@ from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from diffusers import ( from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline, StableDiffusionImg2ImgPipeline,
EulerAncestralDiscreteScheduler EulerAncestralDiscreteScheduler
) )
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
@ -31,7 +31,7 @@ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
return np.asarray(img) return np.asarray(img)
def pipeline_for(algo: str, mem_fraction: float = 1.0): def pipeline_for(algo: str, mem_fraction: float = 1.0, image=False):
assert torch.cuda.is_available() assert torch.cuda.is_available()
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(mem_fraction) torch.cuda.set_per_process_memory_fraction(mem_fraction)
@ -46,12 +46,18 @@ def pipeline_for(algo: str, mem_fraction: float = 1.0):
if algo == 'stable': if algo == 'stable':
params['revision'] = 'fp16' params['revision'] = 'fp16'
pipe = StableDiffusionPipeline.from_pretrained( if image:
pipe_class = StableDiffusionImg2ImgPipeline
else:
pipe_class = StableDiffusionPipeline
pipe = pipe_class.from_pretrained(
ALGOS[algo], **params) ALGOS[algo], **params)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config) pipe.scheduler.config)
if not image:
pipe.enable_vae_slicing() pipe.enable_vae_slicing()
return pipe.to('cuda') return pipe.to('cuda')
@ -89,6 +95,39 @@ def txt2img(
image.save(output) image.save(output)
def img2img(
hf_token: str,
model: str = 'midj',
prompt: str = 'a red old tractor in a sunny wheat field',
img_path: str = 'input.png',
output: str = 'output.png',
guidance: float = 10,
steps: int = 28,
seed: Optional[int] = None
):
assert torch.cuda.is_available()
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(1.0)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
login(token=hf_token)
pipe = pipeline_for(model, image=True)
input_img = Image.open(img_path).convert('RGB')
seed = seed if seed else random.randint(0, 2 ** 64)
prompt = prompt
image = pipe(
prompt,
image=input_img,
guidance_scale=guidance, num_inference_steps=steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(output)
def upscale( def upscale(
img_path: str = 'input.png', img_path: str = 'input.png',
output: str = 'output.png', output: str = 'output.png',

View File

@ -321,3 +321,59 @@ async def test_dgpu_heartbeat(dgpu_workers):
) as test_rpc: ) as test_rpc:
await wait_for_dgpus(test_rpc, 1) await wait_for_dgpus(test_rpc, 1)
await trio.sleep(120) await trio.sleep(120)
@pytest.mark.parametrize(
'dgpu_workers', [(1, ['midj'])], indirect=True)
async def test_dgpu_img2img(dgpu_workers):
async with open_skynet_rpc(
'1',
security=True,
cert_name='whitelist/testing',
key_name='testing'
) as rpc_call:
await wait_for_dgpus(rpc_call, 1)
res = await rpc_call(
'txt2img', {
'prompt': 'red old tractor in a sunny wheat field',
'step': 28,
'width': 512, 'height': 512,
'guidance': 7.5,
'seed': None,
'algo': list(ALGOS.keys())[0],
'upscaler': None
})
if 'error' in res.result:
raise SkynetDGPUComputeError(MessageToDict(res.result))
img_raw = res.result['img']
img = zlib.decompress(bytes.fromhex(img_raw))
logging.info(img[:10])
img = Image.open(io.BytesIO(img))
img.save('txt2img.png')
res = await rpc_call(
'img2img', {
'prompt': 'red sports car in a sunny wheat field',
'step': 28,
'img': img_raw,
'guidance': 12,
'seed': None,
'algo': list(ALGOS.keys())[0],
'upscaler': 'x4'
})
if 'error' in res.result:
raise SkynetDGPUComputeError(MessageToDict(res.result))
img_raw = res.result['img']
img = zlib.decompress(bytes.fromhex(img_raw))
logging.info(img[:10])
img = Image.open(io.BytesIO(img))
img.save('img2img.png')