mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add redo support to img2img also switch pinner to use http api
							parent
							
								
									22c403d3ae
								
							
						
					
					
						commit
						2b18fa376b
					
				| 
						 | 
					@ -384,22 +384,66 @@ def telegram(
 | 
				
			||||||
    ))
 | 
					    ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class IPFSHTTP:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, endpoint: str):
 | 
				
			||||||
 | 
					        self.endpoint = endpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def pin(self, cid: str):
 | 
				
			||||||
 | 
					        return requests.post(
 | 
				
			||||||
 | 
					            f'{self.endpoint}/api/v0/pin/add',
 | 
				
			||||||
 | 
					            params={'arg': cid}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@run.command()
 | 
					@run.command()
 | 
				
			||||||
@click.option('--loglevel', '-l', default='INFO', help='logging level')
 | 
					@click.option('--loglevel', '-l', default='INFO', help='logging level')
 | 
				
			||||||
@click.option(
 | 
					@click.option(
 | 
				
			||||||
    '--container', '-c', default='ipfs_host')
 | 
					    '--ipfs-rpc', '-i', default='http://127.0.0.1:5001')
 | 
				
			||||||
@click.option(
 | 
					@click.option(
 | 
				
			||||||
    '--hyperion-url', '-n', default='http://127.0.0.1:42001')
 | 
					    '--hyperion-url', '-n', default='http://127.0.0.1:42001')
 | 
				
			||||||
def pinner(loglevel, container, hyperion_url):
 | 
					def pinner(loglevel, ipfs_rpc, hyperion_url):
 | 
				
			||||||
    logging.basicConfig(level=loglevel)
 | 
					    logging.basicConfig(level=loglevel)
 | 
				
			||||||
    dclient = docker.from_env()
 | 
					    ipfs_node = IPFSHTTP(ipfs_rpc)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    container = dclient.containers.get(container)
 | 
					 | 
				
			||||||
    ipfs_node = IPFSDocker(container)
 | 
					 | 
				
			||||||
    hyperion = HyperionAPI(hyperion_url)
 | 
					    hyperion = HyperionAPI(hyperion_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    last_pinned: dict[str, datetime] = {}
 | 
					    last_pinned: dict[str, datetime] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def capture_enqueues(half_min_ago: datetime):
 | 
				
			||||||
 | 
					        # get all enqueues with binary data
 | 
				
			||||||
 | 
					        # in the last minute
 | 
				
			||||||
 | 
					        enqueues = hyperion.get_actions(
 | 
				
			||||||
 | 
					            account='telos.gpu',
 | 
				
			||||||
 | 
					            filter='telos.gpu:enqueue',
 | 
				
			||||||
 | 
					            sort='desc',
 | 
				
			||||||
 | 
					            after=half_min_ago.isoformat()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cids = []
 | 
				
			||||||
 | 
					        for action in enqueues['actions']:
 | 
				
			||||||
 | 
					            cid = action['act']['data']['binary_data']
 | 
				
			||||||
 | 
					            if cid and cid not in last_pinned:
 | 
				
			||||||
 | 
					                cids.append(cid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return cids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def capture_submits(half_min_ago: datetime):
 | 
				
			||||||
 | 
					        # get all submits in the last minute
 | 
				
			||||||
 | 
					        submits = hyperion.get_actions(
 | 
				
			||||||
 | 
					            account='telos.gpu',
 | 
				
			||||||
 | 
					            filter='telos.gpu:submit',
 | 
				
			||||||
 | 
					            sort='desc',
 | 
				
			||||||
 | 
					            after=half_min_ago.isoformat()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cids = []
 | 
				
			||||||
 | 
					        for action in submits['actions']:
 | 
				
			||||||
 | 
					            cid = action['act']['data']['ipfs_hash']
 | 
				
			||||||
 | 
					            if cid and cid not in last_pinned:
 | 
				
			||||||
 | 
					                cids.append(cid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return cids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def cleanup_pinned(now: datetime):
 | 
					    def cleanup_pinned(now: datetime):
 | 
				
			||||||
        for cid in set(last_pinned.keys()):
 | 
					        for cid in set(last_pinned.keys()):
 | 
				
			||||||
            ts = last_pinned[cid]
 | 
					            ts = last_pinned[cid]
 | 
				
			||||||
| 
						 | 
					@ -411,50 +455,23 @@ def pinner(loglevel, container, hyperion_url):
 | 
				
			||||||
            now = datetime.now()
 | 
					            now = datetime.now()
 | 
				
			||||||
            half_min_ago = now - timedelta(seconds=30)
 | 
					            half_min_ago = now - timedelta(seconds=30)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # get all enqueues with binary data
 | 
					 | 
				
			||||||
            # in the last minute
 | 
					 | 
				
			||||||
            enqueues = hyperion.get_actions(
 | 
					 | 
				
			||||||
                account='telos.gpu',
 | 
					 | 
				
			||||||
                filter='telos.gpu:enqueue',
 | 
					 | 
				
			||||||
                sort='desc',
 | 
					 | 
				
			||||||
                after=half_min_ago.isoformat()
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # get all submits in the last minute
 | 
					 | 
				
			||||||
            submits = hyperion.get_actions(
 | 
					 | 
				
			||||||
                account='telos.gpu',
 | 
					 | 
				
			||||||
                filter='telos.gpu:submit',
 | 
					 | 
				
			||||||
                sort='desc',
 | 
					 | 
				
			||||||
                after=half_min_ago.isoformat()
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # filter for the ones not already pinned
 | 
					            # filter for the ones not already pinned
 | 
				
			||||||
            cids = [
 | 
					            cids = [*capture_enqueues(half_min_ago), *capture_submits(half_min_ago)]
 | 
				
			||||||
                *[
 | 
					 | 
				
			||||||
                    action['act']['data']['binary_data']
 | 
					 | 
				
			||||||
                    for action in enqueues['actions']
 | 
					 | 
				
			||||||
                    if action['act']['data']['binary_data']
 | 
					 | 
				
			||||||
                    not in last_pinned
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                *[
 | 
					 | 
				
			||||||
                    action['act']['data']['ipfs_hash']
 | 
					 | 
				
			||||||
                    for action in submits['actions']
 | 
					 | 
				
			||||||
                    if action['act']['data']['ipfs_hash']
 | 
					 | 
				
			||||||
                    not in last_pinned
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # pin and remember
 | 
					            # pin and remember
 | 
				
			||||||
            for cid in cids:
 | 
					            for cid in cids:
 | 
				
			||||||
                last_pinned[cid] = now
 | 
					                last_pinned[cid] = now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ipfs_node.pin(cid)
 | 
					                resp = ipfs_node.pin(cid)
 | 
				
			||||||
 | 
					                if resp.status_code != 200:
 | 
				
			||||||
 | 
					                    logging.error(f'error pinning {cid}:\n{resp.text}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
                    logging.info(f'pinned {cid}')
 | 
					                    logging.info(f'pinned {cid}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cleanup_pinned(now)
 | 
					            cleanup_pinned(now)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            time.sleep(1)
 | 
					            time.sleep(0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    except KeyboardInterrupt:
 | 
					    except KeyboardInterrupt:
 | 
				
			||||||
        ...
 | 
					        ...
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,7 +28,10 @@ CREATE TABLE IF NOT EXISTS skynet.user(
 | 
				
			||||||
    id SERIAL PRIMARY KEY NOT NULL,
 | 
					    id SERIAL PRIMARY KEY NOT NULL,
 | 
				
			||||||
    generated INT NOT NULL,
 | 
					    generated INT NOT NULL,
 | 
				
			||||||
    joined TIMESTAMP NOT NULL,
 | 
					    joined TIMESTAMP NOT NULL,
 | 
				
			||||||
 | 
					    last_method TEXT,
 | 
				
			||||||
    last_prompt TEXT,
 | 
					    last_prompt TEXT,
 | 
				
			||||||
 | 
					    last_file   TEXT,
 | 
				
			||||||
 | 
					    last_binary TEXT,
 | 
				
			||||||
    role VARCHAR(128) NOT NULL
 | 
					    role VARCHAR(128) NOT NULL
 | 
				
			||||||
);
 | 
					);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -175,12 +178,26 @@ async def get_user_config(conn, user: int):
 | 
				
			||||||
async def get_user(conn, uid: int):
 | 
					async def get_user(conn, uid: int):
 | 
				
			||||||
    return await get_user_config(conn, uid)
 | 
					    return await get_user_config(conn, uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def get_last_method_of(conn, user: int):
 | 
				
			||||||
 | 
					    stmt = await conn.prepare(
 | 
				
			||||||
 | 
					        'SELECT last_method FROM skynet.user WHERE id = $1')
 | 
				
			||||||
 | 
					    return await stmt.fetchval(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_last_prompt_of(conn, user: int):
 | 
					async def get_last_prompt_of(conn, user: int):
 | 
				
			||||||
    stmt = await conn.prepare(
 | 
					    stmt = await conn.prepare(
 | 
				
			||||||
        'SELECT last_prompt FROM skynet.user WHERE id = $1')
 | 
					        'SELECT last_prompt FROM skynet.user WHERE id = $1')
 | 
				
			||||||
    return await stmt.fetchval(user)
 | 
					    return await stmt.fetchval(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def get_last_file_of(conn, user: int):
 | 
				
			||||||
 | 
					    stmt = await conn.prepare(
 | 
				
			||||||
 | 
					        'SELECT last_file FROM skynet.user WHERE id = $1')
 | 
				
			||||||
 | 
					    return await stmt.fetchval(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def get_last_binary_of(conn, user: int):
 | 
				
			||||||
 | 
					    stmt = await conn.prepare(
 | 
				
			||||||
 | 
					        'SELECT last_binary FROM skynet.user WHERE id = $1')
 | 
				
			||||||
 | 
					    return await stmt.fetchval(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def new_user(conn, uid: int):
 | 
					async def new_user(conn, uid: int):
 | 
				
			||||||
    if await get_user(conn, uid):
 | 
					    if await get_user(conn, uid):
 | 
				
			||||||
| 
						 | 
					@ -192,12 +209,15 @@ async def new_user(conn, uid: int):
 | 
				
			||||||
    async with conn.transaction():
 | 
					    async with conn.transaction():
 | 
				
			||||||
        stmt = await conn.prepare('''
 | 
					        stmt = await conn.prepare('''
 | 
				
			||||||
            INSERT INTO skynet.user(
 | 
					            INSERT INTO skynet.user(
 | 
				
			||||||
                id, generated, joined, last_prompt, role)
 | 
					                id, generated, joined,
 | 
				
			||||||
 | 
					                last_method, last_prompt, last_file, last_binary,
 | 
				
			||||||
 | 
					                role
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            VALUES($1, $2, $3, $4, $5)
 | 
					            VALUES($1, $2, $3, $4, $5, $6, $7, $8)
 | 
				
			||||||
        ''')
 | 
					        ''')
 | 
				
			||||||
        await stmt.fetch(
 | 
					        await stmt.fetch(
 | 
				
			||||||
            uid, 0, date, None, DEFAULT_ROLE
 | 
					            uid, 0, date, 'txt2img', None, None, None, DEFAULT_ROLE
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        stmt = await conn.prepare('''
 | 
					        stmt = await conn.prepare('''
 | 
				
			||||||
| 
						 | 
					@ -222,7 +242,8 @@ async def get_or_create_user(conn, uid: str):
 | 
				
			||||||
    user = await get_user(conn, uid)
 | 
					    user = await get_user(conn, uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not user:
 | 
					    if not user:
 | 
				
			||||||
        user = await new_user(conn, uid)
 | 
					        await new_user(conn, uid)
 | 
				
			||||||
 | 
					        user = await get_user(conn, uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return user
 | 
					    return user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -253,11 +274,7 @@ async def get_user_stats(conn, user: int):
 | 
				
			||||||
    record = records[0]
 | 
					    record = records[0]
 | 
				
			||||||
    return record
 | 
					    return record
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def update_user_stats(
 | 
					async def increment_generated(conn, user: int):
 | 
				
			||||||
    conn,
 | 
					 | 
				
			||||||
    user: int,
 | 
					 | 
				
			||||||
    last_prompt: Optional[str] = None
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    stmt = await conn.prepare('''
 | 
					    stmt = await conn.prepare('''
 | 
				
			||||||
        UPDATE skynet.user
 | 
					        UPDATE skynet.user
 | 
				
			||||||
        SET generated = generated + 1
 | 
					        SET generated = generated + 1
 | 
				
			||||||
| 
						 | 
					@ -265,5 +282,20 @@ async def update_user_stats(
 | 
				
			||||||
    ''')
 | 
					    ''')
 | 
				
			||||||
    await stmt.fetch(user)
 | 
					    await stmt.fetch(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def update_user_stats(
 | 
				
			||||||
 | 
					    conn,
 | 
				
			||||||
 | 
					    user: int,
 | 
				
			||||||
 | 
					    method: str,
 | 
				
			||||||
 | 
					    last_prompt: str | None = None,
 | 
				
			||||||
 | 
					    last_file: str | None = None,
 | 
				
			||||||
 | 
					    last_binary: str | None = None
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    await update_user(conn, user, 'last_method', method)
 | 
				
			||||||
    if last_prompt:
 | 
					    if last_prompt:
 | 
				
			||||||
        await update_user(conn, user, 'last_prompt', last_prompt)
 | 
					        await update_user(conn, user, 'last_prompt', last_prompt)
 | 
				
			||||||
 | 
					    if last_file:
 | 
				
			||||||
 | 
					        await update_user(conn, user, 'last_file', last_file)
 | 
				
			||||||
 | 
					    if last_binary:
 | 
				
			||||||
 | 
					        await update_user(conn, user, 'last_binary', last_binary)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info((method, last_prompt, last_binary))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -128,9 +128,8 @@ async def work_request(
 | 
				
			||||||
    account: str,
 | 
					    account: str,
 | 
				
			||||||
    permission: str,
 | 
					    permission: str,
 | 
				
			||||||
    params: dict,
 | 
					    params: dict,
 | 
				
			||||||
    ipfs_node,
 | 
					 | 
				
			||||||
    file_id: str | None = None,
 | 
					    file_id: str | None = None,
 | 
				
			||||||
    file_path: str | None = None
 | 
					    binary_data: str = ''
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    if params['seed'] == None:
 | 
					    if params['seed'] == None:
 | 
				
			||||||
        params['seed'] = random.randint(0, 9e18)
 | 
					        params['seed'] = random.randint(0, 9e18)
 | 
				
			||||||
| 
						 | 
					@ -141,30 +140,8 @@ async def work_request(
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
    request_time = datetime.now().isoformat()
 | 
					    request_time = datetime.now().isoformat()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if file_id:
 | 
					 | 
				
			||||||
        image_raw = await bot.download_file(file_path)
 | 
					 | 
				
			||||||
        with Image.open(io.BytesIO(image_raw)) as image:
 | 
					 | 
				
			||||||
            w, h = image.size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if w > 512 or h > 512:
 | 
					 | 
				
			||||||
                logging.warning(f'user sent img of size {image.size}')
 | 
					 | 
				
			||||||
                image.thumbnail((512, 512))
 | 
					 | 
				
			||||||
                logging.warning(f'resized it to {image.size}')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            image.save(f'ipfs-docker-staging/image.png', format='PNG')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            ipfs_hash = ipfs_node.add('image.png')
 | 
					 | 
				
			||||||
            ipfs_node.pin(ipfs_hash)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            logging.info(f'published input image {ipfs_hash} on ipfs')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            binary = ipfs_hash
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        binary = ''
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ec, out = cleos.push_action(
 | 
					    ec, out = cleos.push_action(
 | 
				
			||||||
        'telos.gpu', 'enqueue', [account, body, binary, '20.0000 GPU'], f'{account}@{permission}'
 | 
					        'telos.gpu', 'enqueue', [account, body, binary_data, '20.0000 GPU'], f'{account}@{permission}'
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    out = collect_stdout(out)
 | 
					    out = collect_stdout(out)
 | 
				
			||||||
    if ec != 0:
 | 
					    if ec != 0:
 | 
				
			||||||
| 
						 | 
					@ -173,7 +150,7 @@ async def work_request(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    nonce = await get_user_nonce(cleos, account)
 | 
					    nonce = await get_user_nonce(cleos, account)
 | 
				
			||||||
    request_hash = sha256(
 | 
					    request_hash = sha256(
 | 
				
			||||||
        (str(nonce) + body + binary).encode('utf-8')).hexdigest().upper()
 | 
					        (str(nonce) + body + binary_data).encode('utf-8')).hexdigest().upper()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    request_id = int(out)
 | 
					    request_id = int(out)
 | 
				
			||||||
    logging.info(f'{request_id} enqueued.')
 | 
					    logging.info(f'{request_id} enqueued.')
 | 
				
			||||||
| 
						 | 
					@ -209,13 +186,13 @@ async def work_request(
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # attempt to get the image and send it
 | 
					    # attempt to get the image and send it
 | 
				
			||||||
    resp = await get_ipfs_file(
 | 
					    ipfs_link = f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png'
 | 
				
			||||||
        f'http://test1.us.telos.net:8080/ipfs/{ipfs_hash}/image.png')
 | 
					    resp = await get_ipfs_file(ipfs_link)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    caption = generate_reply_caption(
 | 
					    caption = generate_reply_caption(
 | 
				
			||||||
        user, params, ipfs_hash, tx_hash, worker)
 | 
					        user, params, ipfs_hash, tx_hash, worker)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if resp.status_code != 200:
 | 
					    if not resp or resp.status_code != 200:
 | 
				
			||||||
        logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
 | 
					        logging.error(f'couldn\'t get ipfs hosted image at {ipfs_link}!')
 | 
				
			||||||
        await bot.reply_to(
 | 
					        await bot.reply_to(
 | 
				
			||||||
            message,
 | 
					            message,
 | 
				
			||||||
| 
						 | 
					@ -233,11 +210,10 @@ async def work_request(
 | 
				
			||||||
                    InputMediaPhoto(file_id),
 | 
					                    InputMediaPhoto(file_id),
 | 
				
			||||||
                    InputMediaPhoto(
 | 
					                    InputMediaPhoto(
 | 
				
			||||||
                        resp.raw,
 | 
					                        resp.raw,
 | 
				
			||||||
                        caption=caption
 | 
					                        caption=caption,
 | 
				
			||||||
 | 
					                        parse_mode='HTML'
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                ],
 | 
					                ],
 | 
				
			||||||
                reply_markup=build_redo_menu(),
 | 
					 | 
				
			||||||
                parse_mode='HTML'
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:  # txt2img
 | 
					        else:  # txt2img
 | 
				
			||||||
| 
						 | 
					@ -307,10 +283,18 @@ async def run_skynet_telegram(
 | 
				
			||||||
            async def send_cool_words(message):
 | 
					            async def send_cool_words(message):
 | 
				
			||||||
                await bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
					                await bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            @bot.message_handler(commands=['txt2img'])
 | 
					            async def _generic_txt2img(message_or_query):
 | 
				
			||||||
            async def send_txt2img(message):
 | 
					                if isinstance(message_or_query, CallbackQuery):
 | 
				
			||||||
 | 
					                    query = message_or_query
 | 
				
			||||||
 | 
					                    message = query.message
 | 
				
			||||||
 | 
					                    user = query.from_user
 | 
				
			||||||
 | 
					                    chat = query.message.chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    message = message_or_query
 | 
				
			||||||
                    user = message.from_user
 | 
					                    user = message.from_user
 | 
				
			||||||
                    chat = message.chat
 | 
					                    chat = message.chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                reply_id = None
 | 
					                reply_id = None
 | 
				
			||||||
                if chat.type == 'group' and chat.id == GROUP_ID:
 | 
					                if chat.type == 'group' and chat.id == GROUP_ID:
 | 
				
			||||||
                    reply_id = message.message_id
 | 
					                    reply_id = message.message_id
 | 
				
			||||||
| 
						 | 
					@ -332,19 +316,30 @@ async def run_skynet_telegram(
 | 
				
			||||||
                    **user_config
 | 
					                    **user_config
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                await db_call('update_user_stats', user.id, last_prompt=prompt)
 | 
					                await db_call(
 | 
				
			||||||
 | 
					                    'update_user_stats', user.id, 'txt2img', last_prompt=prompt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                await work_request(
 | 
					                ec = await work_request(
 | 
				
			||||||
                    bot, cleos, hyperion,
 | 
					                    bot, cleos, hyperion,
 | 
				
			||||||
                    message, user, chat,
 | 
					                    message, user, chat,
 | 
				
			||||||
                    account, permission, params,
 | 
					                    account, permission, params
 | 
				
			||||||
                    ipfs_node
 | 
					 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            @bot.message_handler(func=lambda message: True, content_types=['photo'])
 | 
					                if ec == 0:
 | 
				
			||||||
            async def send_img2img(message):
 | 
					                    await db_call('increment_generated', user.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            async def _generic_img2img(message_or_query):
 | 
				
			||||||
 | 
					                if isinstance(message_or_query, CallbackQuery):
 | 
				
			||||||
 | 
					                    query = message_or_query
 | 
				
			||||||
 | 
					                    message = query.message
 | 
				
			||||||
 | 
					                    user = query.from_user
 | 
				
			||||||
 | 
					                    chat = query.message.chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    message = message_or_query
 | 
				
			||||||
                    user = message.from_user
 | 
					                    user = message.from_user
 | 
				
			||||||
                    chat = message.chat
 | 
					                    chat = message.chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                reply_id = None
 | 
					                reply_id = None
 | 
				
			||||||
                if chat.type == 'group' and chat.id == GROUP_ID:
 | 
					                if chat.type == 'group' and chat.id == GROUP_ID:
 | 
				
			||||||
                    reply_id = message.message_id
 | 
					                    reply_id = message.message_id
 | 
				
			||||||
| 
						 | 
					@ -364,6 +359,22 @@ async def run_skynet_telegram(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                file_id = message.photo[-1].file_id
 | 
					                file_id = message.photo[-1].file_id
 | 
				
			||||||
                file_path = (await bot.get_file(file_id)).file_path
 | 
					                file_path = (await bot.get_file(file_id)).file_path
 | 
				
			||||||
 | 
					                image_raw = await bot.download_file(file_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                with Image.open(io.BytesIO(image_raw)) as image:
 | 
				
			||||||
 | 
					                    w, h = image.size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if w > 512 or h > 512:
 | 
				
			||||||
 | 
					                        logging.warning(f'user sent img of size {image.size}')
 | 
				
			||||||
 | 
					                        image.thumbnail((512, 512))
 | 
				
			||||||
 | 
					                        logging.warning(f'resized it to {image.size}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    image.save(f'ipfs-docker-staging/image.png', format='PNG')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    ipfs_hash = ipfs_node.add('image.png')
 | 
				
			||||||
 | 
					                    ipfs_node.pin(ipfs_hash)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    logging.info(f'published input image {ipfs_hash} on ipfs')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                logging.info(f'mid: {message.id}')
 | 
					                logging.info(f'mid: {message.id}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -376,16 +387,34 @@ async def run_skynet_telegram(
 | 
				
			||||||
                    **user_config
 | 
					                    **user_config
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                await db_call('update_user_stats', user.id, last_prompt=prompt)
 | 
					                await db_call(
 | 
				
			||||||
 | 
					                    'update_user_stats',
 | 
				
			||||||
 | 
					                    user.id,
 | 
				
			||||||
 | 
					                    'img2img',
 | 
				
			||||||
 | 
					                    last_file=file_id,
 | 
				
			||||||
 | 
					                    last_prompt=prompt,
 | 
				
			||||||
 | 
					                    last_binary=ipfs_hash
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                await work_request(
 | 
					                ec = await work_request(
 | 
				
			||||||
                    bot, cleos, hyperion,
 | 
					                    bot, cleos, hyperion,
 | 
				
			||||||
                    message, user, chat,
 | 
					                    message, user, chat,
 | 
				
			||||||
                    account, permission, params,
 | 
					                    account, permission, params,
 | 
				
			||||||
                    ipfs_node,
 | 
					                    file_id=file_id,
 | 
				
			||||||
                    file_id=file_id, file_path=file_path
 | 
					                    binary_data=ipfs_hash
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if ec == 0:
 | 
				
			||||||
 | 
					                    await db_call('increment_generated', user.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            @bot.message_handler(commands=['txt2img'])
 | 
				
			||||||
 | 
					            async def send_txt2img(message):
 | 
				
			||||||
 | 
					                await _generic_txt2img(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            @bot.message_handler(func=lambda message: True, content_types=[
 | 
				
			||||||
 | 
					                'photo', 'document'])
 | 
				
			||||||
 | 
					            async def send_img2img(message):
 | 
				
			||||||
 | 
					                await _generic_img2img(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            @bot.message_handler(commands=['img2img'])
 | 
					            @bot.message_handler(commands=['img2img'])
 | 
				
			||||||
            async def img2img_missing_image(message):
 | 
					            async def img2img_missing_image(message):
 | 
				
			||||||
| 
						 | 
					@ -406,12 +435,15 @@ async def run_skynet_telegram(
 | 
				
			||||||
                    user = message.from_user
 | 
					                    user = message.from_user
 | 
				
			||||||
                    chat = message.chat
 | 
					                    chat = message.chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                reply_id = None
 | 
					                method = await db_call('get_last_method_of', user.id)
 | 
				
			||||||
                if chat.type == 'group' and chat.id == GROUP_ID:
 | 
					 | 
				
			||||||
                    reply_id = message.message_id
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                prompt = await db_call('get_last_prompt_of', user.id)
 | 
					                prompt = await db_call('get_last_prompt_of', user.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                file_id = None
 | 
				
			||||||
 | 
					                binary = ''
 | 
				
			||||||
 | 
					                if method == 'img2img':
 | 
				
			||||||
 | 
					                    file_id = await db_call('get_last_file_of', user.id)
 | 
				
			||||||
 | 
					                    binary = await db_call('get_last_binary_of', user.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if not prompt:
 | 
					                if not prompt:
 | 
				
			||||||
                    await bot.reply_to(
 | 
					                    await bot.reply_to(
 | 
				
			||||||
                        message,
 | 
					                        message,
 | 
				
			||||||
| 
						 | 
					@ -433,7 +465,8 @@ async def run_skynet_telegram(
 | 
				
			||||||
                    bot, cleos, hyperion,
 | 
					                    bot, cleos, hyperion,
 | 
				
			||||||
                    message, user, chat,
 | 
					                    message, user, chat,
 | 
				
			||||||
                    account, permission, params,
 | 
					                    account, permission, params,
 | 
				
			||||||
                    ipfs_node
 | 
					                    file_id=file_id,
 | 
				
			||||||
 | 
					                    binary_data=binary
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            @bot.message_handler(commands=['redo'])
 | 
					            @bot.message_handler(commands=['redo'])
 | 
				
			||||||
| 
						 | 
					@ -485,7 +518,6 @@ async def run_skynet_telegram(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                await bot.send_message(GROUP_ID, message.text[4:])
 | 
					                await bot.send_message(GROUP_ID, message.text[4:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
            @bot.message_handler(func=lambda message: True)
 | 
					            @bot.message_handler(func=lambda message: True)
 | 
				
			||||||
            async def echo_message(message):
 | 
					            async def echo_message(message):
 | 
				
			||||||
                if message.text[0] == '/':
 | 
					                if message.text[0] == '/':
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue