mirror of https://github.com/skygpu/skynet.git
Add whitelist & blacklist
parent
ad1a9ef9ea
commit
342dd9ac1c
|
@ -28,6 +28,15 @@ class SkynetDGPUDaemon:
|
||||||
if 'auto_withdraw' in config else False
|
if 'auto_withdraw' in config else False
|
||||||
)
|
)
|
||||||
self.non_compete = set(config['non_compete'])
|
self.non_compete = set(config['non_compete'])
|
||||||
|
|
||||||
|
self.model_whitelist = set()
|
||||||
|
if 'model_whitelist' in config:
|
||||||
|
self.model_whitelist = set(config['model_whitelist'])
|
||||||
|
|
||||||
|
self.model_blacklist = set()
|
||||||
|
if 'model_blacklist' in config:
|
||||||
|
self.model_blacklist = set(config['model_blacklist'])
|
||||||
|
|
||||||
self.current_request = None
|
self.current_request = None
|
||||||
|
|
||||||
async def should_cancel_work(self):
|
async def should_cancel_work(self):
|
||||||
|
@ -49,14 +58,25 @@ class SkynetDGPUDaemon:
|
||||||
for req in queue:
|
for req in queue:
|
||||||
rid = req['id']
|
rid = req['id']
|
||||||
|
|
||||||
|
# parse request
|
||||||
|
body = json.loads(req['body'])
|
||||||
|
model = body['params']['model']
|
||||||
|
|
||||||
|
# if whitelist enabled and model not in it continue
|
||||||
|
if (len(self.model_whitelist) > 0 and
|
||||||
|
not model in self.model_whitelist):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if blacklist contains model skip
|
||||||
|
if model in self.model_blacklist:
|
||||||
|
continue
|
||||||
|
|
||||||
my_results = [res['id'] for res in (await self.conn.find_my_results())]
|
my_results = [res['id'] for res in (await self.conn.find_my_results())]
|
||||||
if rid not in my_results:
|
if rid not in my_results:
|
||||||
statuses = await self.conn.get_status_by_request_id(rid)
|
statuses = await self.conn.get_status_by_request_id(rid)
|
||||||
|
|
||||||
if len(statuses) == 0:
|
if len(statuses) == 0:
|
||||||
self.current_request = rid
|
self.current_request = rid
|
||||||
# parse request
|
|
||||||
body = json.loads(req['body'])
|
|
||||||
|
|
||||||
binary = await self.conn.get_input_data(req['binary_data'])
|
binary = await self.conn.get_input_data(req['binary_data'])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue