Add redo support to img2img also switch pinner to use http api

add-txt2txt-models
Guillermo Rodriguez 2023-05-29 12:42:55 -03:00
parent 22c403d3ae
commit 2b18fa376b
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
3 changed files with 189 additions and 108 deletions

View File

@ -384,22 +384,66 @@ def telegram(
))
class IPFSHTTP:
def __init__(self, endpoint: str):
self.endpoint = endpoint
def pin(self, cid: str):
return requests.post(
f'{self.endpoint}/api/v0/pin/add',
params={'arg': cid}
)
@run.command()
@click.option('--loglevel', '-l', default='INFO', help='logging level')
@click.option(
'--container', '-c', default='ipfs_host')
'--ipfs-rpc', '-i', default='http://127.0.0.1:5001')
@click.option(
'--hyperion-url', '-n', default='http://127.0.0.1:42001')
def pinner(loglevel, container, hyperion_url):
def pinner(loglevel, ipfs_rpc, hyperion_url):
logging.basicConfig(level=loglevel)
dclient = docker.from_env()
container = dclient.containers.get(container)
ipfs_node = IPFSDocker(container)
ipfs_node = IPFSHTTP(ipfs_rpc)
hyperion = HyperionAPI(hyperion_url)
last_pinned: dict[str, datetime] = {}
def capture_enqueues(half_min_ago: datetime):
# get all enqueues with binary data
# in the last minute
enqueues = hyperion.get_actions(
account='telos.gpu',
filter='telos.gpu:enqueue',
sort='desc',
after=half_min_ago.isoformat()
)
cids = []
for action in enqueues['actions']:
cid = action['act']['data']['binary_data']
if cid and cid not in last_pinned:
cids.append(cid)
return cids
def capture_submits(half_min_ago: datetime):
# get all submits in the last minute
submits = hyperion.get_actions(
account='telos.gpu',
filter='telos.gpu:submit',
sort='desc',
after=half_min_ago.isoformat()
)
cids = []
for action in submits['actions']:
cid = action['act']['data']['ipfs_hash']
if cid and cid not in last_pinned:
cids.append(cid)
return cids
def cleanup_pinned(now: datetime):
for cid in set(last_pinned.keys()):
ts = last_pinned[cid]
@ -411,50 +455,23 @@ def pinner(loglevel, container, hyperion_url):
now = datetime.now()
half_min_ago = now - timedelta(seconds=30)
# get all enqueues with binary data
# in the last minute
enqueues = hyperion.get_actions(
account='telos.gpu',
filter='telos.gpu:enqueue',
sort='desc',
after=half_min_ago.isoformat()
)
# get all submits in the last minute
submits = hyperion.get_actions(
account='telos.gpu',
filter='telos.gpu:submit',
sort='desc',
after=half_min_ago.isoformat()
)
# filter for the ones not already pinned
cids = [
*[
action['act']['data']['binary_data']
for action in enqueues['actions']
if action['act']['data']['binary_data']
not in last_pinned
],
*[
action['act']['data']['ipfs_hash']
for action in submits['actions']
if action['act']['data']['ipfs_hash']
not in last_pinned
]
]
cids = [*capture_enqueues(half_min_ago), *capture_submits(half_min_ago)]
# pin and remember
for cid in cids:
last_pinned[cid] = now
ipfs_node.pin(cid)
resp = ipfs_node.pin(cid)
if resp.status_code != 200:
logging.error(f'error pinning {cid}:\n{resp.text}')
logging.info(f'pinned {cid}')
else:
logging.info(f'pinned {cid}')
cleanup_pinned(now)
time.sleep(1)
time.sleep(0.1)
except KeyboardInterrupt:
...

View File

