mirror of https://github.com/skygpu/skynet.git
				
				
				
			Add gpu worker tests to debug tractor lifecycle stuff also logging
Minor tweak to cuda docker image Pin tractor branch Add triton to cuda reqspull/25/head
							parent
							
								
									66d997c039
								
							
						
					
					
						commit
						20b377dd32
					
				| 
						 | 
				
			
			@ -5,7 +5,7 @@ env DEBIAN_FRONTEND=noninteractive
 | 
			
		|||
 | 
			
		||||
workdir /skynet
 | 
			
		||||
 | 
			
		||||
copy requirements.* .
 | 
			
		||||
copy requirements.* ./
 | 
			
		||||
 | 
			
		||||
run pip install -U pip ninja
 | 
			
		||||
run pip install -r requirements.cuda.0.txt
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,6 @@
 | 
			
		|||
pdbpp
 | 
			
		||||
scipy
 | 
			
		||||
triton
 | 
			
		||||
accelerate
 | 
			
		||||
transformers
 | 
			
		||||
huggingface_hub
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,4 +5,4 @@ aiohttp
 | 
			
		|||
msgspec
 | 
			
		||||
trio_asyncio
 | 
			
		||||
 | 
			
		||||
git+https://github.com/goodboy/tractor.git@master#egg=tractor
 | 
			
		||||
