openpilot v0.9.6 release
date: 2024-01-12T10:13:37 master commit: ba792d576a49a0899b88a753fa1c52956bedf9e6
This commit is contained in:
51
tools/lib/README.md
Normal file
51
tools/lib/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
## LogReader
|
||||
|
||||
Route is a class for conveniently accessing all the [logs](/system/loggerd/) from your routes. The LogReader class reads the non-video logs, i.e. rlog.bz2 and qlog.bz2. There's also a matching FrameReader class for reading the videos.
|
||||
|
||||
```python
|
||||
from openpilot.tools.lib.route import Route
|
||||
from openpilot.tools.lib.logreader import LogReader
|
||||
|
||||
r = Route("a2a0ccea32023010|2023-07-27--13-01-19")
|
||||
|
||||
# get a list of paths for the route's rlog files
|
||||
print(r.log_paths())
|
||||
|
||||
# and road camera (fcamera.hevc) files
|
||||
print(r.camera_paths())
|
||||
|
||||
# setup a LogReader to read the route's first rlog
|
||||
lr = LogReader(r.log_paths()[0])
|
||||
|
||||
# print out all the messages in the log
|
||||
import codecs
|
||||
codecs.register_error("strict", codecs.backslashreplace_errors)
|
||||
for msg in lr:
|
||||
print(msg)
|
||||
|
||||
# setup a LogReader for the route's second qlog
|
||||
lr = LogReader(r.log_paths()[1])
|
||||
|
||||
# print all the steering angles values from the log
|
||||
for msg in lr:
|
||||
if msg.which() == "carState":
|
||||
print(msg.carState.steeringAngleDeg)
|
||||
```
|
||||
|
||||
### MultiLogIterator
|
||||
|
||||
`MultiLogIterator` is similar to `LogReader`, but reads multiple logs.
|
||||
|
||||
```python
|
||||
from openpilot.tools.lib.route import Route
|
||||
from openpilot.tools.lib.logreader import MultiLogIterator
|
||||
|
||||
# setup a MultiLogIterator to read all the logs in the route
|
||||
r = Route("a2a0ccea32023010|2023-07-27--13-01-19")
|
||||
lr = MultiLogIterator(r.log_paths())
|
||||
|
||||
# print all the steering angles values from all the logs in the route
|
||||
for msg in lr:
|
||||
if msg.which() == "carState":
|
||||
print(msg.carState.steeringAngleDeg)
|
||||
```
|
||||
0
tools/lib/__init__.py
Normal file
0
tools/lib/__init__.py
Normal file
34
tools/lib/api.py
Normal file
34
tools/lib/api.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
import requests
|
||||
API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com')
|
||||
|
||||
class CommaApi():
|
||||
def __init__(self, token=None):
|
||||
self.session = requests.Session()
|
||||
self.session.headers['User-agent'] = 'OpenpilotTools'
|
||||
if token:
|
||||
self.session.headers['Authorization'] = 'JWT ' + token
|
||||
|
||||
def request(self, method, endpoint, **kwargs):
|
||||
resp = self.session.request(method, API_HOST + '/' + endpoint, **kwargs)
|
||||
resp_json = resp.json()
|
||||
if isinstance(resp_json, dict) and resp_json.get('error'):
|
||||
if resp.status_code in [401, 403]:
|
||||
raise UnauthorizedError('Unauthorized. Authenticate with tools/lib/auth.py')
|
||||
|
||||
e = APIError(str(resp.status_code) + ":" + resp_json.get('description', str(resp_json['error'])))
|
||||
e.status_code = resp.status_code
|
||||
raise e
|
||||
return resp_json
|
||||
|
||||
def get(self, endpoint, **kwargs):
|
||||
return self.request('GET', endpoint, **kwargs)
|
||||
|
||||
def post(self, endpoint, **kwargs):
|
||||
return self.request('POST', endpoint, **kwargs)
|
||||
|
||||
class APIError(Exception):
|
||||
pass
|
||||
|
||||
class UnauthorizedError(Exception):
|
||||
pass
|
||||
145
tools/lib/auth.py
Executable file
145
tools/lib/auth.py
Executable file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Usage::
|
||||
|
||||
usage: auth.py [-h] [{google,apple,github,jwt}] [jwt]
|
||||
|
||||
Login to your comma account
|
||||
|
||||
positional arguments:
|
||||
{google,apple,github,jwt}
|
||||
jwt
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
|
||||
|
||||
Examples::
|
||||
|
||||
./auth.py # Log in with google account
|
||||
./auth.py github # Log in with GitHub Account
|
||||
./auth.py jwt ey......hw # Log in with a JWT from https://jwt.comma.ai, for use in CI
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import pprint
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any, Dict
|
||||
from urllib.parse import parse_qs, urlencode
|
||||
|
||||
from openpilot.tools.lib.api import APIError, CommaApi, UnauthorizedError
|
||||
from openpilot.tools.lib.auth_config import set_token, get_token
|
||||
|
||||
PORT = 3000
|
||||
|
||||
|
||||
class ClientRedirectServer(HTTPServer):
|
||||
query_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class ClientRedirectHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if not self.path.startswith('/auth'):
|
||||
self.send_response(204)
|
||||
return
|
||||
|
||||
query = self.path.split('?', 1)[-1]
|
||||
query_parsed = parse_qs(query, keep_blank_values=True)
|
||||
self.server.query_params = query_parsed
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'text/plain')
|
||||
self.end_headers()
|
||||
self.wfile.write(b'Return to the CLI to continue')
|
||||
|
||||
def log_message(self, *args):
|
||||
pass # this prevent http server from dumping messages to stdout
|
||||
|
||||
|
||||
def auth_redirect_link(method):
|
||||
provider_id = {
|
||||
'google': 'g',
|
||||
'apple': 'a',
|
||||
'github': 'h',
|
||||
}[method]
|
||||
|
||||
params = {
|
||||
'redirect_uri': f"https://api.comma.ai/v2/auth/{provider_id}/redirect/",
|
||||
'state': f'service,localhost:{PORT}',
|
||||
}
|
||||
|
||||
if method == 'google':
|
||||
params.update({
|
||||
'type': 'web_server',
|
||||
'client_id': '45471411055-ornt4svd2miog6dnopve7qtmh5mnu6id.apps.googleusercontent.com',
|
||||
'response_type': 'code',
|
||||
'scope': 'https://www.googleapis.com/auth/userinfo.email',
|
||||
'prompt': 'select_account',
|
||||
})
|
||||
return 'https://accounts.google.com/o/oauth2/auth?' + urlencode(params)
|
||||
elif method == 'github':
|
||||
params.update({
|
||||
'client_id': '28c4ecb54bb7272cb5a4',
|
||||
'scope': 'read:user',
|
||||
})
|
||||
return 'https://github.com/login/oauth/authorize?' + urlencode(params)
|
||||
elif method == 'apple':
|
||||
params.update({
|
||||
'client_id': 'ai.comma.login',
|
||||
'response_type': 'code',
|
||||
'response_mode': 'form_post',
|
||||
'scope': 'name email',
|
||||
})
|
||||
return 'https://appleid.apple.com/auth/authorize?' + urlencode(params)
|
||||
else:
|
||||
raise NotImplementedError(f"no redirect implemented for method {method}")
|
||||
|
||||
|
||||
def login(method):
|
||||
oauth_uri = auth_redirect_link(method)
|
||||
|
||||
web_server = ClientRedirectServer(('localhost', PORT), ClientRedirectHandler)
|
||||
print(f'To sign in, use your browser and navigate to {oauth_uri}')
|
||||
webbrowser.open(oauth_uri, new=2)
|
||||
|
||||
while True:
|
||||
web_server.handle_request()
|
||||
if 'code' in web_server.query_params:
|
||||
break
|
||||
elif 'error' in web_server.query_params:
|
||||
print('Authentication Error: "{}". Description: "{}" '.format(
|
||||
web_server.query_params['error'],
|
||||
web_server.query_params.get('error_description')), file=sys.stderr)
|
||||
break
|
||||
|
||||
try:
|
||||
auth_resp = CommaApi().post('v2/auth/', data={'code': web_server.query_params['code'], 'provider': web_server.query_params['provider']})
|
||||
set_token(auth_resp['access_token'])
|
||||
except APIError as e:
|
||||
print(f'Authentication Error: {e}', file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Login to your comma account')
|
||||
parser.add_argument('method', default='google', const='google', nargs='?', choices=['google', 'apple', 'github', 'jwt'])
|
||||
parser.add_argument('jwt', nargs='?')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.method == 'jwt':
|
||||
if args.jwt is None:
|
||||
print("method JWT selected, but no JWT was provided")
|
||||
exit(1)
|
||||
|
||||
set_token(args.jwt)
|
||||
else:
|
||||
login(args.method)
|
||||
|
||||
try:
|
||||
me = CommaApi(token=get_token()).get('/v1/me')
|
||||
print("Authenticated!")
|
||||
pprint.pprint(me)
|
||||
except UnauthorizedError:
|
||||
print("Got invalid JWT")
|
||||
exit(1)
|
||||
29
tools/lib/auth_config.py
Normal file
29
tools/lib/auth_config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import json
|
||||
import os
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
|
||||
|
||||
class MissingAuthConfigError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_token():
|
||||
try:
|
||||
with open(os.path.join(Paths.config_root(), 'auth.json')) as f:
|
||||
auth = json.load(f)
|
||||
return auth['access_token']
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def set_token(token):
|
||||
os.makedirs(Paths.config_root(), exist_ok=True)
|
||||
with open(os.path.join(Paths.config_root(), 'auth.json'), 'w') as f:
|
||||
json.dump({'access_token': token}, f)
|
||||
|
||||
|
||||
def clear_token():
|
||||
try:
|
||||
os.unlink(os.path.join(Paths.config_root(), 'auth.json'))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
63
tools/lib/bootlog.py
Normal file
63
tools/lib/bootlog.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import datetime
|
||||
import functools
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from openpilot.tools.lib.auth_config import get_token
|
||||
from openpilot.tools.lib.api import CommaApi
|
||||
from openpilot.tools.lib.helpers import RE, timestamp_to_datetime
|
||||
|
||||
|
||||
@functools.total_ordering
|
||||
class Bootlog:
|
||||
def __init__(self, url: str):
|
||||
self._url = url
|
||||
|
||||
r = re.search(RE.BOOTLOG_NAME, url)
|
||||
if not r:
|
||||
raise Exception(f"Unable to parse: {url}")
|
||||
|
||||
self._dongle_id = r.group('dongle_id')
|
||||
self._timestamp = r.group('timestamp')
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return self._url
|
||||
|
||||
@property
|
||||
def dongle_id(self) -> str:
|
||||
return self._dongle_id
|
||||
|
||||
@property
|
||||
def timestamp(self) -> str:
|
||||
return self._timestamp
|
||||
|
||||
@property
|
||||
def datetime(self) -> datetime.datetime:
|
||||
return timestamp_to_datetime(self._timestamp)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self._dongle_id}|{self._timestamp}"
|
||||
|
||||
def __eq__(self, b) -> bool:
|
||||
if not isinstance(b, Bootlog):
|
||||
return False
|
||||
return self.datetime == b.datetime
|
||||
|
||||
def __lt__(self, b) -> bool:
|
||||
if not isinstance(b, Bootlog):
|
||||
return False
|
||||
return self.datetime < b.datetime
|
||||
|
||||
def get_bootlog_from_id(bootlog_id: str) -> Optional[Bootlog]:
|
||||
# TODO: implement an API endpoint for this
|
||||
bl = Bootlog(bootlog_id)
|
||||
for b in get_bootlogs(bl.dongle_id):
|
||||
if b == bl:
|
||||
return b
|
||||
return None
|
||||
|
||||
def get_bootlogs(dongle_id: str) -> List[Bootlog]:
|
||||
api = CommaApi(get_token())
|
||||
r = api.get(f'v1/devices/{dongle_id}/bootlogs')
|
||||
return [Bootlog(b) for b in r]
|
||||
14
tools/lib/cache.py
Normal file
14
tools/lib/cache.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
import urllib.parse
|
||||
|
||||
DEFAULT_CACHE_DIR = os.getenv("CACHE_ROOT", os.path.expanduser("~/.commacache"))
|
||||
|
||||
def cache_path_for_file_path(fn, cache_dir=DEFAULT_CACHE_DIR):
|
||||
dir_ = os.path.join(cache_dir, "local")
|
||||
os.makedirs(dir_, exist_ok=True)
|
||||
fn_parsed = urllib.parse.urlparse(fn)
|
||||
if fn_parsed.scheme == '':
|
||||
cache_fn = os.path.abspath(fn).replace("/", "_")
|
||||
else:
|
||||
cache_fn = f'{fn_parsed.hostname}_{fn_parsed.path.replace("/", "_")}'
|
||||
return os.path.join(dir_, cache_fn)
|
||||
2
tools/lib/exceptions.py
Normal file
2
tools/lib/exceptions.py
Normal file
@@ -0,0 +1,2 @@
|
||||
class DataUnreadableError(Exception):
|
||||
pass
|
||||
15
tools/lib/filereader.py
Normal file
15
tools/lib/filereader.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
from openpilot.tools.lib.url_file import URLFile
|
||||
|
||||
DATA_ENDPOINT = os.getenv("DATA_ENDPOINT", "http://data-raw.comma.internal/")
|
||||
|
||||
def resolve_name(fn):
|
||||
if fn.startswith("cd:/"):
|
||||
return fn.replace("cd:/", DATA_ENDPOINT)
|
||||
return fn
|
||||
|
||||
def FileReader(fn, debug=False):
|
||||
fn = resolve_name(fn)
|
||||
if fn.startswith(("http://", "https://")):
|
||||
return URLFile(fn, debug=debug)
|
||||
return open(fn, "rb")
|
||||
537
tools/lib/framereader.py
Normal file
537
tools/lib/framereader.py
Normal file
@@ -0,0 +1,537 @@
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import struct
|
||||
import subprocess
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
from lru import LRU
|
||||
|
||||
import _io
|
||||
from openpilot.tools.lib.cache import cache_path_for_file_path, DEFAULT_CACHE_DIR
|
||||
from openpilot.tools.lib.exceptions import DataUnreadableError
|
||||
from openpilot.tools.lib.vidindex import hevc_index
|
||||
from openpilot.common.file_helpers import atomic_write_in_dir
|
||||
|
||||
from openpilot.tools.lib.filereader import FileReader, resolve_name
|
||||
|
||||
HEVC_SLICE_B = 0
|
||||
HEVC_SLICE_P = 1
|
||||
HEVC_SLICE_I = 2
|
||||
|
||||
|
||||
class GOPReader:
|
||||
def get_gop(self, num):
|
||||
# returns (start_frame_num, num_frames, frames_to_skip, gop_data)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DoNothingContextManager:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *x):
|
||||
pass
|
||||
|
||||
|
||||
class FrameType(IntEnum):
|
||||
raw = 1
|
||||
h265_stream = 2
|
||||
|
||||
|
||||
def fingerprint_video(fn):
|
||||
with FileReader(fn) as f:
|
||||
header = f.read(4)
|
||||
if len(header) == 0:
|
||||
raise DataUnreadableError(f"{fn} is empty")
|
||||
elif header == b"\x00\xc0\x12\x00":
|
||||
return FrameType.raw
|
||||
elif header == b"\x00\x00\x00\x01":
|
||||
if 'hevc' in fn:
|
||||
return FrameType.h265_stream
|
||||
else:
|
||||
raise NotImplementedError(fn)
|
||||
else:
|
||||
raise NotImplementedError(fn)
|
||||
|
||||
|
||||
def ffprobe(fn, fmt=None):
|
||||
fn = resolve_name(fn)
|
||||
cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", "-show_streams"]
|
||||
if fmt:
|
||||
cmd += ["-f", fmt]
|
||||
cmd += ["-i", "-"]
|
||||
|
||||
try:
|
||||
with FileReader(fn) as f:
|
||||
ffprobe_output = subprocess.check_output(cmd, input=f.read(4096))
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise DataUnreadableError(fn) from e
|
||||
|
||||
return json.loads(ffprobe_output)
|
||||
|
||||
|
||||
def cache_fn(func):
|
||||
@wraps(func)
|
||||
def cache_inner(fn, *args, **kwargs):
|
||||
if kwargs.pop('no_cache', None):
|
||||
cache_path = None
|
||||
else:
|
||||
cache_dir = kwargs.pop('cache_dir', DEFAULT_CACHE_DIR)
|
||||
cache_path = cache_path_for_file_path(fn, cache_dir)
|
||||
|
||||
if cache_path and os.path.exists(cache_path):
|
||||
with open(cache_path, "rb") as cache_file:
|
||||
cache_value = pickle.load(cache_file)
|
||||
else:
|
||||
cache_value = func(fn, *args, **kwargs)
|
||||
if cache_path:
|
||||
with atomic_write_in_dir(cache_path, mode="wb", overwrite=True) as cache_file:
|
||||
pickle.dump(cache_value, cache_file, -1)
|
||||
|
||||
return cache_value
|
||||
|
||||
return cache_inner
|
||||
|
||||
|
||||
@cache_fn
|
||||
def index_stream(fn, ft):
|
||||
if ft != FrameType.h265_stream:
|
||||
raise NotImplementedError("Only h265 supported")
|
||||
|
||||
frame_types, dat_len, prefix = hevc_index(fn)
|
||||
index = np.array(frame_types + [(0xFFFFFFFF, dat_len)], dtype=np.uint32)
|
||||
probe = ffprobe(fn, "hevc")
|
||||
|
||||
return {
|
||||
'index': index,
|
||||
'global_prefix': prefix,
|
||||
'probe': probe
|
||||
}
|
||||
|
||||
|
||||
def get_video_index(fn, frame_type, cache_dir=DEFAULT_CACHE_DIR):
|
||||
return index_stream(fn, frame_type, cache_dir=cache_dir)
|
||||
|
||||
def read_file_check_size(f, sz, cookie):
|
||||
buff = bytearray(sz)
|
||||
bytes_read = f.readinto(buff)
|
||||
assert bytes_read == sz, (bytes_read, sz)
|
||||
return buff
|
||||
|
||||
|
||||
def rgb24toyuv(rgb):
|
||||
yuv_from_rgb = np.array([[ 0.299 , 0.587 , 0.114 ],
|
||||
[-0.14714119, -0.28886916, 0.43601035 ],
|
||||
[ 0.61497538, -0.51496512, -0.10001026 ]])
|
||||
img = np.dot(rgb.reshape(-1, 3), yuv_from_rgb.T).reshape(rgb.shape)
|
||||
|
||||
|
||||
|
||||
ys = img[:, :, 0]
|
||||
us = (img[::2, ::2, 1] + img[1::2, ::2, 1] + img[::2, 1::2, 1] + img[1::2, 1::2, 1]) / 4 + 128
|
||||
vs = (img[::2, ::2, 2] + img[1::2, ::2, 2] + img[::2, 1::2, 2] + img[1::2, 1::2, 2]) / 4 + 128
|
||||
|
||||
return ys, us, vs
|
||||
|
||||
|
||||
def rgb24toyuv420(rgb):
|
||||
ys, us, vs = rgb24toyuv(rgb)
|
||||
|
||||
y_len = rgb.shape[0] * rgb.shape[1]
|
||||
uv_len = y_len // 4
|
||||
|
||||
yuv420 = np.empty(y_len + 2 * uv_len, dtype=rgb.dtype)
|
||||
yuv420[:y_len] = ys.reshape(-1)
|
||||
yuv420[y_len:y_len + uv_len] = us.reshape(-1)
|
||||
yuv420[y_len + uv_len:y_len + 2 * uv_len] = vs.reshape(-1)
|
||||
|
||||
return yuv420.clip(0, 255).astype('uint8')
|
||||
|
||||
|
||||
def rgb24tonv12(rgb):
|
||||
ys, us, vs = rgb24toyuv(rgb)
|
||||
|
||||
y_len = rgb.shape[0] * rgb.shape[1]
|
||||
uv_len = y_len // 4
|
||||
|
||||
nv12 = np.empty(y_len + 2 * uv_len, dtype=rgb.dtype)
|
||||
nv12[:y_len] = ys.reshape(-1)
|
||||
nv12[y_len::2] = us.reshape(-1)
|
||||
nv12[y_len+1::2] = vs.reshape(-1)
|
||||
|
||||
return nv12.clip(0, 255).astype('uint8')
|
||||
|
||||
|
||||
def decompress_video_data(rawdat, vid_fmt, w, h, pix_fmt):
|
||||
threads = os.getenv("FFMPEG_THREADS", "0")
|
||||
cuda = os.getenv("FFMPEG_CUDA", "0") == "1"
|
||||
args = ["ffmpeg", "-v", "quiet",
|
||||
"-threads", threads,
|
||||
"-hwaccel", "none" if not cuda else "cuda",
|
||||
"-c:v", "hevc",
|
||||
"-vsync", "0",
|
||||
"-f", vid_fmt,
|
||||
"-flags2", "showall",
|
||||
"-i", "-",
|
||||
"-threads", threads,
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", pix_fmt,
|
||||
"-"]
|
||||
dat = subprocess.check_output(args, input=rawdat)
|
||||
|
||||
if pix_fmt == "rgb24":
|
||||
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, h, w, 3)
|
||||
elif pix_fmt == "nv12":
|
||||
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, (h*w*3//2))
|
||||
elif pix_fmt == "yuv420p":
|
||||
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, (h*w*3//2))
|
||||
elif pix_fmt == "yuv444p":
|
||||
ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, 3, h, w)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class BaseFrameReader:
|
||||
# properties: frame_type, frame_count, w, h
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def get(self, num, count=1, pix_fmt="yuv420p"):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def FrameReader(fn, cache_dir=DEFAULT_CACHE_DIR, readahead=False, readbehind=False, index_data=None):
|
||||
frame_type = fingerprint_video(fn)
|
||||
if frame_type == FrameType.raw:
|
||||
return RawFrameReader(fn)
|
||||
elif frame_type in (FrameType.h265_stream,):
|
||||
if not index_data:
|
||||
index_data = get_video_index(fn, frame_type, cache_dir)
|
||||
return StreamFrameReader(fn, frame_type, index_data, readahead=readahead, readbehind=readbehind)
|
||||
else:
|
||||
raise NotImplementedError(frame_type)
|
||||
|
||||
|
||||
class RawData:
|
||||
def __init__(self, f):
|
||||
self.f = _io.FileIO(f, 'rb')
|
||||
self.lenn = struct.unpack("I", self.f.read(4))[0]
|
||||
self.count = os.path.getsize(f) / (self.lenn+4)
|
||||
|
||||
def read(self, i):
|
||||
self.f.seek((self.lenn+4)*i + 4)
|
||||
return self.f.read(self.lenn)
|
||||
|
||||
|
||||
class RawFrameReader(BaseFrameReader):
|
||||
def __init__(self, fn):
|
||||
# raw camera
|
||||
self.fn = fn
|
||||
self.frame_type = FrameType.raw
|
||||
self.rawfile = RawData(self.fn)
|
||||
self.frame_count = self.rawfile.count
|
||||
self.w, self.h = 640, 480
|
||||
|
||||
def load_and_debayer(self, img):
|
||||
img = np.frombuffer(img, dtype='uint8').reshape(960, 1280)
|
||||
cimg = np.dstack([img[0::2, 1::2], ((img[0::2, 0::2].astype("uint16") + img[1::2, 1::2].astype("uint16")) >> 1).astype("uint8"), img[1::2, 0::2]])
|
||||
return cimg
|
||||
|
||||
def get(self, num, count=1, pix_fmt="yuv420p"):
|
||||
assert self.frame_count is not None
|
||||
assert num+count <= self.frame_count
|
||||
|
||||
if pix_fmt not in ("nv12", "yuv420p", "rgb24"):
|
||||
raise ValueError(f"Unsupported pixel format {pix_fmt!r}")
|
||||
|
||||
app = []
|
||||
for i in range(num, num+count):
|
||||
dat = self.rawfile.read(i)
|
||||
rgb_dat = self.load_and_debayer(dat)
|
||||
if pix_fmt == "rgb24":
|
||||
app.append(rgb_dat)
|
||||
elif pix_fmt == "nv12":
|
||||
app.append(rgb24tonv12(rgb_dat))
|
||||
elif pix_fmt == "yuv420p":
|
||||
app.append(rgb24toyuv420(rgb_dat))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class VideoStreamDecompressor:
|
||||
def __init__(self, fn, vid_fmt, w, h, pix_fmt):
|
||||
self.fn = fn
|
||||
self.vid_fmt = vid_fmt
|
||||
self.w = w
|
||||
self.h = h
|
||||
self.pix_fmt = pix_fmt
|
||||
|
||||
if pix_fmt in ("nv12", "yuv420p"):
|
||||
self.out_size = w*h*3//2 # yuv420p
|
||||
elif pix_fmt in ("rgb24", "yuv444p"):
|
||||
self.out_size = w*h*3
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.proc = None
|
||||
self.t = threading.Thread(target=self.write_thread)
|
||||
self.t.daemon = True
|
||||
|
||||
def write_thread(self):
|
||||
try:
|
||||
with FileReader(self.fn) as f:
|
||||
while True:
|
||||
r = f.read(1024*1024)
|
||||
if len(r) == 0:
|
||||
break
|
||||
self.proc.stdin.write(r)
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
finally:
|
||||
self.proc.stdin.close()
|
||||
|
||||
def read(self):
|
||||
threads = os.getenv("FFMPEG_THREADS", "0")
|
||||
cuda = os.getenv("FFMPEG_CUDA", "0") == "1"
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-threads", threads,
|
||||
"-hwaccel", "none" if not cuda else "cuda",
|
||||
"-c:v", "hevc",
|
||||
# "-avioflags", "direct",
|
||||
"-analyzeduration", "0",
|
||||
"-probesize", "32",
|
||||
"-flush_packets", "0",
|
||||
# "-fflags", "nobuffer",
|
||||
"-vsync", "0",
|
||||
"-f", self.vid_fmt,
|
||||
"-i", "pipe:0",
|
||||
"-threads", threads,
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", self.pix_fmt,
|
||||
"pipe:1"
|
||||
]
|
||||
self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
||||
try:
|
||||
self.t.start()
|
||||
|
||||
while True:
|
||||
dat = self.proc.stdout.read(self.out_size)
|
||||
if len(dat) == 0:
|
||||
break
|
||||
assert len(dat) == self.out_size
|
||||
if self.pix_fmt == "rgb24":
|
||||
ret = np.frombuffer(dat, dtype=np.uint8).reshape((self.h, self.w, 3))
|
||||
elif self.pix_fmt == "yuv420p":
|
||||
ret = np.frombuffer(dat, dtype=np.uint8)
|
||||
elif self.pix_fmt == "nv12":
|
||||
ret = np.frombuffer(dat, dtype=np.uint8)
|
||||
elif self.pix_fmt == "yuv444p":
|
||||
ret = np.frombuffer(dat, dtype=np.uint8).reshape((3, self.h, self.w))
|
||||
else:
|
||||
raise RuntimeError(f"unknown pix_fmt: {self.pix_fmt}")
|
||||
yield ret
|
||||
|
||||
result_code = self.proc.wait()
|
||||
assert result_code == 0, result_code
|
||||
finally:
|
||||
self.proc.kill()
|
||||
self.t.join()
|
||||
|
||||
class StreamGOPReader(GOPReader):
|
||||
def __init__(self, fn, frame_type, index_data):
|
||||
assert frame_type == FrameType.h265_stream
|
||||
|
||||
self.fn = fn
|
||||
|
||||
self.frame_type = frame_type
|
||||
self.frame_count = None
|
||||
self.w, self.h = None, None
|
||||
|
||||
self.prefix = None
|
||||
self.index = None
|
||||
|
||||
self.index = index_data['index']
|
||||
self.prefix = index_data['global_prefix']
|
||||
probe = index_data['probe']
|
||||
|
||||
self.prefix_frame_data = None
|
||||
self.num_prefix_frames = 0
|
||||
self.vid_fmt = "hevc"
|
||||
|
||||
i = 0
|
||||
while i < self.index.shape[0] and self.index[i, 0] != HEVC_SLICE_I:
|
||||
i += 1
|
||||
self.first_iframe = i
|
||||
|
||||
assert self.first_iframe == 0
|
||||
|
||||
self.frame_count = len(self.index) - 1
|
||||
|
||||
self.w = probe['streams'][0]['width']
|
||||
self.h = probe['streams'][0]['height']
|
||||
|
||||
def _lookup_gop(self, num):
|
||||
frame_b = num
|
||||
while frame_b > 0 and self.index[frame_b, 0] != HEVC_SLICE_I:
|
||||
frame_b -= 1
|
||||
|
||||
frame_e = num + 1
|
||||
while frame_e < (len(self.index) - 1) and self.index[frame_e, 0] != HEVC_SLICE_I:
|
||||
frame_e += 1
|
||||
|
||||
offset_b = self.index[frame_b, 1]
|
||||
offset_e = self.index[frame_e, 1]
|
||||
|
||||
return (frame_b, frame_e, offset_b, offset_e)
|
||||
|
||||
def get_gop(self, num):
|
||||
frame_b, frame_e, offset_b, offset_e = self._lookup_gop(num)
|
||||
assert frame_b <= num < frame_e
|
||||
|
||||
num_frames = frame_e - frame_b
|
||||
|
||||
with FileReader(self.fn) as f:
|
||||
f.seek(offset_b)
|
||||
rawdat = f.read(offset_e - offset_b)
|
||||
|
||||
if num < self.first_iframe:
|
||||
assert self.prefix_frame_data
|
||||
rawdat = self.prefix_frame_data + rawdat
|
||||
|
||||
rawdat = self.prefix + rawdat
|
||||
|
||||
skip_frames = 0
|
||||
if num < self.first_iframe:
|
||||
skip_frames = self.num_prefix_frames
|
||||
|
||||
return frame_b, num_frames, skip_frames, rawdat
|
||||
|
||||
|
||||
class GOPFrameReader(BaseFrameReader):
|
||||
#FrameReader with caching and readahead for formats that are group-of-picture based
|
||||
|
||||
def __init__(self, readahead=False, readbehind=False):
|
||||
self.open_ = True
|
||||
|
||||
self.readahead = readahead
|
||||
self.readbehind = readbehind
|
||||
self.frame_cache = LRU(64)
|
||||
|
||||
if self.readahead:
|
||||
self.cache_lock = threading.RLock()
|
||||
self.readahead_last = None
|
||||
self.readahead_len = 30
|
||||
self.readahead_c = threading.Condition()
|
||||
self.readahead_thread = threading.Thread(target=self._readahead_thread)
|
||||
self.readahead_thread.daemon = True
|
||||
self.readahead_thread.start()
|
||||
else:
|
||||
self.cache_lock = DoNothingContextManager()
|
||||
|
||||
def close(self):
|
||||
if not self.open_:
|
||||
return
|
||||
self.open_ = False
|
||||
|
||||
if self.readahead:
|
||||
self.readahead_c.acquire()
|
||||
self.readahead_c.notify()
|
||||
self.readahead_c.release()
|
||||
self.readahead_thread.join()
|
||||
|
||||
def _readahead_thread(self):
|
||||
while True:
|
||||
self.readahead_c.acquire()
|
||||
try:
|
||||
if not self.open_:
|
||||
break
|
||||
self.readahead_c.wait()
|
||||
finally:
|
||||
self.readahead_c.release()
|
||||
if not self.open_:
|
||||
break
|
||||
assert self.readahead_last
|
||||
num, pix_fmt = self.readahead_last
|
||||
|
||||
if self.readbehind:
|
||||
for k in range(num - 1, max(0, num - self.readahead_len), -1):
|
||||
self._get_one(k, pix_fmt)
|
||||
else:
|
||||
for k in range(num, min(self.frame_count, num + self.readahead_len)):
|
||||
self._get_one(k, pix_fmt)
|
||||
|
||||
def _get_one(self, num, pix_fmt):
|
||||
assert num < self.frame_count
|
||||
|
||||
if (num, pix_fmt) in self.frame_cache:
|
||||
return self.frame_cache[(num, pix_fmt)]
|
||||
|
||||
with self.cache_lock:
|
||||
if (num, pix_fmt) in self.frame_cache:
|
||||
return self.frame_cache[(num, pix_fmt)]
|
||||
|
||||
frame_b, num_frames, skip_frames, rawdat = self.get_gop(num)
|
||||
|
||||
ret = decompress_video_data(rawdat, self.vid_fmt, self.w, self.h, pix_fmt)
|
||||
ret = ret[skip_frames:]
|
||||
assert ret.shape[0] == num_frames
|
||||
|
||||
for i in range(ret.shape[0]):
|
||||
self.frame_cache[(frame_b+i, pix_fmt)] = ret[i]
|
||||
|
||||
return self.frame_cache[(num, pix_fmt)]
|
||||
|
||||
def get(self, num, count=1, pix_fmt="yuv420p"):
|
||||
assert self.frame_count is not None
|
||||
|
||||
if num + count > self.frame_count:
|
||||
raise ValueError(f"{num + count} > {self.frame_count}")
|
||||
|
||||
if pix_fmt not in ("nv12", "yuv420p", "rgb24", "yuv444p"):
|
||||
raise ValueError(f"Unsupported pixel format {pix_fmt!r}")
|
||||
|
||||
ret = [self._get_one(num + i, pix_fmt) for i in range(count)]
|
||||
|
||||
if self.readahead:
|
||||
self.readahead_last = (num+count, pix_fmt)
|
||||
self.readahead_c.acquire()
|
||||
self.readahead_c.notify()
|
||||
self.readahead_c.release()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class StreamFrameReader(StreamGOPReader, GOPFrameReader):
|
||||
def __init__(self, fn, frame_type, index_data, readahead=False, readbehind=False):
|
||||
StreamGOPReader.__init__(self, fn, frame_type, index_data)
|
||||
GOPFrameReader.__init__(self, readahead, readbehind)
|
||||
|
||||
|
||||
def GOPFrameIterator(gop_reader, pix_fmt):
|
||||
dec = VideoStreamDecompressor(gop_reader.fn, gop_reader.vid_fmt, gop_reader.w, gop_reader.h, pix_fmt)
|
||||
yield from dec.read()
|
||||
|
||||
|
||||
def FrameIterator(fn, pix_fmt, **kwargs):
|
||||
fr = FrameReader(fn, **kwargs)
|
||||
if isinstance(fr, GOPReader):
|
||||
yield from GOPFrameIterator(fr, pix_fmt)
|
||||
else:
|
||||
for i in range(fr.frame_count):
|
||||
yield fr.get(i, pix_fmt=pix_fmt)[0]
|
||||
35
tools/lib/helpers.py
Normal file
35
tools/lib/helpers.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import bz2
|
||||
import datetime
|
||||
|
||||
TIME_FMT = "%Y-%m-%d--%H-%M-%S"
|
||||
|
||||
# regex patterns
|
||||
class RE:
|
||||
DONGLE_ID = r'(?P<dongle_id>[a-z0-9]{16})'
|
||||
TIMESTAMP = r'(?P<timestamp>[0-9]{4}-[0-9]{2}-[0-9]{2}--[0-9]{2}-[0-9]{2}-[0-9]{2})'
|
||||
ROUTE_NAME = r'(?P<route_name>{}[|_/]{})'.format(DONGLE_ID, TIMESTAMP)
|
||||
SEGMENT_NAME = r'{}(?:--|/)(?P<segment_num>[0-9]+)'.format(ROUTE_NAME)
|
||||
INDEX = r'-?[0-9]+'
|
||||
SLICE = r'(?P<start>{})?:?(?P<end>{})?:?(?P<step>{})?'.format(INDEX, INDEX, INDEX)
|
||||
SEGMENT_RANGE = r'{}(?:--|/)?(?P<slice>({}))?/?(?P<selector>([qr]))?'.format(ROUTE_NAME, SLICE)
|
||||
BOOTLOG_NAME = ROUTE_NAME
|
||||
|
||||
EXPLORER_FILE = r'^(?P<segment_name>{})--(?P<file_name>[a-z]+\.[a-z0-9]+)$'.format(SEGMENT_NAME)
|
||||
OP_SEGMENT_DIR = r'^(?P<segment_name>{})$'.format(SEGMENT_NAME)
|
||||
|
||||
|
||||
def timestamp_to_datetime(t: str) -> datetime.datetime:
|
||||
"""
|
||||
Convert an openpilot route timestamp to a python datetime
|
||||
"""
|
||||
return datetime.datetime.strptime(t, TIME_FMT)
|
||||
|
||||
|
||||
def save_log(dest, log_msgs, compress=True):
|
||||
dat = b"".join(msg.as_builder().to_bytes() for msg in log_msgs)
|
||||
|
||||
if compress:
|
||||
dat = bz2.compress(dat)
|
||||
|
||||
with open(dest, "wb") as f:
|
||||
f.write(dat)
|
||||
81
tools/lib/kbhit.py
Executable file
81
tools/lib/kbhit.py
Executable file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
import termios
|
||||
import atexit
|
||||
from select import select
|
||||
|
||||
STDIN_FD = sys.stdin.fileno()
|
||||
|
||||
class KBHit:
|
||||
def __init__(self) -> None:
|
||||
''' Creates a KBHit object that you can call to do various keyboard things.
|
||||
'''
|
||||
|
||||
self.set_kbhit_terminal()
|
||||
|
||||
def set_kbhit_terminal(self) -> None:
|
||||
''' Save old terminal settings for closure, remove ICANON & ECHO flags.
|
||||
'''
|
||||
|
||||
# Save the terminal settings
|
||||
self.old_term = termios.tcgetattr(STDIN_FD)
|
||||
self.new_term = self.old_term.copy()
|
||||
|
||||
# New terminal setting unbuffered
|
||||
self.new_term[3] &= ~(termios.ICANON | termios.ECHO)
|
||||
termios.tcsetattr(STDIN_FD, termios.TCSAFLUSH, self.new_term)
|
||||
|
||||
# Support normal-terminal reset at exit
|
||||
atexit.register(self.set_normal_term)
|
||||
|
||||
def set_normal_term(self) -> None:
|
||||
''' Resets to normal terminal. On Windows this is a no-op.
|
||||
'''
|
||||
|
||||
termios.tcsetattr(STDIN_FD, termios.TCSAFLUSH, self.old_term)
|
||||
|
||||
@staticmethod
|
||||
def getch() -> str:
|
||||
''' Returns a keyboard character after kbhit() has been called.
|
||||
Should not be called in the same program as getarrow().
|
||||
'''
|
||||
return sys.stdin.read(1)
|
||||
|
||||
@staticmethod
|
||||
def getarrow() -> int:
|
||||
''' Returns an arrow-key code after kbhit() has been called. Codes are
|
||||
0 : up
|
||||
1 : right
|
||||
2 : down
|
||||
3 : left
|
||||
Should not be called in the same program as getch().
|
||||
'''
|
||||
|
||||
c = sys.stdin.read(3)[2]
|
||||
vals = [65, 67, 66, 68]
|
||||
|
||||
return vals.index(ord(c))
|
||||
|
||||
@staticmethod
|
||||
def kbhit():
|
||||
''' Returns True if keyboard character was hit, False otherwise.
|
||||
'''
|
||||
return select([sys.stdin], [], [], 0)[0] != []
|
||||
|
||||
|
||||
# Test
|
||||
if __name__ == "__main__":
|
||||
|
||||
kb = KBHit()
|
||||
|
||||
print('Hit any key, or ESC to exit')
|
||||
|
||||
while True:
|
||||
|
||||
if kb.kbhit():
|
||||
c = kb.getch()
|
||||
if c == '\x1b': # ESC
|
||||
break
|
||||
print(c)
|
||||
|
||||
kb.set_normal_term()
|
||||
141
tools/lib/logreader.py
Executable file
141
tools/lib/logreader.py
Executable file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import bz2
|
||||
import urllib.parse
|
||||
import capnp
|
||||
import warnings
|
||||
|
||||
from typing import Iterable, Iterator
|
||||
|
||||
from cereal import log as capnp_log
|
||||
from openpilot.tools.lib.filereader import FileReader
|
||||
from openpilot.tools.lib.route import Route, SegmentName
|
||||
|
||||
LogIterable = Iterable[capnp._DynamicStructReader]
|
||||
|
||||
# this is an iterator itself, and uses private variables from LogReader
|
||||
class MultiLogIterator:
|
||||
def __init__(self, log_paths, sort_by_time=False):
|
||||
self._log_paths = log_paths
|
||||
self.sort_by_time = sort_by_time
|
||||
|
||||
self._first_log_idx = next(i for i in range(len(log_paths)) if log_paths[i] is not None)
|
||||
self._current_log = self._first_log_idx
|
||||
self._idx = 0
|
||||
self._log_readers = [None]*len(log_paths)
|
||||
self.start_time = self._log_reader(self._first_log_idx)._ts[0]
|
||||
|
||||
def _log_reader(self, i):
|
||||
if self._log_readers[i] is None and self._log_paths[i] is not None:
|
||||
log_path = self._log_paths[i]
|
||||
self._log_readers[i] = LogReader(log_path, sort_by_time=self.sort_by_time)
|
||||
|
||||
return self._log_readers[i]
|
||||
|
||||
def __iter__(self) -> Iterator[capnp._DynamicStructReader]:
|
||||
return self
|
||||
|
||||
def _inc(self):
|
||||
lr = self._log_reader(self._current_log)
|
||||
if self._idx < len(lr._ents)-1:
|
||||
self._idx += 1
|
||||
else:
|
||||
self._idx = 0
|
||||
self._current_log = next(i for i in range(self._current_log + 1, len(self._log_readers) + 1)
|
||||
if i == len(self._log_readers) or self._log_paths[i] is not None)
|
||||
if self._current_log == len(self._log_readers):
|
||||
raise StopIteration
|
||||
|
||||
def __next__(self):
|
||||
while 1:
|
||||
lr = self._log_reader(self._current_log)
|
||||
ret = lr._ents[self._idx]
|
||||
self._inc()
|
||||
return ret
|
||||
|
||||
def tell(self):
|
||||
# returns seconds from start of log
|
||||
return (self._log_reader(self._current_log)._ts[self._idx] - self.start_time) * 1e-9
|
||||
|
||||
def seek(self, ts):
|
||||
# seek to nearest minute
|
||||
minute = int(ts/60)
|
||||
if minute >= len(self._log_paths) or self._log_paths[minute] is None:
|
||||
return False
|
||||
|
||||
self._current_log = minute
|
||||
|
||||
# HACK: O(n) seek afterward
|
||||
self._idx = 0
|
||||
while self.tell() < ts:
|
||||
self._inc()
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
self.__init__(self._log_paths, sort_by_time=self.sort_by_time)
|
||||
|
||||
|
||||
class LogReader:
|
||||
def __init__(self, fn, canonicalize=True, only_union_types=False, sort_by_time=False, dat=None):
|
||||
self.data_version = None
|
||||
self._only_union_types = only_union_types
|
||||
|
||||
ext = None
|
||||
if not dat:
|
||||
_, ext = os.path.splitext(urllib.parse.urlparse(fn).path)
|
||||
if ext not in ('', '.bz2'):
|
||||
# old rlogs weren't bz2 compressed
|
||||
raise Exception(f"unknown extension {ext}")
|
||||
|
||||
with FileReader(fn) as f:
|
||||
dat = f.read()
|
||||
|
||||
if ext == ".bz2" or dat.startswith(b'BZh9'):
|
||||
dat = bz2.decompress(dat)
|
||||
|
||||
ents = capnp_log.Event.read_multiple_bytes(dat)
|
||||
|
||||
_ents = []
|
||||
try:
|
||||
for e in ents:
|
||||
_ents.append(e)
|
||||
except capnp.KjException:
|
||||
warnings.warn("Corrupted events detected", RuntimeWarning, stacklevel=1)
|
||||
|
||||
self._ents = list(sorted(_ents, key=lambda x: x.logMonoTime) if sort_by_time else _ents)
|
||||
self._ts = [x.logMonoTime for x in self._ents]
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, dat):
|
||||
return cls("", dat=dat)
|
||||
|
||||
def __iter__(self) -> Iterator[capnp._DynamicStructReader]:
|
||||
for ent in self._ents:
|
||||
if self._only_union_types:
|
||||
try:
|
||||
ent.which()
|
||||
yield ent
|
||||
except capnp.lib.capnp.KjException:
|
||||
pass
|
||||
else:
|
||||
yield ent
|
||||
|
||||
def logreader_from_route_or_segment(r, sort_by_time=False):
|
||||
sn = SegmentName(r, allow_route_name=True)
|
||||
route = Route(sn.route_name.canonical_name)
|
||||
if sn.segment_num < 0:
|
||||
return MultiLogIterator(route.log_paths(), sort_by_time=sort_by_time)
|
||||
else:
|
||||
return LogReader(route.log_paths()[sn.segment_num], sort_by_time=sort_by_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import codecs
|
||||
# capnproto <= 0.8.0 throws errors converting byte data to string
|
||||
# below line catches those errors and replaces the bytes with \x__
|
||||
codecs.register_error("strict", codecs.backslashreplace_errors)
|
||||
log_path = sys.argv[1]
|
||||
lr = LogReader(log_path, sort_by_time=True)
|
||||
for msg in lr:
|
||||
print(msg)
|
||||
257
tools/lib/route.py
Normal file
257
tools/lib/route.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import os
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
from openpilot.tools.lib.auth_config import get_token
|
||||
from openpilot.tools.lib.api import CommaApi
|
||||
from openpilot.tools.lib.helpers import RE
|
||||
|
||||
QLOG_FILENAMES = ['qlog', 'qlog.bz2']
|
||||
QCAMERA_FILENAMES = ['qcamera.ts']
|
||||
LOG_FILENAMES = ['rlog', 'rlog.bz2', 'raw_log.bz2']
|
||||
CAMERA_FILENAMES = ['fcamera.hevc', 'video.hevc']
|
||||
DCAMERA_FILENAMES = ['dcamera.hevc']
|
||||
ECAMERA_FILENAMES = ['ecamera.hevc']
|
||||
|
||||
class Route:
|
||||
def __init__(self, name, data_dir=None):
|
||||
self._name = RouteName(name)
|
||||
self.files = None
|
||||
if data_dir is not None:
|
||||
self._segments = self._get_segments_local(data_dir)
|
||||
else:
|
||||
self._segments = self._get_segments_remote()
|
||||
self.max_seg_number = self._segments[-1].name.segment_num
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def segments(self):
|
||||
return self._segments
|
||||
|
||||
def log_paths(self):
|
||||
log_path_by_seg_num = {s.name.segment_num: s.log_path for s in self._segments}
|
||||
return [log_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
|
||||
|
||||
def qlog_paths(self):
|
||||
qlog_path_by_seg_num = {s.name.segment_num: s.qlog_path for s in self._segments}
|
||||
return [qlog_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
|
||||
|
||||
def camera_paths(self):
|
||||
camera_path_by_seg_num = {s.name.segment_num: s.camera_path for s in self._segments}
|
||||
return [camera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
|
||||
|
||||
def dcamera_paths(self):
|
||||
dcamera_path_by_seg_num = {s.name.segment_num: s.dcamera_path for s in self._segments}
|
||||
return [dcamera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
|
||||
|
||||
def ecamera_paths(self):
|
||||
ecamera_path_by_seg_num = {s.name.segment_num: s.ecamera_path for s in self._segments}
|
||||
return [ecamera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
|
||||
|
||||
def qcamera_paths(self):
|
||||
qcamera_path_by_seg_num = {s.name.segment_num: s.qcamera_path for s in self._segments}
|
||||
return [qcamera_path_by_seg_num.get(i, None) for i in range(self.max_seg_number+1)]
|
||||
|
||||
# TODO: refactor this, it's super repetitive
|
||||
def _get_segments_remote(self):
|
||||
api = CommaApi(get_token())
|
||||
route_files = api.get('v1/route/' + self.name.canonical_name + '/files')
|
||||
self.files = list(chain.from_iterable(route_files.values()))
|
||||
|
||||
segments = {}
|
||||
for url in self.files:
|
||||
_, dongle_id, time_str, segment_num, fn = urlparse(url).path.rsplit('/', maxsplit=4)
|
||||
segment_name = f'{dongle_id}|{time_str}--{segment_num}'
|
||||
if segments.get(segment_name):
|
||||
segments[segment_name] = Segment(
|
||||
segment_name,
|
||||
url if fn in LOG_FILENAMES else segments[segment_name].log_path,
|
||||
url if fn in QLOG_FILENAMES else segments[segment_name].qlog_path,
|
||||
url if fn in CAMERA_FILENAMES else segments[segment_name].camera_path,
|
||||
url if fn in DCAMERA_FILENAMES else segments[segment_name].dcamera_path,
|
||||
url if fn in ECAMERA_FILENAMES else segments[segment_name].ecamera_path,
|
||||
url if fn in QCAMERA_FILENAMES else segments[segment_name].qcamera_path,
|
||||
)
|
||||
else:
|
||||
segments[segment_name] = Segment(
|
||||
segment_name,
|
||||
url if fn in LOG_FILENAMES else None,
|
||||
url if fn in QLOG_FILENAMES else None,
|
||||
url if fn in CAMERA_FILENAMES else None,
|
||||
url if fn in DCAMERA_FILENAMES else None,
|
||||
url if fn in ECAMERA_FILENAMES else None,
|
||||
url if fn in QCAMERA_FILENAMES else None,
|
||||
)
|
||||
|
||||
return sorted(segments.values(), key=lambda seg: seg.name.segment_num)
|
||||
|
||||
def _get_segments_local(self, data_dir):
|
||||
files = os.listdir(data_dir)
|
||||
segment_files = defaultdict(list)
|
||||
|
||||
for f in files:
|
||||
fullpath = os.path.join(data_dir, f)
|
||||
explorer_match = re.match(RE.EXPLORER_FILE, f)
|
||||
op_match = re.match(RE.OP_SEGMENT_DIR, f)
|
||||
|
||||
if explorer_match:
|
||||
segment_name = explorer_match.group('segment_name')
|
||||
fn = explorer_match.group('file_name')
|
||||
if segment_name.replace('_', '|').startswith(self.name.canonical_name):
|
||||
segment_files[segment_name].append((fullpath, fn))
|
||||
elif op_match and os.path.isdir(fullpath):
|
||||
segment_name = op_match.group('segment_name')
|
||||
if segment_name.startswith(self.name.canonical_name):
|
||||
for seg_f in os.listdir(fullpath):
|
||||
segment_files[segment_name].append((os.path.join(fullpath, seg_f), seg_f))
|
||||
elif f == self.name.canonical_name:
|
||||
for seg_num in os.listdir(fullpath):
|
||||
if not seg_num.isdigit():
|
||||
continue
|
||||
|
||||
segment_name = f'{self.name.canonical_name}--{seg_num}'
|
||||
for seg_f in os.listdir(os.path.join(fullpath, seg_num)):
|
||||
segment_files[segment_name].append((os.path.join(fullpath, seg_num, seg_f), seg_f))
|
||||
|
||||
segments = []
|
||||
for segment, files in segment_files.items():
|
||||
|
||||
try:
|
||||
log_path = next(path for path, filename in files if filename in LOG_FILENAMES)
|
||||
except StopIteration:
|
||||
log_path = None
|
||||
|
||||
try:
|
||||
qlog_path = next(path for path, filename in files if filename in QLOG_FILENAMES)
|
||||
except StopIteration:
|
||||
qlog_path = None
|
||||
|
||||
try:
|
||||
camera_path = next(path for path, filename in files if filename in CAMERA_FILENAMES)
|
||||
except StopIteration:
|
||||
camera_path = None
|
||||
|
||||
try:
|
||||
dcamera_path = next(path for path, filename in files if filename in DCAMERA_FILENAMES)
|
||||
except StopIteration:
|
||||
dcamera_path = None
|
||||
|
||||
try:
|
||||
ecamera_path = next(path for path, filename in files if filename in ECAMERA_FILENAMES)
|
||||
except StopIteration:
|
||||
ecamera_path = None
|
||||
|
||||
try:
|
||||
qcamera_path = next(path for path, filename in files if filename in QCAMERA_FILENAMES)
|
||||
except StopIteration:
|
||||
qcamera_path = None
|
||||
|
||||
segments.append(Segment(segment, log_path, qlog_path, camera_path, dcamera_path, ecamera_path, qcamera_path))
|
||||
|
||||
if len(segments) == 0:
|
||||
raise ValueError(f'Could not find segments for route {self.name.canonical_name} in data directory {data_dir}')
|
||||
return sorted(segments, key=lambda seg: seg.name.segment_num)
|
||||
|
||||
class Segment:
|
||||
def __init__(self, name, log_path, qlog_path, camera_path, dcamera_path, ecamera_path, qcamera_path):
|
||||
self._name = SegmentName(name)
|
||||
self.log_path = log_path
|
||||
self.qlog_path = qlog_path
|
||||
self.camera_path = camera_path
|
||||
self.dcamera_path = dcamera_path
|
||||
self.ecamera_path = ecamera_path
|
||||
self.qcamera_path = qcamera_path
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
class RouteName:
|
||||
def __init__(self, name_str: str):
|
||||
self._name_str = name_str
|
||||
delim = next(c for c in self._name_str if c in ("|", "/"))
|
||||
self._dongle_id, self._time_str = self._name_str.split(delim)
|
||||
|
||||
assert len(self._dongle_id) == 16, self._name_str
|
||||
assert len(self._time_str) == 20, self._name_str
|
||||
self._canonical_name = f"{self._dongle_id}|{self._time_str}"
|
||||
|
||||
@property
|
||||
def canonical_name(self) -> str: return self._canonical_name
|
||||
|
||||
@property
|
||||
def dongle_id(self) -> str: return self._dongle_id
|
||||
|
||||
@property
|
||||
def time_str(self) -> str: return self._time_str
|
||||
|
||||
def __str__(self) -> str: return self._canonical_name
|
||||
|
||||
class SegmentName:
|
||||
# TODO: add constructor that takes dongle_id, time_str, segment_num and then create instances
|
||||
# of this class instead of manually constructing a segment name (use canonical_name prop instead)
|
||||
def __init__(self, name_str: str, allow_route_name=False):
|
||||
data_dir_path_separator_index = name_str.rsplit("|", 1)[0].rfind("/")
|
||||
use_data_dir = (data_dir_path_separator_index != -1) and ("|" in name_str)
|
||||
self._name_str = name_str[data_dir_path_separator_index + 1:] if use_data_dir else name_str
|
||||
self._data_dir = name_str[:data_dir_path_separator_index] if use_data_dir else None
|
||||
|
||||
seg_num_delim = "--" if self._name_str.count("--") == 2 else "/"
|
||||
name_parts = self._name_str.rsplit(seg_num_delim, 1)
|
||||
if allow_route_name and len(name_parts) == 1:
|
||||
name_parts.append("-1") # no segment number
|
||||
self._route_name = RouteName(name_parts[0])
|
||||
self._num = int(name_parts[1])
|
||||
self._canonical_name = f"{self._route_name._dongle_id}|{self._route_name._time_str}--{self._num}"
|
||||
|
||||
@property
|
||||
def canonical_name(self) -> str: return self._canonical_name
|
||||
|
||||
@property
|
||||
def dongle_id(self) -> str: return self._route_name.dongle_id
|
||||
|
||||
@property
|
||||
def time_str(self) -> str: return self._route_name.time_str
|
||||
|
||||
@property
|
||||
def segment_num(self) -> int: return self._num
|
||||
|
||||
@property
|
||||
def route_name(self) -> RouteName: return self._route_name
|
||||
|
||||
@property
|
||||
def data_dir(self) -> Optional[str]: return self._data_dir
|
||||
|
||||
def __str__(self) -> str: return self._canonical_name
|
||||
|
||||
|
||||
class SegmentRange:
|
||||
def __init__(self, segment_range: str):
|
||||
self.m = re.fullmatch(RE.SEGMENT_RANGE, segment_range)
|
||||
assert self.m, f"Segment range is not valid {segment_range}"
|
||||
|
||||
@property
|
||||
def route_name(self):
|
||||
return self.m.group("route_name")
|
||||
|
||||
@property
|
||||
def dongle_id(self):
|
||||
return self.m.group("dongle_id")
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
return self.m.group("timestamp")
|
||||
|
||||
@property
|
||||
def _slice(self):
|
||||
return self.m.group("slice")
|
||||
|
||||
@property
|
||||
def selector(self):
|
||||
return self.m.group("selector")
|
||||
141
tools/lib/srreader.py
Normal file
141
tools/lib/srreader.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import enum
|
||||
import numpy as np
|
||||
import pathlib
|
||||
import re
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from openpilot.selfdrive.test.openpilotci import get_url
|
||||
from openpilot.tools.lib.helpers import RE
|
||||
from openpilot.tools.lib.logreader import LogReader
|
||||
from openpilot.tools.lib.route import Route, SegmentRange
|
||||
|
||||
class ReadMode(enum.StrEnum):
|
||||
RLOG = "r" # only read rlogs
|
||||
QLOG = "q" # only read qlogs
|
||||
#AUTO = "a" # default to rlogs, fallback to qlogs, not supported yet
|
||||
|
||||
|
||||
def create_slice_from_string(s: str):
|
||||
m = re.fullmatch(RE.SLICE, s)
|
||||
assert m is not None, f"Invalid slice: {s}"
|
||||
start, end, step = m.groups()
|
||||
start = int(start) if start is not None else None
|
||||
end = int(end) if end is not None else None
|
||||
step = int(step) if step is not None else None
|
||||
|
||||
if start is not None and ":" not in s and end is None and step is None:
|
||||
return start
|
||||
return slice(start, end, step)
|
||||
|
||||
|
||||
def parse_slice(sr: SegmentRange):
|
||||
route = Route(sr.route_name)
|
||||
segs = np.arange(route.max_seg_number+1)
|
||||
s = create_slice_from_string(sr._slice)
|
||||
return segs[s] if isinstance(s, slice) else [segs[s]]
|
||||
|
||||
def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False):
|
||||
segs = parse_slice(sr)
|
||||
route = Route(sr.route_name)
|
||||
|
||||
log_paths = route.log_paths() if mode == ReadMode.RLOG else route.qlog_paths()
|
||||
|
||||
invalid_segs = [seg for seg in segs if log_paths[seg] is None]
|
||||
|
||||
assert not len(invalid_segs), f"Some of the requested segments are not available: {invalid_segs}"
|
||||
|
||||
for seg in segs:
|
||||
yield LogReader(log_paths[seg], sort_by_time=sort_by_time)
|
||||
|
||||
def internal_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False):
|
||||
segs = parse_slice(sr)
|
||||
|
||||
for seg in segs:
|
||||
yield LogReader(f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{'rlog' if mode == ReadMode.RLOG else 'qlog'}.bz2", sort_by_time=sort_by_time)
|
||||
|
||||
def openpilotci_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False):
|
||||
segs = parse_slice(sr)
|
||||
|
||||
for seg in segs:
|
||||
yield LogReader(get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog'), sort_by_time=sort_by_time)
|
||||
|
||||
def direct_source(file_or_url, sort_by_time):
|
||||
yield LogReader(file_or_url, sort_by_time=sort_by_time)
|
||||
|
||||
def auto_source(*args, **kwargs):
|
||||
# Automatically determine viable source
|
||||
|
||||
try:
|
||||
next(internal_source(*args, **kwargs))
|
||||
return internal_source(*args, **kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
next(openpilotci_source(*args, **kwargs))
|
||||
return openpilotci_source(*args, **kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return comma_api_source(*args, **kwargs)
|
||||
|
||||
def parse_useradmin(identifier):
|
||||
if "useradmin.comma.ai" in identifier:
|
||||
query = parse_qs(urlparse(identifier).query)
|
||||
return query["onebox"][0]
|
||||
return None
|
||||
|
||||
def parse_cabana(identifier):
|
||||
if "cabana.comma.ai" in identifier:
|
||||
query = parse_qs(urlparse(identifier).query)
|
||||
return query["route"][0]
|
||||
return None
|
||||
|
||||
def parse_cd(identifier):
|
||||
if "cd:/" in identifier:
|
||||
return identifier.replace("cd:/", "")
|
||||
return None
|
||||
|
||||
def parse_direct(identifier):
|
||||
if "https://" in identifier or "http://" in identifier or pathlib.Path(identifier).exists():
|
||||
return identifier
|
||||
return None
|
||||
|
||||
def parse_indirect(identifier):
|
||||
parsed = parse_useradmin(identifier) or parse_cabana(identifier)
|
||||
|
||||
if parsed is not None:
|
||||
return parsed, comma_api_source, True
|
||||
|
||||
parsed = parse_cd(identifier)
|
||||
if parsed is not None:
|
||||
return parsed, internal_source, True
|
||||
|
||||
return identifier, None, False
|
||||
|
||||
class SegmentRangeReader:
|
||||
def _logreaders_from_identifier(self, identifier):
|
||||
parsed, source, is_indirect = parse_indirect(identifier)
|
||||
|
||||
if not is_indirect:
|
||||
direct_parsed = parse_direct(identifier)
|
||||
if direct_parsed is not None:
|
||||
return direct_source(identifier, sort_by_time=self.sort_by_time)
|
||||
|
||||
sr = SegmentRange(parsed)
|
||||
mode = self.default_mode if sr.selector is None else ReadMode(sr.selector)
|
||||
source = self.default_source if source is None else source
|
||||
|
||||
return source(sr, mode, sort_by_time=self.sort_by_time)
|
||||
|
||||
def __init__(self, identifier: str, default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False):
|
||||
self.default_mode = default_mode
|
||||
self.default_source = default_source
|
||||
self.sort_by_time = sort_by_time
|
||||
|
||||
self.lrs = self._logreaders_from_identifier(identifier)
|
||||
|
||||
def __iter__(self):
|
||||
for lr in self.lrs:
|
||||
for m in lr:
|
||||
yield m
|
||||
0
tools/lib/tests/__init__.py
Normal file
0
tools/lib/tests/__init__.py
Normal file
129
tools/lib/tests/test_caching.py
Executable file
129
tools/lib/tests/test_caching.py
Executable file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
from functools import wraps
|
||||
import http.server
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from openpilot.tools.lib.url_file import URLFile
|
||||
|
||||
|
||||
class CachingTestRequestHandler(http.server.BaseHTTPRequestHandler):
|
||||
FILE_EXISTS = True
|
||||
|
||||
def do_GET(self):
|
||||
if self.FILE_EXISTS:
|
||||
self.send_response(200, b'1234')
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def do_HEAD(self):
|
||||
if self.FILE_EXISTS:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Length", "4")
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
|
||||
class CachingTestServer(threading.Thread):
|
||||
def run(self):
|
||||
self.server = http.server.HTTPServer(("127.0.0.1", 0), CachingTestRequestHandler)
|
||||
self.port = self.server.server_port
|
||||
self.server.serve_forever()
|
||||
|
||||
def stop(self):
|
||||
self.server.server_close()
|
||||
self.server.shutdown()
|
||||
|
||||
def with_caching_server(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
server = CachingTestServer()
|
||||
server.start()
|
||||
time.sleep(0.25) # wait for server to get it's port
|
||||
try:
|
||||
func(*args, **kwargs, port=server.port)
|
||||
finally:
|
||||
server.stop()
|
||||
return wrapper
|
||||
|
||||
|
||||
class TestFileDownload(unittest.TestCase):
|
||||
|
||||
def compare_loads(self, url, start=0, length=None):
|
||||
"""Compares range between cached and non cached version"""
|
||||
file_cached = URLFile(url, cache=True)
|
||||
file_downloaded = URLFile(url, cache=False)
|
||||
|
||||
file_cached.seek(start)
|
||||
file_downloaded.seek(start)
|
||||
|
||||
self.assertEqual(file_cached.get_length(), file_downloaded.get_length())
|
||||
self.assertLessEqual(length + start if length is not None else 0, file_downloaded.get_length())
|
||||
|
||||
response_cached = file_cached.read(ll=length)
|
||||
response_downloaded = file_downloaded.read(ll=length)
|
||||
|
||||
self.assertEqual(response_cached, response_downloaded)
|
||||
|
||||
# Now test with cache in place
|
||||
file_cached = URLFile(url, cache=True)
|
||||
file_cached.seek(start)
|
||||
response_cached = file_cached.read(ll=length)
|
||||
|
||||
self.assertEqual(file_cached.get_length(), file_downloaded.get_length())
|
||||
self.assertEqual(response_cached, response_downloaded)
|
||||
|
||||
def test_small_file(self):
|
||||
# Make sure we don't force cache
|
||||
os.environ["FILEREADER_CACHE"] = "0"
|
||||
small_file_url = "https://raw.githubusercontent.com/commaai/openpilot/master/docs/SAFETY.md"
|
||||
# If you want large file to be larger than a chunk
|
||||
# large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/fcamera.hevc"
|
||||
|
||||
# Load full small file
|
||||
self.compare_loads(small_file_url)
|
||||
|
||||
file_small = URLFile(small_file_url)
|
||||
length = file_small.get_length()
|
||||
|
||||
self.compare_loads(small_file_url, length - 100, 100)
|
||||
self.compare_loads(small_file_url, 50, 100)
|
||||
|
||||
# Load small file 100 bytes at a time
|
||||
for i in range(length // 100):
|
||||
self.compare_loads(small_file_url, 100 * i, 100)
|
||||
|
||||
def test_large_file(self):
|
||||
large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"
|
||||
# Load the end 100 bytes of both files
|
||||
file_large = URLFile(large_file_url)
|
||||
length = file_large.get_length()
|
||||
|
||||
self.compare_loads(large_file_url, length - 100, 100)
|
||||
self.compare_loads(large_file_url)
|
||||
|
||||
@parameterized.expand([(True, ), (False, )])
|
||||
@with_caching_server
|
||||
def test_recover_from_missing_file(self, cache_enabled, port):
|
||||
os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0"
|
||||
|
||||
file_url = f"http://localhost:{port}/test.png"
|
||||
|
||||
CachingTestRequestHandler.FILE_EXISTS = False
|
||||
length = URLFile(file_url).get_length()
|
||||
self.assertEqual(length, -1)
|
||||
|
||||
CachingTestRequestHandler.FILE_EXISTS = True
|
||||
length = URLFile(file_url).get_length()
|
||||
self.assertEqual(length, 4)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
67
tools/lib/tests/test_readers.py
Executable file
67
tools/lib/tests/test_readers.py
Executable file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import requests
|
||||
import tempfile
|
||||
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from openpilot.tools.lib.framereader import FrameReader
|
||||
from openpilot.tools.lib.logreader import LogReader
|
||||
|
||||
|
||||
class TestReaders(unittest.TestCase):
|
||||
@unittest.skip("skip for bandwidth reasons")
|
||||
def test_logreader(self):
|
||||
def _check_data(lr):
|
||||
hist = defaultdict(int)
|
||||
for l in lr:
|
||||
hist[l.which()] += 1
|
||||
|
||||
self.assertEqual(hist['carControl'], 6000)
|
||||
self.assertEqual(hist['logMessage'], 6857)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".bz2") as fp:
|
||||
r = requests.get("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/raw_log.bz2?raw=true", timeout=10)
|
||||
fp.write(r.content)
|
||||
fp.flush()
|
||||
|
||||
lr_file = LogReader(fp.name)
|
||||
_check_data(lr_file)
|
||||
|
||||
lr_url = LogReader("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/raw_log.bz2?raw=true")
|
||||
_check_data(lr_url)
|
||||
|
||||
@unittest.skip("skip for bandwidth reasons")
|
||||
def test_framereader(self):
|
||||
def _check_data(f):
|
||||
self.assertEqual(f.frame_count, 1200)
|
||||
self.assertEqual(f.w, 1164)
|
||||
self.assertEqual(f.h, 874)
|
||||
|
||||
frame_first_30 = f.get(0, 30)
|
||||
self.assertEqual(len(frame_first_30), 30)
|
||||
|
||||
print(frame_first_30[15])
|
||||
|
||||
print("frame_0")
|
||||
frame_0 = f.get(0, 1)
|
||||
frame_15 = f.get(15, 1)
|
||||
|
||||
print(frame_15[0])
|
||||
|
||||
assert np.all(frame_first_30[0] == frame_0[0])
|
||||
assert np.all(frame_first_30[15] == frame_15[0])
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".hevc") as fp:
|
||||
r = requests.get("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/video.hevc?raw=true", timeout=10)
|
||||
fp.write(r.content)
|
||||
fp.flush()
|
||||
|
||||
fr_file = FrameReader(fp.name)
|
||||
_check_data(fr_file)
|
||||
|
||||
fr_url = FrameReader("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/video.hevc?raw=true")
|
||||
_check_data(fr_url)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
32
tools/lib/tests/test_route_library.py
Executable file
32
tools/lib/tests/test_route_library.py
Executable file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
|
||||
from openpilot.tools.lib.route import SegmentName
|
||||
|
||||
class TestRouteLibrary(unittest.TestCase):
|
||||
def test_segment_name_formats(self):
|
||||
Case = namedtuple('Case', ['input', 'expected_route', 'expected_segment_num', 'expected_data_dir'])
|
||||
|
||||
cases = [ Case("a2a0ccea32023010|2023-07-27--13-01-19", "a2a0ccea32023010|2023-07-27--13-01-19", -1, None),
|
||||
Case("a2a0ccea32023010/2023-07-27--13-01-19--1", "a2a0ccea32023010|2023-07-27--13-01-19", 1, None),
|
||||
Case("a2a0ccea32023010|2023-07-27--13-01-19/2", "a2a0ccea32023010|2023-07-27--13-01-19", 2, None),
|
||||
Case("a2a0ccea32023010/2023-07-27--13-01-19/3", "a2a0ccea32023010|2023-07-27--13-01-19", 3, None),
|
||||
Case("/data/media/0/realdata/a2a0ccea32023010|2023-07-27--13-01-19", "a2a0ccea32023010|2023-07-27--13-01-19", -1, "/data/media/0/realdata"),
|
||||
Case("/data/media/0/realdata/a2a0ccea32023010|2023-07-27--13-01-19--1", "a2a0ccea32023010|2023-07-27--13-01-19", 1, "/data/media/0/realdata"),
|
||||
Case("/data/media/0/realdata/a2a0ccea32023010|2023-07-27--13-01-19/2", "a2a0ccea32023010|2023-07-27--13-01-19", 2, "/data/media/0/realdata") ]
|
||||
|
||||
def _validate(case):
|
||||
route_or_segment_name = case.input
|
||||
|
||||
s = SegmentName(route_or_segment_name, allow_route_name=True)
|
||||
|
||||
self.assertEqual(str(s.route_name), case.expected_route)
|
||||
self.assertEqual(s.segment_num, case.expected_segment_num)
|
||||
self.assertEqual(s.data_dir, case.expected_data_dir)
|
||||
|
||||
for case in cases:
|
||||
_validate(case)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
88
tools/lib/tests/test_srreader.py
Normal file
88
tools/lib/tests/test_srreader.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import unittest
|
||||
from parameterized import parameterized
|
||||
import requests
|
||||
|
||||
from openpilot.tools.lib.route import SegmentRange
|
||||
from openpilot.tools.lib.srreader import ReadMode, SegmentRangeReader, parse_slice, parse_indirect
|
||||
|
||||
NUM_SEGS = 17 # number of segments in the test route
|
||||
ALL_SEGS = list(np.arange(NUM_SEGS))
|
||||
TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12"
|
||||
QLOG_FILE = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"
|
||||
|
||||
class TestSegmentRangeReader(unittest.TestCase):
|
||||
@parameterized.expand([
|
||||
(f"{TEST_ROUTE}", ALL_SEGS),
|
||||
(f"{TEST_ROUTE.replace('/', '|')}", ALL_SEGS),
|
||||
(f"{TEST_ROUTE}--0", [0]),
|
||||
(f"{TEST_ROUTE}--5", [5]),
|
||||
(f"{TEST_ROUTE}/0", [0]),
|
||||
(f"{TEST_ROUTE}/5", [5]),
|
||||
(f"{TEST_ROUTE}/0:10", ALL_SEGS[0:10]),
|
||||
(f"{TEST_ROUTE}/0:0", []),
|
||||
(f"{TEST_ROUTE}/4:6", ALL_SEGS[4:6]),
|
||||
(f"{TEST_ROUTE}/0:-1", ALL_SEGS[0:-1]),
|
||||
(f"{TEST_ROUTE}/:5", ALL_SEGS[:5]),
|
||||
(f"{TEST_ROUTE}/2:", ALL_SEGS[2:]),
|
||||
(f"{TEST_ROUTE}/2:-1", ALL_SEGS[2:-1]),
|
||||
(f"{TEST_ROUTE}/-1", [ALL_SEGS[-1]]),
|
||||
(f"{TEST_ROUTE}/-2", [ALL_SEGS[-2]]),
|
||||
(f"{TEST_ROUTE}/-2:-1", ALL_SEGS[-2:-1]),
|
||||
(f"{TEST_ROUTE}/-4:-2", ALL_SEGS[-4:-2]),
|
||||
(f"{TEST_ROUTE}/:10:2", ALL_SEGS[:10:2]),
|
||||
(f"{TEST_ROUTE}/5::2", ALL_SEGS[5::2]),
|
||||
(f"https://useradmin.comma.ai/?onebox={TEST_ROUTE}", ALL_SEGS),
|
||||
(f"https://useradmin.comma.ai/?onebox={TEST_ROUTE.replace('/', '|')}", ALL_SEGS),
|
||||
(f"https://useradmin.comma.ai/?onebox={TEST_ROUTE.replace('/', '%7C')}", ALL_SEGS),
|
||||
(f"https://cabana.comma.ai/?route={TEST_ROUTE}", ALL_SEGS),
|
||||
(f"cd:/{TEST_ROUTE}", ALL_SEGS),
|
||||
])
|
||||
def test_indirect_parsing(self, identifier, expected):
|
||||
parsed, _, _ = parse_indirect(identifier)
|
||||
sr = SegmentRange(parsed)
|
||||
segs = parse_slice(sr)
|
||||
self.assertListEqual(list(segs), expected)
|
||||
|
||||
def test_direct_parsing(self):
|
||||
qlog = tempfile.NamedTemporaryFile(mode='wb', delete=False)
|
||||
|
||||
with requests.get(QLOG_FILE, stream=True) as r:
|
||||
with qlog as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
|
||||
for f in [QLOG_FILE, qlog.name]:
|
||||
l = len(list(SegmentRangeReader(f)))
|
||||
self.assertGreater(l, 100)
|
||||
|
||||
@parameterized.expand([
|
||||
(f"{TEST_ROUTE}///",),
|
||||
(f"{TEST_ROUTE}---",),
|
||||
(f"{TEST_ROUTE}/-4:--2",),
|
||||
(f"{TEST_ROUTE}/-a",),
|
||||
(f"{TEST_ROUTE}/j",),
|
||||
(f"{TEST_ROUTE}/0:1:2:3",),
|
||||
(f"{TEST_ROUTE}/:::3",),
|
||||
])
|
||||
def test_bad_ranges(self, segment_range):
|
||||
with self.assertRaises(AssertionError):
|
||||
sr = SegmentRange(segment_range)
|
||||
parse_slice(sr)
|
||||
|
||||
def test_modes(self):
|
||||
qlog_len = len(list(SegmentRangeReader(f"{TEST_ROUTE}/0", ReadMode.QLOG)))
|
||||
rlog_len = len(list(SegmentRangeReader(f"{TEST_ROUTE}/0", ReadMode.RLOG)))
|
||||
|
||||
self.assertLess(qlog_len * 6, rlog_len)
|
||||
|
||||
def test_modes_from_name(self):
|
||||
qlog_len = len(list(SegmentRangeReader(f"{TEST_ROUTE}/0/q")))
|
||||
rlog_len = len(list(SegmentRangeReader(f"{TEST_ROUTE}/0/r")))
|
||||
|
||||
self.assertLess(qlog_len * 6, rlog_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
156
tools/lib/url_file.py
Normal file
156
tools/lib/url_file.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from hashlib import sha256
|
||||
from urllib3 import PoolManager
|
||||
from urllib3.util import Timeout
|
||||
from tenacity import retry, wait_random_exponential, stop_after_attempt
|
||||
|
||||
from openpilot.common.file_helpers import atomic_write_in_dir
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
# Cache chunk size
|
||||
K = 1000
|
||||
CHUNK_SIZE = 1000 * K
|
||||
|
||||
|
||||
def hash_256(link):
|
||||
hsh = str(sha256((link.split("?")[0]).encode('utf-8')).hexdigest())
|
||||
return hsh
|
||||
|
||||
|
||||
class URLFileException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class URLFile:
|
||||
_tlocal = threading.local()
|
||||
|
||||
def __init__(self, url, debug=False, cache=None):
|
||||
self._url = url
|
||||
self._pos = 0
|
||||
self._length = None
|
||||
self._local_file = None
|
||||
self._debug = debug
|
||||
# True by default, false if FILEREADER_CACHE is defined, but can be overwritten by the cache input
|
||||
self._force_download = not int(os.environ.get("FILEREADER_CACHE", "0"))
|
||||
if cache is not None:
|
||||
self._force_download = not cache
|
||||
|
||||
if not self._force_download:
|
||||
os.makedirs(Paths.download_cache_root(), exist_ok=True)
|
||||
|
||||
try:
|
||||
self._http_client = URLFile._tlocal.http_client
|
||||
except AttributeError:
|
||||
self._http_client = URLFile._tlocal.http_client = PoolManager()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self._local_file is not None:
|
||||
os.remove(self._local_file.name)
|
||||
self._local_file.close()
|
||||
self._local_file = None
|
||||
|
||||
@retry(wait=wait_random_exponential(multiplier=1, max=5), stop=stop_after_attempt(3), reraise=True)
|
||||
def get_length_online(self):
|
||||
timeout = Timeout(connect=50.0, read=500.0)
|
||||
response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False)
|
||||
if not (200 <= response.status <= 299):
|
||||
return -1
|
||||
length = response.headers.get('content-length', 0)
|
||||
return int(length)
|
||||
|
||||
def get_length(self):
|
||||
if self._length is not None:
|
||||
return self._length
|
||||
|
||||
file_length_path = os.path.join(Paths.download_cache_root(), hash_256(self._url) + "_length")
|
||||
if not self._force_download and os.path.exists(file_length_path):
|
||||
with open(file_length_path) as file_length:
|
||||
content = file_length.read()
|
||||
self._length = int(content)
|
||||
return self._length
|
||||
|
||||
self._length = self.get_length_online()
|
||||
if not self._force_download and self._length != -1:
|
||||
with atomic_write_in_dir(file_length_path, mode="w") as file_length:
|
||||
file_length.write(str(self._length))
|
||||
return self._length
|
||||
|
||||
def read(self, ll=None):
|
||||
if self._force_download:
|
||||
return self.read_aux(ll=ll)
|
||||
|
||||
file_begin = self._pos
|
||||
file_end = self._pos + ll if ll is not None else self.get_length()
|
||||
assert file_end != -1, f"Remote file is empty or doesn't exist: {self._url}"
|
||||
# We have to align with chunks we store. Position is the begginiing of the latest chunk that starts before or at our file
|
||||
position = (file_begin // CHUNK_SIZE) * CHUNK_SIZE
|
||||
response = b""
|
||||
while True:
|
||||
self._pos = position
|
||||
chunk_number = self._pos / CHUNK_SIZE
|
||||
file_name = hash_256(self._url) + "_" + str(chunk_number)
|
||||
full_path = os.path.join(Paths.download_cache_root(), str(file_name))
|
||||
data = None
|
||||
# If we don't have a file, download it
|
||||
if not os.path.exists(full_path):
|
||||
data = self.read_aux(ll=CHUNK_SIZE)
|
||||
with atomic_write_in_dir(full_path, mode="wb") as new_cached_file:
|
||||
new_cached_file.write(data)
|
||||
else:
|
||||
with open(full_path, "rb") as cached_file:
|
||||
data = cached_file.read()
|
||||
|
||||
response += data[max(0, file_begin - position): min(CHUNK_SIZE, file_end - position)]
|
||||
|
||||
position += CHUNK_SIZE
|
||||
if position >= file_end:
|
||||
self._pos = file_end
|
||||
return response
|
||||
|
||||
@retry(wait=wait_random_exponential(multiplier=1, max=5), stop=stop_after_attempt(3), reraise=True)
|
||||
def read_aux(self, ll=None):
|
||||
download_range = False
|
||||
headers = {'Connection': 'keep-alive'}
|
||||
if self._pos != 0 or ll is not None:
|
||||
if ll is None:
|
||||
end = self.get_length() - 1
|
||||
else:
|
||||
end = min(self._pos + ll, self.get_length()) - 1
|
||||
if self._pos >= end:
|
||||
return b""
|
||||
headers['Range'] = f"bytes={self._pos}-{end}"
|
||||
download_range = True
|
||||
|
||||
if self._debug:
|
||||
t1 = time.time()
|
||||
|
||||
timeout = Timeout(connect=50.0, read=500.0)
|
||||
response = self._http_client.request('GET', self._url, timeout=timeout, preload_content=False, headers=headers)
|
||||
ret = response.data
|
||||
|
||||
if self._debug:
|
||||
t2 = time.time()
|
||||
if t2 - t1 > 0.1:
|
||||
print(f"get {self._url} {headers!r} {t2 - t1:.3f} slow")
|
||||
|
||||
response_code = response.status
|
||||
if response_code == 416: # Requested Range Not Satisfiable
|
||||
raise URLFileException(f"Error, range out of bounds {response_code} {headers} ({self._url}): {repr(ret)[:500]}")
|
||||
if download_range and response_code != 206: # Partial Content
|
||||
raise URLFileException(f"Error, requested range but got unexpected response {response_code} {headers} ({self._url}): {repr(ret)[:500]}")
|
||||
if (not download_range) and response_code != 200: # OK
|
||||
raise URLFileException(f"Error {response_code} {headers} ({self._url}): {repr(ret)[:500]}")
|
||||
|
||||
self._pos += len(ret)
|
||||
return ret
|
||||
|
||||
def seek(self, pos):
|
||||
self._pos = pos
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._url
|
||||
312
tools/lib/vidindex.py
Executable file
312
tools/lib/vidindex.py
Executable file
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import os
|
||||
import struct
|
||||
from enum import IntEnum
|
||||
from typing import Tuple
|
||||
|
||||
from openpilot.tools.lib.filereader import FileReader
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
|
||||
# compare to ffmpeg parsing
|
||||
# ffmpeg -i <input.hevc> -c copy -bsf:v trace_headers -f null - 2>&1 | grep -B4 -A32 '] 0 '
|
||||
|
||||
# H.265 specification
|
||||
# https://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-H.265-201802-S!!PDF-E&type=items
|
||||
|
||||
NAL_UNIT_START_CODE = b"\x00\x00\x01"
|
||||
NAL_UNIT_START_CODE_SIZE = len(NAL_UNIT_START_CODE)
|
||||
NAL_UNIT_HEADER_SIZE = 2
|
||||
|
||||
class HevcNalUnitType(IntEnum):
|
||||
TRAIL_N = 0 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
TRAIL_R = 1 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
TSA_N = 2 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
TSA_R = 3 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
STSA_N = 4 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
STSA_R = 5 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
RADL_N = 6 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
RADL_R = 7 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
RASL_N = 8 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
RASL_R = 9 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
RSV_VCL_N10 = 10
|
||||
RSV_VCL_R11 = 11
|
||||
RSV_VCL_N12 = 12
|
||||
RSV_VCL_R13 = 13
|
||||
RSV_VCL_N14 = 14
|
||||
RSV_VCL_R15 = 15
|
||||
BLA_W_LP = 16 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
BLA_W_RADL = 17 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
BLA_N_LP = 18 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
IDR_W_RADL = 19 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
IDR_N_LP = 20 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
CRA_NUT = 21 # RBSP structure: slice_segment_layer_rbsp( )
|
||||
RSV_IRAP_VCL22 = 22
|
||||
RSV_IRAP_VCL23 = 23
|
||||
RSV_VCL24 = 24
|
||||
RSV_VCL25 = 25
|
||||
RSV_VCL26 = 26
|
||||
RSV_VCL27 = 27
|
||||
RSV_VCL28 = 28
|
||||
RSV_VCL29 = 29
|
||||
RSV_VCL30 = 30
|
||||
RSV_VCL31 = 31
|
||||
VPS_NUT = 32 # RBSP structure: video_parameter_set_rbsp( )
|
||||
SPS_NUT = 33 # RBSP structure: seq_parameter_set_rbsp( )
|
||||
PPS_NUT = 34 # RBSP structure: pic_parameter_set_rbsp( )
|
||||
AUD_NUT = 35
|
||||
EOS_NUT = 36
|
||||
EOB_NUT = 37
|
||||
FD_NUT = 38
|
||||
PREFIX_SEI_NUT = 39
|
||||
SUFFIX_SEI_NUT = 40
|
||||
RSV_NVCL41 = 41
|
||||
RSV_NVCL42 = 42
|
||||
RSV_NVCL43 = 43
|
||||
RSV_NVCL44 = 44
|
||||
RSV_NVCL45 = 45
|
||||
RSV_NVCL46 = 46
|
||||
RSV_NVCL47 = 47
|
||||
UNSPEC48 = 48
|
||||
UNSPEC49 = 49
|
||||
UNSPEC50 = 50
|
||||
UNSPEC51 = 51
|
||||
UNSPEC52 = 52
|
||||
UNSPEC53 = 53
|
||||
UNSPEC54 = 54
|
||||
UNSPEC55 = 55
|
||||
UNSPEC56 = 56
|
||||
UNSPEC57 = 57
|
||||
UNSPEC58 = 58
|
||||
UNSPEC59 = 59
|
||||
UNSPEC60 = 60
|
||||
UNSPEC61 = 61
|
||||
UNSPEC62 = 62
|
||||
UNSPEC63 = 63
|
||||
|
||||
# B.2.2 Byte stream NAL unit semantics
|
||||
# - The nal_unit_type within the nal_unit( ) syntax structure is equal to VPS_NUT, SPS_NUT or PPS_NUT.
|
||||
# - The byte stream NAL unit syntax structure contains the first NAL unit of an access unit in decoding
|
||||
# order, as specified in clause 7.4.2.4.4.
|
||||
HEVC_PARAMETER_SET_NAL_UNITS = (
|
||||
HevcNalUnitType.VPS_NUT,
|
||||
HevcNalUnitType.SPS_NUT,
|
||||
HevcNalUnitType.PPS_NUT,
|
||||
)
|
||||
|
||||
# 3.29 coded slice segment NAL unit: A NAL unit that has nal_unit_type in the range of TRAIL_N to RASL_R,
|
||||
# inclusive, or in the range of BLA_W_LP to RSV_IRAP_VCL23, inclusive, which indicates that the NAL unit
|
||||
# contains a coded slice segment
|
||||
HEVC_CODED_SLICE_SEGMENT_NAL_UNITS = (
|
||||
HevcNalUnitType.TRAIL_N,
|
||||
HevcNalUnitType.TRAIL_R,
|
||||
HevcNalUnitType.TSA_N,
|
||||
HevcNalUnitType.TSA_R,
|
||||
HevcNalUnitType.STSA_N,
|
||||
HevcNalUnitType.STSA_R,
|
||||
HevcNalUnitType.RADL_N,
|
||||
HevcNalUnitType.RADL_R,
|
||||
HevcNalUnitType.RASL_N,
|
||||
HevcNalUnitType.RASL_R,
|
||||
HevcNalUnitType.BLA_W_LP,
|
||||
HevcNalUnitType.BLA_W_RADL,
|
||||
HevcNalUnitType.BLA_N_LP,
|
||||
HevcNalUnitType.IDR_W_RADL,
|
||||
HevcNalUnitType.IDR_N_LP,
|
||||
HevcNalUnitType.CRA_NUT,
|
||||
)
|
||||
|
||||
class VideoFileInvalid(Exception):
|
||||
pass
|
||||
|
||||
def get_ue(dat: bytes, start_idx: int, skip_bits: int) -> Tuple[int, int]:
|
||||
prefix_val = 0
|
||||
prefix_len = 0
|
||||
suffix_val = 0
|
||||
suffix_len = 0
|
||||
|
||||
i = start_idx
|
||||
while i < len(dat):
|
||||
j = 7
|
||||
while j >= 0:
|
||||
if skip_bits > 0:
|
||||
skip_bits -= 1
|
||||
elif prefix_val == 0:
|
||||
prefix_val = (dat[i] >> j) & 1
|
||||
prefix_len += 1
|
||||
else:
|
||||
suffix_val = (suffix_val << 1) | ((dat[i] >> j) & 1)
|
||||
suffix_len += 1
|
||||
j -= 1
|
||||
|
||||
if prefix_val == 1 and prefix_len - 1 == suffix_len:
|
||||
val = 2**(prefix_len-1) - 1 + suffix_val
|
||||
size = prefix_len + suffix_len
|
||||
return val, size
|
||||
i += 1
|
||||
|
||||
raise VideoFileInvalid("invalid exponential-golomb code")
|
||||
|
||||
def require_nal_unit_start(dat: bytes, nal_unit_start: int) -> None:
|
||||
if nal_unit_start < 1:
|
||||
raise ValueError("start index must be greater than zero")
|
||||
|
||||
if dat[nal_unit_start:nal_unit_start + NAL_UNIT_START_CODE_SIZE] != NAL_UNIT_START_CODE:
|
||||
raise VideoFileInvalid("data must begin with start code")
|
||||
|
||||
def get_hevc_nal_unit_length(dat: bytes, nal_unit_start: int) -> int:
|
||||
try:
|
||||
pos = dat.index(NAL_UNIT_START_CODE, nal_unit_start + NAL_UNIT_START_CODE_SIZE)
|
||||
except ValueError:
|
||||
pos = -1
|
||||
|
||||
# length of NAL unit is byte count up to next NAL unit start index
|
||||
nal_unit_len = (pos if pos != -1 else len(dat)) - nal_unit_start
|
||||
if DEBUG:
|
||||
print(" nal_unit_len:", nal_unit_len)
|
||||
return nal_unit_len
|
||||
|
||||
def get_hevc_nal_unit_type(dat: bytes, nal_unit_start: int) -> HevcNalUnitType:
|
||||
# 7.3.1.2 NAL unit header syntax
|
||||
# nal_unit_header( ) { // descriptor
|
||||
# forbidden_zero_bit f(1)
|
||||
# nal_unit_type u(6)
|
||||
# nuh_layer_id u(6)
|
||||
# nuh_temporal_id_plus1 u(3)
|
||||
# }
|
||||
header_start = nal_unit_start + NAL_UNIT_START_CODE_SIZE
|
||||
nal_unit_header = dat[header_start:header_start + NAL_UNIT_HEADER_SIZE]
|
||||
if len(nal_unit_header) != 2:
|
||||
raise VideoFileInvalid("data to short to contain nal unit header")
|
||||
nal_unit_type = HevcNalUnitType((nal_unit_header[0] >> 1) & 0x3F)
|
||||
if DEBUG:
|
||||
print(" nal_unit_type:", nal_unit_type.name, f"({nal_unit_type.value})")
|
||||
return nal_unit_type
|
||||
|
||||
def get_hevc_slice_type(dat: bytes, nal_unit_start: int, nal_unit_type: HevcNalUnitType) -> Tuple[int, bool]:
|
||||
# 7.3.2.9 Slice segment layer RBSP syntax
|
||||
# slice_segment_layer_rbsp( ) {
|
||||
# slice_segment_header( )
|
||||
# slice_segment_data( )
|
||||
# rbsp_slice_segment_trailing_bits( )
|
||||
# }
|
||||
# ...
|
||||
# 7.3.6.1 General slice segment header syntax
|
||||
# slice_segment_header( ) { // descriptor
|
||||
# first_slice_segment_in_pic_flag u(1)
|
||||
# if( nal_unit_type >= BLA_W_LP && nal_unit_type <= RSV_IRAP_VCL23 )
|
||||
# no_output_of_prior_pics_flag u(1)
|
||||
# slice_pic_parameter_set_id ue(v)
|
||||
# if( !first_slice_segment_in_pic_flag ) {
|
||||
# if( dependent_slice_segments_enabled_flag )
|
||||
# dependent_slice_segment_flag u(1)
|
||||
# slice_segment_address u(v)
|
||||
# }
|
||||
# if( !dependent_slice_segment_flag ) {
|
||||
# for( i = 0; i < num_extra_slice_header_bits; i++ )
|
||||
# slice_reserved_flag[ i ] u(1)
|
||||
# slice_type ue(v)
|
||||
# ...
|
||||
|
||||
rbsp_start = nal_unit_start + NAL_UNIT_START_CODE_SIZE + NAL_UNIT_HEADER_SIZE
|
||||
skip_bits = 0
|
||||
|
||||
# 7.4.7.1 General slice segment header semantics
|
||||
# first_slice_segment_in_pic_flag equal to 1 specifies that the slice segment is the first slice segment of the picture in
|
||||
# decoding order. first_slice_segment_in_pic_flag equal to 0 specifies that the slice segment is not the first slice segment
|
||||
# of the picture in decoding order.
|
||||
is_first_slice = dat[rbsp_start] >> 7 & 1 == 1
|
||||
if not is_first_slice:
|
||||
# TODO: parse dependent_slice_segment_flag and slice_segment_address and get real slice_type
|
||||
# for now since we don't use it return -1 for slice_type
|
||||
return (-1, is_first_slice)
|
||||
skip_bits += 1 # skip past first_slice_segment_in_pic_flag
|
||||
|
||||
if nal_unit_type >= HevcNalUnitType.BLA_W_LP and nal_unit_type <= HevcNalUnitType.RSV_IRAP_VCL23:
|
||||
# 7.4.7.1 General slice segment header semantics
|
||||
# no_output_of_prior_pics_flag affects the output of previously-decoded pictures in the decoded picture buffer after the
|
||||
# decoding of an IDR or a BLA picture that is not the first picture in the bitstream as specified in Annex C.
|
||||
skip_bits += 1 # skip past no_output_of_prior_pics_flag
|
||||
|
||||
# 7.4.7.1 General slice segment header semantics
|
||||
# slice_pic_parameter_set_id specifies the value of pps_pic_parameter_set_id for the PPS in use.
|
||||
# The value of slice_pic_parameter_set_id shall be in the range of 0 to 63, inclusive.
|
||||
_, size = get_ue(dat, rbsp_start, skip_bits)
|
||||
skip_bits += size # skip past slice_pic_parameter_set_id
|
||||
|
||||
# 7.4.3.3.1 General picture parameter set RBSP semanal_unit_lenntics
|
||||
# num_extra_slice_header_bits specifies the number of extra slice header bits that are present in the slice header RBSP
|
||||
# for coded pictures referring to the PPS. The value of num_extra_slice_header_bits shall be in the range of 0 to 2, inclusive,
|
||||
# in bitstreams conforming to this version of this Specification. Other values for num_extra_slice_header_bits are reserved
|
||||
# for future use by ITU-T | ISO/IEC. However, decoders shall allow num_extra_slice_header_bits to have any value.
|
||||
# TODO: get from PPS_NUT pic_parameter_set_rbsp( ) for corresponding slice_pic_parameter_set_id
|
||||
num_extra_slice_header_bits = 0
|
||||
skip_bits += num_extra_slice_header_bits
|
||||
|
||||
# 7.4.7.1 General slice segment header semantics
|
||||
# slice_type specifies the coding type of the slice according to Table 7-7.
|
||||
# Table 7-7 - Name association to slice_type
|
||||
# slice_type | Name of slice_type
|
||||
# 0 | B (B slice)
|
||||
# 1 | P (P slice)
|
||||
# 2 | I (I slice)
|
||||
# unsigned integer 0-th order Exp-Golomb-coded syntax element with the left bit first
|
||||
slice_type, _ = get_ue(dat, rbsp_start, skip_bits)
|
||||
if DEBUG:
|
||||
print(" slice_type:", slice_type, f"(first slice: {is_first_slice})")
|
||||
if slice_type > 2:
|
||||
raise VideoFileInvalid("slice_type must be 0, 1, or 2")
|
||||
return slice_type, is_first_slice
|
||||
|
||||
def hevc_index(hevc_file_name: str, allow_corrupt: bool=False) -> Tuple[list, int, bytes]:
|
||||
with FileReader(hevc_file_name) as f:
|
||||
dat = f.read()
|
||||
|
||||
if len(dat) < NAL_UNIT_START_CODE_SIZE + 1:
|
||||
raise VideoFileInvalid("data is too short")
|
||||
|
||||
if dat[0] != 0x00:
|
||||
raise VideoFileInvalid("first byte must be 0x00")
|
||||
|
||||
prefix_dat = b""
|
||||
frame_types = list()
|
||||
|
||||
i = 1 # skip past first byte 0x00
|
||||
try:
|
||||
while i < len(dat):
|
||||
require_nal_unit_start(dat, i)
|
||||
nal_unit_len = get_hevc_nal_unit_length(dat, i)
|
||||
nal_unit_type = get_hevc_nal_unit_type(dat, i)
|
||||
if nal_unit_type in HEVC_PARAMETER_SET_NAL_UNITS:
|
||||
prefix_dat += dat[i:i+nal_unit_len]
|
||||
elif nal_unit_type in HEVC_CODED_SLICE_SEGMENT_NAL_UNITS:
|
||||
slice_type, is_first_slice = get_hevc_slice_type(dat, i, nal_unit_type)
|
||||
if is_first_slice:
|
||||
frame_types.append((slice_type, i))
|
||||
i += nal_unit_len
|
||||
except Exception as e:
|
||||
if not allow_corrupt:
|
||||
raise
|
||||
print(f"ERROR: NAL unit skipped @ {i}\n", str(e))
|
||||
|
||||
return frame_types, len(dat), prefix_dat
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input_file", type=str)
|
||||
parser.add_argument("output_prefix_file", type=str)
|
||||
parser.add_argument("output_index_file", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
frame_types, dat_len, prefix_dat = hevc_index(args.input_file)
|
||||
with open(args.output_prefix_file, "wb") as f:
|
||||
f.write(prefix_dat)
|
||||
|
||||
with open(args.output_index_file, "wb") as f:
|
||||
for ft, fp in frame_types:
|
||||
f.write(struct.pack("<II", ft, fp))
|
||||
f.write(struct.pack("<II", 0xFFFFFFFF, dat_len))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user