Add support for both protocol versions on worker daemon

rust_contract
Guillermo Rodriguez 2025-02-20 15:17:24 -03:00
parent 7edca49e95
commit 47dda50f32
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
6 changed files with 71 additions and 36 deletions

View File

@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct):
poll_time: float = 0.5 # wait time for polling updates from contract
log_level: str = 'info'
log_file: str = 'dgpu.log' # log file path (only used when tui = true)
proto_version: int = 0
class FrontendConfig(msgspec.Struct):

View File

@ -5,10 +5,10 @@ from leap import CLEOS
from leap.protocol import Name
from skynet.types import (
ConfigV1,
AccountV1,
Config, ConfigV0, ConfigV1,
Account, AccountV0, AccountV1,
WorkerV0,
RequestV1,
Request, RequestV0, RequestV1,
BodyV0,
WorkerStatusV0,
ResultV0
@ -33,37 +33,39 @@ class WorkerStatusNotFound(BaseException):
class GPUContractAPI:
def __init__(self, cleos: CLEOS):
def __init__(self, cleos: CLEOS, proto_version: int = 0):
self.receiver = 'gpu.scd'
self._cleos = cleos
self.proto_version = proto_version
# views into data
async def get_config(self) -> ConfigV1:
async def get_config(self) -> Config:
rows = await self._cleos.aget_table(
self.receiver, self.receiver, 'config',
resp_cls=ConfigV1
resp_cls=ConfigV1 if self.proto_version > 1 else ConfigV0
)
if len(rows) == 0:
raise ConfigNotFound()
return rows[0]
async def get_user(self, user: str) -> AccountV1:
async def get_user(self, user: str) -> Account:
rows = await self._cleos.aget_table(
self.receiver, self.receiver, 'users',
key_type='name',
lower_bound=user,
upper_bound=user,
resp_cls=AccountV1
resp_cls=AccountV1 if self.proto_version > 1 else AccountV0
)
if len(rows) == 0:
raise AccountNotFound(user)
return rows[0]
async def get_users(self) -> list[AccountV1]:
return await self._cleos.aget_table(self.receiver, self.receiver, 'users', resp_cls=AccountV1)
async def get_users(self) -> list[Account]:
return await self._cleos.aget_table(
self.receiver, self.receiver, 'users', resp_cls=AccountV1 if self.proto_version > 0 else AccountV0)
async def get_worker(self, worker: str) -> WorkerV0:
rows = await self._cleos.aget_table(
@ -78,31 +80,32 @@ class GPUContractAPI:
return rows[0]
async def get_workers(self) -> list[AccountV1]:
async def get_workers(self) -> list[WorkerV0]:
return await self._cleos.aget_table(self.receiver, self.receiver, 'workers', resp_cls=WorkerV0)
async def get_queue(self) -> RequestV1:
return await self._cleos.aget_table(self.receiver, self.receiver, 'queue', resp_cls=RequestV1)
async def get_queue(self) -> Request:
return await self._cleos.aget_table(
self.receiver, self.receiver, 'queue', resp_cls=RequestV1 if self.proto_version > 0 else RequestV0)
async def get_request(self, request_id: int) -> RequestV1:
async def get_request(self, request_id: int) -> Request:
rows = await self._cleos.aget_table(
self.receiver, self.receiver, 'queue',
lower_bound=request_id,
upper_bound=request_id,
resp_cls=RequestV1
resp_cls=RequestV1 if self.proto_version > 0 else RequestV0
)
if len(rows) == 0:
raise RequestNotFound(request_id)
return rows[0]
async def get_requests_since(self, seconds: int) -> list[RequestV1]:
async def get_requests_since(self, seconds: int) -> list[Request]:
return await self._cleos.aget_table(
self.receiver, self.receiver, 'queue',
index_position=2,
key_type='i64',
lower_bound=int(time.time()) - seconds,
resp_cls=RequestV1
resp_cls=RequestV1 if self.proto_version > 0 else RequestV0
)
async def get_statuses_for_request(self, request_id: int) -> list[WorkerStatusV0]:
@ -242,12 +245,17 @@ class GPUContractAPI:
worker: str,
request_id: int,
result_hash: str,
ipfs_hash: str
ipfs_hash: str,
request_hash: str | None = None
):
args = [worker, request_id, result_hash, ipfs_hash]
if request_hash:
args.insert(2, request_hash)
return await self._cleos.a_push_action(
self.receiver,
'submit',
[worker, request_id, result_hash, ipfs_hash],
args,
worker,
key=self._cleos.private_keys[worker]
)

View File

@ -1,4 +1,6 @@
import os
import logging
import contextlib
from functools import partial
import trio
@ -22,7 +24,7 @@ from skynet.dgpu.network import (
async def maybe_update_tui_balance(contract: GPUContractAPI):
async def _fn(tui):
# update balance
balance = await contract.get_user(tui.config.account).balance
balance = (await contract.get_user(tui.config.account)).balance
tui.set_header_text(new_balance=f'balance: {balance}')
await maybe_update_tui_async(_fn)
@ -126,16 +128,18 @@ async def maybe_serve_one(
used by torch each step of the inference, it will use a
trio.from_thread to unblock the main thread and pump the event loop
'''
output_hash, output = await trio.to_thread.run_sync(
partial(
compute_one,
model,
req.id,
mode, body.params,
inputs=inputs,
should_cancel=state_mngr.should_cancel_work,
)
)
with open(os.devnull, 'w') as devnull:
with contextlib.redirect_stdout(devnull):
output_hash, output = await trio.to_thread.run_sync(
partial(
compute_one,
model,
req.id,
mode, body.params,
inputs=inputs,
should_cancel=state_mngr.should_cancel_work,
)
)
case _:
raise DGPUComputeError(
@ -146,7 +150,12 @@ async def maybe_serve_one(
ipfs_hash = await ipfs_api.publish(output, type=output_type)
await contract.submit_work(config.account, req.id, output_hash, ipfs_hash)
maybe_request_hash = None
if config.proto_version == 0:
maybe_request_hash = req.hash_v0()
await contract.submit_work(
config.account, req.id, output_hash, ipfs_hash, request_hash=maybe_request_hash)
await state_mngr.update_state()

View File

@ -16,7 +16,7 @@ from skynet.config import load_skynet_toml
from skynet.contract import GPUContractAPI
from skynet.types import (
BodyV0,
RequestV1,
Request,
WorkerStatusV0,
ResultV0
)
@ -63,7 +63,7 @@ class ContractState:
self._config = load_skynet_toml().dgpu
self._poll_index = 0
self._queue: list[RequestV1] = []
self._queue: list[Request] = []
self._status_by_rid: dict[int, list[WorkerStatusV0]] = {}
self._results: list[ResultV0] = []
@ -139,7 +139,7 @@ class ContractState:
return len(self._queue)
@property
def first(self) -> RequestV1 | None:
def first(self) -> Request | None:
if len(self._queue) > 0:
return self._queue[0]

View File

@ -8,7 +8,7 @@ from skynet.config import DgpuConfig as Config
class WorkerMonitor:
def __init__(self):
def __init__(self, config: Config):
self.requests = []
self.header_info = {}
@ -63,6 +63,7 @@ class WorkerMonitor:
event_loop=self.event_loop,
unhandled_input=self._exit_on_q
)
self.config = config
def _create_listbox_body(self, requests):
"""
@ -197,7 +198,7 @@ def init_tui(config: Config):
global _tui
assert not _tui
setup_logging_for_tui(config)
_tui = WorkerMonitor()
_tui = WorkerMonitor(config)
return _tui

View File

@ -1,4 +1,5 @@
from enum import StrEnum
from hashlib import sha256
from msgspec import Struct
@ -58,6 +59,8 @@ class ConfigV1(Struct):
token_symbol: str
global_nonce: int
type Config = ConfigV0 | ConfigV1
'''
RequestV0
@ -122,6 +125,16 @@ class RequestV0(Struct):
binary_data: str
timestamp: str
def hash_v0(self) -> str:
hash_str = (
str(self.nonce)
+
self.body
+
self.binary_data
)
return sha256(hash_str.encode('utf-8')).hexdigest()
'''
RequestV1
@ -154,6 +167,8 @@ class RequestV1(Struct):
binary_data: str
timestamp: str
type Request = RequestV0 | RequestV1
'''
AccountV0
@ -200,6 +215,7 @@ class AccountV1(Struct):
user: str
balance: str
type Account = AccountV0 | AccountV1
'''
WorkerV0