mirror of https://github.com/skygpu/skynet.git
Add gpu worker tests to debug tractor lifecycle stuff also logging
Minor tweak to cuda docker image Pin tractor branch Add triton to cuda reqspull/2/head
parent
74d2426793
commit
139aea67b1
|
@ -5,7 +5,7 @@ env DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
workdir /skynet
|
workdir /skynet
|
||||||
|
|
||||||
copy requirements.* .
|
copy requirements.* ./
|
||||||
|
|
||||||
run pip install -U pip ninja
|
run pip install -U pip ninja
|
||||||
run pip install -r requirements.cuda.0.txt
|
run pip install -r requirements.cuda.0.txt
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
pdbpp
|
pdbpp
|
||||||
scipy
|
scipy
|
||||||
|
triton
|
||||||
accelerate
|
accelerate
|
||||||
transformers
|
transformers
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
|
|
|
@ -5,4 +5,4 @@ aiohttp
|
||||||
msgspec
|
msgspec
|
||||||
trio_asyncio
|
trio_asyncio
|
||||||
|
|
||||||
git+https://github.com/goodboy/tractor.git@master#egg=tractor
|
git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor
|
||||||
|
|
|
@ -93,8 +93,9 @@ async def open_dgpu_node(
|
||||||
):
|
):
|
||||||
logging.info(f'starting {dgpu_max_tasks} gpu workers')
|
logging.info(f'starting {dgpu_max_tasks} gpu workers')
|
||||||
async with tractor.gather_contexts((
|
async with tractor.gather_contexts((
|
||||||
ctx.open_context(
|
portal.open_context(
|
||||||
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
|
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
|
||||||
|
for portal in portal_map.values()
|
||||||
)) as contexts:
|
)) as contexts:
|
||||||
contexts = {i: ctx for i, ctx in enumerate(contexts)}
|
contexts = {i: ctx for i, ctx in enumerate(contexts)}
|
||||||
for i, ctx in contexts.items():
|
for i, ctx in contexts.items():
|
||||||
|
|
|
@ -45,11 +45,13 @@ async def open_gpu_worker(
|
||||||
start_algo: str,
|
start_algo: str,
|
||||||
mem_fraction: float
|
mem_fraction: float
|
||||||
):
|
):
|
||||||
|
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
|
||||||
|
log.info(f'starting gpu worker with algo {start_algo}...')
|
||||||
current_algo = start_algo
|
current_algo = start_algo
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pipe = pipeline_for(current_algo, mem_fraction)
|
pipe = pipeline_for(current_algo, mem_fraction)
|
||||||
|
log.info('pipeline loaded')
|
||||||
await ctx.started()
|
await ctx.started()
|
||||||
|
|
||||||
async with ctx.open_stream() as bus:
|
async with ctx.open_stream() as bus:
|
||||||
async for ireq in bus:
|
async for ireq in bus:
|
||||||
if ireq.algo != current_algo:
|
if ireq.algo != current_algo:
|
||||||
|
|
3
test.sh
3
test.sh
|
@ -1,8 +1,9 @@
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--rm \
|
--rm \
|
||||||
|
--gpus=all \
|
||||||
--mount type=bind,source="$(pwd)",target=/skynet \
|
--mount type=bind,source="$(pwd)",target=/skynet \
|
||||||
skynet:runtime-cuda \
|
skynet:runtime-cuda \
|
||||||
bash -c \
|
bash -c \
|
||||||
"cd /skynet && pip install -e . && \
|
"cd /skynet && pip install -e . && \
|
||||||
pytest tests/test_dgpu.py --log-cli-level=info"
|
pytest $1 --log-cli-level=info"
|
||||||
|
|
|
@ -6,8 +6,10 @@ import logging
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import pynng
|
import pynng
|
||||||
|
import tractor
|
||||||
import trio_asyncio
|
import trio_asyncio
|
||||||
|
|
||||||
|
from skynet_bot.gpu import open_gpu_worker
|
||||||
from skynet_bot.dgpu import open_dgpu_node
|
from skynet_bot.dgpu import open_dgpu_node
|
||||||
from skynet_bot.types import *
|
from skynet_bot.types import *
|
||||||
from skynet_bot.brain import run_skynet
|
from skynet_bot.brain import run_skynet
|
||||||
|
@ -15,6 +17,110 @@ from skynet_bot.constants import *
|
||||||
from skynet_bot.frontend import open_skynet_rpc, rpc_call
|
from skynet_bot.frontend import open_skynet_rpc, rpc_call
|
||||||
|
|
||||||
|
|
||||||
|
@tractor.context
|
||||||
|
async def open_fake_worker(
|
||||||
|
ctx: tractor.Context,
|
||||||
|
start_algo: str,
|
||||||
|
mem_fraction: float
|
||||||
|
):
|
||||||
|
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
|
||||||
|
log.info(f'starting gpu worker with algo {start_algo}...')
|
||||||
|
current_algo = start_algo
|
||||||
|
log.info('pipeline loaded')
|
||||||
|
await ctx.started()
|
||||||
|
async with ctx.open_stream() as bus:
|
||||||
|
async for ireq in bus:
|
||||||
|
if ireq:
|
||||||
|
await bus.send('hello!')
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def test_gpu_worker():
|
||||||
|
log = tractor.log.get_logger(name='root', _root_name='skynet')
|
||||||
|
async def main():
|
||||||
|
async with (
|
||||||
|
tractor.open_nursery(debug_mode=True) as an,
|
||||||
|
trio.open_nursery() as n
|
||||||
|
):
|
||||||
|
portal = await an.start_actor(
|
||||||
|
'gpu_worker',
|
||||||
|
enable_modules=[__name__],
|
||||||
|
debug_mode=True
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info('portal opened')
|
||||||
|
async with (
|
||||||
|
portal.open_context(
|
||||||
|
open_fake_worker,
|
||||||
|
start_algo='midj',
|
||||||
|
mem_fraction=0.6
|
||||||
|
) as (ctx, _),
|
||||||
|
ctx.open_stream() as stream,
|
||||||
|
):
|
||||||
|
log.info('opened worker sending req...')
|
||||||
|
ireq = ImageGenRequest(
|
||||||
|
prompt='a red tractor on a wheat field',
|
||||||
|
step=28,
|
||||||
|
width=512, height=512,
|
||||||
|
guidance=10, seed=None,
|
||||||
|
algo='midj', upscaler=None)
|
||||||
|
|
||||||
|
await stream.send(ireq)
|
||||||
|
log.info('sent, await respnse')
|
||||||
|
async for msg in stream:
|
||||||
|
log.info(f'got {msg}')
|
||||||
|
break
|
||||||
|
|
||||||
|
assert msg == 'hello!'
|
||||||
|
await stream.send(None)
|
||||||
|
log.info('done.')
|
||||||
|
|
||||||
|
await portal.cancel_actor()
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpu_two_workers():
|
||||||
|
async def main():
|
||||||
|
outputs = []
|
||||||
|
async with (
|
||||||
|
tractor.open_actor_cluster(
|
||||||
|
modules=[__name__],
|
||||||
|
count=2,
|
||||||
|
names=[0, 1]) as portal_map,
|
||||||
|
tractor.trionics.gather_contexts((
|
||||||
|
portal.open_context(
|
||||||
|
open_fake_worker,
|
||||||
|
start_algo='midj',
|
||||||
|
mem_fraction=0.333)
|
||||||
|
for portal in portal_map.values()
|
||||||
|
)) as contexts,
|
||||||
|
trio.open_nursery() as n
|
||||||
|
):
|
||||||
|
ireq = ImageGenRequest(
|
||||||
|
prompt='a red tractor on a wheat field',
|
||||||
|
step=28,
|
||||||
|
width=512, height=512,
|
||||||
|
guidance=10, seed=None,
|
||||||
|
algo='midj', upscaler=None)
|
||||||
|
|
||||||
|
async def get_img(i):
|
||||||
|
ctx = contexts[i]
|
||||||
|
async with ctx.open_stream() as stream:
|
||||||
|
await stream.send(ireq)
|
||||||
|
async for img in stream:
|
||||||
|
outputs[i] = img
|
||||||
|
await portal_map[i].cancel_actor()
|
||||||
|
|
||||||
|
n.start_soon(get_img, 0)
|
||||||
|
n.start_soon(get_img, 1)
|
||||||
|
|
||||||
|
|
||||||
|
assert len(outputs) == 2
|
||||||
|
|
||||||
|
trio.run(main)
|
||||||
|
|
||||||
|
|
||||||
def test_dgpu_simple():
|
def test_dgpu_simple():
|
||||||
async def main():
|
async def main():
|
||||||
async with trio.open_nursery() as n:
|
async with trio.open_nursery() as n:
|
||||||
|
|
Loading…
Reference in New Issue