Switch to protobuf

pull/1/head v0.1a6
Guillermo Rodriguez 2023-01-06 14:36:50 -03:00
parent 7c27ee866a
commit 6c1799e342
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
10 changed files with 267 additions and 95 deletions

View File

@ -2,7 +2,7 @@
import json
import uuid
import base64
import zlib
import logging
import traceback
@ -24,9 +24,10 @@ from OpenSSL.crypto import (
)
from .db import *
from .structs import *
from .constants import *
from .protobuf import *
class SkynetDGPUOffline(BaseException):
...
@ -117,22 +118,30 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
async def dgpu_image_streamer():
nonlocal wip_reqs, fin_reqs
while True:
raw_msg = (await dgpu_bus.arecv()).decode()
raw_msg = await dgpu_bus.arecv()
logging.info(f'streamer got {len(raw_msg)} bytes.')
msg = DGPUBusResponse(**json.loads(raw_msg))
msg = DGPUBusMessage()
msg.ParseFromString(raw_msg)
if security:
msg.verify(tls_whitelist[msg.cert])
verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert])
if msg.rid not in wip_reqs:
rid = msg.rid
if rid not in wip_reqs:
continue
fin_reqs[msg.rid] = msg
event = wip_reqs[msg.rid]
event.set()
del wip_reqs[msg.rid]
if msg.method == 'binary-reply':
logging.info('bin reply, recv extra data')
raw_img = await dgpu_bus.arecv()
msg = (msg, raw_img)
async def dgpu_stream_one_img(req: ImageGenRequest):
fin_reqs[rid] = msg
event = wip_reqs[rid]
event.set()
del wip_reqs[rid]
async def dgpu_stream_one_img(req: Text2ImageParameters):
nonlocal wip_reqs, fin_reqs, next_worker
nid = get_next_worker()
idx = list(nodes.keys()).index(nid)
@ -144,19 +153,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nodes[nid]['task'] = rid
dgpu_req = DGPUBusRequest(
dgpu_req = DGPUBusMessage(
rid=rid,
nid=nid,
task='diffuse',
params=req.to_dict())
logging.info(f'dgpu_bus req: {dgpu_req}')
method='diffuse')
dgpu_req.params.update(req.to_dict())
if security:
dgpu_req.sign(tls_key, 'skynet')
dgpu_req.auth.cert = 'skynet'
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
await dgpu_bus.asend(
json.dumps(dgpu_req.to_dict()).encode())
await dgpu_bus.asend(dgpu_req.SerializeToString())
with trio.move_on_after(4):
await ack_event.wait()
@ -184,13 +191,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nodes[nid]['task'] = None
img_resp = fin_reqs[rid]
resp = fin_reqs[rid]
del fin_reqs[rid]
if isinstance(resp, tuple):
meta, img = resp
return rid, img, meta.params
if 'error' in img_resp.params:
raise SkynetDGPUComputeError(img_resp.params['error'])
raise SkynetDGPUComputeError(MessageToDict(resp.params))
return rid, img_resp.params['img'], img_resp.params['meta']
async def handle_user_request(rpc_ctx, req):
try:
@ -204,13 +212,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
logging.info('txt2img')
user_config = {**(await get_user_config(conn, user))}
del user_config['id']
user_config.update((k, req.params[k]) for k in req.params)
req = ImageGenRequest(**user_config)
user_config.update(MessageToDict(req.params))
req = Text2ImageParameters(**user_config)
rid, img, meta = await dgpu_stream_one_img(req)
logging.info(f'done streaming {rid}')
result = {
'id': rid,
'img': img,
'img': zlib.compress(img).hex(),
'meta': meta
}
@ -224,14 +233,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
prompt = await get_last_prompt_of(conn, user)
if prompt:
req = ImageGenRequest(
req = Text2ImageParameters(
prompt=prompt,
**user_config
)
rid, img, meta = await dgpu_stream_one_img(req)
result = {
'id': rid,
'img': img,
'img': zlib.compress(img).hex(),
'meta': meta
}
await update_user_stats(conn, user)
@ -289,14 +298,15 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
'message': str(e)
}
resp = SkynetRPCResponse(result=result)
resp = SkynetRPCResponse()
resp.result.update(result)
if security:
resp.sign(tls_key, 'skynet')
resp.auth.cert = 'skynet'
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
logging.info('sending response')
await rpc_ctx.asend(
json.dumps(resp.to_dict()).encode())
await rpc_ctx.asend(resp.SerializeToString())
rpc_ctx.close()
logging.info('done')
@ -304,19 +314,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nonlocal next_worker
while True:
ctx = sock.new_context()
msg = await ctx.arecv_msg()
content = msg.bytes.decode()
req = SkynetRPCRequest(**json.loads(content))
req = SkynetRPCRequest()
req.ParseFromString(await ctx.arecv())
if security:
if req.cert not in tls_whitelist:
if req.auth.cert not in tls_whitelist:
logging.warning(
f'{req.cert} not in tls whitelist and security=True')
continue
try:
req.verify(tls_whitelist[req.cert])
verify_protobuf_msg(req, tls_whitelist[req.auth.cert])
except ValueError:
logging.warning(
@ -345,14 +353,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
handle_user_request, ctx, req)
continue
resp = SkynetRPCResponse(
result={'ok': result})
resp = SkynetRPCResponse()
resp.result.update({'ok': result})
if security:
resp.sign(tls_key, 'skynet')
resp.auth.cert = 'skynet'
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
await ctx.asend(
json.dumps(resp.to_dict()).encode())
await ctx.asend(resp.SerializeToString())
ctx.close()

View File

@ -59,8 +59,10 @@ ALTER TABLE skynet.user_config
def try_decode_uid(uid: str):
if isinstance(uid, int):
return None, uid
try:
return None, int(uid)
except ValueError:
...
try:
proto, uid = uid.split('+')

View File

@ -8,6 +8,7 @@ import uuid
import base64
import random
import logging
import traceback
from typing import List, Optional
from pathlib import Path
@ -34,9 +35,9 @@ from .utils import (
pipeline_for,
convert_from_cv2_to_image, convert_from_image_to_cv2
)
from .structs import *
from .constants import *
from .protobuf import *
from .frontend import open_skynet_rpc
from .constants import *
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
@ -92,7 +93,7 @@ async def open_dgpu_node(
logging.info('memory summary:')
logging.info('\n' + torch.cuda.memory_summary())
async def gpu_compute_one(ireq: ImageGenRequest):
async def gpu_compute_one(ireq: Text2ImageParameters):
if ireq.algo not in models:
least_used = list(models.keys())[0]
for model in models:
@ -110,10 +111,10 @@ async def open_dgpu_node(
try:
image = models[ireq.algo]['pipe'](
ireq.prompt,
width=ireq.width,
height=ireq.height,
width=int(ireq.width),
height=int(ireq.height),
guidance_scale=ireq.guidance,
num_inference_steps=ireq.step,
num_inference_steps=int(ireq.step),
generator=torch.Generator("cuda").manual_seed(ireq.seed)
).images[0]
@ -191,9 +192,8 @@ async def open_dgpu_node(
try:
while True:
msg = await dgpu_sock.arecv()
req = DGPUBusRequest(
**json.loads(msg.decode()))
req = DGPUBusMessage()
req.ParseFromString(await dgpu_sock.arecv())
if req.nid != unique_id:
logging.info(
@ -201,54 +201,60 @@ async def open_dgpu_node(
continue
if security:
req.verify(skynet_cert)
verify_protobuf_msg(req, skynet_cert)
ack_resp = DGPUBusResponse(
ack_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid,
params={'ack': {}}
nid=req.nid
)
ack_resp.params.update({'ack': {}})
if security:
ack_resp.sign(tls_key, cert_name)
ack_resp.auth.cert = cert_name
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
# send ack
await dgpu_sock.asend(
json.dumps(ack_resp.to_dict()).encode())
await dgpu_sock.asend(ack_resp.SerializeToString())
logging.info(f'sent ack, processing {req.rid}...')
try:
img_req = ImageGenRequest(**req.params)
img_req = Text2ImageParameters(**req.params)
if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req)
img_resp = DGPUBusResponse(
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid,
params={
'img': base64.b64encode(img).hex(),
'meta': img_req.to_dict()
}
method='binary-reply'
)
img_resp.params.update({
'len': len(img),
'meta': img_req.to_dict()
})
except DGPUComputeError as e:
img_resp = DGPUBusResponse(
traceback.print_exception(type(e), e, e.__traceback__)
img_resp = DGPUBusMessage(
rid=req.rid,
nid=req.nid,
params={'error': str(e)}
nid=req.nid
)
img_resp.params.update({'error': str(e)})
if security:
img_resp.sign(tls_key, cert_name)
raw_msg = json.dumps(img_resp.to_dict()).encode()
img_resp.auth.cert = cert_name
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
# send final image
logging.info('sending img back...')
raw_msg = img_resp.SerializeToString()
await dgpu_sock.asend(raw_msg)
logging.info(f'sent {len(raw_msg)} bytes.')
if img_resp.method == 'binary-reply':
await dgpu_sock.asend(img)
logging.info(f'sent {len(img)} bytes.')
except KeyboardInterrupt:
logging.info('interrupt caught, stopping...')

View File

@ -15,9 +15,13 @@ from OpenSSL.crypto import (
FILETYPE_PEM
)
from ..structs import SkynetRPCRequest, SkynetRPCResponse
from google.protobuf.struct_pb2 import Struct
from ..constants import *
from ..protobuf.auth import *
from ..protobuf.skynet_pb2 import SkynetRPCRequest, SkynetRPCResponse
class ConfigRequestFormatError(BaseException):
...
@ -79,28 +83,26 @@ async def open_skynet_rpc(
async def _rpc_call(
method: str,
params: dict = {},
uid: Optional[Union[int, str]] = None
uid: Optional[str] = None
):
req = SkynetRPCRequest(
uid=uid if uid else unique_id,
method=method,
params=params
)
req = SkynetRPCRequest()
req.uid = uid if uid else unique_id
req.method = method
req.params.update(params)
if security:
req.sign(tls_key, cert_name)
req.auth.cert = cert_name
req.auth.sig = sign_protobuf_msg(req, tls_key)
ctx = sock.new_context()
await ctx.asend(
json.dumps(
req.to_dict()).encode())
await ctx.asend(req.SerializeToString())
resp = SkynetRPCResponse(
**json.loads((await ctx.arecv()).decode()))
resp = SkynetRPCResponse()
resp.ParseFromString(await ctx.arecv())
ctx.close()
if security:
resp.verify(skynet_cert)
verify_protobuf_msg(resp, skynet_cert)
return resp

View File

@ -1,7 +1,7 @@
#!/usr/bin/python
import io
import base64
import zlib
import logging
from datetime import datetime
@ -110,7 +110,7 @@ async def run_skynet_telegram(
else:
logging.info(resp.result['id'])
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
img_raw = zlib.decompress(bytes.fromhex(resp.result['img']))
logging.info(f'got image of size: {len(img_raw)}')
size = (512, 512)
if resp.result['meta']['upscaler'] == 'x4':
@ -141,7 +141,7 @@ async def run_skynet_telegram(
resp_txt = resp.result['message']
else:
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
img_raw = zlib.decompress(bytes.fromhex(resp.result['img']))
logging.info(f'got image of size: {len(img_raw)}')
size = (512, 512)
logging.info(resp.result['meta'])

View File

@ -0,0 +1,27 @@
#!/usr/bin/python
from typing import Optional
from dataclasses import dataclass, asdict
from google.protobuf.json_format import MessageToDict
from .auth import *
from .skynet_pb2 import *
class Struct:
def to_dict(self):
return asdict(self)
@dataclass
class Text2ImageParameters(Struct):
algo: str
prompt: str
step: int
width: int
height: int
guidance: float
seed: Optional[int]
upscaler: Optional[str]

View File

@ -0,0 +1,65 @@
#!/usr/bin/python
import json
import logging
from hashlib import sha256
from collections import OrderedDict
from google.protobuf.json_format import MessageToDict
from OpenSSL.crypto import PKey, X509, verify, sign
from .skynet_pb2 import *
def serialize_msg_deterministic(msg):
descriptors = sorted(
type(msg).DESCRIPTOR.fields_by_name.items(),
key=lambda x: x[0]
)
shasum = sha256()
def hash_dict(d):
data = [
(key, val)
for (key, val) in d.items()
]
for key, val in sorted(data, key=lambda x: x[0]):
if not isinstance(val, dict):
shasum.update(key.encode())
shasum.update(json.dumps(val).encode())
else:
hash_dict(val)
for (field_name, field_descriptor) in descriptors:
if not field_descriptor.message_type:
shasum.update(field_name.encode())
value = getattr(msg, field_name)
if isinstance(value, bytes):
value = value.hex()
shasum.update(json.dumps(value).encode())
continue
if field_descriptor.message_type.name == 'Struct':
hash_dict(MessageToDict(getattr(msg, field_name)))
deterministic_msg = shasum.hexdigest()
return deterministic_msg
def sign_protobuf_msg(msg, key: PKey):
return sign(
key, serialize_msg_deterministic(msg), 'sha256').hex()
def verify_protobuf_msg(msg, cert: X509):
return verify(
cert,
bytes.fromhex(msg.auth.sig),
serialize_msg_deterministic(msg),
'sha256'
)

View File

@ -0,0 +1,30 @@
syntax = "proto3";
package skynet;
import "google/protobuf/struct.proto";
message Auth {
string cert = 1;
string sig = 2;
}
message SkynetRPCRequest {
string uid = 1;
string method = 2;
google.protobuf.Struct params = 3;
optional Auth auth = 4;
}
message SkynetRPCResponse {
google.protobuf.Struct result = 1;
optional Auth auth = 2;
}
message DGPUBusMessage {
string rid = 1;
string nid = 2;
string method = 3;
google.protobuf.Struct params = 4;
optional Auth auth = 5;
}

View File

@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: skynet.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_AUTH._serialized_start=54
_AUTH._serialized_end=87
_SKYNETRPCREQUEST._serialized_start=90
_SKYNETRPCREQUEST._serialized_end=220
_SKYNETRPCRESPONSE._serialized_start=222
_SKYNETRPCRESPONSE._serialized_end=324
_DGPUBUSMESSAGE._serialized_start=327
_DGPUBUSMESSAGE._serialized_end=468
# @@protoc_insertion_point(module_scope)

View File

@ -3,7 +3,7 @@
import io
import time
import json
import base64
import zlib
import logging
from typing import Optional
@ -12,10 +12,10 @@ from functools import partial
import trio
import pytest
import tractor
import trio_asyncio
from PIL import Image
from google.protobuf.json_format import MessageToDict
from skynet.brain import SkynetDGPUComputeError
from skynet.constants import *
@ -40,7 +40,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
_images = set()
async def check_request_img(
i: int,
uid: int = 0,
uid: str = '1',
width: int = 512,
height: int = 512,
expect_unique = True,
@ -66,13 +66,13 @@ async def check_request_img(
})
if 'error' in res.result:
raise SkynetDGPUComputeError(json.dumps(res.result))
raise SkynetDGPUComputeError(MessageToDict(res.result))
if upscaler == 'x4':
width *= 4
height *= 4
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
img_raw = zlib.decompress(bytes.fromhex(res.result['img']))
img_sha = sha256(img_raw).hexdigest()
img = Image.frombytes(
'RGB', (width, height), img_raw)