mirror of https://github.com/skygpu/skynet.git
				
				
				
			Finish telegram frontend integration
Correctly update user stats after a txt2img request Swap tg_id to BIGINT on database Refactor help topic system to be more generic Fix skynet brain cli entry point Fix multi request rpc session, we werent creating a context on the req side Fix redo method DGPU now returns image metadata Improve error messages brain to frontend Add version as constant Add telegram dep to requirementspull/2/head
							parent
							
								
									896b0f684b
								
							
						
					
					
						commit
						cb92aed51c
					
				| 
						 | 
					@ -5,3 +5,4 @@ aiohttp
 | 
				
			||||||
msgspec
 | 
					msgspec
 | 
				
			||||||
pyOpenSSL
 | 
					pyOpenSSL
 | 
				
			||||||
trio_asyncio
 | 
					trio_asyncio
 | 
				
			||||||
 | 
					pyTelegramBotAPI
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										4
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										4
									
								
								setup.py
								
								
								
								
							| 
						 | 
					@ -1,8 +1,10 @@
 | 
				
			||||||
from setuptools import setup, find_packages
 | 
					from setuptools import setup, find_packages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from skynet.constants import VERSION
 | 
				
			||||||
 | 
					
 | 
				
			||||||
setup(
 | 
					setup(
 | 
				
			||||||
    name='skynet',
 | 
					    name='skynet',
 | 
				
			||||||
    version='0.1.0a6',
 | 
					    version=VERSION,
 | 
				
			||||||
    description='Decentralized compute platform',
 | 
					    description='Decentralized compute platform',
 | 
				
			||||||
    author='Guillermo Rodriguez',
 | 
					    author='Guillermo Rodriguez',
 | 
				
			||||||
    author_email='guillermo@telos.net',
 | 
					    author_email='guillermo@telos.net',
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,6 +4,7 @@ import json
 | 
				
			||||||
import uuid
 | 
					import uuid
 | 
				
			||||||
import base64
 | 
					import base64
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					import traceback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from uuid import UUID
 | 
					from uuid import UUID
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
| 
						 | 
					@ -90,10 +91,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
				
			||||||
        logging.info(f'pre next_worker: {next_worker}')
 | 
					        logging.info(f'pre next_worker: {next_worker}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if next_worker == None:
 | 
					        if next_worker == None:
 | 
				
			||||||
            raise SkynetDGPUOffline
 | 
					            raise SkynetDGPUOffline('No workers connected, try again later')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if are_all_workers_busy():
 | 
					        if are_all_workers_busy():
 | 
				
			||||||
            raise SkynetDGPUOverloaded
 | 
					            raise SkynetDGPUOverloaded('All workers are busy at the moment')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        nid = list(nodes.keys())[next_worker]
 | 
					        nid = list(nodes.keys())[next_worker]
 | 
				
			||||||
| 
						 | 
					@ -175,6 +176,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
				
			||||||
        with trio.move_on_after(30):
 | 
					        with trio.move_on_after(30):
 | 
				
			||||||
            await img_event.wait()
 | 
					            await img_event.wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logging.info(f'img event: {ack_event.is_set()}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not img_event.is_set():
 | 
					        if not img_event.is_set():
 | 
				
			||||||
            disconnect_node(nid)
 | 
					            disconnect_node(nid)
 | 
				
			||||||
            raise SkynetDGPUComputeError('30 seconds timeout while processing request')
 | 
					            raise SkynetDGPUComputeError('30 seconds timeout while processing request')
 | 
				
			||||||
| 
						 | 
					@ -187,7 +190,7 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
				
			||||||
        if 'error' in img_resp.params:
 | 
					        if 'error' in img_resp.params:
 | 
				
			||||||
            raise SkynetDGPUComputeError(img_resp.params['error'])
 | 
					            raise SkynetDGPUComputeError(img_resp.params['error'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return rid, img_resp.params['img']
 | 
					        return rid, img_resp.params['img'], img_resp.params['meta']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def handle_user_request(rpc_ctx, req):
 | 
					    async def handle_user_request(rpc_ctx, req):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -202,39 +205,54 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
				
			||||||
                        user_config = {**(await get_user_config(conn, user))}
 | 
					                        user_config = {**(await get_user_config(conn, user))}
 | 
				
			||||||
                        del user_config['id']
 | 
					                        del user_config['id']
 | 
				
			||||||
                        prompt = req.params['prompt']
 | 
					                        prompt = req.params['prompt']
 | 
				
			||||||
                        user_config= {
 | 
					 | 
				
			||||||
                            key : req.params.get(key, val) 
 | 
					 | 
				
			||||||
                            for key, val in user_config.items()
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        req = ImageGenRequest(
 | 
					                        req = ImageGenRequest(
 | 
				
			||||||
                            prompt=prompt,
 | 
					                            prompt=prompt,
 | 
				
			||||||
                            **user_config
 | 
					                            **user_config
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                        rid, img = await dgpu_stream_one_img(req)
 | 
					                        rid, img, meta = await dgpu_stream_one_img(req)
 | 
				
			||||||
 | 
					                        logging.info(f'done streaming {rid}')
 | 
				
			||||||
                        result = {
 | 
					                        result = {
 | 
				
			||||||
                            'id': rid,
 | 
					                            'id': rid,
 | 
				
			||||||
                            'img': img
 | 
					                            'img': img,
 | 
				
			||||||
 | 
					                            'meta': meta
 | 
				
			||||||
                        }
 | 
					                        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        await update_user_stats(conn, user, last_prompt=prompt)
 | 
				
			||||||
 | 
					                        logging.info('updated user stats.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    case 'redo':
 | 
					                    case 'redo':
 | 
				
			||||||
                        logging.info('redo')
 | 
					                        logging.info('redo')
 | 
				
			||||||
                        user_config = await get_user_config(conn, user)
 | 
					                        user_config = {**(await get_user_config(conn, user))}
 | 
				
			||||||
 | 
					                        del user_config['id']
 | 
				
			||||||
                        prompt = await get_last_prompt_of(conn, user)
 | 
					                        prompt = await get_last_prompt_of(conn, user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        if prompt:
 | 
				
			||||||
                            req = ImageGenRequest(
 | 
					                            req = ImageGenRequest(
 | 
				
			||||||
                                prompt=prompt,
 | 
					                                prompt=prompt,
 | 
				
			||||||
                                **user_config
 | 
					                                **user_config
 | 
				
			||||||
                            )
 | 
					                            )
 | 
				
			||||||
                        rid, img = await dgpu_stream_one_img(req)
 | 
					                            rid, img, meta = await dgpu_stream_one_img(req)
 | 
				
			||||||
                            result = {
 | 
					                            result = {
 | 
				
			||||||
                                'id': rid,
 | 
					                                'id': rid,
 | 
				
			||||||
                            'img': img
 | 
					                                'img': img,
 | 
				
			||||||
 | 
					                                'meta': meta
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					                            await update_user_stats(conn, user)
 | 
				
			||||||
 | 
					                            logging.info('updated user stats.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        else:
 | 
				
			||||||
 | 
					                            result = {
 | 
				
			||||||
 | 
					                                'error': 'skynet_no_last_prompt',
 | 
				
			||||||
 | 
					                                'message': 'No prompt to redo, do txt2img first'
 | 
				
			||||||
                            }
 | 
					                            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    case 'config':
 | 
					                    case 'config':
 | 
				
			||||||
                        logging.info('config')
 | 
					                        logging.info('config')
 | 
				
			||||||
                        if req.params['attr'] in CONFIG_ATTRS:
 | 
					                        if req.params['attr'] in CONFIG_ATTRS:
 | 
				
			||||||
 | 
					                            logging.info(f'update: {req.params}')
 | 
				
			||||||
                            await update_user_config(
 | 
					                            await update_user_config(
 | 
				
			||||||
                                conn, user, req.params['attr'], req.params['val'])
 | 
					                                conn, user, req.params['attr'], req.params['val'])
 | 
				
			||||||
 | 
					                            logging.info('done')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    case 'stats':
 | 
					                    case 'stats':
 | 
				
			||||||
                        logging.info('stats')
 | 
					                        logging.info('stats')
 | 
				
			||||||
| 
						 | 
					@ -255,9 +273,10 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
				
			||||||
                'message': str(e)
 | 
					                'message': str(e)
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except SkynetDGPUOverloaded:
 | 
					        except SkynetDGPUOverloaded as e:
 | 
				
			||||||
            result = {
 | 
					            result = {
 | 
				
			||||||
                'error': 'skynet_dgpu_overloaded',
 | 
					                'error': 'skynet_dgpu_overloaded',
 | 
				
			||||||
 | 
					                'message': str(e),
 | 
				
			||||||
                'nodes': len(nodes)
 | 
					                'nodes': len(nodes)
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -266,14 +285,23 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
				
			||||||
                'error': 'skynet_dgpu_compute_error',
 | 
					                'error': 'skynet_dgpu_compute_error',
 | 
				
			||||||
                'message': str(e)
 | 
					                'message': str(e)
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					        except BaseException as e:
 | 
				
			||||||
 | 
					            traceback.print_exception(type(e), e, e.__traceback__)
 | 
				
			||||||
 | 
					            result = {
 | 
				
			||||||
 | 
					                'error': 'skynet_internal_error',
 | 
				
			||||||
 | 
					                'message': str(e)
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        resp = SkynetRPCResponse(result=result)
 | 
					        resp = SkynetRPCResponse(result=result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if security:
 | 
					        if security:
 | 
				
			||||||
            resp.sign(tls_key, 'skynet')
 | 
					            resp.sign(tls_key, 'skynet')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logging.info('sending response')
 | 
				
			||||||
        await rpc_ctx.asend(
 | 
					        await rpc_ctx.asend(
 | 
				
			||||||
            json.dumps(resp.to_dict()).encode())
 | 
					            json.dumps(resp.to_dict()).encode())
 | 
				
			||||||
 | 
					        rpc_ctx.close()
 | 
				
			||||||
 | 
					        logging.info('done')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def request_service(n):
 | 
					    async def request_service(n):
 | 
				
			||||||
        nonlocal next_worker
 | 
					        nonlocal next_worker
 | 
				
			||||||
| 
						 | 
					@ -329,6 +357,8 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
 | 
				
			||||||
            await ctx.asend(
 | 
					            await ctx.asend(
 | 
				
			||||||
                json.dumps(resp.to_dict()).encode())
 | 
					                json.dumps(resp.to_dict()).encode())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ctx.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async with trio.open_nursery() as n:
 | 
					    async with trio.open_nursery() as n:
 | 
				
			||||||
        n.start_soon(dgpu_image_streamer)
 | 
					        n.start_soon(dgpu_image_streamer)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,10 +8,12 @@ from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
import click
 | 
					import click
 | 
				
			||||||
 | 
					import trio_asyncio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import utils
 | 
					from . import utils
 | 
				
			||||||
from .dgpu import open_dgpu_node
 | 
					from .dgpu import open_dgpu_node
 | 
				
			||||||
from .brain import run_skynet
 | 
					from .brain import run_skynet
 | 
				
			||||||
 | 
					from .constants import ALGOS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .frontend.telegram import run_skynet_telegram
 | 
					from .frontend.telegram import run_skynet_telegram
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -61,16 +63,16 @@ def run(*args, **kwargs):
 | 
				
			||||||
@click.option(
 | 
					@click.option(
 | 
				
			||||||
    '--host', '-h', default='localhost:5432')
 | 
					    '--host', '-h', default='localhost:5432')
 | 
				
			||||||
@click.option(
 | 
					@click.option(
 | 
				
			||||||
    '--pass', '-p', default='password')
 | 
					    '--passwd', '-p', default='password')
 | 
				
			||||||
def skynet(
 | 
					def brain(
 | 
				
			||||||
    loglevel: str,
 | 
					    loglevel: str,
 | 
				
			||||||
    host: str,
 | 
					    host: str,
 | 
				
			||||||
    passw: str
 | 
					    passwd: str
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    async def _run_skynet():
 | 
					    async def _run_skynet():
 | 
				
			||||||
        async with run_skynet(
 | 
					        async with run_skynet(
 | 
				
			||||||
            db_host=host,
 | 
					            db_host=host,
 | 
				
			||||||
            db_pass=passw
 | 
					            db_pass=passwd
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            await trio.sleep_forever()
 | 
					            await trio.sleep_forever()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -86,13 +88,13 @@ def skynet(
 | 
				
			||||||
@click.option(
 | 
					@click.option(
 | 
				
			||||||
    '--cert', '-c', default='whitelist/dgpu')
 | 
					    '--cert', '-c', default='whitelist/dgpu')
 | 
				
			||||||
@click.option(
 | 
					@click.option(
 | 
				
			||||||
    '--algos', '-a', default=None)
 | 
					    '--algos', '-a', default=json.dumps(['midj']))
 | 
				
			||||||
def dgpu(
 | 
					def dgpu(
 | 
				
			||||||
    loglevel: str,
 | 
					    loglevel: str,
 | 
				
			||||||
    uid: str,
 | 
					    uid: str,
 | 
				
			||||||
    key: str,
 | 
					    key: str,
 | 
				
			||||||
    cert: str,
 | 
					    cert: str,
 | 
				
			||||||
    algos: Optional[str]
 | 
					    algos: str
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    trio.run(
 | 
					    trio.run(
 | 
				
			||||||
        partial(
 | 
					        partial(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,8 +1,10 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					VERSION = '0.1a6'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
 | 
					DOCKER_RUNTIME_CUDA = 'skynet:runtime-cuda'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DB_HOST = 'ancap.tech:34508'
 | 
					DB_HOST = 'localhost:5432'
 | 
				
			||||||
DB_USER = 'skynet'
 | 
					DB_USER = 'skynet'
 | 
				
			||||||
DB_PASS = 'password'
 | 
					DB_PASS = 'password'
 | 
				
			||||||
DB_NAME = 'skynet'
 | 
					DB_NAME = 'skynet'
 | 
				
			||||||
| 
						 | 
					@ -21,7 +23,7 @@ ALGOS = {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
N = '\n'
 | 
					N = '\n'
 | 
				
			||||||
HELP_TEXT = f'''
 | 
					HELP_TEXT = f'''
 | 
				
			||||||
test art bot v0.1a4
 | 
					test art bot v{VERSION}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
commands work on a user per user basis!
 | 
					commands work on a user per user basis!
 | 
				
			||||||
config is individual to each user!
 | 
					config is individual to each user!
 | 
				
			||||||
| 
						 | 
					@ -47,7 +49,7 @@ config is individual to each user!
 | 
				
			||||||
/config guidance NUMBER - prompt text importance
 | 
					/config guidance NUMBER - prompt text importance
 | 
				
			||||||
'''
 | 
					'''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
UNKNOWN_CMD_TEXT = 'unknown command! try sending \"/help\"'
 | 
					UNKNOWN_CMD_TEXT = 'Unknown command! Try sending \"/help\"'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
 | 
					DONATION_INFO = '0xf95335682DF281FFaB7E104EB87B69625d9622B6\ngoal: 25/650usd'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -74,22 +76,24 @@ COOL_WORDS = [
 | 
				
			||||||
    'michelangelo'
 | 
					    'michelangelo'
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HELP_STEP = '''
 | 
					HELP_TOPICS = {
 | 
				
			||||||
diffusion models are iterative processes – a repeated cycle that starts with a\
 | 
					    'step': '''
 | 
				
			||||||
 | 
					Diffusion models are iterative processes – a repeated cycle that starts with a\
 | 
				
			||||||
 random noise generated from text input. With each step, some noise is removed\
 | 
					 random noise generated from text input. With each step, some noise is removed\
 | 
				
			||||||
, resulting in a higher-quality image over time. The repetition stops when the\
 | 
					, resulting in a higher-quality image over time. The repetition stops when the\
 | 
				
			||||||
 desired number of steps completes.
 | 
					 desired number of steps completes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
around 25 sampling steps are usually enough to achieve high-quality images. Us\
 | 
					Around 25 sampling steps are usually enough to achieve high-quality images. Us\
 | 
				
			||||||
ing more may produce a slightly different picture, but not necessarily better \
 | 
					ing more may produce a slightly different picture, but not necessarily better \
 | 
				
			||||||
quality.
 | 
					quality.
 | 
				
			||||||
'''
 | 
					''',
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HELP_GUIDANCE = '''
 | 
					'guidance': '''
 | 
				
			||||||
the guidance scale is a parameter that controls how much the image generation\
 | 
					The guidance scale is a parameter that controls how much the image generation\
 | 
				
			||||||
 process follows the text prompt. The higher the value, the more image sticks\
 | 
					 process follows the text prompt. The higher the value, the more image sticks\
 | 
				
			||||||
 to a given text input.
 | 
					 to a given text input.
 | 
				
			||||||
'''
 | 
					'''
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
 | 
					HELP_UNKWNOWN_PARAM = 'don\'t have any info on that.'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										29
									
								
								skynet/db.py
								
								
								
								
							
							
						
						
									
										29
									
								
								skynet/db.py
								
								
								
								
							| 
						 | 
					@ -2,6 +2,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
from datetime import datetime
 | 
					from datetime import datetime
 | 
				
			||||||
from contextlib import asynccontextmanager as acm
 | 
					from contextlib import asynccontextmanager as acm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,7 +20,7 @@ CREATE SCHEMA IF NOT EXISTS skynet;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CREATE TABLE IF NOT EXISTS skynet.user(
 | 
					CREATE TABLE IF NOT EXISTS skynet.user(
 | 
				
			||||||
   id SERIAL PRIMARY KEY NOT NULL,
 | 
					   id SERIAL PRIMARY KEY NOT NULL,
 | 
				
			||||||
   tg_id INT,
 | 
					   tg_id BIGINT,
 | 
				
			||||||
   wp_id VARCHAR(128),
 | 
					   wp_id VARCHAR(128),
 | 
				
			||||||
   mx_id VARCHAR(128),
 | 
					   mx_id VARCHAR(128),
 | 
				
			||||||
   ig_id VARCHAR(128),
 | 
					   ig_id VARCHAR(128),
 | 
				
			||||||
| 
						 | 
					@ -47,7 +48,7 @@ CREATE TABLE IF NOT EXISTS skynet.user_config(
 | 
				
			||||||
    step INT NOT NULL,
 | 
					    step INT NOT NULL,
 | 
				
			||||||
    width INT NOT NULL,
 | 
					    width INT NOT NULL,
 | 
				
			||||||
    height INT NOT NULL,
 | 
					    height INT NOT NULL,
 | 
				
			||||||
    seed INT,
 | 
					    seed BIGINT,
 | 
				
			||||||
    guidance INT NOT NULL,
 | 
					    guidance INT NOT NULL,
 | 
				
			||||||
    upscaler VARCHAR(128)
 | 
					    upscaler VARCHAR(128)
 | 
				
			||||||
);
 | 
					);
 | 
				
			||||||
| 
						 | 
					@ -124,7 +125,7 @@ async def get_user_config(conn, user: int):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_last_prompt_of(conn, user: int):
 | 
					async def get_last_prompt_of(conn, user: int):
 | 
				
			||||||
    stms = await conn.prepare(
 | 
					    stmt = await conn.prepare(
 | 
				
			||||||
        'SELECT last_prompt FROM skynet.user WHERE id = $1')
 | 
					        'SELECT last_prompt FROM skynet.user WHERE id = $1')
 | 
				
			||||||
    return await stmt.fetchval(user)
 | 
					    return await stmt.fetchval(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -198,7 +199,12 @@ async def get_or_create_user(conn, uid: str):
 | 
				
			||||||
    return user
 | 
					    return user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def update_user(conn, user: int, attr: str, val):
 | 
					async def update_user(conn, user: int, attr: str, val):
 | 
				
			||||||
    ...
 | 
					    stmt = await conn.prepare(f'''
 | 
				
			||||||
 | 
					        UPDATE skynet.user
 | 
				
			||||||
 | 
					        SET {attr} = $2
 | 
				
			||||||
 | 
					        WHERE id = $1
 | 
				
			||||||
 | 
					    ''')
 | 
				
			||||||
 | 
					    await stmt.fetch(user, val)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def update_user_config(conn, user: int, attr: str, val):
 | 
					async def update_user_config(conn, user: int, attr: str, val):
 | 
				
			||||||
    stmt = await conn.prepare(f'''
 | 
					    stmt = await conn.prepare(f'''
 | 
				
			||||||
| 
						 | 
					@ -218,3 +224,18 @@ async def get_user_stats(conn, user: int):
 | 
				
			||||||
    assert len(records) == 1
 | 
					    assert len(records) == 1
 | 
				
			||||||
    record = records[0]
 | 
					    record = records[0]
 | 
				
			||||||
    return record
 | 
					    return record
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def update_user_stats(
 | 
				
			||||||
 | 
					    conn,
 | 
				
			||||||
 | 
					    user: int,
 | 
				
			||||||
 | 
					    last_prompt: Optional[str] = None
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    stmt = await conn.prepare('''
 | 
				
			||||||
 | 
					        UPDATE skynet.user
 | 
				
			||||||
 | 
					        SET generated = generated + 1
 | 
				
			||||||
 | 
					        WHERE id = $1
 | 
				
			||||||
 | 
					    ''')
 | 
				
			||||||
 | 
					    await stmt.fetch(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if last_prompt:
 | 
				
			||||||
 | 
					        await update_user(conn, user, 'last_prompt', last_prompt)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -109,7 +109,6 @@ async def open_dgpu_node(
 | 
				
			||||||
                'generated': 0
 | 
					                'generated': 0
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        seed = ireq.seed if ireq.seed else random.randint(0, 2 ** 64)
 | 
					 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            image = models[ireq.algo]['pipe'](
 | 
					            image = models[ireq.algo]['pipe'](
 | 
				
			||||||
                ireq.prompt,
 | 
					                ireq.prompt,
 | 
				
			||||||
| 
						 | 
					@ -117,7 +116,7 @@ async def open_dgpu_node(
 | 
				
			||||||
                height=ireq.height,
 | 
					                height=ireq.height,
 | 
				
			||||||
                guidance_scale=ireq.guidance,
 | 
					                guidance_scale=ireq.guidance,
 | 
				
			||||||
                num_inference_steps=ireq.step,
 | 
					                num_inference_steps=ireq.step,
 | 
				
			||||||
                generator=torch.Generator("cuda").manual_seed(seed)
 | 
					                generator=torch.Generator("cuda").manual_seed(ireq.seed)
 | 
				
			||||||
            ).images[0]
 | 
					            ).images[0]
 | 
				
			||||||
            return image.tobytes()
 | 
					            return image.tobytes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -207,12 +206,18 @@ async def open_dgpu_node(
 | 
				
			||||||
                    logging.info(f'sent ack, processing {req.rid}...')
 | 
					                    logging.info(f'sent ack, processing {req.rid}...')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    try:
 | 
					                    try:
 | 
				
			||||||
                        img = await gpu_compute_one(
 | 
					                        img_req = ImageGenRequest(**req.params)
 | 
				
			||||||
                            ImageGenRequest(**req.params))
 | 
					                        if not img_req.seed:
 | 
				
			||||||
 | 
					                            img_req.seed = random.randint(0, 2 ** 64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        img = await gpu_compute_one(img_req)
 | 
				
			||||||
                        img_resp = DGPUBusResponse(
 | 
					                        img_resp = DGPUBusResponse(
 | 
				
			||||||
                            rid=req.rid,
 | 
					                            rid=req.rid,
 | 
				
			||||||
                            nid=req.nid,
 | 
					                            nid=req.nid,
 | 
				
			||||||
                            params={'img': base64.b64encode(img).hex()}
 | 
					                            params={
 | 
				
			||||||
 | 
					                                'img': base64.b64encode(img).hex(),
 | 
				
			||||||
 | 
					                                'meta': img_req.to_dict()
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    except DGPUComputeError as e:
 | 
					                    except DGPUComputeError as e:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -90,13 +90,14 @@ async def open_skynet_rpc(
 | 
				
			||||||
            if security:
 | 
					            if security:
 | 
				
			||||||
                req.sign(tls_key, cert_name)
 | 
					                req.sign(tls_key, cert_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            await sock.asend(
 | 
					            ctx = sock.new_context()
 | 
				
			||||||
 | 
					            await ctx.asend(
 | 
				
			||||||
                json.dumps(
 | 
					                json.dumps(
 | 
				
			||||||
                    req.to_dict()).encode())
 | 
					                    req.to_dict()).encode())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            resp = SkynetRPCResponse(
 | 
					            resp = SkynetRPCResponse(
 | 
				
			||||||
                **json.loads(
 | 
					                **json.loads((await ctx.arecv()).decode()))
 | 
				
			||||||
                    (await sock.arecv_msg()).bytes.decode()))
 | 
					            ctx.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if security:
 | 
					            if security:
 | 
				
			||||||
                resp.verify(skynet_cert)
 | 
					                resp.verify(skynet_cert)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,14 +1,19 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
 | 
					import base64
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from datetime import datetime
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pynng
 | 
					import pynng
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from telebot.async_telebot import AsyncTeleBot
 | 
					from PIL import Image
 | 
				
			||||||
from trio_asyncio import aio_as_trio
 | 
					from trio_asyncio import aio_as_trio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from telebot.types import InputFile
 | 
				
			||||||
 | 
					from telebot.async_telebot import AsyncTeleBot
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..constants import *
 | 
					from ..constants import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import *
 | 
					from . import *
 | 
				
			||||||
| 
						 | 
					@ -17,6 +22,17 @@ from . import *
 | 
				
			||||||
PREFIX = 'tg'
 | 
					PREFIX = 'tg'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def prepare_metainfo_caption(meta: dict) -> str:
 | 
				
			||||||
 | 
					    meta_str = f'prompt: \"{meta["prompt"]}\"\n'
 | 
				
			||||||
 | 
					    meta_str += f'seed: {meta["seed"]}\n'
 | 
				
			||||||
 | 
					    meta_str += f'step: {meta["step"]}\n'
 | 
				
			||||||
 | 
					    meta_str += f'guidance: {meta["guidance"]}\n'
 | 
				
			||||||
 | 
					    meta_str += f'algo: \"{meta["algo"]}\"\n'
 | 
				
			||||||
 | 
					    meta_str += f'sampler: k_euler_ancestral\n'
 | 
				
			||||||
 | 
					    meta_str += f'skynet v{VERSION}'
 | 
				
			||||||
 | 
					    return meta_str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def run_skynet_telegram(
 | 
					async def run_skynet_telegram(
 | 
				
			||||||
    tg_token: str,
 | 
					    tg_token: str,
 | 
				
			||||||
    key_name: str = 'telegram-frontend',
 | 
					    key_name: str = 'telegram-frontend',
 | 
				
			||||||
| 
						 | 
					@ -26,44 +42,96 @@ async def run_skynet_telegram(
 | 
				
			||||||
    logging.basicConfig(level=logging.INFO)
 | 
					    logging.basicConfig(level=logging.INFO)
 | 
				
			||||||
    bot = AsyncTeleBot(tg_token)
 | 
					    bot = AsyncTeleBot(tg_token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with open_skynet_rpc(
 | 
					    async with open_skynet_rpc(
 | 
				
			||||||
        'skynet-telegram-0',
 | 
					        'skynet-telegram-0',
 | 
				
			||||||
        security=True,
 | 
					        security=True,
 | 
				
			||||||
        cert_name=cert,
 | 
					        cert_name=cert_name,
 | 
				
			||||||
        key_name=key
 | 
					        key_name=key_name
 | 
				
			||||||
    ) as rpc_call:
 | 
					    ) as rpc_call:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        async def _rpc_call(
 | 
					        async def _rpc_call(
 | 
				
			||||||
            uid: int,
 | 
					            uid: int,
 | 
				
			||||||
            method: str,
 | 
					            method: str,
 | 
				
			||||||
            params: dict
 | 
					            params: dict = {}
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            return await rpc_call(
 | 
					            return await rpc_call(
 | 
				
			||||||
                method, params, uid=f'{PREFIX}+{uid}')
 | 
					                method, params, uid=f'{PREFIX}+{uid}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(commands=['help'])
 | 
					        @bot.message_handler(commands=['help'])
 | 
				
			||||||
        async def send_help(message):
 | 
					        async def send_help(message):
 | 
				
			||||||
 | 
					            splt_msg = message.text.split(' ')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if len(splt_msg) == 1:
 | 
				
			||||||
                await bot.reply_to(message, HELP_TEXT)
 | 
					                await bot.reply_to(message, HELP_TEXT)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                param = splt_msg[1]
 | 
				
			||||||
 | 
					                if param in HELP_TOPICS:
 | 
				
			||||||
 | 
					                    await bot.reply_to(message, HELP_TOPICS[param])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    await bot.reply_to(message, HELP_UNKWNOWN_PARAM)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(commands=['cool'])
 | 
					        @bot.message_handler(commands=['cool'])
 | 
				
			||||||
        async def send_cool_words(message):
 | 
					        async def send_cool_words(message):
 | 
				
			||||||
            await bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
					            await bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(commands=['txt2img'])
 | 
					        @bot.message_handler(commands=['txt2img'])
 | 
				
			||||||
        async def send_txt2img(message):
 | 
					        async def send_txt2img(message):
 | 
				
			||||||
 | 
					            prompt = ' '.join(message.text.split(' ')[1:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if len(prompt) == 0:
 | 
				
			||||||
 | 
					                await bot.reply_to(message, 'Empty text prompt ignored.')
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            logging.info(f'mid: {message.id}')
 | 
				
			||||||
            resp = await _rpc_call(
 | 
					            resp = await _rpc_call(
 | 
				
			||||||
                message.from_user.id,
 | 
					                message.from_user.id,
 | 
				
			||||||
                'txt2img',
 | 
					                'txt2img',
 | 
				
			||||||
                {}
 | 
					                {'prompt': prompt}
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					            logging.info(f'resp to {message.id} arrived')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            resp_txt = ''
 | 
				
			||||||
 | 
					            if 'error' in resp.result:
 | 
				
			||||||
 | 
					                resp_txt = resp.result['message']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                logging.info(resp.result['id'])
 | 
				
			||||||
 | 
					                img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
 | 
				
			||||||
 | 
					                img = Image.frombytes('RGB', (512, 512), img_raw)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                await bot.send_photo(
 | 
				
			||||||
 | 
					                    message.chat.id,
 | 
				
			||||||
 | 
					                    caption=prepare_metainfo_caption(resp.result['meta']),
 | 
				
			||||||
 | 
					                    photo=img,
 | 
				
			||||||
 | 
					                    reply_to_message_id=message.id
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            await bot.reply_to(message, resp_txt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(commands=['redo'])
 | 
					        @bot.message_handler(commands=['redo'])
 | 
				
			||||||
        async def redo_txt2img(message):
 | 
					        async def redo_txt2img(message):
 | 
				
			||||||
            resp = await _rpc_call(
 | 
					            resp = await _rpc_call(message.from_user.id, 'redo')
 | 
				
			||||||
                message.from_user.id,
 | 
					
 | 
				
			||||||
                'redo',
 | 
					            resp_txt = ''
 | 
				
			||||||
                {}
 | 
					            if 'error' in resp.result:
 | 
				
			||||||
 | 
					                resp_txt = resp.result['message']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
 | 
				
			||||||
 | 
					                img = Image.frombytes('RGB', (512, 512), img_raw)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                await bot.send_photo(
 | 
				
			||||||
 | 
					                    message.chat.id,
 | 
				
			||||||
 | 
					                    caption=prepare_metainfo_caption(resp.result['meta']),
 | 
				
			||||||
 | 
					                    photo=img,
 | 
				
			||||||
 | 
					                    reply_to_message_id=message.id
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					                return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            await bot.reply_to(message, resp_txt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @bot.message_handler(commands=['config'])
 | 
					        @bot.message_handler(commands=['config'])
 | 
				
			||||||
        async def set_config(message):
 | 
					        async def set_config(message):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue