Solve first-recv-cancelled by recursive `.receive()` on wake

live_on_air_from_tokio
Tyler Goodlet 2021-09-02 11:39:56 -04:00
parent 5881a82d2a
commit 2745a2b1dc
1 changed files with 28 additions and 21 deletions

View File

@ -202,13 +202,20 @@ class BroadcastReceiver(ReceiveChannel):
state.subs[sub_key] += 1
# NOTE: this should ONLY be set if the above task was *NOT*
# cancelled on the `._recv()` call otherwise sibling
# consumers will be awoken with a sequence of -1
# cancelled on the `._recv()` call.
event.set()
return value
except trio.Cancelled:
# handle cancelled specially otherwise sibling
# consumers will be awoken with a sequence of -1
# state.recv_ready = trio.Cancelled
if event.statistics().tasks_waiting:
event.set()
raise
finally:
# Reset receiver waiter task event for next blocking condition.
# this MUST be reset even if the above ``.recv()`` call
# was cancelled to avoid the next consumer from blocking on
@ -223,29 +230,29 @@ class BroadcastReceiver(ReceiveChannel):
_, ev = state.recv_ready
await ev.wait()
seq = state.subs[key]
assert seq > -1, f'Invalid sequence {seq}!?'
value = state.queue[seq]
state.subs[key] -= 1
return value
# NOTE: if we ever would like the behaviour where if the
# first task to recv on the underlying is cancelled but it
# still DOES trigger the ``.recv_ready``, event we'll likely need
# this logic:
# if seq > -1:
# # stuff from above..
# elif seq == -1:
# # XXX: In the case where the first task to allocate the
# # ``.recv_ready`` event is cancelled we will be woken with
# # a non-incremented sequence number and thus will read the
# # oldest value if we use that. Instead we need to detect if
# # we have not been incremented and then receive again.
# return await self.receive()
# else:
# raise ValueError(f'Invalid sequence {seq}!?')
if seq > -1:
# stuff from above..
seq = state.subs[key]
value = state.queue[seq]
state.subs[key] -= 1
return value
elif seq == -1:
# XXX: In the case where the first task to allocate the
# ``.recv_ready`` event is cancelled we will be woken with
# a non-incremented sequence number and thus will read the
# oldest value if we use that. Instead we need to detect if
# we have not been incremented and then receive again.
return await self.receive()
else:
raise ValueError(f'Invalid sequence {seq}!?')
@asynccontextmanager
async def subscribe(