mirror of https://github.com/skygpu/skynet.git
Fixes to tui and compatibility with frontend formated requests found while running worker
parent
f60e582ad5
commit
63c849a41e
|
@ -43,14 +43,14 @@ def prepare_params_for_diffuse(
|
|||
if 'flux' in params.model.lower():
|
||||
_params['max_sequence_length'] = 512
|
||||
else:
|
||||
_params['strength'] = params.strength
|
||||
_params['strength'] = float(params.strength)
|
||||
|
||||
case ModelMode.IMG2IMG:
|
||||
image = crop_image(
|
||||
inputs[0], params.width, params.height)
|
||||
|
||||
_params['image'] = image
|
||||
_params['strength'] = params.strength
|
||||
_params['strength'] = float(params.strength)
|
||||
|
||||
case ModelMode.TXT2IMG | ModelMode.DIFFUSE:
|
||||
...
|
||||
|
@ -60,7 +60,7 @@ def prepare_params_for_diffuse(
|
|||
|
||||
return (
|
||||
params.prompt,
|
||||
params.guidance,
|
||||
float(params.guidance),
|
||||
params.step,
|
||||
torch.manual_seed(int(params.seed)),
|
||||
_params
|
||||
|
|
|
@ -33,14 +33,13 @@ async def maybe_serve_one(
|
|||
conn: NetConnector,
|
||||
state_mngr: ContractState,
|
||||
):
|
||||
logging.info(f'maybe serve request pi: {state_mngr.poll_index}')
|
||||
req = state_mngr.first
|
||||
|
||||
# no requests in queue
|
||||
if not req:
|
||||
return
|
||||
|
||||
logging.info(f'maybe serve request #{req.id}')
|
||||
|
||||
# parse request
|
||||
body = msgspec.json.decode(req.body, type=BodyV0)
|
||||
model = body.params.model
|
||||
|
|
|
@ -20,6 +20,7 @@ from skynet.dgpu.tui import maybe_update_tui
|
|||
from skynet.config import DgpuConfig as Config, load_skynet_toml
|
||||
from skynet.types import (
|
||||
ConfigV0,
|
||||
AccountV0,
|
||||
BodyV0,
|
||||
RequestV0,
|
||||
WorkerStatusV0,
|
||||
|
@ -128,7 +129,8 @@ class NetConnector:
|
|||
index_position=1,
|
||||
key_type='name',
|
||||
lower_bound=self.config.account,
|
||||
upper_bound=self.config.account
|
||||
upper_bound=self.config.account,
|
||||
resp_cls=AccountV0
|
||||
))
|
||||
|
||||
if rows:
|
||||
|
@ -334,6 +336,7 @@ class ContractState:
|
|||
self._queue.append(req)
|
||||
|
||||
except msgspec.ValidationError:
|
||||
logging.exception(f'dropping req {req.id} due to:')
|
||||
...
|
||||
|
||||
random.shuffle(self._queue)
|
||||
|
|
|
@ -78,6 +78,11 @@ class WorkerMonitor:
|
|||
"""
|
||||
row_widgets = []
|
||||
|
||||
requests = sorted(
|
||||
requests,
|
||||
key=lambda r: r['id']
|
||||
)
|
||||
|
||||
for req in requests:
|
||||
# Build a columns widget for the request row
|
||||
prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
|
||||
|
@ -159,11 +164,13 @@ class WorkerMonitor:
|
|||
def network_update(self, state_mngr):
|
||||
queue = [
|
||||
{
|
||||
**r,
|
||||
'id': r.id,
|
||||
'user': r.user,
|
||||
'reward': r.reward,
|
||||
**(json.loads(r.body)['params']),
|
||||
'workers': [s.worker for s in state_mngr._status_by_rid[r.id]]
|
||||
}
|
||||
for r in state_mngr.queue
|
||||
for r in state_mngr._queue
|
||||
]
|
||||
self.update_requests(queue)
|
||||
|
||||
|
|
|
@ -79,10 +79,10 @@ class BodyV0Params(Struct):
|
|||
model: str
|
||||
seed: int
|
||||
step: int = 1
|
||||
guidance: float | None = None
|
||||
guidance: str | float | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
strength: float | None = None
|
||||
strength: str | float | None = None
|
||||
output_type: str | None = 'png'
|
||||
upscaler: str | None = None
|
||||
|
||||
|
|
Loading…
Reference in New Issue