mirror of https://github.com/skygpu/skynet.git
Fix cli entrypoints to use new config, improve competitor cancel logic and add default docker image to py311 image
parent
5437af4d05
commit
cc4a4b5189
|
@ -5,3 +5,7 @@ docker build \
|
|||
docker build \
|
||||
-t guilledk/skynet:runtime-cuda-py311 \
|
||||
-f docker/Dockerfile.runtime+cuda-py311 .
|
||||
|
||||
docker build \
|
||||
-t guilledk/skynet:runtime-cuda \
|
||||
-f docker/Dockerfile.runtime+cuda-py311 .
|
||||
|
|
|
@ -33,8 +33,8 @@ def txt2img(*args, **kwargs):
|
|||
from . import utils
|
||||
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
utils.txt2img(hf_token, **kwargs)
|
||||
|
||||
|
@ -51,8 +51,8 @@ def txt2img(*args, **kwargs):
|
|||
def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
||||
from . import utils
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
utils.img2img(
|
||||
hf_token,
|
||||
|
@ -82,8 +82,8 @@ def upscale(input, output, model):
|
|||
def download():
|
||||
from . import utils
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
utils.download_all_models(hf_token)
|
||||
|
||||
|
@ -112,10 +112,10 @@ def enqueue(
|
|||
|
||||
config = load_skynet_toml()
|
||||
|
||||
key = load_key(config, 'skynet.user', 'key')
|
||||
account = load_key(config, 'skynet.user', 'account')
|
||||
permission = load_key(config, 'skynet.user', 'permission')
|
||||
node_url = load_key(config, 'skynet.user', 'node_url')
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
||||
|
@ -156,10 +156,10 @@ def clean(
|
|||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
key = load_key(config, 'skynet.user', 'key')
|
||||
account = load_key(config, 'skynet.user', 'account')
|
||||
permission = load_key(config, 'skynet.user', 'permission')
|
||||
node_url = load_key(config, 'skynet.user', 'node_url')
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
@ -177,7 +177,7 @@ def clean(
|
|||
def queue():
|
||||
import requests
|
||||
config = load_skynet_toml()
|
||||
node_url = load_key(config, 'skynet.user', 'node_url')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
resp = requests.post(
|
||||
f'{node_url}/v1/chain/get_table_rows',
|
||||
json={
|
||||
|
@ -194,7 +194,7 @@ def queue():
|
|||
def status(request_id: int):
|
||||
import requests
|
||||
config = load_skynet_toml()
|
||||
node_url = load_key(config, 'skynet.user', 'node_url')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
resp = requests.post(
|
||||
f'{node_url}/v1/chain/get_table_rows',
|
||||
json={
|
||||
|
@ -213,10 +213,10 @@ def dequeue(request_id: int):
|
|||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
key = load_key(config, 'skynet.user', 'key')
|
||||
account = load_key(config, 'skynet.user', 'account')
|
||||
permission = load_key(config, 'skynet.user', 'permission')
|
||||
node_url = load_key(config, 'skynet.user', 'node_url')
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
res = trio.run(
|
||||
|
@ -248,10 +248,10 @@ def config(
|
|||
|
||||
config = load_skynet_toml()
|
||||
|
||||
key = load_key(config, 'skynet.user', 'key')
|
||||
account = load_key(config, 'skynet.user', 'account')
|
||||
permission = load_key(config, 'skynet.user', 'permission')
|
||||
node_url = load_key(config, 'skynet.user', 'node_url')
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
res = trio.run(
|
||||
|
@ -277,10 +277,10 @@ def deposit(quantity: str):
|
|||
|
||||
config = load_skynet_toml()
|
||||
|
||||
key = load_key(config, 'skynet.user', 'key')
|
||||
account = load_key(config, 'skynet.user', 'account')
|
||||
permission = load_key(config, 'skynet.user', 'permission')
|
||||
node_url = load_key(config, 'skynet.user', 'node_url')
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
||||
res = trio.run(
|
||||
|
@ -365,21 +365,21 @@ def telegram(
|
|||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
tg_token = load_key(config, 'skynet.telegram', 'tg_token')
|
||||
tg_token = load_key(config, 'skynet.telegram.tg_token')
|
||||
|
||||
key = load_key(config, 'skynet.telegram', 'key')
|
||||
account = load_key(config, 'skynet.telegram', 'account')
|
||||
permission = load_key(config, 'skynet.telegram', 'permission')
|
||||
node_url = load_key(config, 'skynet.telegram', 'node_url')
|
||||
hyperion_url = load_key(config, 'skynet.telegram', 'hyperion_url')
|
||||
key = load_key(config, 'skynet.telegram.key')
|
||||
account = load_key(config, 'skynet.telegram.account')
|
||||
permission = load_key(config, 'skynet.telegram.permission')
|
||||
node_url = load_key(config, 'skynet.telegram.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.telegram.hyperion_url')
|
||||
|
||||
try:
|
||||
ipfs_gateway_url = load_key(config, 'skynet.telegram', 'ipfs_gateway_url')
|
||||
ipfs_gateway_url = load_key(config, 'skynet.telegram.ipfs_gateway_url')
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_gateway_url = None
|
||||
|
||||
ipfs_url = load_key(config, 'skynet.telegram', 'ipfs_url')
|
||||
ipfs_url = load_key(config, 'skynet.telegram.ipfs_url')
|
||||
|
||||
async def _async_main():
|
||||
frontend = SkynetTelegramFrontend(
|
||||
|
@ -421,16 +421,16 @@ def discord(
|
|||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
dc_token = load_key(config, 'skynet.discord', 'dc_token')
|
||||
dc_token = load_key(config, 'skynet.discord.dc_token')
|
||||
|
||||
key = load_key(config, 'skynet.discord', 'key')
|
||||
account = load_key(config, 'skynet.discord', 'account')
|
||||
permission = load_key(config, 'skynet.discord', 'permission')
|
||||
node_url = load_key(config, 'skynet.discord', 'node_url')
|
||||
hyperion_url = load_key(config, 'skynet.discord', 'hyperion_url')
|
||||
key = load_key(config, 'skynet.discord.key')
|
||||
account = load_key(config, 'skynet.discord.account')
|
||||
permission = load_key(config, 'skynet.discord.permission')
|
||||
node_url = load_key(config, 'skynet.discord.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.discord.hyperion_url')
|
||||
|
||||
ipfs_gateway_url = load_key(config, 'skynet.discord', 'ipfs_gateway_url')
|
||||
ipfs_url = load_key(config, 'skynet.discord', 'ipfs_url')
|
||||
ipfs_gateway_url = load_key(config, 'skynet.discord.ipfs_gateway_url')
|
||||
ipfs_url = load_key(config, 'skynet.discord.ipfs_url')
|
||||
|
||||
async def _async_main():
|
||||
frontend = SkynetDiscordFrontend(
|
||||
|
@ -471,8 +471,8 @@ def pinner(loglevel):
|
|||
from .ipfs.pinner import SkynetPinner
|
||||
|
||||
config = load_skynet_toml()
|
||||
hyperion_url = load_key(config, 'skynet.pinner', 'hyperion_url')
|
||||
ipfs_url = load_key(config, 'skynet.pinner', 'ipfs_url')
|
||||
hyperion_url = load_key(config, 'skynet.pinner.hyperion_url')
|
||||
ipfs_url = load_key(config, 'skynet.pinner.ipfs_url')
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
ipfs_node = AsyncIPFSHTTP(ipfs_url)
|
||||
|
|
|
@ -13,7 +13,7 @@ import trio
|
|||
import torch
|
||||
|
||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
|
||||
from skynet.dgpu.errors import DGPUComputeError
|
||||
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
|
||||
|
||||
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
|
||||
|
||||
|
@ -132,16 +132,19 @@ class SkynetMM:
|
|||
|
||||
def compute_one(
|
||||
self,
|
||||
request_id: int,
|
||||
should_cancel_work,
|
||||
method: str,
|
||||
params: dict,
|
||||
binary: bytes | None = None
|
||||
):
|
||||
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor):
|
||||
should_raise = trio.from_thread.run(should_cancel_work)
|
||||
def maybe_cancel_work(step, *args, **kwargs):
|
||||
should_raise = trio.from_thread.run(should_cancel_work, request_id)
|
||||
if should_raise:
|
||||
logging.warn(f'cancelling work at step {step}')
|
||||
raise DGPUComputeError('Inference cancelled')
|
||||
raise DGPUInferenceCancelled()
|
||||
|
||||
maybe_cancel_work(0)
|
||||
|
||||
try:
|
||||
match method:
|
||||
|
@ -157,7 +160,7 @@ class SkynetMM:
|
|||
guidance_scale=guidance,
|
||||
num_inference_steps=step,
|
||||
generator=seed,
|
||||
callback=callback_fn,
|
||||
callback=maybe_cancel_work,
|
||||
callback_steps=2,
|
||||
**extra_params
|
||||
).images[0]
|
||||
|
|
|
@ -40,15 +40,9 @@ class SkynetDGPUDaemon:
|
|||
if 'model_blacklist' in config:
|
||||
self.model_blacklist = set(config['model_blacklist'])
|
||||
|
||||
self.current_request = None
|
||||
|
||||
async def should_cancel_work(self):
|
||||
competitors = set((
|
||||
status['worker']
|
||||
for status in
|
||||
(await self.conn.get_status_by_request_id(self.current_request))
|
||||
))
|
||||
return self.non_compete & competitors
|
||||
async def should_cancel_work(self, request_id: int):
|
||||
competitors = await self.conn.get_competitors_for_req(request_id)
|
||||
return bool(self.non_compete & competitors)
|
||||
|
||||
async def serve_forever(self):
|
||||
try:
|
||||
|
@ -79,7 +73,7 @@ class SkynetDGPUDaemon:
|
|||
statuses = await self.conn.get_status_by_request_id(rid)
|
||||
|
||||
if len(statuses) == 0:
|
||||
self.current_request = rid
|
||||
self.conn.monitor_request(rid)
|
||||
|
||||
binary = await self.conn.get_input_data(req['binary_data'])
|
||||
|
||||
|
@ -107,6 +101,7 @@ class SkynetDGPUDaemon:
|
|||
img_sha, img_raw = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
self.mm.compute_one,
|
||||
rid,
|
||||
self.should_cancel_work,
|
||||
body['method'], body['params'], binary=binary
|
||||
)
|
||||
|
@ -115,11 +110,13 @@ class SkynetDGPUDaemon:
|
|||
ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
|
||||
|
||||
await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash)
|
||||
break
|
||||
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
await self.conn.cancel_work(rid, str(e))
|
||||
|
||||
finally:
|
||||
self.conn.forget_request(rid)
|
||||
break
|
||||
|
||||
else:
|
||||
|
|
|
@ -3,3 +3,6 @@
|
|||
|
||||
class DGPUComputeError(BaseException):
|
||||
...
|
||||
|
||||
class DGPUInferenceCancelled(BaseException):
|
||||
...
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
from functools import partial
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
import time
|
||||
import logging
|
||||
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
import asks
|
||||
import trio
|
||||
import anyio
|
||||
|
||||
from PIL import Image
|
||||
|
@ -20,6 +22,9 @@ from skynet.dgpu.errors import DGPUComputeError
|
|||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
|
||||
|
||||
|
||||
REQUEST_UPDATE_TIME = 3
|
||||
|
||||
|
||||
async def failable(fn: partial, ret_fail=None):
|
||||
try:
|
||||
return await fn()
|
||||
|
@ -54,6 +59,8 @@ class SkynetGPUConnector:
|
|||
|
||||
self.ipfs_client = AsyncIPFSHTTP(self.ipfs_url)
|
||||
|
||||
self._wip_requests = {}
|
||||
|
||||
# blockchain helpers
|
||||
|
||||
async def get_work_requests_last_hour(self):
|
||||
|
@ -103,6 +110,36 @@ class SkynetGPUConnector:
|
|||
else:
|
||||
return None
|
||||
|
||||
def monitor_request(self, request_id: int):
|
||||
logging.info(f'begin monitoring request: {request_id}')
|
||||
self._wip_requests[request_id] = {
|
||||
'last_update': None,
|
||||
'competitors': set()
|
||||
}
|
||||
|
||||
async def maybe_update_request(self, request_id: int):
|
||||
now = time.time()
|
||||
stats = self._wip_requests[request_id]
|
||||
if (not stats['last_update'] or
|
||||
(now - stats['last_update']) > REQUEST_UPDATE_TIME):
|
||||
stats['competitors'] = [
|
||||
status['worker']
|
||||
for status in
|
||||
(await self.get_status_by_request_id(request_id))
|
||||
if status['worker'] != self.account
|
||||
]
|
||||
stats['last_update'] = now
|
||||
|
||||
async def get_competitors_for_req(self, request_id: int) -> set:
|
||||
await self.maybe_update_request(request_id)
|
||||
competitors = set(self._wip_requests[request_id]['competitors'])
|
||||
logging.info(f'competitors: {competitors}')
|
||||
return competitors
|
||||
|
||||
def forget_request(self, request_id: int):
|
||||
logging.info(f'end monitoring request: {request_id}')
|
||||
del self._wip_requests[request_id]
|
||||
|
||||
async def begin_work(self, request_id: int):
|
||||
logging.info('begin_work')
|
||||
return await failable(
|
||||
|
|
Loading…
Reference in New Issue