mirror of https://github.com/skygpu/skynet.git
Move gpu lifecycle tests to its own file to make torch import optional
parent
139aea67b1
commit
c3852314a7
|
@ -17,110 +17,6 @@ from skynet_bot.constants import *
|
|||
from skynet_bot.frontend import open_skynet_rpc, rpc_call
|
||||
|
||||
|
||||
@tractor.context
|
||||
async def open_fake_worker(
|
||||
ctx: tractor.Context,
|
||||
start_algo: str,
|
||||
mem_fraction: float
|
||||
):
|
||||
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
|
||||
log.info(f'starting gpu worker with algo {start_algo}...')
|
||||
current_algo = start_algo
|
||||
log.info('pipeline loaded')
|
||||
await ctx.started()
|
||||
async with ctx.open_stream() as bus:
|
||||
async for ireq in bus:
|
||||
if ireq:
|
||||
await bus.send('hello!')
|
||||
else:
|
||||
break
|
||||
|
||||
def test_gpu_worker():
|
||||
log = tractor.log.get_logger(name='root', _root_name='skynet')
|
||||
async def main():
|
||||
async with (
|
||||
tractor.open_nursery(debug_mode=True) as an,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
portal = await an.start_actor(
|
||||
'gpu_worker',
|
||||
enable_modules=[__name__],
|
||||
debug_mode=True
|
||||
)
|
||||
|
||||
log.info('portal opened')
|
||||
async with (
|
||||
portal.open_context(
|
||||
open_fake_worker,
|
||||
start_algo='midj',
|
||||
mem_fraction=0.6
|
||||
) as (ctx, _),
|
||||
ctx.open_stream() as stream,
|
||||
):
|
||||
log.info('opened worker sending req...')
|
||||
ireq = ImageGenRequest(
|
||||
prompt='a red tractor on a wheat field',
|
||||
step=28,
|
||||
width=512, height=512,
|
||||
guidance=10, seed=None,
|
||||
algo='midj', upscaler=None)
|
||||
|
||||
await stream.send(ireq)
|
||||
log.info('sent, await respnse')
|
||||
async for msg in stream:
|
||||
log.info(f'got {msg}')
|
||||
break
|
||||
|
||||
assert msg == 'hello!'
|
||||
await stream.send(None)
|
||||
log.info('done.')
|
||||
|
||||
await portal.cancel_actor()
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_gpu_two_workers():
|
||||
async def main():
|
||||
outputs = []
|
||||
async with (
|
||||
tractor.open_actor_cluster(
|
||||
modules=[__name__],
|
||||
count=2,
|
||||
names=[0, 1]) as portal_map,
|
||||
tractor.trionics.gather_contexts((
|
||||
portal.open_context(
|
||||
open_fake_worker,
|
||||
start_algo='midj',
|
||||
mem_fraction=0.333)
|
||||
for portal in portal_map.values()
|
||||
)) as contexts,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
ireq = ImageGenRequest(
|
||||
prompt='a red tractor on a wheat field',
|
||||
step=28,
|
||||
width=512, height=512,
|
||||
guidance=10, seed=None,
|
||||
algo='midj', upscaler=None)
|
||||
|
||||
async def get_img(i):
|
||||
ctx = contexts[i]
|
||||
async with ctx.open_stream() as stream:
|
||||
await stream.send(ireq)
|
||||
async for img in stream:
|
||||
outputs[i] = img
|
||||
await portal_map[i].cancel_actor()
|
||||
|
||||
n.start_soon(get_img, 0)
|
||||
n.start_soon(get_img, 1)
|
||||
|
||||
|
||||
assert len(outputs) == 2
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_dgpu_simple():
|
||||
async def main():
|
||||
async with trio.open_nursery() as n:
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
import trio
|
||||
import tractor
|
||||
|
||||
from skynet_bot.types import *
|
||||
|
||||
@tractor.context
|
||||
async def open_fake_worker(
|
||||
ctx: tractor.Context,
|
||||
start_algo: str,
|
||||
mem_fraction: float
|
||||
):
|
||||
log = tractor.log.get_logger(name='gpu', _root_name='skynet')
|
||||
log.info(f'starting gpu worker with algo {start_algo}...')
|
||||
current_algo = start_algo
|
||||
log.info('pipeline loaded')
|
||||
await ctx.started()
|
||||
async with ctx.open_stream() as bus:
|
||||
async for ireq in bus:
|
||||
if ireq:
|
||||
await bus.send('hello!')
|
||||
else:
|
||||
break
|
||||
|
||||
def test_gpu_worker():
|
||||
log = tractor.log.get_logger(name='root', _root_name='skynet')
|
||||
async def main():
|
||||
async with (
|
||||
tractor.open_nursery(debug_mode=True) as an,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
portal = await an.start_actor(
|
||||
'gpu_worker',
|
||||
enable_modules=[__name__],
|
||||
debug_mode=True
|
||||
)
|
||||
|
||||
log.info('portal opened')
|
||||
async with (
|
||||
portal.open_context(
|
||||
open_fake_worker,
|
||||
start_algo='midj',
|
||||
mem_fraction=0.6
|
||||
) as (ctx, _),
|
||||
ctx.open_stream() as stream,
|
||||
):
|
||||
log.info('opened worker sending req...')
|
||||
ireq = ImageGenRequest(
|
||||
prompt='a red tractor on a wheat field',
|
||||
step=28,
|
||||
width=512, height=512,
|
||||
guidance=10, seed=None,
|
||||
algo='midj', upscaler=None)
|
||||
|
||||
await stream.send(ireq)
|
||||
log.info('sent, await respnse')
|
||||
async for msg in stream:
|
||||
log.info(f'got {msg}')
|
||||
break
|
||||
|
||||
assert msg == 'hello!'
|
||||
await stream.send(None)
|
||||
log.info('done.')
|
||||
|
||||
await portal.cancel_actor()
|
||||
|
||||
trio.run(main)
|
||||
|
||||
|
||||
def test_gpu_two_workers():
|
||||
async def main():
|
||||
outputs = []
|
||||
async with (
|
||||
tractor.open_actor_cluster(
|
||||
modules=[__name__],
|
||||
count=2,
|
||||
names=[0, 1]) as portal_map,
|
||||
tractor.trionics.gather_contexts((
|
||||
portal.open_context(
|
||||
open_fake_worker,
|
||||
start_algo='midj',
|
||||
mem_fraction=0.333)
|
||||
for portal in portal_map.values()
|
||||
)) as contexts,
|
||||
trio.open_nursery() as n
|
||||
):
|
||||
ireq = ImageGenRequest(
|
||||
prompt='a red tractor on a wheat field',
|
||||
step=28,
|
||||
width=512, height=512,
|
||||
guidance=10, seed=None,
|
||||
algo='midj', upscaler=None)
|
||||
|
||||
async def get_img(i):
|
||||
ctx = contexts[i]
|
||||
async with ctx.open_stream() as stream:
|
||||
await stream.send(ireq)
|
||||
async for img in stream:
|
||||
outputs[i] = img
|
||||
await portal_map[i].cancel_actor()
|
||||
|
||||
n.start_soon(get_img, 0)
|
||||
n.start_soon(get_img, 1)
|
||||
|
||||
|
||||
assert len(outputs) == 2
|
||||
|
||||
trio.run(main)
|
Loading…
Reference in New Issue