Appease mypy

acked_backup
Tyler Goodlet 2021-12-02 12:34:27 -05:00
parent a29924f330
commit e561a4908f
3 changed files with 53 additions and 30 deletions

View File

@ -11,7 +11,7 @@ import importlib.util
import inspect
import uuid
import typing
from typing import Dict, List, Tuple, Any, Optional, Union
from typing import List, Tuple, Any, Optional, Union
from types import ModuleType
import sys
import os
@ -49,7 +49,7 @@ async def _invoke(
cid: str,
chan: Channel,
func: typing.Callable,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
is_rpc: bool = True,
task_status: TaskStatus[
Union[trio.CancelScope, BaseException]
@ -267,21 +267,21 @@ _lifetime_stack: ExitStack = ExitStack()
async def try_ship_error_to_parent(
actor: Actor,
err: Exception,
channel: Channel,
err: Union[Exception, trio.MultiError],
) -> None:
with trio.CancelScope(shield=True):
try:
# internal error so ship to parent without cid
await actor._parent_chan.send(pack_error(err))
await channel.send(pack_error(err))
except (
trio.ClosedResourceError,
trio.BrokenResourceError,
):
log.error(
f"Failed to ship error to parent "
f"{actor._parent_chan.uid}, channel was closed"
f"{channel.uid}, channel was closed"
)
@ -319,7 +319,7 @@ class Actor:
_server_n: Optional[trio.Nursery] = None
# Information about `__main__` from parent
_parent_main_data: Dict[str, str]
_parent_main_data: dict[str, str]
_parent_chan_cs: Optional[trio.CancelScope] = None
# syncs for setup/teardown sequences
@ -357,7 +357,7 @@ class Actor:
mods[name] = _get_mod_abspath(mod)
self.enable_modules = mods
self._mods: Dict[str, ModuleType] = {}
self._mods: dict[str, ModuleType] = {}
# TODO: consider making this a dynamically defined
# @dataclass once we get py3.7
@ -380,12 +380,12 @@ class Actor:
self._ongoing_rpc_tasks = trio.Event()
self._ongoing_rpc_tasks.set()
# (chan, cid) -> (cancel_scope, func)
self._rpc_tasks: Dict[
self._rpc_tasks: dict[
Tuple[Channel, str],
Tuple[trio.CancelScope, typing.Callable, trio.Event]
] = {}
# map {uids -> {callids -> waiter queues}}
self._cids2qs: Dict[
self._cids2qs: dict[
Tuple[Tuple[str, str], str],
Tuple[
trio.abc.SendChannel[Any],
@ -396,7 +396,7 @@ class Actor:
self._parent_chan: Optional[Channel] = None
self._forkserver_info: Optional[
Tuple[Any, Any, Any, Any, Any]] = None
self._actoruid2nursery: Dict[str, 'ActorNursery'] = {} # type: ignore # noqa
self._actoruid2nursery: dict[Optional[tuple[str, str]], 'ActorNursery'] = {} # type: ignore # noqa
async def wait_for_peer(
self, uid: Tuple[str, str]
@ -550,6 +550,7 @@ class Actor:
cs.shield = True
# Attempt to wait for the far end to close the channel
# and bail after timeout (2-generals on closure).
assert chan.msgstream
async for msg in chan.msgstream.drain():
# try to deliver any lingering msgs
# before we destroy the channel.
@ -616,7 +617,7 @@ class Actor:
self,
chan: Channel,
cid: str,
msg: Dict[str, Any],
msg: dict[str, Any],
) -> None:
"""Push an RPC result to the local consumer's queue.
"""
@ -877,7 +878,7 @@ class Actor:
# machinery not from an rpc task) to parent
log.exception("Actor errored:")
if self._parent_chan:
await try_ship_error_to_parent(self, err)
await try_ship_error_to_parent(self._parent_chan, err)
# if this is the `MainProcess` we expect the error broadcasting
# above to trigger an error at consuming portal "checkpoints"
@ -1078,7 +1079,7 @@ class Actor:
)
if self._parent_chan:
await try_ship_error_to_parent(self, err)
await try_ship_error_to_parent(self._parent_chan, err)
# always!
log.exception("Actor errored:")
@ -1360,7 +1361,7 @@ class Arbiter(Actor):
def __init__(self, *args, **kwargs):
self._registry: Dict[
self._registry: dict[
Tuple[str, str],
Tuple[str, int],
] = {}
@ -1377,7 +1378,7 @@ class Arbiter(Actor):
async def get_registry(
self
) -> Dict[Tuple[str, str], Tuple[str, int]]:
) -> dict[Tuple[str, str], Tuple[str, int]]:
'''Return current name registry.
This method is async to allow for cross-actor invocation.

View File

@ -6,9 +6,10 @@ from __future__ import annotations
import platform
import struct
import typing
from collections.abc import AsyncGenerator, AsyncIterator
from typing import (
Any, Tuple, Optional,
Type, Protocol, TypeVar
Type, Protocol, TypeVar,
)
from tricycle import BufferedReceiveStream
@ -46,6 +47,7 @@ MsgType = TypeVar("MsgType")
class MsgTransport(Protocol[MsgType]):
stream: trio.SocketStream
drained: list[MsgType]
def __init__(self, stream: trio.SocketStream) -> None:
...
@ -63,6 +65,11 @@ class MsgTransport(Protocol[MsgType]):
def connected(self) -> bool:
...
# defining this sync otherwise it causes a mypy error because it
# can't figure out it's a generator i guess?..?
def drain(self) -> AsyncIterator[dict]:
...
@property
def laddr(self) -> Tuple[str, int]:
...
@ -94,9 +101,9 @@ class MsgpackTCPStream:
self._send_lock = trio.StrictFIFOLock()
# public i guess?
self.drained = []
self.drained: list[dict] = []
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream.
"""
unpacker = msgpack.Unpacker(
@ -159,7 +166,13 @@ class MsgpackTCPStream:
async def recv(self) -> Any:
return await self._agen.asend(None)
async def drain(self):
async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try:
async for msg in self._iter_packets():
self.drained.append(msg)
@ -196,7 +209,7 @@ class MsgspecTCPStream(MsgpackTCPStream):
self.encode = msgspec.Encoder().encode
self.decode = msgspec.Decoder().decode # dict[str, Any])
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''Yield packets from the underlying stream.
'''
@ -458,9 +471,11 @@ class Channel:
async def _aiter_recv(
self
) -> typing.AsyncGenerator[Any, None]:
"""Async iterate items from underlying stream.
"""
) -> AsyncGenerator[Any, None]:
'''
Async iterate items from underlying stream.
'''
assert self.msgstream
while True:
try:
@ -490,9 +505,11 @@ class Channel:
async def _connect_chan(
host: str, port: int
) -> typing.AsyncGenerator[Channel, None]:
"""Create and connect a channel with disconnect on context manager
'''
Create and connect a channel with disconnect on context manager
teardown.
"""
'''
chan = Channel((host, port))
await chan.connect()
yield chan

View File

@ -5,7 +5,11 @@ Machinery for actor process spawning using multiple backends.
import sys
import multiprocessing as mp
import platform
from typing import Any, Dict, Optional, Union, Callable
from typing import (
Any, Dict, Optional, Union, Callable,
TypeVar,
)
from collections.abc import Awaitable, Coroutine
import trio
from trio_typing import TaskStatus
@ -41,6 +45,7 @@ from ._exceptions import ActorFailure
log = get_logger('tractor')
ProcessType = TypeVar('ProcessType', mp.Process, trio.Process)
# placeholder for an mp start context if so using that backend
_ctx: Optional[mp.context.BaseContext] = None
@ -185,10 +190,10 @@ async def do_hard_kill(
async def soft_wait(
proc: Union[mp.Process, trio.Process],
proc: ProcessType,
wait_func: Callable[
Union[mp.Process, trio.Process],
None,
[ProcessType],
Awaitable,
],
portal: Portal,