Close all portal created async gens on shutdown
							parent
							
								
									db85e13657
								
							
						
					
					
						commit
						eb6e82f577
					
				|  | @ -65,6 +65,7 @@ class Portal: | ||||||
|         self._expect_result: Optional[ |         self._expect_result: Optional[ | ||||||
|             Tuple[str, Any, str, Dict[str, Any]] |             Tuple[str, Any, str, Dict[str, Any]] | ||||||
|         ] = None |         ] = None | ||||||
|  |         self._agens: Set(AsyncGenerator) = set() | ||||||
| 
 | 
 | ||||||
|     async def aclose(self) -> None: |     async def aclose(self) -> None: | ||||||
|         log.debug(f"Closing {self}") |         log.debug(f"Closing {self}") | ||||||
|  | @ -142,14 +143,16 @@ class Portal: | ||||||
|                 except GeneratorExit: |                 except GeneratorExit: | ||||||
|                     # for now this msg cancels an ongoing remote task |                     # for now this msg cancels an ongoing remote task | ||||||
|                     await self.channel.send({'cancel': True, 'cid': cid}) |                     await self.channel.send({'cancel': True, 'cid': cid}) | ||||||
|                     log.debug( |                     log.warn( | ||||||
|                         f"Cancelling async gen call {cid} to " |                         f"Cancelling async gen call {cid} to " | ||||||
|                         f"{self.channel.uid}") |                         f"{self.channel.uid}") | ||||||
|                     raise |                     raise | ||||||
| 
 | 
 | ||||||
|             # TODO: use AsyncExitStack to aclose() all agens |             # TODO: use AsyncExitStack to aclose() all agens | ||||||
|             # on teardown |             # on teardown | ||||||
|             return yield_from_q() |             agen = yield_from_q() | ||||||
|  |             self._agens.add(agen) | ||||||
|  |             return agen | ||||||
| 
 | 
 | ||||||
|         elif resptype == 'return': |         elif resptype == 'return': | ||||||
|             msg = await q.get() |             msg = await q.get() | ||||||
|  | @ -269,13 +272,18 @@ async def open_portal( | ||||||
| 
 | 
 | ||||||
|         nursery.start_soon(actor._process_messages, channel) |         nursery.start_soon(actor._process_messages, channel) | ||||||
|         portal = Portal(channel) |         portal = Portal(channel) | ||||||
|         yield portal |         try: | ||||||
|  |             yield portal | ||||||
|  |         finally: | ||||||
|  |             # tear down all async generators | ||||||
|  |             for agen in portal._agens: | ||||||
|  |                 await agen.aclose() | ||||||
| 
 | 
 | ||||||
|         # cancel remote channel-msg loop |             # cancel remote channel-msg loop | ||||||
|         if channel.connected(): |             if channel.connected(): | ||||||
|             await portal.close() |                 await portal.close() | ||||||
| 
 | 
 | ||||||
|         # cancel background msg loop task |             # cancel background msg loop task | ||||||
|         nursery.cancel_scope.cancel() |             nursery.cancel_scope.cancel() | ||||||
|         if was_connected: |             if was_connected: | ||||||
|             await channel.aclose() |                 await channel.aclose() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue