Rewrite `slice_from_time()` using `numba`

Gives approx a 3-4x speedup using plain old iterate-with-for-loop style
though still not really happy with this .5 to 1 ms latency..

Move the core `@njit` part to a `_slice_from_time()` with a pure python
func with orig name around it. Also, drop the output `mask` array since
we can generally just use the slices in the caller to accomplish the
same input array slicing, duh..
pre_viz_calls
Tyler Goodlet 2022-12-03 15:36:13 -05:00
parent 84b6ec07d8
commit 7e7748ce6b
1 changed files with 92 additions and 74 deletions

View File

@ -33,6 +33,7 @@ from ._m4 import ds_m4
from .._profile import ( from .._profile import (
Profiler, Profiler,
pg_profile_enabled, pg_profile_enabled,
ms_slower_then,
) )
@ -269,6 +270,87 @@ def ohlc_flatten(
return x, flat return x, flat
@njit
def _slice_from_time(
arr: np.ndarray,
start_t: float,
stop_t: float,
) -> tuple[
tuple[int, int],
tuple[int, int],
np.ndarray | None,
]:
'''
Slice an input struct array to a time range and return the absolute
and "readable" slices for that array as well as the indexing mask
for the caller to use to slice the input array if needed.
'''
times = arr['time']
index = arr['index']
if (
start_t < 0
or start_t >= stop_t
):
return (
(
index[0],
index[-1],
),
(
0,
len(arr),
),
)
# TODO: if we can ensure each time field has a uniform
# step we can instead do some arithmetic to determine
# the equivalent index like we used to?
# return array[
# lbar - ifirst:
# (rbar - ifirst) + 1
# ]
read_i_0: int = 0
read_i_last: int = 0
for i in range(times.shape[0]):
time = times[i]
if time >= start_t:
read_i_0 = i
break
for i in range(read_i_0, times.shape[0]):
time = times[i]
if time > stop_t:
read_i_last = time
break
abs_i_0 = int(index[0]) + read_i_0
abs_i_last = int(index[0]) + read_i_last
if read_i_last == 0:
read_i_last = times.shape[0]
abs_slc = (
int(abs_i_0),
int(abs_i_last),
)
read_slc = (
int(read_i_0),
int(read_i_last),
)
# also return the readable data from the timerange
return (
abs_slc,
read_slc,
)
def slice_from_time( def slice_from_time(
arr: np.ndarray, arr: np.ndarray,
start_t: float, start_t: float,
@ -279,93 +361,29 @@ def slice_from_time(
slice, slice,
np.ndarray | None, np.ndarray | None,
]: ]:
'''
Slice an input struct array to a time range and return the absolute
and "readable" slices for that array as well as the indexing mask
for the caller to use to slice the input array if needed.
'''
profiler = Profiler( profiler = Profiler(
msg='slice_from_time()', msg='slice_from_time()',
disabled=not pg_profile_enabled(), disabled=not pg_profile_enabled(),
ms_threshold=4, ms_threshold=ms_slower_then,
# ms_threshold=ms_slower_then,
) )
times = arr['time'] (
index = arr['index'] abs_slc_tuple,
read_slc_tuple,
if ( ) = _slice_from_time(
start_t < 0 arr,
or start_t >= stop_t start_t,
): stop_t,
return (
slice(
index[0],
index[-1],
),
slice(
0,
len(arr),
),
None,
)
# use advanced indexing to map the
# time range to the index range.
mask: np.ndarray = np.where(
(times >= start_t)
&
(times < stop_t)
)
profiler('advanced indexing slice')
# TODO: if we can ensure each time field has a uniform
# step we can instead do some arithmetic to determine
# the equivalent index like we used to?
# return array[
# lbar - ifirst:
# (rbar - ifirst) + 1
# ]
i_by_t = index[mask]
try:
i_0 = i_by_t[0]
i_last = i_by_t[-1]
i_first_read = index[0]
except IndexError:
if (
start_t < times[0]
or stop_t >= times[-1]
):
return (
slice(
index[0],
index[-1],
),
slice(
0,
len(arr),
),
None,
)
abs_slc = slice(i_0, i_last)
# slice data by offset from the first index
# available in the passed datum set.
read_slc = slice(
i_0 - i_first_read,
i_last - i_first_read + 1,
) )
abs_slc = slice(*abs_slc_tuple)
read_slc = slice(*read_slc_tuple)
profiler( profiler(
'slicing complete' 'slicing complete'
f'{start_t} -> {abs_slc.start} | {read_slc.start}\n' f'{start_t} -> {abs_slc.start} | {read_slc.start}\n'
f'{stop_t} -> {abs_slc.stop} | {read_slc.stop}\n' f'{stop_t} -> {abs_slc.stop} | {read_slc.stop}\n'
) )
# also return the readable data from the timerange
return ( return (
abs_slc, abs_slc,
read_slc, read_slc,
mask,
) )