From 318a21ac818de85b96bb95488b11e5cccb5a9da6 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Sun, 4 Dec 2022 21:31:57 -0300 Subject: [PATCH] Add round robin --- run-bot.sh | 6 +++--- scripts/telegram-bot-dev.py | 40 ++++++++++++++++++++++++++++++------- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/run-bot.sh b/run-bot.sh index eb136d5..acc606b 100755 --- a/run-bot.sh +++ b/run-bot.sh @@ -3,9 +3,9 @@ docker run \ --rm \ --gpus=all \ --env HF_TOKEN='' \ - --env DB_USER='' \ - --env DB_PASS='' \ + --env DB_USER='skynet' \ + --env DB_PASS='password' \ --mount type=bind,source="$(pwd)"/outputs,target=/outputs \ --mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \ --mount type=bind,source="$(pwd)"/scripts,target=/scripts \ - skynet-art-bot:0.1a3 $1 + skynet:dif python telegram-bot-dev.py diff --git a/scripts/telegram-bot-dev.py b/scripts/telegram-bot-dev.py index 8d1d2ae..62cba42 100644 --- a/scripts/telegram-bot-dev.py +++ b/scripts/telegram-bot-dev.py @@ -87,6 +87,17 @@ DEFAULT_GUIDANCE = 7.5 DEFAULT_STEP = 75 DEFAULT_CREDITS = 10 +rr_total = 2 +rr_id = 0 +request_counter = 0 + +def its_my_turn(): + global request_counter, rr_total, rr_id + my_turn = request_counter % rr_total == rr_id + logging.info(f'new request {request_counter}, turn: {my_turn} rr_total: {rr_total}, rr_id {rr_id}') + request_counter += 1 + return my_turn + def generate_image(i, prompt, name, step, size, guidance, seed): assert torch.cuda.is_available() @@ -118,9 +129,10 @@ if __name__ == '__main__': API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0' bot = telebot.TeleBot(API_TOKEN) - db_client = MongoClient(f'mongodb://{db_user}:{db_pass}@localhost:27017') - - rr_id = 0 + db_client = MongoClient( + host=['ancap.tech:64000'], + username=db_user, + password=db_pass) tgdb = db_client.get_database('telegram') @@ -179,8 +191,8 @@ if __name__ == '__main__': return tg_users.find_one_and_update( {'uid': uid}, updt_cmd) - # bot handler + # bot handler def img_for_user_with_prompt( uid: int, prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int @@ -210,14 +222,19 @@ if __name__ == '__main__': @bot.message_handler(commands=['help']) def send_help(message): - bot.reply_to(message, HELP_TEXT) + if its_my_turn(): + bot.reply_to(message, HELP_TEXT) @bot.message_handler(commands=['cool']) - def send_help(message): - bot.reply_to(message, '\n'.join(COOL_WORDS)) + def send_cool_words(message): + if its_my_turn(): + bot.reply_to(message, '\n'.join(COOL_WORDS)) @bot.message_handler(commands=['txt2img']) def send_txt2img(message): + if not its_my_turn(): + return + # check msg comes from testing group chat = message.chat if chat.type != 'group' and chat.id != GROUP_ID: @@ -263,6 +280,9 @@ if __name__ == '__main__': @bot.message_handler(commands=['redo']) def redo_txt2img(message): + if not its_my_turn(): + return + # check msg comes from testing group chat = message.chat if chat.type != 'group' and chat.id != GROUP_ID: @@ -307,6 +327,9 @@ if __name__ == '__main__': @bot.message_handler(commands=['config']) def set_config(message): + if not its_my_turn(): + return + params = message.text.split(' ') if len(params) < 3: @@ -362,6 +385,9 @@ if __name__ == '__main__': @bot.message_handler(commands=['stats']) def user_stats(message): + if not its_my_turn(): + return + user = message.from_user db_user = get_user(user.id)