Add CI, finish structification, break smart contract table sync logic to its own class with context manager, add better test helpers

structify
Guillermo Rodriguez 2025-02-11 16:20:26 -03:00
parent 08b6b983a2
commit 9eb46862ae
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
11 changed files with 482 additions and 157 deletions

31
.github/workflows/ci.yml vendored 100644
View File

@ -0,0 +1,31 @@
name: CI
on: [push]
jobs:
auto-tests:
name: Pytest Tests
runs-on: ubuntu-24.04
timeout-minutes: 10
steps:
- uses: actions/checkout@v2
with:
submodules: recursive
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v5
- uses: actions/cache@v3
name: Cache venv
with:
path: ./.venv
key: venv-${{ hashFiles('uv.lock') }}
- name: Install with dev
run: uv sync
- name: Run tests
run: |
uv run \
pytest \
tests/test_chain.py

View File

@ -1,5 +1,7 @@
from skynet.config import Config, DgpuConfig, set_config_override
from contextlib import asynccontextmanager as acm
from skynet.config import Config, DgpuConfig, set_config_override
from skynet.dgpu import open_worker
def override_dgpu_config(**kwargs) -> DgpuConfig:
config = Config(
@ -7,3 +9,24 @@ def override_dgpu_config(**kwargs) -> DgpuConfig:
)
set_config_override(config)
return config.dgpu
@acm
async def open_test_worker(
cleos, ipfs_node,
account: str = 'testworker',
permission: str = 'active',
hf_token: str = '',
**kwargs
):
config = override_dgpu_config(
account=account,
permission=permission,
key=cleos.private_keys[account],
node_url=cleos.endpoint,
ipfs_url=ipfs_node[1].endpoint,
hf_token=hf_token,
**kwargs
)
async with open_worker(config) as worker:
yield worker

View File

@ -7,7 +7,7 @@ import urwid
from skynet.config import Config
from skynet.dgpu.tui import init_tui
from skynet.dgpu.daemon import dgpu_serve_forever
from skynet.dgpu.network import NetConnector
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr
@acm
@ -20,21 +20,23 @@ async def open_worker(config: Config):
tui = init_tui(config)
conn = NetConnector(config)
try:
n: trio.Nursery
async with trio.open_nursery() as n:
if tui:
n.start_soon(tui.run)
async with maybe_open_contract_state_mngr(conn) as state_mngr:
n: trio.Nursery
async with trio.open_nursery() as n:
if tui:
n.start_soon(tui.run)
n.start_soon(conn.iter_poll_update, config.poll_time)
n.start_soon(dgpu_serve_forever, config, conn, state_mngr)
yield conn
yield conn, state_mngr
n.cancel_scope.cancel()
except *urwid.ExitMainLoop:
...
async def _dgpu_main(config: Config):
async with open_worker(config) as conn:
await dgpu_serve_forever(config, conn)
async with open_worker(config):
await trio.sleep_forever()

View File

