mirror of https://github.com/skygpu/skynet.git
				
				
				
			Fix cli entrypoints to use new config, improve competitor cancel logic and add default docker image to py311 image
							parent
							
								
									5437af4d05
								
							
						
					
					
						commit
						cc4a4b5189
					
				| 
						 | 
					@ -5,3 +5,7 @@ docker build \
 | 
				
			||||||
docker build \
 | 
					docker build \
 | 
				
			||||||
    -t guilledk/skynet:runtime-cuda-py311 \
 | 
					    -t guilledk/skynet:runtime-cuda-py311 \
 | 
				
			||||||
    -f docker/Dockerfile.runtime+cuda-py311 .
 | 
					    -f docker/Dockerfile.runtime+cuda-py311 .
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					docker build \
 | 
				
			||||||
 | 
					    -t guilledk/skynet:runtime-cuda \
 | 
				
			||||||
 | 
					    -f docker/Dockerfile.runtime+cuda-py311 .
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,8 +33,8 @@ def txt2img(*args, **kwargs):
 | 
				
			||||||
    from . import utils
 | 
					    from . import utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
 | 
					    hf_token = load_key(config, 'skynet.dgpu.hf_token')
 | 
				
			||||||
    hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
 | 
					    hf_home = load_key(config, 'skynet.dgpu.hf_home')
 | 
				
			||||||
    set_hf_vars(hf_token, hf_home)
 | 
					    set_hf_vars(hf_token, hf_home)
 | 
				
			||||||
    utils.txt2img(hf_token, **kwargs)
 | 
					    utils.txt2img(hf_token, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,8 +51,8 @@ def txt2img(*args, **kwargs):
 | 
				
			||||||
def img2img(model, prompt, input, output, strength, guidance, steps, seed):
 | 
					def img2img(model, prompt, input, output, strength, guidance, steps, seed):
 | 
				
			||||||
    from . import utils
 | 
					    from . import utils
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
 | 
					    hf_token = load_key(config, 'skynet.dgpu.hf_token')
 | 
				
			||||||
    hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
 | 
					    hf_home = load_key(config, 'skynet.dgpu.hf_home')
 | 
				
			||||||
    set_hf_vars(hf_token, hf_home)
 | 
					    set_hf_vars(hf_token, hf_home)
 | 
				
			||||||
    utils.img2img(
 | 
					    utils.img2img(
 | 
				
			||||||
        hf_token,
 | 
					        hf_token,
 | 
				
			||||||
| 
						 | 
					@ -82,8 +82,8 @@ def upscale(input, output, model):
 | 
				
			||||||
def download():
 | 
					def download():
 | 
				
			||||||
    from . import utils
 | 
					    from . import utils
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    hf_token = load_key(config, 'skynet.dgpu', 'hf_token')
 | 
					    hf_token = load_key(config, 'skynet.dgpu.hf_token')
 | 
				
			||||||
    hf_home = load_key(config, 'skynet.dgpu', 'hf_home')
 | 
					    hf_home = load_key(config, 'skynet.dgpu.hf_home')
 | 
				
			||||||
    set_hf_vars(hf_token, hf_home)
 | 
					    set_hf_vars(hf_token, hf_home)
 | 
				
			||||||
    utils.download_all_models(hf_token)
 | 
					    utils.download_all_models(hf_token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -112,10 +112,10 @@ def enqueue(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key = load_key(config, 'skynet.user', 'key')
 | 
					    key = load_key(config, 'skynet.user.key')
 | 
				
			||||||
    account = load_key(config, 'skynet.user', 'account')
 | 
					    account = load_key(config, 'skynet.user.account')
 | 
				
			||||||
    permission = load_key(config, 'skynet.user', 'permission')
 | 
					    permission = load_key(config, 'skynet.user.permission')
 | 
				
			||||||
    node_url = load_key(config, 'skynet.user', 'node_url')
 | 
					    node_url = load_key(config, 'skynet.user.node_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
					    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -156,10 +156,10 @@ def clean(
 | 
				
			||||||
    from leap.cleos import CLEOS
 | 
					    from leap.cleos import CLEOS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    key = load_key(config, 'skynet.user', 'key')
 | 
					    key = load_key(config, 'skynet.user.key')
 | 
				
			||||||
    account = load_key(config, 'skynet.user', 'account')
 | 
					    account = load_key(config, 'skynet.user.account')
 | 
				
			||||||
    permission = load_key(config, 'skynet.user', 'permission')
 | 
					    permission = load_key(config, 'skynet.user.permission')
 | 
				
			||||||
    node_url = load_key(config, 'skynet.user', 'node_url')
 | 
					    node_url = load_key(config, 'skynet.user.node_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logging.basicConfig(level=loglevel)
 | 
					    logging.basicConfig(level=loglevel)
 | 
				
			||||||
    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
					    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
				
			||||||
| 
						 | 
					@ -177,7 +177,7 @@ def clean(
 | 
				
			||||||
def queue():
 | 
					def queue():
 | 
				
			||||||
    import requests
 | 
					    import requests
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    node_url = load_key(config, 'skynet.user', 'node_url')
 | 
					    node_url = load_key(config, 'skynet.user.node_url')
 | 
				
			||||||
    resp = requests.post(
 | 
					    resp = requests.post(
 | 
				
			||||||
        f'{node_url}/v1/chain/get_table_rows',
 | 
					        f'{node_url}/v1/chain/get_table_rows',
 | 
				
			||||||
        json={
 | 
					        json={
 | 
				
			||||||
| 
						 | 
					@ -194,7 +194,7 @@ def queue():
 | 
				
			||||||
def status(request_id: int):
 | 
					def status(request_id: int):
 | 
				
			||||||
    import requests
 | 
					    import requests
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    node_url = load_key(config, 'skynet.user', 'node_url')
 | 
					    node_url = load_key(config, 'skynet.user.node_url')
 | 
				
			||||||
    resp = requests.post(
 | 
					    resp = requests.post(
 | 
				
			||||||
        f'{node_url}/v1/chain/get_table_rows',
 | 
					        f'{node_url}/v1/chain/get_table_rows',
 | 
				
			||||||
        json={
 | 
					        json={
 | 
				
			||||||
| 
						 | 
					@ -213,10 +213,10 @@ def dequeue(request_id: int):
 | 
				
			||||||
    from leap.cleos import CLEOS
 | 
					    from leap.cleos import CLEOS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    key = load_key(config, 'skynet.user', 'key')
 | 
					    key = load_key(config, 'skynet.user.key')
 | 
				
			||||||
    account = load_key(config, 'skynet.user', 'account')
 | 
					    account = load_key(config, 'skynet.user.account')
 | 
				
			||||||
    permission = load_key(config, 'skynet.user', 'permission')
 | 
					    permission = load_key(config, 'skynet.user.permission')
 | 
				
			||||||
    node_url = load_key(config, 'skynet.user', 'node_url')
 | 
					    node_url = load_key(config, 'skynet.user.node_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
					    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
				
			||||||
    res = trio.run(
 | 
					    res = trio.run(
 | 
				
			||||||
| 
						 | 
					@ -248,10 +248,10 @@ def config(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key = load_key(config, 'skynet.user', 'key')
 | 
					    key = load_key(config, 'skynet.user.key')
 | 
				
			||||||
    account = load_key(config, 'skynet.user', 'account')
 | 
					    account = load_key(config, 'skynet.user.account')
 | 
				
			||||||
    permission = load_key(config, 'skynet.user', 'permission')
 | 
					    permission = load_key(config, 'skynet.user.permission')
 | 
				
			||||||
    node_url = load_key(config, 'skynet.user', 'node_url')
 | 
					    node_url = load_key(config, 'skynet.user.node_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
					    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
				
			||||||
    res = trio.run(
 | 
					    res = trio.run(
 | 
				
			||||||
| 
						 | 
					@ -277,10 +277,10 @@ def deposit(quantity: str):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key = load_key(config, 'skynet.user', 'key')
 | 
					    key = load_key(config, 'skynet.user.key')
 | 
				
			||||||
    account = load_key(config, 'skynet.user', 'account')
 | 
					    account = load_key(config, 'skynet.user.account')
 | 
				
			||||||
    permission = load_key(config, 'skynet.user', 'permission')
 | 
					    permission = load_key(config, 'skynet.user.permission')
 | 
				
			||||||
    node_url = load_key(config, 'skynet.user', 'node_url')
 | 
					    node_url = load_key(config, 'skynet.user.node_url')
 | 
				
			||||||
    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
					    cleos = CLEOS(None, None, url=node_url, remote=node_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    res = trio.run(
 | 
					    res = trio.run(
 | 
				
			||||||
| 
						 | 
					@ -365,21 +365,21 @@ def telegram(
 | 
				
			||||||
    logging.basicConfig(level=loglevel)
 | 
					    logging.basicConfig(level=loglevel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    tg_token = load_key(config, 'skynet.telegram', 'tg_token')
 | 
					    tg_token = load_key(config, 'skynet.telegram.tg_token')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key = load_key(config, 'skynet.telegram', 'key')
 | 
					    key = load_key(config, 'skynet.telegram.key')
 | 
				
			||||||
    account = load_key(config, 'skynet.telegram', 'account')
 | 
					    account = load_key(config, 'skynet.telegram.account')
 | 
				
			||||||
    permission = load_key(config, 'skynet.telegram', 'permission')
 | 
					    permission = load_key(config, 'skynet.telegram.permission')
 | 
				
			||||||
    node_url = load_key(config, 'skynet.telegram', 'node_url')
 | 
					    node_url = load_key(config, 'skynet.telegram.node_url')
 | 
				
			||||||
    hyperion_url = load_key(config, 'skynet.telegram', 'hyperion_url')
 | 
					    hyperion_url = load_key(config, 'skynet.telegram.hyperion_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        ipfs_gateway_url = load_key(config, 'skynet.telegram', 'ipfs_gateway_url')
 | 
					        ipfs_gateway_url = load_key(config, 'skynet.telegram.ipfs_gateway_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    except ConfigParsingError:
 | 
					    except ConfigParsingError:
 | 
				
			||||||
        ipfs_gateway_url = None
 | 
					        ipfs_gateway_url = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ipfs_url = load_key(config, 'skynet.telegram', 'ipfs_url')
 | 
					    ipfs_url = load_key(config, 'skynet.telegram.ipfs_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _async_main():
 | 
					    async def _async_main():
 | 
				
			||||||
        frontend = SkynetTelegramFrontend(
 | 
					        frontend = SkynetTelegramFrontend(
 | 
				
			||||||
| 
						 | 
					@ -421,16 +421,16 @@ def discord(
 | 
				
			||||||
    logging.basicConfig(level=loglevel)
 | 
					    logging.basicConfig(level=loglevel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    dc_token = load_key(config, 'skynet.discord', 'dc_token')
 | 
					    dc_token = load_key(config, 'skynet.discord.dc_token')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key = load_key(config, 'skynet.discord', 'key')
 | 
					    key = load_key(config, 'skynet.discord.key')
 | 
				
			||||||
    account = load_key(config, 'skynet.discord', 'account')
 | 
					    account = load_key(config, 'skynet.discord.account')
 | 
				
			||||||
    permission = load_key(config, 'skynet.discord', 'permission')
 | 
					    permission = load_key(config, 'skynet.discord.permission')
 | 
				
			||||||
    node_url = load_key(config, 'skynet.discord', 'node_url')
 | 
					    node_url = load_key(config, 'skynet.discord.node_url')
 | 
				
			||||||
    hyperion_url = load_key(config, 'skynet.discord', 'hyperion_url')
 | 
					    hyperion_url = load_key(config, 'skynet.discord.hyperion_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ipfs_gateway_url = load_key(config, 'skynet.discord', 'ipfs_gateway_url')
 | 
					    ipfs_gateway_url = load_key(config, 'skynet.discord.ipfs_gateway_url')
 | 
				
			||||||
    ipfs_url = load_key(config, 'skynet.discord', 'ipfs_url')
 | 
					    ipfs_url = load_key(config, 'skynet.discord.ipfs_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _async_main():
 | 
					    async def _async_main():
 | 
				
			||||||
        frontend = SkynetDiscordFrontend(
 | 
					        frontend = SkynetDiscordFrontend(
 | 
				
			||||||
| 
						 | 
					@ -471,8 +471,8 @@ def pinner(loglevel):
 | 
				
			||||||
    from .ipfs.pinner import SkynetPinner
 | 
					    from .ipfs.pinner import SkynetPinner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = load_skynet_toml()
 | 
					    config = load_skynet_toml()
 | 
				
			||||||
    hyperion_url = load_key(config, 'skynet.pinner', 'hyperion_url')
 | 
					    hyperion_url = load_key(config, 'skynet.pinner.hyperion_url')
 | 
				
			||||||
    ipfs_url = load_key(config, 'skynet.pinner', 'ipfs_url')
 | 
					    ipfs_url = load_key(config, 'skynet.pinner.ipfs_url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logging.basicConfig(level=loglevel)
 | 
					    logging.basicConfig(level=loglevel)
 | 
				
			||||||
    ipfs_node = AsyncIPFSHTTP(ipfs_url)
 | 
					    ipfs_node = AsyncIPFSHTTP(ipfs_url)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,7 +13,7 @@ import trio
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
 | 
					from skynet.constants import DEFAULT_INITAL_MODELS, MODELS
 | 
				
			||||||
from skynet.dgpu.errors import DGPUComputeError
 | 
					from skynet.dgpu.errors import DGPUComputeError, DGPUInferenceCancelled
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
 | 
					from skynet.utils import convert_from_bytes_and_crop, convert_from_cv2_to_image, convert_from_image_to_cv2, convert_from_img_to_bytes, init_upscaler, pipeline_for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -132,16 +132,19 @@ class SkynetMM:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compute_one(
 | 
					    def compute_one(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
 | 
					        request_id: int,
 | 
				
			||||||
        should_cancel_work,
 | 
					        should_cancel_work,
 | 
				
			||||||
        method: str,
 | 
					        method: str,
 | 
				
			||||||
        params: dict,
 | 
					        params: dict,
 | 
				
			||||||
        binary: bytes | None = None
 | 
					        binary: bytes | None = None
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        def callback_fn(step: int, timestep: int, latents: torch.FloatTensor):
 | 
					        def maybe_cancel_work(step, *args, **kwargs):
 | 
				
			||||||
            should_raise = trio.from_thread.run(should_cancel_work)
 | 
					            should_raise = trio.from_thread.run(should_cancel_work, request_id)
 | 
				
			||||||
            if should_raise:
 | 
					            if should_raise:
 | 
				
			||||||
                logging.warn(f'cancelling work at step {step}')
 | 
					                logging.warn(f'cancelling work at step {step}')
 | 
				
			||||||
                raise DGPUComputeError('Inference cancelled')
 | 
					                raise DGPUInferenceCancelled()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        maybe_cancel_work(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            match method:
 | 
					            match method:
 | 
				
			||||||
| 
						 | 
					@ -157,7 +160,7 @@ class SkynetMM:
 | 
				
			||||||
                        guidance_scale=guidance,
 | 
					                        guidance_scale=guidance,
 | 
				
			||||||
                        num_inference_steps=step,
 | 
					                        num_inference_steps=step,
 | 
				
			||||||
                        generator=seed,
 | 
					                        generator=seed,
 | 
				
			||||||
                        callback=callback_fn,
 | 
					                        callback=maybe_cancel_work,
 | 
				
			||||||
                        callback_steps=2,
 | 
					                        callback_steps=2,
 | 
				
			||||||
                        **extra_params
 | 
					                        **extra_params
 | 
				
			||||||
                    ).images[0]
 | 
					                    ).images[0]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,15 +40,9 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
        if 'model_blacklist' in config:
 | 
					        if 'model_blacklist' in config:
 | 
				
			||||||
            self.model_blacklist = set(config['model_blacklist'])
 | 
					            self.model_blacklist = set(config['model_blacklist'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.current_request = None
 | 
					    async def should_cancel_work(self, request_id: int):
 | 
				
			||||||
 | 
					        competitors = await self.conn.get_competitors_for_req(request_id)
 | 
				
			||||||
    async def should_cancel_work(self):
 | 
					        return bool(self.non_compete & competitors)
 | 
				
			||||||
        competitors = set((
 | 
					 | 
				
			||||||
            status['worker']
 | 
					 | 
				
			||||||
            for status in
 | 
					 | 
				
			||||||
            (await self.conn.get_status_by_request_id(self.current_request))
 | 
					 | 
				
			||||||
        ))
 | 
					 | 
				
			||||||
        return self.non_compete & competitors
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def serve_forever(self):
 | 
					    async def serve_forever(self):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -79,7 +73,7 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
                        statuses = await self.conn.get_status_by_request_id(rid)
 | 
					                        statuses = await self.conn.get_status_by_request_id(rid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        if len(statuses) == 0:
 | 
					                        if len(statuses) == 0:
 | 
				
			||||||
                            self.current_request = rid
 | 
					                            self.conn.monitor_request(rid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            binary = await self.conn.get_input_data(req['binary_data'])
 | 
					                            binary = await self.conn.get_input_data(req['binary_data'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -107,6 +101,7 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
                                    img_sha, img_raw = await trio.to_thread.run_sync(
 | 
					                                    img_sha, img_raw = await trio.to_thread.run_sync(
 | 
				
			||||||
                                        partial(
 | 
					                                        partial(
 | 
				
			||||||
                                            self.mm.compute_one,
 | 
					                                            self.mm.compute_one,
 | 
				
			||||||
 | 
					                                            rid,
 | 
				
			||||||
                                            self.should_cancel_work,
 | 
					                                            self.should_cancel_work,
 | 
				
			||||||
                                            body['method'], body['params'], binary=binary
 | 
					                                            body['method'], body['params'], binary=binary
 | 
				
			||||||
                                        )
 | 
					                                        )
 | 
				
			||||||
| 
						 | 
					@ -115,11 +110,13 @@ class SkynetDGPUDaemon:
 | 
				
			||||||
                                    ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
 | 
					                                    ipfs_hash = await self.conn.publish_on_ipfs(img_raw)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                    await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash)
 | 
					                                    await self.conn.submit_work(rid, request_hash, img_sha, ipfs_hash)
 | 
				
			||||||
                                    break
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                except BaseException as e:
 | 
					                                except BaseException as e:
 | 
				
			||||||
                                    traceback.print_exc()
 | 
					                                    traceback.print_exc()
 | 
				
			||||||
                                    await self.conn.cancel_work(rid, str(e))
 | 
					                                    await self.conn.cancel_work(rid, str(e))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                finally:
 | 
				
			||||||
 | 
					                                    self.conn.forget_request(rid)
 | 
				
			||||||
                                    break
 | 
					                                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,3 +3,6 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DGPUComputeError(BaseException):
 | 
					class DGPUComputeError(BaseException):
 | 
				
			||||||
    ...
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DGPUInferenceCancelled(BaseException):
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,13 +1,15 @@
 | 
				
			||||||
#!/usr/bin/python
 | 
					#!/usr/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from functools import partial
 | 
					 | 
				
			||||||
import io
 | 
					import io
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
from pathlib import Path
 | 
					 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import asks
 | 
					import asks
 | 
				
			||||||
 | 
					import trio
 | 
				
			||||||
import anyio
 | 
					import anyio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
| 
						 | 
					@ -20,6 +22,9 @@ from skynet.dgpu.errors import DGPUComputeError
 | 
				
			||||||
from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
 | 
					from skynet.ipfs import AsyncIPFSHTTP, get_ipfs_file
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					REQUEST_UPDATE_TIME = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def failable(fn: partial, ret_fail=None):
 | 
					async def failable(fn: partial, ret_fail=None):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        return await fn()
 | 
					        return await fn()
 | 
				
			||||||
| 
						 | 
					@ -54,6 +59,8 @@ class SkynetGPUConnector:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.ipfs_client = AsyncIPFSHTTP(self.ipfs_url)
 | 
					        self.ipfs_client = AsyncIPFSHTTP(self.ipfs_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._wip_requests = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # blockchain helpers
 | 
					    # blockchain helpers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_work_requests_last_hour(self):
 | 
					    async def get_work_requests_last_hour(self):
 | 
				
			||||||
| 
						 | 
					@ -103,6 +110,36 @@ class SkynetGPUConnector:
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def monitor_request(self, request_id: int):
 | 
				
			||||||
 | 
					        logging.info(f'begin monitoring request: {request_id}')
 | 
				
			||||||
 | 
					        self._wip_requests[request_id] = {
 | 
				
			||||||
 | 
					            'last_update': None,
 | 
				
			||||||
 | 
					            'competitors': set()
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def maybe_update_request(self, request_id: int):
 | 
				
			||||||
 | 
					        now = time.time()
 | 
				
			||||||
 | 
					        stats = self._wip_requests[request_id]
 | 
				
			||||||
 | 
					        if (not stats['last_update'] or
 | 
				
			||||||
 | 
					            (now - stats['last_update']) > REQUEST_UPDATE_TIME):
 | 
				
			||||||
 | 
					            stats['competitors'] = [
 | 
				
			||||||
 | 
					                status['worker']
 | 
				
			||||||
 | 
					                for status in
 | 
				
			||||||
 | 
					                (await self.get_status_by_request_id(request_id))
 | 
				
			||||||
 | 
					                if status['worker'] != self.account
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					            stats['last_update'] = now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_competitors_for_req(self, request_id: int) -> set:
 | 
				
			||||||
 | 
					        await self.maybe_update_request(request_id)
 | 
				
			||||||
 | 
					        competitors = set(self._wip_requests[request_id]['competitors'])
 | 
				
			||||||
 | 
					        logging.info(f'competitors: {competitors}')
 | 
				
			||||||
 | 
					        return competitors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forget_request(self, request_id: int):
 | 
				
			||||||
 | 
					        logging.info(f'end monitoring request: {request_id}')
 | 
				
			||||||
 | 
					        del self._wip_requests[request_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def begin_work(self, request_id: int):
 | 
					    async def begin_work(self, request_id: int):
 | 
				
			||||||
        logging.info('begin_work')
 | 
					        logging.info('begin_work')
 | 
				
			||||||
        return await failable(
 | 
					        return await failable(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue