mirror of https://github.com/skygpu/skynet.git
				
				
				
			Rework if statement to reduce indentation, add comment about logic
							parent
							
								
									e09652eaae
								
							
						
					
					
						commit
						cc7015eb03
					
				| 
						 | 
				
			
			@ -177,94 +177,91 @@ class WorkerDaemon:
 | 
			
		|||
            return False
 | 
			
		||||
 | 
			
		||||
        results = [res['id'] for res in self._snap['results']]
 | 
			
		||||
        if (
 | 
			
		||||
            rid not in results
 | 
			
		||||
            and
 | 
			
		||||
            rid in self._snap['requests']
 | 
			
		||||
        ):
 | 
			
		||||
            statuses = self._snap['requests'][rid]
 | 
			
		||||
            if len(statuses) == 0:
 | 
			
		||||
                inputs = []
 | 
			
		||||
                for _input in req['binary_data'].split(','):
 | 
			
		||||
                    if _input:
 | 
			
		||||
                        for _ in range(3):
 | 
			
		||||
                            try:
 | 
			
		||||
                                # user `GPUConnector` to IO with
 | 
			
		||||
                                # storage layer to seed the compute
 | 
			
		||||
                                # task.
 | 
			
		||||
                                img = await self.conn.get_input_data(_input)
 | 
			
		||||
                                inputs.append(img)
 | 
			
		||||
                                break
 | 
			
		||||
 | 
			
		||||
                            except BaseException:
 | 
			
		||||
                                logging.exception(
 | 
			
		||||
                                    'Model input error !?!\n'
 | 
			
		||||
                                )
 | 
			
		||||
 | 
			
		||||
                hash_str = (
 | 
			
		||||
                    str(req['nonce'])
 | 
			
		||||
                    +
 | 
			
		||||
                    req['body']
 | 
			
		||||
                    +
 | 
			
		||||
                    req['binary_data']
 | 
			
		||||
                )
 | 
			
		||||
                logging.info(f'hashing: {hash_str}')
 | 
			
		||||
                request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
 | 
			
		||||
 | 
			
		||||
                # TODO: validate request
 | 
			
		||||
 | 
			
		||||
                # perform work
 | 
			
		||||
                logging.info(f'working on {body}')
 | 
			
		||||
 | 
			
		||||
                resp = await self.conn.begin_work(rid)
 | 
			
		||||
                if not resp or 'code' in resp:
 | 
			
		||||
                    logging.info('probably being worked on already... skip.')
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    try:
 | 
			
		||||
                        output_type = 'png'
 | 
			
		||||
                        if 'output_type' in body['params']:
 | 
			
		||||
                            output_type = body['params']['output_type']
 | 
			
		||||
 | 
			
		||||
                        output = None
 | 
			
		||||
                        output_hash = None
 | 
			
		||||
                        match self.backend:
 | 
			
		||||
                            case 'sync-on-thread':
 | 
			
		||||
                                self.mm._should_cancel = self.should_cancel_work
 | 
			
		||||
                                output_hash, output = await trio.to_thread.run_sync(
 | 
			
		||||
                                    partial(
 | 
			
		||||
                                        self.mm.compute_one,
 | 
			
		||||
                                        rid,
 | 
			
		||||
                                        body['method'], body['params'],
 | 
			
		||||
                                        inputs=inputs
 | 
			
		||||
                                    )
 | 
			
		||||
                                )
 | 
			
		||||
 | 
			
		||||
                            case _:
 | 
			
		||||
                                raise DGPUComputeError(
 | 
			
		||||
                                    f'Unsupported backend {self.backend}'
 | 
			
		||||
                                )
 | 
			
		||||
 | 
			
		||||
                        self._last_generation_ts: str = datetime.now().isoformat()
 | 
			
		||||
                        self._last_benchmark: list[float] = self._benchmark
 | 
			
		||||
                        self._benchmark: list[float] = []
 | 
			
		||||
 | 
			
		||||
                        ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
 | 
			
		||||
 | 
			
		||||
                        await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
 | 
			
		||||
 | 
			
		||||
                    except BaseException as err:
 | 
			
		||||
                        logging.exception('Failed to serve model request !?\n')
 | 
			
		||||
                        # traceback.print_exc()  # TODO? <- replaced by above ya?
 | 
			
		||||
                        await self.conn.cancel_work(rid, str(err))
 | 
			
		||||
 | 
			
		||||
                    finally:
 | 
			
		||||
                        return True
 | 
			
		||||
 | 
			
		||||
        # TODO, i would inverse this case logic to avoid an indent
 | 
			
		||||
        # level in above block ;)
 | 
			
		||||
        else:
 | 
			
		||||
        # if worker is already on that request or
 | 
			
		||||
        # if worker has a stale status for that request
 | 
			
		||||
        if rid in results or rid not in self._snap['requests']:
 | 
			
		||||
            logging.info(f'request {rid} already beign worked on, skip...')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        statuses = self._snap['requests'][rid]
 | 
			
		||||
        if len(statuses) == 0:
 | 
			
		||||
            inputs = []
 | 
			
		||||
            for _input in req['binary_data'].split(','):
 | 
			
		||||
                if _input:
 | 
			
		||||
                    for _ in range(3):
 | 
			
		||||
                        try:
 | 
			
		||||
                            # user `GPUConnector` to IO with
 | 
			
		||||
                            # storage layer to seed the compute
 | 
			
		||||
                            # task.
 | 
			
		||||
                            img = await self.conn.get_input_data(_input)
 | 
			
		||||
                            inputs.append(img)
 | 
			
		||||
                            break
 | 
			
		||||
 | 
			
		||||
                        except BaseException:
 | 
			
		||||
                            logging.exception(
 | 
			
		||||
                                'Model input error !?!\n'
 | 
			
		||||
                            )
 | 
			
		||||
 | 
			
		||||
            hash_str = (
 | 
			
		||||
                str(req['nonce'])
 | 
			
		||||
                +
 | 
			
		||||
                req['body']
 | 
			
		||||
                +
 | 
			
		||||
                req['binary_data']
 | 
			
		||||
            )
 | 
			
		||||
            logging.info(f'hashing: {hash_str}')
 | 
			
		||||
            request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
 | 
			
		||||
 | 
			
		||||
            # TODO: validate request
 | 
			
		||||
 | 
			
		||||
            # perform work
 | 
			
		||||
            logging.info(f'working on {body}')
 | 
			
		||||
 | 
			
		||||
            resp = await self.conn.begin_work(rid)
 | 
			
		||||
            if not resp or 'code' in resp:
 | 
			
		||||
                logging.info('probably being worked on already... skip.')
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                try:
 | 
			
		||||
                    output_type = 'png'
 | 
			
		||||
                    if 'output_type' in body['params']:
 | 
			
		||||
                        output_type = body['params']['output_type']
 | 
			
		||||
 | 
			
		||||
                    output = None
 | 
			
		||||
                    output_hash = None
 | 
			
		||||
                    match self.backend:
 | 
			
		||||
                        case 'sync-on-thread':
 | 
			
		||||
                            self.mm._should_cancel = self.should_cancel_work
 | 
			
		||||
                            output_hash, output = await trio.to_thread.run_sync(
 | 
			
		||||
                                partial(
 | 
			
		||||
                                    self.mm.compute_one,
 | 
			
		||||
                                    rid,
 | 
			
		||||
                                    body['method'], body['params'],
 | 
			
		||||
                                    inputs=inputs
 | 
			
		||||
                                )
 | 
			
		||||
                            )
 | 
			
		||||
 | 
			
		||||
                        case _:
 | 
			
		||||
                            raise DGPUComputeError(
 | 
			
		||||
                                f'Unsupported backend {self.backend}'
 | 
			
		||||
                            )
 | 
			
		||||
 | 
			
		||||
                    self._last_generation_ts: str = datetime.now().isoformat()
 | 
			
		||||
                    self._last_benchmark: list[float] = self._benchmark
 | 
			
		||||
                    self._benchmark: list[float] = []
 | 
			
		||||
 | 
			
		||||
                    ipfs_hash = await self.conn.publish_on_ipfs(output, typ=output_type)
 | 
			
		||||
 | 
			
		||||
                    await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
 | 
			
		||||
 | 
			
		||||
                except BaseException as err:
 | 
			
		||||
                    logging.exception('Failed to serve model request !?\n')
 | 
			
		||||
                    # traceback.print_exc()  # TODO? <- replaced by above ya?
 | 
			
		||||
                    await self.conn.cancel_work(rid, str(err))
 | 
			
		||||
 | 
			
		||||
                finally:
 | 
			
		||||
                    return True
 | 
			
		||||
 | 
			
		||||
    # TODO, as per above on `.maybe_serve_one()`, it's likely a bit
 | 
			
		||||
    # more *trionic* to define this all as a module level task-func
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue