Appease mypy

acked_backup
Tyler Goodlet 2021-12-02 12:34:27 -05:00
parent a29924f330
commit e561a4908f
3 changed files with 53 additions and 30 deletions

View File

@ -11,7 +11,7 @@ import importlib.util
import inspect import inspect
import uuid import uuid
import typing import typing
from typing import Dict, List, Tuple, Any, Optional, Union from typing import List, Tuple, Any, Optional, Union
from types import ModuleType from types import ModuleType
import sys import sys
import os import os
@ -49,7 +49,7 @@ async def _invoke(
cid: str, cid: str,
chan: Channel, chan: Channel,
func: typing.Callable, func: typing.Callable,
kwargs: Dict[str, Any], kwargs: dict[str, Any],
is_rpc: bool = True, is_rpc: bool = True,
task_status: TaskStatus[ task_status: TaskStatus[
Union[trio.CancelScope, BaseException] Union[trio.CancelScope, BaseException]
@ -267,21 +267,21 @@ _lifetime_stack: ExitStack = ExitStack()
async def try_ship_error_to_parent( async def try_ship_error_to_parent(
actor: Actor, channel: Channel,
err: Exception, err: Union[Exception, trio.MultiError],
) -> None: ) -> None:
with trio.CancelScope(shield=True): with trio.CancelScope(shield=True):
try: try:
# internal error so ship to parent without cid # internal error so ship to parent without cid
await actor._parent_chan.send(pack_error(err)) await channel.send(pack_error(err))
except ( except (
trio.ClosedResourceError, trio.ClosedResourceError,
trio.BrokenResourceError, trio.BrokenResourceError,
): ):
log.error( log.error(
f"Failed to ship error to parent " f"Failed to ship error to parent "
f"{actor._parent_chan.uid}, channel was closed" f"{channel.uid}, channel was closed"
) )
@ -319,7 +319,7 @@ class Actor:
_server_n: Optional[trio.Nursery] = None _server_n: Optional[trio.Nursery] = None
# Information about `__main__` from parent # Information about `__main__` from parent
_parent_main_data: Dict[str, str] _parent_main_data: dict[str, str]
_parent_chan_cs: Optional[trio.CancelScope] = None _parent_chan_cs: Optional[trio.CancelScope] = None
# syncs for setup/teardown sequences # syncs for setup/teardown sequences
@ -357,7 +357,7 @@ class Actor:
mods[name] = _get_mod_abspath(mod) mods[name] = _get_mod_abspath(mod)
self.enable_modules = mods self.enable_modules = mods
self._mods: Dict[str, ModuleType] = {} self._mods: dict[str, ModuleType] = {}
# TODO: consider making this a dynamically defined # TODO: consider making this a dynamically defined
# @dataclass once we get py3.7 # @dataclass once we get py3.7
@ -380,12 +380,12 @@ class Actor:
self._ongoing_rpc_tasks = trio.Event() self._ongoing_rpc_tasks = trio.Event()
self._ongoing_rpc_tasks.set() self._ongoing_rpc_tasks.set()
# (chan, cid) -> (cancel_scope, func) # (chan, cid) -> (cancel_scope, func)
self._rpc_tasks: Dict[ self._rpc_tasks: dict[
Tuple[Channel, str], Tuple[Channel, str],
Tuple[trio.CancelScope, typing.Callable, trio.Event] Tuple[trio.CancelScope, typing.Callable, trio.Event]
] = {} ] = {}
# map {uids -> {callids -> waiter queues}} # map {uids -> {callids -> waiter queues}}
self._cids2qs: Dict[ self._cids2qs: dict[
Tuple[Tuple[str, str], str], Tuple[Tuple[str, str], str],
Tuple[ Tuple[
trio.abc.SendChannel[Any], trio.abc.SendChannel[Any],
@ -396,7 +396,7 @@ class Actor:
self._parent_chan: Optional[Channel] = None self._parent_chan: Optional[Channel] = None
self._forkserver_info: Optional[ self._forkserver_info: Optional[
Tuple[Any, Any, Any, Any, Any]] = None Tuple[Any, Any, Any, Any, Any]] = None
self._actoruid2nursery: Dict[str, 'ActorNursery'] = {} # type: ignore # noqa self._actoruid2nursery: dict[Optional[tuple[str, str]], 'ActorNursery'] = {} # type: ignore # noqa
async def wait_for_peer( async def wait_for_peer(
self, uid: Tuple[str, str] self, uid: Tuple[str, str]
@ -550,6 +550,7 @@ class Actor:
cs.shield = True cs.shield = True
# Attempt to wait for the far end to close the channel # Attempt to wait for the far end to close the channel
# and bail after timeout (2-generals on closure). # and bail after timeout (2-generals on closure).
assert chan.msgstream
async for msg in chan.msgstream.drain(): async for msg in chan.msgstream.drain():
# try to deliver any lingering msgs # try to deliver any lingering msgs
# before we destroy the channel. # before we destroy the channel.
@ -616,7 +617,7 @@ class Actor:
self, self,
chan: Channel, chan: Channel,
cid: str, cid: str,
msg: Dict[str, Any], msg: dict[str, Any],
) -> None: ) -> None:
"""Push an RPC result to the local consumer's queue. """Push an RPC result to the local consumer's queue.
""" """
@ -877,7 +878,7 @@ class Actor:
# machinery not from an rpc task) to parent # machinery not from an rpc task) to parent
log.exception("Actor errored:") log.exception("Actor errored:")
if self._parent_chan: if self._parent_chan:
await try_ship_error_to_parent(self, err) await try_ship_error_to_parent(self._parent_chan, err)
# if this is the `MainProcess` we expect the error broadcasting # if this is the `MainProcess` we expect the error broadcasting
# above to trigger an error at consuming portal "checkpoints" # above to trigger an error at consuming portal "checkpoints"
@ -1078,7 +1079,7 @@ class Actor:
) )
if self._parent_chan: if self._parent_chan:
await try_ship_error_to_parent(self, err) await try_ship_error_to_parent(self._parent_chan, err)
# always! # always!
log.exception("Actor errored:") log.exception("Actor errored:")
@ -1360,7 +1361,7 @@ class Arbiter(Actor):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._registry: Dict[ self._registry: dict[
Tuple[str, str], Tuple[str, str],
Tuple[str, int], Tuple[str, int],
] = {} ] = {}
@ -1377,7 +1378,7 @@ class Arbiter(Actor):
async def get_registry( async def get_registry(
self self
) -> Dict[Tuple[str, str], Tuple[str, int]]: ) -> dict[Tuple[str, str], Tuple[str, int]]:
'''Return current name registry. '''Return current name registry.
This method is async to allow for cross-actor invocation. This method is async to allow for cross-actor invocation.

View File

@ -6,9 +6,10 @@ from __future__ import annotations
import platform import platform
import struct import struct
import typing import typing
from collections.abc import AsyncGenerator, AsyncIterator
from typing import ( from typing import (
Any, Tuple, Optional, Any, Tuple, Optional,
Type, Protocol, TypeVar Type, Protocol, TypeVar,
) )
from tricycle import BufferedReceiveStream from tricycle import BufferedReceiveStream
@ -46,6 +47,7 @@ MsgType = TypeVar("MsgType")
class MsgTransport(Protocol[MsgType]): class MsgTransport(Protocol[MsgType]):
stream: trio.SocketStream stream: trio.SocketStream
drained: list[MsgType]
def __init__(self, stream: trio.SocketStream) -> None: def __init__(self, stream: trio.SocketStream) -> None:
... ...
@ -63,6 +65,11 @@ class MsgTransport(Protocol[MsgType]):
def connected(self) -> bool: def connected(self) -> bool:
... ...
# defining this sync otherwise it causes a mypy error because it
# can't figure out it's a generator i guess?..?
def drain(self) -> AsyncIterator[dict]:
...
@property @property
def laddr(self) -> Tuple[str, int]: def laddr(self) -> Tuple[str, int]:
... ...
@ -94,9 +101,9 @@ class MsgpackTCPStream:
self._send_lock = trio.StrictFIFOLock() self._send_lock = trio.StrictFIFOLock()
# public i guess? # public i guess?
self.drained = [] self.drained: list[dict] = []
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: async def _iter_packets(self) -> AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream. """Yield packets from the underlying stream.
""" """
unpacker = msgpack.Unpacker( unpacker = msgpack.Unpacker(
@ -159,7 +166,13 @@ class MsgpackTCPStream:
async def recv(self) -> Any: async def recv(self) -> Any:
return await self._agen.asend(None) return await self._agen.asend(None)
async def drain(self): async def drain(self) -> AsyncIterator[dict]:
'''
Drain the stream's remaining messages sent from
the far end until the connection is closed by
the peer.
'''
try: try:
async for msg in self._iter_packets(): async for msg in self._iter_packets():
self.drained.append(msg) self.drained.append(msg)
@ -196,7 +209,7 @@ class MsgspecTCPStream(MsgpackTCPStream):
self.encode = msgspec.Encoder().encode self.encode = msgspec.Encoder().encode
self.decode = msgspec.Decoder().decode # dict[str, Any]) self.decode = msgspec.Decoder().decode # dict[str, Any])
async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]: async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''Yield packets from the underlying stream. '''Yield packets from the underlying stream.
''' '''
@ -458,9 +471,11 @@ class Channel:
async def _aiter_recv( async def _aiter_recv(
self self
) -> typing.AsyncGenerator[Any, None]: ) -> AsyncGenerator[Any, None]:
"""Async iterate items from underlying stream. '''
""" Async iterate items from underlying stream.
'''
assert self.msgstream assert self.msgstream
while True: while True:
try: try:
@ -490,9 +505,11 @@ class Channel:
async def _connect_chan( async def _connect_chan(
host: str, port: int host: str, port: int
) -> typing.AsyncGenerator[Channel, None]: ) -> typing.AsyncGenerator[Channel, None]:
"""Create and connect a channel with disconnect on context manager '''
Create and connect a channel with disconnect on context manager
teardown. teardown.
"""
'''
chan = Channel((host, port)) chan = Channel((host, port))
await chan.connect() await chan.connect()
yield chan yield chan

View File

@ -5,7 +5,11 @@ Machinery for actor process spawning using multiple backends.
import sys import sys
import multiprocessing as mp import multiprocessing as mp
import platform import platform
from typing import Any, Dict, Optional, Union, Callable from typing import (
Any, Dict, Optional, Union, Callable,
TypeVar,
)
from collections.abc import Awaitable, Coroutine
import trio import trio
from trio_typing import TaskStatus from trio_typing import TaskStatus
@ -41,6 +45,7 @@ from ._exceptions import ActorFailure
log = get_logger('tractor') log = get_logger('tractor')
ProcessType = TypeVar('ProcessType', mp.Process, trio.Process)
# placeholder for an mp start context if so using that backend # placeholder for an mp start context if so using that backend
_ctx: Optional[mp.context.BaseContext] = None _ctx: Optional[mp.context.BaseContext] = None
@ -185,10 +190,10 @@ async def do_hard_kill(
async def soft_wait( async def soft_wait(
proc: Union[mp.Process, trio.Process], proc: ProcessType,
wait_func: Callable[ wait_func: Callable[
Union[mp.Process, trio.Process], [ProcessType],
None, Awaitable,
], ],
portal: Portal, portal: Portal,