mirror of https://github.com/skygpu/skynet.git
Create msgspec struct for config
parent
5a3a43b3c6
commit
ea3b35904c
270
skynet/cli.py
270
skynet/cli.py
|
@ -13,7 +13,6 @@ from leap.protocol import (
|
|||
|
||||
from .config import (
|
||||
load_skynet_toml,
|
||||
load_key,
|
||||
set_hf_vars,
|
||||
ConfigParsingError,
|
||||
)
|
||||
|
@ -49,9 +48,7 @@ def txt2img(*args, **kwargs):
|
|||
from . import utils # TODO? why here, import cycle?
|
||||
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.txt2img(hf_token, **kwargs)
|
||||
|
||||
|
||||
|
@ -75,9 +72,7 @@ def txt2img(*args, **kwargs):
|
|||
def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
||||
from . import utils
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.img2img(
|
||||
hf_token,
|
||||
model=model,
|
||||
|
@ -105,9 +100,7 @@ def img2img(model, prompt, input, output, strength, guidance, steps, seed):
|
|||
def inpaint(model, prompt, input, mask, output, strength, guidance, steps, seed):
|
||||
from . import utils
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.inpaint(
|
||||
hf_token,
|
||||
model=model,
|
||||
|
@ -137,113 +130,15 @@ def upscale(input, output, model):
|
|||
def download():
|
||||
from . import utils
|
||||
config = load_skynet_toml()
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
utils.download_all_models(hf_token, hf_home)
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
utils.download_all_models(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
|
||||
|
||||
@skynet.command()
|
||||
@click.option(
|
||||
'--reward', '-r', default='20.0000 GPU')
|
||||
@click.option('--jobs', '-j', default=1)
|
||||
@click.option('--model', '-m', default='stabilityai/stable-diffusion-xl-base-1.0')
|
||||
@click.option(
|
||||
'--prompt', '-p', default='a red old tractor in a sunny wheat field')
|
||||
@click.option('--output', '-o', default='output.png')
|
||||
@click.option('--width', '-w', default=1024)
|
||||
@click.option('--height', '-h', default=1024)
|
||||
@click.option('--guidance', '-g', default=10)
|
||||
@click.option('--step', '-s', default=26)
|
||||
@click.option('--seed', '-S', default=None)
|
||||
@click.option('--upscaler', '-U', default='x4')
|
||||
@click.option('--binary_data', '-b', default='')
|
||||
@click.option('--strength', '-Z', default=None)
|
||||
def enqueue(
|
||||
reward: str,
|
||||
jobs: int,
|
||||
**kwargs
|
||||
):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
||||
binary = kwargs['binary_data']
|
||||
if not kwargs['strength']:
|
||||
if binary:
|
||||
raise ValueError('strength -Z param required if binary data passed')
|
||||
|
||||
del kwargs['strength']
|
||||
|
||||
else:
|
||||
kwargs['strength'] = float(kwargs['strength'])
|
||||
|
||||
async def enqueue_n_jobs():
|
||||
for i in range(jobs):
|
||||
if not kwargs['seed']:
|
||||
kwargs['seed'] = random.randint(0, 10e9)
|
||||
|
||||
req = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': kwargs
|
||||
})
|
||||
|
||||
res = await cleos.a_push_action(
|
||||
'gpu.scd',
|
||||
'enqueue',
|
||||
{
|
||||
'user': Name(account),
|
||||
'request_body': req,
|
||||
'binary_data': binary,
|
||||
'reward': Asset.from_str(reward),
|
||||
'min_verification': 1
|
||||
},
|
||||
account, key, permission,
|
||||
)
|
||||
print(res)
|
||||
|
||||
trio.run(enqueue_n_jobs)
|
||||
|
||||
|
||||
@skynet.command()
|
||||
@click.option('--loglevel', '-l', default='INFO', help='Logging level')
|
||||
def clean(
|
||||
loglevel: str,
|
||||
):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
trio.run(
|
||||
partial(
|
||||
cleos.a_push_action,
|
||||
'gpu.scd',
|
||||
'clean',
|
||||
{},
|
||||
account, key, permission=permission
|
||||
)
|
||||
)
|
||||
|
||||
@skynet.command()
|
||||
def queue():
|
||||
import requests
|
||||
config = load_skynet_toml()
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = config.user.node_url
|
||||
resp = requests.post(
|
||||
f'{node_url}/v1/chain/get_table_rows',
|
||||
json={
|
||||
|
@ -260,7 +155,7 @@ def queue():
|
|||
def status(request_id: int):
|
||||
import requests
|
||||
config = load_skynet_toml()
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = config.user.node_url
|
||||
resp = requests.post(
|
||||
f'{node_url}/v1/chain/get_table_rows',
|
||||
json={
|
||||
|
@ -272,101 +167,6 @@ def status(request_id: int):
|
|||
)
|
||||
print(json.dumps(resp.json(), indent=4))
|
||||
|
||||
@skynet.command()
|
||||
@click.argument('request-id')
|
||||
def dequeue(request_id: int):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
res = trio.run(
|
||||
partial(
|
||||
cleos.a_push_action,
|
||||
'gpu.scd',
|
||||
'dequeue',
|
||||
{
|
||||
'user': Name(account),
|
||||
'request_id': int(request_id),
|
||||
},
|
||||
account, key, permission=permission
|
||||
)
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@skynet.command()
|
||||
@click.option(
|
||||
'--token-contract', '-c', default='eosio.token')
|
||||
@click.option(
|
||||
'--token-symbol', '-S', default='4,GPU')
|
||||
def config(
|
||||
token_contract: str,
|
||||
token_symbol: str
|
||||
):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
|
||||
config = load_skynet_toml()
|
||||
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
res = trio.run(
|
||||
partial(
|
||||
cleos.a_push_action,
|
||||
'gpu.scd',
|
||||
'config',
|
||||
{
|
||||
'token_contract': token_contract,
|
||||
'token_symbol': token_symbol,
|
||||
},
|
||||
account, key, permission=permission
|
||||
)
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@skynet.command()
|
||||
@click.argument('quantity')
|
||||
def deposit(quantity: str):
|
||||
import trio
|
||||
from leap.cleos import CLEOS
|
||||
from leap.sugar import asset_from_str
|
||||
|
||||
config = load_skynet_toml()
|
||||
|
||||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
||||
res = trio.run(
|
||||
partial(
|
||||
cleos.a_push_action,
|
||||
'gpu.scd',
|
||||
'transfer',
|
||||
{
|
||||
'sender': Name(account),
|
||||
'recipient': Name('gpu.scd'),
|
||||
'amount': asset_from_str(quantity),
|
||||
'memo': f'{account} transferred {quantity} to gpu.scd'
|
||||
},
|
||||
account, key, permission=permission
|
||||
)
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
||||
@skynet.group()
|
||||
def run(*args, **kwargs):
|
||||
pass
|
||||
|
@ -380,13 +180,6 @@ def db():
|
|||
container, passwd, host = db_params
|
||||
logging.info(('skynet', passwd, host))
|
||||
|
||||
@run.command()
|
||||
def nodeos():
|
||||
from .nodeos import open_nodeos
|
||||
|
||||
logging.basicConfig(filename='skynet-nodeos.log', level=logging.INFO)
|
||||
with open_nodeos(cleanup=False):
|
||||
...
|
||||
|
||||
@run.command()
|
||||
@click.option('--loglevel', '-l', default='INFO', help='Logging level')
|
||||
|
@ -405,14 +198,9 @@ def dgpu(
|
|||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml(file_path=config_path)
|
||||
hf_token = load_key(config, 'skynet.dgpu.hf_token')
|
||||
hf_home = load_key(config, 'skynet.dgpu.hf_home')
|
||||
set_hf_vars(hf_token, hf_home)
|
||||
set_hf_vars(config.dgpu.hf_token, config.dgpu.hf_home)
|
||||
|
||||
assert 'skynet' in config
|
||||
assert 'dgpu' in config['skynet']
|
||||
|
||||
trio.run(open_dgpu_node, config['skynet']['dgpu'])
|
||||
trio.run(open_dgpu_node, config.dgpu)
|
||||
|
||||
|
||||
@run.command()
|
||||
|
@ -435,24 +223,24 @@ def telegram(
|
|||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
tg_token = load_key(config, 'skynet.telegram.tg_token')
|
||||
tg_token = config.telegram.tg_token
|
||||
|
||||
key = load_key(config, 'skynet.telegram.key')
|
||||
account = load_key(config, 'skynet.telegram.account')
|
||||
permission = load_key(config, 'skynet.telegram.permission')
|
||||
node_url = load_key(config, 'skynet.telegram.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.telegram.hyperion_url')
|
||||
key = config.telegram.key
|
||||
account = config.telegram.account
|
||||
permission = config.telegram.permission
|
||||
node_url = config.telegram.node_url
|
||||
hyperion_url = config.telegram.hyperion_url
|
||||
|
||||
ipfs_url = load_key(config, 'skynet.telegram.ipfs_url')
|
||||
ipfs_url = config.telegram.ipfs_url
|
||||
|
||||
try:
|
||||
explorer_domain = load_key(config, 'skynet.telegram.explorer_domain')
|
||||
explorer_domain = config.telegram.explorer_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
explorer_domain = DEFAULT_EXPLORER_DOMAIN
|
||||
|
||||
try:
|
||||
ipfs_domain = load_key(config, 'skynet.telegram.ipfs_domain')
|
||||
ipfs_domain = config.telegram.ipfs_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
|
@ -498,24 +286,24 @@ def discord(
|
|||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
dc_token = load_key(config, 'skynet.discord.dc_token')
|
||||
dc_token = config.discord.dc_token
|
||||
|
||||
key = load_key(config, 'skynet.discord.key')
|
||||
account = load_key(config, 'skynet.discord.account')
|
||||
permission = load_key(config, 'skynet.discord.permission')
|
||||
node_url = load_key(config, 'skynet.discord.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.discord.hyperion_url')
|
||||
key = config.discord.key
|
||||
account = config.discord.account
|
||||
permission = config.discord.permission
|
||||
node_url = config.discord.node_url
|
||||
hyperion_url = config.discord.hyperion_url
|
||||
|
||||
ipfs_url = load_key(config, 'skynet.discord.ipfs_url')
|
||||
ipfs_url = config.discord.ipfs_url
|
||||
|
||||
try:
|
||||
explorer_domain = load_key(config, 'skynet.discord.explorer_domain')
|
||||
explorer_domain = config.discord.explorer_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
explorer_domain = DEFAULT_EXPLORER_DOMAIN
|
||||
|
||||
try:
|
||||
ipfs_domain = load_key(config, 'skynet.discord.ipfs_domain')
|
||||
ipfs_domain = config.discord.ipfs_domain
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
|
@ -549,8 +337,8 @@ def pinner(loglevel):
|
|||
from .ipfs.pinner import SkynetPinner
|
||||
|
||||
config = load_skynet_toml()
|
||||
hyperion_url = load_key(config, 'skynet.pinner.hyperion_url')
|
||||
ipfs_url = load_key(config, 'skynet.pinner.ipfs_url')
|
||||
hyperion_url = config.pinner.hyperion_url
|
||||
ipfs_url = config.pinner.ipfs_url
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
ipfs_node = AsyncIPFSHTTP(ipfs_url)
|
||||
|
|
|
@ -1,27 +1,70 @@
|
|||
import os
|
||||
import toml
|
||||
|
||||
from .constants import DEFAULT_CONFIG_PATH
|
||||
import msgspec
|
||||
|
||||
from skynet.constants import DEFAULT_CONFIG_PATH, DEFAULT_IPFS_DOMAIN
|
||||
|
||||
|
||||
class ConfigParsingError(BaseException):
|
||||
...
|
||||
|
||||
|
||||
def load_skynet_toml(file_path=DEFAULT_CONFIG_PATH) -> dict:
|
||||
config = toml.load(file_path)
|
||||
return config
|
||||
class DgpuConfig(msgspec.Struct):
|
||||
account: str
|
||||
permission: str
|
||||
key: str
|
||||
node_url: str
|
||||
hyperion_url: str
|
||||
ipfs_url: str
|
||||
hf_token: str
|
||||
ipfs_domain: str = DEFAULT_IPFS_DOMAIN
|
||||
hf_home: str = 'hf_home'
|
||||
non_compete: set[str] = set()
|
||||
model_whitelist: set[str] = set()
|
||||
model_blacklist: set[str] = set()
|
||||
backend: str = 'sync-on-thread'
|
||||
api_bind: str = False
|
||||
tui: bool = False
|
||||
|
||||
class TelegramConfig(msgspec.Struct):
|
||||
account: str
|
||||
permission: str
|
||||
key: str
|
||||
node_url: str
|
||||
hyperion_url: str
|
||||
ipfs_url: str
|
||||
token: str
|
||||
|
||||
def load_key(config: dict, key: str) -> str:
|
||||
for skey in key.split('.'):
|
||||
if skey not in config:
|
||||
conf_keys = [k for k in config]
|
||||
raise ConfigParsingError(f'key \"{skey}\" not in {conf_keys}')
|
||||
class DiscordConfig(msgspec.Struct):
|
||||
account: str
|
||||
permission: str
|
||||
key: str
|
||||
node_url: str
|
||||
hyperion_url: str
|
||||
ipfs_url: str
|
||||
token: str
|
||||
|
||||
config = config[skey]
|
||||
class PinnerConfig(msgspec.Struct):
|
||||
hyperion_url: str
|
||||
ipfs_url: str
|
||||
|
||||
return config
|
||||
class UserConfig(msgspec.Struct):
|
||||
account: str
|
||||
permission: str
|
||||
key: str
|
||||
node_url: str
|
||||
|
||||
class Config(msgspec.Struct):
|
||||
dgpu: DgpuConfig | None = None
|
||||
telegram: TelegramConfig | None = None
|
||||
discord: DiscordConfig | None = None
|
||||
pinner: PinnerConfig | None = None
|
||||
user: UserConfig | None = None
|
||||
|
||||
def load_skynet_toml(file_path=DEFAULT_CONFIG_PATH) -> Config:
|
||||
with open(file_path, 'r') as file:
|
||||
return msgspec.toml.decode(file.read(), type=Config)
|
||||
|
||||
|
||||
def set_hf_vars(hf_token: str, hf_home: str):
|
||||
|
|
|
@ -3,16 +3,17 @@ import logging
|
|||
import trio
|
||||
import urwid
|
||||
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.config import Config as HCConfig
|
||||
from hypercorn.trio import serve
|
||||
from quart_trio import QuartTrio as Quart
|
||||
|
||||
from skynet.config import Config
|
||||
from skynet.dgpu.tui import init_tui
|
||||
from skynet.dgpu.daemon import WorkerDaemon
|
||||
from skynet.dgpu.network import NetConnector
|
||||
|
||||
|
||||
async def open_dgpu_node(config: dict) -> None:
|
||||
async def open_dgpu_node(config: Config) -> None:
|
||||
'''
|
||||
Open a top level "GPU mgmt daemon", keep the
|
||||
`WorkerDaemon._snap: dict[str, list|dict]` table
|
||||
|
@ -23,16 +24,16 @@ async def open_dgpu_node(config: dict) -> None:
|
|||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
tui = None
|
||||
if config['tui']:
|
||||
if config.tui:
|
||||
tui = init_tui()
|
||||
|
||||
conn = NetConnector(config)
|
||||
daemon = WorkerDaemon(conn, config)
|
||||
|
||||
api: Quart|None = None
|
||||
if 'api_bind' in config:
|
||||
api_conf = Config()
|
||||
api_conf.bind = [config['api_bind']]
|
||||
if config.api_bind:
|
||||
api_conf = HCConfig()
|
||||
api_conf.bind = [config.api_bind]
|
||||
api: Quart = await daemon.generate_api()
|
||||
|
||||
tn: trio.Nursery
|
||||
|
|
|
@ -10,6 +10,7 @@ import trio
|
|||
from quart import jsonify
|
||||
from quart_trio import QuartTrio as Quart
|
||||
|
||||
from skynet.config import DgpuConfig as Config
|
||||
from skynet.constants import (
|
||||
MODELS,
|
||||
VERSION,
|
||||
|
@ -41,31 +42,10 @@ class WorkerDaemon:
|
|||
def __init__(
|
||||
self,
|
||||
conn: NetConnector,
|
||||
config: dict
|
||||
config: Config
|
||||
):
|
||||
self.config = config
|
||||
self.conn: NetConnector = conn
|
||||
self.auto_withdraw = (
|
||||
config['auto_withdraw']
|
||||
if 'auto_withdraw' in config else False
|
||||
)
|
||||
|
||||
self.account: str = config['account']
|
||||
|
||||
self.non_compete = set()
|
||||
if 'non_compete' in config:
|
||||
self.non_compete = set(config['non_compete'])
|
||||
|
||||
self.model_whitelist = set()
|
||||
if 'model_whitelist' in config:
|
||||
self.model_whitelist = set(config['model_whitelist'])
|
||||
|
||||
self.model_blacklist = set()
|
||||
if 'model_blacklist' in config:
|
||||
self.model_blacklist = set(config['model_blacklist'])
|
||||
|
||||
self.backend = 'sync-on-thread'
|
||||
if 'backend' in config:
|
||||
self.backend = config['backend']
|
||||
|
||||
self._snap = {
|
||||
'queue': [],
|
||||
|
@ -107,10 +87,10 @@ class WorkerDaemon:
|
|||
competitors = set([
|
||||
status['worker']
|
||||
for status in self._snap['requests'][request_id]
|
||||
if status['worker'] != self.account
|
||||
if status['worker'] != self.config.account
|
||||
])
|
||||
logging.info(f'competitors: {competitors}')
|
||||
should_cancel = bool(self.non_compete & competitors)
|
||||
should_cancel = bool(self.config.non_compete & competitors)
|
||||
logging.info(f'cancel: {should_cancel}')
|
||||
return should_cancel
|
||||
|
||||
|
@ -141,7 +121,7 @@ class WorkerDaemon:
|
|||
@app.route('/')
|
||||
async def health():
|
||||
return jsonify(
|
||||
account=self.account,
|
||||
account=self.config.account,
|
||||
version=VERSION,
|
||||
last_generation_ts=self._last_generation_ts,
|
||||
last_generation_speed=self._get_benchmark_speed()
|
||||
|
@ -182,15 +162,19 @@ class WorkerDaemon:
|
|||
|
||||
# only handle whitelisted models
|
||||
if (
|
||||
len(self.model_whitelist) > 0
|
||||
len(self.config.model_whitelist) > 0
|
||||
and
|
||||
model not in self.model_whitelist
|
||||
model not in self.config.model_whitelist
|
||||
):
|
||||
logging.warning('model not whitelisted!, skip...')
|
||||
return False
|
||||
|
||||
# if blacklist contains model skip
|
||||
if model in self.model_blacklist:
|
||||
if (
|
||||
len(self.config.model_blacklist) > 0
|
||||
and
|
||||
model in self.config.model_blacklist
|
||||
):
|
||||
logging.warning('model not blacklisted!, skip...')
|
||||
return False
|
||||
|
||||
|
@ -205,7 +189,7 @@ class WorkerDaemon:
|
|||
|
||||
# skip if workers in non_compete already on it
|
||||
competitors = set((status['worker'] for status in statuses))
|
||||
if bool(self.non_compete & competitors):
|
||||
if bool(self.config.non_compete & competitors):
|
||||
logging.info('worker in configured non_compete list already working on request, skip...')
|
||||
return False
|
||||
|
||||
|
@ -266,7 +250,7 @@ class WorkerDaemon:
|
|||
|
||||
output = None
|
||||
output_hash = None
|
||||
match self.backend:
|
||||
match self.config.backend:
|
||||
case 'sync-on-thread':
|
||||
output_hash, output = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
|
@ -280,7 +264,7 @@ class WorkerDaemon:
|
|||
|
||||
case _:
|
||||
raise DGPUComputeError(
|
||||
f'Unsupported backend {self.backend}'
|
||||
f'Unsupported backend {self.config.backend}'
|
||||
)
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_progress(total_step))
|
||||
|
@ -316,9 +300,6 @@ class WorkerDaemon:
|
|||
await self._update_balance()
|
||||
try:
|
||||
while True:
|
||||
if self.auto_withdraw:
|
||||
await self.conn.maybe_withdraw_all()
|
||||
|
||||
queue = self._snap['queue']
|
||||
|
||||
random.shuffle(queue)
|
||||
|
|
|
@ -14,6 +14,7 @@ from PIL import Image
|
|||
from leap.cleos import CLEOS
|
||||
from leap.protocol import Asset
|
||||
from skynet.dgpu.tui import maybe_update_tui
|
||||
from skynet.config import DgpuConfig as Config
|
||||
from skynet.constants import (
|
||||
DEFAULT_IPFS_DOMAIN,
|
||||
GPU_CONTRACT_ABI,
|
||||
|
@ -58,32 +59,16 @@ class NetConnector:
|
|||
- CLEOS client
|
||||
|
||||
'''
|
||||
def __init__(self, config: dict):
|
||||
# TODO, why these extra instance vars for an (unsynced)
|
||||
# copy of the `config` state?
|
||||
self.account = config['account']
|
||||
self.permission = config['permission']
|
||||
self.key = config['key']
|
||||
|
||||
# TODO, neither of these instance vars are used anywhere in
|
||||
# methods? so why are they set on this type?
|
||||
self.node_url = config['node_url']
|
||||
self.hyperion_url = config['hyperion_url']
|
||||
|
||||
self.cleos = CLEOS(endpoint=self.node_url)
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.cleos = CLEOS(endpoint=config.node_url)
|
||||
self.cleos.load_abi('gpu.scd', GPU_CONTRACT_ABI)
|
||||
|
||||
self.ipfs_url = config['ipfs_url']
|
||||
|
||||
self.ipfs_client = AsyncIPFSHTTP(self.ipfs_url)
|
||||
|
||||
self.ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
if 'ipfs_domain' in config:
|
||||
self.ipfs_domain = config['ipfs_domain']
|
||||
self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
|
||||
|
||||
self._wip_requests = {}
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.account))
|
||||
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
|
||||
|
||||
|
||||
# blockchain helpers
|
||||
|
@ -135,8 +120,8 @@ class NetConnector:
|
|||
'gpu.scd', 'gpu.scd', 'users',
|
||||
index_position=1,
|
||||
key_type='name',
|
||||
lower_bound=self.account,
|
||||
upper_bound=self.account
|
||||
lower_bound=self.config.account,
|
||||
upper_bound=self.config.account
|
||||
))
|
||||
|
||||
if rows:
|
||||
|
@ -190,12 +175,12 @@ class NetConnector:
|
|||
'gpu.scd',
|
||||
'workbegin',
|
||||
list({
|
||||
'worker': self.account,
|
||||
'worker': self.config.account,
|
||||
'request_id': request_id,
|
||||
'max_workers': 2
|
||||
}.values()),
|
||||
self.account, self.key,
|
||||
permission=self.permission
|
||||
self.config.account, self.config.key,
|
||||
permission=self.config.permission
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -207,12 +192,12 @@ class NetConnector:
|
|||
'gpu.scd',
|
||||
'workcancel',
|
||||
list({
|
||||
'worker': self.account,
|
||||
'worker': self.config.account,
|
||||
'request_id': request_id,
|
||||
'reason': reason
|
||||
}.values()),
|
||||
self.account, self.key,
|
||||
permission=self.permission
|
||||
self.config.account, self.config.key,
|
||||
permission=self.config.permission
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -230,11 +215,11 @@ class NetConnector:
|
|||
'gpu.scd',
|
||||
'withdraw',
|
||||
list({
|
||||
'user': self.account,
|
||||
'user': self.config.account,
|
||||
'quantity': Asset.from_str(balance)
|
||||
}.values()),
|
||||
self.account, self.key,
|
||||
permission=self.permission
|
||||
self.config.account, self.config.key,
|
||||
permission=self.config.permission
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -246,8 +231,8 @@ class NetConnector:
|
|||
'gpu.scd', 'gpu.scd', 'results',
|
||||
index_position=4,
|
||||
key_type='name',
|
||||
lower_bound=self.account,
|
||||
upper_bound=self.account
|
||||
lower_bound=self.config.account,
|
||||
upper_bound=self.config.account
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
@ -266,14 +251,14 @@ class NetConnector:
|
|||
'gpu.scd',
|
||||
'submit',
|
||||
list({
|
||||
'worker': self.account,
|
||||
'worker': self.config.account,
|
||||
'request_id': request_id,
|
||||
'request_hash': request_hash,
|
||||
'result_hash': result_hash,
|
||||
'ipfs_hash': ipfs_hash
|
||||
}.values()),
|
||||
self.account, self.key,
|
||||
permission=self.permission
|
||||
self.config.account, self.config.key,
|
||||
permission=self.config.permission
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -310,7 +295,7 @@ class NetConnector:
|
|||
consuming AI model.
|
||||
|
||||
'''
|
||||
link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
link = f'https://{self.config.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
|
||||
res = await get_ipfs_file(link, timeout=1)
|
||||
if not res or res.status_code != 200:
|
||||
|
|
Loading…
Reference in New Issue