Make gpu work cancellable using trio threading apis!, also make docker always reinstall package for easier development

pull/26/head
Guillermo Rodriguez 2023-10-05 15:07:42 -03:00
parent 47d9f59dbe
commit 01c78b5d20
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
3 changed files with 34 additions and 4 deletions

View File

@ -3,4 +3,6 @@
export VIRTUAL_ENV='/skynet/.venv'
poetry env use $VIRTUAL_ENV/bin/python
poetry install
exec poetry run "$@"

View File

@ -3,12 +3,15 @@
# Skynet Memory Manager
import gc
from hashlib import sha256
import json
import logging
from hashlib import sha256
from diffusers import DiffusionPipeline
import trio
import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
from skynet.dgpu.errors import DGPUComputeError
@ -122,10 +125,17 @@ class SkynetMM:
def compute_one(
self,
should_cancel_work,
method: str,
params: dict,
binary: bytes | None = None
):
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor):
should_raise = trio.from_thread.run(should_cancel_work)
if should_raise:
logging.warn(f'cancelling work at step {step}')
raise DGPUComputeError('Inference cancelled')
try:
match method:
case 'diffuse':
@ -140,6 +150,8 @@ class SkynetMM:
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
callback=callback_fn,
callback_steps=1,
**extra_params
).images[0]

View File

@ -5,6 +5,7 @@ import logging
import traceback
from hashlib import sha256
from functools import partial
import trio
@ -26,6 +27,16 @@ class SkynetDGPUDaemon:
config['auto_withdraw']
if 'auto_withdraw' in config else False
)
self.non_compete = set(('testworker2', 'animus2.boid', 'animus1.boid'))
self.current_request = None
async def should_cancel_work(self):
competitors = set((
status['worker']
for status in
(await self.conn.get_status_by_request_id(self.current_request))
))
return self.non_compete & competitors
async def serve_forever(self):
try:
@ -43,7 +54,7 @@ class SkynetDGPUDaemon:
statuses = await self.conn.get_status_by_request_id(rid)
if len(statuses) == 0:
self.current_request = rid
# parse request
body = json.loads(req['body'])
@ -70,8 +81,13 @@ class SkynetDGPUDaemon:
else:
try:
img_sha, img_raw = self.mm.compute_one(
body['method'], body['params'], binary=binary)
img_sha, img_raw = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
self.should_cancel_work,
body['method'], body['params'], binary=binary
)
)
ipfs_hash = await self.conn.publish_on_ipfs(img_raw)