Add gpu worker tests to debug tractor lifecycle stuff also logging

Minor tweak to cuda docker image
Pin tractor branch
Add triton to cuda reqs
pull/2/head
Guillermo Rodriguez 2022-12-11 08:32:25 -03:00
parent 74d2426793
commit 139aea67b1
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
7 changed files with 116 additions and 5 deletions

View File

@ -5,7 +5,7 @@ env DEBIAN_FRONTEND=noninteractive
workdir /skynet
copy requirements.* .
copy requirements.* ./
run pip install -U pip ninja
run pip install -r requirements.cuda.0.txt

View File

@ -1,5 +1,6 @@
pdbpp
scipy
triton
accelerate
transformers
huggingface_hub

View File

@ -5,4 +5,4 @@ aiohttp
msgspec
trio_asyncio
git+https://github.com/goodboy/tractor.git@master#egg=tractor
git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor

View File

@ -93,8 +93,9 @@ async def open_dgpu_node(
):
logging.info(f'starting {dgpu_max_tasks} gpu workers')
async with tractor.gather_contexts((
ctx.open_context(
portal.open_context(
open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
for portal in portal_map.values()
)) as contexts:
contexts = {i: ctx for i, ctx in enumerate(contexts)}
for i, ctx in contexts.items():

View File

@ -45,11 +45,13 @@ async def open_gpu_worker(
start_algo: str,
mem_fraction: float
):
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
log.info(f'starting gpu worker with algo {start_algo}...')
current_algo = start_algo
with torch.no_grad():
pipe = pipeline_for(current_algo, mem_fraction)
log.info('pipeline loaded')
await ctx.started()
async with ctx.open_stream() as bus:
async for ireq in bus:
if ireq.algo != current_algo:

View File

@ -1,8 +1,9 @@
docker run \
-it \
--rm \
--gpus=all \
--mount type=bind,source="$(pwd)",target=/skynet \
skynet:runtime-cuda \
bash -c \
"cd /skynet && pip install -e . && \
pytest tests/test_dgpu.py --log-cli-level=info"
pytest $1 --log-cli-level=info"

View File

@ -6,8 +6,10 @@ import logging
import trio
import pynng
import tractor
import trio_asyncio
from skynet_bot.gpu import open_gpu_worker
from skynet_bot.dgpu import open_dgpu_node
from skynet_bot.types import *
from skynet_bot.brain import run_skynet
@ -15,6 +17,110 @@ from skynet_bot.constants import *
from skynet_bot.frontend import open_skynet_rpc, rpc_call
@tractor.context
async def open_fake_worker(
ctx: tractor.Context,
start_algo: str,
mem_fraction: float
):
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
log.info(f'starting gpu worker with algo {start_algo}...')
current_algo = start_algo
log.info('pipeline loaded')
await ctx.started()
async with ctx.open_stream() as bus:
async for ireq in bus:
if ireq:
await bus.send('hello!')
else:
break
def test_gpu_worker():
log = tractor.log.get_logger(name='root', _root_name='skynet')
async def main():
async with (
tractor.open_nursery(debug_mode=True) as an,
trio.open_nursery() as n
):
portal = await an.start_actor(
'gpu_worker',
enable_modules=[__name__],
debug_mode=True
)
log.info('portal opened')
async with (
portal.open_context(
open_fake_worker,
start_algo='midj',
mem_fraction=0.6
) as (ctx, _),
ctx.open_stream() as stream,
):
log.info('opened worker sending req...')
ireq = ImageGenRequest(
prompt='a red tractor on a wheat field',
step=28,
width=512, height=512,
guidance=10, seed=None,
algo='midj', upscaler=None)
await stream.send(ireq)
log.info('sent, await respnse')
async for msg in stream:
log.info(f'got {msg}')
break
assert msg == 'hello!'
await stream.send(None)
log.info('done.')
await portal.cancel_actor()
trio.run(main)
def test_gpu_two_workers():
async def main():
outputs = []
async with (
tractor.open_actor_cluster(
modules=[__name__],
count=2,
names=[0, 1]) as portal_map,
tractor.trionics.gather_contexts((
portal.open_context(
open_fake_worker,
start_algo='midj',
mem_fraction=0.333)
for portal in portal_map.values()
)) as contexts,
trio.open_nursery() as n
):
ireq = ImageGenRequest(
prompt='a red tractor on a wheat field',
step=28,
width=512, height=512,
guidance=10, seed=None,
algo='midj', upscaler=None)
async def get_img(i):
ctx = contexts[i]
async with ctx.open_stream() as stream:
await stream.send(ireq)
async for img in stream:
outputs[i] = img
await portal_map[i].cancel_actor()
n.start_soon(get_img, 0)
n.start_soon(get_img, 1)
assert len(outputs) == 2
trio.run(main)
def test_dgpu_simple():
async def main():
async with trio.open_nursery() as n: