Expect multi-output fsps to yield a `dict` of history arrays
parent
8118a57b9a
commit
1aae40cdeb
|
@ -20,7 +20,10 @@ core task logic for processing chains
|
||||||
'''
|
'''
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import AsyncIterator, Callable, Optional
|
from typing import (
|
||||||
|
AsyncIterator, Callable, Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyqtgraph as pg
|
import pyqtgraph as pg
|
||||||
|
@ -101,28 +104,61 @@ async def fsp_compute(
|
||||||
|
|
||||||
# Conduct a single iteration of fsp with historical bars input
|
# Conduct a single iteration of fsp with historical bars input
|
||||||
# and get historical output
|
# and get historical output
|
||||||
|
history_output: Union[
|
||||||
|
dict[str, np.ndarray], # multi-output case
|
||||||
|
np.ndarray, # single output case
|
||||||
|
]
|
||||||
history_output = await out_stream.__anext__()
|
history_output = await out_stream.__anext__()
|
||||||
|
|
||||||
func_name = func.__name__
|
func_name = func.__name__
|
||||||
profiler(f'{func_name} generated history')
|
profiler(f'{func_name} generated history')
|
||||||
|
|
||||||
# build struct array with an 'index' field to push as history
|
# build struct array with an 'index' field to push as history
|
||||||
history = np.zeros(
|
|
||||||
len(history_output),
|
|
||||||
dtype=dst.array.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: push using a[['f0', 'f1', .., 'fn']] = .. syntax no?
|
# TODO: push using a[['f0', 'f1', .., 'fn']] = .. syntax no?
|
||||||
# if the output array is multi-field then push
|
# if the output array is multi-field then push
|
||||||
# each respective field.
|
# each respective field.
|
||||||
fields = getattr(history.dtype, 'fields', None)
|
# await tractor.breakpoint()
|
||||||
if fields:
|
fields = getattr(dst.array.dtype, 'fields', None).copy()
|
||||||
|
fields.pop('index')
|
||||||
|
# TODO: nptyping here!
|
||||||
|
history: Optional[np.ndarray] = None
|
||||||
|
if fields and len(fields) > 1 and fields:
|
||||||
|
if not isinstance(history_output, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f'`{func_name}` is a multi-output FSP and should yield a '
|
||||||
|
'`dict[str, np.ndarray]` for history'
|
||||||
|
)
|
||||||
|
|
||||||
for key in fields.keys():
|
for key in fields.keys():
|
||||||
if key in history.dtype.fields:
|
if key in history_output:
|
||||||
history[func_name] = history_output
|
output = history_output[key]
|
||||||
|
|
||||||
|
if history is None:
|
||||||
|
# using the first output, determine
|
||||||
|
# the length of the struct-array that
|
||||||
|
# will be pushed to shm.
|
||||||
|
history = np.zeros(
|
||||||
|
len(output),
|
||||||
|
dtype=dst.array.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if output is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
history[key] = output
|
||||||
|
|
||||||
# single-key output stream
|
# single-key output stream
|
||||||
else:
|
else:
|
||||||
|
if not isinstance(history_output, np.ndarray):
|
||||||
|
raise ValueError(
|
||||||
|
f'`{func_name}` is a single output FSP and should yield an '
|
||||||
|
'`np.ndarray` for history'
|
||||||
|
)
|
||||||
|
history = np.zeros(
|
||||||
|
len(history_output),
|
||||||
|
dtype=dst.array.dtype
|
||||||
|
)
|
||||||
history[func_name] = history_output
|
history[func_name] = history_output
|
||||||
|
|
||||||
# TODO: XXX:
|
# TODO: XXX:
|
||||||
|
|
Loading…
Reference in New Issue