mirror of https://github.com/skygpu/skynet.git
				
				
				
			Fix cancellation system and provide a reason for the raise
							parent
							
								
									a5dbe5ab12
								
							
						
					
					
						commit
						4c9be4e63e
					
				| 
						 | 
					@ -114,7 +114,7 @@ class ModelMngr:
 | 
				
			||||||
            name, mode, cache_dir=self.cache_dir)
 | 
					            name, mode, cache_dir=self.cache_dir)
 | 
				
			||||||
        self._model_mode = mode
 | 
					        self._model_mode = mode
 | 
				
			||||||
        self._model_name = name
 | 
					        self._model_name = name
 | 
				
			||||||
        logging.info('{name} loaded!')
 | 
					        logging.info(f'{name} loaded!')
 | 
				
			||||||
        self.log_debug_info()
 | 
					        self.log_debug_info()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compute_one(
 | 
					    def compute_one(
 | 
				
			||||||
| 
						 | 
					@ -125,10 +125,14 @@ class ModelMngr:
 | 
				
			||||||
        inputs: list[bytes] = []
 | 
					        inputs: list[bytes] = []
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        def maybe_cancel_work(step, *args, **kwargs):
 | 
					        def maybe_cancel_work(step, *args, **kwargs):
 | 
				
			||||||
 | 
					            '''This is a callback function that gets invoked every inference step,
 | 
				
			||||||
 | 
					            we need to raise an exception here if we need to cancel work
 | 
				
			||||||
 | 
					            '''
 | 
				
			||||||
            if self._should_cancel:
 | 
					            if self._should_cancel:
 | 
				
			||||||
                should_raise = trio.from_thread.run(self._should_cancel, request_id)
 | 
					                should_raise = trio.from_thread.run(self._should_cancel, request_id)
 | 
				
			||||||
                if should_raise:
 | 
					                if should_raise:
 | 
				
			||||||
                    logging.warn(f'CANCELLING work at step {step}')
 | 
					                    logging.warning(f'CANCELLING work at step {step}')
 | 
				
			||||||
 | 
					                    raise DGPUInferenceCancelled('network cancel')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return {}
 | 
					            return {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -100,12 +100,16 @@ class WorkerDaemon:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def should_cancel_work(self, request_id: int):
 | 
					    async def should_cancel_work(self, request_id: int):
 | 
				
			||||||
        self._benchmark.append(time.time())
 | 
					        self._benchmark.append(time.time())
 | 
				
			||||||
 | 
					        logging.info('should cancel work?')
 | 
				
			||||||
 | 
					        if request_id not in self._snap['requests']:
 | 
				
			||||||
 | 
					            logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        competitors = set([
 | 
					        competitors = set([
 | 
				
			||||||
            status['worker']
 | 
					            status['worker']
 | 
				
			||||||
            for status in self._snap['requests'][request_id]
 | 
					            for status in self._snap['requests'][request_id]
 | 
				
			||||||
            if status['worker'] != self.account
 | 
					            if status['worker'] != self.account
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
        logging.info('should cancel work?')
 | 
					 | 
				
			||||||
        logging.info(f'competitors: {competitors}')
 | 
					        logging.info(f'competitors: {competitors}')
 | 
				
			||||||
        should_cancel = bool(self.non_compete & competitors)
 | 
					        should_cancel = bool(self.non_compete & competitors)
 | 
				
			||||||
        logging.info(f'cancel: {should_cancel}')
 | 
					        logging.info(f'cancel: {should_cancel}')
 | 
				
			||||||
| 
						 | 
					@ -274,8 +278,11 @@ class WorkerDaemon:
 | 
				
			||||||
                await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
 | 
					                await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            except BaseException as err:
 | 
					            except BaseException as err:
 | 
				
			||||||
 | 
					                if 'network cancel' not in str(err):
 | 
				
			||||||
                    logging.exception('Failed to serve model request !?\n')
 | 
					                    logging.exception('Failed to serve model request !?\n')
 | 
				
			||||||
                await self.conn.cancel_work(rid, str(err))
 | 
					
 | 
				
			||||||
 | 
					                if rid in self._snap['requests']:
 | 
				
			||||||
 | 
					                    await self.conn.cancel_work(rid, 'reason not provided')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            finally:
 | 
					            finally:
 | 
				
			||||||
                return True
 | 
					                return True
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue