Switch to protobuf

pull/1/head v0.1a6
Guillermo Rodriguez 2023-01-06 14:36:50 -03:00
parent 7c27ee866a
commit 6c1799e342
No known key found for this signature in database
GPG Key ID: EC3AB66D5D83B392
10 changed files with 267 additions and 95 deletions

View File

@ -2,7 +2,7 @@
import json import json
import uuid import uuid
import base64 import zlib
import logging import logging
import traceback import traceback
@ -24,9 +24,10 @@ from OpenSSL.crypto import (
) )
from .db import * from .db import *
from .structs import *
from .constants import * from .constants import *
from .protobuf import *
class SkynetDGPUOffline(BaseException): class SkynetDGPUOffline(BaseException):
... ...
@ -117,22 +118,30 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
async def dgpu_image_streamer(): async def dgpu_image_streamer():
nonlocal wip_reqs, fin_reqs nonlocal wip_reqs, fin_reqs
while True: while True:
raw_msg = (await dgpu_bus.arecv()).decode() raw_msg = await dgpu_bus.arecv()
logging.info(f'streamer got {len(raw_msg)} bytes.') logging.info(f'streamer got {len(raw_msg)} bytes.')
msg = DGPUBusResponse(**json.loads(raw_msg)) msg = DGPUBusMessage()
msg.ParseFromString(raw_msg)
if security: if security:
msg.verify(tls_whitelist[msg.cert]) verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert])
if msg.rid not in wip_reqs: rid = msg.rid
if rid not in wip_reqs:
continue continue
fin_reqs[msg.rid] = msg if msg.method == 'binary-reply':
event = wip_reqs[msg.rid] logging.info('bin reply, recv extra data')
event.set() raw_img = await dgpu_bus.arecv()
del wip_reqs[msg.rid] msg = (msg, raw_img)
async def dgpu_stream_one_img(req: ImageGenRequest): fin_reqs[rid] = msg
event = wip_reqs[rid]
event.set()
del wip_reqs[rid]
async def dgpu_stream_one_img(req: Text2ImageParameters):
nonlocal wip_reqs, fin_reqs, next_worker nonlocal wip_reqs, fin_reqs, next_worker
nid = get_next_worker() nid = get_next_worker()
idx = list(nodes.keys()).index(nid) idx = list(nodes.keys()).index(nid)
@ -144,19 +153,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nodes[nid]['task'] = rid nodes[nid]['task'] = rid
dgpu_req = DGPUBusRequest( dgpu_req = DGPUBusMessage(
rid=rid, rid=rid,
nid=nid, nid=nid,
task='diffuse', method='diffuse')
params=req.to_dict()) dgpu_req.params.update(req.to_dict())
logging.info(f'dgpu_bus req: {dgpu_req}')
if security: if security:
dgpu_req.sign(tls_key, 'skynet') dgpu_req.auth.cert = 'skynet'
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
await dgpu_bus.asend( await dgpu_bus.asend(dgpu_req.SerializeToString())
json.dumps(dgpu_req.to_dict()).encode())
with trio.move_on_after(4): with trio.move_on_after(4):
await ack_event.wait() await ack_event.wait()
@ -184,13 +191,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nodes[nid]['task'] = None nodes[nid]['task'] = None
img_resp = fin_reqs[rid] resp = fin_reqs[rid]
del fin_reqs[rid] del fin_reqs[rid]
if isinstance(resp, tuple):
meta, img = resp
return rid, img, meta.params
if 'error' in img_resp.params: raise SkynetDGPUComputeError(MessageToDict(resp.params))
raise SkynetDGPUComputeError(img_resp.params['error'])
return rid, img_resp.params['img'], img_resp.params['meta']
async def handle_user_request(rpc_ctx, req): async def handle_user_request(rpc_ctx, req):
try: try:
@ -204,13 +212,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
logging.info('txt2img') logging.info('txt2img')
user_config = {**(await get_user_config(conn, user))} user_config = {**(await get_user_config(conn, user))}
del user_config['id'] del user_config['id']
user_config.update((k, req.params[k]) for k in req.params) user_config.update(MessageToDict(req.params))
req = ImageGenRequest(**user_config)
req = Text2ImageParameters(**user_config)
rid, img, meta = await dgpu_stream_one_img(req) rid, img, meta = await dgpu_stream_one_img(req)
logging.info(f'done streaming {rid}') logging.info(f'done streaming {rid}')
result = { result = {
'id': rid, 'id': rid,
'img': img, 'img': zlib.compress(img).hex(),
'meta': meta 'meta': meta
} }
@ -224,14 +233,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
prompt = await get_last_prompt_of(conn, user) prompt = await get_last_prompt_of(conn, user)
if prompt: if prompt:
req = ImageGenRequest( req = Text2ImageParameters(
prompt=prompt, prompt=prompt,
**user_config **user_config
) )
rid, img, meta = await dgpu_stream_one_img(req) rid, img, meta = await dgpu_stream_one_img(req)
result = { result = {
'id': rid, 'id': rid,
'img': img, 'img': zlib.compress(img).hex(),
'meta': meta 'meta': meta
} }
await update_user_stats(conn, user) await update_user_stats(conn, user)
@ -289,14 +298,15 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
'message': str(e) 'message': str(e)
} }
resp = SkynetRPCResponse(result=result) resp = SkynetRPCResponse()
resp.result.update(result)
if security: if security:
resp.sign(tls_key, 'skynet') resp.auth.cert = 'skynet'
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
logging.info('sending response') logging.info('sending response')
await rpc_ctx.asend( await rpc_ctx.asend(resp.SerializeToString())
json.dumps(resp.to_dict()).encode())
rpc_ctx.close() rpc_ctx.close()
logging.info('done') logging.info('done')
@ -304,19 +314,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
nonlocal next_worker nonlocal next_worker
while True: while True:
ctx = sock.new_context() ctx = sock.new_context()
msg = await ctx.arecv_msg() req = SkynetRPCRequest()
req.ParseFromString(await ctx.arecv())
content = msg.bytes.decode()
req = SkynetRPCRequest(**json.loads(content))
if security: if security:
if req.cert not in tls_whitelist: if req.auth.cert not in tls_whitelist:
logging.warning( logging.warning(
f'{req.cert} not in tls whitelist and security=True') f'{req.cert} not in tls whitelist and security=True')
continue continue
try: try:
req.verify(tls_whitelist[req.cert]) verify_protobuf_msg(req, tls_whitelist[req.auth.cert])
except ValueError: except ValueError:
logging.warning( logging.warning(
@ -345,14 +353,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
handle_user_request, ctx, req) handle_user_request, ctx, req)
continue continue
resp = SkynetRPCResponse( resp = SkynetRPCResponse()
result={'ok': result}) resp.result.update({'ok': result})
if security: if security:
resp.sign(tls_key, 'skynet') resp.auth.cert = 'skynet'
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
await ctx.asend( await ctx.asend(resp.SerializeToString())
json.dumps(resp.to_dict()).encode())
ctx.close() ctx.close()

View File

@ -59,8 +59,10 @@ ALTER TABLE skynet.user_config
def try_decode_uid(uid: str): def try_decode_uid(uid: str):
if isinstance(uid, int): try:
return None, uid return None, int(uid)
except ValueError:
...
try: try:
proto, uid = uid.split('+') proto, uid = uid.split('+')

View File

@ -8,6 +8,7 @@ import uuid
import base64 import base64
import random import random
import logging import logging
import traceback
from typing import List, Optional from typing import List, Optional
from pathlib import Path from pathlib import Path
@ -34,9 +35,9 @@ from .utils import (
pipeline_for, pipeline_for,
convert_from_cv2_to_image, convert_from_image_to_cv2 convert_from_cv2_to_image, convert_from_image_to_cv2
) )
from .structs import * from .protobuf import *
from .constants import *
from .frontend import open_skynet_rpc from .frontend import open_skynet_rpc
from .constants import *
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'): def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
@ -92,7 +93,7 @@ async def open_dgpu_node(
logging.info('memory summary:') logging.info('memory summary:')
logging.info('\n' + torch.cuda.memory_summary()) logging.info('\n' + torch.cuda.memory_summary())
async def gpu_compute_one(ireq: ImageGenRequest): async def gpu_compute_one(ireq: Text2ImageParameters):
if ireq.algo not in models: if ireq.algo not in models:
least_used = list(models.keys())[0] least_used = list(models.keys())[0]
for model in models: for model in models:
@ -110,10 +111,10 @@ async def open_dgpu_node(
try: try:
image = models[ireq.algo]['pipe']( image = models[ireq.algo]['pipe'](
ireq.prompt, ireq.prompt,
width=ireq.width, width=int(ireq.width),
height=ireq.height, height=int(ireq.height),
guidance_scale=ireq.guidance, guidance_scale=ireq.guidance,
num_inference_steps=ireq.step, num_inference_steps=int(ireq.step),
generator=torch.Generator("cuda").manual_seed(ireq.seed) generator=torch.Generator("cuda").manual_seed(ireq.seed)
).images[0] ).images[0]
@ -191,9 +192,8 @@ async def open_dgpu_node(
try: try:
while True: while True:
msg = await dgpu_sock.arecv() req = DGPUBusMessage()
req = DGPUBusRequest( req.ParseFromString(await dgpu_sock.arecv())
**json.loads(msg.decode()))
if req.nid != unique_id: if req.nid != unique_id:
logging.info( logging.info(
@ -201,54 +201,60 @@ async def open_dgpu_node(
continue continue
if security: if security:
req.verify(skynet_cert) verify_protobuf_msg(req, skynet_cert)
ack_resp = DGPUBusResponse( ack_resp = DGPUBusMessage(
rid=req.rid, rid=req.rid,
nid=req.nid, nid=req.nid
params={'ack': {}}
) )
ack_resp.params.update({'ack': {}})
if security: if security:
ack_resp.sign(tls_key, cert_name) ack_resp.auth.cert = cert_name
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
# send ack # send ack
await dgpu_sock.asend( await dgpu_sock.asend(ack_resp.SerializeToString())
json.dumps(ack_resp.to_dict()).encode())
logging.info(f'sent ack, processing {req.rid}...') logging.info(f'sent ack, processing {req.rid}...')
try: try:
img_req = ImageGenRequest(**req.params) img_req = Text2ImageParameters(**req.params)
if not img_req.seed: if not img_req.seed:
img_req.seed = random.randint(0, 2 ** 64) img_req.seed = random.randint(0, 2 ** 64)
img = await gpu_compute_one(img_req) img = await gpu_compute_one(img_req)
img_resp = DGPUBusResponse( img_resp = DGPUBusMessage(
rid=req.rid, rid=req.rid,
nid=req.nid, nid=req.nid,
params={ method='binary-reply'
'img': base64.b64encode(img).hex(),
'meta': img_req.to_dict()
}
) )
img_resp.params.update({
'len': len(img),
'meta': img_req.to_dict()
})
except DGPUComputeError as e: except DGPUComputeError as e:
img_resp = DGPUBusResponse( traceback.print_exception(type(e), e, e.__traceback__)
img_resp = DGPUBusMessage(
rid=req.rid, rid=req.rid,
nid=req.nid, nid=req.nid
params={'error': str(e)}
) )
img_resp.params.update({'error': str(e)})
if security: if security:
img_resp.sign(tls_key, cert_name) img_resp.auth.cert = cert_name
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
raw_msg = json.dumps(img_resp.to_dict()).encode()
# send final image # send final image
logging.info('sending img back...') logging.info('sending img back...')
raw_msg = img_resp.SerializeToString()
await dgpu_sock.asend(raw_msg) await dgpu_sock.asend(raw_msg)
logging.info(f'sent {len(raw_msg)} bytes.') logging.info(f'sent {len(raw_msg)} bytes.')
if img_resp.method == 'binary-reply':
await dgpu_sock.asend(img)
logging.info(f'sent {len(img)} bytes.')
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info('interrupt caught, stopping...') logging.info('interrupt caught, stopping...')

View File

@ -15,9 +15,13 @@ from OpenSSL.crypto import (
FILETYPE_PEM FILETYPE_PEM
) )
from ..structs import SkynetRPCRequest, SkynetRPCResponse from google.protobuf.struct_pb2 import Struct
from ..constants import * from ..constants import *
from ..protobuf.auth import *
from ..protobuf.skynet_pb2 import SkynetRPCRequest, SkynetRPCResponse
class ConfigRequestFormatError(BaseException): class ConfigRequestFormatError(BaseException):
... ...
@ -79,28 +83,26 @@ async def open_skynet_rpc(
async def _rpc_call( async def _rpc_call(
method: str, method: str,
params: dict = {}, params: dict = {},
uid: Optional[Union[int, str]] = None uid: Optional[str] = None
): ):
req = SkynetRPCRequest( req = SkynetRPCRequest()
uid=uid if uid else unique_id, req.uid = uid if uid else unique_id
method=method, req.method = method
params=params req.params.update(params)
)
if security: if security:
req.sign(tls_key, cert_name) req.auth.cert = cert_name
req.auth.sig = sign_protobuf_msg(req, tls_key)
ctx = sock.new_context() ctx = sock.new_context()
await ctx.asend( await ctx.asend(req.SerializeToString())
json.dumps(
req.to_dict()).encode())
resp = SkynetRPCResponse( resp = SkynetRPCResponse()
**json.loads((await ctx.arecv()).decode())) resp.ParseFromString(await ctx.arecv())
ctx.close() ctx.close()
if security: if security:
resp.verify(skynet_cert) verify_protobuf_msg(resp, skynet_cert)
return resp return resp

View File

@ -1,7 +1,7 @@
#!/usr/bin/python #!/usr/bin/python
import io import io
import base64 import zlib
import logging import logging
from datetime import datetime from datetime import datetime
@ -110,7 +110,7 @@ async def run_skynet_telegram(
else: else:
logging.info(resp.result['id']) logging.info(resp.result['id'])
img_raw = base64.b64decode(bytes.fromhex(resp.result['img'])) img_raw = zlib.decompress(bytes.fromhex(resp.result['img']))
logging.info(f'got image of size: {len(img_raw)}') logging.info(f'got image of size: {len(img_raw)}')
size = (512, 512) size = (512, 512)
if resp.result['meta']['upscaler'] == 'x4': if resp.result['meta']['upscaler'] == 'x4':
@ -141,7 +141,7 @@ async def run_skynet_telegram(
resp_txt = resp.result['message'] resp_txt = resp.result['message']
else: else:
img_raw = base64.b64decode(bytes.fromhex(resp.result['img'])) img_raw = zlib.decompress(bytes.fromhex(resp.result['img']))
logging.info(f'got image of size: {len(img_raw)}') logging.info(f'got image of size: {len(img_raw)}')
size = (512, 512) size = (512, 512)
logging.info(resp.result['meta']) logging.info(resp.result['meta'])

View File

@ -0,0 +1,27 @@
#!/usr/bin/python
from typing import Optional
from dataclasses import dataclass, asdict
from google.protobuf.json_format import MessageToDict
from .auth import *
from .skynet_pb2 import *
class Struct:
def to_dict(self):
return asdict(self)
@dataclass
class Text2ImageParameters(Struct):
algo: str
prompt: str
step: int
width: int
height: int
guidance: float
seed: Optional[int]
upscaler: Optional[str]

View File

@ -0,0 +1,65 @@
#!/usr/bin/python
import json
import logging
from hashlib import sha256
from collections import OrderedDict
from google.protobuf.json_format import MessageToDict
from OpenSSL.crypto import PKey, X509, verify, sign
from .skynet_pb2 import *
def serialize_msg_deterministic(msg):
descriptors = sorted(
type(msg).DESCRIPTOR.fields_by_name.items(),
key=lambda x: x[0]
)
shasum = sha256()
def hash_dict(d):
data = [
(key, val)
for (key, val) in d.items()
]
for key, val in sorted(data, key=lambda x: x[0]):
if not isinstance(val, dict):
shasum.update(key.encode())
shasum.update(json.dumps(val).encode())
else:
hash_dict(val)
for (field_name, field_descriptor) in descriptors:
if not field_descriptor.message_type:
shasum.update(field_name.encode())
value = getattr(msg, field_name)
if isinstance(value, bytes):
value = value.hex()
shasum.update(json.dumps(value).encode())
continue
if field_descriptor.message_type.name == 'Struct':
hash_dict(MessageToDict(getattr(msg, field_name)))
deterministic_msg = shasum.hexdigest()
return deterministic_msg
def sign_protobuf_msg(msg, key: PKey):
return sign(
key, serialize_msg_deterministic(msg), 'sha256').hex()
def verify_protobuf_msg(msg, cert: X509):
return verify(
cert,
bytes.fromhex(msg.auth.sig),
serialize_msg_deterministic(msg),
'sha256'
)

View File

@ -0,0 +1,30 @@
syntax = "proto3";
package skynet;
import "google/protobuf/struct.proto";
message Auth {
string cert = 1;
string sig = 2;
}
message SkynetRPCRequest {
string uid = 1;
string method = 2;
google.protobuf.Struct params = 3;
optional Auth auth = 4;
}
message SkynetRPCResponse {
google.protobuf.Struct result = 1;
optional Auth auth = 2;
}
message DGPUBusMessage {
string rid = 1;
string nid = 2;
string method = 3;
google.protobuf.Struct params = 4;
optional Auth auth = 5;
}

View File

@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: skynet.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_AUTH._serialized_start=54
_AUTH._serialized_end=87
_SKYNETRPCREQUEST._serialized_start=90
_SKYNETRPCREQUEST._serialized_end=220
_SKYNETRPCRESPONSE._serialized_start=222
_SKYNETRPCRESPONSE._serialized_end=324
_DGPUBUSMESSAGE._serialized_start=327
_DGPUBUSMESSAGE._serialized_end=468
# @@protoc_insertion_point(module_scope)

View File

@ -3,7 +3,7 @@
import io import io
import time import time
import json import json
import base64 import zlib
import logging import logging
from typing import Optional from typing import Optional
@ -12,10 +12,10 @@ from functools import partial
import trio import trio
import pytest import pytest
import tractor
import trio_asyncio import trio_asyncio
from PIL import Image from PIL import Image
from google.protobuf.json_format import MessageToDict
from skynet.brain import SkynetDGPUComputeError from skynet.brain import SkynetDGPUComputeError
from skynet.constants import * from skynet.constants import *
@ -40,7 +40,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
_images = set() _images = set()
async def check_request_img( async def check_request_img(
i: int, i: int,
uid: int = 0, uid: str = '1',
width: int = 512, width: int = 512,
height: int = 512, height: int = 512,
expect_unique = True, expect_unique = True,
@ -66,13 +66,13 @@ async def check_request_img(
}) })
if 'error' in res.result: if 'error' in res.result:
raise SkynetDGPUComputeError(json.dumps(res.result)) raise SkynetDGPUComputeError(MessageToDict(res.result))
if upscaler == 'x4': if upscaler == 'x4':
width *= 4 width *= 4
height *= 4 height *= 4
img_raw = base64.b64decode(bytes.fromhex(res.result['img'])) img_raw = zlib.decompress(bytes.fromhex(res.result['img']))
img_sha = sha256(img_raw).hexdigest() img_sha = sha256(img_raw).hexdigest()
img = Image.frombytes( img = Image.frombytes(
'RGB', (width, height), img_raw) 'RGB', (width, height), img_raw)