Fix cli entrypoints to use new config, improve competitor cancel logic and add default docker image to py311 image

pull/26/head
Guillermo Rodriguez 2023-10-07 12:32:00 -03:00
parent 5437af4d05
commit cc4a4b5189
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
6 changed files with 108 additions and 64 deletions

View File

@ -5,3 +5,7 @@ docker build \
docker build \
-t guilledk/skynet:runtime-cuda-py311 \
-f docker/Dockerfile.runtime+cuda-py311 .
docker build \
-t guilledk/skynet:runtime-cuda \
-f docker/Dockerfile.runtime+cuda-py311 .

View File

@ -33,8 +33,8 @@ def txt2img(*args, **kwargs):
from . import utils
config = load_skynet_toml()
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
utils.txt2img(hf_token, **kwargs)
@ -51,8 +51,8 @@ def txt2img(*args, **kwargs):
def img2img(model, prompt, input, output, strength, guidance, steps, seed):
from . import utils
config = load_skynet_toml()
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
utils.img2img(
hf_token,
@ -82,8 +82,8 @@ def upscale(input, output, model):
def download():
from . import utils
config = load_skynet_toml()
hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home)
utils.download_all_models(hf_token)
@ -112,10 +112,10 @@ def enqueue(
config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key')
account = load_key(config, 'skynet.user', 'account')
permission = load_key(config, 'skynet.user', 'permission')
node_url = load_key(config, 'skynet.user', 'node_url')
key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url)
@ -156,10 +156,10 @@ def clean(
from leap.cleos import CLEOS
config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key')
account = load_key(config, 'skynet.user', 'account')
permission = load_key(config, 'skynet.user', 'permission')
node_url = load_key(config, 'skynet.user', 'node_url')
key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url')
logging.basicConfig(level=loglevel)
cleos = CLEOS(None, None, url=node_url, remote=node_url)
@ -177,7 +177,7 @@ def clean(
def queue():
import requests
config = load_skynet_toml()
node_url = load_key(config, 'skynet.user', 'node_url')
node_url = load_key(config, 'skynet.user.node_url')
resp = requests.post(
f'{node_url}/v1/chain/get_table_rows',
json={
@ -194,7 +194,7 @@ def queue():
def status(request_id: int):
import requests
config = load_skynet_toml()
node_url = load_key(config, 'skynet.user', 'node_url')
node_url = load_key(config, 'skynet.user.node_url')
resp = requests.post(
f'{node_url}/v1/chain/get_table_rows',
json={
@ -213,10 +213,10 @@ def dequeue(request_id: int):
from leap.cleos import CLEOS
config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key')
account = load_key(config, 'skynet.user', 'account')
permission = load_key(config, 'skynet.user', 'permission')
node_url = load_key(config, 'skynet.user', 'node_url')
key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run(
@ -248,10 +248,10 @@ def config(
config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key')
account = load_key(config, 'skynet.user', 'account')
permission = load_key(config, 'skynet.user', 'permission')
node_url = load_key(config, 'skynet.user', 'node_url')
key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run(
@ -277,10 +277,10 @@ def deposit(quantity: str):
config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key')
account = load_key(config, 'skynet.user', 'account')
permission = load_key(config, 'skynet.user', 'permission')
node_url = load_key(config, 'skynet.user', 'node_url')
key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run(
@ -365,21 +365,21 @@ def telegram(
logging.basicConfig(level=loglevel)
config = load_skynet_toml()
tg_token = load_key(config, 'skynet.telegram', 'tg_token')
tg_token = load_key(config, 'skynet.telegram.tg_token')
key = load_key(config, 'skynet.telegram', 'key')
account = load_key(config, 'skynet.telegram', 'account')
permission = load_key(config, 'skynet.telegram', 'permission')
node_url = load_key(config, 'skynet.telegram', 'node_url')
hyperion_url = load_key(config, 'skynet.telegram', 'hyperion_url')
key = load_key(config, 'skynet.telegram.key')
account = load_key(config, 'skynet.telegram.account')
permission = load_key(config, 'skynet.telegram.permission')
node_url = load_key(config, 'skynet.telegram.node_url')
hyperion_url = load_key(config, 'skynet.telegram.hyperion_url')
try:
ipfs_gateway_url = load_key(config, 'skynet.telegram', 'ipfs_gateway_url')
ipfs_gateway_url = load_key(config, 'skynet.telegram.ipfs_gateway_url')
except ConfigParsingError:
ipfs_gateway_url = None
ipfs_url = load_key(config, 'skynet.telegram', 'ipfs_url')
ipfs_url = load_key(config, 'skynet.telegram.ipfs_url')
async def _async_main():
frontend = SkynetTelegramFrontend(
@ -421,16 +421,16 @@ def discord(
logging.basicConfig(level=loglevel)
config = load_skynet_toml()
dc_token = load_key(config, 'skynet.discord', 'dc_token')
dc_token = load_key(config, 'skynet.discord.dc_token')
key = load_key(config, 'skynet.discord', 'key')
account = load_key(config, 'skynet.discord', 'account')
permission = load_key(config, 'skynet.discord', 'permission')
node_url = load_key(config, 'skynet.discord', 'node_url')
hyperion_url = load_key(config, 'skynet.discord', 'hyperion_url')
key = load_key(config, 'skynet.discord.key')
account = load_key(config, 'skynet.discord.account')
permission = load_key(config, 'skynet.discord.permission')
node_url = load_key(config, 'skynet.discord.node_url')
hyperion_url = load_key(config, 'skynet.discord.hyperion_url')
ipfs_gateway_url = load_key(config, 'skynet.discord', 'ipfs_gateway_url')
ipfs_url = load_key(config, 'skynet.discord', 'ipfs_url')
ipfs_gateway_url = load_key(config, 'skynet.discord.ipfs_gateway_url')
ipfs_url = load_key(config, 'skynet.discord.ipfs_url')
async def _async_main():
frontend = SkynetDiscordFrontend(
@ -471,8 +471,8 @@ def pinner(loglevel):
from .ipfs.pinner import SkynetPinner
config = load_skynet_toml()
hyperion_url = load_key(config, 'skynet.pinner', 'hyperion_url')
ipfs_url = load_key(config, 'skynet.pinner', 'ipfs_url')
hyperion_url = load_key(config, 'skynet.pinner.hyperion_url')
ipfs_url = load_key(config, 'skynet.pinner.ipfs_url')
logging.basicConfig(level=loglevel)
ipfs_node = AsyncIPFSHTTP(ipfs_url)

View File

@ -13,7 +13,7 @@ import trio
import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
from skynet.dgpu.errors import DGPUComputeError
from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
@ -132,16 +132,19 @@ class SkynetMM:
def compute_one(
self,
request_id: int,
should_cancel_work,
method: str,
params: dict,
binary: bytes | None = None
):
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor):
should_raise = trio.from_thread.run(should_cancel_work)
def maybe_cancel_work(step, *args, **kwargs):
should_raise = trio.from_thread.run(should_cancel_work, request_id)
if should_raise:
logging.warn(f'cancelling work at step {step}')
raise DGPUComputeError('Inference cancelled')
raise DGPUInferenceCancelled()
maybe_cancel_work(0)
try:
match method:
@ -157,7 +160,7 @@ class SkynetMM:
guidance_scale=guidance,
num_inference_steps=step,
generator=seed,
callback=callback_fn,
callback=maybe_cancel_work,
callback_steps=2,
**extra_params
).images[0]

View File

@ -40,15 +40,9 @@ class SkynetDGPUDaemon:
if 'model_blacklist' in config:
self.model_blacklist = set(config['model_blacklist'])
self.current_request = None
async def should_cancel_work(self):
competitors = set((
status['worker']
for status in
(await self.conn.get_status_by_request_id(self.current_request))
))
return self.non_compete & competitors
async def should_cancel_work(self, request_id: int):
competitors = await self.conn.get_competitors_for_req(request_id)
return bool(self.non_compete & competitors)
async def serve_forever(self):
try:
@ -79,7 +73,7 @@ class SkynetDGPUDaemon:
statuses = await self.conn.get_status_by_request_id(rid)
if len(statuses) == 0:
self.current_request = rid
self.conn.monitor_request(rid)
binary = await self.conn.get_input_data(req['binary_data'])
@ -107,6 +101,7 @@ class SkynetDGPUDaemon:
img_sha, img_raw = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
rid,
self.should_cancel_work,
body['method'], body['params'], binary=binary
)
@ -115,11 +110,13 @@ class SkynetDGPUDaemon:
ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash)
break
except BaseException as e:
traceback.print_exc()
await self.conn.cancel_work(rid, str(e))
finally:
self.conn.forget_request(rid)
break
else:

View File

@ -3,3 +3,6 @@
class DGPUComputeError(BaseException):
...
class DGPUInferenceCancelled(BaseException):
...

View File

@ -1,13 +1,15 @@
#!/usr/bin/python
from functools import partial
import io
import json
from pathlib import Path
import time
import logging
from pathlib import Path
from functools import partial
import asks
import trio
import anyio
from PIL import Image
@ -20,6 +22,9 @@ from skynet.dgpu.errors import DGPUComputeError
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
REQUEST_UPDATE_TIME = 3
async def failable(fn: partial, ret_fail=None):
try:
return await fn()
@ -54,6 +59,8 @@ class SkynetGPUConnector:
self.ipfs_client = AsyncIPFSHTTP(self.ipfs_url)
self._wip_requests = {}
# blockchain helpers
async def get_work_requests_last_hour(self):
@ -103,6 +110,36 @@ class SkynetGPUConnector:
else:
return None
def monitor_request(self, request_id: int):
logging.info(f'begin monitoring request: {request_id}')
self._wip_requests[request_id] = {
'last_update': None,
'competitors': set()
}
async def maybe_update_request(self, request_id: int):
now = time.time()
stats = self._wip_requests[request_id]
if (not stats['last_update'] or
(now - stats['last_update']) > REQUEST_UPDATE_TIME):
stats['competitors'] = [
status['worker']
for status in
(await self.get_status_by_request_id(request_id))
if status['worker'] != self.account
]
stats['last_update'] = now
async def get_competitors_for_req(self, request_id: int) -> set:
await self.maybe_update_request(request_id)
competitors = set(self._wip_requests[request_id]['competitors'])
logging.info(f'competitors: {competitors}')
return competitors
def forget_request(self, request_id: int):
logging.info(f'end monitoring request: {request_id}')
del self._wip_requests[request_id]
async def begin_work(self, request_id: int):
logging.info('begin_work')
return await failable(