mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add round robin
							parent
							
								
									563efed3e9
								
							
						
					
					
						commit
						318a21ac81
					
				| 
						 | 
				
			
			@ -3,9 +3,9 @@ docker run \
 | 
			
		|||
    --rm \
 | 
			
		||||
    --gpus=all \
 | 
			
		||||
    --env HF_TOKEN='' \
 | 
			
		||||
    --env DB_USER='' \
 | 
			
		||||
    --env DB_PASS='' \
 | 
			
		||||
    --env DB_USER='skynet' \
 | 
			
		||||
    --env DB_PASS='password' \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/outputs,target=/outputs \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
 | 
			
		||||
    --mount type=bind,source="$(pwd)"/scripts,target=/scripts \
 | 
			
		||||
    skynet-art-bot:0.1a3 $1
 | 
			
		||||
    skynet:dif python telegram-bot-dev.py
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,6 +87,17 @@ DEFAULT_GUIDANCE = 7.5
 | 
			
		|||
DEFAULT_STEP = 75
 | 
			
		||||
DEFAULT_CREDITS = 10
 | 
			
		||||
 | 
			
		||||
rr_total = 2
 | 
			
		||||
rr_id = 0
 | 
			
		||||
request_counter = 0
 | 
			
		||||
 | 
			
		||||
def its_my_turn():
 | 
			
		||||
    global request_counter, rr_total, rr_id
 | 
			
		||||
    my_turn = request_counter % rr_total == rr_id
 | 
			
		||||
    logging.info(f'new request {request_counter}, turn: {my_turn} rr_total: {rr_total}, rr_id {rr_id}')
 | 
			
		||||
    request_counter += 1
 | 
			
		||||
    return my_turn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_image(i, prompt, name, step, size, guidance, seed):
 | 
			
		||||
    assert torch.cuda.is_available()
 | 
			
		||||
| 
						 | 
				
			
			@ -118,9 +129,10 @@ if __name__ == '__main__':
 | 
			
		|||
    API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
 | 
			
		||||
 | 
			
		||||
    bot = telebot.TeleBot(API_TOKEN)
 | 
			
		||||
    db_client = MongoClient(f'mongodb://{db_user}:{db_pass}@localhost:27017')
 | 
			
		||||
 | 
			
		||||
    rr_id = 0
 | 
			
		||||
    db_client = MongoClient(
 | 
			
		||||
        host=['ancap.tech:64000'],
 | 
			
		||||
        username=db_user,
 | 
			
		||||
        password=db_pass)
 | 
			
		||||
 | 
			
		||||
    tgdb = db_client.get_database('telegram')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -179,8 +191,8 @@ if __name__ == '__main__':
 | 
			
		|||
        return tg_users.find_one_and_update(
 | 
			
		||||
            {'uid': uid}, updt_cmd)
 | 
			
		||||
 | 
			
		||||
    # bot handler
 | 
			
		||||
 | 
			
		||||
    # bot handler
 | 
			
		||||
    def img_for_user_with_prompt(
 | 
			
		||||
        uid: int,
 | 
			
		||||
        prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int
 | 
			
		||||
| 
						 | 
				
			
			@ -210,14 +222,19 @@ if __name__ == '__main__':
 | 
			
		|||
 | 
			
		||||
    @bot.message_handler(commands=['help'])
 | 
			
		||||
    def send_help(message):
 | 
			
		||||
        bot.reply_to(message, HELP_TEXT)
 | 
			
		||||
        if its_my_turn():
 | 
			
		||||
            bot.reply_to(message, HELP_TEXT)
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['cool'])
 | 
			
		||||
    def send_help(message):
 | 
			
		||||
        bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
			
		||||
    def send_cool_words(message):
 | 
			
		||||
        if its_my_turn():
 | 
			
		||||
            bot.reply_to(message, '\n'.join(COOL_WORDS))
 | 
			
		||||
 | 
			
		||||
    @bot.message_handler(commands=['txt2img'])
 | 
			
		||||
    def send_txt2img(message):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # check msg comes from testing group
 | 
			
		||||
        chat = message.chat
 | 
			
		||||
        if chat.type != 'group' and chat.id != GROUP_ID:
 | 
			
		||||
| 
						 | 
				
			
			@ -263,6 +280,9 @@ if __name__ == '__main__':
 | 
			
		|||
 | 
			
		||||
    @bot.message_handler(commands=['redo'])
 | 
			
		||||
    def redo_txt2img(message):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # check msg comes from testing group
 | 
			
		||||
        chat = message.chat
 | 
			
		||||
        if chat.type != 'group' and chat.id != GROUP_ID:
 | 
			
		||||
| 
						 | 
				
			
			@ -307,6 +327,9 @@ if __name__ == '__main__':
 | 
			
		|||
 | 
			
		||||
    @bot.message_handler(commands=['config'])
 | 
			
		||||
    def set_config(message):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        params = message.text.split(' ')
 | 
			
		||||
 | 
			
		||||
        if len(params) < 3:
 | 
			
		||||
| 
						 | 
				
			
			@ -362,6 +385,9 @@ if __name__ == '__main__':
 | 
			
		|||
 | 
			
		||||
    @bot.message_handler(commands=['stats'])
 | 
			
		||||
    def user_stats(message):
 | 
			
		||||
        if not its_my_turn():
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        user = message.from_user
 | 
			
		||||
        db_user = get_user(user.id)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue