mirror of https://github.com/skygpu/skynet.git
Fixes to tui and compatibility with frontend formated requests found while running worker
parent
f60e582ad5
commit
63c849a41e
|
@ -43,14 +43,14 @@ def prepare_params_for_diffuse(
|
||||||
if 'flux' in params.model.lower():
|
if 'flux' in params.model.lower():
|
||||||
_params['max_sequence_length'] = 512
|
_params['max_sequence_length'] = 512
|
||||||
else:
|
else:
|
||||||
_params['strength'] = params.strength
|
_params['strength'] = float(params.strength)
|
||||||
|
|
||||||
case ModelMode.IMG2IMG:
|
case ModelMode.IMG2IMG:
|
||||||
image = crop_image(
|
image = crop_image(
|
||||||
inputs[0], params.width, params.height)
|
inputs[0], params.width, params.height)
|
||||||
|
|
||||||
_params['image'] = image
|
_params['image'] = image
|
||||||
_params['strength'] = params.strength
|
_params['strength'] = float(params.strength)
|
||||||
|
|
||||||
case ModelMode.TXT2IMG | ModelMode.DIFFUSE:
|
case ModelMode.TXT2IMG | ModelMode.DIFFUSE:
|
||||||
...
|
...
|
||||||
|
@ -60,7 +60,7 @@ def prepare_params_for_diffuse(
|
||||||
|
|
||||||
return (
|
return (
|
||||||
params.prompt,
|
params.prompt,
|
||||||
params.guidance,
|
float(params.guidance),
|
||||||
params.step,
|
params.step,
|
||||||
torch.manual_seed(int(params.seed)),
|
torch.manual_seed(int(params.seed)),
|
||||||
_params
|
_params
|
||||||
|
|
|
@ -33,14 +33,13 @@ async def maybe_serve_one(
|
||||||
conn: NetConnector,
|
conn: NetConnector,
|
||||||
state_mngr: ContractState,
|
state_mngr: ContractState,
|
||||||
):
|
):
|
||||||
|
logging.info(f'maybe serve request pi: {state_mngr.poll_index}')
|
||||||
req = state_mngr.first
|
req = state_mngr.first
|
||||||
|
|
||||||
# no requests in queue
|
# no requests in queue
|
||||||
if not req:
|
if not req:
|
||||||
return
|
return
|
||||||
|
|
||||||
logging.info(f'maybe serve request #{req.id}')
|
|
||||||
|
|
||||||
# parse request
|
# parse request
|
||||||
body = msgspec.json.decode(req.body, type=BodyV0)
|
body = msgspec.json.decode(req.body, type=BodyV0)
|
||||||
model = body.params.model
|
model = body.params.model
|
||||||
|
|
|
@ -20,6 +20,7 @@ from skynet.dgpu.tui import maybe_update_tui
|
||||||
from skynet.config import DgpuConfig as Config, load_skynet_toml
|
from skynet.config import DgpuConfig as Config, load_skynet_toml
|
||||||
from skynet.types import (
|
from skynet.types import (
|
||||||
ConfigV0,
|
ConfigV0,
|
||||||
|
AccountV0,
|
||||||
BodyV0,
|
BodyV0,
|
||||||
RequestV0,
|
RequestV0,
|
||||||
WorkerStatusV0,
|
WorkerStatusV0,
|
||||||
|
@ -128,7 +129,8 @@ class NetConnector:
|
||||||
index_position=1,
|
index_position=1,
|
||||||
key_type='name',
|
key_type='name',
|
||||||
lower_bound=self.config.account,
|
lower_bound=self.config.account,
|
||||||
upper_bound=self.config.account
|
upper_bound=self.config.account,
|
||||||
|
resp_cls=AccountV0
|
||||||
))
|
))
|
||||||
|
|
||||||
if rows:
|
if rows:
|
||||||
|
@ -334,6 +336,7 @@ class ContractState:
|
||||||
self._queue.append(req)
|
self._queue.append(req)
|
||||||
|
|
||||||
except msgspec.ValidationError:
|
except msgspec.ValidationError:
|
||||||
|
logging.exception(f'dropping req {req.id} due to:')
|
||||||
...
|
...
|
||||||
|
|
||||||
random.shuffle(self._queue)
|
random.shuffle(self._queue)
|
||||||
|
|
|
@ -78,6 +78,11 @@ class WorkerMonitor:
|
||||||
"""
|
"""
|
||||||
row_widgets = []
|
row_widgets = []
|
||||||
|
|
||||||
|
requests = sorted(
|
||||||
|
requests,
|
||||||
|
key=lambda r: r['id']
|
||||||
|
)
|
||||||
|
|
||||||
for req in requests:
|
for req in requests:
|
||||||
# Build a columns widget for the request row
|
# Build a columns widget for the request row
|
||||||
prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
|
prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
|
||||||
|
@ -159,11 +164,13 @@ class WorkerMonitor:
|
||||||
def network_update(self, state_mngr):
|
def network_update(self, state_mngr):
|
||||||
queue = [
|
queue = [
|
||||||
{
|
{
|
||||||
**r,
|
'id': r.id,
|
||||||
|
'user': r.user,
|
||||||
|
'reward': r.reward,
|
||||||
**(json.loads(r.body)['params']),
|
**(json.loads(r.body)['params']),
|
||||||
'workers': [s.worker for s in state_mngr._status_by_rid[r.id]]
|
'workers': [s.worker for s in state_mngr._status_by_rid[r.id]]
|
||||||
}
|
}
|
||||||
for r in state_mngr.queue
|
for r in state_mngr._queue
|
||||||
]
|
]
|
||||||
self.update_requests(queue)
|
self.update_requests(queue)
|
||||||
|
|
||||||
|
|
|
@ -79,10 +79,10 @@ class BodyV0Params(Struct):
|
||||||
model: str
|
model: str
|
||||||
seed: int
|
seed: int
|
||||||
step: int = 1
|
step: int = 1
|
||||||
guidance: float | None = None
|
guidance: str | float | None = None
|
||||||
width: int | None = None
|
width: int | None = None
|
||||||
height: int | None = None
|
height: int | None = None
|
||||||
strength: float | None = None
|
strength: str | float | None = None
|
||||||
output_type: str | None = 'png'
|
output_type: str | None = 'png'
|
||||||
upscaler: str | None = None
|
upscaler: str | None = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue