openpilot v0.9.6 release

date: 2024-01-12T10:13:37
master commit: ba792d576a49a0899b88a753fa1c52956bedf9e6
This commit is contained in:
FrogAi
2024-01-12 22:39:28 -07:00
commit 08e9fb1edc
1881 changed files with 653708 additions and 0 deletions

51
tools/lib/README.md Normal file
View File

@@ -0,0 +1,51 @@
## LogReader
Route is a class for conveniently accessing all the [logs](/system/loggerd/) from your routes. The LogReader class reads the non-video logs, i.e. rlog.bz2 and qlog.bz2. There's also a matching FrameReader class for reading the videos.
```python
from openpilot.tools.lib.route import Route
from openpilot.tools.lib.logreader import LogReader
r = Route("a2a0ccea32023010|2023-07-27--13-01-19")
# get a list of paths for the route's rlog files
print(r.log_paths())
# and road camera (fcamera.hevc) files
print(r.camera_paths())
# setup a LogReader to read the route's first rlog
lr = LogReader(r.log_paths()[0])
# print out all the messages in the log
import codecs
codecs.register_error("strict", codecs.backslashreplace_errors)
for msg in lr:
print(msg)
# setup a LogReader for the route's second qlog
lr = LogReader(r.log_paths()[1])
# print all the steering angles values from the log
for msg in lr:
if msg.which() == "carState":
print(msg.carState.steeringAngleDeg)
```
### MultiLogIterator
`MultiLogIterator` is similar to `LogReader`, but reads multiple logs.
```python
from openpilot.tools.lib.route import Route
from openpilot.tools.lib.logreader import MultiLogIterator
# setup a MultiLogIterator to read all the logs in the route
r = Route("a2a0ccea32023010|2023-07-27--13-01-19")
lr = MultiLogIterator(r.log_paths())
# print all the steering angles values from all the logs in the route
for msg in lr:
if msg.which() == "carState":
print(msg.carState.steeringAngleDeg)
```

0
tools/lib/__init__.py Normal file
View File

34
tools/lib/api.py Normal file
View File

@@ -0,0 +1,34 @@
import os
import requests
API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com')
class CommaApi():
def __init__(self, token=None):
self.session = requests.Session()
self.session.headers['User-agent'] = 'OpenpilotTools'
if token:
self.session.headers['Authorization'] = 'JWT ' + token
def request(self, method, endpoint, **kwargs):
resp = self.session.request(method, API_HOST + '/' + endpoint, **kwargs)
resp_json = resp.json()
if isinstance(resp_json, dict) and resp_json.get('error'):
if resp.status_code in [401, 403]:
raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py')
e = APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error'])))
e.status_code = resp.status_code
raise e
return resp_json
def get(self, endpoint, **kwargs):
return self.request('GET', endpoint, **kwargs)
def post(self, endpoint, **kwargs):
return self.request('POST', endpoint, **kwargs)
class APIError(Exception):
pass
class UnauthorizedError(Exception):
pass

145
tools/lib/auth.py Executable file
View File

@@ -0,0 +1,145 @@
#!/usr/bin/env python3
"""
Usage::
usage: auth.py [-h] [{google,apple,github,jwt}] [jwt]
Login to your comma account
positional arguments:
{google,apple,github,jwt}
jwt
optional arguments:
-h, --help show this help message and exit
Examples::
./auth.py # Log in with google account
./auth.py github # Log in with GitHub Account
./auth.py jwt ey......hw # Log in with a JWT from https://jwt.comma.ai, for use in CI
"""
import argparse
import sys
import pprint
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Dict
from urllib.parse import parse_qs, urlencode
from openpilot.tools.lib.api import APIError, CommaApi, UnauthorizedError
from openpilot.tools.lib.auth_config import set_token, get_token
PORT = 3000
class ClientRedirectServer(HTTPServer):
query_params: Dict[str, Any] = {}
class ClientRedirectHandler(BaseHTTPRequestHandler):
def do_GET(self):
if not self.path.startswith('/auth'):
self.send_response(204)
return
query = self.path.split('?', 1)[-1]
query_parsed = parse_qs(query, keep_blank_values=True)
self.server.query_params = query_parsed
self.send_response(200)
self.send_header('Content-type', 'text/plain')
self.end_headers()
self.wfile.write(b'Return to the CLI to continue')
def log_message(self, *args):
pass # this prevent http server from dumping messages to stdout
def auth_redirect_link(method):
provider_id = {
'google': 'g',
'apple': 'a',
'github': 'h',
}[method]
params = {
'redirect_uri': f"https://api.comma.ai/v2/auth/{provider_id}/redirect/",
'state': f'service,localhost:{PORT}',
}
if method == 'google':
params.update({
'type': 'web_server',
'client_id': '45471411055-ornt4svd2miog6dnopve7qtmh5mnu6id.apps.googleusercontent.com',
'response_type': 'code',
'scope': 'https://www.googleapis.com/auth/userinfo.email',
'prompt': 'select_account',
})
return 'https://accounts.google.com/o/oauth2/auth?' + urlencode(params)
elif method == 'github':
params.update({
'client_id': '28c4ecb54bb7272cb5a4',
'scope': 'read:user',
})
return 'https://github.com/login/oauth/authorize?' + urlencode(params)
elif method == 'apple':
params.update({
'client_id': 'ai.comma.login',
'response_type': 'code',
'response_mode': 'form_post',
'scope': 'name email',
})
return 'https://appleid.apple.com/auth/authorize?' + urlencode(params)
else:
raise NotImplementedError(f"no redirect implemented for method {method}")
def login(method):
oauth_uri = auth_redirect_link(method)
web_server = ClientRedirectServer(('localhost', PORT), ClientRedirectHandler)
print(f'To sign in, use your browser and navigate to {oauth_uri}')
webbrowser.open(oauth_uri, new=2)
while True:
web_server.handle_request()
if 'code' in web_server.query_params:
break
elif 'error' in web_server.query_params:
print('Authentication Error: "{}". Description: "{}" '.format(
web_server.query_params['error'],
web_server.query_params.get('error_description')), file=sys.stderr)
break
try:
auth_resp = CommaApi().post('v2/auth/', data={'code': web_server.query_params['code'], 'provider': web_server.query_params['provider']})
set_token(auth_resp['access_token'])
except APIError as e:
print(f'Authentication Error: {e}', file=sys.stderr)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Login to your comma account')
parser.add_argument('method', default='google', const='google', nargs='?', choices=['google', 'apple', 'github', 'jwt'])
parser.add_argument('jwt', nargs='?')
args = parser.parse_args()
if args.method == 'jwt':
if args.jwt is None:
print("method JWT selected, but no JWT was provided")
exit(1)
set_token(args.jwt)
else:
login(args.method)
try:
me = CommaApi(token=get_token()).get('/v1/me')
print("Authenticated!")
pprint.pprint(me)
except UnauthorizedError:
print("Got invalid JWT")
exit(1)

29
tools/lib/auth_config.py Normal file
View File

@@ -0,0 +1,29 @@
import json
import os
from openpilot.system.hardware.hw import Paths
class MissingAuthConfigError(Exception):
pass
def get_token():
try:
with open(os.path.join(Paths.config_root(), 'auth.json')) as f:
auth = json.load(f)
return auth['access_token']
except Exception:
return None
def set_token(token):
os.makedirs(Paths.config_root(), exist_ok=True)
with open(os.path.join(Paths.config_root(), 'auth.json'), 'w') as f:
json.dump({'access_token': token}, f)
def clear_token():
try:
os.unlink(os.path.join(Paths.config_root(), 'auth.json'))
except FileNotFoundError:
pass

63
tools/lib/bootlog.py Normal file
View File

@@ -0,0 +1,63 @@
import datetime
import functools
import re
from typing import List, Optional
from openpilot.tools.lib.auth_config import get_token
from openpilot.tools.lib.api import CommaApi
from openpilot.tools.lib.helpers import RE, timestamp_to_datetime
@functools.total_ordering
class Bootlog:
def __init__(self, url: str):
self._url = url
r = re.search(RE.BOOTLOG_NAME, url)
if not r:
raise Exception(f"Unable to parse: {url}")
self._dongle_id = r.group('dongle_id')
self._timestamp = r.group('timestamp')
@property
def url(self) -> str:
return self._url
@property
def dongle_id(self) -> str:
return self._dongle_id
@property
def timestamp(self) -> str:
return self._timestamp
@property
def datetime(self) -> datetime.datetime:
return timestamp_to_datetime(self._timestamp)
def __str__(self):
return f"{self._dongle_id}|{self._timestamp}"
def __eq__(self, b) -> bool:
if not isinstance(b, Bootlog):
return False
return self.datetime == b.datetime
def __lt__(self, b) -> bool:
if not isinstance(b, Bootlog):
return False
return self.datetime < b.datetime
def get_bootlog_from_id(bootlog_id: str) -> Optional[Bootlog]:
# TODO: implement an API endpoint for this
bl = Bootlog(bootlog_id)
for b in get_bootlogs(bl.dongle_id):
if b == bl:
return b
return None
def get_bootlogs(dongle_id: str) -> List[Bootlog]:
api = CommaApi(get_token())
r = api.get(f'v1/devices/{dongle_id}/bootlogs')
return [Bootlog(b) for b in r]

14
tools/lib/cache.py Normal file
View File

@@ -0,0 +1,14 @@
import os
import urllib.parse
DEFAULT_CACHE_DIR = os.getenv("CACHE_ROOT", os.path.expanduser("~/.commacache"))
def cache_path_for_file_path(fn, cache_dir=DEFAULT_CACHE_DIR):
dir_ = os.path.join(cache_dir, "local")
os.makedirs(dir_, exist_ok=True)
fn_parsed = urllib.parse.urlparse(fn)
if fn_parsed.scheme == '':
cache_fn = os.path.abspath(fn).replace("/", "_")
else:
cache_fn = f'{fn_parsed.hostname}_{fn_parsed.path.replace("/", "_")}'
return os.path.join(dir_, cache_fn)

