mirror of https://github.com/skygpu/skynet.git
Add img2img mode
parent
585d304f86
commit
97f7d51782
1
setup.py
1
setup.py
|
@ -13,6 +13,7 @@ setup(
|
||||||
'console_scripts': [
|
'console_scripts': [
|
||||||
'skynet = skynet.cli:skynet',
|
'skynet = skynet.cli:skynet',
|
||||||
'txt2img = skynet.cli:txt2img',
|
'txt2img = skynet.cli:txt2img',
|
||||||
|
'img2img = skynet.cli:img2img',
|
||||||
'upscale = skynet.cli:upscale'
|
'upscale = skynet.cli:upscale'
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
@ -164,7 +164,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
event.set()
|
event.set()
|
||||||
del wip_reqs[rid]
|
del wip_reqs[rid]
|
||||||
|
|
||||||
async def dgpu_stream_one_img(req: Text2ImageParameters):
|
async def dgpu_stream_one_img(req: DiffusionParameters, img_buf=None):
|
||||||
nonlocal wip_reqs, fin_reqs, next_worker
|
nonlocal wip_reqs, fin_reqs, next_worker
|
||||||
nid = get_next_worker()
|
nid = get_next_worker()
|
||||||
idx = list(nodes.keys()).index(nid)
|
idx = list(nodes.keys()).index(nid)
|
||||||
|
@ -186,7 +186,13 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
dgpu_req.auth.cert = 'skynet'
|
dgpu_req.auth.cert = 'skynet'
|
||||||
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
|
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
|
||||||
|
|
||||||
await dgpu_bus.asend(dgpu_req.SerializeToString())
|
msg = dgpu_req.SerializeToString()
|
||||||
|
if img_buf:
|
||||||
|
logging.info(f'sending img of size {len(img_buf)} as attachment')
|
||||||
|
logging.info(img_buf[:10])
|
||||||
|
msg = f'BINEXT%$%$'.encode() + msg + b'%$%$' + img_buf
|
||||||
|
|
||||||
|
await dgpu_bus.asend(msg)
|
||||||
|
|
||||||
with trio.move_on_after(4):
|
with trio.move_on_after(4):
|
||||||
await ack_event.wait()
|
await ack_event.wait()
|
||||||
|
@ -237,12 +243,38 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
del user_config['id']
|
del user_config['id']
|
||||||
user_config.update(MessageToDict(req.params))
|
user_config.update(MessageToDict(req.params))
|
||||||
|
|
||||||
req = Text2ImageParameters(**user_config)
|
req = DiffusionParameters(**user_config, image=False)
|
||||||
rid, img, meta = await dgpu_stream_one_img(req)
|
rid, img, meta = await dgpu_stream_one_img(req)
|
||||||
logging.info(f'done streaming {rid}')
|
logging.info(f'done streaming {rid}')
|
||||||
result = {
|
result = {
|
||||||
'id': rid,
|
'id': rid,
|
||||||
'img': zlib.compress(img).hex(),
|
'img': img.hex(),
|
||||||
|
'meta': meta
|
||||||
|
}
|
||||||
|
|
||||||
|
await update_user_stats(conn, user, last_prompt=user_config['prompt'])
|
||||||
|
logging.info('updated user stats.')
|
||||||
|
|
||||||
|
case 'img2img':
|
||||||
|
logging.info('img2img')
|
||||||
|
user_config = {**(await get_user_config(conn, user))}
|
||||||
|
del user_config['id']
|
||||||
|
|
||||||
|
params = MessageToDict(req.params)
|
||||||
|
img_buf = bytes.fromhex(params['img'])
|
||||||
|
del params['img']
|
||||||
|
user_config.update(params)
|
||||||
|
|
||||||
|
req = DiffusionParameters(**user_config, image=True)
|
||||||
|
|
||||||
|
if not req.image:
|
||||||
|
raise AssertionError('Didn\'t enable image flag for img2img?')
|
||||||
|
|
||||||
|
rid, img, meta = await dgpu_stream_one_img(req, img_buf=img_buf)
|
||||||
|
logging.info(f'done streaming {rid}')
|
||||||
|
result = {
|
||||||
|
'id': rid,
|
||||||
|
'img': img.hex(),
|
||||||
'meta': meta
|
'meta': meta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,14 +288,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
prompt = await get_last_prompt_of(conn, user)
|
prompt = await get_last_prompt_of(conn, user)
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
req = Text2ImageParameters(
|
req = DiffusionParameters(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
**user_config
|
**user_config
|
||||||
)
|
)
|
||||||
rid, img, meta = await dgpu_stream_one_img(req)
|
rid, img, meta = await dgpu_stream_one_img(req)
|
||||||
result = {
|
result = {
|
||||||
'id': rid,
|
'id': rid,
|
||||||
'img': zlib.compress(img).hex(),
|
'img': img.hex(),
|
||||||
'meta': meta
|
'meta': meta
|
||||||
}
|
}
|
||||||
await update_user_stats(conn, user)
|
await update_user_stats(conn, user)
|
||||||
|
|
|
@ -41,6 +41,28 @@ def txt2img(*args, **kwargs):
|
||||||
assert 'HF_TOKEN' in os.environ
|
assert 'HF_TOKEN' in os.environ
|
||||||
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
utils.txt2img(os.environ['HF_TOKEN'], **kwargs)
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option('--model', '-m', default='midj')
|
||||||
|
@click.option(
|
||||||
|
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||||
|
@click.option('--input', '-i', default='input.png')
|
||||||
|
@click.option('--output', '-o', default='output.png')
|
||||||
|
@click.option('--guidance', '-g', default=10.0)
|
||||||
|
@click.option('--steps', '-s', default=26)
|
||||||
|
@click.option('--seed', '-S', default=None)
|
||||||
|
def img2img(model, prompt, input, output, guidance, steps, seed):
|
||||||
|
assert 'HF_TOKEN' in os.environ
|
||||||
|
utils.img2img(
|
||||||
|
os.environ['HF_TOKEN'],
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
img_path=input,
|
||||||
|
output=output,
|
||||||
|
guidance=guidance,
|
||||||
|
steps=steps,
|
||||||
|
seed=seed
|
||||||
|
)
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option('--input', '-i', default='input.png')
|
@click.option('--input', '-i', default='input.png')
|
||||||
@click.option('--output', '-o', default='output.png')
|
@click.option('--output', '-o', default='output.png')
|
||||||
|
|
|
@ -6,10 +6,12 @@ import trio
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
import zlib
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
|
@ -25,6 +27,7 @@ from OpenSSL.crypto import (
|
||||||
)
|
)
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
|
StableDiffusionImg2ImgPipeline,
|
||||||
EulerAncestralDiscreteScheduler
|
EulerAncestralDiscreteScheduler
|
||||||
)
|
)
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
@ -138,8 +141,9 @@ async def open_dgpu_node(
|
||||||
logging.info('memory summary:')
|
logging.info('memory summary:')
|
||||||
logging.info('\n' + torch.cuda.memory_summary())
|
logging.info('\n' + torch.cuda.memory_summary())
|
||||||
|
|
||||||
async def gpu_compute_one(ireq: Text2ImageParameters):
|
async def gpu_compute_one(ireq: DiffusionParameters, image=None):
|
||||||
if ireq.algo not in models:
|
algo = ireq.algo + 'img' if image else ireq.algo
|
||||||
|
if algo not in models:
|
||||||
least_used = list(models.keys())[0]
|
least_used = list(models.keys())[0]
|
||||||
for model in models:
|
for model in models:
|
||||||
if models[least_used]['generated'] > models[model]['generated']:
|
if models[least_used]['generated'] > models[model]['generated']:
|
||||||
|
@ -148,16 +152,23 @@ async def open_dgpu_node(
|
||||||
del models[least_used]
|
del models[least_used]
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
models[ireq.algo] = {
|
models[algo] = {
|
||||||
'pipe': pipeline_for(ireq.algo),
|
'pipe': pipeline_for(ireq.algo, image=True if image else False),
|
||||||
'generated': 0
|
'generated': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_params = {}
|
||||||
|
if ireq.image:
|
||||||
|
_params['image'] = image
|
||||||
|
|
||||||
|
else:
|
||||||
|
_params['width'] = int(ireq.width)
|
||||||
|
_params['height'] = int(ireq.height)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image = models[ireq.algo]['pipe'](
|
image = models[algo]['pipe'](
|
||||||
ireq.prompt,
|
ireq.prompt,
|
||||||
width=int(ireq.width),
|
**_params,
|
||||||
height=int(ireq.height),
|
|
||||||
guidance_scale=ireq.guidance,
|
guidance_scale=ireq.guidance,
|
||||||
num_inference_steps=int(ireq.step),
|
num_inference_steps=int(ireq.step),
|
||||||
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
||||||
|
@ -173,7 +184,9 @@ async def open_dgpu_node(
|
||||||
image = convert_from_cv2_to_image(up_img)
|
image = convert_from_cv2_to_image(up_img)
|
||||||
logging.info('done')
|
logging.info('done')
|
||||||
|
|
||||||
raw_img = image.tobytes()
|
img_byte_arr = io.BytesIO()
|
||||||
|
image.save(img_byte_arr, format='PNG')
|
||||||
|
raw_img = img_byte_arr.getvalue()
|
||||||
logging.info(f'final img size {len(raw_img)} bytes.')
|
logging.info(f'final img size {len(raw_img)} bytes.')
|
||||||
|
|
||||||
return raw_img
|
return raw_img
|
||||||
|
@ -256,8 +269,19 @@ async def open_dgpu_node(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
msg = await dgpu_bus.arecv()
|
||||||
|
|
||||||
|
img = None
|
||||||
|
if b'BINEXT' in msg:
|
||||||
|
header, msg, img_raw = msg.split(b'%$%$')
|
||||||
|
logging.info(f'got img attachment of size {len(img_raw)}')
|
||||||
|
logging.info(img_raw[:10])
|
||||||
|
raw_img = zlib.decompress(img_raw)
|
||||||
|
logging.info(raw_img[:10])
|
||||||
|
img = Image.open(io.BytesIO(raw_img))
|
||||||
|
|
||||||
req = DGPUBusMessage()
|
req = DGPUBusMessage()
|
||||||
req.ParseFromString(await dgpu_bus.arecv())
|
req.ParseFromString(msg)
|
||||||
last_msg = time.time()
|
last_msg = time.time()
|
||||||
|
|
||||||
if req.method == 'heartbeat':
|
if req.method == 'heartbeat':
|
||||||
|
@ -301,11 +325,12 @@ async def open_dgpu_node(
|
||||||
logging.info(f'sent ack, processing {req.rid}...')
|
logging.info(f'sent ack, processing {req.rid}...')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img_req = Text2ImageParameters(**req.params)
|
img_req = DiffusionParameters(**req.params)
|
||||||
|
|
||||||
if not img_req.seed:
|
if not img_req.seed:
|
||||||
img_req.seed = random.randint(0, 2 ** 64)
|
img_req.seed = random.randint(0, 2 ** 64)
|
||||||
|
|
||||||
img = await gpu_compute_one(img_req)
|
img = await gpu_compute_one(img_req, image=img)
|
||||||
img_resp = DGPUBusMessage(
|
img_resp = DGPUBusMessage(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
nid=req.nid,
|
nid=req.nid,
|
||||||
|
@ -335,7 +360,7 @@ async def open_dgpu_node(
|
||||||
await dgpu_bus.asend(raw_msg)
|
await dgpu_bus.asend(raw_msg)
|
||||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||||
if img_resp.method == 'binary-reply':
|
if img_resp.method == 'binary-reply':
|
||||||
await dgpu_bus.asend(img)
|
await dgpu_bus.asend(zlib.compress(img))
|
||||||
logging.info(f'sent {len(img)} bytes.')
|
logging.info(f'sent {len(img)} bytes.')
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
|
@ -130,6 +130,57 @@ async def run_skynet_telegram(
|
||||||
|
|
||||||
await bot.reply_to(message, resp_txt)
|
await bot.reply_to(message, resp_txt)
|
||||||
|
|
||||||
|
@bot.message_handler(commands=['img2img'], content_types=['photo'])
|
||||||
|
async def send_img2img(message):
|
||||||
|
chat = message.chat
|
||||||
|
if chat.type != 'group' and chat.id != GROUP_ID:
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt = ' '.join(message.caption.split(' ')[1:])
|
||||||
|
|
||||||
|
if len(prompt) == 0:
|
||||||
|
await bot.reply_to(message, 'Empty text prompt ignored.')
|
||||||
|
return
|
||||||
|
|
||||||
|
file_id = message.photo[-1].file_id
|
||||||
|
file_path = bot.get_file(file_id).file_path
|
||||||
|
file_raw = bot.download_file(file_path)
|
||||||
|
img = zlib.compress(file_raw)
|
||||||
|
|
||||||
|
logging.info(f'mid: {message.id}')
|
||||||
|
resp = await _rpc_call(
|
||||||
|
message.from_user.id,
|
||||||
|
'img2img',
|
||||||
|
{'prompt': prompt, 'img': img.hex()}
|
||||||
|
)
|
||||||
|
logging.info(f'resp to {message.id} arrived')
|
||||||
|
|
||||||
|
resp_txt = ''
|
||||||
|
result = MessageToDict(resp.result)
|
||||||
|
if 'error' in resp.result:
|
||||||
|
resp_txt = resp.result['message']
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.info(result['id'])
|
||||||
|
img_raw = zlib.decompress(bytes.fromhex(result['img']))
|
||||||
|
logging.info(f'got image of size: {len(img_raw)}')
|
||||||
|
meta = result['meta']['meta']
|
||||||
|
size = (int(meta['width']), int(meta['height']))
|
||||||
|
if meta['upscaler'] == 'x4':
|
||||||
|
size = (size[0] * 4, size[1] * 4)
|
||||||
|
|
||||||
|
img = Image.frombytes('RGB', size, img_raw)
|
||||||
|
|
||||||
|
await bot.send_photo(
|
||||||
|
message.chat.id,
|
||||||
|
caption=prepare_metainfo_caption(meta),
|
||||||
|
photo=img,
|
||||||
|
reply_to_message_id=message.id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await bot.reply_to(message, resp_txt)
|
||||||
|
|
||||||
@bot.message_handler(commands=['redo'])
|
@bot.message_handler(commands=['redo'])
|
||||||
async def redo_txt2img(message):
|
async def redo_txt2img(message):
|
||||||
chat = message.chat
|
chat = message.chat
|
||||||
|
|
|
@ -16,7 +16,7 @@ class Struct:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Text2ImageParameters(Struct):
|
class DiffusionParameters(Struct):
|
||||||
algo: str
|
algo: str
|
||||||
prompt: str
|
prompt: str
|
||||||
step: int
|
step: int
|
||||||
|
@ -24,4 +24,5 @@ class Text2ImageParameters(Struct):
|
||||||
height: int
|
height: int
|
||||||
guidance: float
|
guidance: float
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
|
image: bool # if true indicates a bytestream is next msg
|
||||||
upscaler: Optional[str]
|
upscaler: Optional[str]
|
||||||
|
|
|
@ -12,7 +12,7 @@ from PIL import Image
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
EulerAncestralDiscreteScheduler
|
EulerAncestralDiscreteScheduler
|
||||||
)
|
)
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
@ -31,7 +31,7 @@ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
||||||
return np.asarray(img)
|
return np.asarray(img)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
def pipeline_for(algo: str, mem_fraction: float = 1.0, image=False):
|
||||||
assert torch.cuda.is_available()
|
assert torch.cuda.is_available()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
torch.cuda.set_per_process_memory_fraction(mem_fraction)
|
||||||
|
@ -46,13 +46,19 @@ def pipeline_for(algo: str, mem_fraction: float = 1.0):
|
||||||
if algo == 'stable':
|
if algo == 'stable':
|
||||||
params['revision'] = 'fp16'
|
params['revision'] = 'fp16'
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
if image:
|
||||||
|
pipe_class = StableDiffusionImg2ImgPipeline
|
||||||
|
else:
|
||||||
|
pipe_class = StableDiffusionPipeline
|
||||||
|
|
||||||
|
pipe = pipe_class.from_pretrained(
|
||||||
ALGOS[algo], **params)
|
ALGOS[algo], **params)
|
||||||
|
|
||||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
pipe.scheduler.config)
|
pipe.scheduler.config)
|
||||||
|
|
||||||
pipe.enable_vae_slicing()
|
if not image:
|
||||||
|
pipe.enable_vae_slicing()
|
||||||
|
|
||||||
return pipe.to('cuda')
|
return pipe.to('cuda')
|
||||||
|
|
||||||
|
@ -89,6 +95,39 @@ def txt2img(
|
||||||
image.save(output)
|
image.save(output)
|
||||||
|
|
||||||
|
|
||||||
|
def img2img(
|
||||||
|
hf_token: str,
|
||||||
|
model: str = 'midj',
|
||||||
|
prompt: str = 'a red old tractor in a sunny wheat field',
|
||||||
|
img_path: str = 'input.png',
|
||||||
|
output: str = 'output.png',
|
||||||
|
guidance: float = 10,
|
||||||
|
steps: int = 28,
|
||||||
|
seed: Optional[int] = None
|
||||||
|
):
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.set_per_process_memory_fraction(1.0)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
login(token=hf_token)
|
||||||
|
pipe = pipeline_for(model, image=True)
|
||||||
|
|
||||||
|
input_img = Image.open(img_path).convert('RGB')
|
||||||
|
|
||||||
|
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||||
|
prompt = prompt
|
||||||
|
image = pipe(
|
||||||
|
prompt,
|
||||||
|
image=input_img,
|
||||||
|
guidance_scale=guidance, num_inference_steps=steps,
|
||||||
|
generator=torch.Generator("cuda").manual_seed(seed)
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
image.save(output)
|
||||||
|
|
||||||
|
|
||||||
def upscale(
|
def upscale(
|
||||||
img_path: str = 'input.png',
|
img_path: str = 'input.png',
|
||||||
output: str = 'output.png',
|
output: str = 'output.png',
|
||||||
|
|
|
@ -321,3 +321,59 @@ async def test_dgpu_heartbeat(dgpu_workers):
|
||||||
) as test_rpc:
|
) as test_rpc:
|
||||||
await wait_for_dgpus(test_rpc, 1)
|
await wait_for_dgpus(test_rpc, 1)
|
||||||
await trio.sleep(120)
|
await trio.sleep(120)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'dgpu_workers', [(1, ['midj'])], indirect=True)
|
||||||
|
async def test_dgpu_img2img(dgpu_workers):
|
||||||
|
|
||||||
|
async with open_skynet_rpc(
|
||||||
|
'1',
|
||||||
|
security=True,
|
||||||
|
cert_name='whitelist/testing',
|
||||||
|
key_name='testing'
|
||||||
|
) as rpc_call:
|
||||||
|
await wait_for_dgpus(rpc_call, 1)
|
||||||
|
|
||||||
|
|
||||||
|
res = await rpc_call(
|
||||||
|
'txt2img', {
|
||||||
|
'prompt': 'red old tractor in a sunny wheat field',
|
||||||
|
'step': 28,
|
||||||
|
'width': 512, 'height': 512,
|
||||||
|
'guidance': 7.5,
|
||||||
|
'seed': None,
|
||||||
|
'algo': list(ALGOS.keys())[0],
|
||||||
|
'upscaler': None
|
||||||
|
})
|
||||||
|
|
||||||
|
if 'error' in res.result:
|
||||||
|
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||||
|
|
||||||
|
img_raw = res.result['img']
|
||||||
|
img = zlib.decompress(bytes.fromhex(img_raw))
|
||||||
|
logging.info(img[:10])
|
||||||
|
img = Image.open(io.BytesIO(img))
|
||||||
|
|
||||||
|
img.save('txt2img.png')
|
||||||
|
|
||||||
|
res = await rpc_call(
|
||||||
|
'img2img', {
|
||||||
|
'prompt': 'red sports car in a sunny wheat field',
|
||||||
|
'step': 28,
|
||||||
|
'img': img_raw,
|
||||||
|
'guidance': 12,
|
||||||
|
'seed': None,
|
||||||
|
'algo': list(ALGOS.keys())[0],
|
||||||
|
'upscaler': 'x4'
|
||||||
|
})
|
||||||
|
|
||||||
|
if 'error' in res.result:
|
||||||
|
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||||
|
|
||||||
|
img_raw = res.result['img']
|
||||||
|
img = zlib.decompress(bytes.fromhex(img_raw))
|
||||||
|
logging.info(img[:10])
|
||||||
|
img = Image.open(io.BytesIO(img))
|
||||||
|
|
||||||
|
img.save('img2img.png')
|
||||||
|
|
Loading…
Reference in New Issue