#!/usr/bin/python import trio import json import uuid import logging import pynng import tractor from . import gpu from .gpu import open_gpu_worker from .types import * from .constants import * from .frontend import rpc_call async def open_dgpu_node( cert_name: str, key_name: Optional[str], rpc_address: str = DEFAULT_RPC_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR, dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS, initial_algos: str = DEFAULT_INITAL_ALGOS, security: bool = True ): logging.basicConfig(level=logging.INFO) name = uuid.uuid4() workers = initial_algos.copy() tasks = [None for _ in range(dgpu_max_tasks)] portal_map: dict[int, tractor.Portal] contexts: dict[int, tractor.Context] def get_next_worker(need_algo: str): nonlocal workers, tasks for task, algo in zip(workers, tasks): if need_algo == algo and not task: return workers.index(need_algo) return tasks.index(None) async def gpu_streamer( ctx: tractor.Context, nid: int ): nonlocal tasks async with ctx.open_stream() as stream: async for img in stream: tasks[nid]['res'] = img tasks[nid]['event'].set() async def gpu_compute_one(ireq: ImageGenRequest): wid = get_next_worker(ireq.algo) event = trio.Event() workers[wid] = ireq.algo tasks[wid] = { 'res': None, 'event': event} await contexts[i].send(ireq) await event.wait() img = tasks[wid]['res'] tasks[wid] = None return img async with open_skynet_rpc( security=security, cert_name=cert_name, key_name=key_name ) as rpc_call: with pynng.Bus0(dial=dgpu_address) as dgpu_sock: async def _process_dgpu_req(req: DGPUBusRequest): img = await gpu_compute_one( ImageGenRequest(**req.params)) await dgpu_sock.asend( bytes.fromhex(req.rid) + img) res = await rpc_call( name.hex, 'dgpu_online', {'max_tasks': dgpu_max_tasks}) logging.info(res) assert 'ok' in res.result async with ( tractor.open_actor_cluster( modules=['skynet_bot.gpu'], count=dgpu_max_tasks, names=[i for i in range(dgpu_max_tasks)] ) as portal_map, trio.open_nursery() as n ): logging.info(f'starting {dgpu_max_tasks} gpu workers') async with tractor.gather_contexts(( portal.open_context( open_gpu_worker, algo, 1.0 / dgpu_max_tasks) for portal in portal_map.values() )) as contexts: contexts = {i: ctx for i, ctx in enumerate(contexts)} for i, ctx in contexts.items(): n.start_soon( gpu_streamer, ctx, i) try: while True: msg = await dgpu_sock.arecv() req = DGPUBusRequest( **json.loads(msg.decode())) if req.nid != name.hex: continue logging.info(f'dgpu: {name}, req: {req}') n.start_soon( _process_dgpu_req, req) except KeyboardInterrupt: ... res = await rpc_call(name.hex, 'dgpu_offline') logging.info(res) assert 'ok' in res.result