mirror of https://github.com/skygpu/skynet.git
Add round robin
parent
563efed3e9
commit
318a21ac81
|
@ -3,9 +3,9 @@ docker run \
|
|||
--rm \
|
||||
--gpus=all \
|
||||
--env HF_TOKEN='' \
|
||||
--env DB_USER='' \
|
||||
--env DB_PASS='' \
|
||||
--env DB_USER='skynet' \
|
||||
--env DB_PASS='password' \
|
||||
--mount type=bind,source="$(pwd)"/outputs,target=/outputs \
|
||||
--mount type=bind,source="$(pwd)"/hf_home,target=/hf_home \
|
||||
--mount type=bind,source="$(pwd)"/scripts,target=/scripts \
|
||||
skynet-art-bot:0.1a3 $1
|
||||
skynet:dif python telegram-bot-dev.py
|
||||
|
|
|
@ -87,6 +87,17 @@ DEFAULT_GUIDANCE = 7.5
|
|||
DEFAULT_STEP = 75
|
||||
DEFAULT_CREDITS = 10
|
||||
|
||||
rr_total = 2
|
||||
rr_id = 0
|
||||
request_counter = 0
|
||||
|
||||
def its_my_turn():
|
||||
global request_counter, rr_total, rr_id
|
||||
my_turn = request_counter % rr_total == rr_id
|
||||
logging.info(f'new request {request_counter}, turn: {my_turn} rr_total: {rr_total}, rr_id {rr_id}')
|
||||
request_counter += 1
|
||||
return my_turn
|
||||
|
||||
|
||||
def generate_image(i, prompt, name, step, size, guidance, seed):
|
||||
assert torch.cuda.is_available()
|
||||
|
@ -118,9 +129,10 @@ if __name__ == '__main__':
|
|||
API_TOKEN = '5880619053:AAFge2UfObw1kCn9Kb7AAyqaIHs_HgM0Fx0'
|
||||
|
||||
bot = telebot.TeleBot(API_TOKEN)
|
||||
db_client = MongoClient(f'mongodb://{db_user}:{db_pass}@localhost:27017')
|
||||
|
||||
rr_id = 0
|
||||
db_client = MongoClient(
|
||||
host=['ancap.tech:64000'],
|
||||
username=db_user,
|
||||
password=db_pass)
|
||||
|
||||
tgdb = db_client.get_database('telegram')
|
||||
|
||||
|
@ -179,8 +191,8 @@ if __name__ == '__main__':
|
|||
return tg_users.find_one_and_update(
|
||||
{'uid': uid}, updt_cmd)
|
||||
|
||||
# bot handler
|
||||
|
||||
# bot handler
|
||||
def img_for_user_with_prompt(
|
||||
uid: int,
|
||||
prompt: str, step: int, size: tuple[int, int], guidance: int, seed: int
|
||||
|
@ -210,14 +222,19 @@ if __name__ == '__main__':
|
|||
|
||||
@bot.message_handler(commands=['help'])
|
||||
def send_help(message):
|
||||
bot.reply_to(message, HELP_TEXT)
|
||||
if its_my_turn():
|
||||
bot.reply_to(message, HELP_TEXT)
|
||||
|
||||
@bot.message_handler(commands=['cool'])
|
||||
def send_help(message):
|
||||
bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
def send_cool_words(message):
|
||||
if its_my_turn():
|
||||
bot.reply_to(message, '\n'.join(COOL_WORDS))
|
||||
|
||||
@bot.message_handler(commands=['txt2img'])
|
||||
def send_txt2img(message):
|
||||
if not its_my_turn():
|
||||
return
|
||||
|
||||
# check msg comes from testing group
|
||||
chat = message.chat
|
||||
if chat.type != 'group' and chat.id != GROUP_ID:
|
||||
|
@ -263,6 +280,9 @@ if __name__ == '__main__':
|
|||
|
||||
@bot.message_handler(commands=['redo'])
|
||||
def redo_txt2img(message):
|
||||
if not its_my_turn():
|
||||
return
|
||||
|
||||
# check msg comes from testing group
|
||||
chat = message.chat
|
||||
if chat.type != 'group' and chat.id != GROUP_ID:
|
||||
|
@ -307,6 +327,9 @@ if __name__ == '__main__':
|
|||
|
||||
@bot.message_handler(commands=['config'])
|
||||
def set_config(message):
|
||||
if not its_my_turn():
|
||||
return
|
||||
|
||||
params = message.text.split(' ')
|
||||
|
||||
if len(params) < 3:
|
||||
|
@ -362,6 +385,9 @@ if __name__ == '__main__':
|
|||
|
||||
@bot.message_handler(commands=['stats'])
|
||||
def user_stats(message):
|
||||
if not its_my_turn():
|
||||
return
|
||||
|
||||
user = message.from_user
|
||||
db_user = get_user(user.id)
|
||||
|
||||
|
|
Loading…
Reference in New Issue