mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add autowithdraw switch, start storing input images on ipfs
							parent
							
								
									303ed7b24f
								
							
						
					
					
						commit
						22c403d3ae
					
				|  | @ -282,6 +282,8 @@ def nodeos(): | ||||||
|     '--permission', '-p', default='active') |     '--permission', '-p', default='active') | ||||||
| @click.option( | @click.option( | ||||||
|     '--key', '-k', default=None) |     '--key', '-k', default=None) | ||||||
|  | @click.option( | ||||||
|  |     '--auto-withdraw', '-w', default=True) | ||||||
| @click.option( | @click.option( | ||||||
|     '--node-url', '-n', default='http://skynet.ancap.tech') |     '--node-url', '-n', default='http://skynet.ancap.tech') | ||||||
| @click.option( | @click.option( | ||||||
|  | @ -293,6 +295,7 @@ def dgpu( | ||||||
|     account: str, |     account: str, | ||||||
|     permission: str, |     permission: str, | ||||||
|     key: str | None, |     key: str | None, | ||||||
|  |     auto_withdraw: bool, | ||||||
|     node_url: str, |     node_url: str, | ||||||
|     ipfs_url: str, |     ipfs_url: str, | ||||||
|     algos: list[str] |     algos: list[str] | ||||||
|  | @ -321,6 +324,7 @@ def dgpu( | ||||||
|                 account, permission, |                 account, permission, | ||||||
|                 cleos, |                 cleos, | ||||||
|                 ipfs_url, |                 ipfs_url, | ||||||
|  |                 auto_withdraw=auto_withdraw, | ||||||
|                 key=key, initial_algos=json.loads(algos) |                 key=key, initial_algos=json.loads(algos) | ||||||
|         )) |         )) | ||||||
| 
 | 
 | ||||||
|  | @ -341,6 +345,8 @@ def dgpu( | ||||||
|     '--hyperion-url', '-n', default='http://test1.us.telos.net:42001') |     '--hyperion-url', '-n', default='http://test1.us.telos.net:42001') | ||||||
| @click.option( | @click.option( | ||||||
|     '--node-url', '-n', default='http://skynet.ancap.tech') |     '--node-url', '-n', default='http://skynet.ancap.tech') | ||||||
|  | @click.option( | ||||||
|  |     '--ipfs-url', '-n', default='/ip4/169.197.142.4/tcp/4001/p2p/12D3KooWKHKPFuqJPeqYgtUJtfZTHvEArRX2qvThYBrjuTuPg2Nx') | ||||||
| @click.option( | @click.option( | ||||||
|     '--db-host', '-h', default='localhost:5432') |     '--db-host', '-h', default='localhost:5432') | ||||||
| @click.option( | @click.option( | ||||||
|  | @ -352,8 +358,9 @@ def telegram( | ||||||
|     account: str, |     account: str, | ||||||
|     permission: str, |     permission: str, | ||||||
|     key: str | None, |     key: str | None, | ||||||
|     node_url: str, |  | ||||||
|     hyperion_url: str, |     hyperion_url: str, | ||||||
|  |     ipfs_url: str, | ||||||
|  |     node_url: str, | ||||||
|     db_host: str, |     db_host: str, | ||||||
|     db_user: str, |     db_user: str, | ||||||
|     db_pass: str |     db_pass: str | ||||||
|  | @ -372,6 +379,7 @@ def telegram( | ||||||
|             node_url, |             node_url, | ||||||
|             hyperion_url, |             hyperion_url, | ||||||
|             db_host, db_user, db_pass, |             db_host, db_user, db_pass, | ||||||
|  |             remote_ipfs_node=ipfs_url, | ||||||
|             key=key |             key=key | ||||||
|     )) |     )) | ||||||
| 
 | 
 | ||||||
|  | @ -400,9 +408,19 @@ def pinner(loglevel, container, hyperion_url): | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         while True: |         while True: | ||||||
|             # get all submits in the last minute |  | ||||||
|             now = datetime.now() |             now = datetime.now() | ||||||
|             half_min_ago = now - timedelta(seconds=30) |             half_min_ago = now - timedelta(seconds=30) | ||||||
|  | 
 | ||||||
