From 722bc4af57d5904f3692b342d1dcef0b54773233 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Mon, 3 Feb 2025 21:08:54 -0300 Subject: [PATCH] Fix cancellation system and provide a reason for the raise --- skynet/dgpu/compute.py | 8 ++++++-- skynet/dgpu/daemon.py | 13 ++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index 8b10bd7..56403a1 100755 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -114,7 +114,7 @@ class ModelMngr: name, mode, cache_dir=self.cache_dir) self._model_mode = mode self._model_name = name - logging.info('{name} loaded!') + logging.info(f'{name} loaded!') self.log_debug_info() def compute_one( @@ -125,10 +125,14 @@ class ModelMngr: inputs: list[bytes] = [] ): def maybe_cancel_work(step, *args, **kwargs): + '''This is a callback function that gets invoked every inference step, + we need to raise an exception here if we need to cancel work + ''' if self._should_cancel: should_raise = trio.from_thread.run(self._should_cancel, request_id) if should_raise: - logging.warn(f'CANCELLING work at step {step}') + logging.warning(f'CANCELLING work at step {step}') + raise DGPUInferenceCancelled('network cancel') return {} diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 137259b..bfcab79 100755 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -100,12 +100,16 @@ class WorkerDaemon: async def should_cancel_work(self, request_id: int): self._benchmark.append(time.time()) + logging.info('should cancel work?') + if request_id not in self._snap['requests']: + logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...') + return True + competitors = set([ status['worker'] for status in self._snap['requests'][request_id] if status['worker'] != self.account ]) - logging.info('should cancel work?') logging.info(f'competitors: {competitors}') should_cancel = bool(self.non_compete & competitors) logging.info(f'cancel: {should_cancel}') @@ -274,8 +278,11 @@ class WorkerDaemon: await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash) except BaseException as err: - logging.exception('Failed to serve model request !?\n') - await self.conn.cancel_work(rid, str(err)) + if 'network cancel' not in str(err): + logging.exception('Failed to serve model request !?\n') + + if rid in self._snap['requests']: + await self.conn.cancel_work(rid, 'reason not provided') finally: return True