mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add CI, finish structification, break smart contract table sync logic to its own class with context manager, add better test helpers
							parent
							
								
									08b6b983a2
								
							
						
					
					
						commit
						9eb46862ae
					
				| 
						 | 
				
			
			@ -0,0 +1,31 @@
 | 
			
		|||
name: CI
 | 
			
		||||
 | 
			
		||||
on: [push]
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  auto-tests:
 | 
			
		||||
    name: Pytest Tests
 | 
			
		||||
    runs-on: ubuntu-24.04
 | 
			
		||||
    timeout-minutes: 10
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@v2
 | 
			
		||||
        with:
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
 | 
			
		||||
      - name: Install the latest version of uv
 | 
			
		||||
        uses: astral-sh/setup-uv@v5
 | 
			
		||||
 | 
			
		||||
      - uses: actions/cache@v3
 | 
			
		||||
        name: Cache venv
 | 
			
		||||
        with:
 | 
			
		||||
          path: ./.venv
 | 
			
		||||
          key: venv-${{ hashFiles('uv.lock') }}
 | 
			
		||||
 | 
			
		||||
      - name: Install with dev
 | 
			
		||||
        run: uv sync
 | 
			
		||||
 | 
			
		||||
      - name: Run tests
 | 
			
		||||
        run: |
 | 
			
		||||
          uv run \
 | 
			
		||||
            pytest \
 | 
			
		||||
              tests/test_chain.py
 | 
			
		||||
| 
						 | 
				
			
			@ -1,5 +1,7 @@
 | 
			
		|||
from skynet.config import Config, DgpuConfig, set_config_override
 | 
			
		||||
from contextlib import asynccontextmanager as acm
 | 
			
		||||
 | 
			
		||||
from skynet.config import Config, DgpuConfig, set_config_override
 | 
			
		||||
from skynet.dgpu import open_worker
 | 
			
		||||
 | 
			
		||||
def override_dgpu_config(**kwargs) -> DgpuConfig:
 | 
			
		||||
    config = Config(
 | 
			
		||||
| 
						 | 
				
			
			@ -7,3 +9,24 @@ def override_dgpu_config(**kwargs) -> DgpuConfig:
 | 
			
		|||
    )
 | 
			
		||||
    set_config_override(config)
 | 
			
		||||
    return config.dgpu
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
async def open_test_worker(
 | 
			
		||||
    cleos, ipfs_node,
 | 
			
		||||
    account: str = 'testworker',
 | 
			
		||||
    permission: str = 'active',
 | 
			
		||||
    hf_token: str = '',
 | 
			
		||||
    **kwargs
 | 
			
		||||
):
 | 
			
		||||
    config = override_dgpu_config(
 | 
			
		||||
        account=account,
 | 
			
		||||
        permission=permission,
 | 
			
		||||
        key=cleos.private_keys[account],
 | 
			
		||||
        node_url=cleos.endpoint,
 | 
			
		||||
        ipfs_url=ipfs_node[1].endpoint,
 | 
			
		||||
        hf_token=hf_token,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    )
 | 
			
		||||
    async with open_worker(config) as worker:
 | 
			
		||||
        yield worker
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,7 +7,7 @@ import urwid
 | 
			
		|||
from skynet.config import Config
 | 
			
		||||
from skynet.dgpu.tui import init_tui
 | 
			
		||||
from skynet.dgpu.daemon import dgpu_serve_forever
 | 
			
		||||
from skynet.dgpu.network import NetConnector
 | 
			
		||||
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
| 
						 | 
				
			
			@ -20,21 +20,23 @@ async def open_worker(config: Config):
 | 
			
		|||
        tui = init_tui(config)
 | 
			
		||||
 | 
			
		||||
    conn = NetConnector(config)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        n: trio.Nursery
 | 
			
		||||
        async with trio.open_nursery() as n:
 | 
			
		||||
            if tui:
 | 
			
		||||
                n.start_soon(tui.run)
 | 
			
		||||
        async with maybe_open_contract_state_mngr(conn) as state_mngr:
 | 
			
		||||
            n: trio.Nursery
 | 
			
		||||
            async with trio.open_nursery() as n:
 | 
			
		||||
                if tui:
 | 
			
		||||
                    n.start_soon(tui.run)
 | 
			
		||||
 | 
			
		||||
            n.start_soon(conn.iter_poll_update, config.poll_time)
 | 
			
		||||
                n.start_soon(dgpu_serve_forever, config, conn, state_mngr)
 | 
			
		||||
 | 
			
		||||
            yield conn
 | 
			
		||||
                yield conn, state_mngr
 | 
			
		||||
 | 
			
		||||
                n.cancel_scope.cancel()
 | 
			
		||||
 | 
			
		||||
    except *urwid.ExitMainLoop:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def _dgpu_main(config: Config):
 | 
			
		||||
    async with open_worker(config) as conn:
 | 
			
		||||
        await dgpu_serve_forever(config, conn)
 | 
			
		||||
    async with open_worker(config):
 | 
			
		||||
        await trio.sleep_forever()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,7 +13,7 @@ import trio
 | 
			
		|||
import torch
 | 
			
		||||
 | 
			
		||||
from skynet.config import load_skynet_toml
 | 
			
		||||
from skynet.types import ModelMode, BodyV0, BodyV0Params
 | 
			
		||||
from skynet.types import ModelMode, BodyV0Params
 | 
			
		||||
from skynet.dgpu.tui import maybe_update_tui
 | 
			
		||||
from skynet.dgpu.errors import (
 | 
			
		||||
    DGPUComputeError,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,4 @@
 | 
			
		|||
import logging
 | 
			
		||||
import random
 | 
			
		||||
from functools import partial
 | 
			
		||||
from hashlib import sha256
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -8,22 +7,16 @@ import msgspec
 | 
			
		|||
 | 
			
		||||
from skynet.config import DgpuConfig as Config
 | 
			
		||||
from skynet.types import (
 | 
			
		||||
    RequestV0,
 | 
			
		||||
    BodyV0
 | 
			
		||||
)
 | 
			
		||||
from skynet.constants import MODELS
 | 
			
		||||
from skynet.dgpu.errors import DGPUComputeError
 | 
			
		||||
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
 | 
			
		||||
from skynet.dgpu.compute import maybe_load_model, compute_one
 | 
			
		||||
from skynet.dgpu.network import NetConnector
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_reward_to_int(reward_str):
 | 
			
		||||
    int_part, decimal_part = (
 | 
			
		||||
        reward_str.split('.')[0],
 | 
			
		||||
        reward_str.split('.')[1].split(' ')[0]
 | 
			
		||||
    )
 | 
			
		||||
    return int(int_part + decimal_part)
 | 
			
		||||
from skynet.dgpu.network import (
 | 
			
		||||
    NetConnector,
 | 
			
		||||
    ContractState,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def maybe_update_tui_balance(conn: NetConnector):
 | 
			
		||||
| 
						 | 
				
			
			@ -38,8 +31,14 @@ async def maybe_update_tui_balance(conn: NetConnector):
 | 
			
		|||
async def maybe_serve_one(
 | 
			
		||||
    config: Config,
 | 
			
		||||
    conn: NetConnector,
 | 
			
		||||
    req: RequestV0,
 | 
			
		||||
    state_mngr: ContractState,
 | 
			
		||||
):
 | 
			
		||||
    req = state_mngr.first
 | 
			
		||||
 | 
			
		||||
    # no requests in queue
 | 
			
		||||
    if not req:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    logging.info(f'maybe serve request #{req.id}')
 | 
			
		||||
 | 
			
		||||
    # parse request
 | 
			
		||||
| 
						 | 
				
			
			@ -69,18 +68,13 @@ async def maybe_serve_one(
 | 
			
		|||
        logging.warning('model not blacklisted!, skip...')
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    results = [res['request_id'] for res in conn._tables['results']]
 | 
			
		||||
 | 
			
		||||
    # if worker already produced a result for this request
 | 
			
		||||
    if req.id in results:
 | 
			
		||||
    if state_mngr.is_request_filled(req.id):
 | 
			
		||||
        logging.info(f'worker already submitted a result for request #{req.id}, skip...')
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    statuses = conn._tables['requests'][req.id]
 | 
			
		||||
 | 
			
		||||
    # skip if workers in non_compete already on it
 | 
			
		||||
    competitors = set((status['worker'] for status in statuses))
 | 
			
		||||
    if bool(config.non_compete & competitors):
 | 
			
		||||
    if state_mngr.should_compete_for_id(req.id):
 | 
			
		||||
        logging.info('worker in configured non_compete list already working on request, skip...')
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -146,7 +140,7 @@ async def maybe_serve_one(
 | 
			
		|||
                            req.id,
 | 
			
		||||
                            mode, body.params,
 | 
			
		||||
                            inputs=inputs,
 | 
			
		||||
                            should_cancel=conn.should_cancel_work,
 | 
			
		||||
                            should_cancel=state_mngr.should_cancel_work,
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -168,34 +162,28 @@ async def maybe_serve_one(
 | 
			
		|||
            if 'network cancel' not in str(err):
 | 
			
		||||
                logging.exception('Failed to serve model request !?\n')
 | 
			
		||||
 | 
			
		||||
            if req.id in conn._tables['requests']:
 | 
			
		||||
            if state_mngr.is_request_in_progress(req.id):
 | 
			
		||||
                await conn.cancel_work(req.id, 'reason not provided')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def dgpu_serve_forever(config: Config, conn: NetConnector):
 | 
			
		||||
async def dgpu_serve_forever(
 | 
			
		||||
    config: Config,
 | 
			
		||||
    conn: NetConnector,
 | 
			
		||||
    state_mngr: ContractState
 | 
			
		||||
):
 | 
			
		||||
    await maybe_update_tui_balance(conn)
 | 
			
		||||
 | 
			
		||||
    last_poll_idx = -1
 | 
			
		||||
    try:
 | 
			
		||||
        while True:
 | 
			
		||||
            await conn.wait_data_update()
 | 
			
		||||
            if conn.poll_index == last_poll_idx:
 | 
			
		||||
            await state_mngr.wait_data_update()
 | 
			
		||||
            if state_mngr.poll_index == last_poll_idx:
 | 
			
		||||
                await trio.sleep(config.poll_time)
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            last_poll_idx = conn.poll_index
 | 
			
		||||
            last_poll_idx = state_mngr.poll_index
 | 
			
		||||
 | 
			
		||||
            queue = conn._tables['queue']
 | 
			
		||||
 | 
			
		||||
            random.shuffle(queue)
 | 
			
		||||
            queue = sorted(
 | 
			
		||||
                queue,
 | 
			
		||||
                key=lambda req: convert_reward_to_int(req['reward']),
 | 
			
		||||
                reverse=True
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if len(queue) > 0:
 | 
			
		||||
                await maybe_serve_one(config, conn, queue[0])
 | 
			
		||||
            await maybe_serve_one(config, conn, state_mngr)
 | 
			
		||||
 | 
			
		||||
    except KeyboardInterrupt:
 | 
			
		||||
        ...
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,8 +1,10 @@
 | 
			
		|||
import io
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
import random
 | 
			
		||||
import logging
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from contextlib import asynccontextmanager as acm
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
| 
						 | 
				
			
			@ -10,12 +12,19 @@ import leap
 | 
			
		|||
import anyio
 | 
			
		||||
import httpx
 | 
			
		||||
import outcome
 | 
			
		||||
import msgspec
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from leap.cleos import CLEOS
 | 
			
		||||
from leap.protocol import Asset
 | 
			
		||||
from skynet.dgpu.tui import maybe_update_tui
 | 
			
		||||
from skynet.config import DgpuConfig as Config
 | 
			
		||||
from skynet.types import RequestV0
 | 
			
		||||
from skynet.config import DgpuConfig as Config, load_skynet_toml
 | 
			
		||||
from skynet.types import (
 | 
			
		||||
    ConfigV0,
 | 
			
		||||
    BodyV0,
 | 
			
		||||
    RequestV0,
 | 
			
		||||
    WorkerStatusV0,
 | 
			
		||||
    ResultV0
 | 
			
		||||
)
 | 
			
		||||
from skynet.constants import GPU_CONTRACT_ABI
 | 
			
		||||
 | 
			
		||||
from skynet.ipfs import (
 | 
			
		||||
| 
						 | 
				
			
			@ -64,15 +73,6 @@ class NetConnector:
 | 
			
		|||
 | 
			
		||||
        self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
 | 
			
		||||
 | 
			
		||||
        # poll_index is used to detect stale data
 | 
			
		||||
        self.poll_index = 0
 | 
			
		||||
        self._tables = {
 | 
			
		||||
            'queue': [],
 | 
			
		||||
            'requests': {},
 | 
			
		||||
            'results': []
 | 
			
		||||
        }
 | 
			
		||||
        self._data_event = trio.Event()
 | 
			
		||||
 | 
			
		||||
        maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -93,22 +93,23 @@ class NetConnector:
 | 
			
		|||
        logging.info(f'found {len(rows)} requests on queue')
 | 
			
		||||
        return rows
 | 
			
		||||
 | 
			
		||||
    async def get_status_by_request_id(self, request_id: int):
 | 
			
		||||
    async def get_status_by_request_id(self, request_id: int) -> list[WorkerStatusV0]:
 | 
			
		||||
        logging.info('get_status_by_request_id')
 | 
			
		||||
        rows = await failable(
 | 
			
		||||
            partial(
 | 
			
		||||
                self.cleos.aget_table,
 | 
			
		||||
                'gpu.scd', request_id, 'status'), ret_fail=[])
 | 
			
		||||
                'gpu.scd', request_id, 'status', resp_cls=WorkerStatusV0), ret_fail=[])
 | 
			
		||||
 | 
			
		||||
        logging.info(f'found status for workers: {[r["worker"] for r in rows]}')
 | 
			
		||||
        logging.info(f'found status for workers: {[r.worker for r in rows]}')
 | 
			
		||||
        return rows
 | 
			
		||||
 | 
			
		||||
    async def get_global_config(self):
 | 
			
		||||
    async def get_global_config(self) -> ConfigV0:
 | 
			
		||||
        logging.info('get_global_config')
 | 
			
		||||
        rows = await failable(
 | 
			
		||||
            partial(
 | 
			
		||||
                self.cleos.aget_table,
 | 
			
		||||
                'gpu.scd', 'gpu.scd', 'config'))
 | 
			
		||||
                'gpu.scd', 'gpu.scd', 'config',
 | 
			
		||||
                resp_cls=ConfigV0))
 | 
			
		||||
 | 
			
		||||
        if rows:
 | 
			
		||||
            cfg = rows[0]
 | 
			
		||||
| 
						 | 
				
			
			@ -118,7 +119,7 @@ class NetConnector:
 | 
			
		|||
            logging.error('global config not found, is the contract initialized?')
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    async def get_worker_balance(self):
 | 
			
		||||
    async def get_worker_balance(self) -> str:
 | 
			
		||||
        logging.info('get_worker_balance')
 | 
			
		||||
        rows = await failable(
 | 
			
		||||
            partial(
 | 
			
		||||
| 
						 | 
				
			
			@ -131,73 +132,13 @@ class NetConnector:
 | 
			
		|||
            ))
 | 
			
		||||
 | 
			
		||||
        if rows:
 | 
			
		||||
            b = rows[0]['balance']
 | 
			
		||||
            b = rows[0].balance
 | 
			
		||||
            logging.info(f'balance: {b}')
 | 
			
		||||
            return b
 | 
			
		||||
        else:
 | 
			
		||||
            logging.info('no balance info found')
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    async def get_full_queue_snapshot(self):
 | 
			
		||||
        '''
 | 
			
		||||
        Get a "snapshot" of current contract table state
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        snap = {
 | 
			
		||||
            'requests': {},
 | 
			
		||||
            'results': []
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        snap['queue'] = await self.get_work_requests_last_hour()
 | 
			
		||||
 | 
			
		||||
        async def _run_and_save(d, key: str, fn, *args, **kwargs):
 | 
			
		||||
            d[key] = await fn(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        async with trio.open_nursery() as n:
 | 
			
		||||
            n.start_soon(_run_and_save, snap, 'results', self.find_results)
 | 
			
		||||
            for req in snap['queue']:
 | 
			
		||||
                n.start_soon(
 | 
			
		||||
                    _run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        maybe_update_tui(lambda tui: tui.network_update(snap))
 | 
			
		||||
 | 
			
		||||
        return snap
 | 
			
		||||
 | 
			
		||||
    async def wait_data_update(self):
 | 
			
		||||
        await self._data_event.wait()
 | 
			
		||||
 | 
			
		||||
    async def iter_poll_update(self, poll_time: float):
 | 
			
		||||
        '''
 | 
			
		||||
        Long running task, polls gpu contract tables latest table rows,
 | 
			
		||||
        awakes any self._data_event waiters
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        while True:
 | 
			
		||||
            start_time = time.time()
 | 
			
		||||
            self._tables = await self.get_full_queue_snapshot()
 | 
			
		||||
            elapsed = time.time() - start_time
 | 
			
		||||
            self._data_event.set()
 | 
			
		||||
            await trio.sleep(max(poll_time - elapsed, 0.1))
 | 
			
		||||
            self._data_event = trio.Event()
 | 
			
		||||
            self.poll_index += 1
 | 
			
		||||
 | 
			
		||||
    async def should_cancel_work(self, request_id: int) -> bool:
 | 
			
		||||
        logging.info('should cancel work?')
 | 
			
		||||
        if request_id not in self._tables['requests']:
 | 
			
		||||
            logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        competitors = set([
 | 
			
		||||
            status['worker']
 | 
			
		||||
            for status in self._tables['requests'][request_id]
 | 
			
		||||
            if status['worker'] != self.config.account
 | 
			
		||||
        ])
 | 
			
		||||
        logging.info(f'competitors: {competitors}')
 | 
			
		||||
        should_cancel = bool(self.config.non_compete & competitors)
 | 
			
		||||
        logging.info(f'cancel: {should_cancel}')
 | 
			
		||||
        return should_cancel
 | 
			
		||||
 | 
			
		||||
    async def begin_work(self, request_id: int):
 | 
			
		||||
        '''
 | 
			
		||||
        Publish to the bc that the worker is beginning a model-computation
 | 
			
		||||
| 
						 | 
				
			
			@ -259,7 +200,7 @@ class NetConnector:
 | 
			
		|||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    async def find_results(self):
 | 
			
		||||
    async def find_results(self) -> list[ResultV0]:
 | 
			
		||||
        logging.info('find_results')
 | 
			
		||||
        rows = await failable(
 | 
			
		||||
            partial(
 | 
			
		||||
| 
						 | 
				
			
			@ -268,7 +209,8 @@ class NetConnector:
 | 
			
		|||
                index_position=4,
 | 
			
		||||
                key_type='name',
 | 
			
		||||
                lower_bound=self.config.account,
 | 
			
		||||
                upper_bound=self.config.account
 | 
			
		||||
                upper_bound=self.config.account,
 | 
			
		||||
                resp_cls=ResultV0
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        return rows
 | 
			
		||||
| 
						 | 
				
			
			@ -342,3 +284,157 @@ class NetConnector:
 | 
			
		|||
        logging.info('decoded as image successfully')
 | 
			
		||||
 | 
			
		||||
        return input_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_reward_to_int(reward_str):
 | 
			
		||||
    int_part, decimal_part = (
 | 
			
		||||
        reward_str.split('.')[0],
 | 
			
		||||
        reward_str.split('.')[1].split(' ')[0]
 | 
			
		||||
    )
 | 
			
		||||
    return int(int_part + decimal_part)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ContractState:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, conn: NetConnector):
 | 
			
		||||
        self._conn = conn
 | 
			
		||||
 | 
			
		||||
        self._poll_index = 0
 | 
			
		||||
 | 
			
		||||
        self._queue: list[RequestV0] = []
 | 
			
		||||
        self._status_by_rid: dict[int, list[WorkerStatusV0]] = {}
 | 
			
		||||
        self._results: list[ResultV0] = []
 | 
			
		||||
 | 
			
		||||
        self._new_data = trio.Event()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def poll_index(self) -> int:
 | 
			
		||||
        return self._poll_index
 | 
			
		||||
 | 
			
		||||
    async def _fetch_results(self):
 | 
			
		||||
        self._results = await self._conn.find_results()
 | 
			
		||||
 | 
			
		||||
    async def _fetch_statuses_for_id(self, rid: int):
 | 
			
		||||
        self._status_by_rid[rid] = await self._conn.get_status_by_request_id(rid)
 | 
			
		||||
 | 
			
		||||
    async def update_state(self):
 | 
			
		||||
        '''
 | 
			
		||||
        Get a "snapshot" of current contract table state
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        # raw queue from chain
 | 
			
		||||
        _queue = await self._conn.get_work_requests_last_hour()
 | 
			
		||||
 | 
			
		||||
        # filter out invalids
 | 
			
		||||
        self._queue = []
 | 
			
		||||
        for req in _queue:
 | 
			
		||||
            try:
 | 
			
		||||
                msgspec.json.decode(req.body, type=BodyV0)
 | 
			
		||||
                self._queue.append(req)
 | 
			
		||||
 | 
			
		||||
            except msgspec.ValidationError:
 | 
			
		||||
                ...
 | 
			
		||||
 | 
			
		||||
        random.shuffle(self._queue)
 | 
			
		||||
        self._queue = sorted(
 | 
			
		||||
            self._queue,
 | 
			
		||||
            key=lambda req: convert_reward_to_int(req.reward),
 | 
			
		||||
            reverse=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        async with trio.open_nursery() as n:
 | 
			
		||||
            n.start_soon(self._fetch_results)
 | 
			
		||||
            for req in self._queue:
 | 
			
		||||
                n.start_soon(
 | 
			
		||||
                    self._fetch_statuses_for_id, req.id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        maybe_update_tui(lambda tui: tui.network_update(self))
 | 
			
		||||
 | 
			
		||||
    async def wait_data_update(self):
 | 
			
		||||
        await self._new_data.wait()
 | 
			
		||||
 | 
			
		||||
    async def _state_update_task(self, poll_time: float):
 | 
			
		||||
        '''
 | 
			
		||||
        Long running task, polls gpu contract tables latest table rows,
 | 
			
		||||
        awakes any self._data_event waiters
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        while True:
 | 
			
		||||
            start_time = time.time()
 | 
			
		||||
            await self.update_state()
 | 
			
		||||
            elapsed = time.time() - start_time
 | 
			
		||||
            self._new_data.set()
 | 
			
		||||
            await trio.sleep(max(poll_time - elapsed, 0.1))
 | 
			
		||||
            self._new_data = trio.Event()
 | 
			
		||||
            self._poll_index += 1
 | 
			
		||||
 | 
			
		||||
    # views into data
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def queue_len(self) -> int:
 | 
			
		||||
        return len(self._queue)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def first(self) -> RequestV0 | None:
 | 
			
		||||
        if len(self._queue) > 0:
 | 
			
		||||
            return self._queue[0]
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def competitors_for_id(self, request_id: int) -> set[str]:
 | 
			
		||||
        return set((
 | 
			
		||||
            status.worker
 | 
			
		||||
            for status in self._status_by_rid[request_id]
 | 
			
		||||
            if status.worker != self._conn.config.account
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    # predicates
 | 
			
		||||
 | 
			
		||||
    def is_request_filled(self, request_id: int) -> bool:
 | 
			
		||||
        return request_id in [
 | 
			
		||||
            result.request_id for result in self._results
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def is_request_in_progress(self, request_id: int) -> bool:
 | 
			
		||||
        return request_id in self._status_by_rid
 | 
			
		||||
 | 
			
		||||
    def should_compete_for_id(self, request_id: int) -> bool:
 | 
			
		||||
        return bool(
 | 
			
		||||
            self._conn.config.non_compete &
 | 
			
		||||
            self.competitors_for_id(request_id)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def should_cancel_work(self, request_id: int) -> bool:
 | 
			
		||||
        logging.info('should cancel work?')
 | 
			
		||||
        if request_id not in self._status_by_rid:
 | 
			
		||||
            logging.info(f'request #{request_id} no longer in queue, likely its been filled by another worker, cancelling work...')
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        should_cancel = self.should_compete_for_id(request_id)
 | 
			
		||||
        logging.info(f'cancel: {should_cancel}')
 | 
			
		||||
        return should_cancel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__state_mngr = None
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
async def maybe_open_contract_state_mngr(conn: NetConnector):
 | 
			
		||||
    global __state_mngr
 | 
			
		||||
 | 
			
		||||
    if __state_mngr:
 | 
			
		||||
        yield __state_mngr
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    config = load_skynet_toml().dgpu
 | 
			
		||||
 | 
			
		||||
    mngr = ContractState(conn)
 | 
			
		||||
    async with trio.open_nursery() as n:
 | 
			
		||||
        await mngr.update_state()
 | 
			
		||||
        n.start_soon(mngr._state_update_task, config.poll_time)
 | 
			
		||||
        __state_mngr = mngr
 | 
			
		||||
        yield mngr
 | 
			
		||||
        n.cancel_scope.cancel()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,7 +2,6 @@ import json
 | 
			
		|||
import logging
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import urwid
 | 
			
		||||
 | 
			
		||||
from skynet.config import DgpuConfig as Config
 | 
			
		||||
| 
						 | 
				
			
			@ -157,14 +156,14 @@ class WorkerMonitor:
 | 
			
		|||
        if new_balance is not None:
 | 
			
		||||
            self.balance_widget.set_text(new_balance)
 | 
			
		||||
 | 
			
		||||
    def network_update(self, snapshot: dict):
 | 
			
		||||
    def network_update(self, state_mngr):
 | 
			
		||||
        queue = [
 | 
			
		||||
            {
 | 
			
		||||
                **r,
 | 
			
		||||
                **(json.loads(r['body'])['params']),
 | 
			
		||||
                'workers': [s['worker'] for s in snapshot['requests'][r['id']]]
 | 
			
		||||
                **(json.loads(r.body)['params']),
 | 
			
		||||
                'workers': [s.worker for s in state_mngr._status_by_rid[r.id]]
 | 
			
		||||
            }
 | 
			
		||||
            for r in snapshot['queue']
 | 
			
		||||
            for r in state_mngr.queue
 | 
			
		||||
        ]
 | 
			
		||||
        self.update_requests(queue)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,6 @@
 | 
			
		|||
import io
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import random
 | 
			
		||||
import logging
 | 
			
		||||
import importlib
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										183
									
								
								skynet/types.py
								
								
								
								
							
							
						
						
									
										183
									
								
								skynet/types.py
								
								
								
								
							| 
						 | 
				
			
			@ -17,6 +17,62 @@ class ModelDesc(Struct):
 | 
			
		|||
    attrs: dict  # additional mode specific attrs
 | 
			
		||||
    tags: list[ModelMode]
 | 
			
		||||
 | 
			
		||||
# smart contract types
 | 
			
		||||
# https://github.com/guilledk/telos.contracts/blob/gpu_contracts/contracts/telos.gpu/include/telos.gpu/telos.gpu.hpp
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
ConfigV0
 | 
			
		||||
 | 
			
		||||
singleton containing global info about system, definition:
 | 
			
		||||
 | 
			
		||||
```c++
 | 
			
		||||
struct [[eosio::table]] global_configuration_struct {
 | 
			
		||||
    name token_contract;
 | 
			
		||||
    symbol token_symbol;
 | 
			
		||||
} global_config_row;
 | 
			
		||||
 | 
			
		||||
typedef eosio::singleton<"config"_n, global_configuration_struct> global_config;
 | 
			
		||||
```
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
class ConfigV0:
 | 
			
		||||
    token_contract: str
 | 
			
		||||
    token_symbol: str
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
RequestV0
 | 
			
		||||
 | 
			
		||||
a request placed on the queue, definition:
 | 
			
		||||
 | 
			
		||||
scope: get_self()
 | 
			
		||||
 | 
			
		||||
```c++
 | 
			
		||||
struct [[eosio::table]] work_request_struct {
 | 
			
		||||
    uint64_t id;
 | 
			
		||||
    name user;
 | 
			
		||||
    asset reward;
 | 
			
		||||
    uint32_t min_verification;
 | 
			
		||||
 | 
			
		||||
    uint64_t nonce;
 | 
			
		||||
 | 
			
		||||
    string body;
 | 
			
		||||
    string binary_data;
 | 
			
		||||
 | 
			
		||||
    time_point_sec timestamp;
 | 
			
		||||
 | 
			
		||||
    uint64_t primary_key() const { return id; }
 | 
			
		||||
    uint64_t by_time() const { return (uint64_t)timestamp.sec_since_epoch(); }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
typedef eosio::multi_index<
 | 
			
		||||
    "queue"_n,
 | 
			
		||||
    work_request_struct,
 | 
			
		||||
    indexed_by<
 | 
			
		||||
        "bytime"_n, const_mem_fun<work_request_struct, uint64_t, &work_request_struct::by_time>
 | 
			
		||||
    >
 | 
			
		||||
> work_queue;
 | 
			
		||||
```
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
class BodyV0Params(Struct):
 | 
			
		||||
    prompt: str
 | 
			
		||||
| 
						 | 
				
			
			@ -45,3 +101,130 @@ class RequestV0(Struct):
 | 
			
		|||
    body: str
 | 
			
		||||
    binary_data: str
 | 
			
		||||
    timestamp: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
AccountV0
 | 
			
		||||
 | 
			
		||||
a user account, users must deposit tokens in order to enqueue requests, definition:
 | 
			
		||||
 | 
			
		||||
scope: get_self()
 | 
			
		||||
 | 
			
		||||
```c++
 | 
			
		||||
struct [[eosio::table]] account {
 | 
			
		||||
    name user;
 | 
			
		||||
    asset    balance;
 | 
			
		||||
    uint64_t nonce;
 | 
			
		||||
 | 
			
		||||
    uint64_t primary_key()const { return user.value; }
 | 
			
		||||
};
 | 
			
		||||
typedef eosio::multi_index<"users"_n, account> users;
 | 
			
		||||
```
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
class AccountV0(Struct):
 | 
			
		||||
    user: str
 | 
			
		||||
    balance: str
 | 
			
		||||
    nonce: int
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
WorkerV0
 | 
			
		||||
 | 
			
		||||
a registered worker info, definition:
 | 
			
		||||
 | 
			
		||||
scope: get_self()
 | 
			
		||||
 | 
			
		||||
```c++
 | 
			
		||||
struct [[eosio::table]] worker {
 | 
			
		||||
    name account;
 | 
			
		||||
    time_point_sec joined;
 | 
			
		||||
    time_point_sec left;
 | 
			
		||||
    string url;
 | 
			
		||||
 | 
			
		||||
    uint64_t primary_key()const { return account.value; }
 | 
			
		||||
};
 | 
			
		||||
typedef eosio::multi_index<"workers"_n, worker> workers;
 | 
			
		||||
```
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
class WorkerV0(Struct):
 | 
			
		||||
    account: str
 | 
			
		||||
    joined: str
 | 
			
		||||
    left: str
 | 
			
		||||
    url: str
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
WorkerStatusV0
 | 
			
		||||
 | 
			
		||||
a worker's status related to a currently in progress fill, definition:
 | 
			
		||||
 | 
			
		||||
scope: request id
 | 
			
		||||
 | 
			
		||||
```c++
 | 
			
		||||
struct [[eosio::table]] worker_status_struct {
 | 
			
		||||
    name worker;
 | 
			
		||||
    string status;
 | 
			
		||||
    time_point_sec started;
 | 
			
		||||
 | 
			
		||||
    uint64_t primary_key() const { return worker.value; }
 | 
			
		||||
};
 | 
			
		||||
```
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
class WorkerStatusV0(Struct):
 | 
			
		||||
    worker: str
 | 
			
		||||
    status: str
 | 
			
		||||
    started: str
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
ResultV0
 | 
			
		||||
 | 
			
		||||
a submited result related to a request, definition:
 | 
			
		||||
 | 
			
		||||
scope: get_self()
 | 
			
		||||
 | 
			
		||||
```c++
 | 
			
		||||
struct [[eosio::table]] work_result_struct {
 | 
			
		||||
    uint64_t id;
 | 
			
		||||
    uint64_t request_id;
 | 
			
		||||
    name user;
 | 
			
		||||
    name worker;
 | 
			
		||||
    checksum256 result_hash;
 | 
			
		||||
    string ipfs_hash;
 | 
			
		||||
    time_point_sec submited;
 | 
			
		||||
 | 
			
		||||
    uint64_t primary_key() const { return id; }
 | 
			
		||||
    uint64_t by_request_id() const { return request_id; }
 | 
			
		||||
    checksum256 by_result_hash() const { return result_hash; }
 | 
			
		||||
    uint64_t by_worker() const { return worker.value; }
 | 
			
		||||
    uint64_t by_time() const { return (uint64_t)submited.sec_since_epoch(); }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
typedef eosio::multi_index<
 | 
			
		||||
    "results"_n,
 | 
			
		||||
    work_result_struct,
 | 
			
		||||
    indexed_by<
 | 
			
		||||
        "byreqid"_n, const_mem_fun<work_result_struct, uint64_t, &work_result_struct::by_request_id>
 | 
			
		||||
    >,
 | 
			
		||||
    indexed_by<
 | 
			
		||||
        "byresult"_n, const_mem_fun<work_result_struct, checksum256, &work_result_struct::by_result_hash>
 | 
			
		||||
    >,
 | 
			
		||||
    indexed_by<
 | 
			
		||||
        "byworker"_n, const_mem_fun<work_result_struct, uint64_t, &work_result_struct::by_worker>
 | 
			
		||||
    >,
 | 
			
		||||
    indexed_by<
 | 
			
		||||
        "bytime"_n, const_mem_fun<work_result_struct, uint64_t, &work_result_struct::by_time>
 | 
			
		||||
    >
 | 
			
		||||
> work_results;
 | 
			
		||||
```
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
class ResultV0(Struct):
 | 
			
		||||
    id: int
 | 
			
		||||
    request_id: int
 | 
			
		||||
    user: str
 | 
			
		||||
    worker: str
 | 
			
		||||
    result_hash: str
 | 
			
		||||
    ipfs_hash: str
 | 
			
		||||
    submited: str
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,9 +19,9 @@ def postgres_db():
 | 
			
		|||
def skynet_cleos(cleos_bs):
 | 
			
		||||
    cleos = cleos_bs
 | 
			
		||||
 | 
			
		||||
    priv, pub = cleos.create_key_pair()
 | 
			
		||||
    cleos.import_key('gpu.scd', priv)
 | 
			
		||||
    cleos.new_account('gpu.scd', ram=4200000, key=pub)
 | 
			
		||||
    # priv, pub = cleos.create_key_pair()
 | 
			
		||||
    # cleos.import_key('gpu.scd', priv)
 | 
			
		||||
    cleos.new_account('gpu.scd', ram=4200000)
 | 
			
		||||
 | 
			
		||||
    cleos.deploy_contract_from_path(
 | 
			
		||||
        'gpu.scd',
 | 
			
		||||
| 
						 | 
				
			
			@ -36,6 +36,8 @@ def skynet_cleos(cleos_bs):
 | 
			
		|||
        'gpu.scd'
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    cleos.new_account('testworker')
 | 
			
		||||
 | 
			
		||||
    yield cleos
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -52,3 +54,18 @@ def inject_mockers():
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
    yield
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(scope='session')
 | 
			
		||||
def ipfs_node(dockerctl):
 | 
			
		||||
    rpc_port = 15001
 | 
			
		||||
    with dockerctl.run(
 | 
			
		||||
        'ipfs/go-ipfs:latest',
 | 
			
		||||
        name='skynet-ipfs',
 | 
			
		||||
        ports={
 | 
			
		||||
            '8080/tcp': 18080,
 | 
			
		||||
            '4001/tcp': 14001,
 | 
			
		||||
            '5001/tcp': ('127.0.0.1', rpc_port)
 | 
			
		||||
        }
 | 
			
		||||
    ) as cntr:
 | 
			
		||||
        yield cntr, AsyncIPFSHTTP(f'http://127.0.0.1:{rpc_port}')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,12 +1,12 @@
 | 
			
		|||
import trio
 | 
			
		||||
from msgspec import json
 | 
			
		||||
 | 
			
		||||
from skynet.types import BodyV0, BodyV0Params
 | 
			
		||||
from skynet.dgpu.network import NetConnector
 | 
			
		||||
 | 
			
		||||
from skynet._testing import override_dgpu_config
 | 
			
		||||
from skynet._testing import open_test_worker
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_enqueue(skynet_cleos):
 | 
			
		||||
async def test_full_flow(inject_mockers, skynet_cleos, ipfs_node):
 | 
			
		||||
    cleos = skynet_cleos
 | 
			
		||||
 | 
			
		||||
    # create account and deposit tokens into gpu
 | 
			
		||||
| 
						 | 
				
			
			@ -46,20 +46,7 @@ async def test_enqueue(skynet_cleos):
 | 
			
		|||
        key=cleos.private_keys[account]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = override_dgpu_config(
 | 
			
		||||
        account='testworker1',
 | 
			
		||||
        permission='active',
 | 
			
		||||
        key='',
 | 
			
		||||
        node_url=cleos.endpoint,
 | 
			
		||||
        ipfs_url='http://127.0.0.1:5001',
 | 
			
		||||
        hf_token=''
 | 
			
		||||
    )
 | 
			
		||||
    net = NetConnector(config)
 | 
			
		||||
    queue = await net.get_work_requests_last_hour()
 | 
			
		||||
 | 
			
		||||
    assert len(queue) == 1
 | 
			
		||||
 | 
			
		||||
    req = queue[0]
 | 
			
		||||
    body = json.decode(req.body, type=BodyV0)
 | 
			
		||||
 | 
			
		||||
    assert og_body == body
 | 
			
		||||
    # open worker and fill request
 | 
			
		||||
    async with open_test_worker(cleos, ipfs_node) as (_conn, state_mngr):
 | 
			
		||||
        while state_mngr.queue_len > 0:
 | 
			
		||||
            await trio.sleep(1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue