mirror of https://github.com/skygpu/skynet.git
Make gpu work cancellable using trio threading apis!, also make docker always reinstall package for easier development
parent
47d9f59dbe
commit
01c78b5d20
|
@ -3,4 +3,6 @@
|
||||||
export VIRTUAL_ENV='/skynet/.venv'
|
export VIRTUAL_ENV='/skynet/.venv'
|
||||||
poetry env use $VIRTUAL_ENV/bin/python
|
poetry env use $VIRTUAL_ENV/bin/python
|
||||||
|
|
||||||
|
poetry install
|
||||||
|
|
||||||
exec poetry run "$@"
|
exec poetry run "$@"
|
||||||
|
|
|
@ -3,12 +3,15 @@
|
||||||
# Skynet Memory Manager
|
# Skynet Memory Manager
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
from hashlib import sha256
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from hashlib import sha256
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
import trio
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
||||||
from skynet.dgpu.errors import DGPUComputeError
|
from skynet.dgpu.errors import DGPUComputeError
|
||||||
|
|
||||||
|
@ -122,10 +125,17 @@ class SkynetMM:
|
||||||
|
|
||||||
def compute_one(
|
def compute_one(
|
||||||
self,
|
self,
|
||||||
|
should_cancel_work,
|
||||||
method: str,
|
method: str,
|
||||||
params: dict,
|
params: dict,
|
||||||
binary: bytes | None = None
|
binary: bytes | None = None
|
||||||
):
|
):
|
||||||
|
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor):
|
||||||
|
should_raise = trio.from_thread.run(should_cancel_work)
|
||||||
|
if should_raise:
|
||||||
|
logging.warn(f'cancelling work at step {step}')
|
||||||
|
raise DGPUComputeError('Inference cancelled')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
match method:
|
match method:
|
||||||
case 'diffuse':
|
case 'diffuse':
|
||||||
|
@ -140,6 +150,8 @@ class SkynetMM:
|
||||||
guidance_scale=guidance,
|
guidance_scale=guidance,
|
||||||
num_inference_steps=step,
|
num_inference_steps=step,
|
||||||
generator=seed,
|
generator=seed,
|
||||||
|
callback=callback_fn,
|
||||||
|
callback_steps=1,
|
||||||
**extra_params
|
**extra_params
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
@ -26,6 +27,16 @@ class SkynetDGPUDaemon:
|
||||||
config['auto_withdraw']
|
config['auto_withdraw']
|
||||||
if 'auto_withdraw' in config else False
|
if 'auto_withdraw' in config else False
|
||||||
)
|
)
|
||||||
|
self.non_compete = set(('testworker2', 'animus2.boid', 'animus1.boid'))
|
||||||
|
self.current_request = None
|
||||||
|
|
||||||
|
async def should_cancel_work(self):
|
||||||
|
competitors = set((
|
||||||
|
status['worker']
|
||||||
|
for status in
|
||||||
|
(await self.conn.get_status_by_request_id(self.current_request))
|
||||||
|
))
|
||||||
|
return self.non_compete & competitors
|
||||||
|
|
||||||
async def serve_forever(self):
|
async def serve_forever(self):
|
||||||
try:
|
try:
|
||||||
|
@ -43,7 +54,7 @@ class SkynetDGPUDaemon:
|
||||||
statuses = await self.conn.get_status_by_request_id(rid)
|
statuses = await self.conn.get_status_by_request_id(rid)
|
||||||
|
|
||||||
if len(statuses) == 0:
|
if len(statuses) == 0:
|
||||||
|
self.current_request = rid
|
||||||
# parse request
|
# parse request
|
||||||
body = json.loads(req['body'])
|
body = json.loads(req['body'])
|
||||||
|
|
||||||
|
@ -70,8 +81,13 @@ class SkynetDGPUDaemon:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
img_sha, img_raw = self.mm.compute_one(
|
img_sha, img_raw = await trio.to_thread.run_sync(
|
||||||
body['method'], body['params'], binary=binary)
|
partial(
|
||||||
|
self.mm.compute_one,
|
||||||
|
self.should_cancel_work,
|
||||||
|
body['method'], body['params'], binary=binary
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
|
ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue