diff --git a/tractor/msg.py b/tractor/msg.py index f778ac7..4184333 100644 --- a/tractor/msg.py +++ b/tractor/msg.py @@ -11,7 +11,6 @@ import trio import wrapt from .log import get_logger -from . import current_actor from ._streaming import Context __all__ = ['pub'] @@ -91,6 +90,7 @@ def modify_subs(topics2ctxs, topics, ctx): _pub_state: Dict[str, dict] = {} +_pubtask2lock: Dict[str, dict] = {} def pub( @@ -178,22 +178,22 @@ def pub( subscribers. If you are ok to have a new task running for every call to ``pub_service()`` then probably don't need this. """ - global _pub_state + global _pubtask2lock # handle the decorator not called with () case if wrapped is None: return partial(pub, tasks=tasks) - task2lock: Dict[Union[str, None], trio.StrictFIFOLock] = { - None: trio.StrictFIFOLock()} + task2lock: Dict[Union[str, None], trio.StrictFIFOLock] = {} for name in tasks: task2lock[name] = trio.StrictFIFOLock() @wrapt.decorator async def wrapper(agen, instance, args, kwargs): - # this is used to extract arguments properly as per - # the `wrapt` docs + + # XXX: this is used to extract arguments properly as per the + # `wrapt` docs async def _execute( ctx: Context, topics: Set[str], @@ -203,14 +203,22 @@ def pub( packetizer: Callable = None, **kwargs, ): - if tasks and task_name is None: + if task_name is None: + task_name = trio.lowlevel.current_task().name + + if tasks and task_name not in tasks: raise TypeError( f"{agen} must be called with a `task_name` named " - f"argument with a falue from {tasks}") + f"argument with a value from {tasks}") + + elif not tasks and not task2lock: + # add a default root-task lock if none defined + task2lock[task_name] = trio.StrictFIFOLock() + + _pubtask2lock.update(task2lock) topics = set(topics) - lockmap = _pub_state.setdefault('_pubtask2lock', task2lock) - lock = lockmap[task_name] + lock = _pubtask2lock[task_name] all_subs = _pub_state.setdefault('_subs', {}) topics2ctxs = all_subs.setdefault(task_name, {})