@ -13,7 +13,7 @@ import trio
import torch
from skynet.config import load_skynet_toml
from skynet.types import ModelMode, BodyV0, BodyV0Params
from skynet.types import ModelMode, BodyV0Params
from skynet.dgpu.tui import maybe_update_tui
from skynet.dgpu.errors import (
DGPUComputeError,

View File

@ -1,5 +1,4 @@
import logging
import random
from functools import partial
from hashlib import sha256
@ -8,22 +7,16 @@ import msgspec
from skynet.config import DgpuConfig as Config
from skynet.types import (
RequestV0,
BodyV0
)
from skynet.constants import MODELS
from skynet.dgpu.errors import DGPUComputeError
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import NetConnector
def convert_reward_to_int(reward_str):
int_part, decimal_part = (
reward_str.split('.')[0],
reward_str.split('.')[1].split(' ')[0]
)
return int(int_part + decimal_part)
from skynet.dgpu.network import (
NetConnector,
ContractState,
)
async def maybe_update_tui_balance(conn: NetConnector):
@ -38,8 +31,14 @@ async def maybe_update_tui_balance(conn: NetConnector):
async def maybe_serve_one(
config: Config,
conn: NetConnector,
req: RequestV0,
state_mngr: ContractState,
):
req = state_mngr.first
# no requests in queue
if not req:
return
logging.info(f'maybe serve request #{req.id}')
# parse request
@ -69,18 +68,13 @@ async def maybe_serve_one(
logging.warning('model not blacklisted!, skip...')
return
results = [res['request_id'] for res in conn._tables['results']]
# if worker already produced a result for this request
if req.id in results:
if state_mngr.is_request_filled(req.id):
logging.info(f'worker already submitted a result for request #{req.id}, skip...')
return
statuses = conn._tables['requests'][req.id]
# skip if workers in non_compete already on it
competitors = set((status['worker'] for status in statuses))
if bool(config.non_compete & competitors):
if state_mngr.should_compete_for_id(req.id):
logging.info('worker in configured non_compete list already working on request, skip...')
return
@ -146,7 +140,7 @@ async def maybe_serve_one(
req.id,
mode, body.params,
inputs=inputs,
should_cancel=conn.should_cancel_work,
should_cancel=state_mngr.should_cancel_work,
)
)
@ -168,34 +162,28 @@ async def maybe_serve_one(
if 'network cancel' not in str(err):
logging.exception('Failed to serve model request !?\n')
if req.id in conn._tables['requests']:
if state_mngr.is_request_in_progress(req.id):
await conn.cancel_work(req.id, 'reason not provided')
async def dgpu_serve_forever(config: Config, conn: NetConnector):
async def dgpu_serve_forever(
config: Config,
conn: NetConnector,
state_mngr: ContractState
):
await maybe_update_tui_balance(conn)
last_poll_idx = -1
try:
while True:
await conn.wait_data_update()
if conn.poll_index == last_poll_idx:
await state_mngr.wait_data_update()
if state_mngr.poll_index == last_poll_idx:
await trio.sleep(config.poll_time)
continue
last_poll_idx = conn.poll_index
last_poll_idx = state_mngr.poll_index
queue = conn._tables['queue']
random.shuffle(queue)
queue = sorted(
queue,
key=lambda req: convert_reward_to_int(req['reward']),
reverse=True
)
if len(queue) > 0:
await maybe_serve_one(config, conn, queue[0])
await maybe_serve_one(config, conn, state_mngr)
except KeyboardInterrupt:
...

View File

@ -1,8 +1,10 @@
import io
import json
import time
import random
import logging
from pathlib import Path
from contextlib import asynccontextmanager as acm
from functools import partial
import trio
@ -10,12 +12,19 @@ import leap
import anyio
import httpx
import outcome
import msgspec
from PIL import Image
from leap.cleos import CLEOS
from leap.protocol import Asset
from skynet.dgpu.tui import maybe_update_tui
from skynet.config import DgpuConfig as Config
from skynet.types import RequestV0
from skynet.config import DgpuConfig as Config, load_skynet_toml
from skynet.types import (
ConfigV0,
BodyV0,
RequestV0,
WorkerStatusV0,
ResultV0
)
from skynet.constants import GPU_CONTRACT_ABI
from skynet.ipfs import (
@ -64,15 +73,6 @@ class NetConnector:
self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
# poll_index is used to detect stale data
self.poll_index = 0
self._tables = {
'queue': [],
'requests': {},
'results': []
}
self._data_event = trio.Event()
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
@ -93,22 +93,23 @@ class NetConnector:
logging.info(f'found {len(rows)} requests on queue')
return rows
async def get_status_by_request_id(self, request_id: int):
async def get_status_by_request_id(self, request_id: int) -> list[WorkerStatusV0]:
logging.info('get_status_by_request_id')
rows = await failable(
partial(
self.cleos.aget_table,
'gpu.scd', request_id, 'status'), ret_fail=[])
'gpu.scd', request_id, 'status', resp_cls=WorkerStatusV0), ret_fail=[])
logging.info(f'found status for workers: {[r["worker"] for r in rows]}')
logging.info(f'found status for workers: {[r.worker for r in rows]}')
return rows
async def get_global_config(self):
async def get_global_config(self) -> ConfigV0:
logging.info('get_global_config')
rows = await failable(
partial(
self.cleos.aget_table,
'gpu.scd', 'gpu.scd', 'config'))
'gpu.scd', 'gpu.scd', 'config',
resp_cls=ConfigV0))
if rows:
cfg = rows[0]
@ -118,7 +119,7 @@ class NetConnector:
logging.error('global config not found, is the contract initialized?')
return None
async def get_worker_balance(self):
async def get_worker_balance(self) -> str:
logging.info('get_worker_balance')
rows = await failable(
partial(
@ -131,73 +132,13 @@ class NetConnector:
))
if rows:
b = rows[0]['balance']
b = rows[0].balance
logging.info(f'balance: {b}')
return b
else:
logging.info('no balance info found')
return None
async def get_full_queue_snapshot(self):
'''
Get a "snapshot" of current contract table state
'''
snap = {
'requests': {},
'results': []
}
snap['queue'] = await self.get_work_requests_last_hour()
async def _run_and_save(d, key: str, fn, *args, **kwargs):
d[key] = await fn(*args, **kwargs)
async with trio.open_nursery() as n:
n.start_soon(_run_and_save, snap, 'results', self.find_results)
for req in snap['queue']:
n.start_soon(
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
maybe_update_tui(lambda tui: tui.network_update(snap))
return snap
async def wait_data_update(self):
await self._data_event.wait()
async def iter_poll_update(self, poll_time: float):
'''
Long running task, polls gpu contract tables latest table rows,
awakes any self._data_event waiters
'''
while True:
start_time = time.time()
self._tables = await self.get_full_queue_snapshot()
elapsed = time.time() - start_time
self._data_event.set()
await trio.sleep(max(poll_time - elapsed, 0.1))
self._data_event = trio.Event()
self.poll_index += 1
async def should_cancel_work(self, request_id: int) -> bool:
logging.info('should cancel work?')
if request_id not in self._tables['requests']:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
competitors = set([
status['worker']
for status in self._tables['requests'][request_id]
if status['worker'] != self.config.account
])
logging.info(f'competitors: {competitors}')
should_cancel = bool(self.config.non_compete & competitors)
logging.info(f'cancel: {should_cancel}')
return should_cancel
async def begin_work(self, request_id: int):
'''
Publish to the bc that the worker is beginning a model-computation
@ -259,7 +200,7 @@ class NetConnector:
)
)
async def find_results(self):
async def find_results(self) -> list[ResultV0]:
logging.info('find_results')
rows = await failable(
partial(
@ -268,7 +209,8 @@ class NetConnector:
index_position=4,
key_type='name',
lower_bound=self.config.account,
upper_bound=self.config.account
upper_bound=self.config.account,
resp_cls=ResultV0
)
)
return rows
@ -342,3 +284,157 @@ class NetConnector:
logging.info('decoded as image successfully')
return input_data
def convert_reward_to_int(reward_str):
int_part, decimal_part = (
reward_str.split('.')[0],
reward_str.split('.')[1].split(' ')[0]
)
return int(int_part + decimal_part)
class ContractState:
def __init__(self, conn: NetConnector):
self._conn = conn
self._poll_index = 0
self._queue: list[RequestV0] = []
self._status_by_rid: dict[int, list[WorkerStatusV0]] = {}
self._results: list[ResultV0] = []
self._new_data = trio.Event()
@property
def poll_index(self) -> int:
return self._poll_index
async def _fetch_results(self):
self._results = await self._conn.find_results()
async def _fetch_statuses_for_id(self, rid: int):
self._status_by_rid[rid] = await self._conn.get_status_by_request_id(rid)
async def update_state(self):
'''
Get a "snapshot" of current contract table state
'''
# raw queue from chain
_queue = await self._conn.get_work_requests_last_hour()
# filter out invalids
self._queue = []
for req in _queue:
try:
msgspec.json.decode(req.body, type=BodyV0)
self._queue.append(req)
except msgspec.ValidationError:
...
random.shuffle(self._queue)
self._queue = sorted(
self._queue,
key=lambda req: convert_reward_to_int(req.reward),
reverse=True
)
async with trio.open_nursery() as n:
n.start_soon(self._fetch_results)
for req in self._queue:
n.start_soon(
self._fetch_statuses_for_id, req.id)
maybe_update_tui(lambda tui: tui.network_update(self))
async def wait_data_update(self):
await self._new_data.wait()
async def _state_update_task(self, poll_time: float):
'''
Long running task, polls gpu contract tables latest table rows,
awakes any self._data_event waiters
'''
while True:
start_time = time.time()
await self.update_state()
elapsed = time.time() - start_time
self._new_data.set()
await trio.sleep(max(poll_time - elapsed, 0.1))
self._new_data = trio.Event()
self._poll_index += 1
# views into data
@property
def queue_len(self) -> int:
return len(self._queue)
@property
def first(self) -> RequestV0 | None:
if len(self._queue) > 0:
return self._queue[0]
else:
return None
def competitors_for_id(self, request_id: int) -> set[str]:
return set((
status.worker
for status in self._status_by_rid[request_id]
if status.worker != self._conn.config.account
))
# predicates
def is_request_filled(self, request_id: int) -> bool:
return request_id in [
result.request_id for result in self._results
]
def is_request_in_progress(self, request_id: int) -> bool:
return request_id in self._status_by_rid
def should_compete_for_id(self, request_id: int) -> bool:
return bool(
self._conn.config.non_compete &
self.competitors_for_id(request_id)
)
async def should_cancel_work(self, request_id: int) -> bool:
logging.info('should cancel work?')
if request_id not in self._status_by_rid:
logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
return True
should_cancel = self.should_compete_for_id(request_id)
logging.info(f'cancel: {should_cancel}')
return should_cancel
__state_mngr = None
@acm
async def maybe_open_contract_state_mngr(conn: NetConnector):
global __state_mngr
if __state_mngr:
yield __state_mngr
return
config = load_skynet_toml().dgpu
mngr = ContractState(conn)
async with trio.open_nursery() as n:
await mngr.update_state()
n.start_soon(mngr._state_update_task, config.poll_time)
__state_mngr = mngr
yield mngr
n.cancel_scope.cancel()

View File

@ -2,7 +2,6 @@ import json
import logging
import warnings
import trio
import urwid
from skynet.config import DgpuConfig as Config
@ -157,14 +156,14 @@ class WorkerMonitor:
if new_balance is not None:
self.balance_widget.set_text(new_balance)
def network_update(self, snapshot: dict):
def network_update(self, state_mngr):
queue = [
{
**r,
**(json.loads(r['body'])['params']),
'workers': [s['worker'] for s in snapshot['requests'][r['id']]]
**(json.loads(r.body)['params']),
'workers': [s.worker for s in state_mngr._status_by_rid[r.id]]
}
for r in snapshot['queue']
for r in state_mngr.queue
]
self.update_requests(queue)

View File

@ -1,7 +1,6 @@
import io
import os
import sys
import time
import random
import logging
import importlib

View File

@ -17,6 +17,62 @@ class ModelDesc(Struct):
attrs: dict # additional mode specific attrs
tags: list[ModelMode]
# smart contract types
# https://github.com/guilledk/telos.contracts/blob/gpu_contracts/contracts/telos.gpu/include/telos.gpu/telos.gpu.hpp
'''
ConfigV0
singleton containing global info about system, definition:
```c++
struct [[eosio::table]] global_configuration_struct {
name token_contract;
symbol token_symbol;
} global_config_row;
typedef eosio::singleton<"config"_n, global_configuration_struct> global_config;
```
'''
class ConfigV0:
token_contract: str
token_symbol: str
'''
RequestV0
a request placed on the queue, definition:
scope: get_self()
```c++
struct [[eosio::table]] work_request_struct {
uint64_t id;
name user;
asset reward;
uint32_t min_verification;
uint64_t nonce;
string body;
string binary_data;
time_point_sec timestamp;
uint64_t primary_key() const { return id; }
uint64_t by_time() const { return (uint64_t)timestamp.sec_since_epoch(); }
};
typedef eosio::multi_index<
"queue"_n,
work_request_struct,
indexed_by<
"bytime"_n, const_mem_fun<work_request_struct, uint64_t, &work_request_struct::by_time>
>
> work_queue;
```
'''
class BodyV0Params(Struct):
prompt: str
@ -45,3 +101,130 @@ class RequestV0(Struct):
body: str
binary_data: str
timestamp: str
'''
AccountV0
a user account, users must deposit tokens in order to enqueue requests, definition:
scope: get_self()
```c++
struct [[eosio::table]] account {
name user;
asset balance;
uint64_t nonce;
uint64_t primary_key()const { return user.value; }
};
typedef eosio::multi_index<"users"_n, account> users;
```
'''
class AccountV0(Struct):
user: str
balance: str
nonce: int
'''
WorkerV0
a registered worker info, definition:
scope: get_self()
```c++
struct [[eosio::table]] worker {
name account;
time_point_sec joined;
time_point_sec left;
string url;
uint64_t primary_key()const { return account.value; }
};
typedef eosio::multi_index<"workers"_n, worker> workers;
```
'''
class WorkerV0(Struct):
account: str
joined: str
left: str
url: str
'''
WorkerStatusV0
a worker's status related to a currently in progress fill, definition:
scope: request id
```c++
struct [[eosio::table]] worker_status_struct {
name worker;
string status;
time_point_sec started;
uint64_t primary_key() const { return worker.value; }
};
```
'''
class WorkerStatusV0(Struct):
worker: str
status: str
started: str
'''
ResultV0
a submited result related to a request, definition:
scope: get_self()
```c++
struct [[eosio::table]] work_result_struct {
uint64_t id;
uint64_t request_id;
name user;
name worker;
checksum256 result_hash;
string ipfs_hash;
time_point_sec submited;
uint64_t primary_key() const { return id; }
uint64_t by_request_id() const { return request_id; }
checksum256 by_result_hash() const { return result_hash; }
uint64_t by_worker() const { return worker.value; }
uint64_t by_time() const { return (uint64_t)submited.sec_since_epoch(); }
};
typedef eosio::multi_index<
"results"_n,
work_result_struct,
indexed_by<
"byreqid"_n, const_mem_fun<work_result_struct, uint64_t, &work_result_struct::by_request_id>
>,
indexed_by<
"byresult"_n, const_mem_fun<work_result_struct, checksum256, &work_result_struct::by_result_hash>
>,
indexed_by<
"byworker"_n, const_mem_fun<work_result_struct, uint64_t, &work_result_struct::by_worker>
>,
indexed_by<
"bytime"_n, const_mem_fun<work_result_struct, uint64_t, &work_result_struct::by_time>
>
> work_results;
```
'''
class ResultV0(Struct):
id: int
request_id: int
user: str
worker: str
result_hash: str
ipfs_hash: str
submited: str

View File

@ -19,9 +19,9 @@ def postgres_db():
def skynet_cleos(cleos_bs):
cleos = cleos_bs
priv, pub = cleos.create_key_pair()
cleos.import_key('gpu.scd', priv)
cleos.new_account('gpu.scd', ram=4200000, key=pub)
# priv, pub = cleos.create_key_pair()
# cleos.import_key('gpu.scd', priv)
cleos.new_account('gpu.scd', ram=4200000)
cleos.deploy_contract_from_path(
'gpu.scd',
@ -36,6 +36,8 @@ def skynet_cleos(cleos_bs):
'gpu.scd'
)
cleos.new_account('testworker')
yield cleos
@ -52,3 +54,18 @@ def inject_mockers():
)
yield
@pytest.fixture(scope='session')
def ipfs_node(dockerctl):
rpc_port = 15001
with dockerctl.run(
'ipfs/go-ipfs:latest',
name='skynet-ipfs',
ports={
'8080/tcp': 18080,
'4001/tcp': 14001,
'5001/tcp': ('127.0.0.1', rpc_port)
}
) as cntr:
yield cntr, AsyncIPFSHTTP(f'http://127.0.0.1:{rpc_port}')

View File

@ -1,12 +1,12 @@
import trio
from msgspec import json
from skynet.types import BodyV0, BodyV0Params
from skynet.dgpu.network import NetConnector
from skynet._testing import override_dgpu_config
from skynet._testing import open_test_worker
async def test_enqueue(skynet_cleos):
async def test_full_flow(inject_mockers, skynet_cleos, ipfs_node):
cleos = skynet_cleos
# create account and deposit tokens into gpu
@ -46,20 +46,7 @@ async def test_enqueue(skynet_cleos):
key=cleos.private_keys[account]
)
config = override_dgpu_config(
account='testworker1',
permission='active',
key='',
node_url=cleos.endpoint,
ipfs_url='http://127.0.0.1:5001',
hf_token=''
)
net = NetConnector(config)
queue = await net.get_work_requests_last_hour()
assert len(queue) == 1
req = queue[0]
body = json.decode(req.body, type=BodyV0)
assert og_body == body
# open worker and fill request
async with open_test_worker(cleos, ipfs_node) as (_conn, state_mngr):
while state_mngr.queue_len > 0:
await trio.sleep(1)