mirror of https://github.com/skygpu/skynet.git
Fix for img2img mode on new worker system
parent
91edb2aa56
commit
c8a0a390a6
|
@ -50,7 +50,7 @@ def txt2img(*args, **kwargs):
|
|||
utils.txt2img(hf_token, **kwargs)
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', '-m', default='midj')
|
||||
@click.option('--model', '-m', default=list(MODELS.keys())[0])
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
@click.option('--input', '-i', default='input.png')
|
||||
|
|
|
@ -34,6 +34,7 @@ def init_env_from_config(
|
|||
sub_config = config['skynet.dgpu']
|
||||
if 'hf_token' in sub_config:
|
||||
hf_token = sub_config['hf_token']
|
||||
os.environ['HF_TOKEN'] = hf_token
|
||||
|
||||
if 'HF_HOME' in os.environ:
|
||||
hf_home = os.environ['HF_HOME']
|
||||
|
@ -42,6 +43,7 @@ def init_env_from_config(
|
|||
sub_config = config['skynet.dgpu']
|
||||
if 'hf_home' in sub_config:
|
||||
hf_home = sub_config['hf_home']
|
||||
os.environ['HF_HOME'] = hf_home
|
||||
|
||||
if 'TG_TOKEN' in os.environ:
|
||||
tg_token = os.environ['TG_TOKEN']
|
||||
|
|
|
@ -133,7 +133,7 @@ class SkynetMM:
|
|||
|
||||
arguments = prepare_params_for_diffuse(params, binary)
|
||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||
model = self.get_model(params['model'], 'image' in params)
|
||||
model = self.get_model(params['model'], 'image' in extra_params)
|
||||
|
||||
image = model(
|
||||
prompt,
|
||||
|
|
|
@ -55,6 +55,8 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
|
|||
if w > max_w or h > max_h:
|
||||
image.thumbnail((512, 512))
|
||||
|
||||
return image.convert('RGB')
|
||||
|
||||
|
||||
def pipeline_for(model: str, mem_fraction: float = 1.0, image=False) -> DiffusionPipeline:
|
||||
assert torch.cuda.is_available()
|
||||
|
@ -147,7 +149,8 @@ def img2img(
|
|||
login(token=hf_token)
|
||||
pipe = pipeline_for(model, image=True)
|
||||
|
||||
input_img = Image.open(img_path).convert('RGB')
|
||||
with open(img_path, 'rb') as img_file:
|
||||
input_img = convert_from_bytes_and_crop(img_file.read(), 512, 512)
|
||||
|
||||
seed = seed if seed else random.randint(0, 2 ** 64)
|
||||
prompt = prompt
|
||||
|
@ -162,7 +165,6 @@ def img2img(
|
|||
image.save(output)
|
||||
|
||||
|
||||
|
||||
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
||||
return RealESRGANer(
|
||||
scale=4,
|
||||
|
|
Loading…
Reference in New Issue