Don't wake sibling bcast consumers on a cancelled call

tokio_backup
Tyler Goodlet 2021-08-31 18:30:06 -04:00
parent 71a4f8aaa9
commit c3665801a5
1 changed files with 37 additions and 5 deletions

View File

@ -111,7 +111,7 @@ class BroadcastReceiver(ReceiveChannel):
self._recv = receive_afunc or rx_chan.receive
self._closed: bool = False
async def receive(self):
async def receive(self) -> ReceiveType:
key = self.key
state = self._state
@ -169,9 +169,11 @@ class BroadcastReceiver(ReceiveChannel):
event = trio.Event()
state.recv_ready = key, event
# if we're cancelled here it should be
# fine to bail without affecting any other consumers
# right?
try:
value = await self._recv()
# items with lower indices are "newer"
state.queue.appendleft(value)
@ -193,21 +195,51 @@ class BroadcastReceiver(ReceiveChannel):
):
state.subs[sub_key] += 1
# NOTE: this should ONLY be set if the above task was *NOT*
# cancelled on the `._recv()` call otherwise sibling
# consumers will be awoken with a sequence of -1
event.set()
return value
finally:
# reset receiver waiter task event for next blocking condition
event.set()
# Reset receiver waiter task event for next blocking condition.
# this MUST be reset even if the above ``.recv()`` call
# was cancelled to avoid the next consumer from blocking on
# an event that won't be set!
state.recv_ready = None
# This task is all caught up and ready to receive the latest
# value, so queue sched it on the internal event.
else:
seq = state.subs[key]
assert seq == -1 # sanity
_, ev = state.recv_ready
await ev.wait()
seq = state.subs[key]
assert seq > -1, f'Invalid sequence {seq}!?'
value = state.queue[seq]
state.subs[key] -= 1
return state.queue[seq]
return value
# NOTE: if we ever would like the behaviour where if the
# first task to recv on the underlying is cancelled but it
# still DOES trigger the ``.recv_ready``, event we'll likely need
# this logic:
# if seq > -1:
# # stuff from above..
# elif seq == -1:
# # XXX: In the case where the first task to allocate the
# # ``.recv_ready`` event is cancelled we will be woken with
# # a non-incremented sequence number and thus will read the
# # oldest value if we use that. Instead we need to detect if
# # we have not been incremented and then receive again.
# return await self.receive()
# else:
# raise ValueError(f'Invalid sequence {seq}!?')
@asynccontextmanager
async def subscribe(