`Task` is hashable, so key on it

live_on_air_from_tokio
Tyler Goodlet 2021-08-08 19:58:12 -04:00
parent 6a2c3da1bb
commit 1af7dbb732
1 changed files with 13 additions and 16 deletions

View File

@ -48,15 +48,12 @@ class BroadcastReceiver(ReceiveChannel):
async def receive(self): async def receive(self):
task: Task task: Task = current_task()
task = current_task()
# check that task does not already have a value it can receive # check that task does not already have a value it can receive
# immediately and/or that it has lagged. # immediately and/or that it has lagged.
key = id(task)
# print(key)
try: try:
seq = self._subs[key] seq = self._subs[task]
except KeyError: except KeyError:
raise RuntimeError( raise RuntimeError(
f'Task {task.name} is not registerd as subscriber') f'Task {task.name} is not registerd as subscriber')
@ -69,12 +66,12 @@ class BroadcastReceiver(ReceiveChannel):
# decrement to the last value and expect # decrement to the last value and expect
# consumer to either handle the ``Lagged`` and come back # consumer to either handle the ``Lagged`` and come back
# or bail out on it's own (thus un-subscribing) # or bail out on it's own (thus un-subscribing)
self._subs[key] = self._queue.maxlen - 1 self._subs[task] = self._queue.maxlen - 1
# this task was overrun by the producer side # this task was overrun by the producer side
raise Lagged(f'Task {task.name} was overrun') raise Lagged(f'Task {task.name} was overrun')
self._subs[key] -= 1 self._subs[task] -= 1
return value return value
if self._value_received is None: if self._value_received is None:
@ -97,7 +94,7 @@ class BroadcastReceiver(ReceiveChannel):
subs = self._subs.copy() subs = self._subs.copy()
# don't decerement the sequence # for this task since we # don't decerement the sequence # for this task since we
# already retreived the last value # already retreived the last value
subs.pop(key) subs.pop(task)
for sub_key, seq in subs.items(): for sub_key, seq in subs.items():
self._subs[sub_key] += 1 self._subs[sub_key] += 1
@ -109,10 +106,10 @@ class BroadcastReceiver(ReceiveChannel):
else: else:
await self._value_received.wait() await self._value_received.wait()
seq = self._subs[key] seq = self._subs[task]
assert seq > -1, 'Internal error?' assert seq > -1, 'Internal error?'
self._subs[key] -= 1 self._subs[task] -= 1
return self._queue[0] return self._queue[0]
# @asynccontextmanager # @asynccontextmanager
@ -120,14 +117,14 @@ class BroadcastReceiver(ReceiveChannel):
def subscribe( def subscribe(
self, self,
) -> BroadcastReceiver: ) -> BroadcastReceiver:
key = id(current_task()) task: task = current_task()
self._subs[key] = -1 self._subs[task] = -1
# XXX: we only use this clone for closure tracking # XXX: we only use this clone for closure tracking
clone = self._clones[key] = self._rx.clone() clone = self._clones[task] = self._rx.clone()
try: try:
yield self yield self
finally: finally:
self._subs.pop(key) self._subs.pop(task)
clone.close() clone.close()
# TODO: do we need anything here? # TODO: do we need anything here?
@ -135,8 +132,8 @@ class BroadcastReceiver(ReceiveChannel):
# the underlying rx channel, but couldn't we just # the underlying rx channel, but couldn't we just
# use ``.clone()``s trackign then? # use ``.clone()``s trackign then?
async def aclose(self) -> None: async def aclose(self) -> None:
key = id(current_task()) task: Task = current_task()
await self._clones[key].aclose() await self._clones[task].aclose()
def broadcast_receiver( def broadcast_receiver(