diff --git a/certs/brain.cert b/certs/brain.cert new file mode 100644 index 0000000..d5d7e49 --- /dev/null +++ b/certs/brain.cert @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIFxDCCA6wCAQAwDQYJKoZIhvcNAQENBQAwgacxCzAJBgNVBAYTAlVZMRMwEQYD +VQQIDApNb250ZXZpZGVvMRMwEQYDVQQHDApNb250ZXZpZGVvMRowGAYDVQQKDBFz +a3luZXQtZm91bmRhdGlvbjENMAsGA1UECwwEbm9uZTEcMBoGA1UEAwwTR3VpbGxl +cm1vIFJvZHJpZ3VlejElMCMGCSqGSIb3DQEJARYWZ3VpbGxlcm1vckBmaW5nLmVk +dS51eTAeFw0yMjEyMTExNDM3NDVaFw0zMjEyMDgxNDM3NDVaMIGnMQswCQYDVQQG +EwJVWTETMBEGA1UECAwKTW9udGV2aWRlbzETMBEGA1UEBwwKTW9udGV2aWRlbzEa +MBgGA1UECgwRc2t5bmV0LWZvdW5kYXRpb24xDTALBgNVBAsMBG5vbmUxHDAaBgNV +BAMME0d1aWxsZXJtbyBSb2RyaWd1ZXoxJTAjBgkqhkiG9w0BCQEWFmd1aWxsZXJt +b3JAZmluZy5lZHUudXkwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCu +HdqGPtsqtYqfIilVdq0MmqfEn9g4T+uglfWjRF2gWV3uQCuXDv1O61XfIIyaDQXl +VRqT36txtM8rvn213746SwK0jx9+ln5jD3EDbL4WZv1qvp4/jqA+UPKXFXnD3he+ +pRpcDMu4IpYKuoPl667IW/auFSSy3TIWhIZb8ghqxzb2e2i6/OhzIWKHeFIKvbEA +EB6Z63wy3O0ACY7RVhHu0wzyzqUW1t1VNsbZvO9Xmmqm2EWZBJp0TFph3Z9kOR/g +0Ik7kxMLrGIfhV5/1gPQlNr3ADebGJnaMdGCBUi+pqeZcVnGY45fjOJREaD3aTRG +ohZM0Td40K7paDVjUvQ9rPgKoDMsCWpu8IPdc4LB0hONIO2KycFb49cd8zNWsetj +kHXxL9IVgORxfGmVyOtNGotS5RX6R+qwsll3qUmX4XjwvQMAMvATcSkY26CWdCDM +vGFp+0REbVyDfJ9pwU7ZkAxiWeAoiesGfEWyRLsl0fFkaHgHG+oPCH9IO63TVnCq +E6NGRQpHfJ5oV4ZihUfWjSFxOJqdFM3xfzk/2YGzQUgKVBsbuQTWPKxE0aSwt1Cf +Ug4+C0RSDMmrquRmhRn/BWsSRl+2m17rt1axTA4pEVGcHHyKSowEFQ68spD1Lm2K +iU/LCPBh4REzexwjP+onwHALXoxIEOLiy2lEdYgWnwIDAQABMA0GCSqGSIb3DQEB +DQUAA4ICAQBtTZb6PJJQXtF90MD4Hcgj+phKkbtHVZyM198Giw3I9f2PgjDECKb9 +I7JLzCUgpexKk1TNso2FPNoVlcE4yMO0I0EauoKcwZ1w9GXsXOGwPHvB9hrItaLs +s7Qxf+IVgKO4y5Tv+8WO4lhgShWa4fW3L7Dpk0XK4INoAAxZLbEdekf2GGqTUGzD +SrfvtE8h6JT+gR4lsAvdsRjJIKYacsqhKjtV0reA6v99NthDcpwaStrAaFmtJkD3 +6G3JVU0JyMBlR1GetN0w42BjVHJ2l7cPm405lE2ymFwcl7C8VozXXi4wmfVN+xlh +NOVSbl/QUiMUyt44XPhPCbgopxLqhqtvGzBl+ldF1AR4aaukXjvS/8VtFZ3cfx7n +n5NYxvPnq3kwlFNHgppt+u1leGrzxuesGNQENQd3shO/S9T4I92hAdk2MRTivIfv +m74u6RCtHqDviiOFzF7zcqO37wCrb1dnfS1N4I6/rCf6XtxlRGa8Cp9z4DTKjwAC +5z5irJb+LSJkFXA/zIFpBjjKBdyhjYGuXrbJWdL81kTcYRqjE99XfZaTU8L43qVd +TUaIvQGTtx8k7WGmeTRHk6SauCaXSfeXwYTpEZpictUI/uWo/KJRDL/aE8HmBeH3 +pr+cfDu7erTLH+GG5ZROrILf4929Jd7OF4a0nHUnZcycBS0CjGHVHA== +-----END CERTIFICATE----- diff --git a/certs/testing.key b/certs/testing.key new file mode 100644 index 0000000..72402d7 --- /dev/null +++ b/certs/testing.key @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCyAuCwwnoENeYe +B0159sH47zedmRaxcUmC/qmVdUptzOxIHpUCSAIy+hoR5UOhnRsmjj7Y0kUWtlwj +bHAKHcuUn4sqLBb0nl6kH79DzP/4YCQM3GEIXzE6wy/zmnYrHz53Ci7DzmMcRM3n +MwXDVPPpKXzpmI/yassKxSltBKgbh65U3oOheiuFygOlAkT4fUaXX5Bf9DECZBsj +ewf9WvHzLGN2eQt/YWYxJMstgAecHLlRmLbKoYD/P+O0K1ybmhMDItcXE49kNC4s +Rvq7MUt8B0bi8SlRxv5plAbZBiyMilrxf3yCCgYaTsqtt3x+CSrAWjzYIzEzD5aZ +1+s5O2jsqPYkbTvA4NT/hDnWHkkr7YcBRwQn1iMe2tMUTTsWotIYWH87++BzDAWG +3ZBkqNZ4mUdA3usk2ZPO0BwWNxlb0AqOlAJUYSoCsm3nBPT08rVvumQ44hup6XPW +L5KIDyL5+Fl8RDgDF8cpCfrijdL+U+GoHmmJYM6zMkrGqD7BD+WJgw9plgbaWUBI +q4aimXF4PrBJAAX5IRyZK+EDDH0AREL3qoZIQVvJR+yGIKTixpyVKtj6jm1OY4Go +iXxRLaFrc4ucT9+PxRHo9zYtNIijub4eXuU5nveswptmCsNa4spTO2XCkHh6IE0Z +B4oALC4lrC279WY+3TaOpv/roGzG9QIDAQABAoICABfpXGFMs7MzwkYvrkU/KO3V +bwppHAFDOcqyMU7K7e/d4ly1rvJwKyDJ3mKfrKay7Ii7UXndP5E+IcD9ufcXQCzQ +rug/+pLAC0UkoT6W9PNaMWgrhOU+VDs+fjHM19QRuFmpMSr1jZ6ofLgdGchpSvJR +CQnKh9uFDjfTethoEw96Tv1GKTcHAChSleFpHUv7wqsRbTABJJbbokGb2duQhzD7 +uh3vQzodzT+2CjeBxoPpNS40GKm+FA6KzdLP2FAWhuNESibmu7uMFCpicR+1ZBxe ++zNU4xCsbamk9rPZqSD1HM4/1RZqs53TuP9TcbzvDPfAUgKpMjICWrUuVIHgQcb/ +H3lJbsusZccFkl+B4arncUu7oyYWsw+OLHq/khja1RrJu6/PDDfcqY0cSAAsCKJf +ChiHVyVbhZ6b9g1MdYLNPlcJrpgCVX+PisqLqY/RqQGIln6D0sBK1+MC6TjFW3zA +ca3Dhun18JBZ73mmlGj7LoOUojtnnxy5YVUdB75tdo5BqilGR1nLurJupg9Nkgeq +C7nbA+rZ93MKHptayko91nc7yLzsMRV8PDFhE2UhZWRZfJ5yAW/IaJBZpvTvSYM3 +5lTgAn1o34mnykuNC3sK5tbCAMb0YbCJtmotRwBIqlFHqbH+TK07CW2lnEkqZ8ID +YFTpAJlgKgsdhsd5ZCkpAoIBAQDQMvn4iBKvnhCeRUV/6AOHcOsgwJkV/G61Gz/G +F0mx0kPsaPugNX1VzF15R+vN1kbk3sQ9bDP6FfsX7jp2EjRqGEb9mJ8BoIbSHLJ4 +dDT7M90TMMYepCVoFMC03Hh30vxH3QokgV3E1lakXCwl1dheRz5czT0BL9VuBkpG +x8vGpVfX4VqLliOWK72wEYdfohUTynb2OkRP/e6woBRxb3hYLqpN7nVHVRiMFBgG ++AvpLNv/oSYBOXj9oRBOwVLZaPV8N1p4Pv7WXL+B7E47Z9rUYNzGFf+2iM1uDdrO +xHkAocgMM/sL81sJaj1khoYRLC8IpAxBG8NqRP6xzeGcLVLHAoIBAQDa4ZdEDvqA +gJmJ4vgivIX7/zv7/q9c/nkNsnPiXjMys6HRdwroQjT7wrxO5/jJX9EDjM98dSFg +1HFJWJulpmDMpIzzwC6DLxZWd+EEqG4Pyv50VGmGuwmqDwWAP7v/pMPwUEvlsGYZ +Tvlebr4jze9vz8MiRw3qBp0ASWpDWgySt3zm0gDWRaxqvZbdqlLvK/YTta+4ySay +dfkqMG4SGM2m7Rc6H+DKqhwADoyd3oVrFD7QWCZTUUm414TgFFk+uils8Pms6ulG +u+mZT29Jaq8UzoXLOmf+tX2K07oA98y0HfrGMAto3+c0x9ArIPrtwHuUGJiTdt3V +ShBPP9AzaBxjAoIBAQCF+3gwP2k/CQqKv+t035t9yuYVgrxBkNyxweJtmUj8nWLG +vdzIggOxdj3lMaqHIVEoMk+5c2uTkhevk8ideSOv7wWoZ1JUWrjIeF1F9QqvafXo +RqgIyfukmk5VVdhUzDs8B/xh97qfVIwXY5Wpl4+RRGnWkOGkZOMF1hhwqlzx7i+0 +prp9P9aQ6n880lr66TSFMvMRi/ewPqsfkTT2txSMMyO32TAyAoo0gy3fNjt8CDlf +rZXmjdTV65OyCulFLi1kjb6zyV54FuHLO4Yw5qnFqLwK4ddY4XrKSzI3g+qWxIYX +jFAPpcE9MthlW8jlPjjaZ6/XKoW8WsBJLkP1HJm7AoIBAAm9J+HbWMIG9s3vz2Kc +SMnhnWWk+2CD4hb97bIQxu5ml7ieN1oGOB1LmN1Z7PPo03/47/J1s7p/OVsuGh7Q +vFXerHbcAjXMDo5iXxy58cu6GIBMkTVxdQigCnqeW1sQlbdHm1jo9GID5YySGNu2 ++gRbli8cQj47dRjiK1w70XtltqT+ixL9nqJRNTk/rtj9d8GAwATUzmf6X8/Ev+EG +QYA/5Fyttm7OCtjlzNPpZr5Q9EqI4YurfkA/NqZRwXbNCbLTNgi/mwmOquIraqQ1 +nvyqA8H7I01t/dwDd687V1xcSSAwWxGbhMoQae7BVOjnO5hnT8Kf81beKMOd70Ga +TEkCggEAI8ICJvOBouBO92330s8smVhxPi9tRCnOZ0mg5MoR8EJydbOrcRIap1w7 +Ai0CTR6ziOgMaDbT52ouZ1u0l6izYAdBdeSaPOiiTLx8vEE+U7SpNR3zCesPtZB3 +uvGOY2mVwyfZH2SUc4cs+uzDnAGhPqC7/RSFPMoctXf46YpGc9auyjdesE395KLX +L043DaE9/ng9B1jCnhu5TUyiUtAluHvRGQC32og6id2KUEhmhGCl5vj2KIVoDmI2 +NpeBLCKuaBNi/rOG3zyHLjg1wCYidjE7vwjY6UyemjbW48LI8KN6Sl5rQdaDu+bG +lWI2XLI4C2zqDBVmEL2MuzL0FrWivQ== +-----END PRIVATE KEY----- diff --git a/certs/whitelist/testing.cert b/certs/whitelist/testing.cert new file mode 100644 index 0000000..8c0aa19 --- /dev/null +++ b/certs/whitelist/testing.cert @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIFxDCCA6wCAQIwDQYJKoZIhvcNAQENBQAwgacxCzAJBgNVBAYTAlVZMRMwEQYD +VQQIDApNb250ZXZpZGVvMRMwEQYDVQQHDApNb250ZXZpZGVvMRowGAYDVQQKDBFz +a3luZXQtZm91bmRhdGlvbjENMAsGA1UECwwEbm9uZTEcMBoGA1UEAwwTR3VpbGxl +cm1vIFJvZHJpZ3VlejElMCMGCSqGSIb3DQEJARYWZ3VpbGxlcm1vckBmaW5nLmVk +dS51eTAeFw0yMjEyMTExNTE1MDNaFw0zMjEyMDgxNTE1MDNaMIGnMQswCQYDVQQG +EwJVWTETMBEGA1UECAwKTW9udGV2aWRlbzETMBEGA1UEBwwKTW9udGV2aWRlbzEa +MBgGA1UECgwRc2t5bmV0LWZvdW5kYXRpb24xDTALBgNVBAsMBG5vbmUxHDAaBgNV +BAMME0d1aWxsZXJtbyBSb2RyaWd1ZXoxJTAjBgkqhkiG9w0BCQEWFmd1aWxsZXJt +b3JAZmluZy5lZHUudXkwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCy +AuCwwnoENeYeB0159sH47zedmRaxcUmC/qmVdUptzOxIHpUCSAIy+hoR5UOhnRsm +jj7Y0kUWtlwjbHAKHcuUn4sqLBb0nl6kH79DzP/4YCQM3GEIXzE6wy/zmnYrHz53 +Ci7DzmMcRM3nMwXDVPPpKXzpmI/yassKxSltBKgbh65U3oOheiuFygOlAkT4fUaX +X5Bf9DECZBsjewf9WvHzLGN2eQt/YWYxJMstgAecHLlRmLbKoYD/P+O0K1ybmhMD +ItcXE49kNC4sRvq7MUt8B0bi8SlRxv5plAbZBiyMilrxf3yCCgYaTsqtt3x+CSrA +WjzYIzEzD5aZ1+s5O2jsqPYkbTvA4NT/hDnWHkkr7YcBRwQn1iMe2tMUTTsWotIY +WH87++BzDAWG3ZBkqNZ4mUdA3usk2ZPO0BwWNxlb0AqOlAJUYSoCsm3nBPT08rVv +umQ44hup6XPWL5KIDyL5+Fl8RDgDF8cpCfrijdL+U+GoHmmJYM6zMkrGqD7BD+WJ +gw9plgbaWUBIq4aimXF4PrBJAAX5IRyZK+EDDH0AREL3qoZIQVvJR+yGIKTixpyV +Ktj6jm1OY4GoiXxRLaFrc4ucT9+PxRHo9zYtNIijub4eXuU5nveswptmCsNa4spT +O2XCkHh6IE0ZB4oALC4lrC279WY+3TaOpv/roGzG9QIDAQABMA0GCSqGSIb3DQEB +DQUAA4ICAQBic+3ipdfvmCThWkDjVs97tkbUUNjGXH95okwI0Jbft0iRivVM16Xb +hqGquQK4OvYoSTHTmsMH19/dMj0W/Bd4IUYKl64rG8YJUbjDbO1y7a+wF2TaONyn +z0k3zRCky+IwxqYf9Ppw7s2/cXlt3fOEg0kBr4EooXd+bFCx/+JQIxU3vfL8cDQK +dp55vkh+ROt8eR7ai1FiAC8J1prswyT092ktco2fP0MI4uQ3iQfl07NyI68UV1E5 +aIsOPU3SKMtxz5FLm8JEUVhZRJZJWQ/o/iB/2cdn4PDBGkrBhgU6ysMPNX51RlCM +aHRsMyoO2mFfIlm0jW0C5lZ6nKHuA1sXPFz1YxzpvnRgRlHUlfoKf1wpCeF+5Qz+ +qylArHPSu69CA38wLCzJ3wWTaGVL1nuH1UPR2Pg71HGBYqLCD2XGa8iLShO1DKl7 +1bAeHOvzryngYq35rky1L3cIquinAwCP4QKocJK3DJAD5lPqhpzO1f2/1BmWV9Ri +ZRrRkM/9AxePxGZEmnoQbwKsQs/bY+jGU2fRzqijxRPoX9ogX5Te/Ko0mQh1slbX +4bL9NIipHPgpNeZRmRUnu4z00UJNGrI/qGaont3eMH1V65WGz9VMYnmCxkmsg45e +skrauB/Ly9DRRZBddDwAQF8RIbpqPsfQTuEjF0sGdYH3LaClGbA/cA== +-----END CERTIFICATE----- diff --git a/requirements.txt b/requirements.txt index 1fee2d2..8220873 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ pynng triopg aiohttp msgspec +pyOpenSSL trio_asyncio git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor diff --git a/scripts/generate_cert.py b/scripts/generate_cert.py new file mode 100644 index 0000000..621e4b1 --- /dev/null +++ b/scripts/generate_cert.py @@ -0,0 +1,44 @@ +#!/usr/bin/python + +'''Self signed x509 certificate generator + +can look at generated file using openssl: + openssl x509 -inform pem -in selfsigned.crt -noout -text''' +import sys + +from OpenSSL import crypto, SSL + +from skynet_bot.constants import DEFAULT_CERTS_DIR + + +def input_or_skip(txt, default): + i = input(f'[default: {default}]: {txt}') + if len(i) == 0: + return default + else: + return i + + +if __name__ == '__main__': + # create a key pair + k = crypto.PKey() + k.generate_key(crypto.TYPE_RSA, 4096) + # create a self-signed cert + cert = crypto.X509() + cert.get_subject().C = input('country name two char ISO code (example: US): ') + cert.get_subject().ST = input('state or province name (example: Texas): ') + cert.get_subject().L = input('locality name (example: Dallas): ') + cert.get_subject().O = input('organization name: ') + cert.get_subject().OU = input_or_skip('organizational unit name: ', 'none') + cert.get_subject().CN = input('common name: ') + cert.get_subject().emailAddress = input('email address: ') + cert.set_serial_number(int(input_or_skip('numberic serial number: ', 0))) + cert.gmtime_adj_notBefore(int(input_or_skip('amount of seconds until cert is valid: ', 0))) + cert.gmtime_adj_notAfter(int(input_or_skip('amount of seconds until cert expires: ', 10*365*24*60*60))) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(k) + cert.sign(k, 'sha512') + with open(f'{DEFAULT_CERTS_DIR}/{sys.argv[1]}.cert', "wt") as f: + f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")) + with open(f'{DEFAULT_CERTS_DIR}/{sys.argv[1]}.key', "wt") as f: + f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) diff --git a/skynet_bot/brain.py b/skynet_bot/brain.py index 37f00dc..79b62d2 100644 --- a/skynet_bot/brain.py +++ b/skynet_bot/brain.py @@ -6,6 +6,7 @@ import base64 import logging from uuid import UUID +from pathlib import Path from functools import partial from collections import OrderedDict @@ -13,6 +14,8 @@ import trio import pynng import trio_asyncio +from pynng import TLSConfig + from .db import * from .types import * from .constants import * @@ -241,11 +244,33 @@ async def run_skynet( db_host: str = DB_HOST, rpc_address: str = DEFAULT_RPC_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR, - task_status = trio.TASK_STATUS_IGNORED + task_status = trio.TASK_STATUS_IGNORED, + security: bool = True ): logging.basicConfig(level=logging.INFO) logging.info('skynet is starting') + tls_config = None + if security: + # load tls certs + certs_dir = Path(DEFAULT_CERTS_DIR).resolve() + tls_key = (certs_dir / DEFAULT_CERT_SKYNET_PRIV).read_text() + tls_cert = (certs_dir / DEFAULT_CERT_SKYNET_PUB).read_text() + tls_whitelist = [ + (cert_path).read_text() + for cert_path in (certs_dir / 'whitelist').glob('*.cert')] + + logging.info(f'tls_key: {tls_key}') + logging.info(f'tls_cert: {tls_cert}') + logging.info(f'tls_whitelist len: {len(tls_whitelist)}') + + rpc_address = 'tls+' + rpc_address + dgpu_address = 'tls+' + dgpu_address + tls_config = TLSConfig( + TLSConfig.MODE_SERVER, + own_key_string=tls_key, + own_cert_string=tls_cert) + async with ( trio.open_nursery() as n, open_database_connection( @@ -253,9 +278,16 @@ async def run_skynet( ): logging.info('connected to db.') with ( - pynng.Rep0(listen=rpc_address) as rpc_sock, - pynng.Bus0(listen=dgpu_address) as dgpu_bus + pynng.Rep0() as rpc_sock, + pynng.Bus0() as dgpu_bus ): + if security: + rpc_sock.tls_config = tls_config + dgpu_bus.tls_config = tls_config + + rpc_sock.listen(rpc_address) + dgpu_bus.listen(dgpu_address) + n.start_soon( rpc_service, rpc_sock, dgpu_bus, db_pool) task_status.started() diff --git a/skynet_bot/constants.py b/skynet_bot/constants.py index 0c0c03b..4fe4439 100644 --- a/skynet_bot/constants.py +++ b/skynet_bot/constants.py @@ -113,6 +113,12 @@ DEFAULT_ALGO = 'midj' DEFAULT_ROLE = 'pleb' DEFAULT_UPSCALER = None +DEFAULT_CERTS_DIR = 'certs' +DEFAULT_CERT_WHITELIST_DIR = 'whitelist' +DEFAULT_CERT_SKYNET_PUB = 'brain.cert' +DEFAULT_CERT_SKYNET_PRIV = 'brain.key' +DEFAULT_CERT_DGPU = 'dgpu.key' + DEFAULT_RPC_ADDR = 'tcp://127.0.0.1:41000' DEFAULT_DGPU_ADDR = 'tcp://127.0.0.1:41069' diff --git a/skynet_bot/dgpu.py b/skynet_bot/dgpu.py index 8c019e4..9f1fde3 100644 --- a/skynet_bot/dgpu.py +++ b/skynet_bot/dgpu.py @@ -16,10 +16,13 @@ from .frontend import rpc_call async def open_dgpu_node( + cert_name: str, + key_name: Optional[str], rpc_address: str = DEFAULT_RPC_ADDR, dgpu_address: str = DEFAULT_DGPU_ADDR, dgpu_max_tasks: int = DEFAULT_DGPU_MAX_TASKS, - initial_algos: str = DEFAULT_INITAL_ALGOS + initial_algos: str = DEFAULT_INITAL_ALGOS, + security: bool = True ): logging.basicConfig(level=logging.INFO) @@ -65,7 +68,11 @@ async def open_dgpu_node( return img - async with open_skynet_rpc() as rpc_call: + async with open_skynet_rpc( + security=security, + cert_name=cert_name, + key_name=key_name + ) as rpc_call: with pynng.Bus0(dial=dgpu_address) as dgpu_sock: async def _process_dgpu_req(req: DGPUBusRequest): img = await gpu_compute_one( diff --git a/skynet_bot/frontend/__init__.py b/skynet_bot/frontend/__init__.py index 4e728f9..62ac0af 100644 --- a/skynet_bot/frontend/__init__.py +++ b/skynet_bot/frontend/__init__.py @@ -2,11 +2,14 @@ import json -from typing import Union +from typing import Union, Optional +from pathlib import Path from contextlib import asynccontextmanager as acm import pynng +from pynng import TLSConfig + from ..types import SkynetRPCRequest, SkynetRPCResponse from ..constants import * @@ -48,8 +51,33 @@ async def rpc_call( @acm -async def open_skynet_rpc(rpc_address: str = DEFAULT_RPC_ADDR): - with pynng.Req0(dial=rpc_address) as sock: +async def open_skynet_rpc( + rpc_address: str = DEFAULT_RPC_ADDR, + security: bool = False, + cert_name: Optional[str] = None, + key_name: Optional[str] = None +): + tls_config = None + if security: + # load tls certs + if not key_name: + key_name = certs_name + certs_dir = Path(DEFAULT_CERTS_DIR).resolve() + skynet_cert = (certs_dir / 'brain.cert').read_text() + tls_cert = (certs_dir / f'{cert_name}.cert').read_text() + tls_key = (certs_dir / f'{key_name}.key').read_text() + rpc_address = 'tls+' + rpc_address + tls_config = TLSConfig( + TLSConfig.MODE_CLIENT, + own_key_string=tls_key, + own_cert_string=tls_cert, + ca_string=skynet_cert) + + with pynng.Req0() as sock: + if security: + sock.tls_config = tls_config + + sock.dial(rpc_address) async def _rpc_call(*args, **kwargs): return await rpc_call(sock, *args, **kwargs) diff --git a/skynet_bot/frontend/telegram.py b/skynet_bot/frontend/telegram.py index 211cc48..1e0b0ab 100644 --- a/skynet_bot/frontend/telegram.py +++ b/skynet_bot/frontend/telegram.py @@ -24,7 +24,9 @@ async def run_skynet_telegram( logging.basicConfig(level=logging.INFO) bot = AsyncTeleBot(tg_token) - with open_skynet_rpc() as rpc_call: + with open_skynet_rpc( + security=True, cert_name='telegram-frontend' + ) as rpc_call: async def _rpc_call( uid: int, @@ -69,7 +71,7 @@ async def run_skynet_telegram( 'config', {'attr': attr, 'val': val}) except BaseException as e: - reply_text = e.message + reply_text = str(e.value) finally: await bot.reply_to(message, reply_txt) diff --git a/skynet_bot/utils.py b/skynet_bot/utils.py new file mode 100644 index 0000000..8a60885 --- /dev/null +++ b/skynet_bot/utils.py @@ -0,0 +1,2 @@ +from OpenSSL.crypto import load_publickey, FILETYPE_PEM, verify, X509 + diff --git a/tests/test_skynet.py b/tests/test_skynet.py index f1520f8..5c9367a 100644 --- a/tests/test_skynet.py +++ b/tests/test_skynet.py @@ -3,6 +3,8 @@ import logging import trio +import pynng +import pytest import trio_asyncio from skynet_bot.types import * @@ -10,8 +12,20 @@ from skynet_bot.brain import run_skynet from skynet_bot.frontend import open_skynet_rpc +async def test_skynet_attempt_insecure(skynet_running): + with pytest.raises(pynng.exceptions.NNGException) as e: + async with open_skynet_rpc(): + ... + + assert str(e.value) == 'Connection shutdown' + + async def test_skynet_dgpu_connection_simple(skynet_running): - async with open_skynet_rpc() as rpc_call: + async with open_skynet_rpc( + security=True, + cert_name='whitelist/testing', + key_name='testing' + ) as rpc_call: # check 0 nodes are connected res = await rpc_call('dgpu-0', 'dgpu_workers') logging.info(res)