mirror of https://github.com/skygpu/skynet.git
Further simplification of daemon code, remove NetConnector
parent
0a5c06e312
commit
7edca49e95
|
@ -4,10 +4,14 @@ from contextlib import asynccontextmanager as acm
|
|||
import trio
|
||||
import urwid
|
||||
|
||||
from leap import CLEOS
|
||||
|
||||
from skynet.config import Config
|
||||
from skynet.ipfs import AsyncIPFSHTTP
|
||||
from skynet.contract import GPUContractAPI
|
||||
from skynet.dgpu.tui import init_tui, WorkerMonitor
|
||||
from skynet.dgpu.daemon import dgpu_serve_forever
|
||||
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr
|
||||
from skynet.dgpu.network import maybe_open_contract_state_mngr
|
||||
|
||||
|
||||
@acm
|
||||
|
@ -19,17 +23,24 @@ async def open_worker(config: Config):
|
|||
if config.tui:
|
||||
tui = init_tui(config)
|
||||
|
||||
conn = NetConnector(config)
|
||||
cleos = CLEOS(endpoint=config.node_url)
|
||||
cleos.import_key(config.account, config.key)
|
||||
abi = cleos.get_abi('gpu.scd')
|
||||
cleos.load_abi('gpu.scd', abi)
|
||||
|
||||
ipfs_api = AsyncIPFSHTTP(config.ipfs_url)
|
||||
|
||||
contract = GPUContractAPI(cleos)
|
||||
try:
|
||||
async with maybe_open_contract_state_mngr(conn) as state_mngr:
|
||||
async with maybe_open_contract_state_mngr(contract) as state_mngr:
|
||||
n: trio.Nursery
|
||||
async with trio.open_nursery() as n:
|
||||
if tui:
|
||||
n.start_soon(tui.run)
|
||||
|
||||
n.start_soon(dgpu_serve_forever, config, conn, state_mngr)
|
||||
n.start_soon(dgpu_serve_forever, config, contract, ipfs_api, state_mngr)
|
||||
|
||||
yield conn, state_mngr
|
||||
yield contract, ipfs_api, state_mngr
|
||||
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
|
|
|
@ -8,20 +8,21 @@ from skynet.config import DgpuConfig as Config
|
|||
from skynet.types import (
|
||||
BodyV0
|
||||
)
|
||||
from skynet.contract import GPUContractAPI
|
||||
from skynet.constants import MODELS
|
||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_img
|
||||
from skynet.dgpu.errors import DGPUComputeError
|
||||
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
|
||||
from skynet.dgpu.compute import maybe_load_model, compute_one
|
||||
from skynet.dgpu.network import (
|
||||
NetConnector,
|
||||
ContractState,
|
||||
)
|
||||
|
||||
|
||||
async def maybe_update_tui_balance(conn: NetConnector):
|
||||
async def maybe_update_tui_balance(contract: GPUContractAPI):
|
||||
async def _fn(tui):
|
||||
# update balance
|
||||
balance = await conn.contract.get_user(tui.config.account).balance
|
||||
balance = await contract.get_user(tui.config.account).balance
|
||||
tui.set_header_text(new_balance=f'balance: {balance}')
|
||||
|
||||
await maybe_update_tui_async(_fn)
|
||||
|
@ -29,7 +30,8 @@ async def maybe_update_tui_balance(conn: NetConnector):
|
|||
|
||||
async def maybe_serve_one(
|
||||
config: Config,
|
||||
conn: NetConnector,
|
||||
contract: GPUContractAPI,
|
||||
ipfs_api: AsyncIPFSHTTP,
|
||||
state_mngr: ContractState,
|
||||
):
|
||||
logging.info(f'maybe serve request pi: {state_mngr.poll_index}')
|
||||
|
@ -90,7 +92,7 @@ async def maybe_serve_one(
|
|||
# user `GPUConnector` to IO with
|
||||
# storage layer to seed the compute
|
||||
# task.
|
||||
img = await conn.get_input_data(_input)
|
||||
img = await get_ipfs_img(f'https://{config.ipfs_domain}/ipfs/{_input}')
|
||||
inputs.append(img)
|
||||
logging.info(f'retrieved {_input}!')
|
||||
break
|
||||
|
@ -105,7 +107,7 @@ async def maybe_serve_one(
|
|||
|
||||
# TODO: validate request
|
||||
|
||||
resp = await conn.contract.accept_work(config.account, req.id)
|
||||
resp = await contract.accept_work(config.account, req.id)
|
||||
if not resp or 'code' in resp:
|
||||
logging.info('begin_work error, probably being worked on already... skip.')
|
||||
return
|
||||
|
@ -142,13 +144,13 @@ async def maybe_serve_one(
|
|||
|
||||
maybe_update_tui(lambda tui: tui.set_progress(total_step))
|
||||
|
||||
ipfs_hash = await conn.publish_on_ipfs(output, typ=output_type)
|
||||
ipfs_hash = await ipfs_api.publish(output, type=output_type)
|
||||
|
||||
await conn.contract.submit_work(config.account, req.id, output_hash, ipfs_hash)
|
||||
await contract.submit_work(config.account, req.id, output_hash, ipfs_hash)
|
||||
|
||||
await state_mngr.update_state()
|
||||
|
||||
await maybe_update_tui_balance(conn)
|
||||
await maybe_update_tui_balance(contract)
|
||||
|
||||
await state_mngr.update_state()
|
||||
|
||||
|
@ -160,15 +162,17 @@ async def maybe_serve_one(
|
|||
await state_mngr.update_state()
|
||||
|
||||
if state_mngr.is_request_in_progress(req.id):
|
||||
await conn.contract.cancel_work(config.account, req.id, 'reason not provided')
|
||||
await contract.cancel_work(config.account, req.id, 'reason not provided')
|
||||
|
||||
|
||||
async def dgpu_serve_forever(
|
||||
config: Config,
|
||||
conn: NetConnector,
|
||||
contract: GPUContractAPI,
|
||||
ipfs_api: AsyncIPFSHTTP,
|
||||
state_mngr: ContractState
|
||||
):
|
||||
await maybe_update_tui_balance(conn)
|
||||
await maybe_update_tui_balance(contract)
|
||||
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=config.account))
|
||||
|
||||
last_poll_idx = -1
|
||||
try:
|
||||
|
@ -180,7 +184,7 @@ async def dgpu_serve_forever(
|
|||
|
||||
last_poll_idx = state_mngr.poll_index
|
||||
|
||||
await maybe_serve_one(config, conn, state_mngr)
|
||||
await maybe_serve_one(config, contract, ipfs_api, state_mngr)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import io
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager as acm
|
||||
from functools import partial
|
||||
|
||||
|
@ -13,22 +11,15 @@ import anyio
|
|||
import httpx
|
||||
import outcome
|
||||
import msgspec
|
||||
from PIL import Image
|
||||
from leap.cleos import CLEOS
|
||||
from skynet.dgpu.tui import maybe_update_tui
|
||||
from skynet.config import DgpuConfig as Config, load_skynet_toml
|
||||
from skynet.config import load_skynet_toml
|
||||
from skynet.contract import GPUContractAPI
|
||||
from skynet.types import (
|
||||
BodyV0,
|
||||
RequestV1,
|
||||
WorkerStatusV0,
|
||||
ResultV0
|
||||
)
|
||||
from skynet.contract import GPUContractAPI
|
||||
|
||||
from skynet.ipfs import (
|
||||
AsyncIPFSHTTP,
|
||||
get_ipfs_file,
|
||||
)
|
||||
|
||||
|
||||
REQUEST_UPDATE_TIME: int = 3
|
||||
|
@ -53,77 +44,6 @@ async def failable(fn: partial, ret_fail=None):
|
|||
return o.unwrap()
|
||||
|
||||
|
||||
class NetConnector:
|
||||
'''
|
||||
An API for connecting to and conducting various "high level"
|
||||
network-service operations in the skynet.
|
||||
|
||||
- skynet user account creds
|
||||
- hyperion API
|
||||
- IPFs client
|
||||
- CLEOS client
|
||||
|
||||
'''
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.cleos = CLEOS(endpoint=config.node_url)
|
||||
self.cleos.import_key(config.account, config.key)
|
||||
abi = self.cleos.get_abi('gpu.scd')
|
||||
self.cleos.load_abi('gpu.scd', abi)
|
||||
|
||||
self.contract = GPUContractAPI(self.cleos)
|
||||
|
||||
self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
|
||||
|
||||
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
|
||||
|
||||
# IPFS helpers
|
||||
async def publish_on_ipfs(self, raw, typ: str = 'png'):
|
||||
Path('ipfs-staging').mkdir(exist_ok=True)
|
||||
logging.info('publish_on_ipfs')
|
||||
|
||||
target_file = ''
|
||||
match typ:
|
||||
case 'png':
|
||||
raw: Image
|
||||
target_file = 'ipfs-staging/image.png'
|
||||
raw.save(target_file)
|
||||
|
||||
case _:
|
||||
raise ValueError(f'Unsupported output type: {typ}')
|
||||
|
||||
file_info = await self.ipfs_client.add(Path(target_file))
|
||||
file_cid = file_info['Hash']
|
||||
logging.info(f'added file to ipfs, CID: {file_cid}')
|
||||
|
||||
await self.ipfs_client.pin(file_cid)
|
||||
logging.info(f'pinned {file_cid}')
|
||||
|
||||
return file_cid
|
||||
|
||||
async def get_input_data(self, ipfs_hash: str) -> Image:
|
||||
'''
|
||||
Retrieve an input (image) from the IPFs layer.
|
||||
|
||||
Normally used to retreive seed (visual) content previously
|
||||
generated/validated by the network to be fed to some
|
||||
consuming AI model.
|
||||
|
||||
'''
|
||||
link = f'https://{self.config.ipfs_domain}/ipfs/{ipfs_hash}'
|
||||
|
||||
res = await get_ipfs_file(link, timeout=1)
|
||||
if not res or res.status_code != 200:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {link}!')
|
||||
|
||||
# attempt to decode as image
|
||||
input_data = Image.open(io.BytesIO(res.read()))
|
||||
logging.info('decoded as image successfully')
|
||||
|
||||
return input_data
|
||||
|
||||
|
||||
|
||||
def convert_reward_to_int(reward_str):
|
||||
int_part, decimal_part = (
|
||||
reward_str.split('.')[0],
|
||||
|
@ -134,8 +54,11 @@ def convert_reward_to_int(reward_str):
|
|||
|
||||
class ContractState:
|
||||
|
||||
def __init__(self, conn: NetConnector):
|
||||
self._conn = conn
|
||||
def __init__(
|
||||
self,
|
||||
contract: GPUContractAPI
|
||||
):
|
||||
self.contract = contract
|
||||
|
||||
self._config = load_skynet_toml().dgpu
|
||||
self._poll_index = 0
|
||||
|
@ -151,10 +74,10 @@ class ContractState:
|
|||
return self._poll_index
|
||||
|
||||
async def _fetch_results(self):
|
||||
self._results = await self._conn.contract.get_worker_results(self._config.account)
|
||||
self._results = await self.contract.get_worker_results(self._config.account)
|
||||
|
||||
async def _fetch_statuses_for_id(self, rid: int):
|
||||
self._status_by_rid[rid] = await self._conn.contract.get_statuses_for_request(rid)
|
||||
self._status_by_rid[rid] = await self.contract.get_statuses_for_request(rid)
|
||||
|
||||
async def update_state(self):
|
||||
'''
|
||||
|
@ -162,7 +85,7 @@ class ContractState:
|
|||
|
||||
'''
|
||||
# raw queue from chain
|
||||
_queue = await self._conn.contract.get_requests_since(3600)
|
||||
_queue = await self.contract.get_requests_since(3600)
|
||||
|
||||
# filter out invalids
|
||||
self._queue = []
|
||||
|
@ -261,7 +184,7 @@ class ContractState:
|
|||
__state_mngr = None
|
||||
|
||||
@acm
|
||||
async def maybe_open_contract_state_mngr(conn: NetConnector):
|
||||
async def maybe_open_contract_state_mngr(contract: GPUContractAPI):
|
||||
global __state_mngr
|
||||
|
||||
if __state_mngr:
|
||||
|
@ -270,7 +193,7 @@ async def maybe_open_contract_state_mngr(conn: NetConnector):
|
|||
|
||||
config = load_skynet_toml().dgpu
|
||||
|
||||
mngr = ContractState(conn)
|
||||
mngr = ContractState(contract)
|
||||
async with trio.open_nursery() as n:
|
||||
await mngr.update_state()
|
||||
n.start_soon(mngr._state_update_task, config.poll_time)
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from PIL import Image
|
||||
|
||||
class IPFSClientException(Exception):
|
||||
...
|
||||
|
@ -53,14 +54,37 @@ class AsyncIPFSHTTP:
|
|||
params=kwargs
|
||||
))['Peers']
|
||||
|
||||
async def publish(self, raw, type: str = 'png'):
|
||||
stage = Path('/tmp/ipfs-staging')
|
||||
stage.mkdir(exist_ok=True)
|
||||
logging.info('publish_on_ipfs')
|
||||
|
||||
async def get_ipfs_file(ipfs_link: str, timeout: int = 60 * 5):
|
||||
target_file = ''
|
||||
match type:
|
||||
case 'png':
|
||||
raw: Image
|
||||
target_file = stage / 'image.png'
|
||||
raw.save(target_file)
|
||||
|
||||
case _:
|
||||
raise ValueError(f'Unsupported output type: {type}')
|
||||
|
||||
file_info = await self.add(Path(target_file))
|
||||
file_cid = file_info['Hash']
|
||||
logging.info(f'added file to ipfs, CID: {file_cid}')
|
||||
|
||||
await self.pin(file_cid)
|
||||
logging.info(f'pinned {file_cid}')
|
||||
|
||||
return file_cid
|
||||
|
||||
async def get_ipfs_img(ipfs_link: str, timeout: int = 3) -> Image:
|
||||
logging.info(f'attempting to get image at {ipfs_link}')
|
||||
resp = None
|
||||
for _ in range(timeout):
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(ipfs_link, timeout=3)
|
||||
resp = await client.get(ipfs_link, timeout=timeout)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logging.warning(f'Request error: {e}')
|
||||
|
@ -71,6 +95,14 @@ async def get_ipfs_file(ipfs_link: str, timeout: int = 60 * 5):
|
|||
if resp:
|
||||
logging.info(f'status_code: {resp.status_code}')
|
||||
else:
|
||||
logging.error(f'timeout')
|
||||
logging.error('timeout')
|
||||
return None
|
||||
|
||||
return resp
|
||||
if resp.status_code != 200:
|
||||
logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
|
||||
return resp
|
||||
|
||||
img = Image.open(io.BytesIO(resp.read()))
|
||||
logging.info('decoded img successfully')
|
||||
|
||||
return img
|
||||
|
|
|
@ -265,6 +265,6 @@ async def test_full_flow(inject_mockers, skynet_cleos, ipfs_node):
|
|||
)
|
||||
|
||||
# open worker and fill request
|
||||
async with open_test_worker(cleos, ipfs_node) as (_conn, state_mngr):
|
||||
async with open_test_worker(cleos, ipfs_node) as (_contract, _ipfs_api, state_mngr):
|
||||
while state_mngr.queue_len > 0:
|
||||
await trio.sleep(1)
|
||||
|
|
Loading…
Reference in New Issue