protocol_v2
Guillermo Rodriguez 2023-10-12 14:52:29 -03:00
parent 8a415b450f
commit f4592ae254
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
5 changed files with 265 additions and 212 deletions

View File

@ -116,7 +116,7 @@ def enqueue(
key = load_key(config, 'skynet.user.key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url') node_url = load_key(config, 'skynet.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
@ -131,28 +131,39 @@ def enqueue(
kwargs['strength'] = float(kwargs['strength']) kwargs['strength'] = float(kwargs['strength'])
async def enqueue_n_jobs(): async def enqueue_n_jobs():
for i in range(jobs): actions = []
if not kwargs['seed']: for _ in range(jobs):
kwargs['seed'] = random.randint(0, 10e9) if kwargs['seed']:
seed = kwargs['seed']
else:
seed = random.randint(0, 10e9)
_kwargs = kwargs.copy()
_kwargs['seed'] = seed
req = json.dumps({ req = json.dumps({
'method': 'diffuse', 'method': 'diffuse',
'params': kwargs 'params': _kwargs
}) })
res = await cleos.a_push_action( actions.append({
'telos.gpu', 'account': 'telos.gpu',
'enqueue', 'name': 'enqueue',
{ 'data': {
'user': Name(account), 'user': Name(account),
'request_body': req, 'request_body': req,
'binary_data': binary, 'binary_data': binary,
'reward': asset_from_str(reward), 'reward': asset_from_str(reward),
'min_verification': 1 'min_verification': 1
}, },
account, key, permission, 'authorization': [{
) 'actor': account,
print(res) 'permission': permission
}]
})
res = await cleos.a_push_actions(actions, key)
print(res)
trio.run(enqueue_n_jobs) trio.run(enqueue_n_jobs)
@ -169,7 +180,7 @@ def clean(
key = load_key(config, 'skynet.user.key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url') node_url = load_key(config, 'skynet.node_url')
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
@ -187,7 +198,7 @@ def clean(
def queue(): def queue():
import requests import requests
config = load_skynet_toml() config = load_skynet_toml()
node_url = load_key(config, 'skynet.user.node_url') node_url = load_key(config, 'skynet.node_url')
resp = requests.post( resp = requests.post(
f'{node_url}/v1/chain/get_table_rows', f'{node_url}/v1/chain/get_table_rows',
json={ json={
@ -204,7 +215,7 @@ def queue():
def status(request_id: int): def status(request_id: int):
import requests import requests
config = load_skynet_toml() config = load_skynet_toml()
node_url = load_key(config, 'skynet.user.node_url') node_url = load_key(config, 'skynet.node_url')
resp = requests.post( resp = requests.post(
f'{node_url}/v1/chain/get_table_rows', f'{node_url}/v1/chain/get_table_rows',
json={ json={
@ -226,7 +237,7 @@ def dequeue(request_id: int):
key = load_key(config, 'skynet.user.key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url') node_url = load_key(config, 'skynet.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run( res = trio.run(
@ -261,7 +272,7 @@ def config(
key = load_key(config, 'skynet.user.key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url') node_url = load_key(config, 'skynet.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run( res = trio.run(
@ -290,7 +301,7 @@ def deposit(quantity: str):
key = load_key(config, 'skynet.user.key') key = load_key(config, 'skynet.user.key')
account = load_key(config, 'skynet.user.account') account = load_key(config, 'skynet.user.account')
permission = load_key(config, 'skynet.user.permission') permission = load_key(config, 'skynet.user.permission')
node_url = load_key(config, 'skynet.user.node_url') node_url = load_key(config, 'skynet.node_url')
cleos = CLEOS(None, None, url=node_url, remote=node_url) cleos = CLEOS(None, None, url=node_url, remote=node_url)
res = trio.run( res = trio.run(
@ -352,7 +363,7 @@ def dgpu(
assert 'skynet' in config assert 'skynet' in config
assert 'dgpu' in config['skynet'] assert 'dgpu' in config['skynet']
trio.run(open_dgpu_node, config['skynet']['dgpu']) trio.run(open_dgpu_node, config['skynet'])
@run.command() @run.command()
@ -375,30 +386,30 @@ def telegram(
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
config = load_skynet_toml() config = load_skynet_toml()
tg_token = load_key(config, 'skynet.telegram.tg_token') tg_token = load_key(config, 'skynet.telegram.token')
key = load_key(config, 'skynet.telegram.key') key = load_key(config, 'skynet.telegram.key')
account = load_key(config, 'skynet.telegram.account') account = load_key(config, 'skynet.telegram.account')
permission = load_key(config, 'skynet.telegram.permission') permission = load_key(config, 'skynet.telegram.permission')
node_url = load_key(config, 'skynet.telegram.node_url') node_url = load_key(config, 'skynet.node_url')
hyperion_url = load_key(config, 'skynet.telegram.hyperion_url') hyperion_url = load_key(config, 'skynet.hyperion_url')
try: try:
ipfs_gateway_url = load_key(config, 'skynet.telegram.ipfs_gateway_url') ipfs_gateway_url = load_key(config, 'skynet.ipfs_gateway_url')
except ConfigParsingError: except ConfigParsingError:
ipfs_gateway_url = None ipfs_gateway_url = None
ipfs_url = load_key(config, 'skynet.telegram.ipfs_url') ipfs_url = load_key(config, 'skynet.ipfs_url')
try: try:
explorer_domain = load_key(config, 'skynet.telegram.explorer_domain') explorer_domain = load_key(config, 'skynet.explorer_domain')
except ConfigParsingError: except ConfigParsingError:
explorer_domain = DEFAULT_EXPLORER_DOMAIN explorer_domain = DEFAULT_EXPLORER_DOMAIN
try: try:
ipfs_domain = load_key(config, 'skynet.telegram.ipfs_domain') ipfs_domain = load_key(config, 'skynet.ipfs_domain')
except ConfigParsingError: except ConfigParsingError:
ipfs_domain = DEFAULT_IPFS_DOMAIN ipfs_domain = DEFAULT_IPFS_DOMAIN
@ -445,25 +456,25 @@ def discord(
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
config = load_skynet_toml() config = load_skynet_toml()
dc_token = load_key(config, 'skynet.discord.dc_token') dc_token = load_key(config, 'skynet.discord.token')
key = load_key(config, 'skynet.discord.key') key = load_key(config, 'skynet.discord.key')
account = load_key(config, 'skynet.discord.account') account = load_key(config, 'skynet.discord.account')
permission = load_key(config, 'skynet.discord.permission') permission = load_key(config, 'skynet.discord.permission')
node_url = load_key(config, 'skynet.discord.node_url') node_url = load_key(config, 'skynet.node_url')
hyperion_url = load_key(config, 'skynet.discord.hyperion_url') hyperion_url = load_key(config, 'skynet.hyperion_url')
ipfs_gateway_url = load_key(config, 'skynet.discord.ipfs_gateway_url') ipfs_gateway_url = load_key(config, 'skynet.ipfs_gateway_url')
ipfs_url = load_key(config, 'skynet.discord.ipfs_url') ipfs_url = load_key(config, 'skynet.ipfs_url')
try: try:
explorer_domain = load_key(config, 'skynet.discord.explorer_domain') explorer_domain = load_key(config, 'skynet.explorer_domain')
except ConfigParsingError: except ConfigParsingError:
explorer_domain = DEFAULT_EXPLORER_DOMAIN explorer_domain = DEFAULT_EXPLORER_DOMAIN
try: try:
ipfs_domain = load_key(config, 'skynet.discord.ipfs_domain') ipfs_domain = load_key(config, 'skynet.ipfs_domain')
except ConfigParsingError: except ConfigParsingError:
ipfs_domain = DEFAULT_IPFS_DOMAIN ipfs_domain = DEFAULT_IPFS_DOMAIN
@ -509,8 +520,8 @@ def pinner(loglevel):
from .ipfs.pinner import SkynetPinner from .ipfs.pinner import SkynetPinner
config = load_skynet_toml() config = load_skynet_toml()
hyperion_url = load_key(config, 'skynet.pinner.hyperion_url') hyperion_url = load_key(config, 'skynet.hyperion_url')
ipfs_url = load_key(config, 'skynet.pinner.ipfs_url') ipfs_url = load_key(config, 'skynet.ipfs_url')
logging.basicConfig(level=loglevel) logging.basicConfig(level=loglevel)
ipfs_node = AsyncIPFSHTTP(ipfs_url) ipfs_node = AsyncIPFSHTTP(ipfs_url)

View File

@ -11,18 +11,18 @@ from skynet.dgpu.network import SkynetGPUConnector
async def open_dgpu_node(config: dict): async def open_dgpu_node(config: dict):
conn = SkynetGPUConnector(config) conn = SkynetGPUConnector({**config, **config['dgpu']})
mm = SkynetMM(config) mm = SkynetMM(config['dgpu'])
daemon = SkynetDGPUDaemon(mm, conn, config) daemon = SkynetDGPUDaemon(mm, conn, config['dgpu'])
api = None api = None
if 'api_bind' in config: if 'api_bind' in config['dgpu']:
api_conf = Config() api_conf = Config()
api_conf.bind = [config['api_bind']] api_conf.bind = [config['api_bind']]
api = await daemon.generate_api() api = await daemon.generate_api()
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
n.start_soon(daemon.snap_updater_task) n.start_soon(conn.data_updater_task)
if api: if api:
n.start_soon(serve, api, api_conf) n.start_soon(serve, api, api_conf)

View File

@ -6,6 +6,7 @@ import gc
import logging import logging
from hashlib import sha256 from hashlib import sha256
from typing import Any
import zipfile import zipfile
from PIL import Image from PIL import Image
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
@ -21,11 +22,16 @@ from skynet.utils import crop_image, convert_from_cv2_to_image, convert_from_ima
def prepare_params_for_diffuse( def prepare_params_for_diffuse(
params: dict, params: dict,
input_type: str, inputs: list[tuple[Any, str]],
binary = None
): ):
_params = {} _params = {}
if binary != None:
if len(inputs) > 1:
raise DGPUComputeError('sorry binary_inputs > 1 not implemented yet')
if len(inputs) == 0:
binary, input_type = inputs[0]
match input_type: match input_type:
case 'png': case 'png':
image = crop_image( image = crop_image(
@ -34,9 +40,6 @@ def prepare_params_for_diffuse(
_params['image'] = image _params['image'] = image
_params['strength'] = float(params['strength']) _params['strength'] = float(params['strength'])
case 'none':
...
case _: case _:
raise DGPUComputeError(f'Unknown input_type {input_type}') raise DGPUComputeError(f'Unknown input_type {input_type}')
@ -144,8 +147,7 @@ class SkynetMM:
request_id: int, request_id: int,
method: str, method: str,
params: dict, params: dict,
input_type: str = 'png', inputs: list[tuple[Any, str]]
binary: bytes | None = None
): ):
def maybe_cancel_work(step, *args, **kwargs): def maybe_cancel_work(step, *args, **kwargs):
if self._should_cancel: if self._should_cancel:
@ -165,8 +167,7 @@ class SkynetMM:
try: try:
match method: match method:
case 'diffuse': case 'diffuse':
arguments = prepare_params_for_diffuse( arguments = prepare_params_for_diffuse(params, inputs)
params, input_type, binary=binary)
prompt, guidance, step, seed, upscaler, extra_params = arguments prompt, guidance, step, seed, upscaler, extra_params = arguments
model = self.get_model(params['model'], 'image' in extra_params) model = self.get_model(params['model'], 'image' in extra_params)

View File

@ -44,6 +44,10 @@ class SkynetDGPUDaemon:
config['auto_withdraw'] config['auto_withdraw']
if 'auto_withdraw' in config else False if 'auto_withdraw' in config else False
) )
self.max_concurrent = (
config['max_concurrent']
if 'max_concurrent' in config else 0
)
self.account = config['account'] self.account = config['account']
@ -63,12 +67,6 @@ class SkynetDGPUDaemon:
if 'backend' in config: if 'backend' in config:
self.backend = config['backend'] self.backend = config['backend']
self._snap = {
'queue': [],
'requests': {},
'my_results': []
}
self._benchmark = [] self._benchmark = []
self._last_benchmark = None self._last_benchmark = None
self._last_generation_ts = None self._last_generation_ts = None
@ -90,18 +88,10 @@ class SkynetDGPUDaemon:
async def should_cancel_work(self, request_id: int): async def should_cancel_work(self, request_id: int):
self._benchmark.append(time.time()) self._benchmark.append(time.time())
competitors = set([ competitors = self.conn.get_competitors_for_request(request_id)
status['worker'] if competitors == None:
for status in self._snap['requests'][request_id] return True
if status['worker'] != self.account return bool(self.non_compete & set(competitors))
])
return bool(self.non_compete & competitors)
async def snap_updater_task(self):
while True:
self._snap = await self.conn.get_full_queue_snapshot()
await trio.sleep(1)
async def generate_api(self): async def generate_api(self):
app = Quart(__name__) app = Quart(__name__)
@ -117,108 +107,128 @@ class SkynetDGPUDaemon:
return app return app
def find_best_requests(self) -> list[dict]:
queue = self.conn.get_queue()
for _ in range(3):
random.shuffle(queue)
queue = sorted(
queue,
key=lambda req: convert_reward_to_int(req['reward']),
reverse=True
)
requests = []
for req in queue:
rid = req['nonce']
# parse request
body = json.loads(req['body'])
model = body['params']['model']
# if model not known
if model not in MODELS:
logging.warning(f'Unknown model {model}')
continue
# if whitelist enabled and model not in it continue
if (len(self.model_whitelist) > 0 and
not model in self.model_whitelist):
continue
# if blacklist contains model skip
if model in self.model_blacklist:
continue
my_results = [res['id'] for res in self.conn.get_my_results()]
# if this worker already on it
if rid in my_results:
continue
status = self.conn.get_status_for_request(rid)
if status == None:
continue
if self.non_compete & set(self.conn.get_competitors_for_request(rid)):
continue
if len(status) > self.max_concurrent:
continue
requests.append(req)
return requests
async def serve_forever(self): async def serve_forever(self):
try: try:
while True: while True:
if self.auto_withdraw: if self.auto_withdraw:
await self.conn.maybe_withdraw_all() await self.conn.maybe_withdraw_all()
queue = self._snap['queue'] requests = self.find_best_requests()
random.shuffle(queue) if len(requests) > 0:
queue = sorted( request = requests[0]
queue, rid = request['nonce']
key=lambda req: convert_reward_to_int(req['reward']), body = json.loads(request['body'])
reverse=True
)
for req in queue: inputs = await self.conn.get_inputs(request['binary_inputs'])
rid = req['id']
# parse request hash_str = (
body = json.loads(req['body']) str(request['nonce'])
model = body['params']['model'] +
request['body']
+
''.join([_in for _in in request['binary_inputs']])
)
logging.info(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
# if model not known # TODO: validate request
if model not in MODELS:
logging.warning(f'Unknown model {model}')
continue
# if whitelist enabled and model not in it continue # perform work
if (len(self.model_whitelist) > 0 and logging.info(f'working on {body}')
not model in self.model_whitelist):
continue
# if blacklist contains model skip resp = await self.conn.begin_work(rid)
if model in self.model_blacklist: if 'code' in resp:
continue logging.info(f'probably being worked on already... skip.')
my_results = [res['id'] for res in self._snap['my_results']]
if rid not in my_results and rid in self._snap['requests']:
statuses = self._snap['requests'][rid]
if len(statuses) == 0:
binary, input_type = await self.conn.get_input_data(req['binary_data'])
hash_str = (
str(req['nonce'])
+
req['body']
+
req['binary_data']
)
logging.info(f'hashing: {hash_str}')
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
# TODO: validate request
# perform work
logging.info(f'working on {body}')
resp = await self.conn.begin_work(rid)
if 'code' in resp:
logging.info(f'probably being worked on already... skip.')
else:
try:
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.backend:
case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
rid,
body['method'], body['params'],
input_type=input_type,
binary=binary
)
)
case _:
raise DGPUComputeError(f'Unsupported backend {self.backend}')
self._last_generation_ts = datetime.now().isoformat()
self._last_benchmark = self._benchmark
self._benchmark = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
except BaseException as e:
traceback.print_exc()
await self.conn.cancel_work(rid, str(e))
finally:
break
else: else:
logging.info(f'request {rid} already beign worked on, skip...') try:
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
output = None
output_hash = None
match self.backend:
case 'sync-on-thread':
self.mm._should_cancel = self.should_cancel_work
output_hash, output = await trio.to_thread.run_sync(
partial(
self.mm.compute_one,
rid,
body['method'], body['params'],
inputs=inputs
)
)
case _:
raise DGPUComputeError(f'Unsupported backend {self.backend}')
self._last_generation_ts = datetime.now().isoformat()
self._last_benchmark = self._benchmark
self._benchmark = []
ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
except BaseException as e:
traceback.print_exc()
await self.conn.cancel_work(rid, str(e))
await trio.sleep(1) await trio.sleep(1)

View File

@ -7,6 +7,7 @@ import logging
from pathlib import Path from pathlib import Path
from functools import partial from functools import partial
from typing import Any, Coroutine
import asks import asks
import trio import trio
@ -25,7 +26,10 @@ from skynet.dgpu.errors import DGPUComputeError
REQUEST_UPDATE_TIME = 3 REQUEST_UPDATE_TIME = 3
async def failable(fn: partial, ret_fail=None): async def failable(
fn: partial[Coroutine[Any, Any, Any]],
ret_fail: Any | None = None
) -> Any:
try: try:
return await fn() return await fn()
@ -42,6 +46,7 @@ async def failable(fn: partial, ret_fail=None):
class SkynetGPUConnector: class SkynetGPUConnector:
def __init__(self, config: dict): def __init__(self, config: dict):
self.contract = config['contract']
self.account = Name(config['account']) self.account = Name(config['account'])
self.permission = config['permission'] self.permission = config['permission']
self.key = config['key'] self.key = config['key']
@ -63,27 +68,89 @@ class SkynetGPUConnector:
if 'ipfs_domain' in config: if 'ipfs_domain' in config:
self.ipfs_domain = config['ipfs_domain'] self.ipfs_domain = config['ipfs_domain']
self._wip_requests = {} self._update_delta = 1
self._cache: dict[str, tuple[float, Any]] = {}
# blockchain helpers async def _cache_set(
self,
fn: partial[Coroutine[Any, Any, Any]],
key: str
) -> Any:
now = time.time()
val = await fn()
self._cache[key] = (now, val)
async def get_work_requests_last_hour(self): return val
def _cache_get(self, key: str, default: Any = None) -> Any:
if key in self._cache:
return self._cache[key][1]
else:
return default
async def data_updater_task(self):
while True:
async with trio.open_nursery() as n:
n.start_soon(
self._cache_set, self._get_work_requests_last_hour, 'queue')
n.start_soon(
self._cache_set, self._find_my_results, 'my_results')
await trio.sleep(self._update_delta)
def get_queue(self):
return self._cache_get('queue', default=[])
def get_my_results(self):
return self._cache_get('my_results', default=[])
def get_status_for_request(self, request_id: int) -> list[dict] | None:
request: dict | None = next((
req
for req in self.get_queue()
if req['id'] == request_id), None)
if request:
return request['status']
else:
return None
def get_competitors_for_request(self, request_id: int) -> list[str] | None:
status = self.get_status_for_request(request_id)
if not status:
return None
return [
s['worker']
for s in status
if s['worker'] != self.account
]
async def _get_work_requests_last_hour(self) -> list[dict]:
logging.info('get_work_requests_last_hour') logging.info('get_work_requests_last_hour')
return await failable( return await failable(
partial( partial(
self.cleos.aget_table, self.cleos.aget_table,
'telos.gpu', 'telos.gpu', 'queue', self.contract, self.contract, 'queue',
index_position=2, index_position=2,
key_type='i64', key_type='i64',
lower_bound=int(time.time()) - 3600 lower_bound=int(time.time()) - 3600
), ret_fail=[]) ), ret_fail=[])
async def get_status_by_request_id(self, request_id: int): async def _find_my_results(self):
logging.info('get_status_by_request_id') logging.info('find_my_results')
return await failable( return await failable(
partial( partial(
self.cleos.aget_table, self.cleos.aget_table,
'telos.gpu', request_id, 'status'), ret_fail=[]) self.contract, self.contract, 'results',
index_position=4,
key_type='name',
lower_bound=self.account,
upper_bound=self.account
)
)
async def get_global_config(self): async def get_global_config(self):
logging.info('get_global_config') logging.info('get_global_config')
@ -114,36 +181,6 @@ class SkynetGPUConnector:
else: else:
return None return None
async def get_competitors_for_req(self, request_id: int) -> set:
competitors = [
status['worker']
for status in
(await self.get_status_by_request_id(request_id))
if status['worker'] != self.account
]
logging.info(f'competitors: {competitors}')
return set(competitors)
async def get_full_queue_snapshot(self):
snap = {
'requests': {},
'my_results': []
}
snap['queue'] = await self.get_work_requests_last_hour()
async def _run_and_save(d, key: str, fn, *args, **kwargs):
d[key] = await fn(*args, **kwargs)
async with trio.open_nursery() as n:
n.start_soon(_run_and_save, snap, 'my_results', self.find_my_results)
for req in snap['queue']:
n.start_soon(
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
return snap
async def begin_work(self, request_id: int): async def begin_work(self, request_id: int):
logging.info('begin_work') logging.info('begin_work')
return await failable( return await failable(
@ -200,19 +237,6 @@ class SkynetGPUConnector:
) )
) )
async def find_my_results(self):
logging.info('find_my_results')
return await failable(
partial(
self.cleos.aget_table,
'telos.gpu', 'telos.gpu', 'results',
index_position=4,
key_type='name',
lower_bound=self.account,
upper_bound=self.account
)
)
async def submit_work( async def submit_work(
self, self,
request_id: int, request_id: int,
@ -268,15 +292,11 @@ class SkynetGPUConnector:
return file_cid return file_cid
async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]: async def get_input_data(self, ipfs_hash: str) -> tuple[bytes, str]:
input_type = 'none'
if ipfs_hash == '':
return b'', input_type
results = {} results = {}
ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}' ipfs_link = f'https://{self.ipfs_domain}/ipfs/{ipfs_hash}'
ipfs_link_legacy = ipfs_link + '/image.png' ipfs_link_legacy = ipfs_link + '/image.png'
input_type = 'unknown'
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
async def get_and_set_results(link: str): async def get_and_set_results(link: str):
res = await get_ipfs_file(link, timeout=1) res = await get_ipfs_file(link, timeout=1)
@ -310,3 +330,14 @@ class SkynetGPUConnector:
raise DGPUComputeError('Couldn\'t gather input data from ipfs') raise DGPUComputeError('Couldn\'t gather input data from ipfs')
return input_data, input_type return input_data, input_type
async def get_inputs(self, links: list[str]) -> list[tuple[bytes, str]]:
results = {}
async def _get_input(link: str) -> None:
results[link] = await self.get_input_data(link)
async with trio.open_nursery() as n:
for link in links:
n.start_soon(_get_input, link)
return [results[link] for link in links]