diff --git a/tests/test_rpc.py b/tests/test_rpc.py new file mode 100644 index 0000000..ec286f5 --- /dev/null +++ b/tests/test_rpc.py @@ -0,0 +1,80 @@ +""" +RPC related +""" +import pytest +import tractor +import trio + + +async def sleep_back_actor( + actor_name, + func_name, + func_defined, + exposed_mods, +): + async with tractor.find_actor(actor_name) as portal: + try: + await portal.run(__name__, func_name) + except tractor.RemoteActorError as err: + if not func_defined: + expect = AttributeError + if not exposed_mods: + expect = tractor.ModuleNotExposed + + assert err.type is expect + raise + + +async def short_sleep(): + await trio.sleep(0) + + +@pytest.mark.parametrize( + 'to_call', [ + ([], 'short_sleep'), + ([__name__], 'short_sleep'), + ([__name__], 'fake_func'), + ], + ids=['no_mods', 'this_mod', 'this_mod_bad_func'], +) +def test_rpc_errors(arb_addr, to_call): + """Test errors when making various RPC requests to an actor + that either doesn't have the requested module exposed or doesn't define + the named function. + """ + exposed_mods, funcname = to_call + func_defined = globals().get(funcname, False) + + async def main(): + actor = tractor.current_actor() + assert actor.is_arbiter + + # spawn a subactor which calls us back + async with tractor.open_nursery() as n: + await n.run_in_actor( + 'subactor', + sleep_back_actor, + actor_name=actor.name, + # function from this module the subactor will invoke + # when it RPCs back to this actor + func_name=funcname, + exposed_mods=exposed_mods, + func_defined=True if func_defined else False, + ) + + def run(): + tractor.run( + main, + arbiter_addr=arb_addr, + rpc_module_paths=exposed_mods, + ) + + # handle both parameterized cases + if exposed_mods and func_defined: + run() + else: + # underlying errors are propogated upwards (yet) + with pytest.raises(tractor.RemoteActorError) as err: + run() + + assert err.value.type is tractor.RemoteActorError diff --git a/tractor/__init__.py b/tractor/__init__.py index 0667bd5..d8d6489 100644 --- a/tractor/__init__.py +++ b/tractor/__init__.py @@ -17,7 +17,7 @@ from ._actor import ( ) from ._trionics import open_nursery from ._state import current_actor -from ._exceptions import RemoteActorError +from ._exceptions import RemoteActorError, ModuleNotExposed __all__ = [ @@ -29,6 +29,7 @@ __all__ = [ 'Channel', 'MultiError', 'RemoteActorError', + 'ModuleNotExposed', ] diff --git a/tractor/_actor.py b/tractor/_actor.py index 750487a..d9a5ee1 100644 --- a/tractor/_actor.py +++ b/tractor/_actor.py @@ -15,7 +15,7 @@ from async_generator import asynccontextmanager, aclosing from ._ipc import Channel, _connect_chan from .log import get_console_log, get_logger -from ._exceptions import pack_error, InternalActorError +from ._exceptions import pack_error, InternalActorError, ModuleNotExposed from ._portal import ( Portal, open_portal, @@ -236,6 +236,12 @@ class Actor: # self._mods.pop('test_discovery') # TODO: how to test the above? + def _get_rpc_func(self, ns, funcname): + try: + return getattr(self._mods[ns], funcname) + except KeyError as err: + raise ModuleNotExposed(*err.args) + async def _stream_handler( self, stream: trio.SocketStream, @@ -398,7 +404,14 @@ class Actor: if ns == 'self': func = getattr(self, funcname) else: - func = getattr(self._mods[ns], funcname) + # complain to client about restricted modules + try: + func = self._get_rpc_func(ns, funcname) + except (ModuleNotExposed, AttributeError) as err: + err_msg = pack_error(err) + err_msg['cid'] = cid + await chan.send(err_msg) + continue # spin up a task for the requested function log.debug(f"Spawning task for {func}") diff --git a/tractor/_exceptions.py b/tractor/_exceptions.py index 7244396..efedcf0 100644 --- a/tractor/_exceptions.py +++ b/tractor/_exceptions.py @@ -1,16 +1,28 @@ """ Our classy exception set. """ +import importlib import builtins import traceback +_this_mod = importlib.import_module(__name__) + + class RemoteActorError(Exception): # TODO: local recontruction of remote exception deats "Remote actor exception bundled locally" def __init__(self, message, type_str, **msgdata): super().__init__(message) - self.type = getattr(builtins, type_str, Exception) + for ns in [builtins, _this_mod]: + try: + self.type = getattr(ns, type_str) + break + except AttributeError: + continue + else: + self.type = Exception + self.msgdata = msgdata # TODO: a trio.MultiError.catch like context manager @@ -27,6 +39,10 @@ class NoResult(RuntimeError): "No final result is expected for this actor" +class ModuleNotExposed(RuntimeError): + "The requested module is not exposed for RPC" + + def pack_error(exc): """Create an "error message" for tranmission over a channel (aka the wire).