mirror of https://github.com/skygpu/skynet.git
parent
7c27ee866a
commit
6c1799e342
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import base64
|
import zlib
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -24,9 +24,10 @@ from OpenSSL.crypto import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .db import *
|
from .db import *
|
||||||
from .structs import *
|
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
|
||||||
|
from .protobuf import *
|
||||||
|
|
||||||
|
|
||||||
class SkynetDGPUOffline(BaseException):
|
class SkynetDGPUOffline(BaseException):
|
||||||
...
|
...
|
||||||
|
@ -117,22 +118,30 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
async def dgpu_image_streamer():
|
async def dgpu_image_streamer():
|
||||||
nonlocal wip_reqs, fin_reqs
|
nonlocal wip_reqs, fin_reqs
|
||||||
while True:
|
while True:
|
||||||
raw_msg = (await dgpu_bus.arecv()).decode()
|
raw_msg = await dgpu_bus.arecv()
|
||||||
logging.info(f'streamer got {len(raw_msg)} bytes.')
|
logging.info(f'streamer got {len(raw_msg)} bytes.')
|
||||||
msg = DGPUBusResponse(**json.loads(raw_msg))
|
msg = DGPUBusMessage()
|
||||||
|
msg.ParseFromString(raw_msg)
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
msg.verify(tls_whitelist[msg.cert])
|
verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert])
|
||||||
|
|
||||||
if msg.rid not in wip_reqs:
|
rid = msg.rid
|
||||||
|
|
||||||
|
if rid not in wip_reqs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fin_reqs[msg.rid] = msg
|
if msg.method == 'binary-reply':
|
||||||
event = wip_reqs[msg.rid]
|
logging.info('bin reply, recv extra data')
|
||||||
event.set()
|
raw_img = await dgpu_bus.arecv()
|
||||||
del wip_reqs[msg.rid]
|
msg = (msg, raw_img)
|
||||||
|
|
||||||
async def dgpu_stream_one_img(req: ImageGenRequest):
|
fin_reqs[rid] = msg
|
||||||
|
event = wip_reqs[rid]
|
||||||
|
event.set()
|
||||||
|
del wip_reqs[rid]
|
||||||
|
|
||||||
|
async def dgpu_stream_one_img(req: Text2ImageParameters):
|
||||||
nonlocal wip_reqs, fin_reqs, next_worker
|
nonlocal wip_reqs, fin_reqs, next_worker
|
||||||
nid = get_next_worker()
|
nid = get_next_worker()
|
||||||
idx = list(nodes.keys()).index(nid)
|
idx = list(nodes.keys()).index(nid)
|
||||||
|
@ -144,19 +153,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
|
|
||||||
nodes[nid]['task'] = rid
|
nodes[nid]['task'] = rid
|
||||||
|
|
||||||
dgpu_req = DGPUBusRequest(
|
dgpu_req = DGPUBusMessage(
|
||||||
rid=rid,
|
rid=rid,
|
||||||
nid=nid,
|
nid=nid,
|
||||||
task='diffuse',
|
method='diffuse')
|
||||||
params=req.to_dict())
|
dgpu_req.params.update(req.to_dict())
|
||||||
|
|
||||||
logging.info(f'dgpu_bus req: {dgpu_req}')
|
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
dgpu_req.sign(tls_key, 'skynet')
|
dgpu_req.auth.cert = 'skynet'
|
||||||
|
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
|
||||||
|
|
||||||
await dgpu_bus.asend(
|
await dgpu_bus.asend(dgpu_req.SerializeToString())
|
||||||
json.dumps(dgpu_req.to_dict()).encode())
|
|
||||||
|
|
||||||
with trio.move_on_after(4):
|
with trio.move_on_after(4):
|
||||||
await ack_event.wait()
|
await ack_event.wait()
|
||||||
|
@ -184,13 +191,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
|
|
||||||
nodes[nid]['task'] = None
|
nodes[nid]['task'] = None
|
||||||
|
|
||||||
img_resp = fin_reqs[rid]
|
resp = fin_reqs[rid]
|
||||||
del fin_reqs[rid]
|
del fin_reqs[rid]
|
||||||
|
if isinstance(resp, tuple):
|
||||||
|
meta, img = resp
|
||||||
|
return rid, img, meta.params
|
||||||
|
|
||||||
if 'error' in img_resp.params:
|
raise SkynetDGPUComputeError(MessageToDict(resp.params))
|
||||||
raise SkynetDGPUComputeError(img_resp.params['error'])
|
|
||||||
|
|
||||||
return rid, img_resp.params['img'], img_resp.params['meta']
|
|
||||||
|
|
||||||
async def handle_user_request(rpc_ctx, req):
|
async def handle_user_request(rpc_ctx, req):
|
||||||
try:
|
try:
|
||||||
|
@ -204,13 +212,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
logging.info('txt2img')
|
logging.info('txt2img')
|
||||||
user_config = {**(await get_user_config(conn, user))}
|
user_config = {**(await get_user_config(conn, user))}
|
||||||
del user_config['id']
|
del user_config['id']
|
||||||
user_config.update((k, req.params[k]) for k in req.params)
|
user_config.update(MessageToDict(req.params))
|
||||||
req = ImageGenRequest(**user_config)
|
|
||||||
|
req = Text2ImageParameters(**user_config)
|
||||||
rid, img, meta = await dgpu_stream_one_img(req)
|
rid, img, meta = await dgpu_stream_one_img(req)
|
||||||
logging.info(f'done streaming {rid}')
|
logging.info(f'done streaming {rid}')
|
||||||
result = {
|
result = {
|
||||||
'id': rid,
|
'id': rid,
|
||||||
'img': img,
|
'img': zlib.compress(img).hex(),
|
||||||
'meta': meta
|
'meta': meta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,14 +233,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
prompt = await get_last_prompt_of(conn, user)
|
prompt = await get_last_prompt_of(conn, user)
|
||||||
|
|
||||||
if prompt:
|
if prompt:
|
||||||
req = ImageGenRequest(
|
req = Text2ImageParameters(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
**user_config
|
**user_config
|
||||||
)
|
)
|
||||||
rid, img, meta = await dgpu_stream_one_img(req)
|
rid, img, meta = await dgpu_stream_one_img(req)
|
||||||
result = {
|
result = {
|
||||||
'id': rid,
|
'id': rid,
|
||||||
'img': img,
|
'img': zlib.compress(img).hex(),
|
||||||
'meta': meta
|
'meta': meta
|
||||||
}
|
}
|
||||||
await update_user_stats(conn, user)
|
await update_user_stats(conn, user)
|
||||||
|
@ -289,14 +298,15 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
'message': str(e)
|
'message': str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = SkynetRPCResponse(result=result)
|
resp = SkynetRPCResponse()
|
||||||
|
resp.result.update(result)
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
resp.sign(tls_key, 'skynet')
|
resp.auth.cert = 'skynet'
|
||||||
|
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
||||||
|
|
||||||
logging.info('sending response')
|
logging.info('sending response')
|
||||||
await rpc_ctx.asend(
|
await rpc_ctx.asend(resp.SerializeToString())
|
||||||
json.dumps(resp.to_dict()).encode())
|
|
||||||
rpc_ctx.close()
|
rpc_ctx.close()
|
||||||
logging.info('done')
|
logging.info('done')
|
||||||
|
|
||||||
|
@ -304,19 +314,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
nonlocal next_worker
|
nonlocal next_worker
|
||||||
while True:
|
while True:
|
||||||
ctx = sock.new_context()
|
ctx = sock.new_context()
|
||||||
msg = await ctx.arecv_msg()
|
req = SkynetRPCRequest()
|
||||||
|
req.ParseFromString(await ctx.arecv())
|
||||||
content = msg.bytes.decode()
|
|
||||||
req = SkynetRPCRequest(**json.loads(content))
|
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
if req.cert not in tls_whitelist:
|
if req.auth.cert not in tls_whitelist:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f'{req.cert} not in tls whitelist and security=True')
|
f'{req.cert} not in tls whitelist and security=True')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
req.verify(tls_whitelist[req.cert])
|
verify_protobuf_msg(req, tls_whitelist[req.auth.cert])
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
|
@ -345,14 +353,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
||||||
handle_user_request, ctx, req)
|
handle_user_request, ctx, req)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
resp = SkynetRPCResponse(
|
resp = SkynetRPCResponse()
|
||||||
result={'ok': result})
|
resp.result.update({'ok': result})
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
resp.sign(tls_key, 'skynet')
|
resp.auth.cert = 'skynet'
|
||||||
|
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
||||||
|
|
||||||
await ctx.asend(
|
await ctx.asend(resp.SerializeToString())
|
||||||
json.dumps(resp.to_dict()).encode())
|
|
||||||
|
|
||||||
ctx.close()
|
ctx.close()
|
||||||
|
|
||||||
|
|
|
@ -59,8 +59,10 @@ ALTER TABLE skynet.user_config
|
||||||
|
|
||||||
|
|
||||||
def try_decode_uid(uid: str):
|
def try_decode_uid(uid: str):
|
||||||
if isinstance(uid, int):
|
try:
|
||||||
return None, uid
|
return None, int(uid)
|
||||||
|
except ValueError:
|
||||||
|
...
|
||||||
|
|
||||||
try:
|
try:
|
||||||
proto, uid = uid.split('+')
|
proto, uid = uid.split('+')
|
||||||
|
|
|
@ -8,6 +8,7 @@ import uuid
|
||||||
import base64
|
import base64
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -34,9 +35,9 @@ from .utils import (
|
||||||
pipeline_for,
|
pipeline_for,
|
||||||
convert_from_cv2_to_image, convert_from_image_to_cv2
|
convert_from_cv2_to_image, convert_from_image_to_cv2
|
||||||
)
|
)
|
||||||
from .structs import *
|
from .protobuf import *
|
||||||
from .constants import *
|
|
||||||
from .frontend import open_skynet_rpc
|
from .frontend import open_skynet_rpc
|
||||||
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
||||||
|
@ -92,7 +93,7 @@ async def open_dgpu_node(
|
||||||
logging.info('memory summary:')
|
logging.info('memory summary:')
|
||||||
logging.info('\n' + torch.cuda.memory_summary())
|
logging.info('\n' + torch.cuda.memory_summary())
|
||||||
|
|
||||||
async def gpu_compute_one(ireq: ImageGenRequest):
|
async def gpu_compute_one(ireq: Text2ImageParameters):
|
||||||
if ireq.algo not in models:
|
if ireq.algo not in models:
|
||||||
least_used = list(models.keys())[0]
|
least_used = list(models.keys())[0]
|
||||||
for model in models:
|
for model in models:
|
||||||
|
@ -110,10 +111,10 @@ async def open_dgpu_node(
|
||||||
try:
|
try:
|
||||||
image = models[ireq.algo]['pipe'](
|
image = models[ireq.algo]['pipe'](
|
||||||
ireq.prompt,
|
ireq.prompt,
|
||||||
width=ireq.width,
|
width=int(ireq.width),
|
||||||
height=ireq.height,
|
height=int(ireq.height),
|
||||||
guidance_scale=ireq.guidance,
|
guidance_scale=ireq.guidance,
|
||||||
num_inference_steps=ireq.step,
|
num_inference_steps=int(ireq.step),
|
||||||
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
|
@ -191,9 +192,8 @@ async def open_dgpu_node(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
msg = await dgpu_sock.arecv()
|
req = DGPUBusMessage()
|
||||||
req = DGPUBusRequest(
|
req.ParseFromString(await dgpu_sock.arecv())
|
||||||
**json.loads(msg.decode()))
|
|
||||||
|
|
||||||
if req.nid != unique_id:
|
if req.nid != unique_id:
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -201,54 +201,60 @@ async def open_dgpu_node(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
req.verify(skynet_cert)
|
verify_protobuf_msg(req, skynet_cert)
|
||||||
|
|
||||||
ack_resp = DGPUBusResponse(
|
ack_resp = DGPUBusMessage(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
nid=req.nid,
|
nid=req.nid
|
||||||
params={'ack': {}}
|
|
||||||
)
|
)
|
||||||
|
ack_resp.params.update({'ack': {}})
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
ack_resp.sign(tls_key, cert_name)
|
ack_resp.auth.cert = cert_name
|
||||||
|
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
||||||
|
|
||||||
# send ack
|
# send ack
|
||||||
await dgpu_sock.asend(
|
await dgpu_sock.asend(ack_resp.SerializeToString())
|
||||||
json.dumps(ack_resp.to_dict()).encode())
|
|
||||||
|
|
||||||
logging.info(f'sent ack, processing {req.rid}...')
|
logging.info(f'sent ack, processing {req.rid}...')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img_req = ImageGenRequest(**req.params)
|
img_req = Text2ImageParameters(**req.params)
|
||||||
if not img_req.seed:
|
if not img_req.seed:
|
||||||
img_req.seed = random.randint(0, 2 ** 64)
|
img_req.seed = random.randint(0, 2 ** 64)
|
||||||
|
|
||||||
img = await gpu_compute_one(img_req)
|
img = await gpu_compute_one(img_req)
|
||||||
img_resp = DGPUBusResponse(
|
img_resp = DGPUBusMessage(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
nid=req.nid,
|
nid=req.nid,
|
||||||
params={
|
method='binary-reply'
|
||||||
'img': base64.b64encode(img).hex(),
|
|
||||||
'meta': img_req.to_dict()
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
img_resp.params.update({
|
||||||
|
'len': len(img),
|
||||||
|
'meta': img_req.to_dict()
|
||||||
|
})
|
||||||
|
|
||||||
except DGPUComputeError as e:
|
except DGPUComputeError as e:
|
||||||
img_resp = DGPUBusResponse(
|
traceback.print_exception(type(e), e, e.__traceback__)
|
||||||
|
img_resp = DGPUBusMessage(
|
||||||
rid=req.rid,
|
rid=req.rid,
|
||||||
nid=req.nid,
|
nid=req.nid
|
||||||
params={'error': str(e)}
|
|
||||||
)
|
)
|
||||||
|
img_resp.params.update({'error': str(e)})
|
||||||
|
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
img_resp.sign(tls_key, cert_name)
|
img_resp.auth.cert = cert_name
|
||||||
|
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
|
||||||
raw_msg = json.dumps(img_resp.to_dict()).encode()
|
|
||||||
|
|
||||||
# send final image
|
# send final image
|
||||||
logging.info('sending img back...')
|
logging.info('sending img back...')
|
||||||
|
raw_msg = img_resp.SerializeToString()
|
||||||
await dgpu_sock.asend(raw_msg)
|
await dgpu_sock.asend(raw_msg)
|
||||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||||
|
if img_resp.method == 'binary-reply':
|
||||||
|
await dgpu_sock.asend(img)
|
||||||
|
logging.info(f'sent {len(img)} bytes.')
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info('interrupt caught, stopping...')
|
logging.info('interrupt caught, stopping...')
|
||||||
|
|
|
@ -15,9 +15,13 @@ from OpenSSL.crypto import (
|
||||||
FILETYPE_PEM
|
FILETYPE_PEM
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..structs import SkynetRPCRequest, SkynetRPCResponse
|
from google.protobuf.struct_pb2 import Struct
|
||||||
|
|
||||||
from ..constants import *
|
from ..constants import *
|
||||||
|
|
||||||
|
from ..protobuf.auth import *
|
||||||
|
from ..protobuf.skynet_pb2 import SkynetRPCRequest, SkynetRPCResponse
|
||||||
|
|
||||||
|
|
||||||
class ConfigRequestFormatError(BaseException):
|
class ConfigRequestFormatError(BaseException):
|
||||||
...
|
...
|
||||||
|
@ -79,28 +83,26 @@ async def open_skynet_rpc(
|
||||||
async def _rpc_call(
|
async def _rpc_call(
|
||||||
method: str,
|
method: str,
|
||||||
params: dict = {},
|
params: dict = {},
|
||||||
uid: Optional[Union[int, str]] = None
|
uid: Optional[str] = None
|
||||||
):
|
):
|
||||||
req = SkynetRPCRequest(
|
req = SkynetRPCRequest()
|
||||||
uid=uid if uid else unique_id,
|
req.uid = uid if uid else unique_id
|
||||||
method=method,
|
req.method = method
|
||||||
params=params
|
req.params.update(params)
|
||||||
)
|
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
req.sign(tls_key, cert_name)
|
req.auth.cert = cert_name
|
||||||
|
req.auth.sig = sign_protobuf_msg(req, tls_key)
|
||||||
|
|
||||||
ctx = sock.new_context()
|
ctx = sock.new_context()
|
||||||
await ctx.asend(
|
await ctx.asend(req.SerializeToString())
|
||||||
json.dumps(
|
|
||||||
req.to_dict()).encode())
|
|
||||||
|
|
||||||
resp = SkynetRPCResponse(
|
resp = SkynetRPCResponse()
|
||||||
**json.loads((await ctx.arecv()).decode()))
|
resp.ParseFromString(await ctx.arecv())
|
||||||
ctx.close()
|
ctx.close()
|
||||||
|
|
||||||
if security:
|
if security:
|
||||||
resp.verify(skynet_cert)
|
verify_protobuf_msg(resp, skynet_cert)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import base64
|
import zlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -110,7 +110,7 @@ async def run_skynet_telegram(
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.info(resp.result['id'])
|
logging.info(resp.result['id'])
|
||||||
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
|
img_raw = zlib.decompress(bytes.fromhex(resp.result['img']))
|
||||||
logging.info(f'got image of size: {len(img_raw)}')
|
logging.info(f'got image of size: {len(img_raw)}')
|
||||||
size = (512, 512)
|
size = (512, 512)
|
||||||
if resp.result['meta']['upscaler'] == 'x4':
|
if resp.result['meta']['upscaler'] == 'x4':
|
||||||
|
@ -141,7 +141,7 @@ async def run_skynet_telegram(
|
||||||
resp_txt = resp.result['message']
|
resp_txt = resp.result['message']
|
||||||
|
|
||||||
else:
|
else:
|
||||||
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
|
img_raw = zlib.decompress(bytes.fromhex(resp.result['img']))
|
||||||
logging.info(f'got image of size: {len(img_raw)}')
|
logging.info(f'got image of size: {len(img_raw)}')
|
||||||
size = (512, 512)
|
size = (512, 512)
|
||||||
logging.info(resp.result['meta'])
|
logging.info(resp.result['meta'])
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
from google.protobuf.json_format import MessageToDict
|
||||||
|
|
||||||
|
from .auth import *
|
||||||
|
from .skynet_pb2 import *
|
||||||
|
|
||||||
|
|
||||||
|
class Struct:
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Text2ImageParameters(Struct):
|
||||||
|
algo: str
|
||||||
|
prompt: str
|
||||||
|
step: int
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
guidance: float
|
||||||
|
seed: Optional[int]
|
||||||
|
upscaler: Optional[str]
|
|
@ -0,0 +1,65 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from hashlib import sha256
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from google.protobuf.json_format import MessageToDict
|
||||||
|
from OpenSSL.crypto import PKey, X509, verify, sign
|
||||||
|
|
||||||
|
from .skynet_pb2 import *
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_msg_deterministic(msg):
|
||||||
|
descriptors = sorted(
|
||||||
|
type(msg).DESCRIPTOR.fields_by_name.items(),
|
||||||
|
key=lambda x: x[0]
|
||||||
|
)
|
||||||
|
shasum = sha256()
|
||||||
|
|
||||||
|
def hash_dict(d):
|
||||||
|
data = [
|
||||||
|
(key, val)
|
||||||
|
for (key, val) in d.items()
|
||||||
|
]
|
||||||
|
for key, val in sorted(data, key=lambda x: x[0]):
|
||||||
|
if not isinstance(val, dict):
|
||||||
|
shasum.update(key.encode())
|
||||||
|
shasum.update(json.dumps(val).encode())
|
||||||
|
else:
|
||||||
|
hash_dict(val)
|
||||||
|
|
||||||
|
for (field_name, field_descriptor) in descriptors:
|
||||||
|
if not field_descriptor.message_type:
|
||||||
|
shasum.update(field_name.encode())
|
||||||
|
|
||||||
|
value = getattr(msg, field_name)
|
||||||
|
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
value = value.hex()
|
||||||
|
|
||||||
|
shasum.update(json.dumps(value).encode())
|
||||||
|
continue
|
||||||
|
|
||||||
|
if field_descriptor.message_type.name == 'Struct':
|
||||||
|
hash_dict(MessageToDict(getattr(msg, field_name)))
|
||||||
|
|
||||||
|
deterministic_msg = shasum.hexdigest()
|
||||||
|
|
||||||
|
return deterministic_msg
|
||||||
|
|
||||||
|
|
||||||
|
def sign_protobuf_msg(msg, key: PKey):
|
||||||
|
return sign(
|
||||||
|
key, serialize_msg_deterministic(msg), 'sha256').hex()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_protobuf_msg(msg, cert: X509):
|
||||||
|
return verify(
|
||||||
|
cert,
|
||||||
|
bytes.fromhex(msg.auth.sig),
|
||||||
|
serialize_msg_deterministic(msg),
|
||||||
|
'sha256'
|
||||||
|
)
|
|
@ -0,0 +1,30 @@
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package skynet;
|
||||||
|
|
||||||
|
import "google/protobuf/struct.proto";
|
||||||
|
|
||||||
|
message Auth {
|
||||||
|
string cert = 1;
|
||||||
|
string sig = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SkynetRPCRequest {
|
||||||
|
string uid = 1;
|
||||||
|
string method = 2;
|
||||||
|
google.protobuf.Struct params = 3;
|
||||||
|
optional Auth auth = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SkynetRPCResponse {
|
||||||
|
google.protobuf.Struct result = 1;
|
||||||
|
optional Auth auth = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DGPUBusMessage {
|
||||||
|
string rid = 1;
|
||||||
|
string nid = 2;
|
||||||
|
string method = 3;
|
||||||
|
google.protobuf.Struct params = 4;
|
||||||
|
optional Auth auth = 5;
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# source: skynet.proto
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3')
|
||||||
|
|
||||||
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals())
|
||||||
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
|
|
||||||
|
DESCRIPTOR._options = None
|
||||||
|
_AUTH._serialized_start=54
|
||||||
|
_AUTH._serialized_end=87
|
||||||
|
_SKYNETRPCREQUEST._serialized_start=90
|
||||||
|
_SKYNETRPCREQUEST._serialized_end=220
|
||||||
|
_SKYNETRPCRESPONSE._serialized_start=222
|
||||||
|
_SKYNETRPCRESPONSE._serialized_end=324
|
||||||
|
_DGPUBUSMESSAGE._serialized_start=327
|
||||||
|
_DGPUBUSMESSAGE._serialized_end=468
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
|
@ -3,7 +3,7 @@
|
||||||
import io
|
import io
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import base64
|
import zlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -12,10 +12,10 @@ from functools import partial
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
import pytest
|
import pytest
|
||||||
import tractor
|
|
||||||
import trio_asyncio
|
import trio_asyncio
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from google.protobuf.json_format import MessageToDict
|
||||||
|
|
||||||
from skynet.brain import SkynetDGPUComputeError
|
from skynet.brain import SkynetDGPUComputeError
|
||||||
from skynet.constants import *
|
from skynet.constants import *
|
||||||
|
@ -40,7 +40,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
||||||
_images = set()
|
_images = set()
|
||||||
async def check_request_img(
|
async def check_request_img(
|
||||||
i: int,
|
i: int,
|
||||||
uid: int = 0,
|
uid: str = '1',
|
||||||
width: int = 512,
|
width: int = 512,
|
||||||
height: int = 512,
|
height: int = 512,
|
||||||
expect_unique = True,
|
expect_unique = True,
|
||||||
|
@ -66,13 +66,13 @@ async def check_request_img(
|
||||||
})
|
})
|
||||||
|
|
||||||
if 'error' in res.result:
|
if 'error' in res.result:
|
||||||
raise SkynetDGPUComputeError(json.dumps(res.result))
|
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||||
|
|
||||||
if upscaler == 'x4':
|
if upscaler == 'x4':
|
||||||
width *= 4
|
width *= 4
|
||||||
height *= 4
|
height *= 4
|
||||||
|
|
||||||
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
|
img_raw = zlib.decompress(bytes.fromhex(res.result['img']))
|
||||||
img_sha = sha256(img_raw).hexdigest()
|
img_sha = sha256(img_raw).hexdigest()
|
||||||
img = Image.frombytes(
|
img = Image.frombytes(
|
||||||
'RGB', (width, height), img_raw)
|
'RGB', (width, height), img_raw)
|
||||||
|
|
Loading…
Reference in New Issue