git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -93,8 +93,9 @@ async def open_dgpu_node(
 | 
			
		|||
        ):
 | 
			
		||||
            logging.info(f'starting {dgpu_max_tasks} gpu workers')
 | 
			
		||||
            async with tractor.gather_contexts((
 | 
			
		||||
                ctx.open_context(
 | 
			
		||||
                portal.open_context(
 | 
			
		||||
                    open_gpu_worker, algo, 1.0 / dgpu_max_tasks)
 | 
			
		||||
                for portal in portal_map.values()
 | 
			
		||||
            )) as contexts:
 | 
			
		||||
                contexts = {i: ctx for i, ctx in enumerate(contexts)}
 | 
			
		||||
                for i, ctx in contexts.items():
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,11 +45,13 @@ async def open_gpu_worker(
 | 
			
		|||
    start_algo: str,
 | 
			
		||||
    mem_fraction: float
 | 
			
		||||
):
 | 
			
		||||
    log = tractor.log.get_logger(name='gpu', _root_name='skynet')
 | 
			
		||||
    log.info(f'starting gpu worker with algo {start_algo}...')
 | 
			
		||||
    current_algo = start_algo
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        pipe = pipeline_for(current_algo, mem_fraction)
 | 
			
		||||
        log.info('pipeline loaded')
 | 
			
		||||
        await ctx.started()
 | 
			
		||||
 | 
			
		||||
        async with ctx.open_stream() as bus:
 | 
			
		||||
            async for ireq in bus:
 | 
			
		||||
                if ireq.algo != current_algo:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										3
									
								
								test.sh
								
								
								
								
							
							
						
						
									
										3
									
								
								test.sh
								
								
								
								
							| 
						 | 
				
			
			@ -1,8 +1,9 @@
 | 
			
		|||
docker run \
 | 
			
		||||
    -it \
 | 
			
		||||
    --rm \
 | 
			
		||||
    --gpus=all \
 | 
			
		||||
    --mount type=bind,source="$(pwd)",target=/skynet \
 | 
			
		||||
    skynet:runtime-cuda \
 | 
			
		||||
    bash -c \
 | 
			
		||||
        "cd /skynet && pip install -e . && \
 | 
			
		||||
        pytest tests/test_dgpu.py --log-cli-level=info"
 | 
			
		||||
        pytest $1 --log-cli-level=info"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,8 +6,10 @@ import logging
 | 
			
		|||
 | 
			
		||||
import trio
 | 
			
		||||
import pynng
 | 
			
		||||
import tractor
 | 
			
		||||
import trio_asyncio
 | 
			
		||||
 | 
			
		||||
from skynet_bot.gpu import open_gpu_worker
 | 
			
		||||
from skynet_bot.dgpu import open_dgpu_node
 | 
			
		||||
from skynet_bot.types import *
 | 
			
		||||
from skynet_bot.brain import run_skynet
 | 
			
		||||
| 
						 | 
				
			
			@ -15,6 +17,110 @@ from skynet_bot.constants import *
 | 
			
		|||
from skynet_bot.frontend import open_skynet_rpc, rpc_call
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@tractor.context
 | 
			
		||||
async def open_fake_worker(
 | 
			
		||||
    ctx: tractor.Context,
 | 
			
		||||
    start_algo: str,
 | 
			
		||||
    mem_fraction: float
 | 
			
		||||
):
 | 
			
		||||
    log = tractor.log.get_logger(name='gpu', _root_name='skynet')
 | 
			
		||||
    log.info(f'starting gpu worker with algo {start_algo}...')
 | 
			
		||||
    current_algo = start_algo
 | 
			
		||||
    log.info('pipeline loaded')
 | 
			
		||||
    await ctx.started()
 | 
			
		||||
    async with ctx.open_stream() as bus:
 | 
			
		||||
        async for ireq in bus:
 | 
			
		||||
            if ireq:
 | 
			
		||||
                await bus.send('hello!')
 | 
			
		||||
            else:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
def test_gpu_worker():
 | 
			
		||||
    log = tractor.log.get_logger(name='root', _root_name='skynet')
 | 
			
		||||
    async def main():
 | 
			
		||||
        async with (
 | 
			
		||||
            tractor.open_nursery(debug_mode=True) as an,
 | 
			
		||||
            trio.open_nursery() as n
 | 
			
		||||
        ):
 | 
			
		||||
            portal = await an.start_actor(
 | 
			
		||||
                'gpu_worker',
 | 
			
		||||
                enable_modules=[__name__],
 | 
			
		||||
                debug_mode=True
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            log.info('portal opened')
 | 
			
		||||
            async with (
 | 
			
		||||
                portal.open_context(
 | 
			
		||||
                    open_fake_worker,
 | 
			
		||||
                    start_algo='midj',
 | 
			
		||||
                    mem_fraction=0.6
 | 
			
		||||
                ) as (ctx, _),
 | 
			
		||||
                ctx.open_stream() as stream,
 | 
			
		||||
            ):
 | 
			
		||||
                log.info('opened worker sending req...')
 | 
			
		||||
                ireq = ImageGenRequest(
 | 
			
		||||
                    prompt='a red tractor on a wheat field',
 | 
			
		||||
                    step=28,
 | 
			
		||||
                    width=512, height=512,
 | 
			
		||||
                    guidance=10, seed=None,
 | 
			
		||||
                    algo='midj', upscaler=None)
 | 
			
		||||
 | 
			
		||||
                await stream.send(ireq)
 | 
			
		||||
                log.info('sent, await respnse')
 | 
			
		||||
                async for msg in stream:
 | 
			
		||||
                    log.info(f'got {msg}')
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                assert msg == 'hello!'
 | 
			
		||||
                await stream.send(None)
 | 
			
		||||
                log.info('done.')
 | 
			
		||||
 | 
			
		||||
            await portal.cancel_actor()
 | 
			
		||||
 | 
			
		||||
    trio.run(main)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_gpu_two_workers():
 | 
			
		||||
    async def main():
 | 
			
		||||
        outputs = []
 | 
			
		||||
        async with (
 | 
			
		||||
            tractor.open_actor_cluster(
 | 
			
		||||
                modules=[__name__],
 | 
			
		||||
                count=2,
 | 
			
		||||
                names=[0, 1]) as portal_map,
 | 
			
		||||
            tractor.trionics.gather_contexts((
 | 
			
		||||
                portal.open_context(
 | 
			
		||||
                    open_fake_worker,
 | 
			
		||||
                    start_algo='midj',
 | 
			
		||||
                    mem_fraction=0.333)
 | 
			
		||||
                for portal in portal_map.values()
 | 
			
		||||
            )) as contexts,
 | 
			
		||||
            trio.open_nursery() as n
 | 
			
		||||
        ):
 | 
			
		||||
            ireq = ImageGenRequest(
 | 
			
		||||
                prompt='a red tractor on a wheat field',
 | 
			
		||||
                step=28,
 | 
			
		||||
                width=512, height=512,
 | 
			
		||||
                guidance=10, seed=None,
 | 
			
		||||
                algo='midj', upscaler=None)
 | 
			
		||||
 | 
			
		||||
            async def get_img(i):
 | 
			
		||||
                ctx = contexts[i]
 | 
			
		||||
                async with ctx.open_stream() as stream:
 | 
			
		||||
                    await stream.send(ireq)
 | 
			
		||||
                    async for img in stream:
 | 
			
		||||
                        outputs[i] = img
 | 
			
		||||
                        await portal_map[i].cancel_actor()
 | 
			
		||||
 | 
			
		||||
            n.start_soon(get_img, 0)
 | 
			
		||||
            n.start_soon(get_img, 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        assert len(outputs) == 2
 | 
			
		||||
 | 
			
		||||
    trio.run(main)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_dgpu_simple():
 | 
			
		||||
    async def main():
 | 
			
		||||
        async with trio.open_nursery() as n:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue