mirror of https://github.com/skygpu/skynet.git
Add support for both protocol versions on worker daemon
parent
7edca49e95
commit
47dda50f32
|
@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct):
|
|||
poll_time: float = 0.5 # wait time for polling updates from contract
|
||||
log_level: str = 'info'
|
||||
log_file: str = 'dgpu.log' # log file path (only used when tui = true)
|
||||
proto_version: int = 0
|
||||
|
||||
|
||||
class FrontendConfig(msgspec.Struct):
|
||||
|
|
|
@ -5,10 +5,10 @@ from leap import CLEOS
|
|||
from leap.protocol import Name
|
||||
|
||||
from skynet.types import (
|
||||
ConfigV1,
|
||||
AccountV1,
|
||||
Config, ConfigV0, ConfigV1,
|
||||
Account, AccountV0, AccountV1,
|
||||
WorkerV0,
|
||||
RequestV1,
|
||||
Request, RequestV0, RequestV1,
|
||||
BodyV0,
|
||||
WorkerStatusV0,
|
||||
ResultV0
|
||||
|
@ -33,37 +33,39 @@ class WorkerStatusNotFound(BaseException):
|
|||
|
||||
class GPUContractAPI:
|
||||
|
||||
def __init__(self, cleos: CLEOS):
|
||||
def __init__(self, cleos: CLEOS, proto_version: int = 0):
|
||||
self.receiver = 'gpu.scd'
|
||||
self._cleos = cleos
|
||||
self.proto_version = proto_version
|
||||
|
||||
# views into data
|
||||
|
||||
async def get_config(self) -> ConfigV1:
|
||||
async def get_config(self) -> Config:
|
||||
rows = await self._cleos.aget_table(
|
||||
self.receiver, self.receiver, 'config',
|
||||
resp_cls=ConfigV1
|
||||
resp_cls=ConfigV1 if self.proto_version > 1 else ConfigV0
|
||||
)
|
||||
if len(rows) == 0:
|
||||
raise ConfigNotFound()
|
||||
|
||||
return rows[0]
|
||||
|
||||
async def get_user(self, user: str) -> AccountV1:
|
||||
async def get_user(self, user: str) -> Account:
|
||||
rows = await self._cleos.aget_table(
|
||||
self.receiver, self.receiver, 'users',
|
||||
key_type='name',
|
||||
lower_bound=user,
|
||||
upper_bound=user,
|
||||
resp_cls=AccountV1
|
||||
resp_cls=AccountV1 if self.proto_version > 1 else AccountV0
|
||||
)
|
||||
if len(rows) == 0:
|
||||
raise AccountNotFound(user)
|
||||
|
||||
return rows[0]
|
||||
|
||||
async def get_users(self) -> list[AccountV1]:
|
||||
return await self._cleos.aget_table(self.receiver, self.receiver, 'users', resp_cls=AccountV1)
|
||||
async def get_users(self) -> list[Account]:
|
||||
return await self._cleos.aget_table(
|
||||
self.receiver, self.receiver, 'users', resp_cls=AccountV1 if self.proto_version > 0 else AccountV0)
|
||||
|
||||
async def get_worker(self, worker: str) -> WorkerV0:
|
||||
rows = await self._cleos.aget_table(
|
||||
|
@ -78,31 +80,32 @@ class GPUContractAPI:
|
|||
|
||||
return rows[0]
|
||||
|
||||
async def get_workers(self) -> list[AccountV1]:
|
||||
async def get_workers(self) -> list[WorkerV0]:
|
||||
return await self._cleos.aget_table(self.receiver, self.receiver, 'workers', resp_cls=WorkerV0)
|
||||
|
||||
async def get_queue(self) -> RequestV1:
|
||||
return await self._cleos.aget_table(self.receiver, self.receiver, 'queue', resp_cls=RequestV1)
|
||||
async def get_queue(self) -> Request:
|
||||
return await self._cleos.aget_table(
|
||||
self.receiver, self.receiver, 'queue', resp_cls=RequestV1 if self.proto_version > 0 else RequestV0)
|
||||
|
||||
async def get_request(self, request_id: int) -> RequestV1:
|
||||
async def get_request(self, request_id: int) -> Request:
|
||||
rows = await self._cleos.aget_table(
|
||||
self.receiver, self.receiver, 'queue',
|
||||
lower_bound=request_id,
|
||||
upper_bound=request_id,
|
||||
resp_cls=RequestV1
|
||||
resp_cls=RequestV1 if self.proto_version > 0 else RequestV0
|
||||
)
|
||||
if len(rows) == 0:
|
||||
raise RequestNotFound(request_id)
|
||||
|
||||
return rows[0]
|
||||
|
||||
async def get_requests_since(self, seconds: int) -> list[RequestV1]:
|
||||
async def get_requests_since(self, seconds: int) -> list[Request]:
|
||||
return await self._cleos.aget_table(
|
||||
self.receiver, self.receiver, 'queue',
|
||||
index_position=2,
|
||||
key_type='i64',
|
||||
lower_bound=int(time.time()) - seconds,
|
||||
resp_cls=RequestV1
|
||||
resp_cls=RequestV1 if self.proto_version > 0 else RequestV0
|
||||
)
|
||||
|
||||
async def get_statuses_for_request(self, request_id: int) -> list[WorkerStatusV0]:
|
||||
|
@ -242,12 +245,17 @@ class GPUContractAPI:
|
|||
worker: str,
|
||||
request_id: int,
|
||||
result_hash: str,
|
||||
ipfs_hash: str
|
||||
ipfs_hash: str,
|
||||
request_hash: str | None = None
|
||||
):
|
||||
args = [worker, request_id, result_hash, ipfs_hash]
|
||||
if request_hash:
|
||||
args.insert(2, request_hash)
|
||||
|
||||
return await self._cleos.a_push_action(
|
||||
self.receiver,
|
||||
'submit',
|
||||
[worker, request_id, result_hash, ipfs_hash],
|
||||
args,
|
||||
worker,
|
||||
key=self._cleos.private_keys[worker]
|
||||
)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import logging
|
||||
import contextlib
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
|
@ -22,7 +24,7 @@ from skynet.dgpu.network import (
|
|||
async def maybe_update_tui_balance(contract: GPUContractAPI):
|
||||
async def _fn(tui):
|
||||
# update balance
|
||||
balance = await contract.get_user(tui.config.account).balance
|
||||
balance = (await contract.get_user(tui.config.account)).balance
|
||||
tui.set_header_text(new_balance=f'balance: {balance}')
|
||||
|
||||
await maybe_update_tui_async(_fn)
|
||||
|
@ -126,6 +128,8 @@ async def maybe_serve_one(
|
|||
used by torch each step of the inference, it will use a
|
||||
trio.from_thread to unblock the main thread and pump the event loop
|
||||
'''
|
||||
with open(os.devnull, 'w') as devnull:
|
||||
with contextlib.redirect_stdout(devnull):
|
||||
output_hash, output = await trio.to_thread.run_sync(
|
||||
partial(
|
||||
compute_one,
|
||||
|
@ -146,7 +150,12 @@ async def maybe_serve_one(
|
|||
|
||||
ipfs_hash = await ipfs_api.publish(output, type=output_type)
|
||||
|
||||
await contract.submit_work(config.account, req.id, output_hash, ipfs_hash)
|
||||
maybe_request_hash = None
|
||||
if config.proto_version == 0:
|
||||
maybe_request_hash = req.hash_v0()
|
||||
|
||||
await contract.submit_work(
|
||||
config.account, req.id, output_hash, ipfs_hash, request_hash=maybe_request_hash)
|
||||
|
||||
await state_mngr.update_state()
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from skynet.config import load_skynet_toml
|
|||
from skynet.contract import GPUContractAPI
|
||||
from skynet.types import (
|
||||
BodyV0,
|
||||
RequestV1,
|
||||
Request,
|
||||
WorkerStatusV0,
|
||||
ResultV0
|
||||
)
|
||||
|
@ -63,7 +63,7 @@ class ContractState:
|
|||
self._config = load_skynet_toml().dgpu
|
||||
self._poll_index = 0
|
||||
|
||||
self._queue: list[RequestV1] = []
|
||||
self._queue: list[Request] = []
|
||||
self._status_by_rid: dict[int, list[WorkerStatusV0]] = {}
|
||||
self._results: list[ResultV0] = []
|
||||
|
||||
|
@ -139,7 +139,7 @@ class ContractState:
|
|||
return len(self._queue)
|
||||
|
||||
@property
|
||||
def first(self) -> RequestV1 | None:
|
||||
def first(self) -> Request | None:
|
||||
if len(self._queue) > 0:
|
||||
return self._queue[0]
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from skynet.config import DgpuConfig as Config
|
|||
|
||||
|
||||
class WorkerMonitor:
|
||||
def __init__(self):
|
||||
def __init__(self, config: Config):
|
||||
self.requests = []
|
||||
self.header_info = {}
|
||||
|
||||
|
@ -63,6 +63,7 @@ class WorkerMonitor:
|
|||
event_loop=self.event_loop,
|
||||
unhandled_input=self._exit_on_q
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def _create_listbox_body(self, requests):
|
||||
"""
|
||||
|
@ -197,7 +198,7 @@ def init_tui(config: Config):
|
|||
global _tui
|
||||
assert not _tui
|
||||
setup_logging_for_tui(config)
|
||||
_tui = WorkerMonitor()
|
||||
_tui = WorkerMonitor(config)
|
||||
return _tui
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
|
||||
from msgspec import Struct
|
||||
|
||||
|
@ -58,6 +59,8 @@ class ConfigV1(Struct):
|
|||
token_symbol: str
|
||||
global_nonce: int
|
||||
|
||||
type Config = ConfigV0 | ConfigV1
|
||||
|
||||
'''
|
||||
RequestV0
|
||||
|
||||
|
@ -122,6 +125,16 @@ class RequestV0(Struct):
|
|||
binary_data: str
|
||||
timestamp: str
|
||||
|
||||
def hash_v0(self) -> str:
|
||||
hash_str = (
|
||||
str(self.nonce)
|
||||
+
|
||||
self.body
|
||||
+
|
||||
self.binary_data
|
||||
)
|
||||
return sha256(hash_str.encode('utf-8')).hexdigest()
|
||||
|
||||
'''
|
||||
RequestV1
|
||||
|
||||
|
@ -154,6 +167,8 @@ class RequestV1(Struct):
|
|||
binary_data: str
|
||||
timestamp: str
|
||||
|
||||
type Request = RequestV0 | RequestV1
|
||||
|
||||
|
||||
'''
|
||||
AccountV0
|
||||
|
@ -200,6 +215,7 @@ class AccountV1(Struct):
|
|||
user: str
|
||||
balance: str
|
||||
|
||||
type Account = AccountV0 | AccountV1
|
||||
|
||||
'''
|
||||
WorkerV0
|
||||
|
|
Loading…
Reference in New Issue