mirror of https://github.com/skygpu/skynet.git
Finish testing img2img
parent
fedbd95ca8
commit
83b44e5e69
|
@ -240,6 +240,14 @@ class BaseChatbot(ABC):
|
|||
await self.reply_to(msg, f'unknown request of type {msg.command}')
|
||||
return
|
||||
|
||||
if (
|
||||
msg.command == BaseCommands.IMG2IMG
|
||||
and
|
||||
len(inputs) == 0
|
||||
):
|
||||
await self.edit_msg(status_msg, 'seems you tried to do an img2img command without sending image')
|
||||
return
|
||||
|
||||
# maybe apply recomended settings to this request
|
||||
del user_row['id']
|
||||
if user_row['autoconf']:
|
||||
|
@ -258,7 +266,7 @@ class BaseChatbot(ABC):
|
|||
# publish inputs to ipfs
|
||||
input_cids = []
|
||||
for i in inputs:
|
||||
i.publish()
|
||||
await i.publish(self.ipfs, user_row)
|
||||
input_cids.append(i.cid)
|
||||
|
||||
inputs_str = ','.join((i for i in input_cids))
|
||||
|
|
|
@ -101,6 +101,9 @@ class TelegramFileInput(BaseFileInput):
|
|||
|
||||
raise ValueError
|
||||
|
||||
def set_cid(self, cid: str):
|
||||
self._cid = cid
|
||||
|
||||
async def download(self, bot: AsyncTeleBot) -> bytes:
|
||||
file_path = (await bot.get_file(self.id)).file_path
|
||||
self._raw = await bot.download_file(file_path)
|
||||
|
@ -113,6 +116,7 @@ class TelegramMessage(BaseMessage):
|
|||
self._msg = msg
|
||||
self._cmd = cmd
|
||||
self._chat = TelegramChatRoom(msg.chat)
|
||||
self._inputs: list[TelegramFileInput] | None = None
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
|
@ -124,7 +128,11 @@ class TelegramMessage(BaseMessage):
|
|||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self._msg.text[len(self._cmd) + 2:] # remove command name, slash and first space
|
||||
# remove command name, slash and first space
|
||||
if self._msg.text:
|
||||
return self._msg.text[len(self._cmd) + 2:]
|
||||
|
||||
return self._msg.caption[len(self._cmd) + 2:]
|
||||
|
||||
@property
|
||||
def author(self) -> TelegramUser:
|
||||
|
@ -136,13 +144,15 @@ class TelegramMessage(BaseMessage):
|
|||
|
||||
@property
|
||||
def inputs(self) -> list[TelegramFileInput]:
|
||||
if self._msg.photo:
|
||||
return [
|
||||
TelegramFileInput(photo=p)
|
||||
for p in self._msg.photo
|
||||
]
|
||||
if self._inputs is None:
|
||||
self._inputs = []
|
||||
if self._msg.photo:
|
||||
self._inputs = [
|
||||
TelegramFileInput(photo=p)
|
||||
for p in self._msg.photo
|
||||
]
|
||||
|
||||
return []
|
||||
return self._inputs
|
||||
|
||||
|
||||
# generic tg utils
|
||||
|
@ -242,7 +252,16 @@ class TelegramChatbot(BaseChatbot):
|
|||
append_handler(bot, BaseCommands.SAY, self.say)
|
||||
|
||||
append_handler(bot, BaseCommands.TXT2IMG, self.handle_request)
|
||||
|
||||
append_handler(bot, BaseCommands.IMG2IMG, self.handle_request)
|
||||
|
||||
@bot.message_handler(func=lambda _: True, content_types=['photo', 'document'])
|
||||
async def handle_img2img(tg_msg: TGMessage):
|
||||
msg = TelegramMessage(cmd='img2img', msg=tg_msg)
|
||||
for file in msg.inputs:
|
||||
await file.download(bot)
|
||||
await self.handle_request(msg)
|
||||
|
||||
append_handler(bot, BaseCommands.REDO, self.handle_request)
|
||||
|
||||
self.bot = bot
|
||||
|
@ -267,7 +286,7 @@ class TelegramChatbot(BaseChatbot):
|
|||
return TelegramMessage(cmd=None, msg=msg)
|
||||
|
||||
async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage:
|
||||
msg = await self.bot.reply_to(msg._msg, text)
|
||||
msg = await self.bot.reply_to(msg._msg, text, parse_mode='HTML')
|
||||
return TelegramMessage(cmd=None, msg=msg)
|
||||
|
||||
async def edit_msg(self, msg: TelegramMessage, text: str):
|
||||
|
@ -370,8 +389,8 @@ class TelegramChatbot(BaseChatbot):
|
|||
parse_mode='HTML'
|
||||
)
|
||||
|
||||
case 1:
|
||||
_input = inputs.pop()
|
||||
case _:
|
||||
_input = inputs[-1]
|
||||
await self.bot.send_media_group(
|
||||
status_msg.chat.id,
|
||||
media=[
|
||||
|
@ -379,6 +398,3 @@ class TelegramChatbot(BaseChatbot):
|
|||
InputMediaPhoto(result_img, caption=caption, parse_mode='HTML')
|
||||
]
|
||||
)
|
||||
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -3,6 +3,7 @@ import io
|
|||
from abc import ABC, abstractproperty, abstractmethod
|
||||
from enum import StrEnum
|
||||
from typing import Self
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
from skynet.ipfs import AsyncIPFSHTTP
|
||||
|
@ -52,6 +53,10 @@ class BaseFileInput(ABC):
|
|||
async def download(self, *args) -> bytes:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_cid(self, cid: str):
|
||||
...
|
||||
|
||||
async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict):
|
||||
with Image.open(io.BytesIO(self._raw)) as img:
|
||||
w, h = img.size
|
||||
|
@ -63,11 +68,12 @@ class BaseFileInput(ABC):
|
|||
):
|
||||
img.thumbnail((user_row['width'], user_row['height']))
|
||||
|
||||
img_path = '/tmp/ipfs-staging/img.png'
|
||||
img_path = Path('/tmp/ipfs-staging/img.png')
|
||||
img.save(img_path, format='PNG')
|
||||
|
||||
ipfs_info = await ipfs_api.add(img_path)
|
||||
ipfs_hash = ipfs_info['Hash']
|
||||
self.set_cid(ipfs_hash)
|
||||
await ipfs_api.pin(ipfs_hash)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue