Add redo support to img2img also switch pinner to use http api

add-txt2txt-models
Guillermo Rodriguez 2023-05-29 12:42:55 -03:00
parent 22c403d3ae
commit 2b18fa376b
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
3 changed files with 189 additions and 108 deletions

View File

@ -384,22 +384,66 @@ def telegram(
)) ))
class IPFSHTTP:
def __init__(self, endpoint: str):
self.endpoint = endpoint
def pin(self, cid: str):
return requests.post(
f'{self.endpoint}/api/v0/pin/add',
params={'arg': cid}
)
@run.command() @run.command()
@click.option('--loglevel', '-l', default='INFO', help='logging level') @click.option('--loglevel', '-l', default='INFO', help='logging level')
@click.option( @click.option(
'--container', '-c', default='ipfs_host') '--ipfs-rpc', '-i', default='http://127.0.0.1:5001')
@click.option( @click.option(
'--hyperion-url', '-n', default='http://127.0.0.1:42001') '--hyperion-url', '-n', default='http://127.0.0.1:42001')
def pinner(loglevel, container, hyperion_url): def pinner(loglevel, ipfs_rpc, hyperion_url):
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
dclient = docker.from_env() ipfs_node = IPFSHTTP(ipfs_rpc)
container = dclient.containers.get(container)
ipfs_node = IPFSDocker(container)
hyperion = HyperionAPI(hyperion_url) hyperion = HyperionAPI(hyperion_url)
last_pinned: dict[str, datetime] = {} last_pinned: dict[str, datetime] = {}
def capture_enqueues(half_min_ago: datetime):
# get all enqueues with binary data
# in the last minute
enqueues = hyperion.get_actions(
account='telos.gpu',
filter='telos.gpu:enqueue',
sort='desc',
after=half_min_ago.isoformat()
)
cids = []
for action in enqueues['actions']:
cid = action['act']['data']['binary_data']
if cid and cid not in last_pinned:
cids.append(cid)
return cids
def capture_submits(half_min_ago: datetime):
# get all submits in the last minute
submits = hyperion.get_actions(
account='telos.gpu',
filter='telos.gpu:submit',
sort='desc',
after=half_min_ago.isoformat()
)
cids = []
for action in submits['actions']:
cid = action['act']['data']['ipfs_hash']
if cid and cid not in last_pinned:
cids.append(cid)
return cids
def cleanup_pinned(now: datetime): def cleanup_pinned(now: datetime):
for cid in set(last_pinned.keys()): for cid in set(last_pinned.keys()):
ts = last_pinned[cid] ts = last_pinned[cid]
@ -411,50 +455,23 @@ def pinner(loglevel, container, hyperion_url):
now = datetime.now() now = datetime.now()
half_min_ago = now - timedelta(seconds=30) half_min_ago = now - timedelta(seconds=30)
# get all enqueues with binary data
# in the last minute
enqueues = hyperion.get_actions(
account='telos.gpu',
filter='telos.gpu:enqueue',
sort='desc',
after=half_min_ago.isoformat()
)
# get all submits in the last minute
submits = hyperion.get_actions(
account='telos.gpu',
filter='telos.gpu:submit',
sort='desc',
after=half_min_ago.isoformat()
)
# filter for the ones not already pinned # filter for the ones not already pinned
cids = [ cids = [*capture_enqueues(half_min_ago), *capture_submits(half_min_ago)]
*[
action['act']['data']['binary_data']
for action in enqueues['actions']
if action['act']['data']['binary_data']
not in last_pinned
],
*[
action['act']['data']['ipfs_hash']
for action in submits['actions']
if action['act']['data']['ipfs_hash']
not in last_pinned
]
]
# pin and remember # pin and remember
for cid in cids: for cid in cids:
last_pinned[cid] = now last_pinned[cid] = now
ipfs_node.pin(cid) resp = ipfs_node.pin(cid)
if resp.status_code != 200:
logging.error(f'error pinning {cid}:\n{resp.text}')
logging.info(f'pinned {cid}') else:
logging.info(f'pinned {cid}')
cleanup_pinned(now) cleanup_pinned(now)
time.sleep(1) time.sleep(0.1)
except KeyboardInterrupt: except KeyboardInterrupt:
... ...

