From e2ea9eb47c6680d6c01228fdecdd35a9c89ebd08 Mon Sep 17 00:00:00 2001 From: Konstantine Tsafatinos Date: Sat, 3 Jun 2023 23:05:01 -0400 Subject: [PATCH] start adding txt2txt model support --- skynet/cli.py | 34 +++++++++++++++++++++++++++++++--- skynet/utils.py | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/skynet/cli.py b/skynet/cli.py index d7e389d..2558ca1 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -46,6 +46,7 @@ def txt2img(*args, **kwargs): _, hf_token, _ = init_env_from_config() utils.txt2img(hf_token, **kwargs) + @click.command() @click.option('--model', '-m', default=list(MODELS.keys())[0]) @click.option( @@ -71,6 +72,20 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed): seed=seed ) + +@click.command() +@click.option('--model', '-m', default='microsoft/DialoGPT-small') +@click.option( + '--prompt', '-p', default='a red old tractor in a sunny wheat field') +@click.option('--output', '-o', default='output.txt') +@click.option('--temperature', '-t', default=1.0) +@click.option('--max-length', '-ml', default=256) +def txt2txt(*args, **kwargs): + from . import utils + _, hf_token, _, cfg = init_env_from_config() + utils.txt2txt(hf_token, **kwargs) + + @click.command() @click.option('--input', '-i', default='input.png') @click.option('--output', '-o', default='output.png') @@ -89,6 +104,7 @@ def download(): _, hf_token, _ = init_env_from_config() utils.download_all_models(hf_token) + @skynet.command() @click.option( '--account', '-A', default=None) @@ -135,12 +151,14 @@ def enqueue( binary = '' ec, out = cleos.push_action( - 'telos.gpu', 'enqueue', [account, req, binary, reward], f'{account}@{permission}' + 'telos.gpu', 'enqueue', [account, req, + binary, reward], f'{account}@{permission}' ) print(collect_stdout(out)) assert ec == 0 + @skynet.command() @click.option('--loglevel', '-l', default='INFO', help='Logging level') @click.option( @@ -176,6 +194,7 @@ def clean( ) ) + @skynet.command() @click.option( '--node-url', '-n', default='https://skynet.ancap.tech') @@ -193,6 +212,7 @@ def queue(node_url: str): ) print(json.dumps(resp.json(), indent=4)) + @skynet.command() @click.option( '--node-url', '-n', default='https://skynet.ancap.tech') @@ -211,6 +231,7 @@ def status(node_url: str, request_id: int): ) print(json.dumps(resp.json(), indent=4)) + @skynet.command() @click.option( '--account', '-a', default='telegram') @@ -236,12 +257,14 @@ def dequeue( with open_cleos(node_url, key=key) as cleos: ec, out = cleos.push_action( - 'telos.gpu', 'dequeue', [account, request_id], f'{account}@{permission}' + 'telos.gpu', 'dequeue', [ + account, request_id], f'{account}@{permission}' ) print(collect_stdout(out)) assert ec == 0 + @skynet.command() @click.option( '--account', '-a', default='telos.gpu') @@ -276,6 +299,7 @@ def config( print(collect_stdout(out)) assert ec == 0 + @skynet.command() @click.option( '--account', '-a', default='telegram') @@ -304,10 +328,12 @@ def deposit( print(collect_stdout(out)) assert ec == 0 + @skynet.group() def run(*args, **kwargs): pass + @run.command() def db(): logging.basicConfig(level=logging.INFO) @@ -315,12 +341,14 @@ def db(): container, passwd, host = db_params logging.info(('skynet', passwd, host)) + @run.command() def nodeos(): logging.basicConfig(filename='skynet-nodeos.log', level=logging.INFO) with open_nodeos(cleanup=False): ... + @run.command() @click.option('--loglevel', '-l', default='INFO', help='Logging level') @click.option( @@ -397,7 +425,6 @@ def telegram( async with frontend.open(): await frontend.bot.infinity_polling() - asyncio.run(_async_main()) @@ -411,6 +438,7 @@ def ipfs(loglevel, name): with open_ipfs_node(name=name): ... + @run.command() @click.option('--loglevel', '-l', default='INFO', help='logging level') @click.option( diff --git a/skynet/utils.py b/skynet/utils.py index 2837118..e522e7e 100644 --- a/skynet/utils.py +++ b/skynet/utils.py @@ -19,6 +19,7 @@ from diffusers import ( StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler ) +from tansformers import pipeline, Conversation from realesrgan import RealESRGANer from huggingface_hub import login @@ -165,6 +166,43 @@ def img2img( image.save(output) +def txt2txt( + hf_token: str, + # TODO: change this to actual model ref + # add more granular control of models + model: str = 'microsoft/DialoGPT-small', + prompt: str = 'a red old tractor in a sunny wheat field', + output: str = 'output.txt', + temperature: float = 1.0, + max_length: int = 256, +): + assert torch.cuda.is_available() + torch.cuda.empty_cache() + torch.cuda.set_per_process_memory_fraction(1.0) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + login(token=hf_token) + chatbot = pipeline('text-generation', model=model, device_map='auto') + + prompt = prompt + conversation = Conversation(prompt) + conversation = chatbot( + conversation, + max_length=max_length, + do_sample=True, + temperature=temperature + ) + response = conversation.generated_responses[-1] + with open(output, 'w', encoding='utf-8') as f: + f.write(response) + + # This if for continued conversatin, need to figure out how to store convo + # conversation.add_user_input("Is it an action movie?") + # conversation = chatbot(conversation) + # conversation.generated_responses[-1] + + def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'): return RealESRGANer( scale=4, @@ -181,6 +219,7 @@ def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'): half=True ) + def upscale( img_path: str = 'input.png', output: str = 'output.png', @@ -201,7 +240,6 @@ def upscale( image = convert_from_cv2_to_image(up_img) - image.save(output)