Trying to make full suite pass with uds

Guillermo Rodriguez 2025-03-23 02:18:01 -03:00
parent 34a2f0c1f3
commit 9c8d23f41b
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
7 changed files with 44 additions and 19 deletions

View File

@ -10,6 +10,7 @@ pkgs.mkShell {
inherit nativeBuildInputs; inherit nativeBuildInputs;
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs; LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath nativeBuildInputs;
TMPDIR = "/tmp";
shellHook = '' shellHook = ''
set -e set -e

View File

@ -9,7 +9,7 @@ async def main(service_name):
async with tractor.open_nursery() as an: async with tractor.open_nursery() as an:
await an.start_actor(service_name) await an.start_actor(service_name)
async with tractor.get_registry(('127.0.0.1', 1616)) as portal: async with tractor.get_registry() as portal:
print(f"Arbiter is listening on {portal.channel}") print(f"Arbiter is listening on {portal.channel}")
async with tractor.wait_for_actor(service_name) as sockaddr: async with tractor.wait_for_actor(service_name) as sockaddr:

View File

@ -66,6 +66,9 @@ def run_example_in_subproc(
# due to backpressure!!! # due to backpressure!!!
proc = testdir.popen( proc = testdir.popen(
cmdargs, cmdargs,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
**kwargs, **kwargs,
) )
assert not proc.returncode assert not proc.returncode
@ -119,10 +122,14 @@ def test_example(
code = ex.read() code = ex.read()
with run_example_in_subproc(code) as proc: with run_example_in_subproc(code) as proc:
proc.wait() err = None
err, _ = proc.stderr.read(), proc.stdout.read() try:
# print(f'STDERR: {err}') if not proc.poll():
# print(f'STDOUT: {out}') _, err = proc.communicate(timeout=15)
except subprocess.TimeoutExpired as e:
proc.kill()
err = e.stderr
# if we get some gnarly output let's aggregate and raise # if we get some gnarly output let's aggregate and raise
if err: if err:

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
import os
import tempfile import tempfile
from uuid import uuid4 from uuid import uuid4
from typing import ( from typing import (
@ -79,6 +80,9 @@ class Address(Protocol[
async def open_listener(self, **kwargs) -> ListenerType: async def open_listener(self, **kwargs) -> ListenerType:
... ...
async def close_listener(self):
...
class TCPAddress(Address[ class TCPAddress(Address[
str, str,
@ -162,6 +166,9 @@ class TCPAddress(Address[
self._host, self._port = listener.socket.getsockname()[:2] self._host, self._port = listener.socket.getsockname()[:2]
return listener return listener
async def close_listener(self):
...
class UDSAddress(Address[ class UDSAddress(Address[
None, None,
@ -195,8 +202,8 @@ class UDSAddress(Address[
return self._filepath return self._filepath
@classmethod @classmethod
def get_random(cls, _ns: None = None) -> UDSAddress: def get_random(cls, namespace: None = None) -> UDSAddress:
return UDSAddress(f'{tempfile.gettempdir()}/{uuid4().sock}') return UDSAddress(f'{tempfile.gettempdir()}/{uuid4()}.sock')
@classmethod @classmethod
def get_root(cls) -> Address: def get_root(cls) -> Address:
@ -214,22 +221,24 @@ class UDSAddress(Address[
return self._filepath == other._filepath return self._filepath == other._filepath
async def open_stream(self, **kwargs) -> trio.SocketStream: async def open_stream(self, **kwargs) -> trio.SocketStream:
stream = await trio.open_tcp_stream( stream = await trio.open_unix_socket(
self._filepath, self._filepath,
**kwargs **kwargs
) )
self._binded = True
return stream return stream
async def open_listener(self, **kwargs) -> trio.SocketListener: async def open_listener(self, **kwargs) -> trio.SocketListener:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind(self._filepath) await self._sock.bind(self._filepath)
sock.listen() self._sock.listen(1)
self._binded = True return trio.SocketListener(self._sock)
return trio.SocketListener(sock)
async def close_listener(self):
self._sock.close()
os.unlink(self._filepath)
preferred_transport = 'tcp' preferred_transport = 'uds'
_address_types = ( _address_types = (

View File

@ -31,7 +31,12 @@ def parse_uid(arg):
return str(name), str(uuid) # ensures str encoding return str(name), str(uuid) # ensures str encoding
def parse_ipaddr(arg): def parse_ipaddr(arg):
return literal_eval(arg) try:
return literal_eval(arg)
except (ValueError, SyntaxError):
# UDS: try to interpret as a straight up str
return arg
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -54,7 +54,7 @@ log = get_logger(__name__)
@acm @acm
async def get_registry(addr: AddressTypes) -> AsyncGenerator[ async def get_registry(addr: AddressTypes | None = None) -> AsyncGenerator[
Portal | LocalPortal | None, Portal | LocalPortal | None,
None, None,
]: ]:

View File

@ -77,8 +77,9 @@ from .ipc import Channel
from ._addr import ( from ._addr import (
AddressTypes, AddressTypes,
Address, Address,
TCPAddress,
wrap_address, wrap_address,
preferred_transport,
default_lo_addrs
) )
from ._context import ( from ._context import (
mk_context, mk_context,
@ -1181,7 +1182,7 @@ class Actor:
''' '''
if listen_addrs is None: if listen_addrs is None:
listen_addrs = [TCPAddress.get_random()] listen_addrs = default_lo_addrs([preferred_transport])
else: else:
listen_addrs: list[Address] = [ listen_addrs: list[Address] = [
@ -1217,6 +1218,8 @@ class Actor:
task_status.started(server_n) task_status.started(server_n)
finally: finally:
for addr in listen_addrs:
await addr.close_listener()
# signal the server is down since nursery above terminated # signal the server is down since nursery above terminated
self._server_down.set() self._server_down.set()