forked from goodboy/tractor
Appease mypy
parent
a29924f330
commit
e561a4908f
|
@ -11,7 +11,7 @@ import importlib.util
|
|||
import inspect
|
||||
import uuid
|
||||
import typing
|
||||
from typing import Dict, List, Tuple, Any, Optional, Union
|
||||
from typing import List, Tuple, Any, Optional, Union
|
||||
from types import ModuleType
|
||||
import sys
|
||||
import os
|
||||
|
@ -49,7 +49,7 @@ async def _invoke(
|
|||
cid: str,
|
||||
chan: Channel,
|
||||
func: typing.Callable,
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
is_rpc: bool = True,
|
||||
task_status: TaskStatus[
|
||||
Union[trio.CancelScope, BaseException]
|
||||
|
@ -267,21 +267,21 @@ _lifetime_stack: ExitStack = ExitStack()
|
|||
|
||||
|
||||
async def try_ship_error_to_parent(
|
||||
actor: Actor,
|
||||
err: Exception,
|
||||
channel: Channel,
|
||||
err: Union[Exception, trio.MultiError],
|
||||
|
||||
) -> None:
|
||||
with trio.CancelScope(shield=True):
|
||||
try:
|
||||
# internal error so ship to parent without cid
|
||||
await actor._parent_chan.send(pack_error(err))
|
||||
await channel.send(pack_error(err))
|
||||
except (
|
||||
trio.ClosedResourceError,
|
||||
trio.BrokenResourceError,
|
||||
):
|
||||
log.error(
|
||||
f"Failed to ship error to parent "
|
||||
f"{actor._parent_chan.uid}, channel was closed"
|
||||
f"{channel.uid}, channel was closed"
|
||||
)
|
||||
|
||||
|
||||
|
@ -319,7 +319,7 @@ class Actor:
|
|||
_server_n: Optional[trio.Nursery] = None
|
||||
|
||||
# Information about `__main__` from parent
|
||||
_parent_main_data: Dict[str, str]
|
||||
_parent_main_data: dict[str, str]
|
||||
_parent_chan_cs: Optional[trio.CancelScope] = None
|
||||
|
||||
# syncs for setup/teardown sequences
|
||||
|
@ -357,7 +357,7 @@ class Actor:
|
|||
mods[name] = _get_mod_abspath(mod)
|
||||
|
||||
self.enable_modules = mods
|
||||
self._mods: Dict[str, ModuleType] = {}
|
||||
self._mods: dict[str, ModuleType] = {}
|
||||
|
||||
# TODO: consider making this a dynamically defined
|
||||
# @dataclass once we get py3.7
|
||||
|
@ -380,12 +380,12 @@ class Actor:
|
|||
self._ongoing_rpc_tasks = trio.Event()
|
||||
self._ongoing_rpc_tasks.set()
|
||||
# (chan, cid) -> (cancel_scope, func)
|
||||
self._rpc_tasks: Dict[
|
||||
self._rpc_tasks: dict[
|
||||
Tuple[Channel, str],
|
||||
Tuple[trio.CancelScope, typing.Callable, trio.Event]
|
||||
] = {}
|
||||
# map {uids -> {callids -> waiter queues}}
|
||||
self._cids2qs: Dict[
|
||||
self._cids2qs: dict[
|
||||
Tuple[Tuple[str, str], str],
|
||||
Tuple[
|
||||
trio.abc.SendChannel[Any],
|
||||
|
@ -396,7 +396,7 @@ class Actor:
|
|||
self._parent_chan: Optional[Channel] = None
|
||||
self._forkserver_info: Optional[
|
||||
Tuple[Any, Any, Any, Any, Any]] = None
|
||||
self._actoruid2nursery: Dict[str, 'ActorNursery'] = {} # type: ignore # noqa
|
||||
self._actoruid2nursery: dict[Optional[tuple[str, str]], 'ActorNursery'] = {} # type: ignore # noqa
|
||||
|
||||
async def wait_for_peer(
|
||||
self, uid: Tuple[str, str]
|
||||
|
@ -550,6 +550,7 @@ class Actor:
|
|||
cs.shield = True
|
||||
# Attempt to wait for the far end to close the channel
|
||||
# and bail after timeout (2-generals on closure).
|
||||
assert chan.msgstream
|
||||
async for msg in chan.msgstream.drain():
|
||||
# try to deliver any lingering msgs
|
||||
# before we destroy the channel.
|
||||
|
@ -616,7 +617,7 @@ class Actor:
|
|||
self,
|
||||
chan: Channel,
|
||||
cid: str,
|
||||
msg: Dict[str, Any],
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Push an RPC result to the local consumer's queue.
|
||||
"""
|
||||
|
@ -877,7 +878,7 @@ class Actor:
|
|||
# machinery not from an rpc task) to parent
|
||||
log.exception("Actor errored:")
|
||||
if self._parent_chan:
|
||||
await try_ship_error_to_parent(self, err)
|
||||
await try_ship_error_to_parent(self._parent_chan, err)
|
||||
|
||||
# if this is the `MainProcess` we expect the error broadcasting
|
||||
# above to trigger an error at consuming portal "checkpoints"
|
||||
|
@ -1078,7 +1079,7 @@ class Actor:
|
|||
)
|
||||
|
||||
if self._parent_chan:
|
||||
await try_ship_error_to_parent(self, err)
|
||||
await try_ship_error_to_parent(self._parent_chan, err)
|
||||
|
||||
# always!
|
||||
log.exception("Actor errored:")
|
||||
|
@ -1360,7 +1361,7 @@ class Arbiter(Actor):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
self._registry: Dict[
|
||||
self._registry: dict[
|
||||
Tuple[str, str],
|
||||
Tuple[str, int],
|
||||
] = {}
|
||||
|
@ -1377,7 +1378,7 @@ class Arbiter(Actor):
|
|||
|
||||
async def get_registry(
|
||||
self
|
||||
) -> Dict[Tuple[str, str], Tuple[str, int]]:
|
||||
) -> dict[Tuple[str, str], Tuple[str, int]]:
|
||||
'''Return current name registry.
|
||||
|
||||
This method is async to allow for cross-actor invocation.
|
||||
|
|
|
@ -6,9 +6,10 @@ from __future__ import annotations
|
|||
import platform
|
||||
import struct
|
||||
import typing
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import (
|
||||
Any, Tuple, Optional,
|
||||
Type, Protocol, TypeVar
|
||||
Type, Protocol, TypeVar,
|
||||
)
|
||||
|
||||
from tricycle import BufferedReceiveStream
|
||||
|
@ -46,6 +47,7 @@ MsgType = TypeVar("MsgType")
|
|||
class MsgTransport(Protocol[MsgType]):
|
||||
|
||||
stream: trio.SocketStream
|
||||
drained: list[MsgType]
|
||||
|
||||
def __init__(self, stream: trio.SocketStream) -> None:
|
||||
...
|
||||
|
@ -63,6 +65,11 @@ class MsgTransport(Protocol[MsgType]):
|
|||
def connected(self) -> bool:
|
||||
...
|
||||
|
||||
# defining this sync otherwise it causes a mypy error because it
|
||||
# can't figure out it's a generator i guess?..?
|
||||
def drain(self) -> AsyncIterator[dict]:
|
||||
...
|
||||
|
||||
@property
|
||||
def laddr(self) -> Tuple[str, int]:
|
||||
...
|
||||
|
@ -94,9 +101,9 @@ class MsgpackTCPStream:
|
|||
self._send_lock = trio.StrictFIFOLock()
|
||||
|
||||
# public i guess?
|
||||
self.drained = []
|
||||
self.drained: list[dict] = []
|
||||
|
||||
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||
"""Yield packets from the underlying stream.
|
||||
"""
|
||||
unpacker = msgpack.Unpacker(
|
||||
|
@ -159,7 +166,13 @@ class MsgpackTCPStream:
|
|||
async def recv(self) -> Any:
|
||||
return await self._agen.asend(None)
|
||||
|
||||
async def drain(self):
|
||||
async def drain(self) -> AsyncIterator[dict]:
|
||||
'''
|
||||
Drain the stream's remaining messages sent from
|
||||
the far end until the connection is closed by
|
||||
the peer.
|
||||
|
||||
'''
|
||||
try:
|
||||
async for msg in self._iter_packets():
|
||||
self.drained.append(msg)
|
||||
|
@ -196,7 +209,7 @@ class MsgspecTCPStream(MsgpackTCPStream):
|
|||
self.encode = msgspec.Encoder().encode
|
||||
self.decode = msgspec.Decoder().decode # dict[str, Any])
|
||||
|
||||
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
|
||||
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
|
||||
'''Yield packets from the underlying stream.
|
||||
|
||||
'''
|
||||
|
@ -458,9 +471,11 @@ class Channel:
|
|||
|
||||
async def _aiter_recv(
|
||||
self
|
||||
) -> typing.AsyncGenerator[Any, None]:
|
||||
"""Async iterate items from underlying stream.
|
||||
"""
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
'''
|
||||
Async iterate items from underlying stream.
|
||||
|
||||
'''
|
||||
assert self.msgstream
|
||||
while True:
|
||||
try:
|
||||
|
@ -490,9 +505,11 @@ class Channel:
|
|||
async def _connect_chan(
|
||||
host: str, port: int
|
||||
) -> typing.AsyncGenerator[Channel, None]:
|
||||
"""Create and connect a channel with disconnect on context manager
|
||||
'''
|
||||
Create and connect a channel with disconnect on context manager
|
||||
teardown.
|
||||
"""
|
||||
|
||||
'''
|
||||
chan = Channel((host, port))
|
||||
await chan.connect()
|
||||
yield chan
|
||||
|
|
|
@ -5,7 +5,11 @@ Machinery for actor process spawning using multiple backends.
|
|||
import sys
|
||||
import multiprocessing as mp
|
||||
import platform
|
||||
from typing import Any, Dict, Optional, Union, Callable
|
||||
from typing import (
|
||||
Any, Dict, Optional, Union, Callable,
|
||||
TypeVar,
|
||||
)
|
||||
from collections.abc import Awaitable, Coroutine
|
||||
|
||||
import trio
|
||||
from trio_typing import TaskStatus
|
||||
|
@ -41,6 +45,7 @@ from ._exceptions import ActorFailure
|
|||
|
||||
|
||||
log = get_logger('tractor')
|
||||
ProcessType = TypeVar('ProcessType', mp.Process, trio.Process)
|
||||
|
||||
# placeholder for an mp start context if so using that backend
|
||||
_ctx: Optional[mp.context.BaseContext] = None
|
||||
|
@ -185,10 +190,10 @@ async def do_hard_kill(
|
|||
|
||||
async def soft_wait(
|
||||
|
||||
proc: Union[mp.Process, trio.Process],
|
||||
proc: ProcessType,
|
||||
wait_func: Callable[
|
||||
Union[mp.Process, trio.Process],
|
||||
None,
|
||||
[ProcessType],
|
||||
Awaitable,
|
||||
],
|
||||
portal: Portal,
|
||||
|
||||
|
|
Loading…
Reference in New Issue