Finish testing img2img

frontend_abc
Guillermo Rodriguez 2025-02-19 00:33:24 -03:00
parent fedbd95ca8
commit 83b44e5e69
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
3 changed files with 45 additions and 15 deletions

View File

@ -240,6 +240,14 @@ class BaseChatbot(ABC):
await self.reply_to(msg, f'unknown request of type {msg.command}') await self.reply_to(msg, f'unknown request of type {msg.command}')
return return
if (
msg.command == BaseCommands.IMG2IMG
and
len(inputs) == 0
):
await self.edit_msg(status_msg, 'seems you tried to do an img2img command without sending image')
return
# maybe apply recomended settings to this request # maybe apply recomended settings to this request
del user_row['id'] del user_row['id']
if user_row['autoconf']: if user_row['autoconf']:
@ -258,7 +266,7 @@ class BaseChatbot(ABC):
# publish inputs to ipfs # publish inputs to ipfs
input_cids = [] input_cids = []
for i in inputs: for i in inputs:
i.publish() await i.publish(self.ipfs, user_row)
input_cids.append(i.cid) input_cids.append(i.cid)
inputs_str = ','.join((i for i in input_cids)) inputs_str = ','.join((i for i in input_cids))

View File

@ -101,6 +101,9 @@ class TelegramFileInput(BaseFileInput):
raise ValueError raise ValueError
def set_cid(self, cid: str):
self._cid = cid
async def download(self, bot: AsyncTeleBot) -> bytes: async def download(self, bot: AsyncTeleBot) -> bytes:
file_path = (await bot.get_file(self.id)).file_path file_path = (await bot.get_file(self.id)).file_path
self._raw = await bot.download_file(file_path) self._raw = await bot.download_file(file_path)
@ -113,6 +116,7 @@ class TelegramMessage(BaseMessage):
self._msg = msg self._msg = msg
self._cmd = cmd self._cmd = cmd
self._chat = TelegramChatRoom(msg.chat) self._chat = TelegramChatRoom(msg.chat)
self._inputs: list[TelegramFileInput] | None = None
@property @property
def id(self) -> int: def id(self) -> int:
@ -124,7 +128,11 @@ class TelegramMessage(BaseMessage):
@property @property
def text(self) -> str: def text(self) -> str:
return self._msg.text[len(self._cmd) + 2:] # remove command name, slash and first space # remove command name, slash and first space
if self._msg.text:
return self._msg.text[len(self._cmd) + 2:]
return self._msg.caption[len(self._cmd) + 2:]
@property @property
def author(self) -> TelegramUser: def author(self) -> TelegramUser:
@ -136,13 +144,15 @@ class TelegramMessage(BaseMessage):
@property @property
def inputs(self) -> list[TelegramFileInput]: def inputs(self) -> list[TelegramFileInput]:
if self._msg.photo: if self._inputs is None:
return [ self._inputs = []
TelegramFileInput(photo=p) if self._msg.photo:
for p in self._msg.photo self._inputs = [
] TelegramFileInput(photo=p)
for p in self._msg.photo
]
return [] return self._inputs
# generic tg utils # generic tg utils
@ -242,7 +252,16 @@ class TelegramChatbot(BaseChatbot):
append_handler(bot, BaseCommands.SAY, self.say) append_handler(bot, BaseCommands.SAY, self.say)
append_handler(bot, BaseCommands.TXT2IMG, self.handle_request) append_handler(bot, BaseCommands.TXT2IMG, self.handle_request)
append_handler(bot, BaseCommands.IMG2IMG, self.handle_request) append_handler(bot, BaseCommands.IMG2IMG, self.handle_request)
@bot.message_handler(func=lambda _: True, content_types=['photo', 'document'])
async def handle_img2img(tg_msg: TGMessage):
msg = TelegramMessage(cmd='img2img', msg=tg_msg)
for file in msg.inputs:
await file.download(bot)
await self.handle_request(msg)
append_handler(bot, BaseCommands.REDO, self.handle_request) append_handler(bot, BaseCommands.REDO, self.handle_request)
self.bot = bot self.bot = bot
@ -267,7 +286,7 @@ class TelegramChatbot(BaseChatbot):
return TelegramMessage(cmd=None, msg=msg) return TelegramMessage(cmd=None, msg=msg)
async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage: async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage:
msg = await self.bot.reply_to(msg._msg, text) msg = await self.bot.reply_to(msg._msg, text, parse_mode='HTML')
return TelegramMessage(cmd=None, msg=msg) return TelegramMessage(cmd=None, msg=msg)
async def edit_msg(self, msg: TelegramMessage, text: str): async def edit_msg(self, msg: TelegramMessage, text: str):
@ -370,8 +389,8 @@ class TelegramChatbot(BaseChatbot):
parse_mode='HTML' parse_mode='HTML'
) )
case 1: case _:
_input = inputs.pop() _input = inputs[-1]
await self.bot.send_media_group( await self.bot.send_media_group(
status_msg.chat.id, status_msg.chat.id,
media=[ media=[
@ -379,6 +398,3 @@ class TelegramChatbot(BaseChatbot):
InputMediaPhoto(result_img, caption=caption, parse_mode='HTML') InputMediaPhoto(result_img, caption=caption, parse_mode='HTML')
] ]
) )
case _:
raise NotImplementedError

View File

@ -3,6 +3,7 @@ import io
from abc import ABC, abstractproperty, abstractmethod from abc import ABC, abstractproperty, abstractmethod
from enum import StrEnum from enum import StrEnum
from typing import Self from typing import Self
from pathlib import Path
from PIL import Image from PIL import Image
from skynet.ipfs import AsyncIPFSHTTP from skynet.ipfs import AsyncIPFSHTTP
@ -52,6 +53,10 @@ class BaseFileInput(ABC):
async def download(self, *args) -> bytes: async def download(self, *args) -> bytes:
... ...
@abstractmethod
def set_cid(self, cid: str):
...
async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict): async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict):
with Image.open(io.BytesIO(self._raw)) as img: with Image.open(io.BytesIO(self._raw)) as img:
w, h = img.size w, h = img.size
@ -63,11 +68,12 @@ class BaseFileInput(ABC):
): ):
img.thumbnail((user_row['width'], user_row['height'])) img.thumbnail((user_row['width'], user_row['height']))
img_path = '/tmp/ipfs-staging/img.png' img_path = Path('/tmp/ipfs-staging/img.png')
img.save(img_path, format='PNG') img.save(img_path, format='PNG')
ipfs_info = await ipfs_api.add(img_path) ipfs_info = await ipfs_api.add(img_path)
ipfs_hash = ipfs_info['Hash'] ipfs_hash = ipfs_info['Hash']
self.set_cid(ipfs_hash)
await ipfs_api.pin(ipfs_hash) await ipfs_api.pin(ipfs_hash)