Compare commits
	
		
			18 Commits 
		
	
	
		
			master
			...
			windows_bi
		
	
	| Author | SHA1 | Date | 
|---|---|---|
| 
							
							
								 | 
						2be69bb9fb | |
| 
							
							
								 | 
						8f25f2d2fa | |
| 
							
							
								 | 
						8297e765c1 | |
| 
							
							
								 | 
						0090a57681 | |
| 
							
							
								 | 
						582eae699a | |
| 
							
							
								 | 
						b0bcb430bf | |
| 
							
							
								 | 
						7eb76e8d97 | |
| 
							
							
								 | 
						55bea3ca17 | |
| 
							
							
								 | 
						e4216b0691 | |
| 
							
							
								 | 
						f0ceb9a811 | |
| 
							
							
								 | 
						9793851134 | |
| 
							
							
								 | 
						8cbe519d41 | |
| 
							
							
								 | 
						613e613b4c | |
| 
							
							
								 | 
						5ff5e7a6ef | |
| 
							
							
								 | 
						a166a62b31 | |
| 
							
							
								 | 
						265120afd9 | |
| 
							
							
								 | 
						ae6aa75bcd | |
| 
							
							
								 | 
						0027115589 | 
| 
						 | 
				
			
			@ -0,0 +1,212 @@
 | 
			
		|||
