Trying to make full suite pass with uds
parent
34a2f0c1f3
commit
9c8d23f41b
|
@ -10,6 +10,7 @@ pkgs.mkShell {
|
|||
inherit nativeBuildInputs;
|
||||
|
||||
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs;
|
||||
TMPDIR = "/tmp";
|
||||
|
||||
shellHook = ''
|
||||
set -e
|
||||
|
|
|
@ -9,7 +9,7 @@ async def main(service_name):
|
|||
async with tractor.open_nursery() as an:
|
||||
await an.start_actor(service_name)
|
||||
|
||||
async with tractor.get_registry(('127.0.0.1', 1616)) as portal:
|
||||
async with tractor.get_registry() as portal:
|
||||
print(f"Arbiter is listening on {portal.channel}")
|
||||
|
||||
async with tractor.wait_for_actor(service_name) as sockaddr:
|
||||
|
|
|
@ -66,6 +66,9 @@ def run_example_in_subproc(
|
|||
# due to backpressure!!!
|
||||
proc = testdir.popen(
|
||||
cmdargs,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
**kwargs,
|
||||
)
|
||||
assert not proc.returncode
|
||||
|
@ -119,10 +122,14 @@ def test_example(
|
|||
code = ex.read()
|
||||
|
||||
with run_example_in_subproc(code) as proc:
|
||||
proc.wait()
|
||||
err, _ = proc.stderr.read(), proc.stdout.read()
|
||||
# print(f'STDERR: {err}')
|
||||
# print(f'STDOUT: {out}')
|
||||
err = None
|
||||
try:
|
||||
if not proc.poll():
|
||||
_, err = proc.communicate(timeout=15)
|
||||
|
||||
except subprocess.TimeoutExpired as e:
|
||||
proc.kill()
|
||||
err = e.stderr
|
||||
|
||||
# if we get some gnarly output let's aggregate and raise
|
||||
if err:
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import tempfile
|
||||
from uuid import uuid4
|
||||
from typing import (
|
||||
|
@ -79,6 +80,9 @@ class Address(Protocol[
|
|||
async def open_listener(self, **kwargs) -> ListenerType:
|
||||
...
|
||||
|
||||
async def close_listener(self):
|
||||
...
|
||||
|
||||
|
||||
class TCPAddress(Address[
|
||||
str,
|
||||
|
@ -162,6 +166,9 @@ class TCPAddress(Address[
|
|||
self._host, self._port = listener.socket.getsockname()[:2]
|
||||
return listener
|
||||
|
||||
async def close_listener(self):
|
||||
...
|
||||
|
||||
|
||||
class UDSAddress(Address[
|
||||
None,
|
||||
|
@ -195,8 +202,8 @@ class UDSAddress(Address[
|
|||
return self._filepath
|
||||
|
||||
@classmethod
|
||||
def get_random(cls, _ns: None = None) -> UDSAddress:
|
||||
return UDSAddress(f'{tempfile.gettempdir()}/{uuid4().sock}')
|
||||
def get_random(cls, namespace: None = None) -> UDSAddress:
|
||||
return UDSAddress(f'{tempfile.gettempdir()}/{uuid4()}.sock')
|
||||
|
||||
@classmethod
|
||||
def get_root(cls) -> Address:
|
||||
|
@ -214,22 +221,24 @@ class UDSAddress(Address[
|
|||
return self._filepath == other._filepath
|
||||
|
||||
async def open_stream(self, **kwargs) -> trio.SocketStream:
|
||||
stream = await trio.open_tcp_stream(
|
||||
stream = await trio.open_unix_socket(
|
||||
self._filepath,
|
||||
**kwargs
|
||||
)
|
||||
self._binded = True
|
||||
return stream
|
||||
|
||||
async def open_listener(self, **kwargs) -> trio.SocketListener:
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.bind(self._filepath)
|
||||
sock.listen()
|
||||
self._binded = True
|
||||
return trio.SocketListener(sock)
|
||||
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
await self._sock.bind(self._filepath)
|
||||
self._sock.listen(1)
|
||||
return trio.SocketListener(self._sock)
|
||||
|
||||
async def close_listener(self):
|
||||
self._sock.close()
|
||||
os.unlink(self._filepath)
|
||||
|
||||
|
||||
preferred_transport = 'tcp'
|
||||
preferred_transport = 'uds'
|
||||
|
||||
|
||||
_address_types = (
|
||||
|
|
|
@ -31,7 +31,12 @@ def parse_uid(arg):
|
|||
return str(name), str(uuid) # ensures str encoding
|
||||
|
||||
def parse_ipaddr(arg):
|
||||
return literal_eval(arg)
|
||||
try:
|
||||
return literal_eval(arg)
|
||||
|
||||
except (ValueError, SyntaxError):
|
||||
# UDS: try to interpret as a straight up str
|
||||
return arg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -54,7 +54,7 @@ log = get_logger(__name__)
|
|||
|
||||
|
||||
@acm
|
||||
async def get_registry(addr: AddressTypes) -> AsyncGenerator[
|
||||
async def get_registry(addr: AddressTypes | None = None) -> AsyncGenerator[
|
||||
Portal | LocalPortal | None,
|
||||
None,
|
||||
]:
|
||||
|
|
|
@ -77,8 +77,9 @@ from .ipc import Channel
|
|||
from ._addr import (
|
||||
AddressTypes,
|
||||
Address,
|
||||
TCPAddress,
|
||||
wrap_address,
|
||||
preferred_transport,
|
||||
default_lo_addrs
|
||||
)
|
||||
from ._context import (
|
||||
mk_context,
|
||||
|
@ -1181,7 +1182,7 @@ class Actor:
|
|||
|
||||
'''
|
||||
if listen_addrs is None:
|
||||
listen_addrs = [TCPAddress.get_random()]
|
||||
listen_addrs = default_lo_addrs([preferred_transport])
|
||||
|
||||
else:
|
||||
listen_addrs: list[Address] = [
|
||||
|
@ -1217,6 +1218,8 @@ class Actor:
|
|||
task_status.started(server_n)
|
||||
|
||||
finally:
|
||||
for addr in listen_addrs:
|
||||
await addr.close_listener()
|
||||
# signal the server is down since nursery above terminated
|
||||
self._server_down.set()
|
||||
|
||||
|
|
Loading…
Reference in New Issue