Trying to make full suite pass with uds

Guillermo Rodriguez 2025-03-23 02:18:01 -03:00
parent 34a2f0c1f3
commit 9c8d23f41b
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
7 changed files with 44 additions and 19 deletions

View File

@ -10,6 +10,7 @@ pkgs.mkShell {
inherit nativeBuildInputs;
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs;
TMPDIR = "/tmp";
shellHook = ''
set -e

View File

@ -9,7 +9,7 @@ async def main(service_name):
async with tractor.open_nursery() as an:
await an.start_actor(service_name)
async with tractor.get_registry(('127.0.0.1', 1616)) as portal:
async with tractor.get_registry() as portal:
print(f"Arbiter is listening on {portal.channel}")
async with tractor.wait_for_actor(service_name) as sockaddr:

View File

@ -66,6 +66,9 @@ def run_example_in_subproc(
# due to backpressure!!!
proc = testdir.popen(
cmdargs,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
**kwargs,
)
assert not proc.returncode
@ -119,10 +122,14 @@ def test_example(
code = ex.read()
with run_example_in_subproc(code) as proc:
proc.wait()
err, _ = proc.stderr.read(), proc.stdout.read()
# print(f'STDERR: {err}')
# print(f'STDOUT: {out}')
err = None
try:
if not proc.poll():
_, err = proc.communicate(timeout=15)
except subprocess.TimeoutExpired as e:
proc.kill()
err = e.stderr
# if we get some gnarly output let's aggregate and raise
if err:

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import os
import tempfile
from uuid import uuid4
from typing import (
@ -79,6 +80,9 @@ class Address(Protocol[
async def open_listener(self, **kwargs) -> ListenerType:
...
async def close_listener(self):
...
class TCPAddress(Address[
str,
@ -162,6 +166,9 @@ class TCPAddress(Address[
self._host, self._port = listener.socket.getsockname()[:2]
return listener
async def close_listener(self):
...
class UDSAddress(Address[
None,
@ -195,8 +202,8 @@ class UDSAddress(Address[
return self._filepath
@classmethod
def get_random(cls, _ns: None = None) -> UDSAddress:
return UDSAddress(f'{tempfile.gettempdir()}/{uuid4().sock}')
def get_random(cls, namespace: None = None) -> UDSAddress:
return UDSAddress(f'{tempfile.gettempdir()}/{uuid4()}.sock')
@classmethod
def get_root(cls) -> Address:
@ -214,22 +221,24 @@ class UDSAddress(Address[
return self._filepath == other._filepath
async def open_stream(self, **kwargs) -> trio.SocketStream:
stream = await trio.open_tcp_stream(
stream = await trio.open_unix_socket(
self._filepath,
**kwargs
)
self._binded = True
return stream
async def open_listener(self, **kwargs) -> trio.SocketListener:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind(self._filepath)
sock.listen()
self._binded = True
return trio.SocketListener(sock)
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
await self._sock.bind(self._filepath)
self._sock.listen(1)
return trio.SocketListener(self._sock)
async def close_listener(self):
self._sock.close()
os.unlink(self._filepath)
preferred_transport = 'tcp'
preferred_transport = 'uds'
_address_types = (

View File

@ -31,7 +31,12 @@ def parse_uid(arg):
return str(name), str(uuid) # ensures str encoding
def parse_ipaddr(arg):
return literal_eval(arg)
try:
return literal_eval(arg)
except (ValueError, SyntaxError):
# UDS: try to interpret as a straight up str
return arg
if __name__ == "__main__":

View File

@ -54,7 +54,7 @@ log = get_logger(__name__)
@acm
async def get_registry(addr: AddressTypes) -> AsyncGenerator[
async def get_registry(addr: AddressTypes | None = None) -> AsyncGenerator[
Portal | LocalPortal | None,
None,
]:

View File

@ -77,8 +77,9 @@ from .ipc import Channel
from ._addr import (
AddressTypes,
Address,
TCPAddress,
wrap_address,
preferred_transport,
default_lo_addrs
)
from ._context import (
mk_context,
@ -1181,7 +1182,7 @@ class Actor:
'''
if listen_addrs is None:
listen_addrs = [TCPAddress.get_random()]
listen_addrs = default_lo_addrs([preferred_transport])
else:
listen_addrs: list[Address] = [
@ -1217,6 +1218,8 @@ class Actor:
task_status.started(server_n)
finally:
for addr in listen_addrs:
await addr.close_listener()
# signal the server is down since nursery above terminated
self._server_down.set()