@ -25,11 +25,14 @@ DB_INIT_SQL = '''
CREATE SCHEMA IF NOT EXISTS skynet;
CREATE TABLE IF NOT EXISTS skynet.user(
id SERIAL PRIMARY KEY NOT NULL,
generated INT NOT NULL,
joined TIMESTAMP NOT NULL,
last_prompt TEXT,
role VARCHAR(128) NOT NULL
id SERIAL PRIMARY KEY NOT NULL,
generated INT NOT NULL,
joined TIMESTAMP NOT NULL,
last_method TEXT,
last_prompt TEXT,
last_file TEXT,
last_binary TEXT,
role VARCHAR(128) NOT NULL
);
CREATE TABLE IF NOT EXISTS skynet.user_config(
@ -175,12 +178,26 @@ async def get_user_config(conn, user: int):
async def get_user(conn, uid: int):
return await get_user_config(conn, uid)
async def get_last_method_of(conn, user: int):
stmt = await conn.prepare(
'SELECT last_method FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user)
async def get_last_prompt_of(conn, user: int):
stmt = await conn.prepare(
'SELECT last_prompt FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user)
async def get_last_file_of(conn, user: int):
stmt = await conn.prepare(
'SELECT last_file FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user)
async def get_last_binary_of(conn, user: int):
stmt = await conn.prepare(
'SELECT last_binary FROM skynet.user WHERE id = $1')
return await stmt.fetchval(user)
async def new_user(conn, uid: int):
if await get_user(conn, uid):
@ -192,12 +209,15 @@ async def new_user(conn, uid: int):
async with conn.transaction():
stmt = await conn.prepare('''
INSERT INTO skynet.user(
id, generated, joined, last_prompt, role)
id, generated, joined,
last_method, last_prompt, last_file, last_binary,
role
)
VALUES($1, $2, $3, $4, $5)
VALUES($1, $2, $3, $4, $5, $6, $7, $8)
''')
await stmt.fetch(
uid, 0, date, None, DEFAULT_ROLE
uid, 0, date, 'txt2img', None, None, None, DEFAULT_ROLE
)
stmt = await conn.prepare('''
@ -222,7 +242,8 @@ async def get_or_create_user(conn, uid: str):
user = await get_user(conn, uid)
if not user:
user = await new_user(conn, uid)
await new_user(conn, uid)
user = await get_user(conn, uid)
return user
@ -253,11 +274,7 @@ async def get_user_stats(conn, user: int):
record = records[0]
return record
async def update_user_stats(
conn,
user: int,
last_prompt: Optional[str] = None
):
async def increment_generated(conn, user: int):
stmt = await conn.prepare('''
UPDATE skynet.user
SET generated = generated + 1
@ -265,5 +282,20 @@ async def update_user_stats(
''')
await stmt.fetch(user)
async def update_user_stats(
conn,
user: int,
method: str,
last_prompt: str | None = None,
last_file: str | None = None,
last_binary: str | None = None
):
await update_user(conn, user, 'last_method', method)
if last_prompt:
await update_user(conn, user, 'last_prompt', last_prompt)
if last_file:
await update_user(conn, user, 'last_file', last_file)
if last_binary:
await update_user(conn, user, 'last_binary', last_binary)
logging.info((method, last_prompt, last_binary))

View File

