From 2b18fa376be3a81671baaedce221bae4b4497c07 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Mon, 29 May 2023 12:42:55 -0300 Subject: [PATCH] Add redo support to img2img also switch pinner to use http api --- skynet/cli.py | 97 ++++++++++++++----------- skynet/db/functions.py | 60 ++++++++++++---- skynet/frontend/telegram.py | 140 ++++++++++++++++++++++-------------- 3 files changed, 189 insertions(+), 108 deletions(-) diff --git a/skynet/cli.py b/skynet/cli.py index a5d4601..1d19ae3 100644 --- a/skynet/cli.py +++ b/skynet/cli.py @@ -384,22 +384,66 @@ def telegram( )) +class IPFSHTTP: + + def __init__(self, endpoint: str): + self.endpoint = endpoint + + def pin(self, cid: str): + return requests.post( + f'{self.endpoint}/api/v0/pin/add', + params={'arg': cid} + ) + + @run.command() @click.option('--loglevel', '-l', default='INFO', help='logging level') @click.option( - '--container', '-c', default='ipfs_host') + '--ipfs-rpc', '-i', default='http://127.0.0.1:5001') @click.option( '--hyperion-url', '-n', default='http://127.0.0.1:42001') -def pinner(loglevel, container, hyperion_url): +def pinner(loglevel, ipfs_rpc, hyperion_url): logging.basicConfig(level=loglevel) - dclient = docker.from_env() - - container = dclient.containers.get(container) - ipfs_node = IPFSDocker(container) + ipfs_node = IPFSHTTP(ipfs_rpc) hyperion = HyperionAPI(hyperion_url) last_pinned: dict[str, datetime] = {} + def capture_enqueues(half_min_ago: datetime): + # get all enqueues with binary data + # in the last minute + enqueues = hyperion.get_actions( + account='telos.gpu', + filter='telos.gpu:enqueue', + sort='desc', + after=half_min_ago.isoformat() + ) + + cids = [] + for action in enqueues['actions']: + cid = action['act']['data']['binary_data'] + if cid and cid not in last_pinned: + cids.append(cid) + + return cids + + def capture_submits(half_min_ago: datetime): + # get all submits in the last minute + submits = hyperion.get_actions( + account='telos.gpu', + filter='telos.gpu:submit', + sort='desc', + after=half_min_ago.isoformat() + ) + + cids = [] + for action in submits['actions']: + cid = action['act']['data']['ipfs_hash'] + if cid and cid not in last_pinned: + cids.append(cid) + + return cids + def cleanup_pinned(now: datetime): for cid in set(last_pinned.keys()): ts = last_pinned[cid] @@ -411,50 +455,23 @@ def pinner(loglevel, container, hyperion_url): now = datetime.now() half_min_ago = now - timedelta(seconds=30) - # get all enqueues with binary data - # in the last minute - enqueues = hyperion.get_actions( - account='telos.gpu', - filter='telos.gpu:enqueue', - sort='desc', - after=half_min_ago.isoformat() - ) - - # get all submits in the last minute - submits = hyperion.get_actions( - account='telos.gpu', - filter='telos.gpu:submit', - sort='desc', - after=half_min_ago.isoformat() - ) - # filter for the ones not already pinned - cids = [ - *[ - action['act']['data']['binary_data'] - for action in enqueues['actions'] - if action['act']['data']['binary_data'] - not in last_pinned - ], - *[ - action['act']['data']['ipfs_hash'] - for action in submits['actions'] - if action['act']['data']['ipfs_hash'] - not in last_pinned - ] - ] + cids = [*capture_enqueues(half_min_ago), *capture_submits(half_min_ago)] # pin and remember for cid in cids: last_pinned[cid] = now - ipfs_node.pin(cid) + resp = ipfs_node.pin(cid) + if resp.status_code != 200: + logging.error(f'error pinning {cid}:\n{resp.text}') - logging.info(f'pinned {cid}') + else: + logging.info(f'pinned {cid}') cleanup_pinned(now) - time.sleep(1) + time.sleep(0.1) except KeyboardInterrupt: ... diff --git a/skynet/db/functions.py b/skynet/db/functions.py index 2b8e192..08067ae 100644 --- a/skynet/db/functions.py +++ b/skynet/db/functions.py @@ -25,11 +25,14 @@ DB_INIT_SQL = ''' CREATE SCHEMA IF NOT EXISTS skynet; CREATE TABLE IF NOT EXISTS skynet.user( - id SERIAL PRIMARY KEY NOT NULL, - generated INT NOT NULL, - joined TIMESTAMP NOT NULL, - last_prompt TEXT, - role VARCHAR(128) NOT NULL + id SERIAL PRIMARY KEY NOT NULL, + generated INT NOT NULL, + joined TIMESTAMP NOT NULL, + last_method TEXT, + last_prompt TEXT, + last_file TEXT, + last_binary TEXT, + role VARCHAR(128) NOT NULL ); CREATE TABLE IF NOT EXISTS skynet.user_config( @@ -175,12 +178,26 @@ async def get_user_config(conn, user: int): async def get_user(conn, uid: int): return await get_user_config(conn, uid) +async def get_last_method_of(conn, user: int): + stmt = await conn.prepare( + 'SELECT last_method FROM skynet.user WHERE id = $1') + return await stmt.fetchval(user) async def get_last_prompt_of(conn, user: int): stmt = await conn.prepare( 'SELECT last_prompt FROM skynet.user WHERE id = $1') return await stmt.fetchval(user) +async def get_last_file_of(conn, user: int): + stmt = await conn.prepare( + 'SELECT last_file FROM skynet.user WHERE id = $1') + return await stmt.fetchval(user) + +async def get_last_binary_of(conn, user: int): + stmt = await conn.prepare( + 'SELECT last_binary FROM skynet.user WHERE id = $1') + return await stmt.fetchval(user) + async def new_user(conn, uid: int): if await get_user(conn, uid): @@ -192,12 +209,15 @@ async def new_user(conn, uid: int): async with conn.transaction(): stmt = await conn.prepare(''' INSERT INTO skynet.user( - id, generated, joined, last_prompt, role) + id, generated, joined, + last_method, last_prompt, last_file, last_binary, + role + ) - VALUES($1, $2, $3, $4, $5) + VALUES($1, $2, $3, $4, $5, $6, $7, $8) ''') await stmt.fetch( - uid, 0, date, None, DEFAULT_ROLE + uid, 0, date, 'txt2img', None, None, None, DEFAULT_ROLE ) stmt = await conn.prepare(''' @@ -222,7 +242,8 @@ async def get_or_create_user(conn, uid: str): user = await get_user(conn, uid) if not user: - user = await new_user(conn, uid) + await new_user(conn, uid) + user = await get_user(conn, uid) return user @@ -253,11 +274,7 @@ async def get_user_stats(conn, user: int): record = records[0] return record -async def update_user_stats( - conn, - user: int, - last_prompt: Optional[str] = None -): +async def increment_generated(conn, user: int): stmt = await conn.prepare(''' UPDATE skynet.user SET generated = generated + 1 @@ -265,5 +282,20 @@ async def update_user_stats( ''') await stmt.fetch(user) +async def update_user_stats( + conn, + user: int, + method: str, + last_prompt: str | None = None, + last_file: str | None = None, + last_binary: str | None = None +): + await update_user(conn, user, 'last_method', method) if last_prompt: await update_user(conn, user, 'last_prompt', last_prompt) + if last_file: + await update_user(conn, user, 'last_file', last_file) + if last_binary: + await update_user(conn, user, 'last_binary', last_binary) + + logging.info((method, last_prompt, last_binary)) diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py index 193a6fd..3a9c964 100644 --- a/skynet/frontend/telegram.py +++ b/skynet/frontend/telegram.py @@ -128,9 +128,8 @@ async def work_request( account: str, permission: str, params: dict, - ipfs_node, file_id: str | None = None, - file_path: str | None = None + binary_data: str = '' ): if params['seed'] == None: params['seed'] = random.randint(0, 9e18) @@ -141,30 +140,8 @@ async def work_request( }) request_time = datetime.now().isoformat() - if file_id: - image_raw = await bot.download_file(file_path) - with Image.open(io.BytesIO(image_raw)) as image: - w, h = image.size - - if w > 512 or h > 512: - logging.warning(f'user sent img of size {image.size}') - image.thumbnail((512, 512)) - logging.warning(f'resized it to {image.size}') - - image.save(f'ipfs-docker-staging/image.png', format='PNG') - - ipfs_hash = ipfs_node.add('image.png') - ipfs_node.pin(ipfs_hash) - - logging.info(f'published input image {ipfs_hash} on ipfs') - - binary = ipfs_hash - - else: - binary = '' - ec, out = cleos.push_action( - 'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}' + 'telos.gpu', 'enqueue', [account, body, binary_data, '20.0000 GPU'], f'{account}@{permission}' ) out = collect_stdout(out) if ec != 0: @@ -173,7 +150,7 @@ async def work_request( nonce = await get_user_nonce(cleos, account) request_hash = sha256( - (str(nonce) + body + binary).encode('utf-8')).hexdigest().upper() + (str(nonce) + body + binary_data).encode('utf-8')).hexdigest().upper() request_id = int(out) logging.info(f'{request_id} enqueued.') @@ -209,13 +186,13 @@ async def work_request( return # attempt to get the image and send it - resp = await get_ipfs_file( - f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png') + ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png' + resp = await get_ipfs_file(ipfs_link) caption = generate_reply_caption( user, params, ipfs_hash, tx_hash, worker) - if resp.status_code != 200: + if not resp or resp.status_code != 200: logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!') await bot.reply_to( message, @@ -233,11 +210,10 @@ async def work_request( InputMediaPhoto(file_id), InputMediaPhoto( resp.raw, - caption=caption + caption=caption, + parse_mode='HTML' ) ], - reply_markup=build_redo_menu(), - parse_mode='HTML' ) else: # txt2img @@ -307,10 +283,18 @@ async def run_skynet_telegram( async def send_cool_words(message): await bot.reply_to(message, '\n'.join(COOL_WORDS)) - @bot.message_handler(commands=['txt2img']) - async def send_txt2img(message): - user = message.from_user - chat = message.chat + async def _generic_txt2img(message_or_query): + if isinstance(message_or_query, CallbackQuery): + query = message_or_query + message = query.message + user = query.from_user + chat = query.message.chat + + else: + message = message_or_query + user = message.from_user + chat = message.chat + reply_id = None if chat.type == 'group' and chat.id == GROUP_ID: reply_id = message.message_id @@ -332,19 +316,30 @@ async def run_skynet_telegram( **user_config } - await db_call('update_user_stats', user.id, last_prompt=prompt) + await db_call( + 'update_user_stats', user.id, 'txt2img', last_prompt=prompt) - await work_request( + ec = await work_request( bot, cleos, hyperion, message, user, chat, - account, permission, params, - ipfs_node + account, permission, params ) - @bot.message_handler(func=lambda message: True, content_types=['photo']) - async def send_img2img(message): - user = message.from_user - chat = message.chat + if ec == 0: + await db_call('increment_generated', user.id) + + async def _generic_img2img(message_or_query): + if isinstance(message_or_query, CallbackQuery): + query = message_or_query + message = query.message + user = query.from_user + chat = query.message.chat + + else: + message = message_or_query + user = message.from_user + chat = message.chat + reply_id = None if chat.type == 'group' and chat.id == GROUP_ID: reply_id = message.message_id @@ -364,6 +359,22 @@ async def run_skynet_telegram( file_id = message.photo[-1].file_id file_path = (await bot.get_file(file_id)).file_path + image_raw = await bot.download_file(file_path) + + with Image.open(io.BytesIO(image_raw)) as image: + w, h = image.size + + if w > 512 or h > 512: + logging.warning(f'user sent img of size {image.size}') + image.thumbnail((512, 512)) + logging.warning(f'resized it to {image.size}') + + image.save(f'ipfs-docker-staging/image.png', format='PNG') + + ipfs_hash = ipfs_node.add('image.png') + ipfs_node.pin(ipfs_hash) + + logging.info(f'published input image {ipfs_hash} on ipfs') logging.info(f'mid: {message.id}') @@ -376,16 +387,34 @@ async def run_skynet_telegram( **user_config } - await db_call('update_user_stats', user.id, last_prompt=prompt) + await db_call( + 'update_user_stats', + user.id, + 'img2img', + last_file=file_id, + last_prompt=prompt, + last_binary=ipfs_hash + ) - await work_request( + ec = await work_request( bot, cleos, hyperion, message, user, chat, account, permission, params, - ipfs_node, - file_id=file_id, file_path=file_path + file_id=file_id, + binary_data=ipfs_hash ) + if ec == 0: + await db_call('increment_generated', user.id) + + @bot.message_handler(commands=['txt2img']) + async def send_txt2img(message): + await _generic_txt2img(message) + + @bot.message_handler(func=lambda message: True, content_types=[ + 'photo', 'document']) + async def send_img2img(message): + await _generic_img2img(message) @bot.message_handler(commands=['img2img']) async def img2img_missing_image(message): @@ -406,12 +435,15 @@ async def run_skynet_telegram( user = message.from_user chat = message.chat - reply_id = None - if chat.type == 'group' and chat.id == GROUP_ID: - reply_id = message.message_id - + method = await db_call('get_last_method_of', user.id) prompt = await db_call('get_last_prompt_of', user.id) + file_id = None + binary = '' + if method == 'img2img': + file_id = await db_call('get_last_file_of', user.id) + binary = await db_call('get_last_binary_of', user.id) + if not prompt: await bot.reply_to( message, @@ -433,7 +465,8 @@ async def run_skynet_telegram( bot, cleos, hyperion, message, user, chat, account, permission, params, - ipfs_node + file_id=file_id, + binary_data=binary ) @bot.message_handler(commands=['redo']) @@ -485,7 +518,6 @@ async def run_skynet_telegram( await bot.send_message(GROUP_ID, message.text[4:]) - @bot.message_handler(func=lambda message: True) async def echo_message(message): if message.text[0] == '/':