diff --git a/skynet/brain.py b/skynet/brain.py index 2094faa..ab28df0 100644 --- a/skynet/brain.py +++ b/skynet/brain.py @@ -2,7 +2,7 @@ import json import uuid -import base64 +import zlib import logging import traceback @@ -24,9 +24,10 @@ from OpenSSL.crypto import ( ) from .db import * -from .structs import * from .constants import * +from .protobuf import * + class SkynetDGPUOffline(BaseException): ... @@ -117,22 +118,30 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): async def dgpu_image_streamer(): nonlocal wip_reqs, fin_reqs while True: - raw_msg = (await dgpu_bus.arecv()).decode() + raw_msg = await dgpu_bus.arecv() logging.info(f'streamer got {len(raw_msg)} bytes.') - msg = DGPUBusResponse(**json.loads(raw_msg)) + msg = DGPUBusMessage() + msg.ParseFromString(raw_msg) if security: - msg.verify(tls_whitelist[msg.cert]) + verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert]) - if msg.rid not in wip_reqs: + rid = msg.rid + + if rid not in wip_reqs: continue - fin_reqs[msg.rid] = msg - event = wip_reqs[msg.rid] - event.set() - del wip_reqs[msg.rid] + if msg.method == 'binary-reply': + logging.info('bin reply, recv extra data') + raw_img = await dgpu_bus.arecv() + msg = (msg, raw_img) - async def dgpu_stream_one_img(req: ImageGenRequest): + fin_reqs[rid] = msg + event = wip_reqs[rid] + event.set() + del wip_reqs[rid] + + async def dgpu_stream_one_img(req: Text2ImageParameters): nonlocal wip_reqs, fin_reqs, next_worker nid = get_next_worker() idx = list(nodes.keys()).index(nid) @@ -144,19 +153,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): nodes[nid]['task'] = rid - dgpu_req = DGPUBusRequest( + dgpu_req = DGPUBusMessage( rid=rid, nid=nid, - task='diffuse', - params=req.to_dict()) - - logging.info(f'dgpu_bus req: {dgpu_req}') + method='diffuse') + dgpu_req.params.update(req.to_dict()) if security: - dgpu_req.sign(tls_key, 'skynet') + dgpu_req.auth.cert = 'skynet' + dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key) - await dgpu_bus.asend( - json.dumps(dgpu_req.to_dict()).encode()) + await dgpu_bus.asend(dgpu_req.SerializeToString()) with trio.move_on_after(4): await ack_event.wait() @@ -184,13 +191,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): nodes[nid]['task'] = None - img_resp = fin_reqs[rid] + resp = fin_reqs[rid] del fin_reqs[rid] + if isinstance(resp, tuple): + meta, img = resp + return rid, img, meta.params - if 'error' in img_resp.params: - raise SkynetDGPUComputeError(img_resp.params['error']) + raise SkynetDGPUComputeError(MessageToDict(resp.params)) - return rid, img_resp.params['img'], img_resp.params['meta'] async def handle_user_request(rpc_ctx, req): try: @@ -204,13 +212,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): logging.info('txt2img') user_config = {**(await get_user_config(conn, user))} del user_config['id'] - user_config.update((k, req.params[k]) for k in req.params) - req = ImageGenRequest(**user_config) + user_config.update(MessageToDict(req.params)) + + req = Text2ImageParameters(**user_config) rid, img, meta = await dgpu_stream_one_img(req) logging.info(f'done streaming {rid}') result = { 'id': rid, - 'img': img, + 'img': zlib.compress(img).hex(), 'meta': meta } @@ -224,14 +233,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): prompt = await get_last_prompt_of(conn, user) if prompt: - req = ImageGenRequest( + req = Text2ImageParameters( prompt=prompt, **user_config ) rid, img, meta = await dgpu_stream_one_img(req) result = { 'id': rid, - 'img': img, + 'img': zlib.compress(img).hex(), 'meta': meta } await update_user_stats(conn, user) @@ -289,14 +298,15 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): 'message': str(e) } - resp = SkynetRPCResponse(result=result) + resp = SkynetRPCResponse() + resp.result.update(result) if security: - resp.sign(tls_key, 'skynet') + resp.auth.cert = 'skynet' + resp.auth.sig = sign_protobuf_msg(resp, tls_key) logging.info('sending response') - await rpc_ctx.asend( - json.dumps(resp.to_dict()).encode()) + await rpc_ctx.asend(resp.SerializeToString()) rpc_ctx.close() logging.info('done') @@ -304,19 +314,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): nonlocal next_worker while True: ctx = sock.new_context() - msg = await ctx.arecv_msg() - - content = msg.bytes.decode() - req = SkynetRPCRequest(**json.loads(content)) + req = SkynetRPCRequest() + req.ParseFromString(await ctx.arecv()) if security: - if req.cert not in tls_whitelist: + if req.auth.cert not in tls_whitelist: logging.warning( f'{req.cert} not in tls whitelist and security=True') continue try: - req.verify(tls_whitelist[req.cert]) + verify_protobuf_msg(req, tls_whitelist[req.auth.cert]) except ValueError: logging.warning( @@ -345,14 +353,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key): handle_user_request, ctx, req) continue - resp = SkynetRPCResponse( - result={'ok': result}) + resp = SkynetRPCResponse() + resp.result.update({'ok': result}) if security: - resp.sign(tls_key, 'skynet') + resp.auth.cert = 'skynet' + resp.auth.sig = sign_protobuf_msg(resp, tls_key) - await ctx.asend( - json.dumps(resp.to_dict()).encode()) + await ctx.asend(resp.SerializeToString()) ctx.close() diff --git a/skynet/db.py b/skynet/db.py index f803397..1446a62 100644 --- a/skynet/db.py +++ b/skynet/db.py @@ -59,8 +59,10 @@ ALTER TABLE skynet.user_config def try_decode_uid(uid: str): - if isinstance(uid, int): - return None, uid + try: + return None, int(uid) + except ValueError: + ... try: proto, uid = uid.split('+') diff --git a/skynet/dgpu.py b/skynet/dgpu.py index a73a5b6..975c78b 100644 --- a/skynet/dgpu.py +++ b/skynet/dgpu.py @@ -8,6 +8,7 @@ import uuid import base64 import random import logging +import traceback from typing import List, Optional from pathlib import Path @@ -34,9 +35,9 @@ from .utils import ( pipeline_for, convert_from_cv2_to_image, convert_from_image_to_cv2 ) -from .structs import * -from .constants import * +from .protobuf import * from .frontend import open_skynet_rpc +from .constants import * def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'): @@ -92,7 +93,7 @@ async def open_dgpu_node( logging.info('memory summary:') logging.info('\n' + torch.cuda.memory_summary()) - async def gpu_compute_one(ireq: ImageGenRequest): + async def gpu_compute_one(ireq: Text2ImageParameters): if ireq.algo not in models: least_used = list(models.keys())[0] for model in models: @@ -110,10 +111,10 @@ async def open_dgpu_node( try: image = models[ireq.algo]['pipe']( ireq.prompt, - width=ireq.width, - height=ireq.height, + width=int(ireq.width), + height=int(ireq.height), guidance_scale=ireq.guidance, - num_inference_steps=ireq.step, + num_inference_steps=int(ireq.step), generator=torch.Generator("cuda").manual_seed(ireq.seed) ).images[0] @@ -191,9 +192,8 @@ async def open_dgpu_node( try: while True: - msg = await dgpu_sock.arecv() - req = DGPUBusRequest( - **json.loads(msg.decode())) + req = DGPUBusMessage() + req.ParseFromString(await dgpu_sock.arecv()) if req.nid != unique_id: logging.info( @@ -201,54 +201,60 @@ async def open_dgpu_node( continue if security: - req.verify(skynet_cert) + verify_protobuf_msg(req, skynet_cert) - ack_resp = DGPUBusResponse( + ack_resp = DGPUBusMessage( rid=req.rid, - nid=req.nid, - params={'ack': {}} + nid=req.nid ) + ack_resp.params.update({'ack': {}}) if security: - ack_resp.sign(tls_key, cert_name) + ack_resp.auth.cert = cert_name + ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key) # send ack - await dgpu_sock.asend( - json.dumps(ack_resp.to_dict()).encode()) + await dgpu_sock.asend(ack_resp.SerializeToString()) logging.info(f'sent ack, processing {req.rid}...') try: - img_req = ImageGenRequest(**req.params) + img_req = Text2ImageParameters(**req.params) if not img_req.seed: img_req.seed = random.randint(0, 2 ** 64) img = await gpu_compute_one(img_req) - img_resp = DGPUBusResponse( + img_resp = DGPUBusMessage( rid=req.rid, nid=req.nid, - params={ - 'img': base64.b64encode(img).hex(), - 'meta': img_req.to_dict() - } + method='binary-reply' ) + img_resp.params.update({ + 'len': len(img), + 'meta': img_req.to_dict() + }) except DGPUComputeError as e: - img_resp = DGPUBusResponse( + traceback.print_exception(type(e), e, e.__traceback__) + img_resp = DGPUBusMessage( rid=req.rid, - nid=req.nid, - params={'error': str(e)} + nid=req.nid ) + img_resp.params.update({'error': str(e)}) + if security: - img_resp.sign(tls_key, cert_name) - - raw_msg = json.dumps(img_resp.to_dict()).encode() + img_resp.auth.cert = cert_name + img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key) # send final image logging.info('sending img back...') + raw_msg = img_resp.SerializeToString() await dgpu_sock.asend(raw_msg) logging.info(f'sent {len(raw_msg)} bytes.') + if img_resp.method == 'binary-reply': + await dgpu_sock.asend(img) + logging.info(f'sent {len(img)} bytes.') except KeyboardInterrupt: logging.info('interrupt caught, stopping...') diff --git a/skynet/frontend/__init__.py b/skynet/frontend/__init__.py index 06a9623..0ea798a 100644 --- a/skynet/frontend/__init__.py +++ b/skynet/frontend/__init__.py @@ -15,9 +15,13 @@ from OpenSSL.crypto import ( FILETYPE_PEM ) -from ..structs import SkynetRPCRequest, SkynetRPCResponse +from google.protobuf.struct_pb2 import Struct + from ..constants import * +from ..protobuf.auth import * +from ..protobuf.skynet_pb2 import SkynetRPCRequest, SkynetRPCResponse + class ConfigRequestFormatError(BaseException): ... @@ -79,28 +83,26 @@ async def open_skynet_rpc( async def _rpc_call( method: str, params: dict = {}, - uid: Optional[Union[int, str]] = None + uid: Optional[str] = None ): - req = SkynetRPCRequest( - uid=uid if uid else unique_id, - method=method, - params=params - ) + req = SkynetRPCRequest() + req.uid = uid if uid else unique_id + req.method = method + req.params.update(params) if security: - req.sign(tls_key, cert_name) + req.auth.cert = cert_name + req.auth.sig = sign_protobuf_msg(req, tls_key) ctx = sock.new_context() - await ctx.asend( - json.dumps( - req.to_dict()).encode()) + await ctx.asend(req.SerializeToString()) - resp = SkynetRPCResponse( - **json.loads((await ctx.arecv()).decode())) + resp = SkynetRPCResponse() + resp.ParseFromString(await ctx.arecv()) ctx.close() if security: - resp.verify(skynet_cert) + verify_protobuf_msg(resp, skynet_cert) return resp diff --git a/skynet/frontend/telegram.py b/skynet/frontend/telegram.py index 47c2b66..5b7cdaf 100644 --- a/skynet/frontend/telegram.py +++ b/skynet/frontend/telegram.py @@ -1,7 +1,7 @@ #!/usr/bin/python import io -import base64 +import zlib import logging from datetime import datetime @@ -110,7 +110,7 @@ async def run_skynet_telegram( else: logging.info(resp.result['id']) - img_raw = base64.b64decode(bytes.fromhex(resp.result['img'])) + img_raw = zlib.decompress(bytes.fromhex(resp.result['img'])) logging.info(f'got image of size: {len(img_raw)}') size = (512, 512) if resp.result['meta']['upscaler'] == 'x4': @@ -141,7 +141,7 @@ async def run_skynet_telegram( resp_txt = resp.result['message'] else: - img_raw = base64.b64decode(bytes.fromhex(resp.result['img'])) + img_raw = zlib.decompress(bytes.fromhex(resp.result['img'])) logging.info(f'got image of size: {len(img_raw)}') size = (512, 512) logging.info(resp.result['meta']) diff --git a/skynet/protobuf/__init__.py b/skynet/protobuf/__init__.py new file mode 100644 index 0000000..eb99c3f --- /dev/null +++ b/skynet/protobuf/__init__.py @@ -0,0 +1,27 @@ +#!/usr/bin/python + +from typing import Optional +from dataclasses import dataclass, asdict + +from google.protobuf.json_format import MessageToDict + +from .auth import * +from .skynet_pb2 import * + + +class Struct: + + def to_dict(self): + return asdict(self) + + +@dataclass +class Text2ImageParameters(Struct): + algo: str + prompt: str + step: int + width: int + height: int + guidance: float + seed: Optional[int] + upscaler: Optional[str] diff --git a/skynet/protobuf/auth.py b/skynet/protobuf/auth.py new file mode 100644 index 0000000..e2904cb --- /dev/null +++ b/skynet/protobuf/auth.py @@ -0,0 +1,65 @@ +#!/usr/bin/python + +import json +import logging + +from hashlib import sha256 +from collections import OrderedDict + +from google.protobuf.json_format import MessageToDict +from OpenSSL.crypto import PKey, X509, verify, sign + +from .skynet_pb2 import * + + +def serialize_msg_deterministic(msg): + descriptors = sorted( + type(msg).DESCRIPTOR.fields_by_name.items(), + key=lambda x: x[0] + ) + shasum = sha256() + + def hash_dict(d): + data = [ + (key, val) + for (key, val) in d.items() + ] + for key, val in sorted(data, key=lambda x: x[0]): + if not isinstance(val, dict): + shasum.update(key.encode()) + shasum.update(json.dumps(val).encode()) + else: + hash_dict(val) + + for (field_name, field_descriptor) in descriptors: + if not field_descriptor.message_type: + shasum.update(field_name.encode()) + + value = getattr(msg, field_name) + + if isinstance(value, bytes): + value = value.hex() + + shasum.update(json.dumps(value).encode()) + continue + + if field_descriptor.message_type.name == 'Struct': + hash_dict(MessageToDict(getattr(msg, field_name))) + + deterministic_msg = shasum.hexdigest() + + return deterministic_msg + + +def sign_protobuf_msg(msg, key: PKey): + return sign( + key, serialize_msg_deterministic(msg), 'sha256').hex() + + +def verify_protobuf_msg(msg, cert: X509): + return verify( + cert, + bytes.fromhex(msg.auth.sig), + serialize_msg_deterministic(msg), + 'sha256' + ) diff --git a/skynet/protobuf/skynet.proto b/skynet/protobuf/skynet.proto new file mode 100644 index 0000000..6e66274 --- /dev/null +++ b/skynet/protobuf/skynet.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package skynet; + +import "google/protobuf/struct.proto"; + +message Auth { + string cert = 1; + string sig = 2; +} + +message SkynetRPCRequest { + string uid = 1; + string method = 2; + google.protobuf.Struct params = 3; + optional Auth auth = 4; +} + +message SkynetRPCResponse { + google.protobuf.Struct result = 1; + optional Auth auth = 2; +} + +message DGPUBusMessage { + string rid = 1; + string nid = 2; + string method = 3; + google.protobuf.Struct params = 4; + optional Auth auth = 5; +} diff --git a/skynet/protobuf/skynet_pb2.py b/skynet/protobuf/skynet_pb2.py new file mode 100644 index 0000000..dd7db33 --- /dev/null +++ b/skynet/protobuf/skynet_pb2.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: skynet.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _AUTH._serialized_start=54 + _AUTH._serialized_end=87 + _SKYNETRPCREQUEST._serialized_start=90 + _SKYNETRPCREQUEST._serialized_end=220 + _SKYNETRPCRESPONSE._serialized_start=222 + _SKYNETRPCRESPONSE._serialized_end=324 + _DGPUBUSMESSAGE._serialized_start=327 + _DGPUBUSMESSAGE._serialized_end=468 +# @@protoc_insertion_point(module_scope) diff --git a/tests/test_dgpu.py b/tests/test_dgpu.py index a580a27..4699156 100644 --- a/tests/test_dgpu.py +++ b/tests/test_dgpu.py @@ -3,7 +3,7 @@ import io import time import json -import base64 +import zlib import logging from typing import Optional @@ -12,10 +12,10 @@ from functools import partial import trio import pytest -import tractor import trio_asyncio from PIL import Image +from google.protobuf.json_format import MessageToDict from skynet.brain import SkynetDGPUComputeError from skynet.constants import * @@ -40,7 +40,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0): _images = set() async def check_request_img( i: int, - uid: int = 0, + uid: str = '1', width: int = 512, height: int = 512, expect_unique = True, @@ -66,13 +66,13 @@ async def check_request_img( }) if 'error' in res.result: - raise SkynetDGPUComputeError(json.dumps(res.result)) + raise SkynetDGPUComputeError(MessageToDict(res.result)) if upscaler == 'x4': width *= 4 height *= 4 - img_raw = base64.b64decode(bytes.fromhex(res.result['img'])) + img_raw = zlib.decompress(bytes.fromhex(res.result['img'])) img_sha = sha256(img_raw).hexdigest() img = Image.frombytes( 'RGB', (width, height), img_raw)