@ -128,9 +128,8 @@ async def work_request(
account: str,
permission: str,
params: dict,
ipfs_node,
file_id: str | None = None,
file_path: str | None = None
binary_data: str = ''
):
if params['seed'] == None:
params['seed'] = random.randint(0, 9e18)
@ -141,30 +140,8 @@ async def work_request(
})
request_time = datetime.now().isoformat()
if file_id:
image_raw = await bot.download_file(file_path)
with Image.open(io.BytesIO(image_raw)) as image:
w, h = image.size
if w > 512 or h > 512:
logging.warning(f'user sent img of size {image.size}')
image.thumbnail((512, 512))
logging.warning(f'resized it to {image.size}')
image.save(f'ipfs-docker-staging/image.png', format='PNG')
ipfs_hash = ipfs_node.add('image.png')
ipfs_node.pin(ipfs_hash)
logging.info(f'published input image {ipfs_hash} on ipfs')
binary = ipfs_hash
else:
binary = ''
ec, out = cleos.push_action(
'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}'
'telos.gpu', 'enqueue', [account, body, binary_data, '20.0000 GPU'], f'{account}@{permission}'
)
out = collect_stdout(out)
if ec != 0:
@ -173,7 +150,7 @@ async def work_request(
nonce = await get_user_nonce(cleos, account)
request_hash = sha256(
(str(nonce) + body + binary).encode('utf-8')).hexdigest().upper()
(str(nonce) + body + binary_data).encode('utf-8')).hexdigest().upper()
request_id = int(out)
logging.info(f'{request_id} enqueued.')
@ -209,13 +186,13 @@ async def work_request(
return
# attempt to get the image and send it
resp = await get_ipfs_file(
f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png')
ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png'
resp = await get_ipfs_file(ipfs_link)
caption = generate_reply_caption(
user, params, ipfs_hash, tx_hash, worker)
if resp.status_code != 200:
if not resp or resp.status_code != 200:
logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
await bot.reply_to(
message,
@ -233,11 +210,10 @@ async def work_request(
InputMediaPhoto(file_id),
InputMediaPhoto(
resp.raw,
caption=caption
caption=caption,
parse_mode='HTML'
)
],
reply_markup=build_redo_menu(),
parse_mode='HTML'
)
else: # txt2img
@ -307,10 +283,18 @@ async def run_skynet_telegram(
async def send_cool_words(message):
await bot.reply_to(message, '\n'.join(COOL_WORDS))
@bot.message_handler(commands=['txt2img'])
async def send_txt2img(message):
user = message.from_user
chat = message.chat
async def _generic_txt2img(message_or_query):
if isinstance(message_or_query, CallbackQuery):
query = message_or_query
message = query.message
user = query.from_user
chat = query.message.chat
else:
message = message_or_query
user = message.from_user
chat = message.chat
reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id
@ -332,19 +316,30 @@ async def run_skynet_telegram(
**user_config
}
await db_call('update_user_stats', user.id, last_prompt=prompt)
await db_call(
'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
await work_request(
ec = await work_request(
bot, cleos, hyperion,
message, user, chat,
account, permission, params,
ipfs_node
account, permission, params
)
@bot.message_handler(func=lambda message: True, content_types=['photo'])
async def send_img2img(message):
user = message.from_user
chat = message.chat
if ec == 0:
await db_call('increment_generated', user.id)
async def _generic_img2img(message_or_query):
if isinstance(message_or_query, CallbackQuery):
query = message_or_query
message = query.message
user = query.from_user
chat = query.message.chat
else:
message = message_or_query
user = message.from_user
chat = message.chat
reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id
@ -364,6 +359,22 @@ async def run_skynet_telegram(
file_id = message.photo[-1].file_id
file_path = (await bot.get_file(file_id)).file_path
image_raw = await bot.download_file(file_path)
with Image.open(io.BytesIO(image_raw)) as image:
w, h = image.size
if w > 512 or h > 512:
logging.warning(f'user sent img of size {image.size}')
image.thumbnail((512, 512))
logging.warning(f'resized it to {image.size}')
image.save(f'ipfs-docker-staging/image.png', format='PNG')
ipfs_hash = ipfs_node.add('image.png')
ipfs_node.pin(ipfs_hash)
logging.info(f'published input image {ipfs_hash} on ipfs')
logging.info(f'mid: {message.id}')
@ -376,16 +387,34 @@ async def run_skynet_telegram(
**user_config
}
await db_call('update_user_stats', user.id, last_prompt=prompt)
await db_call(
'update_user_stats',
user.id,
'img2img',
last_file=file_id,
last_prompt=prompt,
last_binary=ipfs_hash
)
await work_request(
ec = await work_request(
bot, cleos, hyperion,
message, user, chat,
account, permission, params,
ipfs_node,
file_id=file_id, file_path=file_path
file_id=file_id,
binary_data=ipfs_hash
)
if ec == 0:
await db_call('increment_generated', user.id)
@bot.message_handler(commands=['txt2img'])
async def send_txt2img(message):
await _generic_txt2img(message)
@bot.message_handler(func=lambda message: True, content_types=[
'photo', 'document'])
async def send_img2img(message):
await _generic_img2img(message)
@bot.message_handler(commands=['img2img'])
async def img2img_missing_image(message):
@ -406,12 +435,15 @@ async def run_skynet_telegram(
user = message.from_user
chat = message.chat
reply_id = None
if chat.type == 'group' and chat.id == GROUP_ID:
reply_id = message.message_id
method = await db_call('get_last_method_of', user.id)
prompt = await db_call('get_last_prompt_of', user.id)
file_id = None
binary = ''
if method == 'img2img':
file_id = await db_call('get_last_file_of', user.id)
binary = await db_call('get_last_binary_of', user.id)
if not prompt:
await bot.reply_to(
message,
@ -433,7 +465,8 @@ async def run_skynet_telegram(
bot, cleos, hyperion,
message, user, chat,
account, permission, params,
ipfs_node
file_id=file_id,
binary_data=binary
)
@bot.message_handler(commands=['redo'])
@ -485,7 +518,6 @@ async def run_skynet_telegram(
await bot.send_message(GROUP_ID, message.text[4:])
@bot.message_handler(func=lambda message: True)
async def echo_message(message):
if message.text[0] == '/':