View File

@ -25,11 +25,14 @@ DB_INIT_SQL = '''
CREATE SCHEMA IF NOT EXISTS skynet; CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user( CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL, id SERIAL PRIMARY KEY NOT NULL,
generated INT NOT NULL, generated INT NOT NULL,
joined TIMESTAMP NOT NULL, joined TIMESTAMP NOT NULL,
last_prompt TEXT, last_method TEXT,
role VARCHAR(128) NOT NULL last_prompt TEXT,
last_file TEXT,
last_binary TEXT,
role VARCHAR(128) NOT NULL
); );
CREATE TABLE IF NOT EXISTS skynet.user_config( CREATE TABLE IF NOT EXISTS skynet.user_config(
@ -175,12 +178,26 @@ async def get_user_config(conn, user: int):
async def get_user(conn, uid: int): async def get_user(conn, uid: int):
return await get_user_config(conn, uid) return await get_user_config(conn, uid)
async def get_last_method_of(conn, user: int):
stmt = await conn.prepare(
'SELECT last_method FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user)
async def get_last_prompt_of(conn, user: int): async def get_last_prompt_of(conn, user: int):
stmt = await conn.prepare( stmt = await conn.prepare(
'SELECT last_prompt FROM skynet.user WHERE id = $1') 'SELECT last_prompt FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user) return await stmt.fetchval(user)
async def get_last_file_of(conn, user: int):
stmt = await conn.prepare(
'SELECT last_file FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user)
async def get_last_binary_of(conn, user: int):
stmt = await conn.prepare(
'SELECT last_binary FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user)
async def new_user(conn, uid: int): async def new_user(conn, uid: int):
if await get_user(conn, uid): if await get_user(conn, uid):
@ -192,12 +209,15 @@ async def new_user(conn, uid: int):
async with conn.transaction(): async with conn.transaction():
stmt = await conn.prepare(''' stmt = await conn.prepare('''
INSERT INTO skynet.user( INSERT INTO skynet.user(
id, generated, joined, last_prompt, role) id, generated, joined,
last_method, last_prompt, last_file, last_binary,
role
)
VALUES($1, $2, $3, $4, $5) VALUES($1, $2, $3, $4, $5, $6, $7, $8)
''') ''')
await stmt.fetch( await stmt.fetch(
uid, 0, date, None, DEFAULT_ROLE uid, 0, date, 'txt2img', None, None, None, DEFAULT_ROLE
) )
stmt = await conn.prepare(''' stmt = await conn.prepare('''
@ -222,7 +242,8 @@ async def get_or_create_user(conn, uid: str):
user = await get_user(conn, uid) user = await get_user(conn, uid)
if not user: if not user:
user = await new_user(conn, uid) await new_user(conn, uid)
user = await get_user(conn, uid)
return user return user
@ -253,11 +274,7 @@ async def get_user_stats(conn, user: int):
record = records[0] record = records[0]
return record return record
async def update_user_stats( async def increment_generated(conn, user: int):
conn,
user: int,
last_prompt: Optional[str] = None
):
stmt = await conn.prepare(''' stmt = await conn.prepare('''
UPDATE skynet.user UPDATE skynet.user
SET generated = generated + 1 SET generated = generated + 1
@ -265,5 +282,20 @@ async def update_user_stats(
''') ''')
await stmt.fetch(user) await stmt.fetch(user)
async def update_user_stats(
conn,
user: int,
method: str,
last_prompt: str | None = None,
last_file: str | None = None,
last_binary: str | None = None
):
await update_user(conn, user, 'last_method', method)
if last_prompt: if last_prompt:
await update_user(conn, user, 'last_prompt', last_prompt) await update_user(conn, user, 'last_prompt', last_prompt)
if last_file:
await update_user(conn, user, 'last_file', last_file)
if last_binary:
await update_user(conn, user, 'last_binary', last_binary)
logging.info((method, last_prompt, last_binary))

View File

@ -128,9 +128,8 @@ async def work_request(
account: str, account: str,
permission: str, permission: str,
params: dict, params: dict,
ipfs_node,
file_id: str | None = None, file_id: str | None = None,
file_path: str | None = None binary_data: str = ''
): ):
if params['seed'] == None: if params['seed'] == None:
params['seed'] = random.randint(0, 9e18) params['seed'] = random.randint(0, 9e18)
@ -141,30 +140,8 @@ async def work_request(
}) })
request_time = datetime.now().isoformat() request_time = datetime.now().isoformat()
if file_id:
image_raw = await bot.download_file(file_path)
with Image.open(io.BytesIO(image_raw)) as image:
w, h = image.size
if w > 512 or h > 512:
logging.warning(f'user sent img of size {image.size}')
image.thumbnail((512, 512))
logging.warning(f'resized it to {image.size}')
image.save(f'ipfs-docker-staging/image.png', format='PNG')
ipfs_hash = ipfs_node.add('image.png')
ipfs_node.pin(ipfs_hash)
logging.info(f'published input image {ipfs_hash} on ipfs')
binary = ipfs_hash
else:
binary = ''
ec, out = cleos.push_action( ec, out = cleos.push_action(
'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}' 'telos.gpu', 'enqueue', [account, body, binary_data, '20.0000 GPU'], f'{account}@{permission}'
) )
out = collect_stdout(out) out = collect_stdout(out)
if ec != 0: if ec != 0:
@ -173,7 +150,7 @@ async def work_request(
nonce = await get_user_nonce(cleos, account) nonce = await get_user_nonce(cleos, account)
request_hash = sha256( request_hash = sha256(
(str(nonce) + body + binary).encode('utf-8')).hexdigest().upper() (str(nonce) + body + binary_data).encode('utf-8')).hexdigest().upper()
request_id = int(out) request_id = int(out)
logging.info(f'{request_id} enqueued.') logging.info(f'{request_id} enqueued.')
@ -209,13 +186,13 @@ async def work_request(
return return
# attempt to get the image and send it # attempt to get the image and send it
resp = await get_ipfs_file( ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png'
f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png') resp = await get_ipfs_file(ipfs_link)
caption = generate_reply_caption( caption = generate_reply_caption(
user, params, ipfs_hash, tx_hash, worker) user, params, ipfs_hash, tx_hash, worker)
if resp.status_code != 200: if not resp or resp.status_code != 200:
logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!') logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
await bot.reply_to( await bot.reply_to(
message, message,
@ -233,11 +210,10 @@ async def work_request(
InputMediaPhoto(file_id), InputMediaPhoto(file_id),
InputMediaPhoto( InputMediaPhoto(
resp.raw, resp.raw,
caption=caption caption=caption,
parse_mode='HTML'
) )
], ],
reply_markup=build_redo_menu(),
parse_mode='HTML'
) )
else: # txt2img else: # txt2img
@ -307,10 +283,18 @@ async def run_skynet_telegram(
async def send_cool_words(message): async def send_cool_words(message):
await bot.reply_to(message, '\n'.join(COOL_WORDS)) await bot.reply_to(message, '\n'.join(COOL_WORDS))
@bot.message_handler(commands=['txt2img']) async def _generic_txt2img(message_or_query):
async def send_txt2img(message): if isinstance(message_or_query, CallbackQuery):
user = message.from_user query = message_or_query
chat = message.chat message = query.message
user = query.from_user
chat = query.message.chat
else:
message = message_or_query
user = message.from_user
chat = message.chat
reply_id = None reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID: if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id reply_id = message.message_id
@ -332,19 +316,30 @@ async def run_skynet_telegram(
**user_config **user_config
} }
await db_call('update_user_stats', user.id, last_prompt=prompt) await db_call(
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
await work_request( ec = await work_request(
bot, cleos, hyperion, bot, cleos, hyperion,
message, user, chat, message, user, chat,
account, permission, params, account, permission, params
ipfs_node
) )
@bot.message_handler(func=lambda message: True, content_types=['photo']) if ec == 0:
async def send_img2img(message): await db_call('increment_generated', user.id)
user = message.from_user
chat = message.chat async def _generic_img2img(message_or_query):
if isinstance(message_or_query, CallbackQuery):
query = message_or_query
message = query.message
user = query.from_user
chat = query.message.chat
else:
message = message_or_query
user = message.from_user
chat = message.chat
reply_id = None reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID: if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id reply_id = message.message_id
@ -364,6 +359,22 @@ async def run_skynet_telegram(
file_id = message.photo[-1].file_id file_id = message.photo[-1].file_id
file_path = (await bot.get_file(file_id)).file_path file_path = (await bot.get_file(file_id)).file_path
image_raw = await bot.download_file(file_path)
with Image.open(io.BytesIO(image_raw)) as image:
w, h = image.size
if w > 512 or h > 512:
logging.warning(f'user sent img of size {image.size}')
image.thumbnail((512, 512))
logging.warning(f'resized it to {image.size}')
image.save(f'ipfs-docker-staging/image.png', format='PNG')
ipfs_hash = ipfs_node.add('image.png')
ipfs_node.pin(ipfs_hash)
logging.info(f'published input image {ipfs_hash} on ipfs')
logging.info(f'mid: {message.id}') logging.info(f'mid: {message.id}')
@ -376,16 +387,34 @@ async def run_skynet_telegram(
**user_config **user_config
} }
await db_call('update_user_stats', user.id, last_prompt=prompt) await db_call(
'update_user_stats',
user.id,
'img2img',
last_file=file_id,
last_prompt=prompt,
last_binary=ipfs_hash
)
await work_request( ec = await work_request(
bot, cleos, hyperion, bot, cleos, hyperion,
message, user, chat, message, user, chat,
account, permission, params, account, permission, params,
ipfs_node, file_id=file_id,
file_id=file_id, file_path=file_path binary_data=ipfs_hash
) )
if ec == 0:
await db_call('increment_generated', user.id)
@bot.message_handler(commands=['txt2img'])
async def send_txt2img(message):
await _generic_txt2img(message)
@bot.message_handler(func=lambda message: True, content_types=[
'photo', 'document'])
async def send_img2img(message):
await _generic_img2img(message)
@bot.message_handler(commands=['img2img']) @bot.message_handler(commands=['img2img'])
async def img2img_missing_image(message): async def img2img_missing_image(message):
@ -406,12 +435,15 @@ async def run_skynet_telegram(
user = message.from_user user = message.from_user
chat = message.chat chat = message.chat
reply_id = None method = await db_call('get_last_method_of', user.id)
if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id
prompt = await db_call('get_last_prompt_of', user.id) prompt = await db_call('get_last_prompt_of', user.id)
file_id = None
binary = ''
if method == 'img2img':
file_id = await db_call('get_last_file_of', user.id)
binary = await db_call('get_last_binary_of', user.id)
if not prompt: if not prompt:
await bot.reply_to( await bot.reply_to(
message, message,
@ -433,7 +465,8 @@ async def run_skynet_telegram(
bot, cleos, hyperion, bot, cleos, hyperion,
message, user, chat, message, user, chat,
account, permission, params, account, permission, params,
ipfs_node file_id=file_id,
binary_data=binary
) )
@bot.message_handler(commands=['redo']) @bot.message_handler(commands=['redo'])
@ -485,7 +518,6 @@ async def run_skynet_telegram(
await bot.send_message(GROUP_ID, message.text[4:]) await bot.send_message(GROUP_ID, message.text[4:])
@bot.message_handler(func=lambda message: True) @bot.message_handler(func=lambda message: True)
async def echo_message(message): async def echo_message(message):
if message.text[0] == '/': if message.text[0] == '/':