2
tools/lib/exceptions.py Normal file
View File

@@ -0,0 +1,2 @@
class DataUnreadableError(Exception):
pass

15
tools/lib/filereader.py Normal file
View File

@@ -0,0 +1,15 @@
import os
from openpilot.tools.lib.url_file import URLFile
DATA_ENDPOINT = os.getenv("DATA_ENDPOINT", "http://data-raw.comma.internal/")
def resolve_name(fn):
if fn.startswith("cd:/"):
return fn.replace("cd:/", DATA_ENDPOINT)
return fn
def FileReader(fn, debug=False):
fn = resolve_name(fn)
if fn.startswith(("http://", "https://")):
return URLFile(fn, debug=debug)
return open(fn, "rb")

537
tools/lib/framereader.py Normal file
View File

@@ -0,0 +1,537 @@
import json
import os
import pickle
import struct
import subprocess
import threading
from enum import IntEnum
from functools import wraps
import numpy as np
from lru import LRU
import _io
from openpilot.tools.lib.cache import cache_path_for_file_path, DEFAULT_CACHE_DIR
from openpilot.tools.lib.exceptions import DataUnreadableError
from openpilot.tools.lib.vidindex import hevc_index
from openpilot.common.file_helpers import atomic_write_in_dir
from openpilot.tools.lib.filereader import FileReader, resolve_name
HEVC_SLICE_B = 0
HEVC_SLICE_P = 1
HEVC_SLICE_I = 2
class GOPReader:
def get_gop(self, num):
# returns (start_frame_num, num_frames, frames_to_skip, gop_data)
raise NotImplementedError
class DoNothingContextManager:
def __enter__(self):
return self
def __exit__(self, *x):
pass
class FrameType(IntEnum):
raw = 1
h265_stream = 2
def fingerprint_video(fn):
with FileReader(fn) as f:
header = f.read(4)
if len(header) == 0:
raise DataUnreadableError(f"{fn} is empty")
elif header == b"\x00\xc0\x12\x00":
return FrameType.raw
elif header == b"\x00\x00\x00\x01":
if 'hevc' in fn:
return FrameType.h265_stream
else:
raise NotImplementedError(fn)
else:
raise NotImplementedError(fn)
def ffprobe(fn, fmt=None):
fn = resolve_name(fn)
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", "-show_streams"]
if fmt:
cmd += ["-f", fmt]
cmd += ["-i", "-"]
try:
with FileReader(fn) as f:
ffprobe_output = subprocess.check_output(cmd, input=f.read(4096))
except subprocess.CalledProcessError as e:
raise DataUnreadableError(fn) from e
return json.loads(ffprobe_output)
def cache_fn(func):
@wraps(func)
def cache_inner(fn, *args, **kwargs):
if kwargs.pop('no_cache', None):
cache_path = None
else:
cache_dir = kwargs.pop('cache_dir', DEFAULT_CACHE_DIR)
cache_path = cache_path_for_file_path(fn, cache_dir)
if cache_path and os.path.exists(cache_path):
with open(cache_path, "rb") as cache_file:
cache_value = pickle.load(cache_file)
else:
cache_value = func(fn, *args, **kwargs)
if cache_path:
with atomic_write_in_dir(cache_path, mode="wb", overwrite=True) as cache_file:
pickle.dump(cache_value, cache_file, -1)
return cache_value
return cache_inner
@cache_fn
def index_stream(fn, ft):
if ft != FrameType.h265_stream:
raise NotImplementedError("Only h265 supported")
frame_types, dat_len, prefix = hevc_index(fn)
index = np.array(frame_types + [(0xFFFFFFFF, dat_len)], dtype=np.uint32)
probe = ffprobe(fn, "hevc")
return {
'index': index,
'global_prefix': prefix,
'probe': probe
}
def get_video_index(fn, frame_type, cache_dir=DEFAULT_CACHE_DIR):
return index_stream(fn, frame_type, cache_dir=cache_dir)
def read_file_check_size(f, sz, cookie):
buff = bytearray(sz)
bytes_read = f.readinto(buff)
assert bytes_read == sz, (bytes_read, sz)
return buff
def rgb24toyuv(rgb):
yuv_from_rgb = np.array([[ 0.299 , 0.587 , 0.114 ],
[-0.14714119, -0.28886916, 0.43601035 ],
[ 0.61497538, -0.51496512, -0.10001026 ]])
img = np.dot(rgb.reshape(-1, 3), yuv_from_rgb.T).reshape(rgb.shape)
ys = img[:, :, 0]
us = (img[::2, ::2, 1] + img[1::2, ::2, 1] + img[::2, 1::2, 1] + img[1::2, 1::2, 1]) / 4 + 128
vs = (img[::2, ::2, 2] + img[1::2, ::2, 2] + img[::2, 1::2, 2] + img[1::2, 1::2, 2]) / 4 + 128
return ys, us, vs
def rgb24toyuv420(rgb):
ys, us, vs = rgb24toyuv(rgb)
y_len = rgb.shape[0] * rgb.shape[1]
uv_len = y_len // 4
yuv420 = np.empty(y_len + 2 * uv_len, dtype=rgb.dtype)
yuv420[:y_len] = ys.reshape(-1)
yuv420[y_len:y_len + uv_len] = us.reshape(-1)
yuv420[y_len + uv_len:y_len + 2 * uv_len] = vs.reshape(-1)
return yuv420.clip(0, 255).astype('uint8')
def rgb24tonv12(rgb):
ys, us, vs = rgb24toyuv(rgb)
y_len = rgb.shape[0] * rgb.shape[1]
uv_len = y_len // 4
nv12 = np.empty(y_len + 2 * uv_len, dtype=rgb.dtype)
nv12[:y_len] = ys.reshape(-1)
nv12[y_len::2] = us.reshape(-1)
nv12[y_len+1::2] = vs.reshape(-1)
return nv12.clip(0, 255).astype('uint8')
def decompress_video_data(rawdat, vid_fmt, w, h, pix_fmt):
threads = os.getenv("FFMPEG_THREADS", "0")
cuda = os.getenv("FFMPEG_CUDA", "0") == "1"
args = ["ffmpeg", "-v", "quiet",
"-threads", threads,
"-hwaccel", "none" if not cuda else "cuda",
"-c:v", "hevc",
"-vsync", "0",
"-f", vid_fmt,
"-flags2", "showall",
"-i", "-",
"-threads", threads,
"-f", "rawvideo",
"-pix_fmt", pix_fmt,
"-"]
dat = subprocess.check_output(args, input=rawdat)
if pix_fmt == "rgb24":
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, h, w, 3)
elif pix_fmt == "nv12":
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, (h*w*3//2))
elif pix_fmt == "yuv420p":
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, (h*w*3//2))
elif pix_fmt == "yuv444p":
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, 3, h, w)
else:
raise NotImplementedError
return ret
class BaseFrameReader:
# properties: frame_type, frame_count, w, h
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
def close(self):
pass
def get(self, num, count=1, pix_fmt="yuv420p"):
raise NotImplementedError
def FrameReader(fn, cache_dir=DEFAULT_CACHE_DIR, readahead=False, readbehind=False, index_data=None):
frame_type = fingerprint_video(fn)
if frame_type == FrameType.raw:
return RawFrameReader(fn)
elif frame_type in (FrameType.h265_stream,):
if not index_data:
index_data = get_video_index(fn, frame_type, cache_dir)
return StreamFrameReader(fn, frame_type, index_data, readahead=readahead, readbehind=readbehind)
else:
raise NotImplementedError(frame_type)
class RawData:
def __init__(self, f):
self.f = _io.FileIO(f, 'rb')
self.lenn = struct.unpack("I", self.f.read(4))[0]
self.count = os.path.getsize(f) / (self.lenn+4)
def read(self, i):
self.f.seek((self.lenn+4)*i + 4)
return self.f.read(self.lenn)
class RawFrameReader(BaseFrameReader):
def __init__(self, fn):
# raw camera
self.fn = fn
self.frame_type = FrameType.raw
self.rawfile = RawData(self.fn)
self.frame_count = self.rawfile.count
self.w, self.h = 640, 480
def load_and_debayer(self, img):
img = np.frombuffer(img, dtype='uint8').reshape(960, 1280)
cimg = np.dstack([img[0::2, 1::2], ((img[0::2, 0::2].astype("uint16") + img[1::2, 1::2].astype("uint16")) >> 1).astype("uint8"), img[1::2, 0::2]])
return cimg
def get(self, num, count=1, pix_fmt="yuv420p"):
assert self.frame_count is not None
assert num+count <= self.frame_count
if pix_fmt not in ("nv12", "yuv420p", "rgb24"):
raise ValueError(f"Unsupported pixel format {pix_fmt!r}")
app = []
for i in range(num, num+count):
dat = self.rawfile.read(i)
rgb_dat = self.load_and_debayer(dat)
if pix_fmt == "rgb24":
app.append(rgb_dat)
elif pix_fmt == "nv12":
app.append(rgb24tonv12(rgb_dat))
elif pix_fmt == "yuv420p":
app.append(rgb24toyuv420(rgb_dat))
else:
raise NotImplementedError
return app
class VideoStreamDecompressor:
def __init__(self, fn, vid_fmt, w, h, pix_fmt):
self.fn = fn
self.vid_fmt = vid_fmt
self.w = w
self.h = h
self.pix_fmt = pix_fmt
if pix_fmt in ("nv12", "yuv420p"):
self.out_size = w*h*3//2 # yuv420p
elif pix_fmt in ("rgb24", "yuv444p"):
self.out_size = w*h*3
else:
raise NotImplementedError
self.proc = None
self.t = threading.Thread(target=self.write_thread)
self.t.daemon = True
def write_thread(self):
try:
with FileReader(self.fn) as f:
while True:
r = f.read(1024*1024)
if len(r) == 0:
break
self.proc.stdin.write(r)
except BrokenPipeError:
pass
finally:
self.proc.stdin.close()
def read(self):
threads = os.getenv("FFMPEG_THREADS", "0")
cuda = os.getenv("FFMPEG_CUDA", "0") == "1"
cmd = [
"ffmpeg",
"-threads", threads,
"-hwaccel", "none" if not cuda else "cuda",
"-c:v", "hevc",
# "-avioflags", "direct",
"-analyzeduration", "0",
"-probesize", "32",
"-flush_packets", "0",
# "-fflags", "nobuffer",
"-vsync", "0",
"-f", self.vid_fmt,
"-i", "pipe:0",
"-threads", threads,
"-f", "rawvideo",
"-pix_fmt", self.pix_fmt,
"pipe:1"
]
self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
try:
self.t.start()
while True:
dat = self.proc.stdout.read(self.out_size)
if len(dat) == 0:
break
assert len(dat) == self.out_size
if self.pix_fmt == "rgb24":
ret = np.frombuffer(dat, dtype=np.uint8).reshape((self.h, self.w, 3))
elif self.pix_fmt == "yuv420p":
ret = np.frombuffer(dat, dtype=np.uint8)
elif self.pix_fmt == "nv12":
ret = np.frombuffer(dat, dtype=np.uint8)
elif self.pix_fmt == "yuv444p":
ret = np.frombuffer(dat, dtype=np.uint8).reshape((3, self.h, self.w))
else:
raise RuntimeError(f"unknown pix_fmt: {self.pix_fmt}")
yield ret
result_code = self.proc.wait()
assert result_code == 0, result_code
finally:
self.proc.kill()
self.t.join()
class StreamGOPReader(GOPReader):
def __init__(self, fn, frame_type, index_data):
assert frame_type == FrameType.h265_stream
self.fn = fn
self.frame_type = frame_type
self.frame_count = None
self.w, self.h = None, None
self.prefix = None
self.index = None
self.index = index_data['index']
self.prefix = index_data['global_prefix']
probe = index_data['probe']
self.prefix_frame_data = None
self.num_prefix_frames = 0
self.vid_fmt = "hevc"
i = 0
while i < self.index.shape[0] and self.index[i, 0] != HEVC_SLICE_I:
i += 1
self.first_iframe = i
assert self.first_iframe == 0
self.frame_count = len(self.index) - 1
self.w = probe['streams'][0]['width']
self.h = probe['streams'][0]['height']
def _lookup_gop(self, num):
frame_b = num
while frame_b > 0 and self.index[frame_b, 0] != HEVC_SLICE_I:
frame_b -= 1
frame_e = num + 1
while frame_e < (len(self.index) - 1) and self.index[frame_e, 0] != HEVC_SLICE_I:
frame_e += 1
offset_b = self.index[frame_b, 1]
offset_e = self.index[frame_e, 1]
return (frame_b, frame_e, offset_b, offset_e)
def get_gop(self, num):
frame_b, frame_e, offset_b, offset_e = self._lookup_gop(num)
assert frame_b <= num < frame_e
num_frames = frame_e - frame_b
with FileReader(self.fn) as f:
f.seek(offset_b)
rawdat = f.read(offset_e - offset_b)
if num < self.first_iframe:
assert self.prefix_frame_data
rawdat = self.prefix_frame_data + rawdat
rawdat = self.prefix + rawdat
skip_frames = 0
if num < self.first_iframe:
skip_frames = self.num_prefix_frames
return frame_b, num_frames, skip_frames, rawdat
class GOPFrameReader(BaseFrameReader):
#FrameReader with caching and readahead for formats that are group-of-picture based
def __init__(self, readahead=False, readbehind=False):
self.open_ = True
self.readahead = readahead
self.readbehind = readbehind
self.frame_cache = LRU(64)
if self.readahead:
self.cache_lock = threading.RLock()
self.readahead_last = None
self.readahead_len = 30
self.readahead_c = threading.Condition()
self.readahead_thread = threading.Thread(target=self._readahead_thread)
self.readahead_thread.daemon = True
self.readahead_thread.start()
else:
self.cache_lock = DoNothingContextManager()
def close(self):
if not self.open_:
return
self.open_ = False
if self.readahead:
self.readahead_c.acquire()
self.readahead_c.notify()
self.readahead_c.release()
self.readahead_thread.join()
def _readahead_thread(self):
while True:
self.readahead_c.acquire()
try:
if not self.open_:
break
self.readahead_c.wait()
finally:
self.readahead_c.release()
if not self.open_:
break
assert self.readahead_last
num, pix_fmt = self.readahead_last
if self.readbehind:
for k in range(num - 1, max(0, num - self.readahead_len), -1):
self._get_one(k, pix_fmt)
else:
for k in range(num, min(self.frame_count, num + self.readahead_len)):
self._get_one(k, pix_fmt)
def _get_one(self, num, pix_fmt):
assert num < self.frame_count
if (num, pix_fmt) in self.frame_cache:
return self.frame_cache[(num, pix_fmt)]
with self.cache_lock:
if (num, pix_fmt) in self.frame_cache:
return self.frame_cache[(num, pix_fmt)]
frame_b, num_frames, skip_frames, rawdat = self.get_gop(num)
ret = decompress_video_data(rawdat, self.vid_fmt, self.w, self.h, pix_fmt)
ret = ret[skip_frames:]
assert ret.shape[0] == num_frames
for i in range(ret.shape[0]):
self.frame_cache[(frame_b+i, pix_fmt)] = ret[i]
return self.frame_cache[(num, pix_fmt)]
def get(self, num, count=1, pix_fmt="yuv420p"):
assert self.frame_count is not None
if num + count > self.frame_count:
raise ValueError(f"{num + count} > {self.frame_count}")
if pix_fmt not in ("nv12", "yuv420p", "rgb24", "yuv444p"):
raise ValueError(f"Unsupported pixel format {pix_fmt!r}")
ret = [self._get_one(num + i, pix_fmt) for i in range(count)]
if self.readahead:
self.readahead_last = (num+count, pix_fmt)
self.readahead_c.acquire()
self.readahead_c.notify()
self.readahead_c.release()
return ret
class StreamFrameReader(StreamGOPReader, GOPFrameReader):
def __init__(self, fn, frame_type, index_data, readahead=False, readbehind=False):
StreamGOPReader.__init__(self, fn, frame_type, index_data)
GOPFrameReader.__init__(self, readahead, readbehind)
def GOPFrameIterator(gop_reader, pix_fmt):
dec = VideoStreamDecompressor(gop_reader.fn, gop_reader.vid_fmt, gop_reader.w, gop_reader.h, pix_fmt)
yield from dec.read()
def FrameIterator(fn, pix_fmt, **kwargs):
fr = FrameReader(fn, **kwargs)
if isinstance(fr, GOPReader):
yield from GOPFrameIterator(fr, pix_fmt)
else:
for i in range(fr.frame_count):
yield fr.get(i, pix_fmt=pix_fmt)[0]

35
tools/lib/helpers.py Normal file
View File

@@ -0,0 +1,35 @@
import bz2
import datetime
TIME_FMT = "%Y-%m-%d--%H-%M-%S"
# regex patterns
class RE:
DONGLE_ID = r'(?P<dongle_id>[a-z0-9]{16})'
TIMESTAMP = r'(?P<timestamp>[0-9]{4}-[0-9]{2}-[0-9]{2}--[0-9]{2}-[0-9]{2}-[0-9]{2})'
ROUTE_NAME = r'(?P<route_name>{}[|_/]{})'.format(DONGLE_ID, TIMESTAMP)
SEGMENT_NAME = r'{}(?:--|/)(?P<segment_num>[0-9]+)'.format(ROUTE_NAME)
INDEX = r'-?[0-9]+'
SLICE = r'(?P<start>{})?:?(?P<end>{})?:?(?P<step>{})?'.format(INDEX, INDEX, INDEX)
SEGMENT_RANGE = r'{}(?:--|/)?(?P<slice>({}))?/?(?P<selector>([qr]))?'.format(ROUTE_NAME, SLICE)
BOOTLOG_NAME = ROUTE_NAME
EXPLORER_FILE = r'^(?P<segment_name>{})--(?P<file_name>[a-z]+\.[a-z0-9]+)$'.format(SEGMENT_NAME)
OP_SEGMENT_DIR = r'^(?P<segment_name>{})$'.format(SEGMENT_NAME)
def timestamp_to_datetime(t: str) -> datetime.datetime:
"""
Convert an openpilot route timestamp to a python datetime
"""
return datetime.datetime.strptime(t, TIME_FMT)
def save_log(dest, log_msgs, compress=True):
dat = b"".join(msg.as_builder().to_bytes() for msg in log_msgs)
if compress:
dat = bz2.compress(dat)
with open(dest, "wb") as f:
f.write(dat)

81
tools/lib/kbhit.py Executable file
View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python
import sys
import termios
import atexit
from select import select
STDIN_FD = sys.stdin.fileno()
class KBHit:
def __init__(self) -> None:
''' Creates a KBHit object that you can call to do various keyboard things.
'''
self.set_kbhit_terminal()
def set_kbhit_terminal(self) -> None:
''' Save old terminal settings for closure, remove ICANON & ECHO flags.
'''
# Save the terminal settings
self.old_term = termios.tcgetattr(STDIN_FD)
self.new_term = self.old_term.copy()
# New terminal setting unbuffered
self.new_term[3] &= ~(termios.ICANON | termios.ECHO)
termios.tcsetattr(STDIN_FD, termios.TCSAFLUSH, self.new_term)
# Support normal-terminal reset at exit
atexit.register(self.set_normal_term)
def set_normal_term(self) -> None:
''' Resets to normal terminal. On Windows this is a no-op.
'''
termios.tcsetattr(STDIN_FD, termios.TCSAFLUSH, self.old_term)
@staticmethod
def getch() -> str:
''' Returns a keyboard character after kbhit() has been called.
Should not be called in the same program as getarrow().
'''
return sys.stdin.read(1)
@staticmethod
def getarrow() -> int:
''' Returns an arrow-key code after kbhit() has been called. Codes are
0 : up
1 : right
2 : down
3 : left
Should not be called in the same program as getch().
'''
c = sys.stdin.read(3)[2]
vals = [65, 67, 66, 68]
return vals.index(ord(c))
@staticmethod
def kbhit():
''' Returns True if keyboard character was hit, False otherwise.
'''
return select([sys.stdin], [], [], 0)[0] != []
# Test
if __name__ == "__main__":
kb = KBHit()
print('Hit any key, or ESC to exit')
while True:
if kb.kbhit():
c = kb.getch()
if c == '\x1b': # ESC
break
print(c)
kb.set_normal_term()

141
tools/lib/logreader.py Executable file
View File

@@ -0,0 +1,141 @@
#!/usr/bin/env python3
import os
import sys
import bz2
import urllib.parse
import capnp
import warnings
from typing import Iterable, Iterator
from cereal import log as capnp_log
from openpilot.tools.lib.filereader import FileReader
from openpilot.tools.lib.route import Route, SegmentName
LogIterable = Iterable[capnp._DynamicStructReader]
# this is an iterator itself, and uses private variables from LogReader
class MultiLogIterator:
def __init__(self, log_paths, sort_by_time=False):
self._log_paths = log_paths
self.sort_by_time = sort_by_time
self._first_log_idx = next(i for i in range(len(log_paths)) if log_paths[i] is not None)
self._current_log = self._first_log_idx
self._idx = 0
self._log_readers = [None]*len(log_paths)
self.start_time = self._log_reader(self._first_log_idx)._ts[0]
def _log_reader(self, i):
if self._log_readers[i] is None and self._log_paths[i] is not None:
log_path = self._log_paths[i]
self._log_readers[i] = LogReader(log_path, sort_by_time=self.sort_by_time)
return self._log_readers[i]
def __iter__(self) -> Iterator[capnp._DynamicStructReader]:
return self
def _inc(self):
lr = self._log_reader(self._current_log)
if self._idx < len(lr._ents)-1:
self._idx += 1
else:
self._idx = 0
self._current_log = next(i for i in range(self._current_log + 1, len(self._log_readers) + 1)
if i == len(self._log_readers) or self._log_paths[i] is not None)
if self._current_log == len(self._log_readers):
raise StopIteration
def __next__(self):
while 1:
lr = self._log_reader(self._current_log)
ret = lr._ents[self._idx]
self._inc()
return ret
def tell(self):
# returns seconds from start of log
return (self._log_reader(self._current_log)._ts[self._idx] - self.start_time) * 1e-9
def seek(self, ts):
# seek to nearest minute
minute = int(ts/60)
if minute >= len(self._log_paths) or self._log_paths[minute] is None:
return False
self._current_log = minute
# HACK: O(n) seek afterward
self._idx = 0
while self.tell() < ts:
self._inc()
return True
def reset(self):
self.__init__(self._log_paths, sort_by_time=self.sort_by_time)
class LogReader:
def __init__(self, fn, canonicalize=True, only_union_types=False, sort_by_time=False, dat=None):
self.data_version = None
self._only_union_types = only_union_types
ext = None
if not dat:
_, ext = os.path.splitext(urllib.parse.urlparse(fn).path)
if ext not in ('', '.bz2'):
# old rlogs weren't bz2 compressed
raise Exception(f"unknown extension {ext}")
with FileReader(fn) as f:
dat = f.read()
if ext == ".bz2" or dat.startswith(b'BZh9'):
dat = bz2.decompress(dat)
ents = capnp_log.Event.read_multiple_bytes(dat)
_ents = []
try:
for e in ents:
_ents.append(e)
except capnp.KjException:
warnings.warn("Corrupted events detected", RuntimeWarning, stacklevel=1)
self._ents = list(sorted(_ents, key=lambda x: x.logMonoTime) if sort_by_time else _ents)
self._ts = [x.logMonoTime for x in self._ents]
@classmethod
def from_bytes(cls, dat):
return cls("", dat=dat)
def __iter__(self) -> Iterator[capnp._DynamicStructReader]:
for ent in self._ents:
if self._only_union_types:
try:
ent.which()
yield ent
except capnp.lib.capnp.KjException:
pass
else:
yield ent
def logreader_from_route_or_segment(r, sort_by_time=False):
sn = SegmentName(r, allow_route_name=True)
route = Route(sn.route_name.canonical_name)
if sn.segment_num < 0:
return MultiLogIterator(route.log_paths(), sort_by_time=sort_by_time)
else:
return LogReader(route.log_paths()[sn.segment_num], sort_by_time=sort_by_time)
if __name__ == "__main__":
import codecs
# capnproto <= 0.8.0 throws errors converting byte data to string
# below line catches those errors and replaces the bytes with \x__
codecs.register_error("strict", codecs.backslashreplace_errors)
log_path = sys.argv[1]
lr = LogReader(log_path, sort_by_time=True)
for msg in lr:
print(msg)

257
tools/lib/route.py Normal file
View File

@@ -0,0 +1,257 @@
import os
import re
from urllib.parse import urlparse
from collections import defaultdict
from itertools import chain
from typing import Optional
from openpilot.tools.lib.auth_config import get_token
from openpilot.tools.lib.api import CommaApi
from openpilot.tools.lib.helpers import RE
QLOG_FILENAMES = ['qlog', 'qlog.bz2']
QCAMERA_FILENAMES = ['qcamera.ts']
LOG_FILENAMES = ['rlog', 'rlog.bz2', 'raw_log.bz2']
CAMERA_FILENAMES = ['fcamera.hevc', 'video.hevc']
DCAMERA_FILENAMES = ['dcamera.hevc']
ECAMERA_FILENAMES = ['ecamera.hevc']
class Route:
def __init__(self, name, data_dir=None):
self._name = RouteName(name)
self.files = None
if data_dir is not None:
self._segments = self._get_segments_local(data_dir)
else:
self._segments = self._get_segments_remote()
self.max_seg_number = self._segments[-1].name.segment_num
@property
def name(self):
return self._name
@property
def segments(self):
return self._segments
def log_paths(self):
log_path_by_seg_num = {s.name.segment_num: s.log_path for s in self._segments}
return [log_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
def qlog_paths(self):
qlog_path_by_seg_num = {s.name.segment_num: s.qlog_path for s in self._segments}
return [qlog_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
def camera_paths(self):
camera_path_by_seg_num = {s.name.segment_num: s.camera_path for s in self._segments}
return [camera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
def dcamera_paths(self):
dcamera_path_by_seg_num = {s.name.segment_num: s.dcamera_path for s in self._segments}
return [dcamera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
def ecamera_paths(self):
ecamera_path_by_seg_num = {s.name.segment_num: s.ecamera_path for s in self._segments}
return [ecamera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
def qcamera_paths(self):
qcamera_path_by_seg_num = {s.name.segment_num: s.qcamera_path for s in self._segments}
return [qcamera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
# TODO: refactor this, it's super repetitive
def _get_segments_remote(self):
api = CommaApi(get_token())
route_files = api.get('v1/route/' + self.name.canonical_name + '/files')
self.files = list(chain.from_iterable(route_files.values()))
segments = {}
for url in self.files:
_, dongle_id, time_str, segment_num, fn = urlparse(url).path.rsplit('/', maxsplit=4)
segment_name = f'{dongle_id}|{time_str}--{segment_num}'
if segments.get(segment_name):
segments[segment_name] = Segment(
segment_name,
url if fn in LOG_FILENAMES else segments[segment_name].log_path,
url if fn in QLOG_FILENAMES else segments[segment_name].qlog_path,
url if fn in CAMERA_FILENAMES else segments[segment_name].camera_path,
url if fn in DCAMERA_FILENAMES else segments[segment_name].dcamera_path,
url if fn in ECAMERA_FILENAMES else segments[segment_name].ecamera_path,
url if fn in QCAMERA_FILENAMES else segments[segment_name].qcamera_path,
)
else:
segments[segment_name] = Segment(
segment_name,
url if fn in LOG_FILENAMES else None,
url if fn in QLOG_FILENAMES else None,
url if fn in CAMERA_FILENAMES else None,
url if fn in DCAMERA_FILENAMES else None,
url if fn in ECAMERA_FILENAMES else None,
url if fn in QCAMERA_FILENAMES else None,
)
return sorted(segments.values(), key=lambda seg: seg.name.segment_num)
def _get_segments_local(self, data_dir):
files = os.listdir(data_dir)
segment_files = defaultdict(list)
for f in files:
fullpath = os.path.join(data_dir, f)
explorer_match = re.match(RE.EXPLORER_FILE, f)
op_match = re.match(RE.OP_SEGMENT_DIR, f)
if explorer_match:
segment_name = explorer_match.group('segment_name')
fn = explorer_match.group('file_name')
if segment_name.replace('_', '|').startswith(self.name.canonical_name):
segment_files[segment_name].append((fullpath, fn))
elif op_match and os.path.isdir(fullpath):
segment_name = op_match.group('segment_name')
if segment_name.startswith(self.name.canonical_name):
for seg_f in os.listdir(fullpath):
segment_files[segment_name].append((os.path.join(fullpath, seg_f), seg_f))
elif f == self.name.canonical_name:
for seg_num in os.listdir(fullpath):
if not seg_num.isdigit():
continue
segment_name = f'{self.name.canonical_name}--{seg_num}'
for seg_f in os.listdir(os.path.join(fullpath, seg_num)):
segment_files[segment_name].append((os.path.join(fullpath, seg_num, seg_f), seg_f))
segments = []
for segment, files in segment_files.items():
try:
log_path = next(path for path, filename in files if filename in LOG_FILENAMES)
except StopIteration:
log_path = None
try:
qlog_path = next(path for path, filename in files if filename in QLOG_FILENAMES)
except StopIteration:
qlog_path = None
try:
camera_path = next(path for path, filename in files if filename in CAMERA_FILENAMES)
except StopIteration:
camera_path = None
try:
dcamera_path = next(path for path, filename in files if filename in DCAMERA_FILENAMES)
except StopIteration:
dcamera_path = None
try:
ecamera_path = next(path for path, filename in files if filename in ECAMERA_FILENAMES)
except StopIteration:
ecamera_path = None
try:
qcamera_path = next(path for path, filename in files if filename in QCAMERA_FILENAMES)
except StopIteration:
qcamera_path = None
segments.append(Segment(segment, log_path, qlog_path, camera_path, dcamera_path, ecamera_path, qcamera_path))
if len(segments) == 0:
raise ValueError(f'Could not find segments for route {self.name.canonical_name} in data directory {data_dir}')
return sorted(segments, key=lambda seg: seg.name.segment_num)
class Segment:
def __init__(self, name, log_path, qlog_path, camera_path, dcamera_path, ecamera_path, qcamera_path):
self._name = SegmentName(name)
self.log_path = log_path
self.qlog_path = qlog_path
self.camera_path = camera_path
self.dcamera_path = dcamera_path
self.ecamera_path = ecamera_path
self.qcamera_path = qcamera_path
@property
def name(self):
return self._name
class RouteName:
def __init__(self, name_str: str):
self._name_str = name_str
delim = next(c for c in self._name_str if c in ("|", "/"))
self._dongle_id, self._time_str = self._name_str.split(delim)
assert len(self._dongle_id) == 16, self._name_str
assert len(self._time_str) == 20, self._name_str
self._canonical_name = f"{self._dongle_id}|{self._time_str}"
@property
def canonical_name(self) -> str: return self._canonical_name
@property
def dongle_id(self) -> str: return self._dongle_id
@property
def time_str(self) -> str: return self._time_str
def __str__(self) -> str: return self._canonical_name
class SegmentName:
# TODO: add constructor that takes dongle_id, time_str, segment_num and then create instances
# of this class instead of manually constructing a segment name (use canonical_name prop instead)
def __init__(self, name_str: str, allow_route_name=False):
data_dir_path_separator_index = name_str.rsplit("|", 1)[0].rfind("/")
use_data_dir = (data_dir_path_separator_index != -1) and ("|" in name_str)
self._name_str = name_str[data_dir_path_separator_index + 1:] if use_data_dir else name_str
self._data_dir = name_str[:data_dir_path_separator_index] if use_data_dir else None
seg_num_delim = "--" if self._name_str.count("--") == 2 else "/"
name_parts = self._name_str.rsplit(seg_num_delim, 1)
if allow_route_name and len(name_parts) == 1:
name_parts.append("-1") # no segment number
self._route_name = RouteName(name_parts[0])
self._num = int(name_parts[1])
self._canonical_name = f"{self._route_name._dongle_id}|{self._route_name._time_str}--{self._num}"
@property
def canonical_name(self) -> str: return self._canonical_name
@property
def dongle_id(self) -> str: return self._route_name.dongle_id
@property
def time_str(self) -> str: return self._route_name.time_str
@property
def segment_num(self) -> int: return self._num
@property
def route_name(self) -> RouteName: return self._route_name
@property
def data_dir(self) -> Optional[str]: return self._data_dir
def __str__(self) -> str: return self._canonical_name
class SegmentRange:
def __init__(self, segment_range: str):
self.m = re.fullmatch(RE.SEGMENT_RANGE, segment_range)
assert self.m, f"Segment range is not valid {segment_range}"
@property
def route_name(self):
return self.m.group("route_name")
@property
def dongle_id(self):
return self.m.group("dongle_id")
@property
def timestamp(self):
return self.m.group("timestamp")
@property
def _slice(self):
return self.m.group("slice")
@property
def selector(self):
return self.m.group("selector")

141
tools/lib/srreader.py Normal file
View File

@@ -0,0 +1,141 @@
import enum
import numpy as np
import pathlib
import re
from urllib.parse import parse_qs, urlparse
from openpilot.selfdrive.test.openpilotci import get_url
from openpilot.tools.lib.helpers import RE
from openpilot.tools.lib.logreader import LogReader
from openpilot.tools.lib.route import Route, SegmentRange
class ReadMode(enum.StrEnum):
RLOG = "r" # only read rlogs
QLOG = "q" # only read qlogs
#AUTO = "a" # default to rlogs, fallback to qlogs, not supported yet
def create_slice_from_string(s: str):
m = re.fullmatch(RE.SLICE, s)
assert m is not None, f"Invalid slice: {s}"
start, end, step = m.groups()
start = int(start) if start is not None else None
end = int(end) if end is not None else None
step = int(step) if step is not None else None
if start is not None and ":" not in s and end is None and step is None:
return start
return slice(start, end, step)
def parse_slice(sr: SegmentRange):
route = Route(sr.route_name)
segs = np.arange(route.max_seg_number+1)
s = create_slice_from_string(sr._slice)
return segs[s] if isinstance(s, slice) else [segs[s]]
def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False):
segs = parse_slice(sr)
route = Route(sr.route_name)
log_paths = route.log_paths() if mode == ReadMode.RLOG else route.qlog_paths()
invalid_segs = [seg for seg in segs if log_paths[seg] is None]
assert not len(invalid_segs), f"Some of the requested segments are not available: {invalid_segs}"
for seg in segs:
yield LogReader(log_paths[seg], sort_by_time=sort_by_time)
def internal_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False):
segs = parse_slice(sr)
for seg in segs:
yield LogReader(f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{'rlog' if mode == ReadMode.RLOG else 'qlog'}.bz2", sort_by_time=sort_by_time)
def openpilotci_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False):
segs = parse_slice(sr)
for seg in segs:
yield LogReader(get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog'), sort_by_time=sort_by_time)
def direct_source(file_or_url, sort_by_time):
yield LogReader(file_or_url, sort_by_time=sort_by_time)
def auto_source(*args, **kwargs):
# Automatically determine viable source
try:
next(internal_source(*args, **kwargs))
return internal_source(*args, **kwargs)
except Exception:
pass
try:
next(openpilotci_source(*args, **kwargs))
return openpilotci_source(*args, **kwargs)
except Exception:
pass
return comma_api_source(*args, **kwargs)
def parse_useradmin(identifier):
if "useradmin.comma.ai" in identifier:
query = parse_qs(urlparse(identifier).query)
return query["onebox"][0]
return None
def parse_cabana(identifier):
if "cabana.comma.ai" in identifier:
query = parse_qs(urlparse(identifier).query)
return query["route"][0]
return None
def parse_cd(identifier):
if "cd:/" in identifier:
return identifier.replace("cd:/", "")
return None
def parse_direct(identifier):
if "https://" in identifier or "http://" in identifier or pathlib.Path(identifier).exists():
return identifier
return None
def parse_indirect(identifier):
parsed = parse_useradmin(identifier) or parse_cabana(identifier)
if parsed is not None:
return parsed, comma_api_source, True
parsed = parse_cd(identifier)
if parsed is not None:
return parsed, internal_source, True
return identifier, None, False
class SegmentRangeReader:
def _logreaders_from_identifier(self, identifier):
parsed, source, is_indirect = parse_indirect(identifier)
if not is_indirect:
direct_parsed = parse_direct(identifier)
if direct_parsed is not None:
return direct_source(identifier, sort_by_time=self.sort_by_time)
sr = SegmentRange(parsed)
mode = self.default_mode if sr.selector is None else ReadMode(sr.selector)
source = self.default_source if source is None else source
return source(sr, mode, sort_by_time=self.sort_by_time)
def __init__(self, identifier: str, default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False):
self.default_mode = default_mode
self.default_source = default_source
self.sort_by_time = sort_by_time
self.lrs = self._logreaders_from_identifier(identifier)
def __iter__(self):
for lr in self.lrs:
for m in lr:
yield m

View File

129
tools/lib/tests/test_caching.py Executable file
View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
from functools import wraps
import http.server
import os
import threading
import time
import unittest
from parameterized import parameterized
from openpilot.tools.lib.url_file import URLFile
class CachingTestRequestHandler(http.server.BaseHTTPRequestHandler):
FILE_EXISTS = True
def do_GET(self):
if self.FILE_EXISTS:
self.send_response(200, b'1234')
else:
self.send_response(404)
self.end_headers()
def do_HEAD(self):
if self.FILE_EXISTS:
self.send_response(200)
self.send_header("Content-Length", "4")
else:
self.send_response(404)
self.end_headers()
class CachingTestServer(threading.Thread):
def run(self):
self.server = http.server.HTTPServer(("127.0.0.1", 0), CachingTestRequestHandler)
self.port = self.server.server_port
self.server.serve_forever()
def stop(self):
self.server.server_close()
self.server.shutdown()
def with_caching_server(func):
@wraps(func)
def wrapper(*args, **kwargs):
server = CachingTestServer()
server.start()
time.sleep(0.25) # wait for server to get it's port
try:
func(*args, **kwargs, port=server.port)
finally:
server.stop()
return wrapper
class TestFileDownload(unittest.TestCase):
def compare_loads(self, url, start=0, length=None):
"""Compares range between cached and non cached version"""
file_cached = URLFile(url, cache=True)
file_downloaded = URLFile(url, cache=False)
file_cached.seek(start)
file_downloaded.seek(start)
self.assertEqual(file_cached.get_length(), file_downloaded.get_length())
self.assertLessEqual(length + start if length is not None else 0, file_downloaded.get_length())
response_cached = file_cached.read(ll=length)
response_downloaded = file_downloaded.read(ll=length)
self.assertEqual(response_cached, response_downloaded)
# Now test with cache in place
file_cached = URLFile(url, cache=True)
file_cached.seek(start)
response_cached = file_cached.read(ll=length)
self.assertEqual(file_cached.get_length(), file_downloaded.get_length())
self.assertEqual(response_cached, response_downloaded)
def test_small_file(self):
# Make sure we don't force cache
os.environ["FILEREADER_CACHE"] = "0"
small_file_url = "https://raw.githubusercontent.com/commaai/openpilot/master/docs/SAFETY.md"
# If you want large file to be larger than a chunk
# large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/fcamera.hevc"
# Load full small file
self.compare_loads(small_file_url)
file_small = URLFile(small_file_url)
length = file_small.get_length()
self.compare_loads(small_file_url, length - 100, 100)
self.compare_loads(small_file_url, 50, 100)
# Load small file 100 bytes at a time
for i in range(length // 100):
self.compare_loads(small_file_url, 100 * i, 100)
def test_large_file(self):
large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"
# Load the end 100 bytes of both files
file_large = URLFile(large_file_url)
length = file_large.get_length()
self.compare_loads(large_file_url, length - 100, 100)
self.compare_loads(large_file_url)
@parameterized.expand([(True, ), (False, )])
@with_caching_server
def test_recover_from_missing_file(self, cache_enabled, port):
os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0"
file_url = f"http://localhost:{port}/test.png"
CachingTestRequestHandler.FILE_EXISTS = False
length = URLFile(file_url).get_length()
self.assertEqual(length, -1)
CachingTestRequestHandler.FILE_EXISTS = True
length = URLFile(file_url).get_length()
self.assertEqual(length, 4)
if __name__ == "__main__":
unittest.main()

67
tools/lib/tests/test_readers.py Executable file
View File

@@ -0,0 +1,67 @@
#!/usr/bin/env python
import unittest
import requests
import tempfile
from collections import defaultdict
import numpy as np
from openpilot.tools.lib.framereader import FrameReader
from openpilot.tools.lib.logreader import LogReader
class TestReaders(unittest.TestCase):
@unittest.skip("skip for bandwidth reasons")
def test_logreader(self):
def _check_data(lr):
hist = defaultdict(int)
for l in lr:
hist[l.which()] += 1
self.assertEqual(hist['carControl'], 6000)
self.assertEqual(hist['logMessage'], 6857)
with tempfile.NamedTemporaryFile(suffix=".bz2") as fp:
r = requests.get("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/raw_log.bz2?raw=true", timeout=10)
fp.write(r.content)
fp.flush()
lr_file = LogReader(fp.name)
_check_data(lr_file)
lr_url = LogReader("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/raw_log.bz2?raw=true")
_check_data(lr_url)
@unittest.skip("skip for bandwidth reasons")
def test_framereader(self):
def _check_data(f):
self.assertEqual(f.frame_count, 1200)
self.assertEqual(f.w, 1164)
self.assertEqual(f.h, 874)
frame_first_30 = f.get(0, 30)
self.assertEqual(len(frame_first_30), 30)
print(frame_first_30[15])
print("frame_0")
frame_0 = f.get(0, 1)
frame_15 = f.get(15, 1)
print(frame_15[0])
assert np.all(frame_first_30[0] == frame_0[0])
assert np.all(frame_first_30[15] == frame_15[0])
with tempfile.NamedTemporaryFile(suffix=".hevc") as fp:
r = requests.get("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/video.hevc?raw=true", timeout=10)
fp.write(r.content)
fp.flush()
fr_file = FrameReader(fp.name)
_check_data(fr_file)
fr_url = FrameReader("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/video.hevc?raw=true")
_check_data(fr_url)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,32 @@
#!/usr/bin/env python
import unittest
from collections import namedtuple
from openpilot.tools.lib.route import SegmentName
class TestRouteLibrary(unittest.TestCase):
def test_segment_name_formats(self):
Case = namedtuple('Case', ['input', 'expected_route', 'expected_segment_num', 'expected_data_dir'])
cases = [ Case("a2a0ccea32023010|2023-07-27--13-01-19", "a2a0ccea32023010|2023-07-27--13-01-19", -1, None),
Case("a2a0ccea32023010/2023-07-27--13-01-19--1", "a2a0ccea32023010|2023-07-27--13-01-19", 1, None),
Case("a2a0ccea32023010|2023-07-27--13-01-19/2", "a2a0ccea32023010|2023-07-27--13-01-19", 2, None),
Case("a2a0ccea32023010/2023-07-27--13-01-19/3", "a2a0ccea32023010|2023-07-27--13-01-19", 3, None),
Case("/data/media/0/realdata/a2a0ccea32023010|2023-07-27--13-01-19", "a2a0ccea32023010|2023-07-27--13-01-19", -1, "/data/media/0/realdata"),
Case("/data/media/0/realdata/a2a0ccea32023010|2023-07-27--13-01-19--1", "a2a0ccea32023010|2023-07-27--13-01-19", 1, "/data/media/0/realdata"),
Case("/data/media/0/realdata/a2a0ccea32023010|2023-07-27--13-01-19/2", "a2a0ccea32023010|2023-07-27--13-01-19", 2, "/data/media/0/realdata") ]
def _validate(case):
route_or_segment_name = case.input
s = SegmentName(route_or_segment_name, allow_route_name=True)
self.assertEqual(str(s.route_name), case.expected_route)
self.assertEqual(s.segment_num, case.expected_segment_num)
self.assertEqual(s.data_dir, case.expected_data_dir)
for case in cases:
_validate(case)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,88 @@
import shutil
import tempfile
import numpy as np
import unittest
from parameterized import parameterized
import requests
from openpilot.tools.lib.route import SegmentRange
from openpilot.tools.lib.srreader import ReadMode, SegmentRangeReader, parse_slice, parse_indirect
NUM_SEGS = 17 # number of segments in the test route
ALL_SEGS = list(np.arange(NUM_SEGS))
TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12"
QLOG_FILE = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"
class TestSegmentRangeReader(unittest.TestCase):
@parameterized.expand([
(f"{TEST_ROUTE}", ALL_SEGS),
(f"{TEST_ROUTE.replace('/', '|')}", ALL_SEGS),
(f"{TEST_ROUTE}--0", [0]),
(f"{TEST_ROUTE}--5", [5]),
(f"{TEST_ROUTE}/0", [0]),
(f"{TEST_ROUTE}/5", [5]),
(f"{TEST_ROUTE}/0:10", ALL_SEGS[0:10]),
(f"{TEST_ROUTE}/0:0", []),
(f"{TEST_ROUTE}/4:6", ALL_SEGS[4:6]),
(f"{TEST_ROUTE}/0:-1", ALL_SEGS[0:-1]),
(f"{TEST_ROUTE}/:5", ALL_SEGS[:5]),
(f"{TEST_ROUTE}/2:", ALL_SEGS[2:]),
(f"{TEST_ROUTE}/2:-1", ALL_SEGS[2:-1]),
(f"{TEST_ROUTE}/-1", [ALL_SEGS[-1]]),
(f"{TEST_ROUTE}/-2", [ALL_SEGS[-2]]),
(f"{TEST_ROUTE}/-2:-1", ALL_SEGS[-2:-1]),
(f"{TEST_ROUTE}/-4:-2", ALL_SEGS[-4:-2]),
(f"{TEST_ROUTE}/:10:2", ALL_SEGS[:10:2]),
(f"{TEST_ROUTE}/5::2", ALL_SEGS[5::2]),
(f"https://useradmin.comma.ai/?onebox={TEST_ROUTE}", ALL_SEGS),
(f"https://useradmin.comma.ai/?onebox={TEST_ROUTE.replace('/', '|')}", ALL_SEGS),
(f"https://useradmin.comma.ai/?onebox={TEST_ROUTE.replace('/', '%7C')}", ALL_SEGS),
(f"https://cabana.comma.ai/?route={TEST_ROUTE}", ALL_SEGS),
(f"cd:/{TEST_ROUTE}", ALL_SEGS),
])
def test_indirect_parsing(self, identifier, expected):
parsed, _, _ = parse_indirect(identifier)
sr = SegmentRange(parsed)
segs = parse_slice(sr)
self.assertListEqual(list(segs), expected)
def test_direct_parsing(self):
qlog = tempfile.NamedTemporaryFile(mode='wb', delete=False)
with requests.get(QLOG_FILE, stream=True) as r:
with qlog as f:
shutil.copyfileobj(r.raw, f)
for f in [QLOG_FILE, qlog.name]:
l = len(list(SegmentRangeReader(f)))
self.assertGreater(l, 100)
@parameterized.expand([
(f"{TEST_ROUTE}///",),
(f"{TEST_ROUTE}---",),
(f"{TEST_ROUTE}/-4:--2",),
(f"{TEST_ROUTE}/-a",),
(f"{TEST_ROUTE}/j",),
(f"{TEST_ROUTE}/0:1:2:3",),
(f"{TEST_ROUTE}/:::3",),
])
def test_bad_ranges(self, segment_range):
with self.assertRaises(AssertionError):
sr = SegmentRange(segment_range)
parse_slice(sr)
def test_modes(self):
qlog_len = len(list(SegmentRangeReader(f"{TEST_ROUTE}/0", ReadMode.QLOG)))
rlog_len = len(list(SegmentRangeReader(f"{TEST_ROUTE}/0", ReadMode.RLOG)))
self.assertLess(qlog_len * 6, rlog_len)
def test_modes_from_name(self):
qlog_len = len(list(SegmentRangeReader(f"{TEST_ROUTE}/0/q")))
rlog_len = len(list(SegmentRangeReader(f"{TEST_ROUTE}/0/r")))
self.assertLess(qlog_len * 6, rlog_len)
if __name__ == "__main__":
unittest.main()

156
tools/lib/url_file.py Normal file
View File

@@ -0,0 +1,156 @@
import os
import time
import threading
from hashlib import sha256
from urllib3 import PoolManager
from urllib3.util import Timeout
from tenacity import retry, wait_random_exponential, stop_after_attempt
from openpilot.common.file_helpers import atomic_write_in_dir
from openpilot.system.hardware.hw import Paths
# Cache chunk size
K = 1000
CHUNK_SIZE = 1000 * K
def hash_256(link):
hsh = str(sha256((link.split("?")[0]).encode('utf-8')).hexdigest())
return hsh
class URLFileException(Exception):
pass
class URLFile:
_tlocal = threading.local()
def __init__(self, url, debug=False, cache=None):
self._url = url
self._pos = 0
self._length = None
self._local_file = None
self._debug = debug
# True by default, false if FILEREADER_CACHE is defined, but can be overwritten by the cache input
self._force_download = not int(os.environ.get("FILEREADER_CACHE", "0"))
if cache is not None:
self._force_download = not cache
if not self._force_download:
os.makedirs(Paths.download_cache_root(), exist_ok=True)
try:
self._http_client = URLFile._tlocal.http_client
except AttributeError:
self._http_client = URLFile._tlocal.http_client = PoolManager()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if self._local_file is not None:
os.remove(self._local_file.name)
self._local_file.close()
self._local_file = None
@retry(wait=wait_random_exponential(multiplier=1, max=5), stop=stop_after_attempt(3), reraise=True)
def get_length_online(self):
timeout = Timeout(connect=50.0, read=500.0)
response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False)
if not (200 <= response.status <= 299):
return -1
length = response.headers.get('content-length', 0)
return int(length)
def get_length(self):
if self._length is not None:
return self._length
file_length_path = os.path.join(Paths.download_cache_root(), hash_256(self._url) + "_length")
if not self._force_download and os.path.exists(file_length_path):
with open(file_length_path) as file_length:
content = file_length.read()
self._length = int(content)
return self._length
self._length = self.get_length_online()
if not self._force_download and self._length != -1:
with atomic_write_in_dir(file_length_path, mode="w") as file_length:
file_length.write(str(self._length))
return self._length
def read(self, ll=None):
if self._force_download:
return self.read_aux(ll=ll)
file_begin = self._pos
file_end = self._pos + ll if ll is not None else self.get_length()
assert file_end != -1, f"Remote file is empty or doesn't exist: {self._url}"
# We have to align with chunks we store. Position is the begginiing of the latest chunk that starts before or at our file
position = (file_begin // CHUNK_SIZE) * CHUNK_SIZE
response = b""
while True:
self._pos = position
chunk_number = self._pos / CHUNK_SIZE
file_name = hash_256(self._url) + "_" + str(chunk_number)
full_path = os.path.join(Paths.download_cache_root(), str(file_name))
data = None
# If we don't have a file, download it
if not os.path.exists(full_path):
data = self.read_aux(ll=CHUNK_SIZE)
with atomic_write_in_dir(full_path, mode="wb") as new_cached_file:
new_cached_file.write(data)
else:
with open(full_path, "rb") as cached_file:
data = cached_file.read()
response += data[max(0, file_begin - position): min(CHUNK_SIZE, file_end - position)]
position += CHUNK_SIZE
if position >= file_end:
self._pos = file_end
return response
@retry(wait=wait_random_exponential(multiplier=1, max=5), stop=stop_after_attempt(3), reraise=True)
def read_aux(self, ll=None):
download_range = False
headers = {'Connection': 'keep-alive'}
if self._pos != 0 or ll is not None:
if ll is None:
end = self.get_length() - 1
else:
end = min(self._pos + ll, self.get_length()) - 1
if self._pos >= end:
return b""
headers['Range'] = f"bytes={self._pos}-{end}"
download_range = True
if self._debug:
t1 = time.time()
timeout = Timeout(connect=50.0, read=500.0)
response = self._http_client.request('GET', self._url, timeout=timeout, preload_content=False, headers=headers)
ret = response.data
if self._debug:
t2 = time.time()
if t2 - t1 > 0.1:
print(f"get {self._url} {headers!r} {t2 - t1:.3f} slow")
response_code = response.status
if response_code == 416: # Requested Range Not Satisfiable
raise URLFileException(f"Error, range out of bounds {response_code} {headers} ({self._url}): {repr(ret)[:500]}")
if download_range and response_code != 206: # Partial Content
raise URLFileException(f"Error, requested range but got unexpected response {response_code} {headers} ({self._url}): {repr(ret)[:500]}")
if (not download_range) and response_code != 200: # OK
raise URLFileException(f"Error {response_code} {headers} ({self._url}): {repr(ret)[:500]}")
self._pos += len(ret)
return ret
def seek(self, pos):
self._pos = pos
@property
def name(self):
return self._url

312
tools/lib/vidindex.py Executable file
View File

@@ -0,0 +1,312 @@
#!/usr/bin/env python3
import argparse
import os
import struct
from enum import IntEnum
from typing import Tuple
from openpilot.tools.lib.filereader import FileReader
DEBUG = int(os.getenv("DEBUG", "0"))
# compare to ffmpeg parsing
# ffmpeg -i <input.hevc> -c copy -bsf:v trace_headers -f null - 2>&1 | grep -B4 -A32 '] 0 '
# H.265 specification
# https://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-H.265-201802-S!!PDF-E&type=items
NAL_UNIT_START_CODE = b"\x00\x00\x01"
NAL_UNIT_START_CODE_SIZE = len(NAL_UNIT_START_CODE)
NAL_UNIT_HEADER_SIZE = 2
class HevcNalUnitType(IntEnum):
TRAIL_N = 0 # RBSP structure: slice_segment_layer_rbsp( )
TRAIL_R = 1 # RBSP structure: slice_segment_layer_rbsp( )
TSA_N = 2 # RBSP structure: slice_segment_layer_rbsp( )
TSA_R = 3 # RBSP structure: slice_segment_layer_rbsp( )
STSA_N = 4 # RBSP structure: slice_segment_layer_rbsp( )
STSA_R = 5 # RBSP structure: slice_segment_layer_rbsp( )
RADL_N = 6 # RBSP structure: slice_segment_layer_rbsp( )
RADL_R = 7 # RBSP structure: slice_segment_layer_rbsp( )
RASL_N = 8 # RBSP structure: slice_segment_layer_rbsp( )
RASL_R = 9 # RBSP structure: slice_segment_layer_rbsp( )
RSV_VCL_N10 = 10
RSV_VCL_R11 = 11
RSV_VCL_N12 = 12
RSV_VCL_R13 = 13
RSV_VCL_N14 = 14
RSV_VCL_R15 = 15
BLA_W_LP = 16 # RBSP structure: slice_segment_layer_rbsp( )
BLA_W_RADL = 17 # RBSP structure: slice_segment_layer_rbsp( )
BLA_N_LP = 18 # RBSP structure: slice_segment_layer_rbsp( )
IDR_W_RADL = 19 # RBSP structure: slice_segment_layer_rbsp( )
IDR_N_LP = 20 # RBSP structure: slice_segment_layer_rbsp( )
CRA_NUT = 21 # RBSP structure: slice_segment_layer_rbsp( )
RSV_IRAP_VCL22 = 22
RSV_IRAP_VCL23 = 23
RSV_VCL24 = 24
RSV_VCL25 = 25
RSV_VCL26 = 26
RSV_VCL27 = 27
RSV_VCL28 = 28
RSV_VCL29 = 29
RSV_VCL30 = 30
RSV_VCL31 = 31
VPS_NUT = 32 # RBSP structure: video_parameter_set_rbsp( )
SPS_NUT = 33 # RBSP structure: seq_parameter_set_rbsp( )
PPS_NUT = 34 # RBSP structure: pic_parameter_set_rbsp( )
AUD_NUT = 35
EOS_NUT = 36
EOB_NUT = 37
FD_NUT = 38
PREFIX_SEI_NUT = 39
SUFFIX_SEI_NUT = 40
RSV_NVCL41 = 41
RSV_NVCL42 = 42
RSV_NVCL43 = 43
RSV_NVCL44 = 44
RSV_NVCL45 = 45
RSV_NVCL46 = 46
RSV_NVCL47 = 47
UNSPEC48 = 48
UNSPEC49 = 49
UNSPEC50 = 50
UNSPEC51 = 51
UNSPEC52 = 52
UNSPEC53 = 53
UNSPEC54 = 54
UNSPEC55 = 55
UNSPEC56 = 56
UNSPEC57 = 57
UNSPEC58 = 58
UNSPEC59 = 59
UNSPEC60 = 60
UNSPEC61 = 61
UNSPEC62 = 62
UNSPEC63 = 63
# B.2.2 Byte stream NAL unit semantics
# - The nal_unit_type within the nal_unit( ) syntax structure is equal to VPS_NUT, SPS_NUT or PPS_NUT.
# - The byte stream NAL unit syntax structure contains the first NAL unit of an access unit in decoding
# order, as specified in clause 7.4.2.4.4.
HEVC_PARAMETER_SET_NAL_UNITS = (
HevcNalUnitType.VPS_NUT,
HevcNalUnitType.SPS_NUT,
HevcNalUnitType.PPS_NUT,
)
# 3.29 coded slice segment NAL unit: A NAL unit that has nal_unit_type in the range of TRAIL_N to RASL_R,
# inclusive, or in the range of BLA_W_LP to RSV_IRAP_VCL23, inclusive, which indicates that the NAL unit
# contains a coded slice segment
HEVC_CODED_SLICE_SEGMENT_NAL_UNITS = (
HevcNalUnitType.TRAIL_N,
HevcNalUnitType.TRAIL_R,
HevcNalUnitType.TSA_N,
HevcNalUnitType.TSA_R,
HevcNalUnitType.STSA_N,
HevcNalUnitType.STSA_R,
HevcNalUnitType.RADL_N,
HevcNalUnitType.RADL_R,
HevcNalUnitType.RASL_N,
HevcNalUnitType.RASL_R,
HevcNalUnitType.BLA_W_LP,
HevcNalUnitType.BLA_W_RADL,
HevcNalUnitType.BLA_N_LP,
HevcNalUnitType.IDR_W_RADL,
HevcNalUnitType.IDR_N_LP,
HevcNalUnitType.CRA_NUT,
)
class VideoFileInvalid(Exception):
pass
def get_ue(dat: bytes, start_idx: int, skip_bits: int) -> Tuple[int, int]:
prefix_val = 0
prefix_len = 0
suffix_val = 0
suffix_len = 0
i = start_idx
while i < len(dat):
j = 7
while j >= 0:
if skip_bits > 0:
skip_bits -= 1
elif prefix_val == 0:
prefix_val = (dat[i] >> j) & 1
prefix_len += 1
else:
suffix_val = (suffix_val << 1) | ((dat[i] >> j) & 1)
suffix_len += 1
j -= 1
if prefix_val == 1 and prefix_len - 1 == suffix_len:
val = 2**(prefix_len-1) - 1 + suffix_val
size = prefix_len + suffix_len
return val, size
i += 1
raise VideoFileInvalid("invalid exponential-golomb code")
def require_nal_unit_start(dat: bytes, nal_unit_start: int) -> None:
if nal_unit_start < 1:
raise ValueError("start index must be greater than zero")
if dat[nal_unit_start:nal_unit_start + NAL_UNIT_START_CODE_SIZE] != NAL_UNIT_START_CODE:
raise VideoFileInvalid("data must begin with start code")
def get_hevc_nal_unit_length(dat: bytes, nal_unit_start: int) -> int:
try:
pos = dat.index(NAL_UNIT_START_CODE, nal_unit_start + NAL_UNIT_START_CODE_SIZE)
except ValueError:
pos = -1
# length of NAL unit is byte count up to next NAL unit start index
nal_unit_len = (pos if pos != -1 else len(dat)) - nal_unit_start
if DEBUG:
print(" nal_unit_len:", nal_unit_len)
return nal_unit_len
def get_hevc_nal_unit_type(dat: bytes, nal_unit_start: int) -> HevcNalUnitType:
# 7.3.1.2 NAL unit header syntax
# nal_unit_header( ) { // descriptor
# forbidden_zero_bit f(1)
# nal_unit_type u(6)
# nuh_layer_id u(6)
# nuh_temporal_id_plus1 u(3)
# }
header_start = nal_unit_start + NAL_UNIT_START_CODE_SIZE
nal_unit_header = dat[header_start:header_start + NAL_UNIT_HEADER_SIZE]
if len(nal_unit_header) != 2:
raise VideoFileInvalid("data to short to contain nal unit header")
nal_unit_type = HevcNalUnitType((nal_unit_header[0] >> 1) & 0x3F)
if DEBUG:
print(" nal_unit_type:", nal_unit_type.name, f"({nal_unit_type.value})")
return nal_unit_type
def get_hevc_slice_type(dat: bytes, nal_unit_start: int, nal_unit_type: HevcNalUnitType) -> Tuple[int, bool]:
# 7.3.2.9 Slice segment layer RBSP syntax
# slice_segment_layer_rbsp( ) {
# slice_segment_header( )
# slice_segment_data( )
# rbsp_slice_segment_trailing_bits( )
# }
# ...
# 7.3.6.1 General slice segment header syntax
# slice_segment_header( ) { // descriptor
# first_slice_segment_in_pic_flag u(1)
# if( nal_unit_type >= BLA_W_LP && nal_unit_type <= RSV_IRAP_VCL23 )
# no_output_of_prior_pics_flag u(1)
# slice_pic_parameter_set_id ue(v)
# if( !first_slice_segment_in_pic_flag ) {
# if( dependent_slice_segments_enabled_flag )
# dependent_slice_segment_flag u(1)
# slice_segment_address u(v)
# }
# if( !dependent_slice_segment_flag ) {
# for( i = 0; i < num_extra_slice_header_bits; i++ )
# slice_reserved_flag[ i ] u(1)
# slice_type ue(v)
# ...
rbsp_start = nal_unit_start + NAL_UNIT_START_CODE_SIZE + NAL_UNIT_HEADER_SIZE
skip_bits = 0
# 7.4.7.1 General slice segment header semantics
# first_slice_segment_in_pic_flag equal to 1 specifies that the slice segment is the first slice segment of the picture in
# decoding order. first_slice_segment_in_pic_flag equal to 0 specifies that the slice segment is not the first slice segment
# of the picture in decoding order.
is_first_slice = dat[rbsp_start] >> 7 & 1 == 1
if not is_first_slice:
# TODO: parse dependent_slice_segment_flag and slice_segment_address and get real slice_type
# for now since we don't use it return -1 for slice_type
return (-1, is_first_slice)
skip_bits += 1 # skip past first_slice_segment_in_pic_flag
if nal_unit_type >= HevcNalUnitType.BLA_W_LP and nal_unit_type <= HevcNalUnitType.RSV_IRAP_VCL23:
# 7.4.7.1 General slice segment header semantics
# no_output_of_prior_pics_flag affects the output of previously-decoded pictures in the decoded picture buffer after the
# decoding of an IDR or a BLA picture that is not the first picture in the bitstream as specified in Annex C.
skip_bits += 1 # skip past no_output_of_prior_pics_flag
# 7.4.7.1 General slice segment header semantics
# slice_pic_parameter_set_id specifies the value of pps_pic_parameter_set_id for the PPS in use.
# The value of slice_pic_parameter_set_id shall be in the range of 0 to 63, inclusive.
_, size = get_ue(dat, rbsp_start, skip_bits)
skip_bits += size # skip past slice_pic_parameter_set_id
# 7.4.3.3.1 General picture parameter set RBSP semanal_unit_lenntics
# num_extra_slice_header_bits specifies the number of extra slice header bits that are present in the slice header RBSP
# for coded pictures referring to the PPS. The value of num_extra_slice_header_bits shall be in the range of 0 to 2, inclusive,
# in bitstreams conforming to this version of this Specification. Other values for num_extra_slice_header_bits are reserved
# for future use by ITU-T | ISO/IEC. However, decoders shall allow num_extra_slice_header_bits to have any value.
# TODO: get from PPS_NUT pic_parameter_set_rbsp( ) for corresponding slice_pic_parameter_set_id
num_extra_slice_header_bits = 0
skip_bits += num_extra_slice_header_bits
# 7.4.7.1 General slice segment header semantics
# slice_type specifies the coding type of the slice according to Table 7-7.
# Table 7-7 - Name association to slice_type
# slice_type | Name of slice_type
# 0 | B (B slice)
# 1 | P (P slice)
# 2 | I (I slice)
# unsigned integer 0-th order Exp-Golomb-coded syntax element with the left bit first
slice_type, _ = get_ue(dat, rbsp_start, skip_bits)
if DEBUG:
print(" slice_type:", slice_type, f"(first slice: {is_first_slice})")
if slice_type > 2:
raise VideoFileInvalid("slice_type must be 0, 1, or 2")
return slice_type, is_first_slice
def hevc_index(hevc_file_name: str, allow_corrupt: bool=False) -> Tuple[list, int, bytes]:
with FileReader(hevc_file_name) as f:
dat = f.read()
if len(dat) < NAL_UNIT_START_CODE_SIZE + 1:
raise VideoFileInvalid("data is too short")
if dat[0] != 0x00:
raise VideoFileInvalid("first byte must be 0x00")
prefix_dat = b""
frame_types = list()
i = 1 # skip past first byte 0x00
try:
while i < len(dat):
require_nal_unit_start(dat, i)
nal_unit_len = get_hevc_nal_unit_length(dat, i)
nal_unit_type = get_hevc_nal_unit_type(dat, i)
if nal_unit_type in HEVC_PARAMETER_SET_NAL_UNITS:
prefix_dat += dat[i:i+nal_unit_len]
elif nal_unit_type in HEVC_CODED_SLICE_SEGMENT_NAL_UNITS:
slice_type, is_first_slice = get_hevc_slice_type(dat, i, nal_unit_type)
if is_first_slice:
frame_types.append((slice_type, i))
i += nal_unit_len
except Exception as e:
if not allow_corrupt:
raise
print(f"ERROR: NAL unit skipped @ {i}\n", str(e))
return frame_types, len(dat), prefix_dat
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("input_file", type=str)
parser.add_argument("output_prefix_file", type=str)
parser.add_argument("output_index_file", type=str)
args = parser.parse_args()
frame_types, dat_len, prefix_dat = hevc_index(args.input_file)
with open(args.output_prefix_file, "wb") as f:
f.write(prefix_dat)
with open(args.output_index_file, "wb") as f:
for ft, fp in frame_types:
f.write(struct.pack("<II", ft, fp))
f.write(struct.pack("<II", 0xFFFFFFFF, dat_len))
if __name__ == "__main__":
main()