wip
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Any
|
||||
from typing import Any
|
||||
|
||||
from cereal import log
|
||||
|
||||
@@ -12,7 +12,7 @@ class NPQueue:
|
||||
def __len__(self) -> int:
|
||||
return len(self.arr)
|
||||
|
||||
def append(self, pt: List[float]) -> None:
|
||||
def append(self, pt: list[float]) -> None:
|
||||
if len(self.arr) < self.maxlen:
|
||||
self.arr = np.append(self.arr, [pt], axis=0)
|
||||
else:
|
||||
@@ -21,7 +21,7 @@ class NPQueue:
|
||||
|
||||
|
||||
class PointBuckets:
|
||||
def __init__(self, x_bounds: List[Tuple[float, float]], min_points: List[float], min_points_total: int, points_per_bucket: int, rowsize: int) -> None:
|
||||
def __init__(self, x_bounds: list[tuple[float, float]], min_points: list[float], min_points_total: int, points_per_bucket: int, rowsize: int) -> None:
|
||||
self.x_bounds = x_bounds
|
||||
self.buckets = {bounds: NPQueue(maxlen=points_per_bucket, rowsize=rowsize) for bounds in x_bounds}
|
||||
self.buckets_min_points = dict(zip(x_bounds, min_points, strict=True))
|
||||
@@ -41,13 +41,13 @@ class PointBuckets:
|
||||
def add_point(self, x: float, y: float, bucket_val: float) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_points(self, num_points: Optional[int] = None) -> Any:
|
||||
def get_points(self, num_points: int = None) -> Any:
|
||||
points = np.vstack([x.arr for x in self.buckets.values()])
|
||||
if num_points is None:
|
||||
return points
|
||||
return points[np.random.choice(np.arange(len(points)), min(len(points), num_points), replace=False)]
|
||||
|
||||
def load_points(self, points: List[List[float]]) -> None:
|
||||
def load_points(self, points: list[list[float]]) -> None:
|
||||
for point in points:
|
||||
self.add_point(*point)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user