From 01c78b5d20cb8fe2cab73fbd37944be3f443f5f3 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Thu, 5 Oct 2023 15:07:42 -0300 Subject: [PATCH] Make gpu work cancellable using trio threading apis!, also make docker always reinstall package for easier development --- docker/entrypoint.sh | 2 ++ skynet/dgpu/compute.py | 14 +++++++++++++- skynet/dgpu/daemon.py | 22 +++++++++++++++++++--- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 788341a..80cc2ce 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -3,4 +3,6 @@ export VIRTUAL_ENV='/skynet/.venv' poetry env use $VIRTUAL_ENV/bin/python +poetry install + exec poetry run "$@" diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index 069af47..9620c00 100644 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -3,12 +3,15 @@ # Skynet Memory Manager import gc -from hashlib import sha256 import json import logging + +from hashlib import sha256 from diffusers import DiffusionPipeline +import trio import torch + from skynet.constants import DEFAULT_INITAL_MODELS, MODELS from skynet.dgpu.errors import DGPUComputeError @@ -122,10 +125,17 @@ class SkynetMM: def compute_one( self, + should_cancel_work, method: str, params: dict, binary: bytes | None = None ): + def callback_fn(step: int, timestep: int, latents: torch.FloatTensor): + should_raise = trio.from_thread.run(should_cancel_work) + if should_raise: + logging.warn(f'cancelling work at step {step}') + raise DGPUComputeError('Inference cancelled') + try: match method: case 'diffuse': @@ -140,6 +150,8 @@ class SkynetMM: guidance_scale=guidance, num_inference_steps=step, generator=seed, + callback=callback_fn, + callback_steps=1, **extra_params ).images[0] diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index bd0f1f9..ce3eee1 100644 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -5,6 +5,7 @@ import logging import traceback from hashlib import sha256 +from functools import partial import trio @@ -26,6 +27,16 @@ class SkynetDGPUDaemon: config['auto_withdraw'] if 'auto_withdraw' in config else False ) + self.non_compete = set(('testworker2', 'animus2.boid', 'animus1.boid')) + self.current_request = None + + async def should_cancel_work(self): + competitors = set(( + status['worker'] + for status in + (await self.conn.get_status_by_request_id(self.current_request)) + )) + return self.non_compete & competitors async def serve_forever(self): try: @@ -43,7 +54,7 @@ class SkynetDGPUDaemon: statuses = await self.conn.get_status_by_request_id(rid) if len(statuses) == 0: - + self.current_request = rid # parse request body = json.loads(req['body']) @@ -70,8 +81,13 @@ class SkynetDGPUDaemon: else: try: - img_sha, img_raw = self.mm.compute_one( - body['method'], body['params'], binary=binary) + img_sha, img_raw = await trio.to_thread.run_sync( + partial( + self.mm.compute_one, + self.should_cancel_work, + body['method'], body['params'], binary=binary + ) + ) ipfs_hash = await self.conn.publish_on_ipfs(img_raw)