Fix cli entrypoints to use new config, improve competitor cancel logic and add default docker image to py311 image

pull/26/head
Guillermo Rodriguez 2023-10-07 12:32:00 -03:00
parent 5437af4d05
commit cc4a4b5189
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
6 changed files with 108 additions and 64 deletions

View File

@ -5,3 +5,7 @@ docker build \
docker build \ docker build \
-t guilledk/skynet:runtime-cuda-py311 \ -t guilledk/skynet:runtime-cuda-py311 \
-f docker/Dockerfile.runtime+cuda-py311 . -f docker/Dockerfile.runtime+cuda-py311 .
docker build \
-t guilledk/skynet:runtime-cuda \
-f docker/Dockerfile.runtime+cuda-py311 .

View File

@ -33,8 +33,8 @@ def txt2img(*args, **kwargs):
from . import utils from . import utils
config = load_skynet_toml() config = load_skynet_toml()
hf_token = load_key(config, 'skynet.dgpu', 'hf_token') hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home') hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home) set_hf_vars(hf_token, hf_home)
utils.txt2img(hf_token, **kwargs) utils.txt2img(hf_token, **kwargs)
@ -51,8 +51,8 @@ def txt2img(*args, **kwargs):
def img2img(model, prompt, input, output, strength, guidance, steps, seed): def img2img(model, prompt, input, output, strength, guidance, steps, seed):
from . import utils from . import utils
config = load_skynet_toml() config = load_skynet_toml()
hf_token = load_key(config, 'skynet.dgpu', 'hf_token') hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home') hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home) set_hf_vars(hf_token, hf_home)
utils.img2img( utils.img2img(
hf_token, hf_token,
@ -82,8 +82,8 @@ def upscale(input, output, model):
def download(): def download():
from . import utils from . import utils
config = load_skynet_toml() config = load_skynet_toml()
hf_token = load_key(config, 'skynet.dgpu', 'hf_token') hf_token = load_key(config, 'skynet.dgpu.hf_token')
hf_home = load_key(config, 'skynet.dgpu', 'hf_home') hf_home = load_key(config, 'skynet.dgpu.hf_home')
set_hf_vars(hf_token, hf_home) set_hf_vars(hf_token, hf_home)
utils.download_all_models(hf_token) utils.download_all_models(hf_token)
@ -112,10 +112,10 @@ def enqueue(
config = load_skynet_toml() config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user', 'account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user', 'permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user', 'node_url') node_url = load_key(config, 'skynet.user.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
@ -156,10 +156,10 @@ def clean(
from leap.cleos import CLEOS from leap.cleos import CLEOS
config = load_skynet_toml() config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user', 'account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user', 'permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user', 'node_url') node_url = load_key(config, 'skynet.user.node_url')
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
@ -177,7 +177,7 @@ def clean(
def queue(): def queue():
import requests import requests
config = load_skynet_toml() config = load_skynet_toml()
node_url = load_key(config, 'skynet.user', 'node_url') node_url = load_key(config, 'skynet.user.node_url')
resp = requests.post( resp = requests.post(
f'{node_url}/v1/chain/get_table_rows', f'{node_url}/v1/chain/get_table_rows',
json={ json={
@ -194,7 +194,7 @@ def queue():
def status(request_id: int): def status(request_id: int):
import requests import requests
config = load_skynet_toml() config = load_skynet_toml()
node_url = load_key(config, 'skynet.user', 'node_url') node_url = load_key(config, 'skynet.user.node_url')
resp = requests.post( resp = requests.post(
f'{node_url}/v1/chain/get_table_rows', f'{node_url}/v1/chain/get_table_rows',
json={ json={
@ -213,10 +213,10 @@ def dequeue(request_id: int):
from leap.cleos import CLEOS from leap.cleos import CLEOS
config = load_skynet_toml() config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user', 'account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user', 'permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user', 'node_url') node_url = load_key(config, 'skynet.user.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run( res = trio.run(
@ -248,10 +248,10 @@ def config(
config = load_skynet_toml() config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user', 'account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user', 'permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user', 'node_url') node_url = load_key(config, 'skynet.user.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run( res = trio.run(
@ -277,10 +277,10 @@ def deposit(quantity: str):
config = load_skynet_toml() config = load_skynet_toml()
key = load_key(config, 'skynet.user', 'key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user', 'account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user', 'permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user', 'node_url') node_url = load_key(config, 'skynet.user.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run( res = trio.run(
@ -365,21 +365,21 @@ def telegram(
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
config = load_skynet_toml() config = load_skynet_toml()
tg_token = load_key(config, 'skynet.telegram', 'tg_token') tg_token = load_key(config, 'skynet.telegram.tg_token')
key = load_key(config, 'skynet.telegram', 'key') key = load_key(config, 'skynet.telegram.key')
account = load_key(config, 'skynet.telegram', 'account') account = load_key(config, 'skynet.telegram.account')
permission = load_key(config, 'skynet.telegram', 'permission') permission = load_key(config, 'skynet.telegram.permission')
node_url = load_key(config, 'skynet.telegram', 'node_url') node_url = load_key(config, 'skynet.telegram.node_url')
hyperion_url = load_key(config, 'skynet.telegram', 'hyperion_url') hyperion_url = load_key(config, 'skynet.telegram.hyperion_url')
try: try:
ipfs_gateway_url = load_key(config, 'skynet.telegram', 'ipfs_gateway_url') ipfs_gateway_url = load_key(config, 'skynet.telegram.ipfs_gateway_url')
except ConfigParsingError: except ConfigParsingError:
ipfs_gateway_url = None ipfs_gateway_url = None
ipfs_url = load_key(config, 'skynet.telegram', 'ipfs_url') ipfs_url = load_key(config, 'skynet.telegram.ipfs_url')
async def _async_main(): async def _async_main():
frontend = SkynetTelegramFrontend( frontend = SkynetTelegramFrontend(
@ -421,16 +421,16 @@ def discord(
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
config = load_skynet_toml() config = load_skynet_toml()
dc_token = load_key(config, 'skynet.discord', 'dc_token') dc_token = load_key(config, 'skynet.discord.dc_token')
key = load_key(config, 'skynet.discord', 'key') key = load_key(config, 'skynet.discord.key')
account = load_key(config, 'skynet.discord', 'account') account = load_key(config, 'skynet.discord.account')
permission = load_key(config, 'skynet.discord', 'permission') permission = load_key(config, 'skynet.discord.permission')
node_url = load_key(config, 'skynet.discord', 'node_url') node_url = load_key(config, 'skynet.discord.node_url')
hyperion_url = load_key(config, 'skynet.discord', 'hyperion_url') hyperion_url = load_key(config, 'skynet.discord.hyperion_url')
ipfs_gateway_url = load_key(config, 'skynet.discord', 'ipfs_gateway_url') ipfs_gateway_url = load_key(config, 'skynet.discord.ipfs_gateway_url')
ipfs_url = load_key(config, 'skynet.discord', 'ipfs_url') ipfs_url = load_key(config, 'skynet.discord.ipfs_url')
async def _async_main(): async def _async_main():
frontend = SkynetDiscordFrontend( frontend = SkynetDiscordFrontend(
@ -471,8 +471,8 @@ def pinner(loglevel):
from .ipfs.pinner import SkynetPinner from .ipfs.pinner import SkynetPinner
config = load_skynet_toml() config = load_skynet_toml()
hyperion_url = load_key(config, 'skynet.pinner', 'hyperion_url') hyperion_url = load_key(config, 'skynet.pinner.hyperion_url')
ipfs_url = load_key(config, 'skynet.pinner', 'ipfs_url') ipfs_url = load_key(config, 'skynet.pinner.ipfs_url')
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
ipfs_node = AsyncIPFSHTTP(ipfs_url) ipfs_node = AsyncIPFSHTTP(ipfs_url)

View File

@ -13,7 +13,7 @@ import trio
import torch import torch
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
from skynet.dgpu.errors import DGPUComputeError from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
@ -132,16 +132,19 @@ class SkynetMM:
def compute_one( def compute_one(
self, self,
request_id: int,
should_cancel_work, should_cancel_work,
method: str, method: str,
params: dict, params: dict,
binary: bytes | None = None binary: bytes | None = None
): ):
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor): def maybe_cancel_work(step, *args, **kwargs):
should_raise = trio.from_thread.run(should_cancel_work) should_raise = trio.from_thread.run(should_cancel_work, request_id)
if should_raise: if should_raise:
logging.warn(f'cancelling work at step {step}') logging.warn(f'cancelling work at step {step}')
raise DGPUComputeError('Inference cancelled') raise DGPUInferenceCancelled()
maybe_cancel_work(0)
try: try:
match method: match method:
@ -157,7 +160,7 @@ class SkynetMM:
guidance_scale=guidance, guidance_scale=guidance,
num_inference_steps=step, num_inference_steps=step,
generator=seed, generator=seed,
callback=callback_fn, callback=maybe_cancel_work,
callback_steps=2, callback_steps=2,
**extra_params **extra_params
).images[0] ).images[0]

View File

@ -40,15 +40,9 @@ class SkynetDGPUDaemon:
if 'model_blacklist' in config: if 'model_blacklist' in config:
self.model_blacklist = set(config['model_blacklist']) self.model_blacklist = set(config['model_blacklist'])
self.current_request = None async def should_cancel_work(self, request_id: int):
competitors = await self.conn.get_competitors_for_req(request_id)
async def should_cancel_work(self): return bool(self.non_compete & competitors)
competitors = set((
status['worker']
for status in
(await self.conn.get_status_by_request_id(self.current_request))
))
return self.non_compete & competitors
async def serve_forever(self): async def serve_forever(self):
try: try:
@ -79,7 +73,7 @@ class SkynetDGPUDaemon:
statuses = await self.conn.get_status_by_request_id(rid) statuses = await self.conn.get_status_by_request_id(rid)
if len(statuses) == 0: if len(statuses) == 0:
self.current_request = rid self.conn.monitor_request(rid)
binary = await self.conn.get_input_data(req['binary_data']) binary = await self.conn.get_input_data(req['binary_data'])
@ -107,6 +101,7 @@ class SkynetDGPUDaemon:
img_sha, img_raw = await trio.to_thread.run_sync( img_sha, img_raw = await trio.to_thread.run_sync(
partial( partial(
self.mm.compute_one, self.mm.compute_one,
rid,
self.should_cancel_work, self.should_cancel_work,
body['method'], body['params'], binary=binary body['method'], body['params'], binary=binary
) )
@ -115,11 +110,13 @@ class SkynetDGPUDaemon:
ipfs_hash = await self.conn.publish_on_ipfs(img_raw) ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash) await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash)
break
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
await self.conn.cancel_work(rid, str(e)) await self.conn.cancel_work(rid, str(e))
finally:
self.conn.forget_request(rid)
break break
else: else:

View File

@ -3,3 +3,6 @@
class DGPUComputeError(BaseException): class DGPUComputeError(BaseException):
... ...
class DGPUInferenceCancelled(BaseException):
...

View File

@ -1,13 +1,15 @@
#!/usr/bin/python #!/usr/bin/python
from functools import partial
import io import io
import json import json
from pathlib import Path
import time import time
import logging import logging
from pathlib import Path
from functools import partial
import asks import asks
import trio
import anyio import anyio
from PIL import Image from PIL import Image
@ -20,6 +22,9 @@ from skynet.dgpu.errors import DGPUComputeError
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
REQUEST_UPDATE_TIME = 3
async def failable(fn: partial, ret_fail=None): async def failable(fn: partial, ret_fail=None):
try: try:
return await fn() return await fn()
@ -54,6 +59,8 @@ class SkynetGPUConnector:
self.ipfs_client = AsyncIPFSHTTP(self.ipfs_url) self.ipfs_client = AsyncIPFSHTTP(self.ipfs_url)
self._wip_requests = {}
# blockchain helpers # blockchain helpers
async def get_work_requests_last_hour(self): async def get_work_requests_last_hour(self):
@ -103,6 +110,36 @@ class SkynetGPUConnector:
else: else:
return None return None
def monitor_request(self, request_id: int):
logging.info(f'begin monitoring request: {request_id}')
self._wip_requests[request_id] = {
'last_update': None,
'competitors': set()
}
async def maybe_update_request(self, request_id: int):
now = time.time()
stats = self._wip_requests[request_id]
if (not stats['last_update'] or
(now - stats['last_update']) > REQUEST_UPDATE_TIME):
stats['competitors'] = [
status['worker']
for status in
(await self.get_status_by_request_id(request_id))
if status['worker'] != self.account
]
stats['last_update'] = now
async def get_competitors_for_req(self, request_id: int) -> set:
await self.maybe_update_request(request_id)
competitors = set(self._wip_requests[request_id]['competitors'])
logging.info(f'competitors: {competitors}')
return competitors
def forget_request(self, request_id: int):
logging.info(f'end monitoring request: {request_id}')
del self._wip_requests[request_id]
async def begin_work(self, request_id: int): async def begin_work(self, request_id: int):
logging.info('begin_work') logging.info('begin_work')
return await failable( return await failable(