mirror of https://github.com/skygpu/skynet.git
Make gpu work cancellable using trio threading apis!, also make docker always reinstall package for easier development
parent
47d9f59dbe
commit
01c78b5d20
|
@ -3,4 +3,6 @@
|
|||
export VIRTUAL_ENV='/skynet/.venv'
|
||||
poetry env use $VIRTUAL_ENV/bin/python
|
||||
|
||||
poetry install
|
||||
|
||||
exec poetry run "$@"
|
||||
|
|
|
@ -3,12 +3,15 @@
|
|||
# Skynet Memory Manager
|
||||
|
||||
import gc
|
||||
from hashlib import sha256
|
||||
import json
|
||||
import logging
|
||||
|
||||
from hashlib import sha256
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
import trio
|
||||
import torch
|
||||
|
||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
||||
from skynet.dgpu.errors import DGPUComputeError
|
||||
|
||||
|
@ -122,10 +125,17 @@ class SkynetMM:
|
|||
|
||||
def compute_one(
|
||||
self,
|
||||
should_cancel_work,
|
||||
method: str,
|
||||
params: dict,
|
||||
binary: bytes | None = None
|
||||
):
|
||||
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor):
|
||||
should_raise = trio.from_thread.run(should_cancel_work)
|
||||
if should_raise:
|
||||
logging.warn(f'cancelling work at step {step}')
|
||||
raise DGPUComputeError('Inference cancelled')
|
||||
|
||||
try:
|
||||
match method:
|
||||
case 'diffuse':
|
||||
|
@ -140,6 +150,8 @@ class SkynetMM:
|
|||
guidance_scale=guidance,
|
||||
num_inference_steps=step,
|
||||
generator=seed,
|
||||
callback=callback_fn,
|
||||
callback_steps=1,
|
||||
**extra_params
|
||||
).images[0]
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import logging
|
|||
import traceback
|
||||
|
||||
from hashlib import sha256
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
|
||||
|
@ -26,6 +27,16 @@ class SkynetDGPUDaemon:
|
|||
config['auto_withdraw']
|
||||
if 'auto_withdraw' in config else False
|
||||
)
|
||||
self.non_compete = set(('testworker2', 'animus2.boid', 'animus1.boid'))
|
||||
self.current_request = None
|
||||
|
||||
async def should_cancel_work(self):
|
||||
competitors = set((
|
||||
status['worker']
|
||||
for status in
|
||||
(await self.conn.get_status_by_request_id(self.current_request))
|
||||
))
|
||||
return self.non_compete & competitors
|
||||
|
||||
async def serve_forever(self):
|
||||
try:
|
||||
|
@ -43,7 +54,7 @@ class SkynetDGPUDaemon:
|
|||
statuses = await self.conn.get_status_by_request_id(rid)
|
||||
|
||||
if len(statuses) == 0:
|
||||
|
||||
self.current_request = rid
|
||||
# parse request
|
||||
body = json.loads(req['body'])
|
||||
|
||||
|
@ -70,8 +81,13 @@ class SkynetDGPUDaemon:
|
|||
|
||||
else:
|
||||
try:
|
||||
img_sha, img_raw = self.mm.compute_one(
|
||||
body['method'], body['params'], binary=binary)
|
||||
img_sha, img_raw = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
self.mm.compute_one,
|
||||
self.should_cancel_work,
|
||||
body['method'], body['params'], binary=binary
|
||||
)
|
||||
)
|
||||
|
||||
ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
|
||||
|
||||
|
|
Loading…
Reference in New Issue