Add support for both protocol versions on worker daemon

rust_contract
Guillermo Rodriguez 2025-02-20 15:17:24 -03:00
parent 7edca49e95
commit 47dda50f32
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
6 changed files with 71 additions and 36 deletions

View File

@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct):
poll_time: float = 0.5 # wait time for polling updates from contract poll_time: float = 0.5 # wait time for polling updates from contract
log_level: str = 'info' log_level: str = 'info'
log_file: str = 'dgpu.log' # log file path (only used when tui = true) log_file: str = 'dgpu.log' # log file path (only used when tui = true)
proto_version: int = 0
class FrontendConfig(msgspec.Struct): class FrontendConfig(msgspec.Struct):

View File

@ -5,10 +5,10 @@ from leap import CLEOS
from leap.protocol import Name from leap.protocol import Name
from skynet.types import ( from skynet.types import (
ConfigV1, Config, ConfigV0, ConfigV1,
AccountV1, Account, AccountV0, AccountV1,
WorkerV0, WorkerV0,
RequestV1, Request, RequestV0, RequestV1,
BodyV0, BodyV0,
WorkerStatusV0, WorkerStatusV0,
ResultV0 ResultV0
@ -33,37 +33,39 @@ class WorkerStatusNotFound(BaseException):
class GPUContractAPI: class GPUContractAPI:
def __init__(self, cleos: CLEOS): def __init__(self, cleos: CLEOS, proto_version: int = 0):
self.receiver = 'gpu.scd' self.receiver = 'gpu.scd'
self._cleos = cleos self._cleos = cleos
self.proto_version = proto_version
# views into data # views into data
async def get_config(self) -> ConfigV1: async def get_config(self) -> Config:
rows = await self._cleos.aget_table( rows = await self._cleos.aget_table(
self.receiver, self.receiver, 'config', self.receiver, self.receiver, 'config',
resp_cls=ConfigV1 resp_cls=ConfigV1 if self.proto_version > 1 else ConfigV0
) )
if len(rows) == 0: if len(rows) == 0:
raise ConfigNotFound() raise ConfigNotFound()
return rows[0] return rows[0]
async def get_user(self, user: str) -> AccountV1: async def get_user(self, user: str) -> Account:
rows = await self._cleos.aget_table( rows = await self._cleos.aget_table(
self.receiver, self.receiver, 'users', self.receiver, self.receiver, 'users',
key_type='name', key_type='name',
lower_bound=user, lower_bound=user,
upper_bound=user, upper_bound=user,
resp_cls=AccountV1 resp_cls=AccountV1 if self.proto_version > 1 else AccountV0
) )
if len(rows) == 0: if len(rows) == 0:
raise AccountNotFound(user) raise AccountNotFound(user)
return rows[0] return rows[0]
async def get_users(self) -> list[AccountV1]: async def get_users(self) -> list[Account]:
return await self._cleos.aget_table(self.receiver, self.receiver, 'users', resp_cls=AccountV1) return await self._cleos.aget_table(
self.receiver, self.receiver, 'users', resp_cls=AccountV1 if self.proto_version > 0 else AccountV0)
async def get_worker(self, worker: str) -> WorkerV0: async def get_worker(self, worker: str) -> WorkerV0:
rows = await self._cleos.aget_table( rows = await self._cleos.aget_table(
@ -78,31 +80,32 @@ class GPUContractAPI:
return rows[0] return rows[0]
async def get_workers(self) -> list[AccountV1]: async def get_workers(self) -> list[WorkerV0]:
return await self._cleos.aget_table(self.receiver, self.receiver, 'workers', resp_cls=WorkerV0) return await self._cleos.aget_table(self.receiver, self.receiver, 'workers', resp_cls=WorkerV0)
async def get_queue(self) -> RequestV1: async def get_queue(self) -> Request:
return await self._cleos.aget_table(self.receiver, self.receiver, 'queue', resp_cls=RequestV1) return await self._cleos.aget_table(
self.receiver, self.receiver, 'queue', resp_cls=RequestV1 if self.proto_version > 0 else RequestV0)
async def get_request(self, request_id: int) -> RequestV1: async def get_request(self, request_id: int) -> Request:
rows = await self._cleos.aget_table( rows = await self._cleos.aget_table(
self.receiver, self.receiver, 'queue', self.receiver, self.receiver, 'queue',
lower_bound=request_id, lower_bound=request_id,
upper_bound=request_id, upper_bound=request_id,
resp_cls=RequestV1 resp_cls=RequestV1 if self.proto_version > 0 else RequestV0
) )
if len(rows) == 0: if len(rows) == 0:
raise RequestNotFound(request_id) raise RequestNotFound(request_id)
return rows[0] return rows[0]
async def get_requests_since(self, seconds: int) -> list[RequestV1]: async def get_requests_since(self, seconds: int) -> list[Request]:
return await self._cleos.aget_table( return await self._cleos.aget_table(
self.receiver, self.receiver, 'queue', self.receiver, self.receiver, 'queue',
index_position=2, index_position=2,
key_type='i64', key_type='i64',
lower_bound=int(time.time()) - seconds, lower_bound=int(time.time()) - seconds,
resp_cls=RequestV1 resp_cls=RequestV1 if self.proto_version > 0 else RequestV0
) )
async def get_statuses_for_request(self, request_id: int) -> list[WorkerStatusV0]: async def get_statuses_for_request(self, request_id: int) -> list[WorkerStatusV0]:
@ -242,12 +245,17 @@ class GPUContractAPI:
worker: str, worker: str,
request_id: int, request_id: int,
result_hash: str, result_hash: str,
ipfs_hash: str ipfs_hash: str,
request_hash: str | None = None
): ):
args = [worker, request_id, result_hash, ipfs_hash]
if request_hash:
args.insert(2, request_hash)
return await self._cleos.a_push_action( return await self._cleos.a_push_action(
self.receiver, self.receiver,
'submit', 'submit',
[worker, request_id, result_hash, ipfs_hash], args,
worker, worker,
key=self._cleos.private_keys[worker] key=self._cleos.private_keys[worker]
) )

View File

@ -1,4 +1,6 @@
import os
import logging import logging
import contextlib
from functools import partial from functools import partial
import trio import trio
@ -22,7 +24,7 @@ from skynet.dgpu.network import (
async def maybe_update_tui_balance(contract: GPUContractAPI): async def maybe_update_tui_balance(contract: GPUContractAPI):
async def _fn(tui): async def _fn(tui):
# update balance # update balance
balance = await contract.get_user(tui.config.account).balance balance = (await contract.get_user(tui.config.account)).balance
tui.set_header_text(new_balance=f'balance: {balance}') tui.set_header_text(new_balance=f'balance: {balance}')
await maybe_update_tui_async(_fn) await maybe_update_tui_async(_fn)
@ -126,16 +128,18 @@ async def maybe_serve_one(
used by torch each step of the inference, it will use a used by torch each step of the inference, it will use a
trio.from_thread to unblock the main thread and pump the event loop trio.from_thread to unblock the main thread and pump the event loop
''' '''
output_hash, output = await trio.to_thread.run_sync( with open(os.devnull, 'w') as devnull:
partial( with contextlib.redirect_stdout(devnull):
compute_one, output_hash, output = await trio.to_thread.run_sync(
model, partial(
req.id, compute_one,
mode, body.params, model,
inputs=inputs, req.id,
should_cancel=state_mngr.should_cancel_work, mode, body.params,
) inputs=inputs,
) should_cancel=state_mngr.should_cancel_work,
)
)
case _: case _:
raise DGPUComputeError( raise DGPUComputeError(
@ -146,7 +150,12 @@ async def maybe_serve_one(
ipfs_hash = await ipfs_api.publish(output, type=output_type) ipfs_hash = await ipfs_api.publish(output, type=output_type)
await contract.submit_work(config.account, req.id, output_hash, ipfs_hash) maybe_request_hash = None
if config.proto_version == 0:
maybe_request_hash = req.hash_v0()
await contract.submit_work(
config.account, req.id, output_hash, ipfs_hash, request_hash=maybe_request_hash)
await state_mngr.update_state() await state_mngr.update_state()

View File

@ -16,7 +16,7 @@ from skynet.config import load_skynet_toml
from skynet.contract import GPUContractAPI from skynet.contract import GPUContractAPI
from skynet.types import ( from skynet.types import (
BodyV0, BodyV0,
RequestV1, Request,
WorkerStatusV0, WorkerStatusV0,
ResultV0 ResultV0
) )
@ -63,7 +63,7 @@ class ContractState:
self._config = load_skynet_toml().dgpu self._config = load_skynet_toml().dgpu
self._poll_index = 0 self._poll_index = 0
self._queue: list[RequestV1] = [] self._queue: list[Request] = []
self._status_by_rid: dict[int, list[WorkerStatusV0]] = {} self._status_by_rid: dict[int, list[WorkerStatusV0]] = {}
self._results: list[ResultV0] = [] self._results: list[ResultV0] = []
@ -139,7 +139,7 @@ class ContractState:
return len(self._queue) return len(self._queue)
@property @property
def first(self) -> RequestV1 | None: def first(self) -> Request | None:
if len(self._queue) > 0: if len(self._queue) > 0:
return self._queue[0] return self._queue[0]

View File

@ -8,7 +8,7 @@ from skynet.config import DgpuConfig as Config
class WorkerMonitor: class WorkerMonitor:
def __init__(self): def __init__(self, config: Config):
self.requests = [] self.requests = []
self.header_info = {} self.header_info = {}
@ -63,6 +63,7 @@ class WorkerMonitor:
event_loop=self.event_loop, event_loop=self.event_loop,
unhandled_input=self._exit_on_q unhandled_input=self._exit_on_q
) )
self.config = config
def _create_listbox_body(self, requests): def _create_listbox_body(self, requests):
""" """
@ -197,7 +198,7 @@ def init_tui(config: Config):
global _tui global _tui
assert not _tui assert not _tui
setup_logging_for_tui(config) setup_logging_for_tui(config)
_tui = WorkerMonitor() _tui = WorkerMonitor(config)
return _tui return _tui

View File

@ -1,4 +1,5 @@
from enum import StrEnum from enum import StrEnum
from hashlib import sha256
from msgspec import Struct from msgspec import Struct
@ -58,6 +59,8 @@ class ConfigV1(Struct):
token_symbol: str token_symbol: str
global_nonce: int global_nonce: int
type Config = ConfigV0 | ConfigV1
''' '''
RequestV0 RequestV0
@ -122,6 +125,16 @@ class RequestV0(Struct):
binary_data: str binary_data: str
timestamp: str timestamp: str
def hash_v0(self) -> str:
hash_str = (
str(self.nonce)
+
self.body
+
self.binary_data
)
return sha256(hash_str.encode('utf-8')).hexdigest()
''' '''
RequestV1 RequestV1
@ -154,6 +167,8 @@ class RequestV1(Struct):
binary_data: str binary_data: str
timestamp: str timestamp: str
type Request = RequestV0 | RequestV1
''' '''
AccountV0 AccountV0
@ -200,6 +215,7 @@ class AccountV1(Struct):
user: str user: str
balance: str balance: str
type Account = AccountV0 | AccountV1
''' '''
WorkerV0 WorkerV0