From 83b44e5e691fbddbefcbed357d5ddc946348bc37 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Wed, 19 Feb 2025 00:33:24 -0300 Subject: [PATCH] Finish testing img2img --- skynet/frontend/chatbot/__init__.py | 10 ++++++- skynet/frontend/chatbot/telegram.py | 42 ++++++++++++++++++++--------- skynet/frontend/chatbot/types.py | 8 +++++- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/skynet/frontend/chatbot/__init__.py b/skynet/frontend/chatbot/__init__.py index ab985a9..ac60eee 100644 --- a/skynet/frontend/chatbot/__init__.py +++ b/skynet/frontend/chatbot/__init__.py @@ -240,6 +240,14 @@ class BaseChatbot(ABC): await self.reply_to(msg, f'unknown request of type {msg.command}') return + if ( + msg.command == BaseCommands.IMG2IMG + and + len(inputs) == 0 + ): + await self.edit_msg(status_msg, 'seems you tried to do an img2img command without sending image') + return + # maybe apply recomended settings to this request del user_row['id'] if user_row['autoconf']: @@ -258,7 +266,7 @@ class BaseChatbot(ABC): # publish inputs to ipfs input_cids = [] for i in inputs: - i.publish() + await i.publish(self.ipfs, user_row) input_cids.append(i.cid) inputs_str = ','.join((i for i in input_cids)) diff --git a/skynet/frontend/chatbot/telegram.py b/skynet/frontend/chatbot/telegram.py index aa604c5..58e2bf6 100644 --- a/skynet/frontend/chatbot/telegram.py +++ b/skynet/frontend/chatbot/telegram.py @@ -101,6 +101,9 @@ class TelegramFileInput(BaseFileInput): raise ValueError + def set_cid(self, cid: str): + self._cid = cid + async def download(self, bot: AsyncTeleBot) -> bytes: file_path = (await bot.get_file(self.id)).file_path self._raw = await bot.download_file(file_path) @@ -113,6 +116,7 @@ class TelegramMessage(BaseMessage): self._msg = msg self._cmd = cmd self._chat = TelegramChatRoom(msg.chat) + self._inputs: list[TelegramFileInput] | None = None @property def id(self) -> int: @@ -124,7 +128,11 @@ class TelegramMessage(BaseMessage): @property def text(self) -> str: - return self._msg.text[len(self._cmd) + 2:] # remove command name, slash and first space + # remove command name, slash and first space + if self._msg.text: + return self._msg.text[len(self._cmd) + 2:] + + return self._msg.caption[len(self._cmd) + 2:] @property def author(self) -> TelegramUser: @@ -136,13 +144,15 @@ class TelegramMessage(BaseMessage): @property def inputs(self) -> list[TelegramFileInput]: - if self._msg.photo: - return [ - TelegramFileInput(photo=p) - for p in self._msg.photo - ] + if self._inputs is None: + self._inputs = [] + if self._msg.photo: + self._inputs = [ + TelegramFileInput(photo=p) + for p in self._msg.photo + ] - return [] + return self._inputs # generic tg utils @@ -242,7 +252,16 @@ class TelegramChatbot(BaseChatbot): append_handler(bot, BaseCommands.SAY, self.say) append_handler(bot, BaseCommands.TXT2IMG, self.handle_request) + append_handler(bot, BaseCommands.IMG2IMG, self.handle_request) + + @bot.message_handler(func=lambda _: True, content_types=['photo', 'document']) + async def handle_img2img(tg_msg: TGMessage): + msg = TelegramMessage(cmd='img2img', msg=tg_msg) + for file in msg.inputs: + await file.download(bot) + await self.handle_request(msg) + append_handler(bot, BaseCommands.REDO, self.handle_request) self.bot = bot @@ -267,7 +286,7 @@ class TelegramChatbot(BaseChatbot): return TelegramMessage(cmd=None, msg=msg) async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage: - msg = await self.bot.reply_to(msg._msg, text) + msg = await self.bot.reply_to(msg._msg, text, parse_mode='HTML') return TelegramMessage(cmd=None, msg=msg) async def edit_msg(self, msg: TelegramMessage, text: str): @@ -370,8 +389,8 @@ class TelegramChatbot(BaseChatbot): parse_mode='HTML' ) - case 1: - _input = inputs.pop() + case _: + _input = inputs[-1] await self.bot.send_media_group( status_msg.chat.id, media=[ @@ -379,6 +398,3 @@ class TelegramChatbot(BaseChatbot): InputMediaPhoto(result_img, caption=caption, parse_mode='HTML') ] ) - - case _: - raise NotImplementedError diff --git a/skynet/frontend/chatbot/types.py b/skynet/frontend/chatbot/types.py index 283029b..a5376d0 100644 --- a/skynet/frontend/chatbot/types.py +++ b/skynet/frontend/chatbot/types.py @@ -3,6 +3,7 @@ import io from abc import ABC, abstractproperty, abstractmethod from enum import StrEnum from typing import Self +from pathlib import Path from PIL import Image from skynet.ipfs import AsyncIPFSHTTP @@ -52,6 +53,10 @@ class BaseFileInput(ABC): async def download(self, *args) -> bytes: ... + @abstractmethod + def set_cid(self, cid: str): + ... + async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict): with Image.open(io.BytesIO(self._raw)) as img: w, h = img.size @@ -63,11 +68,12 @@ class BaseFileInput(ABC): ): img.thumbnail((user_row['width'], user_row['height'])) - img_path = '/tmp/ipfs-staging/img.png' + img_path = Path('/tmp/ipfs-staging/img.png') img.save(img_path, format='PNG') ipfs_info = await ipfs_api.add(img_path) ipfs_hash = ipfs_info['Hash'] + self.set_cid(ipfs_hash) await ipfs_api.pin(ipfs_hash)