diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 1c2f365..3e8f26b 100644 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -28,6 +28,15 @@ class SkynetDGPUDaemon: if 'auto_withdraw' in config else False ) self.non_compete = set(config['non_compete']) + + self.model_whitelist = set() + if 'model_whitelist' in config: + self.model_whitelist = set(config['model_whitelist']) + + self.model_blacklist = set() + if 'model_blacklist' in config: + self.model_blacklist = set(config['model_blacklist']) + self.current_request = None async def should_cancel_work(self): @@ -49,14 +58,25 @@ class SkynetDGPUDaemon: for req in queue: rid = req['id'] + # parse request + body = json.loads(req['body']) + model = body['params']['model'] + + # if whitelist enabled and model not in it continue + if (len(self.model_whitelist) > 0 and + not model in self.model_whitelist): + continue + + # if blacklist contains model skip + if model in self.model_blacklist: + continue + my_results = [res['id'] for res in (await self.conn.find_my_results())] if rid not in my_results: statuses = await self.conn.get_status_by_request_id(rid) if len(statuses) == 0: self.current_request = rid - # parse request - body = json.loads(req['body']) binary = await self.conn.get_input_data(req['binary_data'])