mirror of https://github.com/skygpu/skynet.git
Finish testing img2img
parent
fedbd95ca8
commit
83b44e5e69
|
@ -240,6 +240,14 @@ class BaseChatbot(ABC):
|
||||||
await self.reply_to(msg, f'unknown request of type {msg.command}')
|
await self.reply_to(msg, f'unknown request of type {msg.command}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
msg.command == BaseCommands.IMG2IMG
|
||||||
|
and
|
||||||
|
len(inputs) == 0
|
||||||
|
):
|
||||||
|
await self.edit_msg(status_msg, 'seems you tried to do an img2img command without sending image')
|
||||||
|
return
|
||||||
|
|
||||||
# maybe apply recomended settings to this request
|
# maybe apply recomended settings to this request
|
||||||
del user_row['id']
|
del user_row['id']
|
||||||
if user_row['autoconf']:
|
if user_row['autoconf']:
|
||||||
|
@ -258,7 +266,7 @@ class BaseChatbot(ABC):
|
||||||
# publish inputs to ipfs
|
# publish inputs to ipfs
|
||||||
input_cids = []
|
input_cids = []
|
||||||
for i in inputs:
|
for i in inputs:
|
||||||
i.publish()
|
await i.publish(self.ipfs, user_row)
|
||||||
input_cids.append(i.cid)
|
input_cids.append(i.cid)
|
||||||
|
|
||||||
inputs_str = ','.join((i for i in input_cids))
|
inputs_str = ','.join((i for i in input_cids))
|
||||||
|
|
|
@ -101,6 +101,9 @@ class TelegramFileInput(BaseFileInput):
|
||||||
|
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
def set_cid(self, cid: str):
|
||||||
|
self._cid = cid
|
||||||
|
|
||||||
async def download(self, bot: AsyncTeleBot) -> bytes:
|
async def download(self, bot: AsyncTeleBot) -> bytes:
|
||||||
file_path = (await bot.get_file(self.id)).file_path
|
file_path = (await bot.get_file(self.id)).file_path
|
||||||
self._raw = await bot.download_file(file_path)
|
self._raw = await bot.download_file(file_path)
|
||||||
|
@ -113,6 +116,7 @@ class TelegramMessage(BaseMessage):
|
||||||
self._msg = msg
|
self._msg = msg
|
||||||
self._cmd = cmd
|
self._cmd = cmd
|
||||||
self._chat = TelegramChatRoom(msg.chat)
|
self._chat = TelegramChatRoom(msg.chat)
|
||||||
|
self._inputs: list[TelegramFileInput] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self) -> int:
|
def id(self) -> int:
|
||||||
|
@ -124,7 +128,11 @@ class TelegramMessage(BaseMessage):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
return self._msg.text[len(self._cmd) + 2:] # remove command name, slash and first space
|
# remove command name, slash and first space
|
||||||
|
if self._msg.text:
|
||||||
|
return self._msg.text[len(self._cmd) + 2:]
|
||||||
|
|
||||||
|
return self._msg.caption[len(self._cmd) + 2:]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def author(self) -> TelegramUser:
|
def author(self) -> TelegramUser:
|
||||||
|
@ -136,13 +144,15 @@ class TelegramMessage(BaseMessage):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> list[TelegramFileInput]:
|
def inputs(self) -> list[TelegramFileInput]:
|
||||||
|
if self._inputs is None:
|
||||||
|
self._inputs = []
|
||||||
if self._msg.photo:
|
if self._msg.photo:
|
||||||
return [
|
self._inputs = [
|
||||||
TelegramFileInput(photo=p)
|
TelegramFileInput(photo=p)
|
||||||
for p in self._msg.photo
|
for p in self._msg.photo
|
||||||
]
|
]
|
||||||
|
|
||||||
return []
|
return self._inputs
|
||||||
|
|
||||||
|
|
||||||
# generic tg utils
|
# generic tg utils
|
||||||
|
@ -242,7 +252,16 @@ class TelegramChatbot(BaseChatbot):
|
||||||
append_handler(bot, BaseCommands.SAY, self.say)
|
append_handler(bot, BaseCommands.SAY, self.say)
|
||||||
|
|
||||||
append_handler(bot, BaseCommands.TXT2IMG, self.handle_request)
|
append_handler(bot, BaseCommands.TXT2IMG, self.handle_request)
|
||||||
|
|
||||||
append_handler(bot, BaseCommands.IMG2IMG, self.handle_request)
|
append_handler(bot, BaseCommands.IMG2IMG, self.handle_request)
|
||||||
|
|
||||||
|
@bot.message_handler(func=lambda _: True, content_types=['photo', 'document'])
|
||||||
|
async def handle_img2img(tg_msg: TGMessage):
|
||||||
|
msg = TelegramMessage(cmd='img2img', msg=tg_msg)
|
||||||
|
for file in msg.inputs:
|
||||||
|
await file.download(bot)
|
||||||
|
await self.handle_request(msg)
|
||||||
|
|
||||||
append_handler(bot, BaseCommands.REDO, self.handle_request)
|
append_handler(bot, BaseCommands.REDO, self.handle_request)
|
||||||
|
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
@ -267,7 +286,7 @@ class TelegramChatbot(BaseChatbot):
|
||||||
return TelegramMessage(cmd=None, msg=msg)
|
return TelegramMessage(cmd=None, msg=msg)
|
||||||
|
|
||||||
async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage:
|
async def reply_to(self, msg: TelegramMessage, text: str) -> TelegramMessage:
|
||||||
msg = await self.bot.reply_to(msg._msg, text)
|
msg = await self.bot.reply_to(msg._msg, text, parse_mode='HTML')
|
||||||
return TelegramMessage(cmd=None, msg=msg)
|
return TelegramMessage(cmd=None, msg=msg)
|
||||||
|
|
||||||
async def edit_msg(self, msg: TelegramMessage, text: str):
|
async def edit_msg(self, msg: TelegramMessage, text: str):
|
||||||
|
@ -370,8 +389,8 @@ class TelegramChatbot(BaseChatbot):
|
||||||
parse_mode='HTML'
|
parse_mode='HTML'
|
||||||
)
|
)
|
||||||
|
|
||||||
case 1:
|
case _:
|
||||||
_input = inputs.pop()
|
_input = inputs[-1]
|
||||||
await self.bot.send_media_group(
|
await self.bot.send_media_group(
|
||||||
status_msg.chat.id,
|
status_msg.chat.id,
|
||||||
media=[
|
media=[
|
||||||
|
@ -379,6 +398,3 @@ class TelegramChatbot(BaseChatbot):
|
||||||
InputMediaPhoto(result_img, caption=caption, parse_mode='HTML')
|
InputMediaPhoto(result_img, caption=caption, parse_mode='HTML')
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
case _:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import io
|
||||||
from abc import ABC, abstractproperty, abstractmethod
|
from abc import ABC, abstractproperty, abstractmethod
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
from pathlib import Path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from skynet.ipfs import AsyncIPFSHTTP
|
from skynet.ipfs import AsyncIPFSHTTP
|
||||||
|
@ -52,6 +53,10 @@ class BaseFileInput(ABC):
|
||||||
async def download(self, *args) -> bytes:
|
async def download(self, *args) -> bytes:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_cid(self, cid: str):
|
||||||
|
...
|
||||||
|
|
||||||
async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict):
|
async def publish(self, ipfs_api: AsyncIPFSHTTP, user_row: dict):
|
||||||
with Image.open(io.BytesIO(self._raw)) as img:
|
with Image.open(io.BytesIO(self._raw)) as img:
|
||||||
w, h = img.size
|
w, h = img.size
|
||||||
|
@ -63,11 +68,12 @@ class BaseFileInput(ABC):
|
||||||
):
|
):
|
||||||
img.thumbnail((user_row['width'], user_row['height']))
|
img.thumbnail((user_row['width'], user_row['height']))
|
||||||
|
|
||||||
img_path = '/tmp/ipfs-staging/img.png'
|
img_path = Path('/tmp/ipfs-staging/img.png')
|
||||||
img.save(img_path, format='PNG')
|
img.save(img_path, format='PNG')
|
||||||
|
|
||||||
ipfs_info = await ipfs_api.add(img_path)
|
ipfs_info = await ipfs_api.add(img_path)
|
||||||
ipfs_hash = ipfs_info['Hash']
|
ipfs_hash = ipfs_info['Hash']
|
||||||
|
self.set_cid(ipfs_hash)
|
||||||
await ipfs_api.pin(ipfs_hash)
|
await ipfs_api.pin(ipfs_hash)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue