Fix cancellation system and provide a reason for the raise

guilles_counter_review
Guillermo Rodriguez 2025-02-03 21:08:54 -03:00
parent 399299c62b
commit 722bc4af57
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
2 changed files with 16 additions and 5 deletions

View File

@ -114,7 +114,7 @@ class ModelMngr:
name, mode, cache_dir=self.cache_dir) name, mode, cache_dir=self.cache_dir)
self._model_mode = mode self._model_mode = mode
self._model_name = name self._model_name = name
logging.info('{name} loaded!') logging.info(f'{name} loaded!')
self.log_debug_info() self.log_debug_info()
def compute_one( def compute_one(
@ -125,10 +125,14 @@ class ModelMngr:
inputs: list[bytes] = [] inputs: list[bytes] = []
): ):
def maybe_cancel_work(step, *args, **kwargs): def maybe_cancel_work(step, *args, **kwargs):
'''This is a callback function that gets invoked every inference step,
we need to raise an exception here if we need to cancel work
'''
if self._should_cancel: if self._should_cancel:
should_raise = trio.from_thread.run(self._should_cancel, request_id) should_raise = trio.from_thread.run(self._should_cancel, request_id)
if should_raise: if should_raise:
logging.warn(f'CANCELLING work at step {step}') logging.warning(f'CANCELLING work at step {step}')
raise DGPUInferenceCancelled('network cancel')
return {} return {}

View File

@ -100,12 +100,16 @@ class WorkerDaemon:
async def should_cancel_work(self, request_id: int): async def should_cancel_work(self, request_id: int):
self._benchmark.append(time.time()) self._benchmark.append(time.time())
logging.info('should cancel work?')
if request_id not in self._snap['requests']:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
competitors = set([ competitors = set([
status['worker'] status['worker']
for status in self._snap['requests'][request_id] for status in self._snap['requests'][request_id]
if status['worker'] != self.account if status['worker'] != self.account
]) ])
logging.info('should cancel work?')
logging.info(f'competitors: {competitors}') logging.info(f'competitors: {competitors}')
should_cancel = bool(self.non_compete & competitors) should_cancel = bool(self.non_compete & competitors)
logging.info(f'cancel: {should_cancel}') logging.info(f'cancel: {should_cancel}')
@ -274,8 +278,11 @@ class WorkerDaemon:
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash) await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
except BaseException as err: except BaseException as err:
logging.exception('Failed to serve model request !?\n') if 'network cancel' not in str(err):
await self.conn.cancel_work(rid, str(err)) logging.exception('Failed to serve model request !?\n')
if rid in self._snap['requests']:
await self.conn.cancel_work(rid, 'reason not provided')
finally: finally:
return True return True