"""
 | 
			
		||||
Bidirectional streaming and context API.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
import pytest
 | 
			
		||||
import trio
 | 
			
		||||
import tractor
 | 
			
		||||
 | 
			
		||||
# from conftest import tractor_test
 | 
			
		||||
 | 
			
		||||
# TODO: test endofchannel semantics / cancellation / error cases:
 | 
			
		||||
# 3 possible outcomes:
 | 
			
		||||
# - normal termination: far end relays a stop message with
 | 
			
		||||
# final value as in async gen from ``return <val>``.
 | 
			
		||||
 | 
			
		||||
# possible outcomes:
 | 
			
		||||
# - normal termination: far end returns
 | 
			
		||||
# - premature close: far end relays a stop message to tear down stream
 | 
			
		||||
# - cancel: far end raises `ContextCancelled`
 | 
			
		||||
 | 
			
		||||
# future possible outcomes
 | 
			
		||||
# - restart request: far end raises `ContextRestart`
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_state: bool = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@tractor.context
 | 
			
		||||
async def simple_setup_teardown(
 | 
			
		||||
 | 
			
		||||
    ctx: tractor.Context,
 | 
			
		||||
    data: int,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
    # startup phase
 | 
			
		||||
    global _state
 | 
			
		||||
    _state = True
 | 
			
		||||
 | 
			
		||||
    # signal to parent that we're up
 | 
			
		||||
    await ctx.started(data + 1)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        # block until cancelled
 | 
			
		||||
        await trio.sleep_forever()
 | 
			
		||||
    finally:
 | 
			
		||||
        _state = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def assert_state(value: bool):
 | 
			
		||||
    global _state
 | 
			
		||||
    assert _state == value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    'error_parent',
 | 
			
		||||
    [False, True],
 | 
			
		||||
)
 | 
			
		||||
def test_simple_context(error_parent):
 | 
			
		||||
 | 
			
		||||
    async def main():
 | 
			
		||||
 | 
			
		||||
        async with tractor.open_nursery() as n:
 | 
			
		||||
 | 
			
		||||
            portal = await n.start_actor(
 | 
			
		||||
                'simple_context',
 | 
			
		||||
                enable_modules=[__name__],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            async with portal.open_context(
 | 
			
		||||
                simple_setup_teardown,
 | 
			
		||||
                data=10,
 | 
			
		||||
            ) as (ctx, sent):
 | 
			
		||||
 | 
			
		||||
                assert sent == 11
 | 
			
		||||
 | 
			
		||||
                await portal.run(assert_state, value=True)
 | 
			
		||||
 | 
			
		||||
            # after cancellation
 | 
			
		||||
            await portal.run(assert_state, value=False)
 | 
			
		||||
 | 
			
		||||
            if error_parent:
 | 
			
		||||
                raise ValueError
 | 
			
		||||
 | 
			
		||||
            # shut down daemon
 | 
			
		||||
            await portal.cancel_actor()
 | 
			
		||||
 | 
			
		||||
    if error_parent:
 | 
			
		||||
        try:
 | 
			
		||||
            trio.run(main)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            pass
 | 
			
		||||
    else:
 | 
			
		||||
        trio.run(main)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@tractor.context
 | 
			
		||||
async def simple_rpc(
 | 
			
		||||
 | 
			
		||||
    ctx: tractor.Context,
 | 
			
		||||
    data: int,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Test a small ping-pong server.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    # signal to parent that we're up
 | 
			
		||||
    await ctx.started(data + 1)
 | 
			
		||||
 | 
			
		||||
    print('opening stream in callee')
 | 
			
		||||
    async with ctx.open_stream() as stream:
 | 
			
		||||
 | 
			
		||||
        count = 0
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                await stream.receive() == 'ping'
 | 
			
		||||
            except trio.EndOfChannel:
 | 
			
		||||
                assert count == 10
 | 
			
		||||
                break
 | 
			
		||||
            else:
 | 
			
		||||
                print('pong')
 | 
			
		||||
                await stream.send('pong')
 | 
			
		||||
                count += 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@tractor.context
 | 
			
		||||
async def simple_rpc_with_forloop(
 | 
			
		||||
 | 
			
		||||
    ctx: tractor.Context,
 | 
			
		||||
    data: int,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Same as previous test but using ``async for`` syntax/api.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # signal to parent that we're up
 | 
			
		||||
    await ctx.started(data + 1)
 | 
			
		||||
 | 
			
		||||
    print('opening stream in callee')
 | 
			
		||||
    async with ctx.open_stream() as stream:
 | 
			
		||||
 | 
			
		||||
        count = 0
 | 
			
		||||
        async for msg in stream:
 | 
			
		||||
 | 
			
		||||
            assert msg == 'ping'
 | 
			
		||||
            print('pong')
 | 
			
		||||
            await stream.send('pong')
 | 
			
		||||
            count += 1
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            assert count == 10
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    'use_async_for',
 | 
			
		||||
    [True, False],
 | 
			
		||||
)
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    'server_func',
 | 
			
		||||
    [simple_rpc, simple_rpc_with_forloop],
 | 
			
		||||
)
 | 
			
		||||
def test_simple_rpc(server_func, use_async_for):
 | 
			
		||||
    """The simplest request response pattern.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    async def main():
 | 
			
		||||
        async with tractor.open_nursery() as n:
 | 
			
		||||
 | 
			
		||||
            portal = await n.start_actor(
 | 
			
		||||
                'rpc_server',
 | 
			
		||||
                enable_modules=[__name__],
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            async with portal.open_context(
 | 
			
		||||
                server_func,  # taken from pytest parameterization
 | 
			
		||||
                data=10,
 | 
			
		||||
            ) as (ctx, sent):
 | 
			
		||||
 | 
			
		||||
                assert sent == 11
 | 
			
		||||
 | 
			
		||||
                async with ctx.open_stream() as stream:
 | 
			
		||||
 | 
			
		||||
                    if use_async_for:
 | 
			
		||||
 | 
			
		||||
                        count = 0
 | 
			
		||||
                        # receive msgs using async for style
 | 
			
		||||
                        print('ping')
 | 
			
		||||
                        await stream.send('ping')
 | 
			
		||||
 | 
			
		||||
                        async for msg in stream:
 | 
			
		||||
                            assert msg == 'pong'
 | 
			
		||||
                            print('ping')
 | 
			
		||||
                            await stream.send('ping')
 | 
			
		||||
                            count += 1
 | 
			
		||||
 | 
			
		||||
                            if count >= 9:
 | 
			
		||||
                                break
 | 
			
		||||
 | 
			
		||||
                    else:
 | 
			
		||||
                        # classic send/receive style
 | 
			
		||||
                        for _ in range(10):
 | 
			
		||||
 | 
			
		||||
                            print('ping')
 | 
			
		||||
                            await stream.send('ping')
 | 
			
		||||
                            assert await stream.receive() == 'pong'
 | 
			
		||||
 | 
			
		||||
                # stream should terminate here
 | 
			
		||||
 | 
			
		||||
            await portal.cancel_actor()
 | 
			
		||||
 | 
			
		||||
    trio.run(main)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,220 @@
 | 
			
		|||
"""
 | 
			
		||||
Advanced streaming patterns using bidirectional streams and contexts.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
import itertools
 | 
			
		||||
from typing import Set, Dict, List
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
import tractor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_registry: Dict[str, Set[tractor.ReceiveMsgStream]] = {
 | 
			
		||||
    'even': set(),
 | 
			
		||||
    'odd': set(),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def publisher(
 | 
			
		||||
 | 
			
		||||
    seed: int = 0,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
    global _registry
 | 
			
		||||
 | 
			
		||||
    def is_even(i):
 | 
			
		||||
        return i % 2 == 0
 | 
			
		||||
 | 
			
		||||
    for val in itertools.count(seed):
 | 
			
		||||
 | 
			
		||||
        sub = 'even' if is_even(val) else 'odd'
 | 
			
		||||
 | 
			
		||||
        for sub_stream in _registry[sub]:
 | 
			
		||||
            await sub_stream.send(val)
 | 
			
		||||
 | 
			
		||||
        # throttle send rate to ~4Hz
 | 
			
		||||
        # making it readable to a human user
 | 
			
		||||
        await trio.sleep(1/4)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@tractor.context
 | 
			
		||||
async def subscribe(
 | 
			
		||||
 | 
			
		||||
    ctx: tractor.Context,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
    global _registry
 | 
			
		||||
 | 
			
		||||
    # syn caller
 | 
			
		||||
    await ctx.started(None)
 | 
			
		||||
 | 
			
		||||
    async with ctx.open_stream() as stream:
 | 
			
		||||
 | 
			
		||||
        # update subs list as consumer requests
 | 
			
		||||
        async for new_subs in stream:
 | 
			
		||||
 | 
			
		||||
            new_subs = set(new_subs)
 | 
			
		||||
            remove = new_subs - _registry.keys()
 | 
			
		||||
 | 
			
		||||
            print(f'setting sub to {new_subs} for {ctx.chan.uid}')
 | 
			
		||||
 | 
			
		||||
            # remove old subs
 | 
			
		||||
            for sub in remove:
 | 
			
		||||
                _registry[sub].remove(stream)
 | 
			
		||||
 | 
			
		||||
            # add new subs for consumer
 | 
			
		||||
            for sub in new_subs:
 | 
			
		||||
                _registry[sub].add(stream)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def consumer(
 | 
			
		||||
 | 
			
		||||
    subs: List[str],
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
    uid = tractor.current_actor().uid
 | 
			
		||||
 | 
			
		||||
    async with tractor.wait_for_actor('publisher') as portal:
 | 
			
		||||
        async with portal.open_context(subscribe) as (ctx, first):
 | 
			
		||||
            async with ctx.open_stream() as stream:
 | 
			
		||||
 | 
			
		||||
                # flip between the provided subs dynamically
 | 
			
		||||
                if len(subs) > 1:
 | 
			
		||||
 | 
			
		||||
                    for sub in itertools.cycle(subs):
 | 
			
		||||
                        print(f'setting dynamic sub to {sub}')
 | 
			
		||||
                        await stream.send([sub])
 | 
			
		||||
 | 
			
		||||
                        count = 0
 | 
			
		||||
                        async for value in stream:
 | 
			
		||||
                            print(f'{uid} got: {value}')
 | 
			
		||||
                            if count > 5:
 | 
			
		||||
                                break
 | 
			
		||||
                            count += 1
 | 
			
		||||
 | 
			
		||||
                else:  # static sub
 | 
			
		||||
 | 
			
		||||
                    await stream.send(subs)
 | 
			
		||||
                    async for value in stream:
 | 
			
		||||
                        print(f'{uid} got: {value}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_dynamic_pub_sub():
 | 
			
		||||
 | 
			
		||||
    global _registry
 | 
			
		||||
 | 
			
		||||
    from multiprocessing import cpu_count
 | 
			
		||||
    cpus = cpu_count()
 | 
			
		||||
 | 
			
		||||
    async def main():
 | 
			
		||||
        async with tractor.open_nursery() as n:
 | 
			
		||||
 | 
			
		||||
            # name of this actor will be same as target func
 | 
			
		||||
            await n.run_in_actor(publisher)
 | 
			
		||||
 | 
			
		||||
            for i, sub in zip(
 | 
			
		||||
                range(cpus - 2),
 | 
			
		||||
                itertools.cycle(_registry.keys())
 | 
			
		||||
            ):
 | 
			
		||||
                await n.run_in_actor(
 | 
			
		||||
                    consumer,
 | 
			
		||||
                    name=f'consumer_{sub}',
 | 
			
		||||
                    subs=[sub],
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # make one dynamic subscriber
 | 
			
		||||
            await n.run_in_actor(
 | 
			
		||||
                consumer,
 | 
			
		||||
                name='consumer_dynamic',
 | 
			
		||||
                subs=list(_registry.keys()),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # block until cancelled by user
 | 
			
		||||
            with trio.fail_after(10):
 | 
			
		||||
                await trio.sleep_forever()
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        trio.run(main)
 | 
			
		||||
    except trio.TooSlowError:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@tractor.context
 | 
			
		||||
async def one_task_streams_and_one_handles_reqresp(
 | 
			
		||||
 | 
			
		||||
    ctx: tractor.Context,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
    await ctx.started()
 | 
			
		||||
 | 
			
		||||
    async with ctx.open_stream() as stream:
 | 
			
		||||
 | 
			
		||||
        async def pingpong():
 | 
			
		||||
            '''Run a simple req/response service.
 | 
			
		||||
 | 
			
		||||
            '''
 | 
			
		||||
            async for msg in stream:
 | 
			
		||||
                print('rpc server ping')
 | 
			
		||||
                assert msg == 'ping'
 | 
			
		||||
                print('rpc server pong')
 | 
			
		||||
                await stream.send('pong')
 | 
			
		||||
 | 
			
		||||
        async with trio.open_nursery() as n:
 | 
			
		||||
            n.start_soon(pingpong)
 | 
			
		||||
 | 
			
		||||
            for _ in itertools.count():
 | 
			
		||||
                await stream.send('yo')
 | 
			
		||||
                await trio.sleep(0.01)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_reqresp_ontopof_streaming():
 | 
			
		||||
    '''Test a subactor that both streams with one task and
 | 
			
		||||
    spawns another which handles a small requests-response
 | 
			
		||||
    dialogue over the same bidir-stream.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    async def main():
 | 
			
		||||
 | 
			
		||||
        with trio.move_on_after(2):
 | 
			
		||||
            async with tractor.open_nursery() as n:
 | 
			
		||||
 | 
			
		||||
                # name of this actor will be same as target func
 | 
			
		||||
                portal = await n.start_actor(
 | 
			
		||||
                    'dual_tasks',
 | 
			
		||||
                    enable_modules=[__name__]
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # flat to make sure we get at least one pong
 | 
			
		||||
                got_pong: bool = False
 | 
			
		||||
 | 
			
		||||
                async with portal.open_context(
 | 
			
		||||
                    one_task_streams_and_one_handles_reqresp,
 | 
			
		||||
 | 
			
		||||
                ) as (ctx, first):
 | 
			
		||||
 | 
			
		||||
                    assert first is None
 | 
			
		||||
 | 
			
		||||
                    async with ctx.open_stream() as stream:
 | 
			
		||||
 | 
			
		||||
                        await stream.send('ping')
 | 
			
		||||
 | 
			
		||||
                        async for msg in stream:
 | 
			
		||||
                            print(f'client received: {msg}')
 | 
			
		||||
 | 
			
		||||
                            assert msg in {'pong', 'yo'}
 | 
			
		||||
 | 
			
		||||
                            if msg == 'pong':
 | 
			
		||||
                                got_pong = True
 | 
			
		||||
                                await stream.send('ping')
 | 
			
		||||
                                print('client sent ping')
 | 
			
		||||
 | 
			
		||||
        assert got_pong
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        trio.run(main)
 | 
			
		||||
    except trio.TooSlowError:
 | 
			
		||||
        pass
 | 
			
		||||
| 
						 | 
				
			
			@ -338,6 +338,8 @@ async def test_respawn_consumer_task(
 | 
			
		|||
                        print("all values streamed, BREAKING")
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
                cs.cancel()
 | 
			
		||||
 | 
			
		||||
        # TODO: this is justification for a
 | 
			
		||||
        # ``ActorNursery.stream_from_actor()`` helper?
 | 
			
		||||
        await portal.cancel_actor()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,7 +5,13 @@ tractor: An actor model micro-framework built on
 | 
			
		|||
from trio import MultiError
 | 
			
		||||
 | 
			
		||||
from ._ipc import Channel
 | 
			
		||||
from ._streaming import Context, stream
 | 
			
		||||
from ._streaming import (
 | 
			
		||||
    Context,
 | 
			
		||||
    ReceiveMsgStream,
 | 
			
		||||
    MsgStream,
 | 
			
		||||
    stream,
 | 
			
		||||
    context,
 | 
			
		||||
)
 | 
			
		||||
from ._discovery import get_arbiter, find_actor, wait_for_actor
 | 
			
		||||
from ._trionics import open_nursery
 | 
			
		||||
from ._state import current_actor, is_root_process
 | 
			
		||||
| 
						 | 
				
			
			@ -33,7 +39,7 @@ __all__ = [
 | 
			
		|||
    'run',
 | 
			
		||||
    'run_daemon',
 | 
			
		||||
    'stream',
 | 
			
		||||
    'wait_for_actor',
 | 
			
		||||
    'context',
 | 
			
		||||
    'to_asyncio',
 | 
			
		||||
    'wait_for_actor',
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,6 +14,7 @@ from types import ModuleType
 | 
			
		|||
import sys
 | 
			
		||||
import os
 | 
			
		||||
from contextlib import ExitStack
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
import trio  # type: ignore
 | 
			
		||||
from trio_typing import TaskStatus
 | 
			
		||||
| 
						 | 
				
			
			@ -57,13 +58,37 @@ async def _invoke(
 | 
			
		|||
    treat_as_gen = False
 | 
			
		||||
    cs = None
 | 
			
		||||
    cancel_scope = trio.CancelScope()
 | 
			
		||||
    ctx = Context(chan, cid, cancel_scope)
 | 
			
		||||
    ctx = Context(chan, cid, _cancel_scope=cancel_scope)
 | 
			
		||||
    context = False
 | 
			
		||||
 | 
			
		||||
    if getattr(func, '_tractor_stream_function', False):
 | 
			
		||||
        # handle decorated ``@tractor.stream`` async functions
 | 
			
		||||
        sig = inspect.signature(func)
 | 
			
		||||
        params = sig.parameters
 | 
			
		||||
 | 
			
		||||
        # compat with old api
 | 
			
		||||
        kwargs['ctx'] = ctx
 | 
			
		||||
 | 
			
		||||
        if 'ctx' in params:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "`@tractor.stream decorated funcs should now declare "
 | 
			
		||||
                "a `stream`  arg, `ctx` is now designated for use with "
 | 
			
		||||
                "@tractor.context",
 | 
			
		||||
                DeprecationWarning,
 | 
			
		||||
                stacklevel=2,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        elif 'stream' in params:
 | 
			
		||||
            assert 'stream' in params
 | 
			
		||||
            kwargs['stream'] = ctx
 | 
			
		||||
 | 
			
		||||
        treat_as_gen = True
 | 
			
		||||
 | 
			
		||||
    elif getattr(func, '_tractor_context_function', False):
 | 
			
		||||
        # handle decorated ``@tractor.context`` async function
 | 
			
		||||
        kwargs['ctx'] = ctx
 | 
			
		||||
        context = True
 | 
			
		||||
 | 
			
		||||
    # errors raised inside this block are propgated back to caller
 | 
			
		||||
    try:
 | 
			
		||||
        if not (
 | 
			
		||||
| 
						 | 
				
			
			@ -101,26 +126,41 @@ async def _invoke(
 | 
			
		|||
            # `StopAsyncIteration` system here for returning a final
 | 
			
		||||
            # value if desired
 | 
			
		||||
            await chan.send({'stop': True, 'cid': cid})
 | 
			
		||||
 | 
			
		||||
        # one way @stream func that gets treated like an async gen
 | 
			
		||||
        elif treat_as_gen:
 | 
			
		||||
            await chan.send({'functype': 'asyncgen', 'cid': cid})
 | 
			
		||||
            # XXX: the async-func may spawn further tasks which push
 | 
			
		||||
            # back values like an async-generator would but must
 | 
			
		||||
            # manualy construct the response dict-packet-responses as
 | 
			
		||||
            # above
 | 
			
		||||
            with cancel_scope as cs:
 | 
			
		||||
                task_status.started(cs)
 | 
			
		||||
                await coro
 | 
			
		||||
 | 
			
		||||
            if not cs.cancelled_caught:
 | 
			
		||||
                # task was not cancelled so we can instruct the
 | 
			
		||||
                # far end async gen to tear down
 | 
			
		||||
                await chan.send({'stop': True, 'cid': cid})
 | 
			
		||||
 | 
			
		||||
        elif context:
 | 
			
		||||
            # context func with support for bi-dir streaming
 | 
			
		||||
            await chan.send({'functype': 'context', 'cid': cid})
 | 
			
		||||
 | 
			
		||||
            with cancel_scope as cs:
 | 
			
		||||
                task_status.started(cs)
 | 
			
		||||
                await chan.send({'return': await coro, 'cid': cid})
 | 
			
		||||
 | 
			
		||||
            # if cs.cancelled_caught:
 | 
			
		||||
            #     # task was cancelled so relay to the cancel to caller
 | 
			
		||||
            #     await chan.send({'return': await coro, 'cid': cid})
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            if treat_as_gen:
 | 
			
		||||
                await chan.send({'functype': 'asyncgen', 'cid': cid})
 | 
			
		||||
                # XXX: the async-func may spawn further tasks which push
 | 
			
		||||
                # back values like an async-generator would but must
 | 
			
		||||
                # manualy construct the response dict-packet-responses as
 | 
			
		||||
                # above
 | 
			
		||||
                with cancel_scope as cs:
 | 
			
		||||
                    task_status.started(cs)
 | 
			
		||||
                    await coro
 | 
			
		||||
                if not cs.cancelled_caught:
 | 
			
		||||
                    # task was not cancelled so we can instruct the
 | 
			
		||||
                    # far end async gen to tear down
 | 
			
		||||
                    await chan.send({'stop': True, 'cid': cid})
 | 
			
		||||
            else:
 | 
			
		||||
                # regular async function
 | 
			
		||||
                await chan.send({'functype': 'asyncfunc', 'cid': cid})
 | 
			
		||||
                with cancel_scope as cs:
 | 
			
		||||
                    task_status.started(cs)
 | 
			
		||||
                    await chan.send({'return': await coro, 'cid': cid})
 | 
			
		||||
            # regular async function
 | 
			
		||||
            await chan.send({'functype': 'asyncfunc', 'cid': cid})
 | 
			
		||||
            with cancel_scope as cs:
 | 
			
		||||
                task_status.started(cs)
 | 
			
		||||
                await chan.send({'return': await coro, 'cid': cid})
 | 
			
		||||
 | 
			
		||||
    except (Exception, trio.MultiError) as err:
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -404,10 +444,10 @@ class Actor:
 | 
			
		|||
        send_chan, recv_chan = self._cids2qs[(actorid, cid)]
 | 
			
		||||
        assert send_chan.cid == cid  # type: ignore
 | 
			
		||||
 | 
			
		||||
        if 'stop' in msg:
 | 
			
		||||
            log.debug(f"{send_chan} was terminated at remote end")
 | 
			
		||||
            # indicate to consumer that far end has stopped
 | 
			
		||||
            return await send_chan.aclose()
 | 
			
		||||
        # if 'stop' in msg:
 | 
			
		||||
        #     log.debug(f"{send_chan} was terminated at remote end")
 | 
			
		||||
        #     # indicate to consumer that far end has stopped
 | 
			
		||||
        #     return await send_chan.aclose()
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            log.debug(f"Delivering {msg} from {actorid} to caller {cid}")
 | 
			
		||||
| 
						 | 
				
			
			@ -415,6 +455,12 @@ class Actor:
 | 
			
		|||
            await send_chan.send(msg)
 | 
			
		||||
 | 
			
		||||
        except trio.BrokenResourceError:
 | 
			
		||||
            # TODO: what is the right way to handle the case where the
 | 
			
		||||
            # local task has already sent a 'stop' / StopAsyncInteration
 | 
			
		||||
            # to the other side but and possibly has closed the local
 | 
			
		||||
            # feeder mem chan? Do we wait for some kind of ack or just
 | 
			
		||||
            # let this fail silently and bubble up (currently)?
 | 
			
		||||
 | 
			
		||||
            # XXX: local consumer has closed their side
 | 
			
		||||
            # so cancel the far end streaming task
 | 
			
		||||
            log.warning(f"{send_chan} consumer is already closed")
 | 
			
		||||
| 
						 | 
				
			
			@ -477,11 +523,14 @@ class Actor:
 | 
			
		|||
                task_status.started(loop_cs)
 | 
			
		||||
                async for msg in chan:
 | 
			
		||||
                    if msg is None:  # loop terminate sentinel
 | 
			
		||||
 | 
			
		||||
                        log.debug(
 | 
			
		||||
                            f"Cancelling all tasks for {chan} from {chan.uid}")
 | 
			
		||||
                        for (channel, cid) in self._rpc_tasks:
 | 
			
		||||
 | 
			
		||||
                        for (channel, cid) in self._rpc_tasks.copy():
 | 
			
		||||
                            if channel is chan:
 | 
			
		||||
                                await self._cancel_task(cid, channel)
 | 
			
		||||
 | 
			
		||||
                        log.debug(
 | 
			
		||||
                                f"Msg loop signalled to terminate for"
 | 
			
		||||
                                f" {chan} from {chan.uid}")
 | 
			
		||||
| 
						 | 
				
			
			@ -494,6 +543,7 @@ class Actor:
 | 
			
		|||
                    if cid:
 | 
			
		||||
                        # deliver response to local caller/waiter
 | 
			
		||||
                        await self._push_result(chan, cid, msg)
 | 
			
		||||
 | 
			
		||||
                        log.debug(
 | 
			
		||||
                            f"Waiting on next msg for {chan} from {chan.uid}")
 | 
			
		||||
                        continue
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,13 +1,13 @@
 | 
			
		|||
"""
 | 
			
		||||
Multi-core debugging for da peeps!
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
import bdb
 | 
			
		||||
import sys
 | 
			
		||||
from functools import partial
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from typing import Awaitable, Tuple, Optional, Callable, AsyncIterator
 | 
			
		||||
from typing import Tuple, Optional, Callable, AsyncIterator
 | 
			
		||||
 | 
			
		||||
from async_generator import aclosing
 | 
			
		||||
import tractor
 | 
			
		||||
import trio
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -31,14 +31,21 @@ log = get_logger(__name__)
 | 
			
		|||
 | 
			
		||||
__all__ = ['breakpoint', 'post_mortem']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: wrap all these in a static global class: ``DebugLock`` maybe?
 | 
			
		||||
 | 
			
		||||
# placeholder for function to set a ``trio.Event`` on debugger exit
 | 
			
		||||
_pdb_release_hook: Optional[Callable] = None
 | 
			
		||||
 | 
			
		||||
# actor-wide variable pointing to current task name using debugger
 | 
			
		||||
_in_debug = False
 | 
			
		||||
_local_task_in_debug: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
# actor tree-wide actor uid that supposedly has the tty lock
 | 
			
		||||
_global_actor_in_debug: Optional[Tuple[str, str]] = None
 | 
			
		||||
 | 
			
		||||
# lock in root actor preventing multi-access to local tty
 | 
			
		||||
_debug_lock = trio.StrictFIFOLock()
 | 
			
		||||
_debug_lock: trio.StrictFIFOLock = trio.StrictFIFOLock()
 | 
			
		||||
_pdb_complete: Optional[trio.Event] = None
 | 
			
		||||
 | 
			
		||||
# XXX: set by the current task waiting on the root tty lock
 | 
			
		||||
# and must be cancelled if this actor is cancelled via message
 | 
			
		||||
| 
						 | 
				
			
			@ -61,19 +68,19 @@ class PdbwTeardown(pdbpp.Pdb):
 | 
			
		|||
    # TODO: figure out how to dissallow recursive .set_trace() entry
 | 
			
		||||
    # since that'll cause deadlock for us.
 | 
			
		||||
    def set_continue(self):
 | 
			
		||||
        global _in_debug
 | 
			
		||||
        try:
 | 
			
		||||
            super().set_continue()
 | 
			
		||||
        finally:
 | 
			
		||||
            _in_debug = False
 | 
			
		||||
            global _local_task_in_debug
 | 
			
		||||
            _local_task_in_debug = None
 | 
			
		||||
            _pdb_release_hook()
 | 
			
		||||
 | 
			
		||||
    def set_quit(self):
 | 
			
		||||
        global _in_debug
 | 
			
		||||
        try:
 | 
			
		||||
            super().set_quit()
 | 
			
		||||
        finally:
 | 
			
		||||
            _in_debug = False
 | 
			
		||||
            global _local_task_in_debug
 | 
			
		||||
            _local_task_in_debug = None
 | 
			
		||||
            _pdb_release_hook()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -119,18 +126,22 @@ async def _acquire_debug_lock(uid: Tuple[str, str]) -> AsyncIterator[None]:
 | 
			
		|||
    """Acquire a actor local FIFO lock meant to mutex entry to a local
 | 
			
		||||
    debugger entry point to avoid tty clobbering by multiple processes.
 | 
			
		||||
    """
 | 
			
		||||
    task_name = trio.lowlevel.current_task().name
 | 
			
		||||
    try:
 | 
			
		||||
        log.debug(
 | 
			
		||||
            f"Attempting to acquire TTY lock, remote task: {task_name}:{uid}")
 | 
			
		||||
        await _debug_lock.acquire()
 | 
			
		||||
    global _debug_lock, _global_actor_in_debug
 | 
			
		||||
 | 
			
		||||
    task_name = trio.lowlevel.current_task().name
 | 
			
		||||
 | 
			
		||||
    log.debug(
 | 
			
		||||
        f"Attempting to acquire TTY lock, remote task: {task_name}:{uid}")
 | 
			
		||||
 | 
			
		||||
    async with _debug_lock:
 | 
			
		||||
 | 
			
		||||
        # _debug_lock._uid = uid
 | 
			
		||||
        _global_actor_in_debug = uid
 | 
			
		||||
        log.debug(f"TTY lock acquired, remote task: {task_name}:{uid}")
 | 
			
		||||
        yield
 | 
			
		||||
 | 
			
		||||
    finally:
 | 
			
		||||
        _debug_lock.release()
 | 
			
		||||
        log.debug(f"TTY lock released, remote task: {task_name}:{uid}")
 | 
			
		||||
    _global_actor_in_debug = None
 | 
			
		||||
    log.debug(f"TTY lock released, remote task: {task_name}:{uid}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @contextmanager
 | 
			
		||||
| 
						 | 
				
			
			@ -144,118 +155,160 @@ async def _acquire_debug_lock(uid: Tuple[str, str]) -> AsyncIterator[None]:
 | 
			
		|||
#         signal.signal(signal.SIGINT, prior_handler)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@tractor.context
 | 
			
		||||
async def _hijack_stdin_relay_to_child(
 | 
			
		||||
 | 
			
		||||
    ctx: tractor.Context,
 | 
			
		||||
    subactor_uid: Tuple[str, str]
 | 
			
		||||
) -> AsyncIterator[str]:
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
    global _pdb_complete
 | 
			
		||||
 | 
			
		||||
    task_name = trio.lowlevel.current_task().name
 | 
			
		||||
 | 
			
		||||
    # TODO: when we get to true remote debugging
 | 
			
		||||
    # this will deliver stdin data
 | 
			
		||||
    log.warning(f"Actor {subactor_uid} is WAITING on stdin hijack lock")
 | 
			
		||||
    # this will deliver stdin data?
 | 
			
		||||
 | 
			
		||||
    log.debug(
 | 
			
		||||
        "Attempting to acquire TTY lock, "
 | 
			
		||||
        f"remote task: {task_name}:{subactor_uid}"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    log.debug(f"Actor {subactor_uid} is WAITING on stdin hijack lock")
 | 
			
		||||
 | 
			
		||||
    async with _acquire_debug_lock(subactor_uid):
 | 
			
		||||
        log.warning(f"Actor {subactor_uid} ACQUIRED stdin hijack lock")
 | 
			
		||||
 | 
			
		||||
        # with _disable_sigint():
 | 
			
		||||
        with trio.CancelScope(shield=True):
 | 
			
		||||
 | 
			
		||||
        # indicate to child that we've locked stdio
 | 
			
		||||
        yield 'Locked'
 | 
			
		||||
            # indicate to child that we've locked stdio
 | 
			
		||||
            await ctx.started('Locked')
 | 
			
		||||
            log.runtime(  # type: ignore
 | 
			
		||||
                f"Actor {subactor_uid} ACQUIRED stdin hijack lock")
 | 
			
		||||
 | 
			
		||||
        # wait for cancellation of stream by child
 | 
			
		||||
        # indicating debugger is dis-engaged
 | 
			
		||||
        await trio.sleep_forever()
 | 
			
		||||
        # wait for unlock pdb by child
 | 
			
		||||
        async with ctx.open_stream() as stream:
 | 
			
		||||
            assert await stream.receive() == 'Unlock'
 | 
			
		||||
 | 
			
		||||
    log.debug(
 | 
			
		||||
        f"TTY lock released, remote task: {task_name}:{subactor_uid}")
 | 
			
		||||
 | 
			
		||||
    log.debug(f"Actor {subactor_uid} RELEASED stdin hijack lock")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# XXX: We only make this sync in case someone wants to
 | 
			
		||||
# overload the ``breakpoint()`` built-in.
 | 
			
		||||
def _breakpoint(debug_func) -> Awaitable[None]:
 | 
			
		||||
async def _breakpoint(debug_func) -> None:
 | 
			
		||||
    """``tractor`` breakpoint entry for engaging pdb machinery
 | 
			
		||||
    in subactors.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    actor = tractor.current_actor()
 | 
			
		||||
    do_unlock = trio.Event()
 | 
			
		||||
    task_name = trio.lowlevel.current_task().name
 | 
			
		||||
 | 
			
		||||
    global _pdb_complete, _pdb_release_hook
 | 
			
		||||
    global _local_task_in_debug, _global_actor_in_debug
 | 
			
		||||
 | 
			
		||||
    async def wait_for_parent_stdin_hijack(
 | 
			
		||||
        task_status=trio.TASK_STATUS_IGNORED
 | 
			
		||||
    ):
 | 
			
		||||
        global _debugger_request_cs
 | 
			
		||||
 | 
			
		||||
        with trio.CancelScope() as cs:
 | 
			
		||||
            _debugger_request_cs = cs
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                async with get_root() as portal:
 | 
			
		||||
                        async with portal.open_stream_from(
 | 
			
		||||
                            tractor._debug._hijack_stdin_relay_to_child,
 | 
			
		||||
                            subactor_uid=actor.uid,
 | 
			
		||||
                        ) as stream:
 | 
			
		||||
 | 
			
		||||
                                # block until first yield above
 | 
			
		||||
                                async for val in stream:
 | 
			
		||||
                    # this syncs to child's ``Context.started()`` call.
 | 
			
		||||
                    async with portal.open_context(
 | 
			
		||||
 | 
			
		||||
                                    assert val == 'Locked'
 | 
			
		||||
                                    task_status.started()
 | 
			
		||||
                        tractor._debug._hijack_stdin_relay_to_child,
 | 
			
		||||
                        subactor_uid=actor.uid,
 | 
			
		||||
 | 
			
		||||
                                    # with trio.CancelScope(shield=True):
 | 
			
		||||
                                    await do_unlock.wait()
 | 
			
		||||
                    ) as (ctx, val):
 | 
			
		||||
 | 
			
		||||
                        assert val == 'Locked'
 | 
			
		||||
 | 
			
		||||
                        async with ctx.open_stream() as stream:
 | 
			
		||||
 | 
			
		||||
                            # unblock local caller
 | 
			
		||||
                            task_status.started()
 | 
			
		||||
 | 
			
		||||
                            await _pdb_complete.wait()
 | 
			
		||||
                            await stream.send('Unlock')
 | 
			
		||||
 | 
			
		||||
                                    # trigger cancellation of remote stream
 | 
			
		||||
                                    break
 | 
			
		||||
            finally:
 | 
			
		||||
                log.debug(f"Exiting debugger for actor {actor}")
 | 
			
		||||
                global _in_debug
 | 
			
		||||
                _in_debug = False
 | 
			
		||||
                global _local_task_in_debug
 | 
			
		||||
                _local_task_in_debug = None
 | 
			
		||||
                log.debug(f"Child {actor} released parent stdio lock")
 | 
			
		||||
 | 
			
		||||
    async def _bp():
 | 
			
		||||
        """Async breakpoint which schedules a parent stdio lock, and once complete
 | 
			
		||||
        enters the ``pdbpp`` debugging console.
 | 
			
		||||
        """
 | 
			
		||||
        task_name = trio.lowlevel.current_task().name
 | 
			
		||||
    if not _pdb_complete or _pdb_complete.is_set():
 | 
			
		||||
        _pdb_complete = trio.Event()
 | 
			
		||||
 | 
			
		||||
        global _in_debug
 | 
			
		||||
 | 
			
		||||
        # TODO: need a more robust check for the "root" actor
 | 
			
		||||
        if actor._parent_chan and not is_root_process():
 | 
			
		||||
            if _in_debug:
 | 
			
		||||
                if _in_debug == task_name:
 | 
			
		||||
                    # this task already has the lock and is
 | 
			
		||||
                    # likely recurrently entering a breakpoint
 | 
			
		||||
                    return
 | 
			
		||||
 | 
			
		||||
                # if **this** actor is already in debug mode block here
 | 
			
		||||
                # waiting for the control to be released - this allows
 | 
			
		||||
                # support for recursive entries to `tractor.breakpoint()`
 | 
			
		||||
                log.warning(
 | 
			
		||||
                    f"Actor {actor.uid} already has a debug lock, waiting...")
 | 
			
		||||
                await do_unlock.wait()
 | 
			
		||||
                await trio.sleep(0.1)
 | 
			
		||||
 | 
			
		||||
            # assign unlock callback for debugger teardown hooks
 | 
			
		||||
            global _pdb_release_hook
 | 
			
		||||
            _pdb_release_hook = do_unlock.set
 | 
			
		||||
 | 
			
		||||
            # mark local actor as "in debug mode" to avoid recurrent
 | 
			
		||||
            # entries/requests to the root process
 | 
			
		||||
            _in_debug = task_name
 | 
			
		||||
 | 
			
		||||
            # this **must** be awaited by the caller and is done using the
 | 
			
		||||
            # root nursery so that the debugger can continue to run without
 | 
			
		||||
            # being restricted by the scope of a new task nursery.
 | 
			
		||||
            await actor._service_n.start(wait_for_parent_stdin_hijack)
 | 
			
		||||
 | 
			
		||||
        elif is_root_process():
 | 
			
		||||
            # we also wait in the root-parent for any child that
 | 
			
		||||
            # may have the tty locked prior
 | 
			
		||||
            if _debug_lock.locked():  # root process already has it; ignore
 | 
			
		||||
    # TODO: need a more robust check for the "root" actor
 | 
			
		||||
    if actor._parent_chan and not is_root_process():
 | 
			
		||||
        if _local_task_in_debug:
 | 
			
		||||
            if _local_task_in_debug == task_name:
 | 
			
		||||
                # this task already has the lock and is
 | 
			
		||||
                # likely recurrently entering a breakpoint
 | 
			
		||||
                return
 | 
			
		||||
            await _debug_lock.acquire()
 | 
			
		||||
            _pdb_release_hook = _debug_lock.release
 | 
			
		||||
 | 
			
		||||
        # block here one (at the appropriate frame *up* where
 | 
			
		||||
        # ``breakpoint()`` was awaited and begin handling stdio
 | 
			
		||||
        log.debug("Entering the synchronous world of pdb")
 | 
			
		||||
        debug_func(actor)
 | 
			
		||||
            # if **this** actor is already in debug mode block here
 | 
			
		||||
            # waiting for the control to be released - this allows
 | 
			
		||||
            # support for recursive entries to `tractor.breakpoint()`
 | 
			
		||||
            log.warning(f"{actor.uid} already has a debug lock, waiting...")
 | 
			
		||||
 | 
			
		||||
    # user code **must** await this!
 | 
			
		||||
    return _bp()
 | 
			
		||||
            await _pdb_complete.wait()
 | 
			
		||||
            await trio.sleep(0.1)
 | 
			
		||||
 | 
			
		||||
        # mark local actor as "in debug mode" to avoid recurrent
 | 
			
		||||
        # entries/requests to the root process
 | 
			
		||||
        _local_task_in_debug = task_name
 | 
			
		||||
 | 
			
		||||
        # assign unlock callback for debugger teardown hooks
 | 
			
		||||
        _pdb_release_hook = _pdb_complete.set
 | 
			
		||||
 | 
			
		||||
        # this **must** be awaited by the caller and is done using the
 | 
			
		||||
        # root nursery so that the debugger can continue to run without
 | 
			
		||||
        # being restricted by the scope of a new task nursery.
 | 
			
		||||
        await actor._service_n.start(wait_for_parent_stdin_hijack)
 | 
			
		||||
 | 
			
		||||
    elif is_root_process():
 | 
			
		||||
 | 
			
		||||
        # we also wait in the root-parent for any child that
 | 
			
		||||
        # may have the tty locked prior
 | 
			
		||||
        global _debug_lock
 | 
			
		||||
 | 
			
		||||
        # TODO: wait, what about multiple root tasks acquiring
 | 
			
		||||
        # it though.. shrug?
 | 
			
		||||
        # root process (us) already has it; ignore
 | 
			
		||||
        if _global_actor_in_debug == actor.uid:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # XXX: since we need to enter pdb synchronously below,
 | 
			
		||||
        # we have to release the lock manually from pdb completion
 | 
			
		||||
        # callbacks. Can't think of a nicer way then this atm.
 | 
			
		||||
        await _debug_lock.acquire()
 | 
			
		||||
 | 
			
		||||
        _global_actor_in_debug = actor.uid
 | 
			
		||||
        _local_task_in_debug = task_name
 | 
			
		||||
 | 
			
		||||
        # the lock must be released on pdb completion
 | 
			
		||||
        def teardown():
 | 
			
		||||
            global _pdb_complete, _debug_lock
 | 
			
		||||
            global _global_actor_in_debug, _local_task_in_debug
 | 
			
		||||
 | 
			
		||||
            _debug_lock.release()
 | 
			
		||||
            _global_actor_in_debug = None
 | 
			
		||||
            _local_task_in_debug = None
 | 
			
		||||
            _pdb_complete.set()
 | 
			
		||||
 | 
			
		||||
        _pdb_release_hook = teardown
 | 
			
		||||
 | 
			
		||||
    # block here one (at the appropriate frame *up* where
 | 
			
		||||
    # ``breakpoint()`` was awaited and begin handling stdio
 | 
			
		||||
    log.debug("Entering the synchronous world of pdb")
 | 
			
		||||
    debug_func(actor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _mk_pdb():
 | 
			
		||||
| 
						 | 
				
			
			@ -276,7 +329,7 @@ def _set_trace(actor=None):
 | 
			
		|||
    pdb = _mk_pdb()
 | 
			
		||||
 | 
			
		||||
    if actor is not None:
 | 
			
		||||
        log.runtime(f"\nAttaching pdb to actor: {actor.uid}\n")
 | 
			
		||||
        log.runtime(f"\nAttaching pdb to actor: {actor.uid}\n")  # type: ignore
 | 
			
		||||
 | 
			
		||||
        pdb.set_trace(
 | 
			
		||||
            # start 2 levels up in user code
 | 
			
		||||
| 
						 | 
				
			
			@ -285,8 +338,8 @@ def _set_trace(actor=None):
 | 
			
		|||
 | 
			
		||||
    else:
 | 
			
		||||
        # we entered the global ``breakpoint()`` built-in from sync code
 | 
			
		||||
        global _in_debug, _pdb_release_hook
 | 
			
		||||
        _in_debug = 'sync'
 | 
			
		||||
        global _local_task_in_debug, _pdb_release_hook
 | 
			
		||||
        _local_task_in_debug = 'sync'
 | 
			
		||||
 | 
			
		||||
        def nuttin():
 | 
			
		||||
            pass
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -312,11 +312,20 @@ class Portal:
 | 
			
		|||
 | 
			
		||||
        ctx = Context(self.channel, cid, _portal=self)
 | 
			
		||||
        try:
 | 
			
		||||
            async with ReceiveMsgStream(ctx, recv_chan, self) as rchan:
 | 
			
		||||
            # deliver receive only stream
 | 
			
		||||
            async with ReceiveMsgStream(ctx, recv_chan) as rchan:
 | 
			
		||||
                self._streams.add(rchan)
 | 
			
		||||
                yield rchan
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
 | 
			
		||||
            # cancel the far end task on consumer close
 | 
			
		||||
            # NOTE: this is a special case since we assume that if using
 | 
			
		||||
            # this ``.open_fream_from()`` api, the stream is one a one
 | 
			
		||||
            # time use and we couple the far end tasks's lifetime to
 | 
			
		||||
            # the consumer's scope; we don't ever send a `'stop'`
 | 
			
		||||
            # message right now since there shouldn't be a reason to
 | 
			
		||||
            # stop and restart the stream, right?
 | 
			
		||||
            try:
 | 
			
		||||
                await ctx.cancel()
 | 
			
		||||
            except trio.ClosedResourceError:
 | 
			
		||||
| 
						 | 
				
			
			@ -326,17 +335,64 @@ class Portal:
 | 
			
		|||
 | 
			
		||||
            self._streams.remove(rchan)
 | 
			
		||||
 | 
			
		||||
    # @asynccontextmanager
 | 
			
		||||
    # async def open_context(
 | 
			
		||||
    #     self,
 | 
			
		||||
    #     func: Callable,
 | 
			
		||||
    #     **kwargs,
 | 
			
		||||
    # ) -> Context:
 | 
			
		||||
    #     # TODO
 | 
			
		||||
    #     elif resptype == 'context':  # context manager style setup/teardown
 | 
			
		||||
    #         # TODO likely not here though
 | 
			
		||||
    #         raise NotImplementedError
 | 
			
		||||
    @asynccontextmanager
 | 
			
		||||
    async def open_context(
 | 
			
		||||
        self,
 | 
			
		||||
        func: Callable,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[Tuple[Context, Any], None]:
 | 
			
		||||
        """Open an inter-actor task context.
 | 
			
		||||
 | 
			
		||||
        This is a synchronous API which allows for deterministic
 | 
			
		||||
        setup/teardown of a remote task. The yielded ``Context`` further
 | 
			
		||||
        allows for opening bidirectional streams - see
 | 
			
		||||
        ``Context.open_stream()``.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        # conduct target func method structural checks
 | 
			
		||||
        if not inspect.iscoroutinefunction(func) and (
 | 
			
		||||
            getattr(func, '_tractor_contex_function', False)
 | 
			
		||||
        ):
 | 
			
		||||
            raise TypeError(
 | 
			
		||||
                f'{func} must be an async generator function!')
 | 
			
		||||
 | 
			
		||||
        fn_mod_path, fn_name = func_deats(func)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        recv_chan: trio.ReceiveMemoryChannel = None
 | 
			
		||||
        try:
 | 
			
		||||
            cid, recv_chan, functype, first_msg = await self._submit(
 | 
			
		||||
                fn_mod_path, fn_name, kwargs)
 | 
			
		||||
 | 
			
		||||
            assert functype == 'context'
 | 
			
		||||
            msg = await recv_chan.receive()
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                # the "first" value here is delivered by the callee's
 | 
			
		||||
                # ``Context.started()`` call.
 | 
			
		||||
                first = msg['started']
 | 
			
		||||
 | 
			
		||||
            except KeyError:
 | 
			
		||||
                assert msg.get('cid'), ("Received internal error at context?")
 | 
			
		||||
 | 
			
		||||
                if msg.get('error'):
 | 
			
		||||
                    # raise the error message
 | 
			
		||||
                    raise unpack_error(msg, self.channel)
 | 
			
		||||
                else:
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
            # deliver context instance and .started() msg value in open
 | 
			
		||||
            # tuple.
 | 
			
		||||
            ctx = Context(self.channel, cid, _portal=self)
 | 
			
		||||
            try:
 | 
			
		||||
                yield ctx, first
 | 
			
		||||
 | 
			
		||||
            finally:
 | 
			
		||||
                await ctx.cancel()
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
            if recv_chan is not None:
 | 
			
		||||
                await recv_chan.aclose()
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class LocalPortal:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,19 +1,211 @@
 | 
			
		|||
"""
 | 
			
		||||
Message stream types and APIs.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
import inspect
 | 
			
		||||
from contextlib import contextmanager  # , asynccontextmanager
 | 
			
		||||
from contextlib import contextmanager, asynccontextmanager
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Iterator, Optional
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any, Iterator, Optional, Callable,
 | 
			
		||||
    AsyncGenerator,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
 | 
			
		||||
from ._ipc import Channel
 | 
			
		||||
from ._exceptions import unpack_error
 | 
			
		||||
from ._state import current_actor
 | 
			
		||||
from .log import get_logger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
log = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: generic typing like trio's receive channel
 | 
			
		||||
# but with msgspec messages?
 | 
			
		||||
# class ReceiveChannel(AsyncResource, Generic[ReceiveType]):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		||||
    """A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
 | 
			
		||||
    special behaviour for signalling stream termination across an
 | 
			
		||||
    inter-actor ``Channel``. This is the type returned to a local task
 | 
			
		||||
    which invoked a remote streaming function using `Portal.run()`.
 | 
			
		||||
 | 
			
		||||
    Termination rules:
 | 
			
		||||
    - if the local task signals stop iteration a cancel signal is
 | 
			
		||||
      relayed to the remote task indicating to stop streaming
 | 
			
		||||
    - if the remote task signals the end of a stream, raise a
 | 
			
		||||
      ``StopAsyncIteration`` to terminate the local ``async for``
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        ctx: 'Context',  # typing: ignore # noqa
 | 
			
		||||
        rx_chan: trio.abc.ReceiveChannel,
 | 
			
		||||
        shield: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self._ctx = ctx
 | 
			
		||||
        self._rx_chan = rx_chan
 | 
			
		||||
        self._shielded = shield
 | 
			
		||||
 | 
			
		||||
    # delegate directly to underlying mem channel
 | 
			
		||||
    def receive_nowait(self):
 | 
			
		||||
        msg = self._rx_chan.receive_nowait()
 | 
			
		||||
        return msg['yield']
 | 
			
		||||
 | 
			
		||||
    async def receive(self):
 | 
			
		||||
        try:
 | 
			
		||||
            msg = await self._rx_chan.receive()
 | 
			
		||||
            return msg['yield']
 | 
			
		||||
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            # internal error should never get here
 | 
			
		||||
            assert msg.get('cid'), ("Received internal error at portal?")
 | 
			
		||||
 | 
			
		||||
            # TODO: handle 2 cases with 3.10 match syntax
 | 
			
		||||
            # - 'stop'
 | 
			
		||||
            # - 'error'
 | 
			
		||||
            # possibly just handle msg['stop'] here!
 | 
			
		||||
 | 
			
		||||
            if msg.get('stop'):
 | 
			
		||||
                log.debug(f"{self} was stopped at remote end")
 | 
			
		||||
                # when the send is closed we assume the stream has
 | 
			
		||||
                # terminated and signal this local iterator to stop
 | 
			
		||||
                await self.aclose()
 | 
			
		||||
                raise trio.EndOfChannel
 | 
			
		||||
 | 
			
		||||
            # TODO: test that shows stream raising an expected error!!!
 | 
			
		||||
            elif msg.get('error'):
 | 
			
		||||
                # raise the error message
 | 
			
		||||
                raise unpack_error(msg, self._ctx.chan)
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
        except (trio.ClosedResourceError, StopAsyncIteration):
 | 
			
		||||
            # XXX: this indicates that a `stop` message was
 | 
			
		||||
            # sent by the far side of the underlying channel.
 | 
			
		||||
            # Currently this is triggered by calling ``.aclose()`` on
 | 
			
		||||
            # the send side of the channel inside
 | 
			
		||||
            # ``Actor._push_result()``, but maybe it should be put here?
 | 
			
		||||
            # to avoid exposing the internal mem chan closing mechanism?
 | 
			
		||||
            # in theory we could instead do some flushing of the channel
 | 
			
		||||
            # if needed to ensure all consumers are complete before
 | 
			
		||||
            # triggering closure too early?
 | 
			
		||||
 | 
			
		||||
            # Locally, we want to close this stream gracefully, by
 | 
			
		||||
            # terminating any local consumers tasks deterministically.
 | 
			
		||||
            # We **don't** want to be closing this send channel and not
 | 
			
		||||
            # relaying a final value to remaining consumers who may not
 | 
			
		||||
            # have been scheduled to receive it yet?
 | 
			
		||||
 | 
			
		||||
            # lots of testing to do here
 | 
			
		||||
 | 
			
		||||
            # when the send is closed we assume the stream has
 | 
			
		||||
            # terminated and signal this local iterator to stop
 | 
			
		||||
            await self.aclose()
 | 
			
		||||
 | 
			
		||||
            # await self._ctx.send_stop()
 | 
			
		||||
            raise StopAsyncIteration
 | 
			
		||||
 | 
			
		||||
        except trio.Cancelled:
 | 
			
		||||
            # relay cancels to the remote task
 | 
			
		||||
            await self.aclose()
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def shield(
 | 
			
		||||
        self
 | 
			
		||||
    ) -> Iterator['ReceiveMsgStream']:  # noqa
 | 
			
		||||
        """Shield this stream's underlying channel such that a local consumer task
 | 
			
		||||
        can be cancelled (and possibly restarted) using ``trio.Cancelled``.
 | 
			
		||||
 | 
			
		||||
        Note that here, "shielding" here guards against relaying
 | 
			
		||||
        a ``'stop'`` message to the far end of the stream thus keeping
 | 
			
		||||
        the stream machinery active and ready for further use, it does
 | 
			
		||||
        not have anything to do with an internal ``trio.CancelScope``.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        self._shielded = True
 | 
			
		||||
        yield self
 | 
			
		||||
        self._shielded = False
 | 
			
		||||
 | 
			
		||||
    async def aclose(self):
 | 
			
		||||
        """Cancel associated remote actor task and local memory channel
 | 
			
		||||
        on close.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        # TODO: proper adherance to trio's `.aclose()` semantics:
 | 
			
		||||
        # https://trio.readthedocs.io/en/stable/reference-io.html#trio.abc.AsyncResource.aclose
 | 
			
		||||
        rx_chan = self._rx_chan
 | 
			
		||||
 | 
			
		||||
        if rx_chan._closed:
 | 
			
		||||
            log.warning(f"{self} is already closed")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # TODO: broadcasting to multiple consumers
 | 
			
		||||
        # stats = rx_chan.statistics()
 | 
			
		||||
        # if stats.open_receive_channels > 1:
 | 
			
		||||
        #     # if we've been cloned don't kill the stream
 | 
			
		||||
        #     log.debug(
 | 
			
		||||
        #       "there are still consumers running keeping stream alive")
 | 
			
		||||
        #     return
 | 
			
		||||
 | 
			
		||||
        if self._shielded:
 | 
			
		||||
            log.warning(f"{self} is shielded, portal channel being kept alive")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # NOTE: this is super subtle IPC messaging stuff:
 | 
			
		||||
        # Relay stop iteration to far end **iff** we're
 | 
			
		||||
        # in bidirectional mode. If we're only streaming
 | 
			
		||||
        # *from* one side then that side **won't** have an
 | 
			
		||||
        # entry in `Actor._cids2qs` (maybe it should though?).
 | 
			
		||||
        # So any `yield` or `stop` msgs sent from the caller side
 | 
			
		||||
        # will cause key errors on the callee side since there is
 | 
			
		||||
        # no entry for a local feeder mem chan since the callee task
 | 
			
		||||
        # isn't expecting messages to be sent by the caller.
 | 
			
		||||
        # Thus, we must check that this context DOES NOT
 | 
			
		||||
        # have a portal reference to ensure this is indeed the callee
 | 
			
		||||
        # side and can relay a 'stop'. In the bidirectional case,
 | 
			
		||||
        # `Context.open_stream()` will create the `Actor._cids2qs`
 | 
			
		||||
        # entry from a call to `Actor.get_memchans()`.
 | 
			
		||||
        if not self._ctx._portal:
 | 
			
		||||
            # only for 2 way streams can we can send
 | 
			
		||||
            # stop from the caller side
 | 
			
		||||
            await self._ctx.send_stop()
 | 
			
		||||
 | 
			
		||||
        # close the local mem chan
 | 
			
		||||
        await rx_chan.aclose()
 | 
			
		||||
 | 
			
		||||
    # TODO: but make it broadcasting to consumers
 | 
			
		||||
    # def clone(self):
 | 
			
		||||
    #     """Clone this receive channel allowing for multi-task
 | 
			
		||||
    #     consumption from the same channel.
 | 
			
		||||
 | 
			
		||||
    #     """
 | 
			
		||||
    #     return ReceiveStream(
 | 
			
		||||
    #         self._cid,
 | 
			
		||||
    #         self._rx_chan.clone(),
 | 
			
		||||
    #         self._portal,
 | 
			
		||||
    #     )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MsgStream(ReceiveMsgStream, trio.abc.Channel):
 | 
			
		||||
    """
 | 
			
		||||
    Bidirectional message stream for use within an inter-actor actor
 | 
			
		||||
    ``Context```.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    async def send(
 | 
			
		||||
        self,
 | 
			
		||||
        data: Any
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass(frozen=True)
 | 
			
		||||
class Context:
 | 
			
		||||
    """An IAC (inter-actor communication) context.
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +223,10 @@ class Context:
 | 
			
		|||
    chan: Channel
 | 
			
		||||
    cid: str
 | 
			
		||||
 | 
			
		||||
    # TODO: should we have seperate types for caller vs. callee
 | 
			
		||||
    # side contexts? The caller always opens a portal whereas the callee
 | 
			
		||||
    # is always responding back through a context-stream
 | 
			
		||||
 | 
			
		||||
    # only set on the caller side
 | 
			
		||||
    _portal: Optional['Portal'] = None    # type: ignore # noqa
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -57,46 +253,110 @@ class Context:
 | 
			
		|||
        timeout quickly to sidestep 2-generals...
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        assert self._portal, (
 | 
			
		||||
            "No portal found, this is likely a callee side context")
 | 
			
		||||
        if self._portal:  # caller side:
 | 
			
		||||
            if not self._portal:
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                    "No portal found, this is likely a callee side context"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        cid = self.cid
 | 
			
		||||
        with trio.move_on_after(0.5) as cs:
 | 
			
		||||
            cs.shield = True
 | 
			
		||||
            log.warning(
 | 
			
		||||
                f"Cancelling stream {cid} to "
 | 
			
		||||
                f"{self._portal.channel.uid}")
 | 
			
		||||
 | 
			
		||||
            # NOTE: we're telling the far end actor to cancel a task
 | 
			
		||||
            # corresponding to *this actor*. The far end local channel
 | 
			
		||||
            # instance is passed to `Actor._cancel_task()` implicitly.
 | 
			
		||||
            await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
 | 
			
		||||
 | 
			
		||||
        if cs.cancelled_caught:
 | 
			
		||||
            # XXX: there's no way to know if the remote task was indeed
 | 
			
		||||
            # cancelled in the case where the connection is broken or
 | 
			
		||||
            # some other network error occurred.
 | 
			
		||||
            if not self._portal.channel.connected():
 | 
			
		||||
            cid = self.cid
 | 
			
		||||
            with trio.move_on_after(0.5) as cs:
 | 
			
		||||
                cs.shield = True
 | 
			
		||||
                log.warning(
 | 
			
		||||
                    "May have failed to cancel remote task "
 | 
			
		||||
                    f"{cid} for {self._portal.channel.uid}")
 | 
			
		||||
                    f"Cancelling stream {cid} to "
 | 
			
		||||
                    f"{self._portal.channel.uid}")
 | 
			
		||||
 | 
			
		||||
                # NOTE: we're telling the far end actor to cancel a task
 | 
			
		||||
                # corresponding to *this actor*. The far end local channel
 | 
			
		||||
                # instance is passed to `Actor._cancel_task()` implicitly.
 | 
			
		||||
                await self._portal.run_from_ns('self', '_cancel_task', cid=cid)
 | 
			
		||||
 | 
			
		||||
            if cs.cancelled_caught:
 | 
			
		||||
                # XXX: there's no way to know if the remote task was indeed
 | 
			
		||||
                # cancelled in the case where the connection is broken or
 | 
			
		||||
                # some other network error occurred.
 | 
			
		||||
                # if not self._portal.channel.connected():
 | 
			
		||||
                if not self.chan.connected():
 | 
			
		||||
                    log.warning(
 | 
			
		||||
                        "May have failed to cancel remote task "
 | 
			
		||||
                        f"{cid} for {self._portal.channel.uid}")
 | 
			
		||||
        else:
 | 
			
		||||
            # ensure callee side
 | 
			
		||||
            assert self._cancel_scope
 | 
			
		||||
            # TODO: should we have an explicit cancel message
 | 
			
		||||
            # or is relaying the local `trio.Cancelled` as an
 | 
			
		||||
            # {'error': trio.Cancelled, cid: "blah"} enough?
 | 
			
		||||
            # This probably gets into the discussion in
 | 
			
		||||
            # https://github.com/goodboy/tractor/issues/36
 | 
			
		||||
            self._cancel_scope.cancel()
 | 
			
		||||
 | 
			
		||||
    # TODO: do we need a restart api?
 | 
			
		||||
    # async def restart(self) -> None:
 | 
			
		||||
    #     # TODO
 | 
			
		||||
    #     pass
 | 
			
		||||
 | 
			
		||||
    # @asynccontextmanager
 | 
			
		||||
    # async def open_stream(
 | 
			
		||||
    #     self,
 | 
			
		||||
    # ) -> AsyncContextManager:
 | 
			
		||||
    #     # TODO
 | 
			
		||||
    #     pass
 | 
			
		||||
    @asynccontextmanager
 | 
			
		||||
    async def open_stream(
 | 
			
		||||
        self,
 | 
			
		||||
        shield: bool = False,
 | 
			
		||||
    ) -> AsyncGenerator[MsgStream, None]:
 | 
			
		||||
        # TODO
 | 
			
		||||
 | 
			
		||||
        actor = current_actor()
 | 
			
		||||
 | 
			
		||||
        # here we create a mem chan that corresponds to the
 | 
			
		||||
        # far end caller / callee.
 | 
			
		||||
 | 
			
		||||
        # NOTE: in one way streaming this only happens on the
 | 
			
		||||
        # caller side inside `Actor.send_cmd()` so if you try
 | 
			
		||||
        # to send a stop from the caller to the callee in the
 | 
			
		||||
        # single-direction-stream case you'll get a lookup error
 | 
			
		||||
        # currently.
 | 
			
		||||
        _, recv_chan = actor.get_memchans(
 | 
			
		||||
            self.chan.uid,
 | 
			
		||||
            self.cid
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        async with MsgStream(
 | 
			
		||||
            ctx=self,
 | 
			
		||||
            rx_chan=recv_chan,
 | 
			
		||||
            shield=shield,
 | 
			
		||||
        ) as rchan:
 | 
			
		||||
 | 
			
		||||
            if self._portal:
 | 
			
		||||
                self._portal._streams.add(rchan)
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                yield rchan
 | 
			
		||||
 | 
			
		||||
            except trio.EndOfChannel:
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                # signal ``StopAsyncIteration`` on far end.
 | 
			
		||||
                await self.send_stop()
 | 
			
		||||
 | 
			
		||||
            finally:
 | 
			
		||||
                if self._portal:
 | 
			
		||||
                    self._portal._streams.remove(rchan)
 | 
			
		||||
 | 
			
		||||
    async def started(self, value: Optional[Any] = None) -> None:
 | 
			
		||||
 | 
			
		||||
        if self._portal:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f"Caller side context {self} can not call started!")
 | 
			
		||||
 | 
			
		||||
        await self.chan.send({'started': value, 'cid': self.cid})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stream(func):
 | 
			
		||||
def stream(func: Callable) -> Callable:
 | 
			
		||||
    """Mark an async function as a streaming routine with ``@stream``.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    func._tractor_stream_function = True
 | 
			
		||||
    # annotate
 | 
			
		||||
    # TODO: apply whatever solution ``mypy`` ends up picking for this:
 | 
			
		||||
    # https://github.com/python/mypy/issues/2087#issuecomment-769266912
 | 
			
		||||
    func._tractor_stream_function = True  # type: ignore
 | 
			
		||||
 | 
			
		||||
    sig = inspect.signature(func)
 | 
			
		||||
    params = sig.parameters
 | 
			
		||||
    if 'stream' not in params and 'ctx' in params:
 | 
			
		||||
| 
						 | 
				
			
			@ -114,147 +374,26 @@ def stream(func):
 | 
			
		|||
    ):
 | 
			
		||||
        raise TypeError(
 | 
			
		||||
            "The first argument to the stream function "
 | 
			
		||||
            f"{func.__name__} must be `ctx: tractor.Context`"
 | 
			
		||||
            f"{func.__name__} must be `ctx: tractor.Context` "
 | 
			
		||||
            "(Or ``to_trio`` if using ``asyncio`` in guest mode)."
 | 
			
		||||
        )
 | 
			
		||||
    return func
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReceiveMsgStream(trio.abc.ReceiveChannel):
 | 
			
		||||
    """A wrapper around a ``trio._channel.MemoryReceiveChannel`` with
 | 
			
		||||
    special behaviour for signalling stream termination across an
 | 
			
		||||
    inter-actor ``Channel``. This is the type returned to a local task
 | 
			
		||||
    which invoked a remote streaming function using `Portal.run()`.
 | 
			
		||||
 | 
			
		||||
    Termination rules:
 | 
			
		||||
    - if the local task signals stop iteration a cancel signal is
 | 
			
		||||
      relayed to the remote task indicating to stop streaming
 | 
			
		||||
    - if the remote task signals the end of a stream, raise a
 | 
			
		||||
      ``StopAsyncIteration`` to terminate the local ``async for``
 | 
			
		||||
def context(func: Callable) -> Callable:
 | 
			
		||||
    """Mark an async function as a streaming routine with ``@context``.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        ctx: Context,
 | 
			
		||||
        rx_chan: trio.abc.ReceiveChannel,
 | 
			
		||||
        portal: 'Portal',  # type: ignore # noqa
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self._ctx = ctx
 | 
			
		||||
        self._rx_chan = rx_chan
 | 
			
		||||
        self._portal = portal
 | 
			
		||||
        self._shielded = False
 | 
			
		||||
    # annotate
 | 
			
		||||
    # TODO: apply whatever solution ``mypy`` ends up picking for this:
 | 
			
		||||
    # https://github.com/python/mypy/issues/2087#issuecomment-769266912
 | 
			
		||||
    func._tractor_context_function = True  # type: ignore
 | 
			
		||||
 | 
			
		||||
    # delegate directly to underlying mem channel
 | 
			
		||||
    def receive_nowait(self):
 | 
			
		||||
        return self._rx_chan.receive_nowait()
 | 
			
		||||
 | 
			
		||||
    async def receive(self):
 | 
			
		||||
        try:
 | 
			
		||||
            msg = await self._rx_chan.receive()
 | 
			
		||||
            return msg['yield']
 | 
			
		||||
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            # internal error should never get here
 | 
			
		||||
            assert msg.get('cid'), ("Received internal error at portal?")
 | 
			
		||||
 | 
			
		||||
            # TODO: handle 2 cases with 3.10 match syntax
 | 
			
		||||
            # - 'stop'
 | 
			
		||||
            # - 'error'
 | 
			
		||||
            # possibly just handle msg['stop'] here!
 | 
			
		||||
 | 
			
		||||
            # TODO: test that shows stream raising an expected error!!!
 | 
			
		||||
            if msg.get('error'):
 | 
			
		||||
                # raise the error message
 | 
			
		||||
                raise unpack_error(msg, self._portal.channel)
 | 
			
		||||
 | 
			
		||||
        except (trio.ClosedResourceError, StopAsyncIteration):
 | 
			
		||||
            # XXX: this indicates that a `stop` message was
 | 
			
		||||
            # sent by the far side of the underlying channel.
 | 
			
		||||
            # Currently this is triggered by calling ``.aclose()`` on
 | 
			
		||||
            # the send side of the channel inside
 | 
			
		||||
            # ``Actor._push_result()``, but maybe it should be put here?
 | 
			
		||||
            # to avoid exposing the internal mem chan closing mechanism?
 | 
			
		||||
            # in theory we could instead do some flushing of the channel
 | 
			
		||||
            # if needed to ensure all consumers are complete before
 | 
			
		||||
            # triggering closure too early?
 | 
			
		||||
 | 
			
		||||
            # Locally, we want to close this stream gracefully, by
 | 
			
		||||
            # terminating any local consumers tasks deterministically.
 | 
			
		||||
            # We **don't** want to be closing this send channel and not
 | 
			
		||||
            # relaying a final value to remaining consumers who may not
 | 
			
		||||
            # have been scheduled to receive it yet?
 | 
			
		||||
 | 
			
		||||
            # lots of testing to do here
 | 
			
		||||
 | 
			
		||||
            # when the send is closed we assume the stream has
 | 
			
		||||
            # terminated and signal this local iterator to stop
 | 
			
		||||
            await self.aclose()
 | 
			
		||||
            raise StopAsyncIteration
 | 
			
		||||
 | 
			
		||||
        except trio.Cancelled:
 | 
			
		||||
            # relay cancels to the remote task
 | 
			
		||||
            await self.aclose()
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def shield(
 | 
			
		||||
        self
 | 
			
		||||
    ) -> Iterator['ReceiveMsgStream']:  # noqa
 | 
			
		||||
        """Shield this stream's underlying channel such that a local consumer task
 | 
			
		||||
        can be cancelled (and possibly restarted) using ``trio.Cancelled``.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        self._shielded = True
 | 
			
		||||
        yield self
 | 
			
		||||
        self._shielded = False
 | 
			
		||||
 | 
			
		||||
    async def aclose(self):
 | 
			
		||||
        """Cancel associated remote actor task and local memory channel
 | 
			
		||||
        on close.
 | 
			
		||||
        """
 | 
			
		||||
        rx_chan = self._rx_chan
 | 
			
		||||
 | 
			
		||||
        if rx_chan._closed:
 | 
			
		||||
            log.warning(f"{self} is already closed")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # stats = rx_chan.statistics()
 | 
			
		||||
        # if stats.open_receive_channels > 1:
 | 
			
		||||
        #     # if we've been cloned don't kill the stream
 | 
			
		||||
        #     log.debug(
 | 
			
		||||
        #       "there are still consumers running keeping stream alive")
 | 
			
		||||
        #     return
 | 
			
		||||
 | 
			
		||||
        if self._shielded:
 | 
			
		||||
            log.warning(f"{self} is shielded, portal channel being kept alive")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # close the local mem chan
 | 
			
		||||
        rx_chan.close()
 | 
			
		||||
 | 
			
		||||
        # cancel surrounding IPC context
 | 
			
		||||
        await self._ctx.cancel()
 | 
			
		||||
 | 
			
		||||
    # TODO: but make it broadcasting to consumers
 | 
			
		||||
    # def clone(self):
 | 
			
		||||
    #     """Clone this receive channel allowing for multi-task
 | 
			
		||||
    #     consumption from the same channel.
 | 
			
		||||
 | 
			
		||||
    #     """
 | 
			
		||||
    #     return ReceiveStream(
 | 
			
		||||
    #         self._cid,
 | 
			
		||||
    #         self._rx_chan.clone(),
 | 
			
		||||
    #         self._portal,
 | 
			
		||||
    #     )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# class MsgStream(ReceiveMsgStream, trio.abc.Channel):
 | 
			
		||||
#     """
 | 
			
		||||
#     Bidirectional message stream for use within an inter-actor actor
 | 
			
		||||
#     ``Context```.
 | 
			
		||||
 | 
			
		||||
#     """
 | 
			
		||||
#     async def send(
 | 
			
		||||
#         self,
 | 
			
		||||
#         data: Any
 | 
			
		||||
#     ) -> None:
 | 
			
		||||
#         await self._ctx.chan.send({'yield': data, 'cid': self._ctx.cid})
 | 
			
		||||
    sig = inspect.signature(func)
 | 
			
		||||
    params = sig.parameters
 | 
			
		||||
    if 'ctx' not in params:
 | 
			
		||||
        raise TypeError(
 | 
			
		||||
            "The first argument to the context function "
 | 
			
		||||
            f"{func.__name__} must be `ctx: tractor.Context`"
 | 
			
		||||
        )
 | 
			
		||||
    return func
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -357,7 +357,8 @@ async def open_nursery(
 | 
			
		|||
    try:
 | 
			
		||||
        if actor is None and is_main_process():
 | 
			
		||||
 | 
			
		||||
            # if we are the parent process start the actor runtime implicitly
 | 
			
		||||
            # if we are the parent process start the
 | 
			
		||||
            # actor runtime implicitly
 | 
			
		||||
            log.info("Starting actor runtime!")
 | 
			
		||||
 | 
			
		||||
            # mark us for teardown on exit
 | 
			
		||||
| 
						 | 
				
			
			@ -376,7 +377,6 @@ async def open_nursery(
 | 
			
		|||
            async with _open_and_supervise_one_cancels_all_nursery(
 | 
			
		||||
                actor
 | 
			
		||||
            ) as anursery:
 | 
			
		||||
 | 
			
		||||
                yield anursery
 | 
			
		||||
 | 
			
		||||
    finally:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue