forked from goodboy/tractor
				
			Appease mypy
							parent
							
								
									a29924f330
								
							
						
					
					
						commit
						e561a4908f
					
				| 
						 | 
				
			
			@ -11,7 +11,7 @@ import importlib.util
 | 
			
		|||
import inspect
 | 
			
		||||
import uuid
 | 
			
		||||
import typing
 | 
			
		||||
from typing import Dict, List, Tuple, Any, Optional, Union
 | 
			
		||||
from typing import List, Tuple, Any, Optional, Union
 | 
			
		||||
from types import ModuleType
 | 
			
		||||
import sys
 | 
			
		||||
import os
 | 
			
		||||
| 
						 | 
				
			
			@ -49,7 +49,7 @@ async def _invoke(
 | 
			
		|||
    cid: str,
 | 
			
		||||
    chan: Channel,
 | 
			
		||||
    func: typing.Callable,
 | 
			
		||||
    kwargs: Dict[str, Any],
 | 
			
		||||
    kwargs: dict[str, Any],
 | 
			
		||||
    is_rpc: bool = True,
 | 
			
		||||
    task_status: TaskStatus[
 | 
			
		||||
        Union[trio.CancelScope, BaseException]
 | 
			
		||||
| 
						 | 
				
			
			@ -267,21 +267,21 @@ _lifetime_stack: ExitStack = ExitStack()
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
async def try_ship_error_to_parent(
 | 
			
		||||
    actor: Actor,
 | 
			
		||||
    err: Exception,
 | 
			
		||||
    channel: Channel,
 | 
			
		||||
    err: Union[Exception, trio.MultiError],
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
    with trio.CancelScope(shield=True):
 | 
			
		||||
        try:
 | 
			
		||||
            # internal error so ship to parent without cid
 | 
			
		||||
            await actor._parent_chan.send(pack_error(err))
 | 
			
		||||
            await channel.send(pack_error(err))
 | 
			
		||||
        except (
 | 
			
		||||
            trio.ClosedResourceError,
 | 
			
		||||
            trio.BrokenResourceError,
 | 
			
		||||
        ):
 | 
			
		||||
            log.error(
 | 
			
		||||
                f"Failed to ship error to parent "
 | 
			
		||||
                f"{actor._parent_chan.uid}, channel was closed"
 | 
			
		||||
                f"{channel.uid}, channel was closed"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -319,7 +319,7 @@ class Actor:
 | 
			
		|||
    _server_n: Optional[trio.Nursery] = None
 | 
			
		||||
 | 
			
		||||
    # Information about `__main__` from parent
 | 
			
		||||
    _parent_main_data: Dict[str, str]
 | 
			
		||||
    _parent_main_data: dict[str, str]
 | 
			
		||||
    _parent_chan_cs: Optional[trio.CancelScope] = None
 | 
			
		||||
 | 
			
		||||
    # syncs for setup/teardown sequences
 | 
			
		||||
| 
						 | 
				
			
			@ -357,7 +357,7 @@ class Actor:
 | 
			
		|||
            mods[name] = _get_mod_abspath(mod)
 | 
			
		||||
 | 
			
		||||
        self.enable_modules = mods
 | 
			
		||||
        self._mods: Dict[str, ModuleType] = {}
 | 
			
		||||
        self._mods: dict[str, ModuleType] = {}
 | 
			
		||||
 | 
			
		||||
        # TODO: consider making this a dynamically defined
 | 
			
		||||
        # @dataclass once we get py3.7
 | 
			
		||||
| 
						 | 
				
			
			@ -380,12 +380,12 @@ class Actor:
 | 
			
		|||
        self._ongoing_rpc_tasks = trio.Event()
 | 
			
		||||
        self._ongoing_rpc_tasks.set()
 | 
			
		||||
        # (chan, cid) -> (cancel_scope, func)
 | 
			
		||||
        self._rpc_tasks: Dict[
 | 
			
		||||
        self._rpc_tasks: dict[
 | 
			
		||||
            Tuple[Channel, str],
 | 
			
		||||
            Tuple[trio.CancelScope, typing.Callable, trio.Event]
 | 
			
		||||
        ] = {}
 | 
			
		||||
        # map {uids -> {callids -> waiter queues}}
 | 
			
		||||
        self._cids2qs: Dict[
 | 
			
		||||
        self._cids2qs: dict[
 | 
			
		||||
            Tuple[Tuple[str, str], str],
 | 
			
		||||
            Tuple[
 | 
			
		||||
                trio.abc.SendChannel[Any],
 | 
			
		||||
| 
						 | 
				
			
			@ -396,7 +396,7 @@ class Actor:
 | 
			
		|||
        self._parent_chan: Optional[Channel] = None
 | 
			
		||||
        self._forkserver_info: Optional[
 | 
			
		||||
            Tuple[Any, Any, Any, Any, Any]] = None
 | 
			
		||||
        self._actoruid2nursery: Dict[str, 'ActorNursery'] = {}  # type: ignore  # noqa
 | 
			
		||||
        self._actoruid2nursery: dict[Optional[tuple[str, str]], 'ActorNursery'] = {}  # type: ignore  # noqa
 | 
			
		||||
 | 
			
		||||
    async def wait_for_peer(
 | 
			
		||||
        self, uid: Tuple[str, str]
 | 
			
		||||
| 
						 | 
				
			
			@ -550,6 +550,7 @@ class Actor:
 | 
			
		|||
                    cs.shield = True
 | 
			
		||||
                    # Attempt to wait for the far end to close the channel
 | 
			
		||||
                    # and bail after timeout (2-generals on closure).
 | 
			
		||||
                    assert chan.msgstream
 | 
			
		||||
                    async for msg in chan.msgstream.drain():
 | 
			
		||||
                        # try to deliver any lingering msgs
 | 
			
		||||
                        # before we destroy the channel.
 | 
			
		||||
| 
						 | 
				
			
			@ -616,7 +617,7 @@ class Actor:
 | 
			
		|||
        self,
 | 
			
		||||
        chan: Channel,
 | 
			
		||||
        cid: str,
 | 
			
		||||
        msg: Dict[str, Any],
 | 
			
		||||
        msg: dict[str, Any],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Push an RPC result to the local consumer's queue.
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -877,7 +878,7 @@ class Actor:
 | 
			
		|||
                # machinery not from an rpc task) to parent
 | 
			
		||||
                log.exception("Actor errored:")
 | 
			
		||||
                if self._parent_chan:
 | 
			
		||||
                    await try_ship_error_to_parent(self, err)
 | 
			
		||||
                    await try_ship_error_to_parent(self._parent_chan, err)
 | 
			
		||||
 | 
			
		||||
            # if this is the `MainProcess` we expect the error broadcasting
 | 
			
		||||
            # above to trigger an error at consuming portal "checkpoints"
 | 
			
		||||
| 
						 | 
				
			
			@ -1078,7 +1079,7 @@ class Actor:
 | 
			
		|||
                )
 | 
			
		||||
 | 
			
		||||
            if self._parent_chan:
 | 
			
		||||
                await try_ship_error_to_parent(self, err)
 | 
			
		||||
                await try_ship_error_to_parent(self._parent_chan, err)
 | 
			
		||||
 | 
			
		||||
            # always!
 | 
			
		||||
            log.exception("Actor errored:")
 | 
			
		||||
| 
						 | 
				
			
			@ -1360,7 +1361,7 @@ class Arbiter(Actor):
 | 
			
		|||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
        self._registry: Dict[
 | 
			
		||||
        self._registry: dict[
 | 
			
		||||
            Tuple[str, str],
 | 
			
		||||
            Tuple[str, int],
 | 
			
		||||
        ] = {}
 | 
			
		||||
| 
						 | 
				
			
			@ -1377,7 +1378,7 @@ class Arbiter(Actor):
 | 
			
		|||
 | 
			
		||||
    async def get_registry(
 | 
			
		||||
        self
 | 
			
		||||
    ) -> Dict[Tuple[str, str], Tuple[str, int]]:
 | 
			
		||||
    ) -> dict[Tuple[str, str], Tuple[str, int]]:
 | 
			
		||||
        '''Return current name registry.
 | 
			
		||||
 | 
			
		||||
        This method is async to allow for cross-actor invocation.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,9 +6,10 @@ from __future__ import annotations
 | 
			
		|||
import platform
 | 
			
		||||
import struct
 | 
			
		||||
import typing
 | 
			
		||||
from collections.abc import AsyncGenerator, AsyncIterator
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any, Tuple, Optional,
 | 
			
		||||
    Type, Protocol, TypeVar
 | 
			
		||||
    Type, Protocol, TypeVar,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from tricycle import BufferedReceiveStream
 | 
			
		||||
| 
						 | 
				
			
			@ -46,6 +47,7 @@ MsgType = TypeVar("MsgType")
 | 
			
		|||
class MsgTransport(Protocol[MsgType]):
 | 
			
		||||
 | 
			
		||||
    stream: trio.SocketStream
 | 
			
		||||
    drained: list[MsgType]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, stream: trio.SocketStream) -> None:
 | 
			
		||||
        ...
 | 
			
		||||
| 
						 | 
				
			
			@ -63,6 +65,11 @@ class MsgTransport(Protocol[MsgType]):
 | 
			
		|||
    def connected(self) -> bool:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    # defining this sync otherwise it causes a mypy error because it
 | 
			
		||||
    # can't figure out it's a generator i guess?..?
 | 
			
		||||
    def drain(self) -> AsyncIterator[dict]:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def laddr(self) -> Tuple[str, int]:
 | 
			
		||||
        ...
 | 
			
		||||
| 
						 | 
				
			
			@ -94,9 +101,9 @@ class MsgpackTCPStream:
 | 
			
		|||
        self._send_lock = trio.StrictFIFOLock()
 | 
			
		||||
 | 
			
		||||
        # public i guess?
 | 
			
		||||
        self.drained = []
 | 
			
		||||
        self.drained: list[dict] = []
 | 
			
		||||
 | 
			
		||||
    async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
 | 
			
		||||
    async def _iter_packets(self) -> AsyncGenerator[dict, None]:
 | 
			
		||||
        """Yield packets from the underlying stream.
 | 
			
		||||
        """
 | 
			
		||||
        unpacker = msgpack.Unpacker(
 | 
			
		||||
| 
						 | 
				
			
			@ -159,7 +166,13 @@ class MsgpackTCPStream:
 | 
			
		|||
    async def recv(self) -> Any:
 | 
			
		||||
        return await self._agen.asend(None)
 | 
			
		||||
 | 
			
		||||
    async def drain(self):
 | 
			
		||||
    async def drain(self) -> AsyncIterator[dict]:
 | 
			
		||||
        '''
 | 
			
		||||
        Drain the stream's remaining messages sent from
 | 
			
		||||
        the far end until the connection is closed by
 | 
			
		||||
        the peer.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        try:
 | 
			
		||||
            async for msg in self._iter_packets():
 | 
			
		||||
                self.drained.append(msg)
 | 
			
		||||
| 
						 | 
				
			
			@ -196,7 +209,7 @@ class MsgspecTCPStream(MsgpackTCPStream):
 | 
			
		|||
        self.encode = msgspec.Encoder().encode
 | 
			
		||||
        self.decode = msgspec.Decoder().decode  # dict[str, Any])
 | 
			
		||||
 | 
			
		||||
    async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
 | 
			
		||||
    async def _iter_packets(self) -> AsyncGenerator[dict, None]:
 | 
			
		||||
        '''Yield packets from the underlying stream.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
| 
						 | 
				
			
			@ -458,9 +471,11 @@ class Channel:
 | 
			
		|||
 | 
			
		||||
    async def _aiter_recv(
 | 
			
		||||
        self
 | 
			
		||||
    ) -> typing.AsyncGenerator[Any, None]:
 | 
			
		||||
        """Async iterate items from underlying stream.
 | 
			
		||||
        """
 | 
			
		||||
    ) -> AsyncGenerator[Any, None]:
 | 
			
		||||
        '''
 | 
			
		||||
        Async iterate items from underlying stream.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        assert self.msgstream
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
| 
						 | 
				
			
			@ -490,9 +505,11 @@ class Channel:
 | 
			
		|||
async def _connect_chan(
 | 
			
		||||
    host: str, port: int
 | 
			
		||||
) -> typing.AsyncGenerator[Channel, None]:
 | 
			
		||||
    """Create and connect a channel with disconnect on context manager
 | 
			
		||||
    '''
 | 
			
		||||
    Create and connect a channel with disconnect on context manager
 | 
			
		||||
    teardown.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    chan = Channel((host, port))
 | 
			
		||||
    await chan.connect()
 | 
			
		||||
    yield chan
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,7 +5,11 @@ Machinery for actor process spawning using multiple backends.
 | 
			
		|||
import sys
 | 
			
		||||
import multiprocessing as mp
 | 
			
		||||
import platform
 | 
			
		||||
from typing import Any, Dict, Optional, Union, Callable
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any, Dict, Optional, Union, Callable,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
)
 | 
			
		||||
from collections.abc import Awaitable, Coroutine
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
from trio_typing import TaskStatus
 | 
			
		||||
| 
						 | 
				
			
			@ -41,6 +45,7 @@ from ._exceptions import ActorFailure
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
log = get_logger('tractor')
 | 
			
		||||
ProcessType = TypeVar('ProcessType', mp.Process, trio.Process)
 | 
			
		||||
 | 
			
		||||
# placeholder for an mp start context if so using that backend
 | 
			
		||||
_ctx: Optional[mp.context.BaseContext] = None
 | 
			
		||||
| 
						 | 
				
			
			@ -185,10 +190,10 @@ async def do_hard_kill(
 | 
			
		|||
 | 
			
		||||
async def soft_wait(
 | 
			
		||||
 | 
			
		||||
    proc: Union[mp.Process, trio.Process],
 | 
			
		||||
    proc: ProcessType,
 | 
			
		||||
    wait_func: Callable[
 | 
			
		||||
        Union[mp.Process, trio.Process],
 | 
			
		||||
        None,
 | 
			
		||||
        [ProcessType],
 | 
			
		||||
        Awaitable,
 | 
			
		||||
    ],
 | 
			
		||||
    portal: Portal,
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue