mirror of https://github.com/skygpu/skynet.git
				
				
				
			start adding txt2txt model support
							parent
							
								
									cbc9a89bb8
								
							
						
					
					
						commit
						e2ea9eb47c
					
				|  | @ -46,6 +46,7 @@ def txt2img(*args, **kwargs): | ||||||
|     _, hf_token, _ = init_env_from_config() |     _, hf_token, _ = init_env_from_config() | ||||||
|     utils.txt2img(hf_token, **kwargs) |     utils.txt2img(hf_token, **kwargs) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @click.command() | @click.command() | ||||||
| @click.option('--model', '-m', default=list(MODELS.keys())[0]) | @click.option('--model', '-m', default=list(MODELS.keys())[0]) | ||||||
| @click.option( | @click.option( | ||||||
|  | @ -71,6 +72,20 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed): | ||||||
|         seed=seed |         seed=seed | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | @click.command() | ||||||
|  | @click.option('--model', '-m', default='microsoft/DialoGPT-small') | ||||||
|  | @click.option( | ||||||
|  |     '--prompt', '-p', default='a red old tractor in a sunny wheat field') | ||||||
|  | @click.option('--output', '-o', default='output.txt') | ||||||
|  | @click.option('--temperature', '-t', default=1.0) | ||||||
|  | @click.option('--max-length', '-ml', default=256) | ||||||
|  | def txt2txt(*args, **kwargs): | ||||||
|  |     from . import utils | ||||||
|  |     _, hf_token, _, cfg = init_env_from_config() | ||||||
|  |     utils.txt2txt(hf_token, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @click.command() | @click.command() | ||||||
| @click.option('--input', '-i', default='input.png') | @click.option('--input', '-i', default='input.png') | ||||||
| @click.option('--output', '-o', default='output.png') | @click.option('--output', '-o', default='output.png') | ||||||
|  | @ -89,6 +104,7 @@ def download(): | ||||||
|     _, hf_token, _ = init_env_from_config() |     _, hf_token, _ = init_env_from_config() | ||||||
|     utils.download_all_models(hf_token) |     utils.download_all_models(hf_token) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @skynet.command() | @skynet.command() | ||||||
| @click.option( | @click.option( | ||||||
|     '--account', '-A', default=None) |     '--account', '-A', default=None) | ||||||
|  | @ -135,12 +151,14 @@ def enqueue( | ||||||
|         binary = '' |         binary = '' | ||||||
| 
 | 
 | ||||||
|         ec, out = cleos.push_action( |         ec, out = cleos.push_action( | ||||||
|             'telos.gpu', 'enqueue', [account, req, binary, reward], f'{account}@{permission}' |             'telos.gpu', 'enqueue', [account, req, | ||||||
|  |                                      binary, reward], f'{account}@{permission}' | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         print(collect_stdout(out)) |         print(collect_stdout(out)) | ||||||
|         assert ec == 0 |         assert ec == 0 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @skynet.command() | @skynet.command() | ||||||
| @click.option('--loglevel', '-l', default='INFO', help='Logging level') | @click.option('--loglevel', '-l', default='INFO', help='Logging level') | ||||||
| @click.option( | @click.option( | ||||||
|  | @ -176,6 +194,7 @@ def clean( | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @skynet.command() | @skynet.command() | ||||||
| @click.option( | @click.option( | ||||||
|     '--node-url', '-n', default='https://skynet.ancap.tech') |     '--node-url', '-n', default='https://skynet.ancap.tech') | ||||||
|  | @ -193,6 +212,7 @@ def queue(node_url: str): | ||||||
|     ) |     ) | ||||||
|     print(json.dumps(resp.json(), indent=4)) |     print(json.dumps(resp.json(), indent=4)) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @skynet.command() | @skynet.command() | ||||||
| @click.option( | @click.option( | ||||||
|     '--node-url', '-n', default='https://skynet.ancap.tech') |     '--node-url', '-n', default='https://skynet.ancap.tech') | ||||||
|  | @ -211,6 +231,7 @@ def status(node_url: str, request_id: int): | ||||||
|     ) |     ) | ||||||
|     print(json.dumps(resp.json(), indent=4)) |     print(json.dumps(resp.json(), indent=4)) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @skynet.command() | @skynet.command() | ||||||
| @click.option( | @click.option( | ||||||
|     '--account', '-a', default='telegram') |     '--account', '-a', default='telegram') | ||||||
|  | @ -236,12 +257,14 @@ def dequeue( | ||||||
| 
 | 
 | ||||||
|     with open_cleos(node_url, key=key) as cleos: |     with open_cleos(node_url, key=key) as cleos: | ||||||
|         ec, out = cleos.push_action( |         ec, out = cleos.push_action( | ||||||
|             'telos.gpu', 'dequeue', [account, request_id], f'{account}@{permission}' |             'telos.gpu', 'dequeue', [ | ||||||
|  |                 account, request_id], f'{account}@{permission}' | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         print(collect_stdout(out)) |         print(collect_stdout(out)) | ||||||
|         assert ec == 0 |         assert ec == 0 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @skynet.command() | @skynet.command() | ||||||
| @click.option( | @click.option( | ||||||
|     '--account', '-a', default='telos.gpu') |     '--account', '-a', default='telos.gpu') | ||||||
|  | @ -276,6 +299,7 @@ def config( | ||||||
|         print(collect_stdout(out)) |         print(collect_stdout(out)) | ||||||
|         assert ec == 0 |         assert ec == 0 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @skynet.command() | @skynet.command() | ||||||
| @click.option( | @click.option( | ||||||
|     '--account', '-a', default='telegram') |     '--account', '-a', default='telegram') | ||||||
|  | @ -304,10 +328,12 @@ def deposit( | ||||||
|         print(collect_stdout(out)) |         print(collect_stdout(out)) | ||||||
|         assert ec == 0 |         assert ec == 0 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @skynet.group() | @skynet.group() | ||||||
| def run(*args, **kwargs): | def run(*args, **kwargs): | ||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @run.command() | @run.command() | ||||||
| def db(): | def db(): | ||||||
|     logging.basicConfig(level=logging.INFO) |     logging.basicConfig(level=logging.INFO) | ||||||
|  | @ -315,12 +341,14 @@ def db(): | ||||||
|         container, passwd, host = db_params |         container, passwd, host = db_params | ||||||
|         logging.info(('skynet', passwd, host)) |         logging.info(('skynet', passwd, host)) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @run.command() | @run.command() | ||||||
| def nodeos(): | def nodeos(): | ||||||
|     logging.basicConfig(filename='skynet-nodeos.log', level=logging.INFO) |     logging.basicConfig(filename='skynet-nodeos.log', level=logging.INFO) | ||||||
|     with open_nodeos(cleanup=False): |     with open_nodeos(cleanup=False): | ||||||
|         ... |         ... | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @run.command() | @run.command() | ||||||
| @click.option('--loglevel', '-l', default='INFO', help='Logging level') | @click.option('--loglevel', '-l', default='INFO', help='Logging level') | ||||||
| @click.option( | @click.option( | ||||||
|  | @ -397,7 +425,6 @@ def telegram( | ||||||
|         async with frontend.open(): |         async with frontend.open(): | ||||||
|             await frontend.bot.infinity_polling() |             await frontend.bot.infinity_polling() | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     asyncio.run(_async_main()) |     asyncio.run(_async_main()) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -411,6 +438,7 @@ def ipfs(loglevel, name): | ||||||
|     with open_ipfs_node(name=name): |     with open_ipfs_node(name=name): | ||||||
|         ... |         ... | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @run.command() | @run.command() | ||||||
| @click.option('--loglevel', '-l', default='INFO', help='logging level') | @click.option('--loglevel', '-l', default='INFO', help='logging level') | ||||||
| @click.option( | @click.option( | ||||||
|  |  | ||||||
|  | @ -19,6 +19,7 @@ from diffusers import ( | ||||||
|     StableDiffusionImg2ImgPipeline, |     StableDiffusionImg2ImgPipeline, | ||||||
|     EulerAncestralDiscreteScheduler |     EulerAncestralDiscreteScheduler | ||||||
| ) | ) | ||||||
|  | from tansformers import pipeline, Conversation | ||||||
| from realesrgan import RealESRGANer | from realesrgan import RealESRGANer | ||||||
| from huggingface_hub import login | from huggingface_hub import login | ||||||
| 
 | 
 | ||||||
|  | @ -165,6 +166,43 @@ def img2img( | ||||||
|     image.save(output) |     image.save(output) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def txt2txt( | ||||||
|  |     hf_token: str, | ||||||
|  |     # TODO: change this to actual model ref | ||||||
|  |     #       add more granular control of models | ||||||
|  |     model: str = 'microsoft/DialoGPT-small', | ||||||
|  |     prompt: str = 'a red old tractor in a sunny wheat field', | ||||||
|  |     output: str = 'output.txt', | ||||||
|  |     temperature: float = 1.0, | ||||||
|  |     max_length: int = 256, | ||||||
|  | ): | ||||||
|  |     assert torch.cuda.is_available() | ||||||
|  |     torch.cuda.empty_cache() | ||||||
|  |     torch.cuda.set_per_process_memory_fraction(1.0) | ||||||
|  |     torch.backends.cuda.matmul.allow_tf32 = True | ||||||
|  |     torch.backends.cudnn.allow_tf32 = True | ||||||
|  | 
 | ||||||
|  |     login(token=hf_token) | ||||||
|  |     chatbot = pipeline('text-generation', model=model, device_map='auto') | ||||||
|  | 
 | ||||||
|  |     prompt = prompt | ||||||
|  |     conversation = Conversation(prompt) | ||||||
|  |     conversation = chatbot( | ||||||
|  |         conversation, | ||||||
|  |         max_length=max_length, | ||||||
|  |         do_sample=True, | ||||||
|  |         temperature=temperature | ||||||
|  |     ) | ||||||
|  |     response = conversation.generated_responses[-1] | ||||||
|  |     with open(output, 'w', encoding='utf-8') as f: | ||||||
|  |         f.write(response) | ||||||
|  | 
 | ||||||
|  |     # This if for continued conversatin, need to figure out how to store convo | ||||||
|  |     # conversation.add_user_input("Is it an action movie?") | ||||||
|  |     # conversation = chatbot(conversation) | ||||||
|  |     # conversation.generated_responses[-1] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'): | def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'): | ||||||
|     return RealESRGANer( |     return RealESRGANer( | ||||||
|         scale=4, |         scale=4, | ||||||
|  | @ -181,6 +219,7 @@ def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'): | ||||||
|         half=True |         half=True | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def upscale( | def upscale( | ||||||
|     img_path: str = 'input.png', |     img_path: str = 'input.png', | ||||||
|     output: str = 'output.png', |     output: str = 'output.png', | ||||||
|  | @ -201,7 +240,6 @@ def upscale( | ||||||
| 
 | 
 | ||||||
|     image = convert_from_cv2_to_image(up_img) |     image = convert_from_cv2_to_image(up_img) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     image.save(output) |     image.save(output) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue