Go back to hard-coded index field

Turns out https://github.com/numba/numba/issues/8622 is real
and the suggested `numba.literally` hack doesn't seem to work..
multichartz
Tyler Goodlet 2022-11-25 13:25:38 -05:00
parent e4a5dc55de
commit 1e586e7c85
1 changed files with 26 additions and 12 deletions

View File

@ -26,15 +26,20 @@ from typing import (
import msgspec import msgspec
import numpy as np import numpy as np
from numpy.lib import recfunctions as rfn from numpy.lib import recfunctions as rfn
from numba import njit, float64, int64 # , optional from numba import (
# import pyqtgraph as pg types,
# from PyQt5 import QtGui njit,
# from PyQt5.QtCore import QLineF, QPointF float64,
int64,
optional,
)
from numba.core.types.misc import StringLiteral
# from numba.extending import as_numba_type
from ._sharedmem import ( from ._sharedmem import (
ShmArray, ShmArray,
) )
# from .._profile import pg_profile_enabled, ms_slower_then # from ._source import numba_ohlc_dtype
from ._compression import ( from ._compression import (
ds_m4, ds_m4,
) )
@ -514,11 +519,17 @@ class OHLCBarsFmtr(IncrementalFormatter):
@staticmethod @staticmethod
@njit( @njit(
# TODO: for now need to construct this manually for readonly # NOTE: need to construct this manually for readonly
# arrays, see https://github.com/numba/numba/issues/4511 # arrays, see https://github.com/numba/numba/issues/4511
# ntypes.tuple((float64[:], float64[:], float64[:]))( # (
# numba_ohlc_dtype[::1], # contiguous # types.Array(
# numba_ohlc_dtype,
# 1,
# 'C',
# readonly=True,
# ),
# int64, # int64,
# types.unicode_type,
# optional(float64), # optional(float64),
# ), # ),
nogil=True nogil=True
@ -527,7 +538,7 @@ class OHLCBarsFmtr(IncrementalFormatter):
data: np.ndarray, data: np.ndarray,
start: int64, start: int64,
bar_gap: float64 = 0.43, bar_gap: float64 = 0.43,
index_field: str = 'index', # index_field: str,
) -> tuple[ ) -> tuple[
np.ndarray, np.ndarray,
@ -540,8 +551,10 @@ class OHLCBarsFmtr(IncrementalFormatter):
''' '''
size = int(data.shape[0] * 6) size = int(data.shape[0] * 6)
# XXX: see this for why the dtype might have to be defined outside
# the routine.
# https://github.com/numba/numba/issues/4098#issuecomment-493914533
x = np.zeros( x = np.zeros(
# data,
shape=size, shape=size,
dtype=float64, dtype=float64,
) )
@ -559,7 +572,8 @@ class OHLCBarsFmtr(IncrementalFormatter):
high = q['high'] high = q['high']
low = q['low'] low = q['low']
close = q['close'] close = q['close']
index = float64(q[index_field]) # index = float64(q[index_field])
index = float64(q['index'])
istart = i * 6 istart = i * 6
istop = istart + 6 istop = istart + 6
@ -615,8 +629,8 @@ class OHLCBarsFmtr(IncrementalFormatter):
x, y, c = self.path_arrays_from_ohlc( x, y, c = self.path_arrays_from_ohlc(
array, array,
start, start,
# self.index_field,
bar_gap=w, bar_gap=w,
index_field=self.index_field,
) )
return x, y, c return x, y, c