mirror of https://github.com/skygpu/skynet.git
				
				
				
			Further simplification of daemon code, remove NetConnector
							parent
							
								
									0a5c06e312
								
							
						
					
					
						commit
						7edca49e95
					
				| 
						 | 
				
			
			@ -4,10 +4,14 @@ from contextlib import asynccontextmanager as acm
 | 
			
		|||
import trio
 | 
			
		||||
import urwid
 | 
			
		||||
 | 
			
		||||
from leap import CLEOS
 | 
			
		||||
 | 
			
		||||
from skynet.config import Config
 | 
			
		||||
from skynet.ipfs import AsyncIPFSHTTP
 | 
			
		||||
from skynet.contract import GPUContractAPI
 | 
			
		||||
from skynet.dgpu.tui import init_tui, WorkerMonitor
 | 
			
		||||
from skynet.dgpu.daemon import dgpu_serve_forever
 | 
			
		||||
from skynet.dgpu.network import NetConnector, maybe_open_contract_state_mngr
 | 
			
		||||
from skynet.dgpu.network import maybe_open_contract_state_mngr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
| 
						 | 
				
			
			@ -19,17 +23,24 @@ async def open_worker(config: Config):
 | 
			
		|||
    if config.tui:
 | 
			
		||||
        tui = init_tui(config)
 | 
			
		||||
 | 
			
		||||
    conn = NetConnector(config)
 | 
			
		||||
    cleos = CLEOS(endpoint=config.node_url)
 | 
			
		||||
    cleos.import_key(config.account, config.key)
 | 
			
		||||
    abi = cleos.get_abi('gpu.scd')
 | 
			
		||||
    cleos.load_abi('gpu.scd', abi)
 | 
			
		||||
 | 
			
		||||
    ipfs_api = AsyncIPFSHTTP(config.ipfs_url)
 | 
			
		||||
 | 
			
		||||
    contract = GPUContractAPI(cleos)
 | 
			
		||||
    try:
 | 
			
		||||
        async with maybe_open_contract_state_mngr(conn) as state_mngr:
 | 
			
		||||
        async with maybe_open_contract_state_mngr(contract) as state_mngr:
 | 
			
		||||
            n: trio.Nursery
 | 
			
		||||
            async with trio.open_nursery() as n:
 | 
			
		||||
                if tui:
 | 
			
		||||
                    n.start_soon(tui.run)
 | 
			
		||||
 | 
			
		||||
                n.start_soon(dgpu_serve_forever, config, conn, state_mngr)
 | 
			
		||||
                n.start_soon(dgpu_serve_forever, config, contract, ipfs_api, state_mngr)
 | 
			
		||||
 | 
			
		||||
                yield conn, state_mngr
 | 
			
		||||
                yield contract, ipfs_api, state_mngr
 | 
			
		||||
 | 
			
		||||
                n.cancel_scope.cancel()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,20 +8,21 @@ from skynet.config import DgpuConfig as Config
 | 
			
		|||
from skynet.types import (
 | 
			
		||||
    BodyV0
 | 
			
		||||
)
 | 
			
		||||
from skynet.contract import GPUContractAPI
 | 
			
		||||
from skynet.constants import MODELS
 | 
			
		||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_img
 | 
			
		||||
from skynet.dgpu.errors import DGPUComputeError
 | 
			
		||||
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
 | 
			
		||||
from skynet.dgpu.compute import maybe_load_model, compute_one
 | 
			
		||||
