Further simplification of daemon code, remove NetConnector

rust_contract
Guillermo Rodriguez 2025-02-13 20:53:57 -03:00
parent 0a5c06e312
commit 7edca49e95
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
5 changed files with 83 additions and 113 deletions

View File

@ -4,10 +4,14 @@ from contextlib import asynccontextmanager as acm
import trio import trio
import urwid import urwid
from leap import CLEOS
from skynet.config import Config from skynet.config import Config
from skynet.ipfs import AsyncIPFSHTTP
from skynet.contract import GPUContractAPI
from skynet.dgpu.tui import init_tui, WorkerMonitor from skynet.dgpu.tui import init_tui, WorkerMonitor
from skynet.dgpu.daemon import dgpu_serve_forever from skynet.dgpu.daemon import dgpu_serve_forever
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr from skynet.dgpu.network import maybe_open_contract_state_mngr
@acm @acm
@ -19,17 +23,24 @@ async def open_worker(config: Config):
if config.tui: if config.tui:
tui = init_tui(config) tui = init_tui(config)
conn = NetConnector(config) cleos = CLEOS(endpoint=config.node_url)
cleos.import_key(config.account, config.key)
abi = cleos.get_abi('gpu.scd')
cleos.load_abi('gpu.scd', abi)
ipfs_api = AsyncIPFSHTTP(config.ipfs_url)
contract = GPUContractAPI(cleos)
try: try:
async with maybe_open_contract_state_mngr(conn) as state_mngr: async with maybe_open_contract_state_mngr(contract) as state_mngr:
n: trio.Nursery n: trio.Nursery
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
if tui: if tui:
n.start_soon(tui.run) n.start_soon(tui.run)
n.start_soon(dgpu_serve_forever, config, conn, state_mngr) n.start_soon(dgpu_serve_forever, config, contract, ipfs_api, state_mngr)
yield conn, state_mngr yield contract, ipfs_api, state_mngr
n.cancel_scope.cancel() n.cancel_scope.cancel()

View File

