Finish testing img2img

frontend_abc
Guillermo Rodriguez 2025-02-19 00:33:24 -03:00
parent fedbd95ca8
commit 83b44e5e69
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
3 changed files with 45 additions and 15 deletions

View File

@ -240,6 +240,14 @@ class BaseChatbot(ABC):
await self.reply_to(msg, f'unknown request of type {msg.command}')
return
if (
msg.command == BaseCommands.IMG2IMG
and
len(inputs) == 0
):
await self.edit_msg(status_msg, 'seems you tried to do an img2img command without sending image')
return
# maybe apply recomended settings to this request
del user_row['id']
if user_row['autoconf']:
@ -258,7 +266,7 @@ class BaseChatbot(ABC):
# publish inputs to ipfs
input_cids = []
for i in inputs:
i.publish()
await i.publish(self.ipfs, user_row)
input_cids.append(i.cid)
inputs_str = ','.join((i for i in input_cids))

View File

@ -101,6 +101,9 @@ class TelegramFileInput(BaseFileInput):
raise ValueError
def set_cid(self, cid: str):
self._cid = cid
async def download(self, bot: AsyncTeleBot) -> bytes:
file_path = (await bot.get_file(self.id)).file_path
self._raw = await bot.download_file(file_path)
@ -113,6 +116,7 @@ class TelegramMessage(BaseMessage):
self._msg = msg
self._cmd = cmd
self._chat = TelegramChatRoom(msg.chat)
self._inputs: list[TelegramFileInput] | None = None
@property
def id(self) -> int:
@ -124,7 +128,11 @@ class TelegramMessage(BaseMessage):
@property
def text(self) -> str:
return self._msg.text[len(self._cmd) + 2:] # remove command name, slash and first space
# remove command name, slash and first space
if self._msg.text:
return self._msg.text[len(self._cmd) + 2:]
return self._msg.caption[len(self._cmd) + 2:]
@property
def author(self) -> TelegramUser:
@ -136,13 +144,15 @@ class TelegramMessage(BaseMessage):
@property
def inputs(self) -> list[TelegramFileInput]:
if self._inputs is None:
self._inputs = []
if self._msg.photo:
return [
self._inputs = [
TelegramFileInput(photo=p)
for p in self._msg.photo
]
return []
return self._inputs
# generic tg utils
@ -242,7 +252,16 @@ class TelegramChatbot(BaseChatbot):
append_handler(bot, BaseCommands.SAY, self.say)
append_handler(bot, BaseCommands.TXT2IMG, self.handle_request)
append_handler(bot, BaseCommands.IMG2IMG, self.handle_request)
@bot.message_handler(func=lambda _: True, content_types=['photo', 'document'])
async def handle_img2img(tg_msg: TGMessage):
msg = TelegramMessage(cmd='img2img', msg=tg_msg)
for file in msg.inputs:
await file.download(bot)
await self.handle_request(msg)
append_handler(bot, BaseCommands.REDO, self.handle_request)
self.bot = bot
@ -267,7 +286,7 @@ class TelegramChatbot(BaseChatbot):
return TelegramMessage(cmd=None, msg=msg)
async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage:
msg = await self.bot.reply_to(msg._msg, text)
msg = await self.bot.reply_to(msg._msg, text, parse_mode='HTML')
return TelegramMessage(cmd=None, msg=msg)
async def edit_msg(self, msg: TelegramMessage, text: str):
@ -370,8 +389,8 @@ class TelegramChatbot(BaseChatbot):
parse_mode='HTML'
)
case 1:
_input = inputs.pop()
case _:
_input = inputs[-1]
await self.bot.send_media_group(
status_msg.chat.id,
media=[
@ -379,6 +398,3 @@ class TelegramChatbot(BaseChatbot):
InputMediaPhoto(result_img, caption=caption, parse_mode='HTML')
]
)
case _:
raise NotImplementedError

View File

@ -3,6 +3,7 @@ import io
from abc import ABC, abstractproperty, abstractmethod
from enum import StrEnum
from typing import Self
from pathlib import Path
from PIL import Image
from skynet.ipfs import AsyncIPFSHTTP
@ -52,6 +53,10 @@ class BaseFileInput(ABC):
async def download(self, *args) -> bytes:
...
@abstractmethod
def set_cid(self, cid: str):
...
async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict):
with Image.open(io.BytesIO(self._raw)) as img:
w, h = img.size
@ -63,11 +68,12 @@ class BaseFileInput(ABC):
):
img.thumbnail((user_row['width'], user_row['height']))
img_path = '/tmp/ipfs-staging/img.png'
img_path = Path('/tmp/ipfs-staging/img.png')
img.save(img_path, format='PNG')
ipfs_info = await ipfs_api.add(img_path)
ipfs_hash = ipfs_info['Hash']
self.set_cid(ipfs_hash)
await ipfs_api.pin(ipfs_hash)