mirror of https://github.com/skygpu/skynet.git
TODO
parent
8a415b450f
commit
f4592ae254
|
@ -116,7 +116,7 @@ def enqueue(
|
|||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = load_key(config, 'skynet.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
||||
|
@ -131,27 +131,38 @@ def enqueue(
|
|||
kwargs['strength'] = float(kwargs['strength'])
|
||||
|
||||
async def enqueue_n_jobs():
|
||||
for i in range(jobs):
|
||||
if not kwargs['seed']:
|
||||
kwargs['seed'] = random.randint(0, 10e9)
|
||||
actions = []
|
||||
for _ in range(jobs):
|
||||
if kwargs['seed']:
|
||||
seed = kwargs['seed']
|
||||
else:
|
||||
seed = random.randint(0, 10e9)
|
||||
|
||||
_kwargs = kwargs.copy()
|
||||
_kwargs['seed'] = seed
|
||||
|
||||
req = json.dumps({
|
||||
'method': 'diffuse',
|
||||
'params': kwargs
|
||||
'params': _kwargs
|
||||
})
|
||||
|
||||
res = await cleos.a_push_action(
|
||||
'telos.gpu',
|
||||
'enqueue',
|
||||
{
|
||||
actions.append({
|
||||
'account': 'telos.gpu',
|
||||
'name': 'enqueue',
|
||||
'data': {
|
||||
'user': Name(account),
|
||||
'request_body': req,
|
||||
'binary_data': binary,
|
||||
'reward': asset_from_str(reward),
|
||||
'min_verification': 1
|
||||
},
|
||||
account, key, permission,
|
||||
)
|
||||
'authorization': [{
|
||||
'actor': account,
|
||||
'permission': permission
|
||||
}]
|
||||
})
|
||||
|
||||
res = await cleos.a_push_actions(actions, key)
|
||||
print(res)
|
||||
|
||||
trio.run(enqueue_n_jobs)
|
||||
|
@ -169,7 +180,7 @@ def clean(
|
|||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = load_key(config, 'skynet.node_url')
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
@ -187,7 +198,7 @@ def clean(
|
|||
def queue():
|
||||
import requests
|
||||
config = load_skynet_toml()
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = load_key(config, 'skynet.node_url')
|
||||
resp = requests.post(
|
||||
f'{node_url}/v1/chain/get_table_rows',
|
||||
json={
|
||||
|
@ -204,7 +215,7 @@ def queue():
|
|||
def status(request_id: int):
|
||||
import requests
|
||||
config = load_skynet_toml()
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = load_key(config, 'skynet.node_url')
|
||||
resp = requests.post(
|
||||
f'{node_url}/v1/chain/get_table_rows',
|
||||
json={
|
||||
|
@ -226,7 +237,7 @@ def dequeue(request_id: int):
|
|||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = load_key(config, 'skynet.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
res = trio.run(
|
||||
|
@ -261,7 +272,7 @@ def config(
|
|||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = load_key(config, 'skynet.node_url')
|
||||
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
res = trio.run(
|
||||
|
@ -290,7 +301,7 @@ def deposit(quantity: str):
|
|||
key = load_key(config, 'skynet.user.key')
|
||||
account = load_key(config, 'skynet.user.account')
|
||||
permission = load_key(config, 'skynet.user.permission')
|
||||
node_url = load_key(config, 'skynet.user.node_url')
|
||||
node_url = load_key(config, 'skynet.node_url')
|
||||
cleos = CLEOS(None, None, url=node_url, remote=node_url)
|
||||
|
||||
res = trio.run(
|
||||
|
@ -352,7 +363,7 @@ def dgpu(
|
|||
assert 'skynet' in config
|
||||
assert 'dgpu' in config['skynet']
|
||||
|
||||
trio.run(open_dgpu_node, config['skynet']['dgpu'])
|
||||
trio.run(open_dgpu_node, config['skynet'])
|
||||
|
||||
|
||||
@run.command()
|
||||
|
@ -375,30 +386,30 @@ def telegram(
|
|||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
tg_token = load_key(config, 'skynet.telegram.tg_token')
|
||||
tg_token = load_key(config, 'skynet.telegram.token')
|
||||
|
||||
key = load_key(config, 'skynet.telegram.key')
|
||||
account = load_key(config, 'skynet.telegram.account')
|
||||
permission = load_key(config, 'skynet.telegram.permission')
|
||||
node_url = load_key(config, 'skynet.telegram.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.telegram.hyperion_url')
|
||||
node_url = load_key(config, 'skynet.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.hyperion_url')
|
||||
|
||||
try:
|
||||
ipfs_gateway_url = load_key(config, 'skynet.telegram.ipfs_gateway_url')
|
||||
ipfs_gateway_url = load_key(config, 'skynet.ipfs_gateway_url')
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_gateway_url = None
|
||||
|
||||
ipfs_url = load_key(config, 'skynet.telegram.ipfs_url')
|
||||
ipfs_url = load_key(config, 'skynet.ipfs_url')
|
||||
|
||||
try:
|
||||
explorer_domain = load_key(config, 'skynet.telegram.explorer_domain')
|
||||
explorer_domain = load_key(config, 'skynet.explorer_domain')
|
||||
|
||||
except ConfigParsingError:
|
||||
explorer_domain = DEFAULT_EXPLORER_DOMAIN
|
||||
|
||||
try:
|
||||
ipfs_domain = load_key(config, 'skynet.telegram.ipfs_domain')
|
||||
ipfs_domain = load_key(config, 'skynet.ipfs_domain')
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
|
@ -445,25 +456,25 @@ def discord(
|
|||
logging.basicConfig(level=loglevel)
|
||||
|
||||
config = load_skynet_toml()
|
||||
dc_token = load_key(config, 'skynet.discord.dc_token')
|
||||
dc_token = load_key(config, 'skynet.discord.token')
|
||||
|
||||
key = load_key(config, 'skynet.discord.key')
|
||||
account = load_key(config, 'skynet.discord.account')
|
||||
permission = load_key(config, 'skynet.discord.permission')
|
||||
node_url = load_key(config, 'skynet.discord.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.discord.hyperion_url')
|
||||
node_url = load_key(config, 'skynet.node_url')
|
||||
hyperion_url = load_key(config, 'skynet.hyperion_url')
|
||||
|
||||
ipfs_gateway_url = load_key(config, 'skynet.discord.ipfs_gateway_url')
|
||||
ipfs_url = load_key(config, 'skynet.discord.ipfs_url')
|
||||
ipfs_gateway_url = load_key(config, 'skynet.ipfs_gateway_url')
|
||||
ipfs_url = load_key(config, 'skynet.ipfs_url')
|
||||
|
||||
try:
|
||||
explorer_domain = load_key(config, 'skynet.discord.explorer_domain')
|
||||
explorer_domain = load_key(config, 'skynet.explorer_domain')
|
||||
|
||||
except ConfigParsingError:
|
||||
explorer_domain = DEFAULT_EXPLORER_DOMAIN
|
||||
|
||||
try:
|
||||
ipfs_domain = load_key(config, 'skynet.discord.ipfs_domain')
|
||||
ipfs_domain = load_key(config, 'skynet.ipfs_domain')
|
||||
|
||||
except ConfigParsingError:
|
||||
ipfs_domain = DEFAULT_IPFS_DOMAIN
|
||||
|
@ -509,8 +520,8 @@ def pinner(loglevel):
|
|||
from .ipfs.pinner import SkynetPinner
|
||||
|
||||
config = load_skynet_toml()
|
||||
hyperion_url = load_key(config, 'skynet.pinner.hyperion_url')
|
||||
ipfs_url = load_key(config, 'skynet.pinner.ipfs_url')
|
||||
hyperion_url = load_key(config, 'skynet.hyperion_url')
|
||||
ipfs_url = load_key(config, 'skynet.ipfs_url')
|
||||
|
||||
logging.basicConfig(level=loglevel)
|
||||
ipfs_node = AsyncIPFSHTTP(ipfs_url)
|
||||
|
|
|
@ -11,18 +11,18 @@ from skynet.dgpu.network import SkynetGPUConnector
|
|||
|
||||
|
||||
async def open_dgpu_node(config: dict):
|
||||
conn = SkynetGPUConnector(config)
|
||||
mm = SkynetMM(config)
|
||||
daemon = SkynetDGPUDaemon(mm, conn, config)
|
||||
conn = SkynetGPUConnector({**config, **config['dgpu']})
|
||||
mm = SkynetMM(config['dgpu'])
|
||||
daemon = SkynetDGPUDaemon(mm, conn, config['dgpu'])
|
||||
|
||||
api = None
|
||||
if 'api_bind' in config:
|
||||
if 'api_bind' in config['dgpu']:
|
||||
api_conf = Config()
|
||||
api_conf.bind = [config['api_bind']]
|
||||
api = await daemon.generate_api()
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(daemon.snap_updater_task)
|
||||
n.start_soon(conn.data_updater_task)
|
||||
|
||||
if api:
|
||||
n.start_soon(serve, api, api_conf)
|
||||
|
|
|
@ -6,6 +6,7 @@ import gc
|
|||
import logging
|
||||
|
||||
from hashlib import sha256
|
||||
from typing import Any
|
||||
import zipfile
|
||||
from PIL import Image
|
||||
from diffusers import DiffusionPipeline
|
||||
|
@ -21,11 +22,16 @@ from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_ima
|
|||
|
||||
def prepare_params_for_diffuse(
|
||||
params: dict,
|
||||
input_type: str,
|
||||
binary = None
|
||||
inputs: list[tuple[Any, str]],
|
||||
):
|
||||
_params = {}
|
||||
if binary != None:
|
||||
|
||||
if len(inputs) > 1:
|
||||
raise DGPUComputeError('sorry binary_inputs > 1 not implemented yet')
|
||||
|
||||
if len(inputs) == 0:
|
||||
binary, input_type = inputs[0]
|
||||
|
||||
match input_type:
|
||||
case 'png':
|
||||
image = crop_image(
|
||||
|
@ -34,9 +40,6 @@ def prepare_params_for_diffuse(
|
|||
_params['image'] = image
|
||||
_params['strength'] = float(params['strength'])
|
||||
|
||||
case 'none':
|
||||
...
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unknown input_type {input_type}')
|
||||
|
||||
|
@ -144,8 +147,7 @@ class SkynetMM:
|
|||
request_id: int,
|
||||
method: str,
|
||||
params: dict,
|
||||
input_type: str = 'png',
|
||||
binary: bytes | None = None
|
||||
inputs: list[tuple[Any, str]]
|
||||
):
|
||||
def maybe_cancel_work(step, *args, **kwargs):
|
||||
if self._should_cancel:
|
||||
|
@ -165,8 +167,7 @@ class SkynetMM:
|
|||
try:
|
||||
match method:
|
||||
case 'diffuse':
|
||||
arguments = prepare_params_for_diffuse(
|
||||
params, input_type, binary=binary)
|
||||
arguments = prepare_params_for_diffuse(params, inputs)
|
||||
prompt, guidance, step, seed, upscaler, extra_params = arguments
|
||||
model = self.get_model(params['model'], 'image' in extra_params)
|
||||
|
||||
|
|
|
@ -44,6 +44,10 @@ class SkynetDGPUDaemon:
|
|||
config['auto_withdraw']
|
||||
if 'auto_withdraw' in config else False
|
||||
)
|
||||
self.max_concurrent = (
|
||||
config['max_concurrent']
|
||||
if 'max_concurrent' in config else 0
|
||||
)
|
||||
|
||||
self.account = config['account']
|
||||
|
||||
|
@ -63,12 +67,6 @@ class SkynetDGPUDaemon:
|
|||
if 'backend' in config:
|
||||
self.backend = config['backend']
|
||||
|
||||
self._snap = {
|
||||
'queue': [],
|
||||
'requests': {},
|
||||
'my_results': []
|
||||
}
|
||||
|
||||
self._benchmark = []
|
||||
self._last_benchmark = None
|
||||
self._last_generation_ts = None
|
||||
|
@ -90,18 +88,10 @@ class SkynetDGPUDaemon:
|
|||
|
||||
async def should_cancel_work(self, request_id: int):
|
||||
self._benchmark.append(time.time())
|
||||
competitors = set([
|
||||
status['worker']
|
||||
for status in self._snap['requests'][request_id]
|
||||
if status['worker'] != self.account
|
||||
])
|
||||
return bool(self.non_compete & competitors)
|
||||
|
||||
|
||||
async def snap_updater_task(self):
|
||||
while True:
|
||||
self._snap = await self.conn.get_full_queue_snapshot()
|
||||
await trio.sleep(1)
|
||||
competitors = self.conn.get_competitors_for_request(request_id)
|
||||
if competitors == None:
|
||||
return True
|
||||
return bool(self.non_compete & set(competitors))
|
||||
|
||||
async def generate_api(self):
|
||||
app = Quart(__name__)
|
||||
|
@ -117,23 +107,21 @@ class SkynetDGPUDaemon:
|
|||
|
||||
return app
|
||||
|
||||
async def serve_forever(self):
|
||||
try:
|
||||
while True:
|
||||
if self.auto_withdraw:
|
||||
await self.conn.maybe_withdraw_all()
|
||||
|
||||
queue = self._snap['queue']
|
||||
def find_best_requests(self) -> list[dict]:
|
||||
queue = self.conn.get_queue()
|
||||
|
||||
for _ in range(3):
|
||||
random.shuffle(queue)
|
||||
|
||||
queue = sorted(
|
||||
queue,
|
||||
key=lambda req: convert_reward_to_int(req['reward']),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
requests = []
|
||||
for req in queue:
|
||||
rid = req['id']
|
||||
rid = req['nonce']
|
||||
|
||||
# parse request
|
||||
body = json.loads(req['body'])
|
||||
|
@ -153,19 +141,47 @@ class SkynetDGPUDaemon:
|
|||
if model in self.model_blacklist:
|
||||
continue
|
||||
|
||||
my_results = [res['id'] for res in self._snap['my_results']]
|
||||
if rid not in my_results and rid in self._snap['requests']:
|
||||
statuses = self._snap['requests'][rid]
|
||||
my_results = [res['id'] for res in self.conn.get_my_results()]
|
||||
|
||||
if len(statuses) == 0:
|
||||
binary, input_type = await self.conn.get_input_data(req['binary_data'])
|
||||
# if this worker already on it
|
||||
if rid in my_results:
|
||||
continue
|
||||
|
||||
status = self.conn.get_status_for_request(rid)
|
||||
if status == None:
|
||||
continue
|
||||
|
||||
if self.non_compete & set(self.conn.get_competitors_for_request(rid)):
|
||||
continue
|
||||
|
||||
if len(status) > self.max_concurrent:
|
||||
continue
|
||||
|
||||
requests.append(req)
|
||||
|
||||
return requests
|
||||
|
||||
async def serve_forever(self):
|
||||
try:
|
||||
while True:
|
||||
if self.auto_withdraw:
|
||||
await self.conn.maybe_withdraw_all()
|
||||
|
||||
requests = self.find_best_requests()
|
||||
|
||||
if len(requests) > 0:
|
||||
request = requests[0]
|
||||
rid = request['nonce']
|
||||
body = json.loads(request['body'])
|
||||
|
||||
inputs = await self.conn.get_inputs(request['binary_inputs'])
|
||||
|
||||
hash_str = (
|
||||
str(req['nonce'])
|
||||
str(request['nonce'])
|
||||
+
|
||||
req['body']
|
||||
request['body']
|
||||
+
|
||||
req['binary_data']
|
||||
''.join([_in for _in in request['binary_inputs']])
|
||||
)
|
||||
logging.info(f'hashing: {hash_str}')
|
||||
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
|
||||
|
@ -195,13 +211,13 @@ class SkynetDGPUDaemon:
|
|||
self.mm.compute_one,
|
||||
rid,
|
||||
body['method'], body['params'],
|
||||
input_type=input_type,
|
||||
binary=binary
|
||||
inputs=inputs
|
||||
)
|
||||
)
|
||||
|
||||
case _:
|
||||
raise DGPUComputeError(f'Unsupported backend {self.backend}')
|
||||
|
||||
self._last_generation_ts = datetime.now().isoformat()
|
||||
self._last_benchmark = self._benchmark
|
||||
self._benchmark = []
|
||||
|
@ -214,12 +230,6 @@ class SkynetDGPUDaemon:
|
|||
traceback.print_exc()
|
||||
await self.conn.cancel_work(rid, str(e))
|
||||
|
||||
finally:
|
||||
break
|
||||
|
||||
else:
|
||||
logging.info(f'request {rid} already beign worked on, skip...')
|
||||
|
||||
await trio.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
@ -7,6 +7,7 @@ import logging
|
|||
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine
|
||||
|
||||
import asks
|
||||
import trio
|
||||
|
@ -25,7 +26,10 @@ from skynet.dgpu.errors import DGPUComputeError
|
|||
REQUEST_UPDATE_TIME = 3
|
||||
|
||||
|
||||
async def failable(fn: partial, ret_fail=None):
|
||||
async def failable(
|
||||
fn: partial[Coroutine[Any, Any, Any]],
|
||||
ret_fail: Any | None = None
|
||||
) -> Any:
|
||||
try:
|
||||
return await fn()
|
||||
|
||||
|
@ -42,6 +46,7 @@ async def failable(fn: partial, ret_fail=None):
|
|||
class SkynetGPUConnector:
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.contract = config['contract']
|
||||
self.account = Name(config['account'])
|
||||
self.permission = config['permission']
|
||||
self.key = config['key']
|
||||
|
@ -63,27 +68,89 @@ class SkynetGPUConnector:
|
|||
if 'ipfs_domain' in config:
|
||||
self.ipfs_domain = config['ipfs_domain']
|
||||
|
||||
self._wip_requests = {}
|
||||
self._update_delta = 1
|
||||
self._cache: dict[str, tuple[float, Any]] = {}
|
||||
|
||||
# blockchain helpers
|
||||
async def _cache_set(
|
||||
self,
|
||||
fn: partial[Coroutine[Any, Any, Any]],
|
||||
key: str
|
||||
) -> Any:
|
||||
now = time.time()
|
||||
val = await fn()
|
||||
self._cache[key] = (now, val)
|
||||
|
||||
async def get_work_requests_last_hour(self):
|
||||
return val
|
||||
|
||||
def _cache_get(self, key: str, default: Any = None) -> Any:
|
||||
if key in self._cache:
|
||||
return self._cache[key][1]
|
||||
|
||||
else:
|
||||
return default
|
||||
|
||||
async def data_updater_task(self):
|
||||
while True:
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(
|
||||
self._cache_set, self._get_work_requests_last_hour, 'queue')
|
||||
n.start_soon(
|
||||
self._cache_set, self._find_my_results, 'my_results')
|
||||
|
||||
await trio.sleep(self._update_delta)
|
||||
|
||||
def get_queue(self):
|
||||
return self._cache_get('queue', default=[])
|
||||
|
||||
def get_my_results(self):
|
||||
return self._cache_get('my_results', default=[])
|
||||
|
||||
def get_status_for_request(self, request_id: int) -> list[dict] | None:
|
||||
request: dict | None = next((
|
||||
req
|
||||
for req in self.get_queue()
|
||||
if req['id'] == request_id), None)
|
||||
|
||||
if request:
|
||||
return request['status']
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_competitors_for_request(self, request_id: int) -> list[str] | None:
|
||||
status = self.get_status_for_request(request_id)
|
||||
if not status:
|
||||
return None
|
||||
|
||||
return [
|
||||
s['worker']
|
||||
for s in status
|
||||
if s['worker'] != self.account
|
||||
]
|
||||
|
||||
async def _get_work_requests_last_hour(self) -> list[dict]:
|
||||
logging.info('get_work_requests_last_hour')
|
||||
return await failable(
|
||||
partial(
|
||||
self.cleos.aget_table,
|
||||
'telos.gpu', 'telos.gpu', 'queue',
|
||||
self.contract, self.contract, 'queue',
|
||||
index_position=2,
|
||||
key_type='i64',
|
||||
lower_bound=int(time.time()) - 3600
|
||||
), ret_fail=[])
|
||||
|
||||
async def get_status_by_request_id(self, request_id: int):
|
||||
logging.info('get_status_by_request_id')
|
||||
async def _find_my_results(self):
|
||||
logging.info('find_my_results')
|
||||
return await failable(
|
||||
partial(
|
||||
self.cleos.aget_table,
|
||||
'telos.gpu', request_id, 'status'), ret_fail=[])
|
||||
self.contract, self.contract, 'results',
|
||||
index_position=4,
|
||||
key_type='name',
|
||||
lower_bound=self.account,
|
||||
upper_bound=self.account
|
||||
)
|
||||
)
|
||||
|
||||
async def get_global_config(self):
|
||||
logging.info('get_global_config')
|
||||
|
@ -114,36 +181,6 @@ class SkynetGPUConnector:
|
|||
else:
|
||||
return None
|
||||
|
||||
async def get_competitors_for_req(self, request_id: int) -> set:
|
||||
competitors = [
|
||||
status['worker']
|
||||
for status in
|
||||
(await self.get_status_by_request_id(request_id))
|
||||
if status['worker'] != self.account
|
||||
]
|
||||
logging.info(f'competitors: {competitors}')
|
||||
return set(competitors)
|
||||
|
||||
|
||||
async def get_full_queue_snapshot(self):
|
||||
snap = {
|
||||
'requests': {},
|
||||
'my_results': []
|
||||
}
|
||||
|
||||
snap['queue'] = await self.get_work_requests_last_hour()
|
||||
|
||||
async def _run_and_save(d, key: str, fn, *args, **kwargs):
|
||||
d[key] = await fn(*args, **kwargs)
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
n.start_soon(_run_and_save, snap, 'my_results', self.find_my_results)
|
||||
for req in snap['queue']:
|
||||
n.start_soon(
|
||||
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
|
||||
|
||||
return snap
|
||||
|
||||
async def begin_work(self, request_id: int):
|
||||
logging.info('begin_work')
|
||||
return await failable(
|
||||
|
@ -200,19 +237,6 @@ class SkynetGPUConnector:
|
|||
)
|
||||
)
|
||||
|
||||
async def find_my_results(self):
|
||||
logging.info('find_my_results')
|
||||
return await failable(
|
||||
partial(
|
||||
self.cleos.aget_table,
|
||||
'telos.gpu', 'telos.gpu', 'results',
|
||||
index_position=4,
|
||||
key_type='name',
|
||||
lower_bound=self.account,
|
||||
upper_bound=self.account
|
||||
)
|
||||
)
|
||||
|
||||
async def submit_work(
|
||||
self,
|
||||
request_id: int,
|
||||
|
@ -268,15 +292,11 @@ class SkynetGPUConnector:
|
|||
return file_cid
|
||||
|
||||
async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]:
|
||||
input_type = 'none'
|
||||
|
||||
if ipfs_hash == '':
|
||||
return b'', input_type
|
||||
|
||||
results = {}
|
||||
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
ipfs_link_legacy = ipfs_link + '/image.png'
|
||||
|
||||
input_type = 'unknown'
|
||||
async with trio.open_nursery() as n:
|
||||
async def get_and_set_results(link: str):
|
||||
res = await get_ipfs_file(link, timeout=1)
|
||||
|
@ -310,3 +330,14 @@ class SkynetGPUConnector:
|
|||
raise DGPUComputeError('Couldn\'t gather input data from ipfs')
|
||||
|
||||
return input_data, input_type
|
||||
|
||||
async def get_inputs(self, links: list[str]) -> list[tuple[bytes, str]]:
|
||||
results = {}
|
||||
async def _get_input(link: str) -> None:
|
||||
results[link] = await self.get_input_data(link)
|
||||
|
||||
async with trio.open_nursery() as n:
|
||||
for link in links:
|
||||
n.start_soon(_get_input, link)
|
||||
|
||||
return [results[link] for link in links]
|
||||
|
|
Loading…
Reference in New Issue