@ -8,20 +8,21 @@ from skynet.config import DgpuConfig as Config
from skynet.types import ( from skynet.types import (
BodyV0 BodyV0
) )
from skynet.contract import GPUContractAPI
from skynet.constants import MODELS from skynet.constants import MODELS
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_img
from skynet.dgpu.errors import DGPUComputeError from skynet.dgpu.errors import DGPUComputeError
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
from skynet.dgpu.compute import maybe_load_model, compute_one from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import ( from skynet.dgpu.network import (
NetConnector,
ContractState, ContractState,
) )
async def maybe_update_tui_balance(conn: NetConnector): async def maybe_update_tui_balance(contract: GPUContractAPI):
async def _fn(tui): async def _fn(tui):
# update balance # update balance
balance = await conn.contract.get_user(tui.config.account).balance balance = await contract.get_user(tui.config.account).balance
tui.set_header_text(new_balance=f'balance: {balance}') tui.set_header_text(new_balance=f'balance: {balance}')
await maybe_update_tui_async(_fn) await maybe_update_tui_async(_fn)
@ -29,7 +30,8 @@ async def maybe_update_tui_balance(conn: NetConnector):
async def maybe_serve_one( async def maybe_serve_one(
config: Config, config: Config,
conn: NetConnector, contract: GPUContractAPI,
ipfs_api: AsyncIPFSHTTP,
state_mngr: ContractState, state_mngr: ContractState,
): ):
logging.info(f'maybe serve request pi: {state_mngr.poll_index}') logging.info(f'maybe serve request pi: {state_mngr.poll_index}')
@ -90,7 +92,7 @@ async def maybe_serve_one(
# user `GPUConnector` to IO with # user `GPUConnector` to IO with
# storage layer to seed the compute # storage layer to seed the compute
# task. # task.
img = await conn.get_input_data(_input) img = await get_ipfs_img(f'https://{config.ipfs_domain}/ipfs/{_input}')
inputs.append(img) inputs.append(img)
logging.info(f'retrieved {_input}!') logging.info(f'retrieved {_input}!')
break break
@ -105,7 +107,7 @@ async def maybe_serve_one(
# TODO: validate request # TODO: validate request
resp = await conn.contract.accept_work(config.account, req.id) resp = await contract.accept_work(config.account, req.id)
if not resp or 'code' in resp: if not resp or 'code' in resp:
logging.info('begin_work error, probably being worked on already... skip.') logging.info('begin_work error, probably being worked on already... skip.')
return return
@ -142,13 +144,13 @@ async def maybe_serve_one(
maybe_update_tui(lambda tui: tui.set_progress(total_step)) maybe_update_tui(lambda tui: tui.set_progress(total_step))
ipfs_hash = await conn.publish_on_ipfs(output, typ=output_type) ipfs_hash = await ipfs_api.publish(output, type=output_type)
await conn.contract.submit_work(config.account, req.id, output_hash, ipfs_hash) await contract.submit_work(config.account, req.id, output_hash, ipfs_hash)
await state_mngr.update_state() await state_mngr.update_state()
await maybe_update_tui_balance(conn) await maybe_update_tui_balance(contract)
await state_mngr.update_state() await state_mngr.update_state()
@ -160,15 +162,17 @@ async def maybe_serve_one(
await state_mngr.update_state() await state_mngr.update_state()
if state_mngr.is_request_in_progress(req.id): if state_mngr.is_request_in_progress(req.id):
await conn.contract.cancel_work(config.account, req.id, 'reason not provided') await contract.cancel_work(config.account, req.id, 'reason not provided')
async def dgpu_serve_forever( async def dgpu_serve_forever(
config: Config, config: Config,
conn: NetConnector, contract: GPUContractAPI,
ipfs_api: AsyncIPFSHTTP,
state_mngr: ContractState state_mngr: ContractState
): ):
await maybe_update_tui_balance(conn) await maybe_update_tui_balance(contract)
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=config.account))
last_poll_idx = -1 last_poll_idx = -1
try: try:
@ -180,7 +184,7 @@ async def dgpu_serve_forever(
last_poll_idx = state_mngr.poll_index last_poll_idx = state_mngr.poll_index
await maybe_serve_one(config, conn, state_mngr) await maybe_serve_one(config, contract, ipfs_api, state_mngr)
except KeyboardInterrupt: except KeyboardInterrupt:
... ...

View File

@ -1,9 +1,7 @@
import io
import json import json
import time import time
import random import random
import logging import logging
from pathlib import Path
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
from functools import partial from functools import partial
@ -13,22 +11,15 @@ import anyio
import httpx import httpx
import outcome import outcome
import msgspec import msgspec
from PIL import Image
from leap.cleos import CLEOS
from skynet.dgpu.tui import maybe_update_tui from skynet.dgpu.tui import maybe_update_tui
from skynet.config import DgpuConfig as Config, load_skynet_toml from skynet.config import load_skynet_toml
from skynet.contract import GPUContractAPI
from skynet.types import ( from skynet.types import (
BodyV0, BodyV0,
RequestV1, RequestV1,
WorkerStatusV0, WorkerStatusV0,
ResultV0 ResultV0
) )
from skynet.contract import GPUContractAPI
from skynet.ipfs import (
AsyncIPFSHTTP,
get_ipfs_file,
)
REQUEST_UPDATE_TIME: int = 3 REQUEST_UPDATE_TIME: int = 3
@ -53,77 +44,6 @@ async def failable(fn: partial, ret_fail=None):
return o.unwrap() return o.unwrap()
class NetConnector:
'''
An API for connecting to and conducting various "high level"
network-service operations in the skynet.
- skynet user account creds
- hyperion API
- IPFs client
- CLEOS client
'''
def __init__(self, config: Config):
self.config = config
self.cleos = CLEOS(endpoint=config.node_url)
self.cleos.import_key(config.account, config.key)
abi = self.cleos.get_abi('gpu.scd')
self.cleos.load_abi('gpu.scd', abi)
self.contract = GPUContractAPI(self.cleos)
self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
# IPFS helpers
async def publish_on_ipfs(self, raw, typ: str = 'png'):
Path('ipfs-staging').mkdir(exist_ok=True)
logging.info('publish_on_ipfs')
target_file = ''
match typ:
case 'png':
raw: Image
target_file = 'ipfs-staging/image.png'
raw.save(target_file)
case _:
raise ValueError(f'Unsupported output type: {typ}')
file_info = await self.ipfs_client.add(Path(target_file))
file_cid = file_info['Hash']
logging.info(f'added file to ipfs, CID: {file_cid}')
await self.ipfs_client.pin(file_cid)
logging.info(f'pinned {file_cid}')
return file_cid
async def get_input_data(self, ipfs_hash: str) -> Image:
'''
Retrieve an input (image) from the IPFs layer.
Normally used to retreive seed (visual) content previously
generated/validated by the network to be fed to some
consuming AI model.
'''
link = f'https://{self.config.ipfs_domain}/ipfs/{ipfs_hash}'
res = await get_ipfs_file(link, timeout=1)
if not res or res.status_code != 200:
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
# attempt to decode as image
input_data = Image.open(io.BytesIO(res.read()))
logging.info('decoded as image successfully')
return input_data
def convert_reward_to_int(reward_str): def convert_reward_to_int(reward_str):
int_part, decimal_part = ( int_part, decimal_part = (
reward_str.split('.')[0], reward_str.split('.')[0],
@ -134,8 +54,11 @@ def convert_reward_to_int(reward_str):
class ContractState: class ContractState:
def __init__(self, conn: NetConnector): def __init__(
self._conn = conn self,
contract: GPUContractAPI
):
self.contract = contract
self._config = load_skynet_toml().dgpu self._config = load_skynet_toml().dgpu
self._poll_index = 0 self._poll_index = 0
@ -151,10 +74,10 @@ class ContractState:
return self._poll_index return self._poll_index
async def _fetch_results(self): async def _fetch_results(self):
self._results = await self._conn.contract.get_worker_results(self._config.account) self._results = await self.contract.get_worker_results(self._config.account)
async def _fetch_statuses_for_id(self, rid: int): async def _fetch_statuses_for_id(self, rid: int):
self._status_by_rid[rid] = await self._conn.contract.get_statuses_for_request(rid) self._status_by_rid[rid] = await self.contract.get_statuses_for_request(rid)
async def update_state(self): async def update_state(self):
''' '''
@ -162,7 +85,7 @@ class ContractState:
''' '''
# raw queue from chain # raw queue from chain
_queue = await self._conn.contract.get_requests_since(3600) _queue = await self.contract.get_requests_since(3600)
# filter out invalids # filter out invalids
self._queue = [] self._queue = []
@ -261,7 +184,7 @@ class ContractState:
__state_mngr = None __state_mngr = None
@acm @acm
async def maybe_open_contract_state_mngr(conn: NetConnector): async def maybe_open_contract_state_mngr(contract: GPUContractAPI):
global __state_mngr global __state_mngr
if __state_mngr: if __state_mngr:
@ -270,7 +193,7 @@ async def maybe_open_contract_state_mngr(conn: NetConnector):
config = load_skynet_toml().dgpu config = load_skynet_toml().dgpu
mngr = ContractState(conn) mngr = ContractState(contract)
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
await mngr.update_state() await mngr.update_state()
n.start_soon(mngr._state_update_task, config.poll_time) n.start_soon(mngr._state_update_task, config.poll_time)

View File

@ -1,8 +1,9 @@
import io
import logging import logging
from pathlib import Path from pathlib import Path
import httpx import httpx
from PIL import Image
class IPFSClientException(Exception): class IPFSClientException(Exception):
... ...
@ -53,14 +54,37 @@ class AsyncIPFSHTTP:
params=kwargs params=kwargs
))['Peers'] ))['Peers']
async def publish(self, raw, type: str = 'png'):
stage = Path('/tmp/ipfs-staging')
stage.mkdir(exist_ok=True)
logging.info('publish_on_ipfs')
async def get_ipfs_file(ipfs_link: str, timeout: int = 60 * 5): target_file = ''
match type:
case 'png':
raw: Image
target_file = stage / 'image.png'
raw.save(target_file)
case _:
raise ValueError(f'Unsupported output type: {type}')
file_info = await self.add(Path(target_file))
file_cid = file_info['Hash']
logging.info(f'added file to ipfs, CID: {file_cid}')
await self.pin(file_cid)
logging.info(f'pinned {file_cid}')
return file_cid
async def get_ipfs_img(ipfs_link: str, timeout: int = 3) -> Image:
logging.info(f'attempting to get image at {ipfs_link}') logging.info(f'attempting to get image at {ipfs_link}')
resp = None resp = None
for _ in range(timeout): for _ in range(timeout):
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
resp = await client.get(ipfs_link, timeout=3) resp = await client.get(ipfs_link, timeout=timeout)
except httpx.RequestError as e: except httpx.RequestError as e:
logging.warning(f'Request error: {e}') logging.warning(f'Request error: {e}')
@ -71,6 +95,14 @@ async def get_ipfs_file(ipfs_link: str, timeout: int = 60 * 5):
if resp: if resp:
logging.info(f'status_code: {resp.status_code}') logging.info(f'status_code: {resp.status_code}')
else: else:
logging.error(f'timeout') logging.error('timeout')
return None
return resp if resp.status_code != 200:
logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
return resp
img = Image.open(io.BytesIO(resp.read()))
logging.info('decoded img successfully')
return img

View File

@ -265,6 +265,6 @@ async def test_full_flow(inject_mockers, skynet_cleos, ipfs_node):
) )
# open worker and fill request # open worker and fill request
async with open_test_worker(cleos, ipfs_node) as (_conn, state_mngr): async with open_test_worker(cleos, ipfs_node) as (_contract, _ipfs_api, state_mngr):
while state_mngr.queue_len > 0: while state_mngr.queue_len > 0:
await trio.sleep(1) await trio.sleep(1)