mirror of https://github.com/skygpu/skynet.git
Add support for both protocol versions on worker daemon
parent
7edca49e95
commit
47dda50f32
|
@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct):
|
||||||
poll_time: float = 0.5 # wait time for polling updates from contract
|
poll_time: float = 0.5 # wait time for polling updates from contract
|
||||||
log_level: str = 'info'
|
log_level: str = 'info'
|
||||||
log_file: str = 'dgpu.log' # log file path (only used when tui = true)
|
log_file: str = 'dgpu.log' # log file path (only used when tui = true)
|
||||||
|
proto_version: int = 0
|
||||||
|
|
||||||
|
|
||||||
class FrontendConfig(msgspec.Struct):
|
class FrontendConfig(msgspec.Struct):
|
||||||
|
|
|
@ -5,10 +5,10 @@ from leap import CLEOS
|
||||||
from leap.protocol import Name
|
from leap.protocol import Name
|
||||||
|
|
||||||
from skynet.types import (
|
from skynet.types import (
|
||||||
ConfigV1,
|
Config, ConfigV0, ConfigV1,
|
||||||
AccountV1,
|
Account, AccountV0, AccountV1,
|
||||||
WorkerV0,
|
WorkerV0,
|
||||||
RequestV1,
|
Request, RequestV0, RequestV1,
|
||||||
BodyV0,
|
BodyV0,
|
||||||
WorkerStatusV0,
|
WorkerStatusV0,
|
||||||
ResultV0
|
ResultV0
|
||||||
|
@ -33,37 +33,39 @@ class WorkerStatusNotFound(BaseException):
|
||||||
|
|
||||||
class GPUContractAPI:
|
class GPUContractAPI:
|
||||||
|
|
||||||
def __init__(self, cleos: CLEOS):
|
def __init__(self, cleos: CLEOS, proto_version: int = 0):
|
||||||
self.receiver = 'gpu.scd'
|
self.receiver = 'gpu.scd'
|
||||||
self._cleos = cleos
|
self._cleos = cleos
|
||||||
|
self.proto_version = proto_version
|
||||||
|
|
||||||
# views into data
|
# views into data
|
||||||
|
|
||||||
async def get_config(self) -> ConfigV1:
|
async def get_config(self) -> Config:
|
||||||
rows = await self._cleos.aget_table(
|
rows = await self._cleos.aget_table(
|
||||||
self.receiver, self.receiver, 'config',
|
self.receiver, self.receiver, 'config',
|
||||||
resp_cls=ConfigV1
|
resp_cls=ConfigV1 if self.proto_version > 1 else ConfigV0
|
||||||
)
|
)
|
||||||
if len(rows) == 0:
|
if len(rows) == 0:
|
||||||
raise ConfigNotFound()
|
raise ConfigNotFound()
|
||||||
|
|
||||||
return rows[0]
|
return rows[0]
|
||||||
|
|
||||||
async def get_user(self, user: str) -> AccountV1:
|
async def get_user(self, user: str) -> Account:
|
||||||
rows = await self._cleos.aget_table(
|
rows = await self._cleos.aget_table(
|
||||||
self.receiver, self.receiver, 'users',
|
self.receiver, self.receiver, 'users',
|
||||||
key_type='name',
|
key_type='name',
|
||||||
lower_bound=user,
|
lower_bound=user,
|
||||||
upper_bound=user,
|
upper_bound=user,
|
||||||
resp_cls=AccountV1
|
resp_cls=AccountV1 if self.proto_version > 1 else AccountV0
|
||||||
)
|
)
|
||||||
if len(rows) == 0:
|
if len(rows) == 0:
|
||||||
raise AccountNotFound(user)
|
raise AccountNotFound(user)
|
||||||
|
|
||||||
return rows[0]
|
return rows[0]
|
||||||
|
|
||||||
async def get_users(self) -> list[AccountV1]:
|
async def get_users(self) -> list[Account]:
|
||||||
return await self._cleos.aget_table(self.receiver, self.receiver, 'users', resp_cls=AccountV1)
|
return await self._cleos.aget_table(
|
||||||
|
self.receiver, self.receiver, 'users', resp_cls=AccountV1 if self.proto_version > 0 else AccountV0)
|
||||||
|
|
||||||
async def get_worker(self, worker: str) -> WorkerV0:
|
async def get_worker(self, worker: str) -> WorkerV0:
|
||||||
rows = await self._cleos.aget_table(
|
rows = await self._cleos.aget_table(
|
||||||
|
@ -78,31 +80,32 @@ class GPUContractAPI:
|
||||||
|
|
||||||
return rows[0]
|
return rows[0]
|
||||||
|
|
||||||
async def get_workers(self) -> list[AccountV1]:
|
async def get_workers(self) -> list[WorkerV0]:
|
||||||
return await self._cleos.aget_table(self.receiver, self.receiver, 'workers', resp_cls=WorkerV0)
|
return await self._cleos.aget_table(self.receiver, self.receiver, 'workers', resp_cls=WorkerV0)
|
||||||
|
|
||||||
async def get_queue(self) -> RequestV1:
|
async def get_queue(self) -> Request:
|
||||||
return await self._cleos.aget_table(self.receiver, self.receiver, 'queue', resp_cls=RequestV1)
|
return await self._cleos.aget_table(
|
||||||
|
self.receiver, self.receiver, 'queue', resp_cls=RequestV1 if self.proto_version > 0 else RequestV0)
|
||||||
|
|
||||||
async def get_request(self, request_id: int) -> RequestV1:
|
async def get_request(self, request_id: int) -> Request:
|
||||||
rows = await self._cleos.aget_table(
|
rows = await self._cleos.aget_table(
|
||||||
self.receiver, self.receiver, 'queue',
|
self.receiver, self.receiver, 'queue',
|
||||||
lower_bound=request_id,
|
lower_bound=request_id,
|
||||||
upper_bound=request_id,
|
upper_bound=request_id,
|
||||||
resp_cls=RequestV1
|
resp_cls=RequestV1 if self.proto_version > 0 else RequestV0
|
||||||
)
|
)
|
||||||
if len(rows) == 0:
|
if len(rows) == 0:
|
||||||
raise RequestNotFound(request_id)
|
raise RequestNotFound(request_id)
|
||||||
|
|
||||||
return rows[0]
|
return rows[0]
|
||||||
|
|
||||||
async def get_requests_since(self, seconds: int) -> list[RequestV1]:
|
async def get_requests_since(self, seconds: int) -> list[Request]:
|
||||||
return await self._cleos.aget_table(
|
return await self._cleos.aget_table(
|
||||||
self.receiver, self.receiver, 'queue',
|
self.receiver, self.receiver, 'queue',
|
||||||
index_position=2,
|
index_position=2,
|
||||||
key_type='i64',
|
key_type='i64',
|
||||||
lower_bound=int(time.time()) - seconds,
|
lower_bound=int(time.time()) - seconds,
|
||||||
resp_cls=RequestV1
|
resp_cls=RequestV1 if self.proto_version > 0 else RequestV0
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_statuses_for_request(self, request_id: int) -> list[WorkerStatusV0]:
|
async def get_statuses_for_request(self, request_id: int) -> list[WorkerStatusV0]:
|
||||||
|
@ -242,12 +245,17 @@ class GPUContractAPI:
|
||||||
worker: str,
|
worker: str,
|
||||||
request_id: int,
|
request_id: int,
|
||||||
result_hash: str,
|
result_hash: str,
|
||||||
ipfs_hash: str
|
ipfs_hash: str,
|
||||||
|
request_hash: str | None = None
|
||||||
):
|
):
|
||||||
|
args = [worker, request_id, result_hash, ipfs_hash]
|
||||||
|
if request_hash:
|
||||||
|
args.insert(2, request_hash)
|
||||||
|
|
||||||
return await self._cleos.a_push_action(
|
return await self._cleos.a_push_action(
|
||||||
self.receiver,
|
self.receiver,
|
||||||
'submit',
|
'submit',
|
||||||
[worker, request_id, result_hash, ipfs_hash],
|
args,
|
||||||
worker,
|
worker,
|
||||||
key=self._cleos.private_keys[worker]
|
key=self._cleos.private_keys[worker]
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
import contextlib
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
@ -22,7 +24,7 @@ from skynet.dgpu.network import (
|
||||||
async def maybe_update_tui_balance(contract: GPUContractAPI):
|
async def maybe_update_tui_balance(contract: GPUContractAPI):
|
||||||
async def _fn(tui):
|
async def _fn(tui):
|
||||||
# update balance
|
# update balance
|
||||||
balance = await contract.get_user(tui.config.account).balance
|
balance = (await contract.get_user(tui.config.account)).balance
|
||||||
tui.set_header_text(new_balance=f'balance: {balance}')
|
tui.set_header_text(new_balance=f'balance: {balance}')
|
||||||
|
|
||||||
await maybe_update_tui_async(_fn)
|
await maybe_update_tui_async(_fn)
|
||||||
|
@ -126,6 +128,8 @@ async def maybe_serve_one(
|
||||||
used by torch each step of the inference, it will use a
|
used by torch each step of the inference, it will use a
|
||||||
trio.from_thread to unblock the main thread and pump the event loop
|
trio.from_thread to unblock the main thread and pump the event loop
|
||||||
'''
|
'''
|
||||||
|
with open(os.devnull, 'w') as devnull:
|
||||||
|
with contextlib.redirect_stdout(devnull):
|
||||||
output_hash, output = await trio.to_thread.run_sync(
|
output_hash, output = await trio.to_thread.run_sync(
|
||||||
partial(
|
partial(
|
||||||
compute_one,
|
compute_one,
|
||||||
|
@ -146,7 +150,12 @@ async def maybe_serve_one(
|
||||||
|
|
||||||
ipfs_hash = await ipfs_api.publish(output, type=output_type)
|
ipfs_hash = await ipfs_api.publish(output, type=output_type)
|
||||||
|
|
||||||
await contract.submit_work(config.account, req.id, output_hash, ipfs_hash)
|
maybe_request_hash = None
|
||||||
|
if config.proto_version == 0:
|
||||||
|
maybe_request_hash = req.hash_v0()
|
||||||
|
|
||||||
|
await contract.submit_work(
|
||||||
|
config.account, req.id, output_hash, ipfs_hash, request_hash=maybe_request_hash)
|
||||||
|
|
||||||
await state_mngr.update_state()
|
await state_mngr.update_state()
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from skynet.config import load_skynet_toml
|
||||||
from skynet.contract import GPUContractAPI
|
from skynet.contract import GPUContractAPI
|
||||||
from skynet.types import (
|
from skynet.types import (
|
||||||
BodyV0,
|
BodyV0,
|
||||||
RequestV1,
|
Request,
|
||||||
WorkerStatusV0,
|
WorkerStatusV0,
|
||||||
ResultV0
|
ResultV0
|
||||||
)
|
)
|
||||||
|
@ -63,7 +63,7 @@ class ContractState:
|
||||||
self._config = load_skynet_toml().dgpu
|
self._config = load_skynet_toml().dgpu
|
||||||
self._poll_index = 0
|
self._poll_index = 0
|
||||||
|
|
||||||
self._queue: list[RequestV1] = []
|
self._queue: list[Request] = []
|
||||||
self._status_by_rid: dict[int, list[WorkerStatusV0]] = {}
|
self._status_by_rid: dict[int, list[WorkerStatusV0]] = {}
|
||||||
self._results: list[ResultV0] = []
|
self._results: list[ResultV0] = []
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ class ContractState:
|
||||||
return len(self._queue)
|
return len(self._queue)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def first(self) -> RequestV1 | None:
|
def first(self) -> Request | None:
|
||||||
if len(self._queue) > 0:
|
if len(self._queue) > 0:
|
||||||
return self._queue[0]
|
return self._queue[0]
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from skynet.config import DgpuConfig as Config
|
||||||
|
|
||||||
|
|
||||||
class WorkerMonitor:
|
class WorkerMonitor:
|
||||||
def __init__(self):
|
def __init__(self, config: Config):
|
||||||
self.requests = []
|
self.requests = []
|
||||||
self.header_info = {}
|
self.header_info = {}
|
||||||
|
|
||||||
|
@ -63,6 +63,7 @@ class WorkerMonitor:
|
||||||
event_loop=self.event_loop,
|
event_loop=self.event_loop,
|
||||||
unhandled_input=self._exit_on_q
|
unhandled_input=self._exit_on_q
|
||||||
)
|
)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def _create_listbox_body(self, requests):
|
def _create_listbox_body(self, requests):
|
||||||
"""
|
"""
|
||||||
|
@ -197,7 +198,7 @@ def init_tui(config: Config):
|
||||||
global _tui
|
global _tui
|
||||||
assert not _tui
|
assert not _tui
|
||||||
setup_logging_for_tui(config)
|
setup_logging_for_tui(config)
|
||||||
_tui = WorkerMonitor()
|
_tui = WorkerMonitor(config)
|
||||||
return _tui
|
return _tui
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from hashlib import sha256
|
||||||
|
|
||||||
from msgspec import Struct
|
from msgspec import Struct
|
||||||
|
|
||||||
|
@ -58,6 +59,8 @@ class ConfigV1(Struct):
|
||||||
token_symbol: str
|
token_symbol: str
|
||||||
global_nonce: int
|
global_nonce: int
|
||||||
|
|
||||||
|
type Config = ConfigV0 | ConfigV1
|
||||||
|
|
||||||
'''
|
'''
|
||||||
RequestV0
|
RequestV0
|
||||||
|
|
||||||
|
@ -122,6 +125,16 @@ class RequestV0(Struct):
|
||||||
binary_data: str
|
binary_data: str
|
||||||
timestamp: str
|
timestamp: str
|
||||||
|
|
||||||
|
def hash_v0(self) -> str:
|
||||||
|
hash_str = (
|
||||||
|
str(self.nonce)
|
||||||
|
+
|
||||||
|
self.body
|
||||||
|
+
|
||||||
|
self.binary_data
|
||||||
|
)
|
||||||
|
return sha256(hash_str.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
'''
|
'''
|
||||||
RequestV1
|
RequestV1
|
||||||
|
|
||||||
|
@ -154,6 +167,8 @@ class RequestV1(Struct):
|
||||||
binary_data: str
|
binary_data: str
|
||||||
timestamp: str
|
timestamp: str
|
||||||
|
|
||||||
|
type Request = RequestV0 | RequestV1
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
AccountV0
|
AccountV0
|
||||||
|
@ -200,6 +215,7 @@ class AccountV1(Struct):
|
||||||
user: str
|
user: str
|
||||||
balance: str
|
balance: str
|
||||||
|
|
||||||
|
type Account = AccountV0 | AccountV1
|
||||||
|
|
||||||
'''
|
'''
|
||||||
WorkerV0
|
WorkerV0
|
||||||
|
|
Loading…
Reference in New Issue