mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add whitelist & blacklist
							parent
							
								
									ad1a9ef9ea
								
							
						
					
					
						commit
						342dd9ac1c
					
				| 
						 | 
					@ -28,6 +28,15 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
            if 'auto_withdraw' in config else False
 | 
					            if 'auto_withdraw' in config else False
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.non_compete = set(config['non_compete'])
 | 
					        self.non_compete = set(config['non_compete'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.model_whitelist = set()
 | 
				
			||||||
 | 
					        if 'model_whitelist' in config:
 | 
				
			||||||
 | 
					            self.model_whitelist = set(config['model_whitelist'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.model_blacklist = set()
 | 
				
			||||||
 | 
					        if 'model_blacklist' in config:
 | 
				
			||||||
 | 
					            self.model_blacklist = set(config['model_blacklist'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.current_request = None
 | 
					        self.current_request = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def should_cancel_work(self):
 | 
					    async def should_cancel_work(self):
 | 
				
			||||||
| 
						 | 
					@ -49,14 +58,25 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
                for req in queue:
 | 
					                for req in queue:
 | 
				
			||||||
                    rid = req['id']
 | 
					                    rid = req['id']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # parse request
 | 
				
			||||||
 | 
					                    body = json.loads(req['body'])
 | 
				
			||||||
 | 
					                    model = body['params']['model']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # if whitelist enabled and model not in it continue
 | 
				
			||||||
 | 
					                    if (len(self.model_whitelist) > 0 and
 | 
				
			||||||
 | 
					                        not model in self.model_whitelist):
 | 
				
			||||||
 | 
					                        continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # if blacklist contains model skip
 | 
				
			||||||
 | 
					                    if model in self.model_blacklist:
 | 
				
			||||||
 | 
					                        continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    my_results = [res['id'] for res in (await self.conn.find_my_results())]
 | 
					                    my_results = [res['id'] for res in (await self.conn.find_my_results())]
 | 
				
			||||||
                    if rid not in my_results:
 | 
					                    if rid not in my_results:
 | 
				
			||||||
                        statuses = await self.conn.get_status_by_request_id(rid)
 | 
					                        statuses = await self.conn.get_status_by_request_id(rid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        if len(statuses) == 0:
 | 
					                        if len(statuses) == 0:
 | 
				
			||||||
                            self.current_request = rid
 | 
					                            self.current_request = rid
 | 
				
			||||||
                            # parse request
 | 
					 | 
				
			||||||
                            body = json.loads(req['body'])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            binary = await self.conn.get_input_data(req['binary_data'])
 | 
					                            binary = await self.conn.get_input_data(req['binary_data'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue