mirror of https://github.com/skygpu/skynet.git
Add img2img mode
parent
585d304f86
commit
97f7d51782
1
setup.py
1
setup.py
|
@ -13,6 +13,7 @@ setup(
|
|||
'console_scripts': [
|
||||
'skynet = skynet.cli:skynet',
|
||||
'txt2img = skynet.cli:txt2img',
|
||||
'img2img = skynet.cli:img2img',
|
||||
'upscale = skynet.cli:upscale'
|
||||
]
|
||||
},
|
||||
|
|
|
@ -164,7 +164,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
event.set()
|
||||
del wip_reqs[rid]
|
||||
|
||||
async def dgpu_stream_one_img(req: Text2ImageParameters):
|
||||
async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None):
|
||||
nonlocal wip_reqs, fin_reqs, next_worker
|
||||
nid = get_next_worker()
|
||||
idx = list(nodes.keys()).index(nid)
|
||||
|
@ -186,7 +186,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
dgpu_req.auth.cert = 'skynet'
|
||||
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
|
||||
|
||||
await dgpu_bus.asend(dgpu_req.SerializeToString())
|
||||
msg = dgpu_req.SerializeToString()
|
||||
if img_buf:
|
||||
logging.info(f'sending img of size {len(img_buf)} as attachment')
|
||||
logging.info(img_buf[:10])
|
||||
msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf
|
||||
|
||||
await dgpu_bus.asend(msg)
|
||||
|
||||
with trio.move_on_after(4):
|
||||
await ack_event.wait()
|
||||
|
@ -237,12 +243,38 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
del user_config['id']
|
||||
user_config.update(MessageToDict(req.params))
|
||||
|
||||
req = Text2ImageParameters(**user_config)
|
||||
req = DiffusionParameters(**user_config, image=False)
|
||||
rid, img, meta = await dgpu_stream_one_img(req)
|
||||
logging.info(f'done streaming {rid}')
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': zlib.compress(img).hex(),
|
||||
'img': img.hex(),
|
||||
'meta': meta
|
||||
}
|
||||
|
||||
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
|
||||
logging.info('updated user stats.')
|
||||
|
||||
case 'img2img':
|
||||
logging.info('img2img')
|
||||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
|
||||
params = MessageToDict(req.params)
|
||||
img_buf = bytes.fromhex(params['img'])
|
||||
del params['img']
|
||||
user_config.update(params)
|
||||
|
||||
req = DiffusionParameters(**user_config, image=True)
|
||||
|
||||
if not req.image:
|
||||
raise AssertionError('Didn\'t enable image flag for img2img?')
|
||||
|
||||
rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf)
|
||||
logging.info(f'done streaming {rid}')
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img.hex(),
|
||||
'meta': meta
|
||||
}
|
||||
|
||||
|
@ -256,14 +288,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
prompt = await get_last_prompt_of(conn, user)
|
||||
|
||||
if prompt:
|
||||
req = Text2ImageParameters(
|
||||
req = DiffusionParameters(
|
||||
prompt=prompt,
|
||||
**user_config
|
||||
)
|
||||
rid, img, meta = await dgpu_stream_one_img(req)
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': zlib.compress(img).hex(),
|
||||
'img': img.hex(),
|
||||
'meta': meta
|
||||
}
|
||||
await update_user_stats(conn, user)
|
||||
|
|
|
@ -41,6 +41,28 @@ def txt2img(*args, **kwargs):
|
|||
assert 'HF_TOKEN' in os.environ
|
||||
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', '-m', default='midj')
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--guidance', '-g', default=10.0)
|
||||
@click.option('--steps', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
def img2img(model, prompt, input, output, guidance, steps, seed):
|
||||
assert 'HF_TOKEN' in os.environ
|
||||
utils.img2img(
|
||||
os.environ['HF_TOKEN'],
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
img_path=input,
|
||||
output=output,
|
||||
guidance=guidance,
|
||||
steps=steps,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
@click.command()
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
|
|
|
@ -6,10 +6,12 @@ import trio
|
|||
import json
|
||||
import uuid
|
||||
import time
|
||||
import zlib
|
||||
import random
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from PIL import Image
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from contextlib import ExitStack
|
||||
|
@ -25,6 +27,7 @@ from OpenSSL.crypto import (
|
|||
)
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
from realesrgan import RealESRGANer
|
||||
|
@ -138,8 +141,9 @@ async def open_dgpu_node(
|
|||
logging.info('memory summary:')
|
||||
logging.info('\n' + torch.cuda.memory_summary())
|
||||
|
||||
async def gpu_compute_one(ireq: Text2ImageParameters):
|
||||
if ireq.algo not in models:
|
||||
async def gpu_compute_one(ireq: DiffusionParameters, image=None):
|
||||
algo = ireq.algo + 'img' if image else ireq.algo
|
||||
if algo not in models:
|
||||
least_used = list(models.keys())[0]
|
||||
for model in models:
|
||||
if models[least_used]['generated'] > models[model]['generated']:
|
||||
|
@ -148,16 +152,23 @@ async def open_dgpu_node(
|
|||
del models[least_used]
|
||||
gc.collect()
|
||||
|
||||
models[ireq.algo] = {
|
||||
'pipe': pipeline_for(ireq.algo),
|
||||
models[algo] = {
|
||||
'pipe': pipeline_for(ireq.algo, image=True if image else False),
|
||||
'generated': 0
|
||||
}
|
||||
|
||||
_params = {}
|
||||
if ireq.image:
|
||||
_params['image'] = image
|
||||
|
||||
else:
|
||||
_params['width'] = int(ireq.width)
|
||||
_params['height'] = int(ireq.height)
|
||||
|
||||
try:
|
||||
image = models[ireq.algo]['pipe'](
|
||||
image = models[algo]['pipe'](
|
||||
ireq.prompt,
|
||||
width=int(ireq.width),
|
||||
height=int(ireq.height),
|
||||
**_params,
|
||||
guidance_scale=ireq.guidance,
|
||||
num_inference_steps=int(ireq.step),
|
||||
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
||||
|
@ -173,7 +184,9 @@ async def open_dgpu_node(
|
|||
image = convert_from_cv2_to_image(up_img)
|
||||
logging.info('done')
|
||||
|
||||
raw_img = image.tobytes()
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format='PNG')
|
||||
raw_img = img_byte_arr.getvalue()
|
||||
logging.info(f'final img size {len(raw_img)} bytes.')
|
||||
|
||||
return raw_img
|
||||
|
@ -256,8 +269,19 @@ async def open_dgpu_node(
|
|||
|
||||
try:
|
||||
while True:
|
||||
msg = await dgpu_bus.arecv()
|
||||
|
||||
img = None
|
||||
if b'BINEXT' in msg:
|
||||
header, msg, img_raw = msg.split(b'%$%$')
|
||||
logging.info(f'got img attachment of size {len(img_raw)}')
|
||||
logging.info(img_raw[:10])
|
||||
raw_img = zlib.decompress(img_raw)
|
||||
logging.info(raw_img[:10])
|
||||
img = Image.open(io.BytesIO(raw_img))
|
||||
|
||||
req = DGPUBusMessage()
|
||||
req.ParseFromString(await dgpu_bus.arecv())
|
||||
req.ParseFromString(msg)
|
||||
last_msg = time.time()
|
||||
|
||||
if req.method == 'heartbeat':
|
||||
|
@ -301,11 +325,12 @@ async def open_dgpu_node(
|
|||
logging.info(f'sent ack, processing {req.rid}...')
|
||||
|
||||
try:
|
||||
img_req = Text2ImageParameters(**req.params)
|
||||
img_req = DiffusionParameters(**req.params)
|
||||
|
||||
if not img_req.seed:
|
||||
img_req.seed = random.randint(0, 2 ** 64)
|
||||
|
||||
img = await gpu_compute_one(img_req)
|
||||
img = await gpu_compute_one(img_req, image=img)
|
||||
img_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
|
@ -335,7 +360,7 @@ async def open_dgpu_node(
|
|||
await dgpu_bus.asend(raw_msg)
|
||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||
if img_resp.method == 'binary-reply':
|
||||
await dgpu_bus.asend(img)
|
||||
await dgpu_bus.asend(zlib.compress(img))
|
||||
logging.info(f'sent {len(img)} bytes.')
|
||||
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
@ -130,6 +130,57 @@ async def run_skynet_telegram(
|
|||
|
||||
await bot.reply_to(message, resp_txt)
|
||||
|
||||
@bot.message_handler(commands=['img2img'], content_types=['photo'])
|
||||
async def send_img2img(message):
|
||||
chat = message.chat
|
||||
if chat.type != 'group' and chat.id != GROUP_ID:
|
||||
return
|
||||
|
||||
prompt = ' '.join(message.caption.split(' ')[1:])
|
||||
|
||||
if len(prompt) == 0:
|
||||
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||
return
|
||||
|
||||
file_id = message.photo[-1].file_id
|
||||
file_path = bot.get_file(file_id).file_path
|
||||
file_raw = bot.download_file(file_path)
|
||||
img = zlib.compress(file_raw)
|
||||
|
||||
logging.info(f'mid: {message.id}')
|
||||
resp = await _rpc_call(
|
||||
message.from_user.id,
|
||||
'img2img',
|
||||
{'prompt': prompt, 'img': img.hex()}
|
||||
)
|
||||
logging.info(f'resp to {message.id} arrived')
|
||||
|
||||
resp_txt = ''
|
||||
result = MessageToDict(resp.result)
|
||||
if 'error' in resp.result:
|
||||
resp_txt = resp.result['message']
|
||||
|
||||
else:
|
||||
logging.info(result['id'])
|
||||
img_raw = zlib.decompress(bytes.fromhex(result['img']))
|
||||
logging.info(f'got image of size: {len(img_raw)}')
|
||||
meta = result['meta']['meta']
|
||||
size = (int(meta['width']), int(meta['height']))
|
||||
if meta['upscaler'] == 'x4':
|
||||
size = (size[0] * 4, size[1] * 4)
|
||||
|
||||
img = Image.frombytes('RGB', size, img_raw)
|
||||
|
||||
await bot.send_photo(
|
||||
message.chat.id,
|
||||
caption=prepare_metainfo_caption(meta),
|
||||
photo=img,
|
||||
reply_to_message_id=message.id
|
||||
)
|
||||
return
|
||||
|
||||
await bot.reply_to(message, resp_txt)
|
||||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
async def redo_txt2img(message):
|
||||
chat = message.chat
|
||||
|
|
|
@ -16,7 +16,7 @@ class Struct:
|
|||
|
||||
|
||||
@dataclass
|
||||
class Text2ImageParameters(Struct):
|
||||
class DiffusionParameters(Struct):
|
||||
algo: str
|
||||
prompt: str
|
||||
step: int
|
||||
|
@ -24,4 +24,5 @@ class Text2ImageParameters(Struct):
|
|||
height: int
|
||||
guidance: float
|
||||
seed: Optional[int]
|
||||
image: bool # if true indicates a bytestream is next msg
|
||||
upscaler: Optional[str]
|
||||
|
|
|
@ -12,7 +12,7 @@ from PIL import Image
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
EulerAncestralDiscreteScheduler
|
||||
)
|
||||
from realesrgan import RealESRGANer
|
||||
|
@ -31,7 +31,7 @@ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
|||
return np.asarray(img)
|
||||
|
||||
|
||||
def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
||||
def pipeline_for(algo: str, mem_fraction: float = 1.0, image=False):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||
|
@ -46,12 +46,18 @@ def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
|||
if algo == 'stable':
|
||||
params['revision'] = 'fp16'
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
if image:
|
||||
pipe_class = StableDiffusionImg2ImgPipeline
|
||||
else:
|
||||
pipe_class = StableDiffusionPipeline
|
||||
|
||||
pipe = pipe_class.from_pretrained(
|
||||
ALGOS[algo], **params)
|
||||
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config)
|
||||
|
||||
if not image:
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
return pipe.to('cuda')
|
||||
|
@ -89,6 +95,39 @@ def txt2img(
|
|||
image.save(output)
|
||||
|
||||
|
||||
def img2img(
|
||||
hf_token: str,
|
||||
model: str = 'midj',
|
||||
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||
img_path: str = 'input.png',
|
||||
output: str = 'output.png',
|
||||
guidance: float = 10,
|
||||
steps: int = 28,
|
||||
seed: Optional[int] = None
|
||||
):
|
||||
assert torch.cuda.is_available()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
login(token=hf_token)
|
||||
pipe = pipeline_for(model, image=True)
|
||||
|
||||
input_img = Image.open(img_path).convert('RGB')
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=input_img,
|
||||
guidance_scale=guidance, num_inference_steps=steps,
|
||||
generator=torch.Generator("cuda").manual_seed(seed)
|
||||
).images[0]
|
||||
|
||||
image.save(output)
|
||||
|
||||
|
||||
def upscale(
|
||||
img_path: str = 'input.png',
|
||||
output: str = 'output.png',
|
||||
|
|
|
@ -321,3 +321,59 @@ async def test_dgpu_heartbeat(dgpu_workers):
|
|||
) as test_rpc:
|
||||
await wait_for_dgpus(test_rpc, 1)
|
||||
await trio.sleep(120)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||
async def test_dgpu_img2img(dgpu_workers):
|
||||
|
||||
async with open_skynet_rpc(
|
||||
'1',
|
||||
security=True,
|
||||
cert_name='whitelist/testing',
|
||||
key_name='testing'
|
||||
) as rpc_call:
|
||||
await wait_for_dgpus(rpc_call, 1)
|
||||
|
||||
|
||||
res = await rpc_call(
|
||||
'txt2img', {
|
||||
'prompt': 'red old tractor in a sunny wheat field',
|
||||
'step': 28,
|
||||
'width': 512, 'height': 512,
|
||||
'guidance': 7.5,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[0],
|
||||
'upscaler': None
|
||||
})
|
||||
|
||||
if 'error' in res.result:
|
||||
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||
|
||||
img_raw = res.result['img']
|
||||
img = zlib.decompress(bytes.fromhex(img_raw))
|
||||
logging.info(img[:10])
|
||||
img = Image.open(io.BytesIO(img))
|
||||
|
||||
img.save('txt2img.png')
|
||||
|
||||
res = await rpc_call(
|
||||
'img2img', {
|
||||
'prompt': 'red sports car in a sunny wheat field',
|
||||
'step': 28,
|
||||
'img': img_raw,
|
||||
'guidance': 12,
|
||||
'seed': None,
|
||||
'algo': list(ALGOS.keys())[0],
|
||||
'upscaler': 'x4'
|
||||
})
|
||||
|
||||
if 'error' in res.result:
|
||||
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||
|
||||
img_raw = res.result['img']
|
||||
img = zlib.decompress(bytes.fromhex(img_raw))
|
||||
logging.info(img[:10])
|
||||
img = Image.open(io.BytesIO(img))
|
||||
|
||||
img.save('img2img.png')
|
||||
|
|
Loading…
Reference in New Issue