mirror of https://github.com/skygpu/skynet.git
				
				
				
			Make gpu work cancellable using trio threading apis!, also make docker always reinstall package for easier development
							parent
							
								
									47d9f59dbe
								
							
						
					
					
						commit
						01c78b5d20
					
				| 
						 | 
					@ -3,4 +3,6 @@
 | 
				
			||||||
export VIRTUAL_ENV='/skynet/.venv'
 | 
					export VIRTUAL_ENV='/skynet/.venv'
 | 
				
			||||||
poetry env use $VIRTUAL_ENV/bin/python
 | 
					poetry env use $VIRTUAL_ENV/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					poetry install
 | 
				
			||||||
 | 
					
 | 
				
			||||||
exec poetry run "$@"
 | 
					exec poetry run "$@"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,12 +3,15 @@
 | 
				
			||||||
# Skynet Memory Manager
 | 
					# Skynet Memory Manager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import gc
 | 
					import gc
 | 
				
			||||||
from hashlib import sha256
 | 
					 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from hashlib import sha256
 | 
				
			||||||
from diffusers import DiffusionPipeline
 | 
					from diffusers import DiffusionPipeline
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import trio
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
 | 
					from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
 | 
				
			||||||
from skynet.dgpu.errors import DGPUComputeError
 | 
					from skynet.dgpu.errors import DGPUComputeError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -122,10 +125,17 @@ class SkynetMM:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compute_one(
 | 
					    def compute_one(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
 | 
					        should_cancel_work,
 | 
				
			||||||
        method: str,
 | 
					        method: str,
 | 
				
			||||||
        params: dict,
 | 
					        params: dict,
 | 
				
			||||||
        binary: bytes | None = None
 | 
					        binary: bytes | None = None
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
 | 
					        def callback_fn(step: int, timestep: int, latents: torch.FloatTensor):
 | 
				
			||||||
 | 
					            should_raise = trio.from_thread.run(should_cancel_work)
 | 
				
			||||||
 | 
					            if should_raise:
 | 
				
			||||||
 | 
					                logging.warn(f'cancelling work at step {step}')
 | 
				
			||||||
 | 
					                raise DGPUComputeError('Inference cancelled')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            match method:
 | 
					            match method:
 | 
				
			||||||
                case 'diffuse':
 | 
					                case 'diffuse':
 | 
				
			||||||
| 
						 | 
					@ -140,6 +150,8 @@ class SkynetMM:
 | 
				
			||||||
                        guidance_scale=guidance,
 | 
					                        guidance_scale=guidance,
 | 
				
			||||||
                        num_inference_steps=step,
 | 
					                        num_inference_steps=step,
 | 
				
			||||||
                        generator=seed,
 | 
					                        generator=seed,
 | 
				
			||||||
 | 
					                        callback=callback_fn,
 | 
				
			||||||
 | 
					                        callback_steps=1,
 | 
				
			||||||
                        **extra_params
 | 
					                        **extra_params
 | 
				
			||||||
                    ).images[0]
 | 
					                    ).images[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,6 +5,7 @@ import logging
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hashlib import sha256
 | 
					from hashlib import sha256
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import trio
 | 
					import trio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,6 +27,16 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
            config['auto_withdraw']
 | 
					            config['auto_withdraw']
 | 
				
			||||||
            if 'auto_withdraw' in config else False
 | 
					            if 'auto_withdraw' in config else False
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        self.non_compete = set(('testworker2', 'animus2.boid', 'animus1.boid'))
 | 
				
			||||||
 | 
					        self.current_request = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def should_cancel_work(self):
 | 
				
			||||||
 | 
					        competitors = set((
 | 
				
			||||||
 | 
					            status['worker']
 | 
				
			||||||
 | 
					            for status in
 | 
				
			||||||
 | 
					            (await self.conn.get_status_by_request_id(self.current_request))
 | 
				
			||||||
 | 
					        ))
 | 
				
			||||||
 | 
					        return self.non_compete & competitors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def serve_forever(self):
 | 
					    async def serve_forever(self):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -43,7 +54,7 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
                        statuses = await self.conn.get_status_by_request_id(rid)
 | 
					                        statuses = await self.conn.get_status_by_request_id(rid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        if len(statuses) == 0:
 | 
					                        if len(statuses) == 0:
 | 
				
			||||||
 | 
					                            self.current_request = rid
 | 
				
			||||||
                            # parse request
 | 
					                            # parse request
 | 
				
			||||||
                            body = json.loads(req['body'])
 | 
					                            body = json.loads(req['body'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -70,8 +81,13 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            else:
 | 
					                            else:
 | 
				
			||||||
                                try:
 | 
					                                try:
 | 
				
			||||||
                                    img_sha, img_raw = self.mm.compute_one(
 | 
					                                    img_sha, img_raw = await trio.to_thread.run_sync(
 | 
				
			||||||
                                        body['method'], body['params'], binary=binary)
 | 
					                                        partial(
 | 
				
			||||||
 | 
					                                            self.mm.compute_one,
 | 
				
			||||||
 | 
					                                            self.should_cancel_work,
 | 
				
			||||||
 | 
					                                            body['method'], body['params'], binary=binary
 | 
				
			||||||
 | 
					                                        )
 | 
				
			||||||
 | 
					                                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                    ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
 | 
					                                    ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue