From 47dda50f32769dfd80656c2af38e10de6c8e9ab7 Mon Sep 17 00:00:00 2001 From: Guillermo Rodriguez Date: Thu, 20 Feb 2025 15:17:24 -0300 Subject: [PATCH] Add support for both protocol versions on worker daemon --- skynet/config.py | 1 + skynet/contract.py | 46 +++++++++++++++++++++++++----------------- skynet/dgpu/daemon.py | 33 +++++++++++++++++++----------- skynet/dgpu/network.py | 6 +++--- skynet/dgpu/tui.py | 5 +++-- skynet/types.py | 16 +++++++++++++++ 6 files changed, 71 insertions(+), 36 deletions(-) diff --git a/skynet/config.py b/skynet/config.py index dd454b7..310bf56 100755 --- a/skynet/config.py +++ b/skynet/config.py @@ -26,6 +26,7 @@ class DgpuConfig(msgspec.Struct): poll_time: float = 0.5 # wait time for polling updates from contract log_level: str = 'info' log_file: str = 'dgpu.log' # log file path (only used when tui = true) + proto_version: int = 0 class FrontendConfig(msgspec.Struct): diff --git a/skynet/contract.py b/skynet/contract.py index 11e0412..67ed461 100644 --- a/skynet/contract.py +++ b/skynet/contract.py @@ -5,10 +5,10 @@ from leap import CLEOS from leap.protocol import Name from skynet.types import ( - ConfigV1, - AccountV1, + Config, ConfigV0, ConfigV1, + Account, AccountV0, AccountV1, WorkerV0, - RequestV1, + Request, RequestV0, RequestV1, BodyV0, WorkerStatusV0, ResultV0 @@ -33,37 +33,39 @@ class WorkerStatusNotFound(BaseException): class GPUContractAPI: - def __init__(self, cleos: CLEOS): + def __init__(self, cleos: CLEOS, proto_version: int = 0): self.receiver = 'gpu.scd' self._cleos = cleos + self.proto_version = proto_version # views into data - async def get_config(self) -> ConfigV1: + async def get_config(self) -> Config: rows = await self._cleos.aget_table( self.receiver, self.receiver, 'config', - resp_cls=ConfigV1 + resp_cls=ConfigV1 if self.proto_version > 1 else ConfigV0 ) if len(rows) == 0: raise ConfigNotFound() return rows[0] - async def get_user(self, user: str) -> AccountV1: + async def get_user(self, user: str) -> Account: rows = await self._cleos.aget_table( self.receiver, self.receiver, 'users', key_type='name', lower_bound=user, upper_bound=user, - resp_cls=AccountV1 + resp_cls=AccountV1 if self.proto_version > 1 else AccountV0 ) if len(rows) == 0: raise AccountNotFound(user) return rows[0] - async def get_users(self) -> list[AccountV1]: - return await self._cleos.aget_table(self.receiver, self.receiver, 'users', resp_cls=AccountV1) + async def get_users(self) -> list[Account]: + return await self._cleos.aget_table( + self.receiver, self.receiver, 'users', resp_cls=AccountV1 if self.proto_version > 0 else AccountV0) async def get_worker(self, worker: str) -> WorkerV0: rows = await self._cleos.aget_table( @@ -78,31 +80,32 @@ class GPUContractAPI: return rows[0] - async def get_workers(self) -> list[AccountV1]: + async def get_workers(self) -> list[WorkerV0]: return await self._cleos.aget_table(self.receiver, self.receiver, 'workers', resp_cls=WorkerV0) - async def get_queue(self) -> RequestV1: - return await self._cleos.aget_table(self.receiver, self.receiver, 'queue', resp_cls=RequestV1) + async def get_queue(self) -> Request: + return await self._cleos.aget_table( + self.receiver, self.receiver, 'queue', resp_cls=RequestV1 if self.proto_version > 0 else RequestV0) - async def get_request(self, request_id: int) -> RequestV1: + async def get_request(self, request_id: int) -> Request: rows = await self._cleos.aget_table( self.receiver, self.receiver, 'queue', lower_bound=request_id, upper_bound=request_id, - resp_cls=RequestV1 + resp_cls=RequestV1 if self.proto_version > 0 else RequestV0 ) if len(rows) == 0: raise RequestNotFound(request_id) return rows[0] - async def get_requests_since(self, seconds: int) -> list[RequestV1]: + async def get_requests_since(self, seconds: int) -> list[Request]: return await self._cleos.aget_table( self.receiver, self.receiver, 'queue', index_position=2, key_type='i64', lower_bound=int(time.time()) - seconds, - resp_cls=RequestV1 + resp_cls=RequestV1 if self.proto_version > 0 else RequestV0 ) async def get_statuses_for_request(self, request_id: int) -> list[WorkerStatusV0]: @@ -242,12 +245,17 @@ class GPUContractAPI: worker: str, request_id: int, result_hash: str, - ipfs_hash: str + ipfs_hash: str, + request_hash: str | None = None ): + args = [worker, request_id, result_hash, ipfs_hash] + if request_hash: + args.insert(2, request_hash) + return await self._cleos.a_push_action( self.receiver, 'submit', - [worker, request_id, result_hash, ipfs_hash], + args, worker, key=self._cleos.private_keys[worker] ) diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 6d62231..89138d6 100755 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -1,4 +1,6 @@ +import os import logging +import contextlib from functools import partial import trio @@ -22,7 +24,7 @@ from skynet.dgpu.network import ( async def maybe_update_tui_balance(contract: GPUContractAPI): async def _fn(tui): # update balance - balance = await contract.get_user(tui.config.account).balance + balance = (await contract.get_user(tui.config.account)).balance tui.set_header_text(new_balance=f'balance: {balance}') await maybe_update_tui_async(_fn) @@ -126,16 +128,18 @@ async def maybe_serve_one( used by torch each step of the inference, it will use a trio.from_thread to unblock the main thread and pump the event loop ''' - output_hash, output = await trio.to_thread.run_sync( - partial( - compute_one, - model, - req.id, - mode, body.params, - inputs=inputs, - should_cancel=state_mngr.should_cancel_work, - ) - ) + with open(os.devnull, 'w') as devnull: + with contextlib.redirect_stdout(devnull): + output_hash, output = await trio.to_thread.run_sync( + partial( + compute_one, + model, + req.id, + mode, body.params, + inputs=inputs, + should_cancel=state_mngr.should_cancel_work, + ) + ) case _: raise DGPUComputeError( @@ -146,7 +150,12 @@ async def maybe_serve_one( ipfs_hash = await ipfs_api.publish(output, type=output_type) - await contract.submit_work(config.account, req.id, output_hash, ipfs_hash) + maybe_request_hash = None + if config.proto_version == 0: + maybe_request_hash = req.hash_v0() + + await contract.submit_work( + config.account, req.id, output_hash, ipfs_hash, request_hash=maybe_request_hash) await state_mngr.update_state() diff --git a/skynet/dgpu/network.py b/skynet/dgpu/network.py index 52ff9a7..0ee616a 100755 --- a/skynet/dgpu/network.py +++ b/skynet/dgpu/network.py @@ -16,7 +16,7 @@ from skynet.config import load_skynet_toml from skynet.contract import GPUContractAPI from skynet.types import ( BodyV0, - RequestV1, + Request, WorkerStatusV0, ResultV0 ) @@ -63,7 +63,7 @@ class ContractState: self._config = load_skynet_toml().dgpu self._poll_index = 0 - self._queue: list[RequestV1] = [] + self._queue: list[Request] = [] self._status_by_rid: dict[int, list[WorkerStatusV0]] = {} self._results: list[ResultV0] = [] @@ -139,7 +139,7 @@ class ContractState: return len(self._queue) @property - def first(self) -> RequestV1 | None: + def first(self) -> Request | None: if len(self._queue) > 0: return self._queue[0] diff --git a/skynet/dgpu/tui.py b/skynet/dgpu/tui.py index ce8db5c..67976ee 100644 --- a/skynet/dgpu/tui.py +++ b/skynet/dgpu/tui.py @@ -8,7 +8,7 @@ from skynet.config import DgpuConfig as Config class WorkerMonitor: - def __init__(self): + def __init__(self, config: Config): self.requests = [] self.header_info = {} @@ -63,6 +63,7 @@ class WorkerMonitor: event_loop=self.event_loop, unhandled_input=self._exit_on_q ) + self.config = config def _create_listbox_body(self, requests): """ @@ -197,7 +198,7 @@ def init_tui(config: Config): global _tui assert not _tui setup_logging_for_tui(config) - _tui = WorkerMonitor() + _tui = WorkerMonitor(config) return _tui diff --git a/skynet/types.py b/skynet/types.py index 4b44c1d..cff31b4 100644 --- a/skynet/types.py +++ b/skynet/types.py @@ -1,4 +1,5 @@ from enum import StrEnum +from hashlib import sha256 from msgspec import Struct @@ -58,6 +59,8 @@ class ConfigV1(Struct): token_symbol: str global_nonce: int +type Config = ConfigV0 | ConfigV1 + ''' RequestV0 @@ -122,6 +125,16 @@ class RequestV0(Struct): binary_data: str timestamp: str + def hash_v0(self) -> str: + hash_str = ( + str(self.nonce) + + + self.body + + + self.binary_data + ) + return sha256(hash_str.encode('utf-8')).hexdigest() + ''' RequestV1 @@ -154,6 +167,8 @@ class RequestV1(Struct): binary_data: str timestamp: str +type Request = RequestV0 | RequestV1 + ''' AccountV0 @@ -200,6 +215,7 @@ class AccountV1(Struct): user: str balance: str +type Account = AccountV0 | AccountV1 ''' WorkerV0