|  |             # get all enqueues with binary data | ||||||
|  |             # in the last minute | ||||||
|  |             enqueues = hyperion.get_actions( | ||||||
|  |                 account='telos.gpu', | ||||||
|  |                 filter='telos.gpu:enqueue', | ||||||
|  |                 sort='desc', | ||||||
|  |                 after=half_min_ago.isoformat() | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             # get all submits in the last minute | ||||||
|             submits = hyperion.get_actions( |             submits = hyperion.get_actions( | ||||||
|                 account='telos.gpu', |                 account='telos.gpu', | ||||||
|                 filter='telos.gpu:submit', |                 filter='telos.gpu:submit', | ||||||
|  | @ -411,16 +429,23 @@ def pinner(loglevel, container, hyperion_url): | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             # filter for the ones not already pinned |             # filter for the ones not already pinned | ||||||
|             actions = [ |             cids = [ | ||||||
|                 action |                 *[ | ||||||
|                 for action in submits['actions'] |                     action['act']['data']['binary_data'] | ||||||
|                 if action['act']['data']['ipfs_hash'] |                     for action in enqueues['actions'] | ||||||
|                 not in last_pinned |                     if action['act']['data']['binary_data'] | ||||||
|  |                     not in last_pinned | ||||||
|  |                 ], | ||||||
|  |                 *[ | ||||||
|  |                     action['act']['data']['ipfs_hash'] | ||||||
|  |                     for action in submits['actions'] | ||||||
|  |                     if action['act']['data']['ipfs_hash'] | ||||||
|  |                     not in last_pinned | ||||||
|  |                 ] | ||||||
|             ] |             ] | ||||||
| 
 | 
 | ||||||
|             # pin and remember |             # pin and remember | ||||||
|             for action in actions: |             for cid in cids: | ||||||
|                 cid = action['act']['data']['ipfs_hash'] |  | ||||||
|                 last_pinned[cid] = now |                 last_pinned[cid] = now | ||||||
| 
 | 
 | ||||||
|                 ipfs_node.pin(cid) |                 ipfs_node.pin(cid) | ||||||
|  |  | ||||||
|  | @ -27,7 +27,7 @@ from realesrgan import RealESRGANer | ||||||
| from basicsr.archs.rrdbnet_arch import RRDBNet | from basicsr.archs.rrdbnet_arch import RRDBNet | ||||||
| from diffusers.models import UNet2DConditionModel | from diffusers.models import UNet2DConditionModel | ||||||
| 
 | 
 | ||||||
| from .ipfs import IPFSDocker, open_ipfs_node | from .ipfs import IPFSDocker, open_ipfs_node, get_ipfs_file | ||||||
| from .utils import * | from .utils import * | ||||||
| from .constants import * | from .constants import * | ||||||
| 
 | 
 | ||||||
|  | @ -60,6 +60,7 @@ async def open_dgpu_node( | ||||||
|     remote_ipfs_node: str, |     remote_ipfs_node: str, | ||||||
|     key: str = None, |     key: str = None, | ||||||
|     initial_algos: Optional[List[str]] = None, |     initial_algos: Optional[List[str]] = None, | ||||||
|  |     auto_withdraw: bool = True | ||||||
| ): | ): | ||||||
| 
 | 
 | ||||||
|     logging.basicConfig(level=logging.INFO) |     logging.basicConfig(level=logging.INFO) | ||||||
|  | @ -103,7 +104,7 @@ async def open_dgpu_node( | ||||||
|                         logging.info(f'resized it to {image.size}') |                         logging.info(f'resized it to {image.size}') | ||||||
| 
 | 
 | ||||||
|                 if algo not in models: |                 if algo not in models: | ||||||
|                     if algo not in ALGOS: |                     if params['algo'] not in ALGOS: | ||||||
|                         raise DGPUComputeError(f'Unknown algo \"{algo}\"') |                         raise DGPUComputeError(f'Unknown algo \"{algo}\"') | ||||||
| 
 | 
 | ||||||
|                     logging.info(f'{algo} not in loaded models, swapping...') |                     logging.info(f'{algo} not in loaded models, swapping...') | ||||||
|  | @ -266,7 +267,7 @@ async def open_dgpu_node( | ||||||
|     def publish_on_ipfs(img_sha: str, raw_img: bytes): |     def publish_on_ipfs(img_sha: str, raw_img: bytes): | ||||||
|         logging.info('publish_on_ipfs') |         logging.info('publish_on_ipfs') | ||||||
|         img = Image.open(io.BytesIO(raw_img)) |         img = Image.open(io.BytesIO(raw_img)) | ||||||
|         img.save(f'tmp/ipfs-docker-staging/image.png') |         img.save(f'ipfs-docker-staging/image.png') | ||||||
| 
 | 
 | ||||||
|         ipfs_hash = ipfs_node.add('image.png') |         ipfs_hash = ipfs_node.add('image.png') | ||||||
| 
 | 
 | ||||||
|  | @ -291,13 +292,24 @@ async def open_dgpu_node( | ||||||
|         print(collect_stdout(out)) |         print(collect_stdout(out)) | ||||||
|         assert ec == 0 |         assert ec == 0 | ||||||
| 
 | 
 | ||||||
|  |     async def get_input_data(ipfs_hash: str) -> bytes: | ||||||
|  |         if ipfs_hash == '': | ||||||
|  |             return b'' | ||||||
|  | 
 | ||||||
|  |         resp = await get_ipfs_file(f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png') | ||||||
|  |         if resp.status_code != 200: | ||||||
|  |             raise DGPUComputeError('Couldn\'t gather input data from ipfs') | ||||||
|  | 
 | ||||||
|  |         return resp.raw | ||||||
|  | 
 | ||||||
|     config = await get_global_config() |     config = await get_global_config() | ||||||
| 
 | 
 | ||||||
|     with open_ipfs_node() as ipfs_node: |     with open_ipfs_node() as ipfs_node: | ||||||
|         ipfs_node.connect(remote_ipfs_node) |         ipfs_node.connect(remote_ipfs_node) | ||||||
|         try: |         try: | ||||||
|             while True: |             while True: | ||||||
|                 maybe_withdraw_all() |                 if auto_withdraw: | ||||||
|  |                     maybe_withdraw_all() | ||||||
| 
 | 
 | ||||||
|                 queue = await get_work_requests_last_hour() |                 queue = await get_work_requests_last_hour() | ||||||
| 
 | 
 | ||||||
|  | @ -314,11 +326,15 @@ async def open_dgpu_node( | ||||||
| 
 | 
 | ||||||
|                         # parse request |                         # parse request | ||||||
|                         body = json.loads(req['body']) |                         body = json.loads(req['body']) | ||||||
|                         binary = bytes.fromhex(req['binary_data']) | 
 | ||||||
|  |                         binary = await get_input_data(req['binary_data']) | ||||||
|  | 
 | ||||||
|                         hash_str = ( |                         hash_str = ( | ||||||
|                             str(await get_user_nonce(req['user'])) |                             str(await get_user_nonce(req['user'])) | ||||||
|                             + |                             + | ||||||
|                             req['body'] |                             req['body'] | ||||||
|  |                             + | ||||||
|  |                             req['binary_data'] | ||||||
|                         ) |                         ) | ||||||
|                         logging.info(f'hashing: {hash_str}') |                         logging.info(f'hashing: {hash_str}') | ||||||
|                         request_hash = sha256(hash_str.encode('utf-8')).hexdigest() |                         request_hash = sha256(hash_str.encode('utf-8')).hexdigest() | ||||||
|  |  | ||||||
|  | @ -27,6 +27,7 @@ from telebot.async_telebot import AsyncTeleBot, ExceptionHandler | ||||||
| from telebot.formatting import hlink | from telebot.formatting import hlink | ||||||
| 
 | 
 | ||||||
| from ..db import open_new_database, open_database_connection | from ..db import open_new_database, open_database_connection | ||||||
|  | from ..ipfs import open_ipfs_node, get_ipfs_file | ||||||
| from ..constants import * | from ..constants import * | ||||||
| 
 | 
 | ||||||
| from . import * | from . import * | ||||||
|  | @ -45,7 +46,7 @@ def build_redo_menu(): | ||||||
|     return inline_keyboard |     return inline_keyboard | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def prepare_metainfo_caption(tguser, meta: dict) -> str: | def prepare_metainfo_caption(tguser, worker: str, meta: dict) -> str: | ||||||
|     prompt = meta["prompt"] |     prompt = meta["prompt"] | ||||||
|     if len(prompt) > 256: |     if len(prompt) > 256: | ||||||
|         prompt = prompt[:256] |         prompt = prompt[:256] | ||||||
|  | @ -55,7 +56,7 @@ def prepare_metainfo_caption(tguser, meta: dict) -> str: | ||||||
|     else: |     else: | ||||||
|         user = f'{tguser.first_name} id: {tguser.id}' |         user = f'{tguser.first_name} id: {tguser.id}' | ||||||
| 
 | 
 | ||||||
|     meta_str = f'<u>by {user}</u>\n' |     meta_str = f'<u>by {user}</u> <i>performed by {worker}</i>\n' | ||||||
| 
 | 
 | ||||||
|     meta_str += f'<code>prompt:</code> {prompt}\n' |     meta_str += f'<code>prompt:</code> {prompt}\n' | ||||||
|     meta_str += f'<code>seed: {meta["seed"]}</code>\n' |     meta_str += f'<code>seed: {meta["seed"]}</code>\n' | ||||||
|  | @ -76,7 +77,8 @@ def generate_reply_caption( | ||||||
|     tguser,  # telegram user |     tguser,  # telegram user | ||||||
|     params: dict, |     params: dict, | ||||||
|     ipfs_hash: str, |     ipfs_hash: str, | ||||||
|     tx_hash: str |     tx_hash: str, | ||||||
|  |     worker: str | ||||||
| ): | ): | ||||||
|     ipfs_link = hlink( |     ipfs_link = hlink( | ||||||
|         'Get your image on IPFS', |         'Get your image on IPFS', | ||||||
|  | @ -87,7 +89,7 @@ def generate_reply_caption( | ||||||
|         f'http://test1.us.telos.net:42001/v2/explore/transaction/{tx_hash}' |         f'http://test1.us.telos.net:42001/v2/explore/transaction/{tx_hash}' | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     meta_info = prepare_metainfo_caption(tguser, params) |     meta_info = prepare_metainfo_caption(tguser, worker, params) | ||||||
| 
 | 
 | ||||||
|     final_msg = '\n'.join([ |     final_msg = '\n'.join([ | ||||||
|         'Worker finished your task!', |         'Worker finished your task!', | ||||||
|  | @ -126,6 +128,7 @@ async def work_request( | ||||||
|     account: str, |     account: str, | ||||||
|     permission: str, |     permission: str, | ||||||
|     params: dict, |     params: dict, | ||||||
|  |     ipfs_node, | ||||||
|     file_id: str | None = None, |     file_id: str | None = None, | ||||||
|     file_path: str | None = None |     file_path: str | None = None | ||||||
| ): | ): | ||||||
|  | @ -147,11 +150,15 @@ async def work_request( | ||||||
|                 logging.warning(f'user sent img of size {image.size}') |                 logging.warning(f'user sent img of size {image.size}') | ||||||
|                 image.thumbnail((512, 512)) |                 image.thumbnail((512, 512)) | ||||||
|                 logging.warning(f'resized it to {image.size}') |                 logging.warning(f'resized it to {image.size}') | ||||||
|                 img_byte_arr = io.BytesIO() |  | ||||||
|                 image.save(img_byte_arr, format='PNG') |  | ||||||
|                 image_raw = img_byte_arr.getvalue() |  | ||||||
| 
 | 
 | ||||||
|         binary = image_raw.hex() |             image.save(f'ipfs-docker-staging/image.png', format='PNG') | ||||||
|  | 
 | ||||||
|  |             ipfs_hash = ipfs_node.add('image.png') | ||||||
|  |             ipfs_node.pin(ipfs_hash) | ||||||
|  | 
 | ||||||
|  |             logging.info(f'published input image {ipfs_hash} on ipfs') | ||||||
|  | 
 | ||||||
|  |             binary = ipfs_hash | ||||||
| 
 | 
 | ||||||
|     else: |     else: | ||||||
|         binary = '' |         binary = '' | ||||||
|  | @ -166,7 +173,7 @@ async def work_request( | ||||||
| 
 | 
 | ||||||
|     nonce = await get_user_nonce(cleos, account) |     nonce = await get_user_nonce(cleos, account) | ||||||
|     request_hash = sha256( |     request_hash = sha256( | ||||||
|         (str(nonce) + body).encode('utf-8')).hexdigest().upper() |         (str(nonce) + body + binary).encode('utf-8')).hexdigest().upper() | ||||||
| 
 | 
 | ||||||
|     request_id = int(out) |     request_id = int(out) | ||||||
|     logging.info(f'{request_id} enqueued.') |     logging.info(f'{request_id} enqueued.') | ||||||
|  | @ -190,7 +197,9 @@ async def work_request( | ||||||
|         ] |         ] | ||||||
|         if len(actions) > 0: |         if len(actions) > 0: | ||||||
|             tx_hash = actions[0]['trx_id'] |             tx_hash = actions[0]['trx_id'] | ||||||
|             ipfs_hash = actions[0]['act']['data']['ipfs_hash'] |             data = actions[0]['act']['data'] | ||||||
|  |             ipfs_hash = data['ipfs_hash'] | ||||||
|  |             worker = data['worker'] | ||||||
|             break |             break | ||||||
| 
 | 
 | ||||||
|         await asyncio.sleep(1) |         await asyncio.sleep(1) | ||||||
|  | @ -200,23 +209,14 @@ async def work_request( | ||||||
|         return |         return | ||||||
| 
 | 
 | ||||||
|     # attempt to get the image and send it |     # attempt to get the image and send it | ||||||
|     ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png' |     resp = await get_ipfs_file( | ||||||
|     logging.info(f'attempting to get image at {ipfs_link}') |         f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png') | ||||||
|     resp = None |  | ||||||
|     for i in range(10): |  | ||||||
|         try: |  | ||||||
|             resp = await asks.get(ipfs_link, timeout=2) |  | ||||||
| 
 |  | ||||||
|         except asks.errors.RequestTimeout: |  | ||||||
|             logging.warning('timeout...') |  | ||||||
|             ... |  | ||||||
| 
 |  | ||||||
|     logging.info(f'status_code: {resp.status_code}') |  | ||||||
| 
 | 
 | ||||||
|     caption = generate_reply_caption( |     caption = generate_reply_caption( | ||||||
|         user, params, ipfs_hash, tx_hash) |         user, params, ipfs_hash, tx_hash, worker) | ||||||
| 
 | 
 | ||||||
|     if resp.status_code != 200: |     if resp.status_code != 200: | ||||||
|  |         logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!') | ||||||
|         await bot.reply_to( |         await bot.reply_to( | ||||||
|             message, |             message, | ||||||
|             caption, |             caption, | ||||||
|  | @ -225,6 +225,7 @@ async def work_request( | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     else: |     else: | ||||||
|  |         logging.info(f'succes! sending generated image') | ||||||
|         if file_id:  # img2img |         if file_id:  # img2img | ||||||
|             await bot.send_media_group( |             await bot.send_media_group( | ||||||
|                 chat.id, |                 chat.id, | ||||||
|  | @ -258,6 +259,7 @@ async def run_skynet_telegram( | ||||||
|     db_host: str, |     db_host: str, | ||||||
|     db_user: str, |     db_user: str, | ||||||
|     db_pass: str, |     db_pass: str, | ||||||
|  |     remote_ipfs_node: str, | ||||||
|     key: str = None |     key: str = None | ||||||
| ): | ): | ||||||
|     dclient = docker.from_env() |     dclient = docker.from_env() | ||||||
|  | @ -280,224 +282,229 @@ async def run_skynet_telegram( | ||||||
|     bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler) |     bot = AsyncTeleBot(tg_token, exception_handler=SKYExceptionHandler) | ||||||
|     logging.info(f'tg_token: {tg_token}') |     logging.info(f'tg_token: {tg_token}') | ||||||
| 
 | 
 | ||||||
|     async with open_database_connection( |     with open_ipfs_node() as ipfs_node: | ||||||
|         db_user, db_pass, db_host |         ipfs_node.connect(remote_ipfs_node) | ||||||
|     ) as db_call: |         async with open_database_connection( | ||||||
|  |             db_user, db_pass, db_host | ||||||
|  |         ) as db_call: | ||||||
| 
 | 
 | ||||||
|         @bot.message_handler(commands=['help']) |             @bot.message_handler(commands=['help']) | ||||||
|         async def send_help(message): |             async def send_help(message): | ||||||
|             splt_msg = message.text.split(' ') |                 splt_msg = message.text.split(' ') | ||||||
| 
 | 
 | ||||||
|             if len(splt_msg) == 1: |                 if len(splt_msg) == 1: | ||||||
|                 await bot.reply_to(message, HELP_TEXT) |                     await bot.reply_to(message, HELP_TEXT) | ||||||
| 
 |  | ||||||
|             else: |  | ||||||
|                 param = splt_msg[1] |  | ||||||
|                 if param in HELP_TOPICS: |  | ||||||
|                     await bot.reply_to(message, HELP_TOPICS[param]) |  | ||||||
| 
 | 
 | ||||||
|                 else: |                 else: | ||||||
|                     await bot.reply_to(message, HELP_UNKWNOWN_PARAM) |                     param = splt_msg[1] | ||||||
|  |                     if param in HELP_TOPICS: | ||||||
|  |                         await bot.reply_to(message, HELP_TOPICS[param]) | ||||||
| 
 | 
 | ||||||
|         @bot.message_handler(commands=['cool']) |                     else: | ||||||
|         async def send_cool_words(message): |                         await bot.reply_to(message, HELP_UNKWNOWN_PARAM) | ||||||
|             await bot.reply_to(message, '\n'.join(COOL_WORDS)) |  | ||||||
| 
 | 
 | ||||||
|         @bot.message_handler(commands=['txt2img']) |             @bot.message_handler(commands=['cool']) | ||||||
|         async def send_txt2img(message): |             async def send_cool_words(message): | ||||||
|             user = message.from_user |                 await bot.reply_to(message, '\n'.join(COOL_WORDS)) | ||||||
|             chat = message.chat |  | ||||||
|             reply_id = None |  | ||||||
|             if chat.type == 'group' and chat.id == GROUP_ID: |  | ||||||
|                 reply_id = message.message_id |  | ||||||
| 
 | 
 | ||||||
|             prompt = ' '.join(message.text.split(' ')[1:]) |             @bot.message_handler(commands=['txt2img']) | ||||||
| 
 |             async def send_txt2img(message): | ||||||
|             if len(prompt) == 0: |  | ||||||
|                 await bot.reply_to(message, 'Empty text prompt ignored.') |  | ||||||
|                 return |  | ||||||
| 
 |  | ||||||
|             logging.info(f'mid: {message.id}') |  | ||||||
| 
 |  | ||||||
|             user_row = await db_call('get_or_create_user', user.id) |  | ||||||
|             user_config = {**user_row} |  | ||||||
|             del user_config['id'] |  | ||||||
| 
 |  | ||||||
|             params = { |  | ||||||
|                 'prompt': prompt, |  | ||||||
|                 **user_config |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             await db_call('update_user_stats', user.id, last_prompt=prompt) |  | ||||||
| 
 |  | ||||||
|             await work_request( |  | ||||||
|                 bot, cleos, hyperion, |  | ||||||
|                 message, user, chat, |  | ||||||
|                 account, permission, params |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         @bot.message_handler(func=lambda message: True, content_types=['photo']) |  | ||||||
|         async def send_img2img(message): |  | ||||||
|             user = message.from_user |  | ||||||
|             chat = message.chat |  | ||||||
|             reply_id = None |  | ||||||
|             if chat.type == 'group' and chat.id == GROUP_ID: |  | ||||||
|                 reply_id = message.message_id |  | ||||||
| 
 |  | ||||||
|             if not message.caption.startswith('/img2img'): |  | ||||||
|                 await bot.reply_to( |  | ||||||
|                     message, |  | ||||||
|                     'For image to image you need to add /img2img to the beggining of your caption' |  | ||||||
|                 ) |  | ||||||
|                 return |  | ||||||
| 
 |  | ||||||
|             prompt = ' '.join(message.caption.split(' ')[1:]) |  | ||||||
| 
 |  | ||||||
|             if len(prompt) == 0: |  | ||||||
|                 await bot.reply_to(message, 'Empty text prompt ignored.') |  | ||||||
|                 return |  | ||||||
| 
 |  | ||||||
|             file_id = message.photo[-1].file_id |  | ||||||
|             file_path = (await bot.get_file(file_id)).file_path |  | ||||||
| 
 |  | ||||||
|             logging.info(f'mid: {message.id}') |  | ||||||
| 
 |  | ||||||
|             user_row = await db_call('get_or_create_user', user.id) |  | ||||||
|             user_config = {**user_row} |  | ||||||
|             del user_config['id'] |  | ||||||
| 
 |  | ||||||
|             params = { |  | ||||||
|                 'prompt': prompt, |  | ||||||
|                 **user_config |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             await db_call('update_user_stats', user.id, last_prompt=prompt) |  | ||||||
| 
 |  | ||||||
|             await work_request( |  | ||||||
|                 bot, cleos, hyperion, |  | ||||||
|                 message, user, chat, |  | ||||||
|                 account, permission, params, |  | ||||||
|                 file_id=file_id, file_path=file_path |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         @bot.message_handler(commands=['img2img']) |  | ||||||
|         async def img2img_missing_image(message): |  | ||||||
|             await bot.reply_to( |  | ||||||
|                 message, |  | ||||||
|                 'seems you tried to do an img2img command without sending image' |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         async def _redo(message_or_query): |  | ||||||
|             if isinstance(message_or_query, CallbackQuery): |  | ||||||
|                 query = message_or_query |  | ||||||
|                 message = query.message |  | ||||||
|                 user = query.from_user |  | ||||||
|                 chat = query.message.chat |  | ||||||
| 
 |  | ||||||
|             else: |  | ||||||
|                 message = message_or_query |  | ||||||
|                 user = message.from_user |                 user = message.from_user | ||||||
|                 chat = message.chat |                 chat = message.chat | ||||||
|  |                 reply_id = None | ||||||
|  |                 if chat.type == 'group' and chat.id == GROUP_ID: | ||||||
|  |                     reply_id = message.message_id | ||||||
| 
 | 
 | ||||||
|             reply_id = None |                 prompt = ' '.join(message.text.split(' ')[1:]) | ||||||
|             if chat.type == 'group' and chat.id == GROUP_ID: |  | ||||||
|                 reply_id = message.message_id |  | ||||||
| 
 | 
 | ||||||
|             prompt = await db_call('get_last_prompt_of', user.id) |                 if len(prompt) == 0: | ||||||
|  |                     await bot.reply_to(message, 'Empty text prompt ignored.') | ||||||
|  |                     return | ||||||
| 
 | 
 | ||||||
|             if not prompt: |                 logging.info(f'mid: {message.id}') | ||||||
|  | 
 | ||||||
|  |                 user_row = await db_call('get_or_create_user', user.id) | ||||||
|  |                 user_config = {**user_row} | ||||||
|  |                 del user_config['id'] | ||||||
|  | 
 | ||||||
|  |                 params = { | ||||||
|  |                     'prompt': prompt, | ||||||
|  |                     **user_config | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 await db_call('update_user_stats', user.id, last_prompt=prompt) | ||||||
|  | 
 | ||||||
|  |                 await work_request( | ||||||
|  |                     bot, cleos, hyperion, | ||||||
|  |                     message, user, chat, | ||||||
|  |                     account, permission, params, | ||||||
|  |                     ipfs_node | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |             @bot.message_handler(func=lambda message: True, content_types=['photo']) | ||||||
|  |             async def send_img2img(message): | ||||||
|  |                 user = message.from_user | ||||||
|  |                 chat = message.chat | ||||||
|  |                 reply_id = None | ||||||
|  |                 if chat.type == 'group' and chat.id == GROUP_ID: | ||||||
|  |                     reply_id = message.message_id | ||||||
|  | 
 | ||||||
|  |                 if not message.caption.startswith('/img2img'): | ||||||
|  |                     await bot.reply_to( | ||||||
|  |                         message, | ||||||
|  |                         'For image to image you need to add /img2img to the beggining of your caption' | ||||||
|  |                     ) | ||||||
|  |                     return | ||||||
|  | 
 | ||||||
|  |                 prompt = ' '.join(message.caption.split(' ')[1:]) | ||||||
|  | 
 | ||||||
|  |                 if len(prompt) == 0: | ||||||
|  |                     await bot.reply_to(message, 'Empty text prompt ignored.') | ||||||
|  |                     return | ||||||
|  | 
 | ||||||
|  |                 file_id = message.photo[-1].file_id | ||||||
|  |                 file_path = (await bot.get_file(file_id)).file_path | ||||||
|  | 
 | ||||||
|  |                 logging.info(f'mid: {message.id}') | ||||||
|  | 
 | ||||||
|  |                 user_row = await db_call('get_or_create_user', user.id) | ||||||
|  |                 user_config = {**user_row} | ||||||
|  |                 del user_config['id'] | ||||||
|  | 
 | ||||||
|  |                 params = { | ||||||
|  |                     'prompt': prompt, | ||||||
|  |                     **user_config | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 await db_call('update_user_stats', user.id, last_prompt=prompt) | ||||||
|  | 
 | ||||||
|  |                 await work_request( | ||||||
|  |                     bot, cleos, hyperion, | ||||||
|  |                     message, user, chat, | ||||||
|  |                     account, permission, params, | ||||||
|  |                     ipfs_node, | ||||||
|  |                     file_id=file_id, file_path=file_path | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |             @bot.message_handler(commands=['img2img']) | ||||||
|  |             async def img2img_missing_image(message): | ||||||
|                 await bot.reply_to( |                 await bot.reply_to( | ||||||
|                     message, |                     message, | ||||||
|                     'no last prompt found, do a txt2img cmd first!' |                     'seems you tried to do an img2img command without sending image' | ||||||
|                 ) |                 ) | ||||||
|                 return | 
 | ||||||
|  |             async def _redo(message_or_query): | ||||||
|  |                 if isinstance(message_or_query, CallbackQuery): | ||||||
|  |                     query = message_or_query | ||||||
|  |                     message = query.message | ||||||
|  |                     user = query.from_user | ||||||
|  |                     chat = query.message.chat | ||||||
|  | 
 | ||||||
|  |                 else: | ||||||
|  |                     message = message_or_query | ||||||
|  |                     user = message.from_user | ||||||
|  |                     chat = message.chat | ||||||
|  | 
 | ||||||
|  |                 reply_id = None | ||||||
|  |                 if chat.type == 'group' and chat.id == GROUP_ID: | ||||||
|  |                     reply_id = message.message_id | ||||||
|  | 
 | ||||||
|  |                 prompt = await db_call('get_last_prompt_of', user.id) | ||||||
|  | 
 | ||||||
|  |                 if not prompt: | ||||||
|  |                     await bot.reply_to( | ||||||
|  |                         message, | ||||||
|  |                         'no last prompt found, do a txt2img cmd first!' | ||||||
|  |                     ) | ||||||
|  |                     return | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|             user_row = await db_call('get_or_create_user', user.id) |                 user_row = await db_call('get_or_create_user', user.id) | ||||||
|             user_config = {**user_row} |                 user_config = {**user_row} | ||||||
|             del user_config['id'] |                 del user_config['id'] | ||||||
| 
 | 
 | ||||||
|             params = { |                 params = { | ||||||
|                 'prompt': prompt, |                     'prompt': prompt, | ||||||
|                 **user_config |                     **user_config | ||||||
|             } |                 } | ||||||
| 
 | 
 | ||||||
|             await work_request( |                 await work_request( | ||||||
|                 bot, cleos, hyperion, |                     bot, cleos, hyperion, | ||||||
|                 message, user, chat, |                     message, user, chat, | ||||||
|                 account, permission, params |                     account, permission, params, | ||||||
|             ) |                     ipfs_node | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|         @bot.message_handler(commands=['redo']) |             @bot.message_handler(commands=['redo']) | ||||||
|         async def redo(message): |             async def redo(message): | ||||||
|             await _redo(message) |                 await _redo(message) | ||||||
| 
 | 
 | ||||||
|         @bot.message_handler(commands=['config']) |             @bot.message_handler(commands=['config']) | ||||||
|         async def set_config(message): |             async def set_config(message): | ||||||
|             user = message.from_user.id |                 user = message.from_user.id | ||||||
|             try: |                 try: | ||||||
|                 attr, val, reply_txt = validate_user_config_request( |                     attr, val, reply_txt = validate_user_config_request( | ||||||
|                     message.text) |                         message.text) | ||||||
| 
 | 
 | ||||||
|                 logging.info(f'user config update: {attr} to {val}') |                     logging.info(f'user config update: {attr} to {val}') | ||||||
|                 await db_call('update_user_config', user, attr, val) |                     await db_call('update_user_config', user, attr, val) | ||||||
|                 logging.info('done') |                     logging.info('done') | ||||||
| 
 | 
 | ||||||
|             except BaseException as e: |                 except BaseException as e: | ||||||
|                 reply_txt = str(e) |                     reply_txt = str(e) | ||||||
| 
 | 
 | ||||||
|             finally: |                 finally: | ||||||
|                 await bot.reply_to(message, reply_txt) |                     await bot.reply_to(message, reply_txt) | ||||||
| 
 | 
 | ||||||
|         @bot.message_handler(commands=['stats']) |             @bot.message_handler(commands=['stats']) | ||||||
|         async def user_stats(message): |             async def user_stats(message): | ||||||
|             user = message.from_user.id |                 user = message.from_user.id | ||||||
| 
 | 
 | ||||||
|             generated, joined, role = await db_call('get_user_stats', user) |                 generated, joined, role = await db_call('get_user_stats', user) | ||||||
| 
 | 
 | ||||||
|             stats_str = f'generated: {generated}\n' |                 stats_str = f'generated: {generated}\n' | ||||||
|             stats_str += f'joined: {joined}\n' |                 stats_str += f'joined: {joined}\n' | ||||||
|             stats_str += f'role: {role}\n' |                 stats_str += f'role: {role}\n' | ||||||
| 
 | 
 | ||||||
|             await bot.reply_to( |                 await bot.reply_to( | ||||||
|                 message, stats_str) |                     message, stats_str) | ||||||
| 
 | 
 | ||||||
|         @bot.message_handler(commands=['donate']) |             @bot.message_handler(commands=['donate']) | ||||||
|         async def donation_info(message): |             async def donation_info(message): | ||||||
|             await bot.reply_to( |                 await bot.reply_to( | ||||||
|                 message, DONATION_INFO) |                     message, DONATION_INFO) | ||||||
| 
 | 
 | ||||||
|         @bot.message_handler(commands=['say']) |             @bot.message_handler(commands=['say']) | ||||||
|         async def say(message): |             async def say(message): | ||||||
|             chat = message.chat |                 chat = message.chat | ||||||
|             user = message.from_user |                 user = message.from_user | ||||||
| 
 | 
 | ||||||
|             if (chat.type == 'group') or (user.id != 383385940): |                 if (chat.type == 'group') or (user.id != 383385940): | ||||||
|                 return |                     return | ||||||
| 
 | 
 | ||||||
|             await bot.send_message(GROUP_ID, message.text[4:]) |                 await bot.send_message(GROUP_ID, message.text[4:]) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|         @bot.message_handler(func=lambda message: True) |             @bot.message_handler(func=lambda message: True) | ||||||
|         async def echo_message(message): |             async def echo_message(message): | ||||||
|             if message.text[0] == '/': |                 if message.text[0] == '/': | ||||||
|                 await bot.reply_to(message, UNKNOWN_CMD_TEXT) |                     await bot.reply_to(message, UNKNOWN_CMD_TEXT) | ||||||
| 
 | 
 | ||||||
|     @bot.callback_query_handler(func=lambda call: True) |         @bot.callback_query_handler(func=lambda call: True) | ||||||
|     async def callback_query(call): |         async def callback_query(call): | ||||||
|         msg = json.loads(call.data) |             msg = json.loads(call.data) | ||||||
|         logging.info(call.data) |             logging.info(call.data) | ||||||
|         method = msg.get('method') |             method = msg.get('method') | ||||||
|         match method: |             match method: | ||||||
|             case 'redo': |                 case 'redo': | ||||||
|                 await _redo(call) |                     await _redo(call) | ||||||
| 
 | 
 | ||||||
|     try: |         try: | ||||||
|         await bot.infinity_polling() |             await bot.infinity_polling() | ||||||
| 
 | 
 | ||||||
|     except KeyboardInterrupt: |         except KeyboardInterrupt: | ||||||
|         ... |             ... | ||||||
| 
 | 
 | ||||||
|     finally: |         finally: | ||||||
|         vtestnet.stop() |             vtestnet.stop() | ||||||
|  |  | ||||||
|  | @ -6,12 +6,31 @@ import logging | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from contextlib import contextmanager as cm | from contextlib import contextmanager as cm | ||||||
| 
 | 
 | ||||||
|  | import asks | ||||||
| import docker | import docker | ||||||
| 
 | 
 | ||||||
|  | from asks.errors import RequestTimeout | ||||||
| from docker.types import Mount | from docker.types import Mount | ||||||
| from docker.models.containers import Container | from docker.models.containers import Container | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | async def get_ipfs_file(ipfs_link: str): | ||||||
|  |     logging.info(f'attempting to get image at {ipfs_link}') | ||||||
|  |     resp = None | ||||||
|  |     for i in range(10): | ||||||
|  |         try: | ||||||
|  |             resp = await asks.get(ipfs_link, timeout=3) | ||||||
|  | 
 | ||||||
|  |         except asks.errors.RequestTimeout: | ||||||
|  |             logging.warning('timeout...') | ||||||
|  | 
 | ||||||
|  |     if resp: | ||||||
|  |         logging.info(f'status_code: {resp.status_code}') | ||||||
|  |     else: | ||||||
|  |         logging.error(f'timeout') | ||||||
|  |     return resp | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class IPFSDocker: | class IPFSDocker: | ||||||
| 
 | 
 | ||||||
|     def __init__(self, container: Container): |     def __init__(self, container: Container): | ||||||
|  | @ -39,39 +58,42 @@ class IPFSDocker: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @cm | @cm | ||||||
| def open_ipfs_node(): | def open_ipfs_node(name='skynet-ipfs'): | ||||||
|     dclient = docker.from_env() |     dclient = docker.from_env() | ||||||
| 
 | 
 | ||||||
|     staging_dir = (Path().resolve() / 'ipfs-docker-staging').mkdir( |  | ||||||
|         parents=True, exist_ok=True) |  | ||||||
|     data_dir = (Path().resolve() / 'ipfs-docker-data').mkdir( |  | ||||||
|         parents=True, exist_ok=True) |  | ||||||
| 
 |  | ||||||
|     export_target = '/export' |  | ||||||
|     data_target = '/data/ipfs' |  | ||||||
| 
 |  | ||||||
|     container = dclient.containers.run( |  | ||||||
|         'ipfs/go-ipfs:latest', |  | ||||||
|         name='skynet-ipfs', |  | ||||||
|         ports={ |  | ||||||
|             '8080/tcp': 8080, |  | ||||||
|             '4001/tcp': 4001, |  | ||||||
|             '5001/tcp': ('127.0.0.1', 5001) |  | ||||||
|         }, |  | ||||||
|         mounts=[ |  | ||||||
|             Mount(export_target, str(staging_dir), 'bind'), |  | ||||||
|             Mount(data_target, str(data_dir), 'bind') |  | ||||||
|         ], |  | ||||||
|         detach=True, |  | ||||||
|         remove=True |  | ||||||
|     ) |  | ||||||
|     uid = os.getuid() |  | ||||||
|     gid = os.getgid() |  | ||||||
|     ec, out = container.exec_run(['chown', f'{uid}:{gid}', '-R', export_target]) |  | ||||||
|     assert ec == 0 |  | ||||||
|     ec, out = container.exec_run(['chown', f'{uid}:{gid}', '-R', data_target]) |  | ||||||
|     assert ec == 0 |  | ||||||
|     try: |     try: | ||||||
|  |         container = dclient.containers.get(name) | ||||||
|  | 
 | ||||||
|  |     except docker.errors.NotFound: | ||||||
|  |         staging_dir = Path().resolve() / 'ipfs-docker-staging' | ||||||
|  |         staging_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  | 
 | ||||||
|  |         data_dir = Path().resolve() / 'ipfs-docker-data' | ||||||
|  |         data_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  | 
 | ||||||
|  |         export_target = '/export' | ||||||
|  |         data_target = '/data/ipfs' | ||||||
|  | 
 | ||||||
|  |         container = dclient.containers.run( | ||||||
|  |             'ipfs/go-ipfs:latest', | ||||||
|  |             name='skynet-ipfs', | ||||||
|  |             ports={ | ||||||
|  |                 '8080/tcp': 8080, | ||||||
|  |                 '4001/tcp': 4001, | ||||||
|  |                 '5001/tcp': ('127.0.0.1', 5001) | ||||||
|  |             }, | ||||||
|  |             mounts=[ | ||||||
|  |                 Mount(export_target, str(staging_dir), 'bind'), | ||||||
|  |                 Mount(data_target, str(data_dir), 'bind') | ||||||
|  |             ], | ||||||
|  |             detach=True | ||||||
|  |         ) | ||||||
|  |         uid = os.getuid() | ||||||
|  |         gid = os.getgid() | ||||||
|  |         ec, out = container.exec_run(['chown', f'{uid}:{gid}', '-R', export_target]) | ||||||
|  |         assert ec == 0 | ||||||
|  |         ec, out = container.exec_run(['chown', f'{uid}:{gid}', '-R', data_target]) | ||||||
|  |         assert ec == 0 | ||||||
| 
 | 
 | ||||||
|         for log in container.logs(stream=True): |         for log in container.logs(stream=True): | ||||||
|             log = log.decode().rstrip() |             log = log.decode().rstrip() | ||||||
|  | @ -79,9 +101,5 @@ def open_ipfs_node(): | ||||||
|             if 'Daemon is ready' in log: |             if 'Daemon is ready' in log: | ||||||
|                 break |                 break | ||||||
| 
 | 
 | ||||||
|         yield IPFSDocker(container) |     yield IPFSDocker(container) | ||||||
| 
 |  | ||||||
|     finally: |  | ||||||
|         if container: |  | ||||||
|             container.stop() |  | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue