mirror of https://github.com/skygpu/skynet.git
Fix cancellation system and provide a reason for the raise
parent
399299c62b
commit
722bc4af57
|
@ -114,7 +114,7 @@ class ModelMngr:
|
||||||
name, mode, cache_dir=self.cache_dir)
|
name, mode, cache_dir=self.cache_dir)
|
||||||
self._model_mode = mode
|
self._model_mode = mode
|
||||||
self._model_name = name
|
self._model_name = name
|
||||||
logging.info('{name} loaded!')
|
logging.info(f'{name} loaded!')
|
||||||
self.log_debug_info()
|
self.log_debug_info()
|
||||||
|
|
||||||
def compute_one(
|
def compute_one(
|
||||||
|
@ -125,10 +125,14 @@ class ModelMngr:
|
||||||
inputs: list[bytes] = []
|
inputs: list[bytes] = []
|
||||||
):
|
):
|
||||||
def maybe_cancel_work(step, *args, **kwargs):
|
def maybe_cancel_work(step, *args, **kwargs):
|
||||||
|
'''This is a callback function that gets invoked every inference step,
|
||||||
|
we need to raise an exception here if we need to cancel work
|
||||||
|
'''
|
||||||
if self._should_cancel:
|
if self._should_cancel:
|
||||||
should_raise = trio.from_thread.run(self._should_cancel, request_id)
|
should_raise = trio.from_thread.run(self._should_cancel, request_id)
|
||||||
if should_raise:
|
if should_raise:
|
||||||
logging.warn(f'CANCELLING work at step {step}')
|
logging.warning(f'CANCELLING work at step {step}')
|
||||||
|
raise DGPUInferenceCancelled('network cancel')
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
|
@ -100,12 +100,16 @@ class WorkerDaemon:
|
||||||
|
|
||||||
async def should_cancel_work(self, request_id: int):
|
async def should_cancel_work(self, request_id: int):
|
||||||
self._benchmark.append(time.time())
|
self._benchmark.append(time.time())
|
||||||
|
logging.info('should cancel work?')
|
||||||
|
if request_id not in self._snap['requests']:
|
||||||
|
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
|
||||||
|
return True
|
||||||
|
|
||||||
competitors = set([
|
competitors = set([
|
||||||
status['worker']
|
status['worker']
|
||||||
for status in self._snap['requests'][request_id]
|
for status in self._snap['requests'][request_id]
|
||||||
if status['worker'] != self.account
|
if status['worker'] != self.account
|
||||||
])
|
])
|
||||||
logging.info('should cancel work?')
|
|
||||||
logging.info(f'competitors: {competitors}')
|
logging.info(f'competitors: {competitors}')
|
||||||
should_cancel = bool(self.non_compete & competitors)
|
should_cancel = bool(self.non_compete & competitors)
|
||||||
logging.info(f'cancel: {should_cancel}')
|
logging.info(f'cancel: {should_cancel}')
|
||||||
|
@ -274,8 +278,11 @@ class WorkerDaemon:
|
||||||
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
|
||||||
|
|
||||||
except BaseException as err:
|
except BaseException as err:
|
||||||
logging.exception('Failed to serve model request !?\n')
|
if 'network cancel' not in str(err):
|
||||||
await self.conn.cancel_work(rid, str(err))
|
logging.exception('Failed to serve model request !?\n')
|
||||||
|
|
||||||
|
if rid in self._snap['requests']:
|
||||||
|
await self.conn.cancel_work(rid, 'reason not provided')
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
return True
|
return True
|
||||||
|
|
Loading…
Reference in New Issue