from skynet.dgpu.network import (
 | 
			
		||||
    NetConnector,
 | 
			
		||||
    ContractState,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def maybe_update_tui_balance(conn: NetConnector):
 | 
			
		||||
async def maybe_update_tui_balance(contract: GPUContractAPI):
 | 
			
		||||
    async def _fn(tui):
 | 
			
		||||
        # update balance
 | 
			
		||||
        balance = await conn.contract.get_user(tui.config.account).balance
 | 
			
		||||
        balance = await contract.get_user(tui.config.account).balance
 | 
			
		||||
        tui.set_header_text(new_balance=f'balance: {balance}')
 | 
			
		||||
 | 
			
		||||
    await maybe_update_tui_async(_fn)
 | 
			
		||||
| 
						 | 
				
			
			@ -29,7 +30,8 @@ async def maybe_update_tui_balance(conn: NetConnector):
 | 
			
		|||
 | 
			
		||||
async def maybe_serve_one(
 | 
			
		||||
    config: Config,
 | 
			
		||||
    conn: NetConnector,
 | 
			
		||||
    contract: GPUContractAPI,
 | 
			
		||||
    ipfs_api: AsyncIPFSHTTP,
 | 
			
		||||
    state_mngr: ContractState,
 | 
			
		||||
):
 | 
			
		||||
    logging.info(f'maybe serve request pi: {state_mngr.poll_index}')
 | 
			
		||||
| 
						 | 
				
			
			@ -90,7 +92,7 @@ async def maybe_serve_one(
 | 
			
		|||
                    # user `GPUConnector` to IO with
 | 
			
		||||
                    # storage layer to seed the compute
 | 
			
		||||
                    # task.
 | 
			
		||||
                    img = await conn.get_input_data(_input)
 | 
			
		||||
                    img = await get_ipfs_img(f'https://{config.ipfs_domain}/ipfs/{_input}')
 | 
			
		||||
                    inputs.append(img)
 | 
			
		||||
                    logging.info(f'retrieved {_input}!')
 | 
			
		||||
                    break
 | 
			
		||||
| 
						 | 
				
			
			@ -105,7 +107,7 @@ async def maybe_serve_one(
 | 
			
		|||
 | 
			
		||||
    # TODO: validate request
 | 
			
		||||
 | 
			
		||||
    resp = await conn.contract.accept_work(config.account, req.id)
 | 
			
		||||
    resp = await contract.accept_work(config.account, req.id)
 | 
			
		||||
    if not resp or 'code' in resp:
 | 
			
		||||
        logging.info('begin_work error, probably being worked on already... skip.')
 | 
			
		||||
        return
 | 
			
		||||
| 
						 | 
				
			
			@ -142,13 +144,13 @@ async def maybe_serve_one(
 | 
			
		|||
 | 
			
		||||
            maybe_update_tui(lambda tui: tui.set_progress(total_step))
 | 
			
		||||
 | 
			
		||||
            ipfs_hash = await conn.publish_on_ipfs(output, typ=output_type)
 | 
			
		||||
            ipfs_hash = await ipfs_api.publish(output, type=output_type)
 | 
			
		||||
 | 
			
		||||
            await conn.contract.submit_work(config.account, req.id, output_hash, ipfs_hash)
 | 
			
		||||
            await contract.submit_work(config.account, req.id, output_hash, ipfs_hash)
 | 
			
		||||
 | 
			
		||||
            await state_mngr.update_state()
 | 
			
		||||
 | 
			
		||||
            await maybe_update_tui_balance(conn)
 | 
			
		||||
            await maybe_update_tui_balance(contract)
 | 
			
		||||
 | 
			
		||||
            await state_mngr.update_state()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -160,15 +162,17 @@ async def maybe_serve_one(
 | 
			
		|||
            await state_mngr.update_state()
 | 
			
		||||
 | 
			
		||||
            if state_mngr.is_request_in_progress(req.id):
 | 
			
		||||
                await conn.contract.cancel_work(config.account, req.id, 'reason not provided')
 | 
			
		||||
                await contract.cancel_work(config.account, req.id, 'reason not provided')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def dgpu_serve_forever(
 | 
			
		||||
    config: Config,
 | 
			
		||||
    conn: NetConnector,
 | 
			
		||||
    contract: GPUContractAPI,
 | 
			
		||||
    ipfs_api: AsyncIPFSHTTP,
 | 
			
		||||
    state_mngr: ContractState
 | 
			
		||||
):
 | 
			
		||||
    await maybe_update_tui_balance(conn)
 | 
			
		||||
    await maybe_update_tui_balance(contract)
 | 
			
		||||
    maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=config.account))
 | 
			
		||||
 | 
			
		||||
    last_poll_idx = -1
 | 
			
		||||
    try:
 | 
			
		||||
| 
						 | 
				
			
			@ -180,7 +184,7 @@ async def dgpu_serve_forever(
 | 
			
		|||
 | 
			
		||||
            last_poll_idx = state_mngr.poll_index
 | 
			
		||||
 | 
			
		||||
            await maybe_serve_one(config, conn, state_mngr)
 | 
			
		||||
            await maybe_serve_one(config, contract, ipfs_api, state_mngr)
 | 
			
		||||
 | 
			
		||||
    except KeyboardInterrupt:
 | 
			
		||||
        ...
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,7 @@
 | 
			
		|||
import io
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
import random
 | 
			
		||||
import logging
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from contextlib import asynccontextmanager as acm
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -13,22 +11,15 @@ import anyio
 | 
			
		|||
import httpx
 | 
			
		||||
import outcome
 | 
			
		||||
import msgspec
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from leap.cleos import CLEOS
 | 
			
		||||
from skynet.dgpu.tui import maybe_update_tui
 | 
			
		||||
from skynet.config import DgpuConfig as Config, load_skynet_toml
 | 
			
		||||
from skynet.config import load_skynet_toml
 | 
			
		||||
from skynet.contract import GPUContractAPI
 | 
			
		||||
from skynet.types import (
 | 
			
		||||
    BodyV0,
 | 
			
		||||
    RequestV1,
 | 
			
		||||
    WorkerStatusV0,
 | 
			
		||||
    ResultV0
 | 
			
		||||
)
 | 
			
		||||
from skynet.contract import GPUContractAPI
 | 
			
		||||
 | 
			
		||||
from skynet.ipfs import (
 | 
			
		||||
    AsyncIPFSHTTP,
 | 
			
		||||
    get_ipfs_file,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
REQUEST_UPDATE_TIME: int = 3
 | 
			
		||||
| 
						 | 
				
			
			@ -53,77 +44,6 @@ async def failable(fn: partial, ret_fail=None):
 | 
			
		|||
            return o.unwrap()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NetConnector:
 | 
			
		||||
    '''
 | 
			
		||||
    An API for connecting to and conducting various "high level"
 | 
			
		||||
    network-service operations in the skynet.
 | 
			
		||||
 | 
			
		||||
    - skynet user account creds
 | 
			
		||||
    - hyperion API
 | 
			
		||||
    - IPFs client
 | 
			
		||||
    - CLEOS client
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def __init__(self, config: Config):
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.cleos = CLEOS(endpoint=config.node_url)
 | 
			
		||||
        self.cleos.import_key(config.account, config.key)
 | 
			
		||||
        abi = self.cleos.get_abi('gpu.scd')
 | 
			
		||||
        self.cleos.load_abi('gpu.scd', abi)
 | 
			
		||||
 | 
			
		||||
        self.contract = GPUContractAPI(self.cleos)
 | 
			
		||||
 | 
			
		||||
        self.ipfs_client = AsyncIPFSHTTP(config.ipfs_url)
 | 
			
		||||
 | 
			
		||||
        maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.config.account))
 | 
			
		||||
 | 
			
		||||
    # IPFS helpers
 | 
			
		||||
    async def publish_on_ipfs(self, raw, typ: str = 'png'):
 | 
			
		||||
        Path('ipfs-staging').mkdir(exist_ok=True)
 | 
			
		||||
        logging.info('publish_on_ipfs')
 | 
			
		||||
 | 
			
		||||
        target_file = ''
 | 
			
		||||
        match typ:
 | 
			
		||||
            case 'png':
 | 
			
		||||
                raw: Image
 | 
			
		||||
                target_file = 'ipfs-staging/image.png'
 | 
			
		||||
                raw.save(target_file)
 | 
			
		||||
 | 
			
		||||
            case _:
 | 
			
		||||
                raise ValueError(f'Unsupported output type: {typ}')
 | 
			
		||||
 | 
			
		||||
        file_info = await self.ipfs_client.add(Path(target_file))
 | 
			
		||||
        file_cid = file_info['Hash']
 | 
			
		||||
        logging.info(f'added file to ipfs, CID: {file_cid}')
 | 
			
		||||
 | 
			
		||||
        await self.ipfs_client.pin(file_cid)
 | 
			
		||||
        logging.info(f'pinned {file_cid}')
 | 
			
		||||
 | 
			
		||||
        return file_cid
 | 
			
		||||
 | 
			
		||||
    async def get_input_data(self, ipfs_hash: str) -> Image:
 | 
			
		||||
        '''
 | 
			
		||||
        Retrieve an input (image) from the IPFs layer.
 | 
			
		||||
 | 
			
		||||
        Normally used to retreive seed (visual) content previously
 | 
			
		||||
        generated/validated by the network to be fed to some
 | 
			
		||||
        consuming AI model.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        link = f'https://{self.config.ipfs_domain}/ipfs/{ipfs_hash}'
 | 
			
		||||
 | 
			
		||||
        res = await get_ipfs_file(link, timeout=1)
 | 
			
		||||
        if not res or res.status_code != 200:
 | 
			
		||||
            logging.warning(f'couldn\'t get ipfs binary data at {link}!')
 | 
			
		||||
 | 
			
		||||
        # attempt to decode as image
 | 
			
		||||
        input_data = Image.open(io.BytesIO(res.read()))
 | 
			
		||||
        logging.info('decoded as image successfully')
 | 
			
		||||
 | 
			
		||||
        return input_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_reward_to_int(reward_str):
 | 
			
		||||
    int_part, decimal_part = (
 | 
			
		||||
        reward_str.split('.')[0],
 | 
			
		||||
| 
						 | 
				
			
			@ -134,8 +54,11 @@ def convert_reward_to_int(reward_str):
 | 
			
		|||
 | 
			
		||||
class ContractState:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, conn: NetConnector):
 | 
			
		||||
        self._conn = conn
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        contract: GPUContractAPI
 | 
			
		||||
    ):
 | 
			
		||||
        self.contract = contract
 | 
			
		||||
 | 
			
		||||
        self._config = load_skynet_toml().dgpu
 | 
			
		||||
        self._poll_index = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -151,10 +74,10 @@ class ContractState:
 | 
			
		|||
        return self._poll_index
 | 
			
		||||
 | 
			
		||||
    async def _fetch_results(self):
 | 
			
		||||
        self._results = await self._conn.contract.get_worker_results(self._config.account)
 | 
			
		||||
        self._results = await self.contract.get_worker_results(self._config.account)
 | 
			
		||||
 | 
			
		||||
    async def _fetch_statuses_for_id(self, rid: int):
 | 
			
		||||
        self._status_by_rid[rid] = await self._conn.contract.get_statuses_for_request(rid)
 | 
			
		||||
        self._status_by_rid[rid] = await self.contract.get_statuses_for_request(rid)
 | 
			
		||||
 | 
			
		||||
    async def update_state(self):
 | 
			
		||||
        '''
 | 
			
		||||
| 
						 | 
				
			
			@ -162,7 +85,7 @@ class ContractState:
 | 
			
		|||
 | 
			
		||||
        '''
 | 
			
		||||
        # raw queue from chain
 | 
			
		||||
        _queue = await self._conn.contract.get_requests_since(3600)
 | 
			
		||||
        _queue = await self.contract.get_requests_since(3600)
 | 
			
		||||
 | 
			
		||||
        # filter out invalids
 | 
			
		||||
        self._queue = []
 | 
			
		||||
| 
						 | 
				
			
			@ -261,7 +184,7 @@ class ContractState:
 | 
			
		|||
__state_mngr = None
 | 
			
		||||
 | 
			
		||||
@acm
 | 
			
		||||
async def maybe_open_contract_state_mngr(conn: NetConnector):
 | 
			
		||||
async def maybe_open_contract_state_mngr(contract: GPUContractAPI):
 | 
			
		||||
    global __state_mngr
 | 
			
		||||
 | 
			
		||||
    if __state_mngr:
 | 
			
		||||
| 
						 | 
				
			
			@ -270,7 +193,7 @@ async def maybe_open_contract_state_mngr(conn: NetConnector):
 | 
			
		|||
 | 
			
		||||
    config = load_skynet_toml().dgpu
 | 
			
		||||
 | 
			
		||||
    mngr = ContractState(conn)
 | 
			
		||||
    mngr = ContractState(contract)
 | 
			
		||||
    async with trio.open_nursery() as n:
 | 
			
		||||
        await mngr.update_state()
 | 
			
		||||
        n.start_soon(mngr._state_update_task, config.poll_time)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,8 +1,9 @@
 | 
			
		|||
import io
 | 
			
		||||
import logging
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import httpx
 | 
			
		||||
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
class IPFSClientException(Exception):
 | 
			
		||||
    ...
 | 
			
		||||
| 
						 | 
				
			
			@ -53,14 +54,37 @@ class AsyncIPFSHTTP:
 | 
			
		|||
            params=kwargs
 | 
			
		||||
        ))['Peers']
 | 
			
		||||
 | 
			
		||||
    async def publish(self, raw, type: str = 'png'):
 | 
			
		||||
        stage = Path('/tmp/ipfs-staging')
 | 
			
		||||
        stage.mkdir(exist_ok=True)
 | 
			
		||||
        logging.info('publish_on_ipfs')
 | 
			
		||||
 | 
			
		||||
async def get_ipfs_file(ipfs_link: str, timeout: int = 60 * 5):
 | 
			
		||||
        target_file = ''
 | 
			
		||||
        match type:
 | 
			
		||||
            case 'png':
 | 
			
		||||
                raw: Image
 | 
			
		||||
                target_file = stage / 'image.png'
 | 
			
		||||
                raw.save(target_file)
 | 
			
		||||
 | 
			
		||||
            case _:
 | 
			
		||||
                raise ValueError(f'Unsupported output type: {type}')
 | 
			
		||||
 | 
			
		||||
        file_info = await self.add(Path(target_file))
 | 
			
		||||
        file_cid = file_info['Hash']
 | 
			
		||||
        logging.info(f'added file to ipfs, CID: {file_cid}')
 | 
			
		||||
 | 
			
		||||
        await self.pin(file_cid)
 | 
			
		||||
        logging.info(f'pinned {file_cid}')
 | 
			
		||||
 | 
			
		||||
        return file_cid
 | 
			
		||||
 | 
			
		||||
async def get_ipfs_img(ipfs_link: str, timeout: int = 3) -> Image:
 | 
			
		||||
    logging.info(f'attempting to get image at {ipfs_link}')
 | 
			
		||||
    resp = None
 | 
			
		||||
    for _ in range(timeout):
 | 
			
		||||
        try:
 | 
			
		||||
            async with httpx.AsyncClient() as client:
 | 
			
		||||
                resp = await client.get(ipfs_link, timeout=3)
 | 
			
		||||
                resp = await client.get(ipfs_link, timeout=timeout)
 | 
			
		||||
 | 
			
		||||
        except httpx.RequestError as e:
 | 
			
		||||
            logging.warning(f'Request error: {e}')
 | 
			
		||||
| 
						 | 
				
			
			@ -71,6 +95,14 @@ async def get_ipfs_file(ipfs_link: str, timeout: int = 60 * 5):
 | 
			
		|||
    if resp:
 | 
			
		||||
        logging.info(f'status_code: {resp.status_code}')
 | 
			
		||||
    else:
 | 
			
		||||
        logging.error(f'timeout')
 | 
			
		||||
        logging.error('timeout')
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    if resp.status_code != 200:
 | 
			
		||||
        logging.warning(f'couldn\'t get ipfs binary data at {ipfs_link}!')
 | 
			
		||||
        return resp
 | 
			
		||||
 | 
			
		||||
    img = Image.open(io.BytesIO(resp.read()))
 | 
			
		||||
    logging.info('decoded img successfully')
 | 
			
		||||
 | 
			
		||||
    return img
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -265,6 +265,6 @@ async def test_full_flow(inject_mockers, skynet_cleos, ipfs_node):
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
    # open worker and fill request
 | 
			
		||||
    async with open_test_worker(cleos, ipfs_node) as (_conn, state_mngr):
 | 
			
		||||
    async with open_test_worker(cleos, ipfs_node) as (_contract, _ipfs_api, state_mngr):
 | 
			
		||||
        while state_mngr.queue_len > 0:
 | 
			
		||||
            await trio.sleep(1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue