mirror of https://github.com/skygpu/skynet.git
parent
7c27ee866a
commit
6c1799e342
|
@ -2,7 +2,7 @@
|
|||
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import zlib
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
|
@ -24,9 +24,10 @@ from OpenSSL.crypto import (
|
|||
)
|
||||
|
||||
from .db import *
|
||||
from .structs import *
|
||||
from .constants import *
|
||||
|
||||
from .protobuf import *
|
||||
|
||||
|
||||
class SkynetDGPUOffline(BaseException):
|
||||
...
|
||||
|
@ -117,22 +118,30 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
async def dgpu_image_streamer():
|
||||
nonlocal wip_reqs, fin_reqs
|
||||
while True:
|
||||
raw_msg = (await dgpu_bus.arecv()).decode()
|
||||
raw_msg = await dgpu_bus.arecv()
|
||||
logging.info(f'streamer got {len(raw_msg)} bytes.')
|
||||
msg = DGPUBusResponse(**json.loads(raw_msg))
|
||||
msg = DGPUBusMessage()
|
||||
msg.ParseFromString(raw_msg)
|
||||
|
||||
if security:
|
||||
msg.verify(tls_whitelist[msg.cert])
|
||||
verify_protobuf_msg(msg, tls_whitelist[msg.auth.cert])
|
||||
|
||||
if msg.rid not in wip_reqs:
|
||||
rid = msg.rid
|
||||
|
||||
if rid not in wip_reqs:
|
||||
continue
|
||||
|
||||
fin_reqs[msg.rid] = msg
|
||||
event = wip_reqs[msg.rid]
|
||||
event.set()
|
||||
del wip_reqs[msg.rid]
|
||||
if msg.method == 'binary-reply':
|
||||
logging.info('bin reply, recv extra data')
|
||||
raw_img = await dgpu_bus.arecv()
|
||||
msg = (msg, raw_img)
|
||||
|
||||
async def dgpu_stream_one_img(req: ImageGenRequest):
|
||||
fin_reqs[rid] = msg
|
||||
event = wip_reqs[rid]
|
||||
event.set()
|
||||
del wip_reqs[rid]
|
||||
|
||||
async def dgpu_stream_one_img(req: Text2ImageParameters):
|
||||
nonlocal wip_reqs, fin_reqs, next_worker
|
||||
nid = get_next_worker()
|
||||
idx = list(nodes.keys()).index(nid)
|
||||
|
@ -144,19 +153,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
|
||||
nodes[nid]['task'] = rid
|
||||
|
||||
dgpu_req = DGPUBusRequest(
|
||||
dgpu_req = DGPUBusMessage(
|
||||
rid=rid,
|
||||
nid=nid,
|
||||
task='diffuse',
|
||||
params=req.to_dict())
|
||||
|
||||
logging.info(f'dgpu_bus req: {dgpu_req}')
|
||||
method='diffuse')
|
||||
dgpu_req.params.update(req.to_dict())
|
||||
|
||||
if security:
|
||||
dgpu_req.sign(tls_key, 'skynet')
|
||||
dgpu_req.auth.cert = 'skynet'
|
||||
dgpu_req.auth.sig = sign_protobuf_msg(dgpu_req, tls_key)
|
||||
|
||||
await dgpu_bus.asend(
|
||||
json.dumps(dgpu_req.to_dict()).encode())
|
||||
await dgpu_bus.asend(dgpu_req.SerializeToString())
|
||||
|
||||
with trio.move_on_after(4):
|
||||
await ack_event.wait()
|
||||
|
@ -184,13 +191,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
|
||||
nodes[nid]['task'] = None
|
||||
|
||||
img_resp = fin_reqs[rid]
|
||||
resp = fin_reqs[rid]
|
||||
del fin_reqs[rid]
|
||||
if isinstance(resp, tuple):
|
||||
meta, img = resp
|
||||
return rid, img, meta.params
|
||||
|
||||
if 'error' in img_resp.params:
|
||||
raise SkynetDGPUComputeError(img_resp.params['error'])
|
||||
raise SkynetDGPUComputeError(MessageToDict(resp.params))
|
||||
|
||||
return rid, img_resp.params['img'], img_resp.params['meta']
|
||||
|
||||
async def handle_user_request(rpc_ctx, req):
|
||||
try:
|
||||
|
@ -204,13 +212,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
logging.info('txt2img')
|
||||
user_config = {**(await get_user_config(conn, user))}
|
||||
del user_config['id']
|
||||
user_config.update((k, req.params[k]) for k in req.params)
|
||||
req = ImageGenRequest(**user_config)
|
||||
user_config.update(MessageToDict(req.params))
|
||||
|
||||
req = Text2ImageParameters(**user_config)
|
||||
rid, img, meta = await dgpu_stream_one_img(req)
|
||||
logging.info(f'done streaming {rid}')
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img,
|
||||
'img': zlib.compress(img).hex(),
|
||||
'meta': meta
|
||||
}
|
||||
|
||||
|
@ -224,14 +233,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
prompt = await get_last_prompt_of(conn, user)
|
||||
|
||||
if prompt:
|
||||
req = ImageGenRequest(
|
||||
req = Text2ImageParameters(
|
||||
prompt=prompt,
|
||||
**user_config
|
||||
)
|
||||
rid, img, meta = await dgpu_stream_one_img(req)
|
||||
result = {
|
||||
'id': rid,
|
||||
'img': img,
|
||||
'img': zlib.compress(img).hex(),
|
||||
'meta': meta
|
||||
}
|
||||
await update_user_stats(conn, user)
|
||||
|
@ -289,14 +298,15 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
'message': str(e)
|
||||
}
|
||||
|
||||
resp = SkynetRPCResponse(result=result)
|
||||
resp = SkynetRPCResponse()
|
||||
resp.result.update(result)
|
||||
|
||||
if security:
|
||||
resp.sign(tls_key, 'skynet')
|
||||
resp.auth.cert = 'skynet'
|
||||
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
||||
|
||||
logging.info('sending response')
|
||||
await rpc_ctx.asend(
|
||||
json.dumps(resp.to_dict()).encode())
|
||||
await rpc_ctx.asend(resp.SerializeToString())
|
||||
rpc_ctx.close()
|
||||
logging.info('done')
|
||||
|
||||
|
@ -304,19 +314,17 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
nonlocal next_worker
|
||||
while True:
|
||||
ctx = sock.new_context()
|
||||
msg = await ctx.arecv_msg()
|
||||
|
||||
content = msg.bytes.decode()
|
||||
req = SkynetRPCRequest(**json.loads(content))
|
||||
req = SkynetRPCRequest()
|
||||
req.ParseFromString(await ctx.arecv())
|
||||
|
||||
if security:
|
||||
if req.cert not in tls_whitelist:
|
||||
if req.auth.cert not in tls_whitelist:
|
||||
logging.warning(
|
||||
f'{req.cert} not in tls whitelist and security=True')
|
||||
continue
|
||||
|
||||
try:
|
||||
req.verify(tls_whitelist[req.cert])
|
||||
verify_protobuf_msg(req, tls_whitelist[req.auth.cert])
|
||||
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
|
@ -345,14 +353,14 @@ async def open_rpc_service(sock, dgpu_bus, db_pool, tls_whitelist, tls_key):
|
|||
handle_user_request, ctx, req)
|
||||
continue
|
||||
|
||||
resp = SkynetRPCResponse(
|
||||
result={'ok': result})
|
||||
resp = SkynetRPCResponse()
|
||||
resp.result.update({'ok': result})
|
||||
|
||||
if security:
|
||||
resp.sign(tls_key, 'skynet')
|
||||
resp.auth.cert = 'skynet'
|
||||
resp.auth.sig = sign_protobuf_msg(resp, tls_key)
|
||||
|
||||
await ctx.asend(
|
||||
json.dumps(resp.to_dict()).encode())
|
||||
await ctx.asend(resp.SerializeToString())
|
||||
|
||||
ctx.close()
|
||||
|
||||
|
|
|
@ -59,8 +59,10 @@ ALTER TABLE skynet.user_config
|
|||
|
||||
|
||||
def try_decode_uid(uid: str):
|
||||
if isinstance(uid, int):
|
||||
return None, uid
|
||||
try:
|
||||
return None, int(uid)
|
||||
except ValueError:
|
||||
...
|
||||
|
||||
try:
|
||||
proto, uid = uid.split('+')
|
||||
|
|
|
@ -8,6 +8,7 @@ import uuid
|
|||
import base64
|
||||
import random
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
|
@ -34,9 +35,9 @@ from .utils import (
|
|||
pipeline_for,
|
||||
convert_from_cv2_to_image, convert_from_image_to_cv2
|
||||
)
|
||||
from .structs import *
|
||||
from .constants import *
|
||||
from .protobuf import *
|
||||
from .frontend import open_skynet_rpc
|
||||
from .constants import *
|
||||
|
||||
|
||||
def init_upscaler(model_path: str = 'weights/RealESRGAN_x4plus.pth'):
|
||||
|
@ -92,7 +93,7 @@ async def open_dgpu_node(
|
|||
logging.info('memory summary:')
|
||||
logging.info('\n' + torch.cuda.memory_summary())
|
||||
|
||||
async def gpu_compute_one(ireq: ImageGenRequest):
|
||||
async def gpu_compute_one(ireq: Text2ImageParameters):
|
||||
if ireq.algo not in models:
|
||||
least_used = list(models.keys())[0]
|
||||
for model in models:
|
||||
|
@ -110,10 +111,10 @@ async def open_dgpu_node(
|
|||
try:
|
||||
image = models[ireq.algo]['pipe'](
|
||||
ireq.prompt,
|
||||
width=ireq.width,
|
||||
height=ireq.height,
|
||||
width=int(ireq.width),
|
||||
height=int(ireq.height),
|
||||
guidance_scale=ireq.guidance,
|
||||
num_inference_steps=ireq.step,
|
||||
num_inference_steps=int(ireq.step),
|
||||
generator=torch.Generator("cuda").manual_seed(ireq.seed)
|
||||
).images[0]
|
||||
|
||||
|
@ -191,9 +192,8 @@ async def open_dgpu_node(
|
|||
|
||||
try:
|
||||
while True:
|
||||
msg = await dgpu_sock.arecv()
|
||||
req = DGPUBusRequest(
|
||||
**json.loads(msg.decode()))
|
||||
req = DGPUBusMessage()
|
||||
req.ParseFromString(await dgpu_sock.arecv())
|
||||
|
||||
if req.nid != unique_id:
|
||||
logging.info(
|
||||
|
@ -201,54 +201,60 @@ async def open_dgpu_node(
|
|||
continue
|
||||
|
||||
if security:
|
||||
req.verify(skynet_cert)
|
||||
verify_protobuf_msg(req, skynet_cert)
|
||||
|
||||
ack_resp = DGPUBusResponse(
|
||||
ack_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
params={'ack': {}}
|
||||
nid=req.nid
|
||||
)
|
||||
ack_resp.params.update({'ack': {}})
|
||||
|
||||
if security:
|
||||
ack_resp.sign(tls_key, cert_name)
|
||||
ack_resp.auth.cert = cert_name
|
||||
ack_resp.auth.sig = sign_protobuf_msg(ack_resp, tls_key)
|
||||
|
||||
# send ack
|
||||
await dgpu_sock.asend(
|
||||
json.dumps(ack_resp.to_dict()).encode())
|
||||
await dgpu_sock.asend(ack_resp.SerializeToString())
|
||||
|
||||
logging.info(f'sent ack, processing {req.rid}...')
|
||||
|
||||
try:
|
||||
img_req = ImageGenRequest(**req.params)
|
||||
img_req = Text2ImageParameters(**req.params)
|
||||
if not img_req.seed:
|
||||
img_req.seed = random.randint(0, 2 ** 64)
|
||||
|
||||
img = await gpu_compute_one(img_req)
|
||||
img_resp = DGPUBusResponse(
|
||||
img_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
params={
|
||||
'img': base64.b64encode(img).hex(),
|
||||
'meta': img_req.to_dict()
|
||||
}
|
||||
method='binary-reply'
|
||||
)
|
||||
img_resp.params.update({
|
||||
'len': len(img),
|
||||
'meta': img_req.to_dict()
|
||||
})
|
||||
|
||||
except DGPUComputeError as e:
|
||||
img_resp = DGPUBusResponse(
|
||||
traceback.print_exception(type(e), e, e.__traceback__)
|
||||
img_resp = DGPUBusMessage(
|
||||
rid=req.rid,
|
||||
nid=req.nid,
|
||||
params={'error': str(e)}
|
||||
nid=req.nid
|
||||
)
|
||||
img_resp.params.update({'error': str(e)})
|
||||
|
||||
|
||||
if security:
|
||||
img_resp.sign(tls_key, cert_name)
|
||||
|
||||
raw_msg = json.dumps(img_resp.to_dict()).encode()
|
||||
img_resp.auth.cert = cert_name
|
||||
img_resp.auth.sig = sign_protobuf_msg(img_resp, tls_key)
|
||||
|
||||
# send final image
|
||||
logging.info('sending img back...')
|
||||
raw_msg = img_resp.SerializeToString()
|
||||
await dgpu_sock.asend(raw_msg)
|
||||
logging.info(f'sent {len(raw_msg)} bytes.')
|
||||
if img_resp.method == 'binary-reply':
|
||||
await dgpu_sock.asend(img)
|
||||
logging.info(f'sent {len(img)} bytes.')
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logging.info('interrupt caught, stopping...')
|
||||
|
|
|
@ -15,9 +15,13 @@ from OpenSSL.crypto import (
|
|||
FILETYPE_PEM
|
||||
)
|
||||
|
||||
from ..structs import SkynetRPCRequest, SkynetRPCResponse
|
||||
from google.protobuf.struct_pb2 import Struct
|
||||
|
||||
from ..constants import *
|
||||
|
||||
from ..protobuf.auth import *
|
||||
from ..protobuf.skynet_pb2 import SkynetRPCRequest, SkynetRPCResponse
|
||||
|
||||
|
||||
class ConfigRequestFormatError(BaseException):
|
||||
...
|
||||
|
@ -79,28 +83,26 @@ async def open_skynet_rpc(
|
|||
async def _rpc_call(
|
||||
method: str,
|
||||
params: dict = {},
|
||||
uid: Optional[Union[int, str]] = None
|
||||
uid: Optional[str] = None
|
||||
):
|
||||
req = SkynetRPCRequest(
|
||||
uid=uid if uid else unique_id,
|
||||
method=method,
|
||||
params=params
|
||||
)
|
||||
req = SkynetRPCRequest()
|
||||
req.uid = uid if uid else unique_id
|
||||
req.method = method
|
||||
req.params.update(params)
|
||||
|
||||
if security:
|
||||
req.sign(tls_key, cert_name)
|
||||
req.auth.cert = cert_name
|
||||
req.auth.sig = sign_protobuf_msg(req, tls_key)
|
||||
|
||||
ctx = sock.new_context()
|
||||
await ctx.asend(
|
||||
json.dumps(
|
||||
req.to_dict()).encode())
|
||||
await ctx.asend(req.SerializeToString())
|
||||
|
||||
resp = SkynetRPCResponse(
|
||||
**json.loads((await ctx.arecv()).decode()))
|
||||
resp = SkynetRPCResponse()
|
||||
resp.ParseFromString(await ctx.arecv())
|
||||
ctx.close()
|
||||
|
||||
if security:
|
||||
resp.verify(skynet_cert)
|
||||
verify_protobuf_msg(resp, skynet_cert)
|
||||
|
||||
return resp
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import io
|
||||
import base64
|
||||
import zlib
|
||||
import logging
|
||||
|
||||
from datetime import datetime
|
||||
|
@ -110,7 +110,7 @@ async def run_skynet_telegram(
|
|||
|
||||
else:
|
||||
logging.info(resp.result['id'])
|
||||
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
|
||||
img_raw = zlib.decompress(bytes.fromhex(resp.result['img']))
|
||||
logging.info(f'got image of size: {len(img_raw)}')
|
||||
size = (512, 512)
|
||||
if resp.result['meta']['upscaler'] == 'x4':
|
||||
|
@ -141,7 +141,7 @@ async def run_skynet_telegram(
|
|||
resp_txt = resp.result['message']
|
||||
|
||||
else:
|
||||
img_raw = base64.b64decode(bytes.fromhex(resp.result['img']))
|
||||
img_raw = zlib.decompress(bytes.fromhex(resp.result['img']))
|
||||
logging.info(f'got image of size: {len(img_raw)}')
|
||||
size = (512, 512)
|
||||
logging.info(resp.result['meta'])
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
from .auth import *
|
||||
from .skynet_pb2 import *
|
||||
|
||||
|
||||
class Struct:
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Text2ImageParameters(Struct):
|
||||
algo: str
|
||||
prompt: str
|
||||
step: int
|
||||
width: int
|
||||
height: int
|
||||
guidance: float
|
||||
seed: Optional[int]
|
||||
upscaler: Optional[str]
|
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from hashlib import sha256
|
||||
from collections import OrderedDict
|
||||
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from OpenSSL.crypto import PKey, X509, verify, sign
|
||||
|
||||
from .skynet_pb2 import *
|
||||
|
||||
|
||||
def serialize_msg_deterministic(msg):
|
||||
descriptors = sorted(
|
||||
type(msg).DESCRIPTOR.fields_by_name.items(),
|
||||
key=lambda x: x[0]
|
||||
)
|
||||
shasum = sha256()
|
||||
|
||||
def hash_dict(d):
|
||||
data = [
|
||||
(key, val)
|
||||
for (key, val) in d.items()
|
||||
]
|
||||
for key, val in sorted(data, key=lambda x: x[0]):
|
||||
if not isinstance(val, dict):
|
||||
shasum.update(key.encode())
|
||||
shasum.update(json.dumps(val).encode())
|
||||
else:
|
||||
hash_dict(val)
|
||||
|
||||
for (field_name, field_descriptor) in descriptors:
|
||||
if not field_descriptor.message_type:
|
||||
shasum.update(field_name.encode())
|
||||
|
||||
value = getattr(msg, field_name)
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.hex()
|
||||
|
||||
shasum.update(json.dumps(value).encode())
|
||||
continue
|
||||
|
||||
if field_descriptor.message_type.name == 'Struct':
|
||||
hash_dict(MessageToDict(getattr(msg, field_name)))
|
||||
|
||||
deterministic_msg = shasum.hexdigest()
|
||||
|
||||
return deterministic_msg
|
||||
|
||||
|
||||
def sign_protobuf_msg(msg, key: PKey):
|
||||
return sign(
|
||||
key, serialize_msg_deterministic(msg), 'sha256').hex()
|
||||
|
||||
|
||||
def verify_protobuf_msg(msg, cert: X509):
|
||||
return verify(
|
||||
cert,
|
||||
bytes.fromhex(msg.auth.sig),
|
||||
serialize_msg_deterministic(msg),
|
||||
'sha256'
|
||||
)
|
|
@ -0,0 +1,30 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package skynet;
|
||||
|
||||
import "google/protobuf/struct.proto";
|
||||
|
||||
message Auth {
|
||||
string cert = 1;
|
||||
string sig = 2;
|
||||
}
|
||||
|
||||
message SkynetRPCRequest {
|
||||
string uid = 1;
|
||||
string method = 2;
|
||||
google.protobuf.Struct params = 3;
|
||||
optional Auth auth = 4;
|
||||
}
|
||||
|
||||
message SkynetRPCResponse {
|
||||
google.protobuf.Struct result = 1;
|
||||
optional Auth auth = 2;
|
||||
}
|
||||
|
||||
message DGPUBusMessage {
|
||||
string rid = 1;
|
||||
string nid = 2;
|
||||
string method = 3;
|
||||
google.protobuf.Struct params = 4;
|
||||
optional Auth auth = 5;
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: skynet.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cskynet.proto\x12\x06skynet\x1a\x1cgoogle/protobuf/struct.proto\"!\n\x04\x41uth\x12\x0c\n\x04\x63\x65rt\x18\x01 \x01(\t\x12\x0b\n\x03sig\x18\x02 \x01(\t\"\x82\x01\n\x10SkynetRPCRequest\x12\x0b\n\x03uid\x18\x01 \x01(\t\x12\x0e\n\x06method\x18\x02 \x01(\t\x12\'\n\x06params\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x04 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"f\n\x11SkynetRPCResponse\x12\'\n\x06result\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x02 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_auth\"\x8d\x01\n\x0e\x44GPUBusMessage\x12\x0b\n\x03rid\x18\x01 \x01(\t\x12\x0b\n\x03nid\x18\x02 \x01(\t\x12\x0e\n\x06method\x18\x03 \x01(\t\x12\'\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x1f\n\x04\x61uth\x18\x05 \x01(\x0b\x32\x0c.skynet.AuthH\x00\x88\x01\x01\x42\x07\n\x05_authb\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'skynet_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_AUTH._serialized_start=54
|
||||
_AUTH._serialized_end=87
|
||||
_SKYNETRPCREQUEST._serialized_start=90
|
||||
_SKYNETRPCREQUEST._serialized_end=220
|
||||
_SKYNETRPCRESPONSE._serialized_start=222
|
||||
_SKYNETRPCRESPONSE._serialized_end=324
|
||||
_DGPUBUSMESSAGE._serialized_start=327
|
||||
_DGPUBUSMESSAGE._serialized_end=468
|
||||
# @@protoc_insertion_point(module_scope)
|
|
@ -3,7 +3,7 @@
|
|||
import io
|
||||
import time
|
||||
import json
|
||||
import base64
|
||||
import zlib
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
|
@ -12,10 +12,10 @@ from functools import partial
|
|||
|
||||
import trio
|
||||
import pytest
|
||||
import tractor
|
||||
import trio_asyncio
|
||||
|
||||
from PIL import Image
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
from skynet.brain import SkynetDGPUComputeError
|
||||
from skynet.constants import *
|
||||
|
@ -40,7 +40,7 @@ async def wait_for_dgpus(rpc, amount: int, timeout: float = 30.0):
|
|||
_images = set()
|
||||
async def check_request_img(
|
||||
i: int,
|
||||
uid: int = 0,
|
||||
uid: str = '1',
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
expect_unique = True,
|
||||
|
@ -66,13 +66,13 @@ async def check_request_img(
|
|||
})
|
||||
|
||||
if 'error' in res.result:
|
||||
raise SkynetDGPUComputeError(json.dumps(res.result))
|
||||
raise SkynetDGPUComputeError(MessageToDict(res.result))
|
||||
|
||||
if upscaler == 'x4':
|
||||
width *= 4
|
||||
height *= 4
|
||||
|
||||
img_raw = base64.b64decode(bytes.fromhex(res.result['img']))
|
||||
img_raw = zlib.decompress(bytes.fromhex(res.result['img']))
|
||||
img_sha = sha256(img_raw).hexdigest()
|
||||
img = Image.frombytes(
|
||||
'RGB', (width, height), img_raw)
|
||||
|
|
Loading…
Reference in New Issue