wip
This commit is contained in:
Binary file not shown.
92
cereal/messaging/bridge.cc
Normal file
92
cereal/messaging/bridge.cc
Normal file
@@ -0,0 +1,92 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <csignal>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
typedef void (*sighandler_t)(int sig);
|
||||
|
||||
#include "cereal/services.h"
|
||||
#include "cereal/messaging/impl_msgq.h"
|
||||
#include "cereal/messaging/impl_zmq.h"
|
||||
|
||||
std::atomic<bool> do_exit = false;
|
||||
static void set_do_exit(int sig) {
|
||||
do_exit = true;
|
||||
}
|
||||
|
||||
void sigpipe_handler(int sig) {
|
||||
assert(sig == SIGPIPE);
|
||||
std::cout << "SIGPIPE received" << std::endl;
|
||||
}
|
||||
|
||||
static std::vector<std::string> get_services(std::string whitelist_str, bool zmq_to_msgq) {
|
||||
std::vector<std::string> service_list;
|
||||
for (const auto& it : services) {
|
||||
std::string name = it.second.name;
|
||||
bool in_whitelist = whitelist_str.find(name) != std::string::npos;
|
||||
if (name == "plusFrame" || name == "uiLayoutState" || (zmq_to_msgq && !in_whitelist)) {
|
||||
continue;
|
||||
}
|
||||
service_list.push_back(name);
|
||||
}
|
||||
return service_list;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
signal(SIGPIPE, (sighandler_t)sigpipe_handler);
|
||||
signal(SIGINT, (sighandler_t)set_do_exit);
|
||||
signal(SIGTERM, (sighandler_t)set_do_exit);
|
||||
|
||||
bool zmq_to_msgq = argc > 2;
|
||||
std::string ip = zmq_to_msgq ? argv[1] : "127.0.0.1";
|
||||
std::string whitelist_str = zmq_to_msgq ? std::string(argv[2]) : "";
|
||||
|
||||
Poller *poller;
|
||||
Context *pub_context;
|
||||
Context *sub_context;
|
||||
if (zmq_to_msgq) { // republishes zmq debugging messages as msgq
|
||||
poller = new ZMQPoller();
|
||||
pub_context = new MSGQContext();
|
||||
sub_context = new ZMQContext();
|
||||
} else {
|
||||
poller = new MSGQPoller();
|
||||
pub_context = new ZMQContext();
|
||||
sub_context = new MSGQContext();
|
||||
}
|
||||
|
||||
std::map<SubSocket*, PubSocket*> sub2pub;
|
||||
for (auto endpoint : get_services(whitelist_str, zmq_to_msgq)) {
|
||||
PubSocket * pub_sock;
|
||||
SubSocket * sub_sock;
|
||||
if (zmq_to_msgq) {
|
||||
pub_sock = new MSGQPubSocket();
|
||||
sub_sock = new ZMQSubSocket();
|
||||
} else {
|
||||
pub_sock = new ZMQPubSocket();
|
||||
sub_sock = new MSGQSubSocket();
|
||||
}
|
||||
pub_sock->connect(pub_context, endpoint);
|
||||
sub_sock->connect(sub_context, endpoint, ip, false);
|
||||
|
||||
poller->registerSocket(sub_sock);
|
||||
sub2pub[sub_sock] = pub_sock;
|
||||
}
|
||||
|
||||
while (!do_exit) {
|
||||
for (auto sub_sock : poller->poll(100)) {
|
||||
Message * msg = sub_sock->receive();
|
||||
if (msg == NULL) continue;
|
||||
int ret;
|
||||
do {
|
||||
ret = sub2pub[sub_sock]->sendMessage(msg);
|
||||
} while (ret == -1 && errno == EINTR && !do_exit);
|
||||
assert(ret >= 0 || do_exit);
|
||||
delete msg;
|
||||
|
||||
if (do_exit) break;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
50
cereal/messaging/demo.cc
Normal file
50
cereal/messaging/demo.cc
Normal file
@@ -0,0 +1,50 @@
|
||||
#include <iostream>
|
||||
#include <cstddef>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <cassert>
|
||||
|
||||
#include "cereal/messaging/messaging.h"
|
||||
#include "cereal/messaging/impl_zmq.h"
|
||||
|
||||
#define MSGS 1e5
|
||||
|
||||
int main() {
|
||||
Context * c = Context::create();
|
||||
SubSocket * sub_sock = SubSocket::create(c, "controlsState");
|
||||
PubSocket * pub_sock = PubSocket::create(c, "controlsState");
|
||||
|
||||
char data[8];
|
||||
|
||||
Poller * poller = Poller::create({sub_sock});
|
||||
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
|
||||
for (uint64_t i = 0; i < MSGS; i++){
|
||||
*(uint64_t*)data = i;
|
||||
pub_sock->send(data, 8);
|
||||
|
||||
auto r = poller->poll(100);
|
||||
|
||||
for (auto p : r){
|
||||
Message * m = p->receive();
|
||||
uint64_t ii = *(uint64_t*)m->getData();
|
||||
assert(i == ii);
|
||||
delete m;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
double elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count() / 1e9;
|
||||
double throughput = ((double) MSGS / (double) elapsed);
|
||||
std::cout << throughput << " msg/s" << std::endl;
|
||||
|
||||
delete poller;
|
||||
delete sub_sock;
|
||||
delete pub_sock;
|
||||
delete c;
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
29
cereal/messaging/demo.py
Normal file
29
cereal/messaging/demo.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import time
|
||||
|
||||
from messaging_pyx import Context, Poller, SubSocket, PubSocket
|
||||
|
||||
MSGS = 1e5
|
||||
|
||||
if __name__ == "__main__":
|
||||
c = Context()
|
||||
sub_sock = SubSocket()
|
||||
pub_sock = PubSocket()
|
||||
|
||||
sub_sock.connect(c, "controlsState")
|
||||
pub_sock.connect(c, "controlsState")
|
||||
|
||||
poller = Poller()
|
||||
poller.registerSocket(sub_sock)
|
||||
|
||||
t = time.time()
|
||||
for i in range(int(MSGS)):
|
||||
bts = i.to_bytes(4, 'little')
|
||||
pub_sock.send(bts)
|
||||
|
||||
for s in poller.poll(100):
|
||||
dat = s.receive()
|
||||
ii = int.from_bytes(dat, 'little')
|
||||
assert(i == ii)
|
||||
|
||||
dt = time.time() - t
|
||||
print("%.1f msg/s" % (MSGS / dt))
|
||||
236
cereal/messaging/event.cc
Normal file
236
cereal/messaging/event.cc
Normal file
@@ -0,0 +1,236 @@
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <exception>
|
||||
#include <filesystem>
|
||||
|
||||
#include <unistd.h>
|
||||
#include <poll.h>
|
||||
#include <signal.h>
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
|
||||
#include "cereal/messaging/event.h"
|
||||
|
||||
#ifndef __APPLE__
|
||||
#include <sys/eventfd.h>
|
||||
|
||||
void event_state_shm_mmap(std::string endpoint, std::string identifier, char **shm_mem, std::string *shm_path) {
|
||||
const char* op_prefix = std::getenv("OPENPILOT_PREFIX");
|
||||
|
||||
std::string full_path = "/dev/shm/";
|
||||
if (op_prefix) {
|
||||
full_path += std::string(op_prefix) + "/";
|
||||
}
|
||||
full_path += CEREAL_EVENTS_PREFIX + "/";
|
||||
if (identifier.size() > 0) {
|
||||
full_path += identifier + "/";
|
||||
}
|
||||
std::filesystem::create_directories(full_path);
|
||||
full_path += endpoint;
|
||||
|
||||
int shm_fd = open(full_path.c_str(), O_RDWR | O_CREAT, 0664);
|
||||
if (shm_fd < 0) {
|
||||
throw std::runtime_error("Could not open shared memory file.");
|
||||
}
|
||||
|
||||
int rc = ftruncate(shm_fd, sizeof(EventState));
|
||||
if (rc < 0){
|
||||
close(shm_fd);
|
||||
throw std::runtime_error("Could not truncate shared memory file.");
|
||||
}
|
||||
|
||||
char * mem = (char*)mmap(NULL, sizeof(EventState), PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0);
|
||||
close(shm_fd);
|
||||
if (mem == nullptr) {
|
||||
throw std::runtime_error("Could not map shared memory file.");
|
||||
}
|
||||
|
||||
if (shm_mem != nullptr)
|
||||
*shm_mem = mem;
|
||||
if (shm_path != nullptr)
|
||||
*shm_path = full_path;
|
||||
}
|
||||
|
||||
SocketEventHandle::SocketEventHandle(std::string endpoint, std::string identifier, bool override) {
|
||||
char *mem;
|
||||
event_state_shm_mmap(endpoint, identifier, &mem, &this->shm_path);
|
||||
|
||||
this->state = (EventState*)mem;
|
||||
if (override) {
|
||||
this->state->fds[0] = eventfd(0, EFD_NONBLOCK);
|
||||
this->state->fds[1] = eventfd(0, EFD_NONBLOCK);
|
||||
}
|
||||
}
|
||||
|
||||
SocketEventHandle::~SocketEventHandle() {
|
||||
close(this->state->fds[0]);
|
||||
close(this->state->fds[1]);
|
||||
munmap(this->state, sizeof(EventState));
|
||||
unlink(this->shm_path.c_str());
|
||||
}
|
||||
|
||||
bool SocketEventHandle::is_enabled() {
|
||||
return this->state->enabled;
|
||||
}
|
||||
|
||||
void SocketEventHandle::set_enabled(bool enabled) {
|
||||
this->state->enabled = enabled;
|
||||
}
|
||||
|
||||
Event SocketEventHandle::recv_called() {
|
||||
return Event(this->state->fds[0]);
|
||||
}
|
||||
|
||||
Event SocketEventHandle::recv_ready() {
|
||||
return Event(this->state->fds[1]);
|
||||
}
|
||||
|
||||
void SocketEventHandle::toggle_fake_events(bool enabled) {
|
||||
if (enabled)
|
||||
setenv("CEREAL_FAKE", "1", true);
|
||||
else
|
||||
unsetenv("CEREAL_FAKE");
|
||||
}
|
||||
|
||||
void SocketEventHandle::set_fake_prefix(std::string prefix) {
|
||||
if (prefix.size() == 0) {
|
||||
unsetenv("CEREAL_FAKE_PREFIX");
|
||||
} else {
|
||||
setenv("CEREAL_FAKE_PREFIX", prefix.c_str(), true);
|
||||
}
|
||||
}
|
||||
|
||||
std::string SocketEventHandle::fake_prefix() {
|
||||
const char* prefix = std::getenv("CEREAL_FAKE_PREFIX");
|
||||
if (prefix == nullptr) {
|
||||
return "";
|
||||
} else {
|
||||
return std::string(prefix);
|
||||
}
|
||||
}
|
||||
|
||||
Event::Event(int fd): event_fd(fd) {}
|
||||
|
||||
void Event::set() const {
|
||||
throw_if_invalid();
|
||||
|
||||
uint64_t val = 1;
|
||||
size_t count = write(this->event_fd, &val, sizeof(uint64_t));
|
||||
assert(count == sizeof(uint64_t));
|
||||
}
|
||||
|
||||
int Event::clear() const {
|
||||
throw_if_invalid();
|
||||
|
||||
uint64_t val = 0;
|
||||
// read the eventfd to clear it
|
||||
read(this->event_fd, &val, sizeof(uint64_t));
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
void Event::wait(int timeout_sec) const {
|
||||
throw_if_invalid();
|
||||
|
||||
int event_count;
|
||||
struct pollfd fds = { this->event_fd, POLLIN, 0 };
|
||||
struct timespec timeout = { timeout_sec, 0 };;
|
||||
|
||||
sigset_t signals;
|
||||
sigfillset(&signals);
|
||||
sigdelset(&signals, SIGALRM);
|
||||
sigdelset(&signals, SIGINT);
|
||||
sigdelset(&signals, SIGTERM);
|
||||
sigdelset(&signals, SIGQUIT);
|
||||
|
||||
event_count = ppoll(&fds, 1, timeout_sec < 0 ? nullptr : &timeout, &signals);
|
||||
|
||||
if (event_count == 0) {
|
||||
throw std::runtime_error("Event timed out pid: " + std::to_string(getpid()));
|
||||
} else if (event_count < 0) {
|
||||
throw std::runtime_error("Event poll failed, errno: " + std::to_string(errno) + " pid: " + std::to_string(getpid()));
|
||||
}
|
||||
}
|
||||
|
||||
bool Event::peek() const {
|
||||
throw_if_invalid();
|
||||
|
||||
int event_count;
|
||||
|
||||
struct pollfd fds = { this->event_fd, POLLIN, 0 };
|
||||
|
||||
// poll with timeout zero to return status immediately
|
||||
event_count = poll(&fds, 1, 0);
|
||||
|
||||
return event_count != 0;
|
||||
}
|
||||
|
||||
bool Event::is_valid() const {
|
||||
return event_fd != -1;
|
||||
}
|
||||
|
||||
int Event::fd() const {
|
||||
return event_fd;
|
||||
}
|
||||
|
||||
int Event::wait_for_one(const std::vector<Event>& events, int timeout_sec) {
|
||||
struct pollfd fds[events.size()];
|
||||
for (size_t i = 0; i < events.size(); i++) {
|
||||
fds[i] = { events[i].fd(), POLLIN, 0 };
|
||||
}
|
||||
|
||||
struct timespec timeout = { timeout_sec, 0 };
|
||||
|
||||
sigset_t signals;
|
||||
sigfillset(&signals);
|
||||
sigdelset(&signals, SIGALRM);
|
||||
sigdelset(&signals, SIGINT);
|
||||
sigdelset(&signals, SIGTERM);
|
||||
sigdelset(&signals, SIGQUIT);
|
||||
|
||||
int event_count = ppoll(fds, events.size(), timeout_sec < 0 ? nullptr : &timeout, &signals);
|
||||
|
||||
if (event_count == 0) {
|
||||
throw std::runtime_error("Event timed out pid: " + std::to_string(getpid()));
|
||||
} else if (event_count < 0) {
|
||||
throw std::runtime_error("Event poll failed, errno: " + std::to_string(errno) + " pid: " + std::to_string(getpid()));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < events.size(); i++) {
|
||||
if (fds[i].revents & POLLIN) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error("Event poll failed, no events ready");
|
||||
}
|
||||
#else
|
||||
// Stub implementation for Darwin, which does not support eventfd
|
||||
void event_state_shm_mmap(std::string endpoint, std::string identifier, char **shm_mem, std::string *shm_path) {}
|
||||
|
||||
SocketEventHandle::SocketEventHandle(std::string endpoint, std::string identifier, bool override) {
|
||||
std::cerr << "SocketEventHandle not supported on macOS" << std::endl;
|
||||
assert(false);
|
||||
}
|
||||
SocketEventHandle::~SocketEventHandle() {}
|
||||
bool SocketEventHandle::is_enabled() { return this->state->enabled; }
|
||||
void SocketEventHandle::set_enabled(bool enabled) {}
|
||||
Event SocketEventHandle::recv_called() { return Event(); }
|
||||
Event SocketEventHandle::recv_ready() { return Event(); }
|
||||
void SocketEventHandle::toggle_fake_events(bool enabled) {}
|
||||
void SocketEventHandle::set_fake_prefix(std::string prefix) {}
|
||||
std::string SocketEventHandle::fake_prefix() { return ""; }
|
||||
|
||||
Event::Event(int fd): event_fd(fd) {}
|
||||
void Event::set() const {}
|
||||
int Event::clear() const { return 0; }
|
||||
void Event::wait(int timeout_sec) const {}
|
||||
bool Event::peek() const { return false; }
|
||||
bool Event::is_valid() const { return false; }
|
||||
int Event::fd() const { return this->event_fd; }
|
||||
int Event::wait_for_one(const std::vector<Event>& events, int timeout_sec) { return -1; }
|
||||
#endif
|
||||
58
cereal/messaging/event.h
Normal file
58
cereal/messaging/event.h
Normal file
@@ -0,0 +1,58 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define CEREAL_EVENTS_PREFIX std::string("cereal_events")
|
||||
|
||||
void event_state_shm_mmap(std::string endpoint, std::string identifier, char **shm_mem, std::string *shm_path);
|
||||
|
||||
enum EventPurpose {
|
||||
RECV_CALLED,
|
||||
RECV_READY
|
||||
};
|
||||
|
||||
struct EventState {
|
||||
int fds[2];
|
||||
bool enabled;
|
||||
};
|
||||
|
||||
class Event {
|
||||
private:
|
||||
int event_fd = -1;
|
||||
|
||||
inline void throw_if_invalid() const {
|
||||
if (!this->is_valid()) {
|
||||
throw std::runtime_error("Event does not have valid file descriptor.");
|
||||
}
|
||||
}
|
||||
public:
|
||||
Event(int fd = -1);
|
||||
|
||||
void set() const;
|
||||
int clear() const;
|
||||
void wait(int timeout_sec = -1) const;
|
||||
bool peek() const;
|
||||
bool is_valid() const;
|
||||
int fd() const;
|
||||
|
||||
static int wait_for_one(const std::vector<Event>& events, int timeout_sec = -1);
|
||||
};
|
||||
|
||||
class SocketEventHandle {
|
||||
private:
|
||||
std::string shm_path;
|
||||
EventState* state;
|
||||
public:
|
||||
SocketEventHandle(std::string endpoint, std::string identifier = "", bool override = true);
|
||||
~SocketEventHandle();
|
||||
|
||||
bool is_enabled();
|
||||
void set_enabled(bool enabled);
|
||||
Event recv_called();
|
||||
Event recv_ready();
|
||||
|
||||
static void toggle_fake_events(bool enabled);
|
||||
static void set_fake_prefix(std::string prefix);
|
||||
static std::string fake_prefix();
|
||||
};
|
||||
9
cereal/messaging/impl_fake.cc
Normal file
9
cereal/messaging/impl_fake.cc
Normal file
@@ -0,0 +1,9 @@
|
||||
#include "cereal/messaging/impl_fake.h"
|
||||
|
||||
void FakePoller::registerSocket(SubSocket *socket) {
|
||||
this->sockets.push_back(socket);
|
||||
}
|
||||
|
||||
std::vector<SubSocket*> FakePoller::poll(int timeout) {
|
||||
return this->sockets;
|
||||
}
|
||||
67
cereal/messaging/impl_fake.h
Normal file
67
cereal/messaging/impl_fake.h
Normal file
@@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "cereal/messaging/messaging.h"
|
||||
#include "cereal/messaging/event.h"
|
||||
|
||||
template<typename TSubSocket>
|
||||
class FakeSubSocket: public TSubSocket {
|
||||
private:
|
||||
Event *recv_called = nullptr;
|
||||
Event *recv_ready = nullptr;
|
||||
EventState *state = nullptr;
|
||||
|
||||
public:
|
||||
FakeSubSocket(): TSubSocket() {}
|
||||
~FakeSubSocket() {
|
||||
delete recv_called;
|
||||
delete recv_ready;
|
||||
if (state != nullptr) {
|
||||
munmap(state, sizeof(EventState));
|
||||
}
|
||||
}
|
||||
|
||||
int connect(Context *context, std::string endpoint, std::string address, bool conflate=false, bool check_endpoint=true) override {
|
||||
const char* cereal_prefix = std::getenv("CEREAL_FAKE_PREFIX");
|
||||
|
||||
char* mem;
|
||||
std::string identifier = cereal_prefix != nullptr ? std::string(cereal_prefix) : "";
|
||||
event_state_shm_mmap(endpoint, identifier, &mem, nullptr);
|
||||
|
||||
this->state = (EventState*)mem;
|
||||
this->recv_called = new Event(state->fds[EventPurpose::RECV_CALLED]);
|
||||
this->recv_ready = new Event(state->fds[EventPurpose::RECV_READY]);
|
||||
|
||||
return TSubSocket::connect(context, endpoint, address, conflate, check_endpoint);
|
||||
}
|
||||
|
||||
Message *receive(bool non_blocking=false) override {
|
||||
if (this->state->enabled) {
|
||||
this->recv_called->set();
|
||||
this->recv_ready->wait();
|
||||
this->recv_ready->clear();
|
||||
}
|
||||
|
||||
return TSubSocket::receive(non_blocking);
|
||||
}
|
||||
};
|
||||
|
||||
class FakePoller: public Poller {
|
||||
private:
|
||||
std::vector<SubSocket*> sockets;
|
||||
|
||||
public:
|
||||
void registerSocket(SubSocket *socket) override;
|
||||
std::vector<SubSocket*> poll(int timeout) override;
|
||||
~FakePoller() {}
|
||||
};
|
||||
215
cereal/messaging/impl_msgq.cc
Normal file
215
cereal/messaging/impl_msgq.cc
Normal file
@@ -0,0 +1,215 @@
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <csignal>
|
||||
#include <cerrno>
|
||||
|
||||
#include "cereal/services.h"
|
||||
#include "cereal/messaging/impl_msgq.h"
|
||||
|
||||
|
||||
volatile sig_atomic_t msgq_do_exit = 0;
|
||||
|
||||
void sig_handler(int signal) {
|
||||
assert(signal == SIGINT || signal == SIGTERM);
|
||||
msgq_do_exit = 1;
|
||||
}
|
||||
|
||||
static bool service_exists(std::string path){
|
||||
return services.count(path) > 0;
|
||||
}
|
||||
|
||||
|
||||
MSGQContext::MSGQContext() {
|
||||
}
|
||||
|
||||
MSGQContext::~MSGQContext() {
|
||||
}
|
||||
|
||||
void MSGQMessage::init(size_t sz) {
|
||||
size = sz;
|
||||
data = new char[size];
|
||||
}
|
||||
|
||||
void MSGQMessage::init(char * d, size_t sz) {
|
||||
size = sz;
|
||||
data = new char[size];
|
||||
memcpy(data, d, size);
|
||||
}
|
||||
|
||||
void MSGQMessage::takeOwnership(char * d, size_t sz) {
|
||||
size = sz;
|
||||
data = d;
|
||||
}
|
||||
|
||||
void MSGQMessage::close() {
|
||||
if (size > 0){
|
||||
delete[] data;
|
||||
}
|
||||
size = 0;
|
||||
}
|
||||
|
||||
MSGQMessage::~MSGQMessage() {
|
||||
this->close();
|
||||
}
|
||||
|
||||
int MSGQSubSocket::connect(Context *context, std::string endpoint, std::string address, bool conflate, bool check_endpoint){
|
||||
assert(context);
|
||||
assert(address == "127.0.0.1");
|
||||
|
||||
if (check_endpoint && !service_exists(std::string(endpoint))){
|
||||
std::cout << "Warning, " << std::string(endpoint) << " is not in service list." << std::endl;
|
||||
}
|
||||
|
||||
q = new msgq_queue_t;
|
||||
int r = msgq_new_queue(q, endpoint.c_str(), DEFAULT_SEGMENT_SIZE);
|
||||
if (r != 0){
|
||||
return r;
|
||||
}
|
||||
|
||||
msgq_init_subscriber(q);
|
||||
|
||||
if (conflate){
|
||||
q->read_conflate = true;
|
||||
}
|
||||
|
||||
timeout = -1;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
Message * MSGQSubSocket::receive(bool non_blocking){
|
||||
msgq_do_exit = 0;
|
||||
|
||||
void (*prev_handler_sigint)(int);
|
||||
void (*prev_handler_sigterm)(int);
|
||||
if (!non_blocking){
|
||||
prev_handler_sigint = std::signal(SIGINT, sig_handler);
|
||||
prev_handler_sigterm = std::signal(SIGTERM, sig_handler);
|
||||
}
|
||||
|
||||
msgq_msg_t msg;
|
||||
|
||||
MSGQMessage *r = NULL;
|
||||
|
||||
int rc = msgq_msg_recv(&msg, q);
|
||||
|
||||
// Hack to implement blocking read with a poller. Don't use this
|
||||
while (!non_blocking && rc == 0 && msgq_do_exit == 0){
|
||||
msgq_pollitem_t items[1];
|
||||
items[0].q = q;
|
||||
|
||||
int t = (timeout != -1) ? timeout : 100;
|
||||
|
||||
int n = msgq_poll(items, 1, t);
|
||||
rc = msgq_msg_recv(&msg, q);
|
||||
|
||||
// The poll indicated a message was ready, but the receive failed. Try again
|
||||
if (n == 1 && rc == 0){
|
||||
continue;
|
||||
}
|
||||
|
||||
if (timeout != -1){
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (!non_blocking){
|
||||
std::signal(SIGINT, prev_handler_sigint);
|
||||
std::signal(SIGTERM, prev_handler_sigterm);
|
||||
}
|
||||
|
||||
errno = msgq_do_exit ? EINTR : 0;
|
||||
|
||||
if (rc > 0){
|
||||
if (msgq_do_exit){
|
||||
msgq_msg_close(&msg); // Free unused message on exit
|
||||
} else {
|
||||
r = new MSGQMessage;
|
||||
r->takeOwnership(msg.data, msg.size);
|
||||
}
|
||||
}
|
||||
|
||||
return (Message*)r;
|
||||
}
|
||||
|
||||
void MSGQSubSocket::setTimeout(int t){
|
||||
timeout = t;
|
||||
}
|
||||
|
||||
MSGQSubSocket::~MSGQSubSocket(){
|
||||
if (q != NULL){
|
||||
msgq_close_queue(q);
|
||||
delete q;
|
||||
}
|
||||
}
|
||||
|
||||
int MSGQPubSocket::connect(Context *context, std::string endpoint, bool check_endpoint){
|
||||
assert(context);
|
||||
|
||||
if (check_endpoint && !service_exists(std::string(endpoint))){
|
||||
std::cout << "Warning, " << std::string(endpoint) << " is not in service list." << std::endl;
|
||||
}
|
||||
|
||||
q = new msgq_queue_t;
|
||||
int r = msgq_new_queue(q, endpoint.c_str(), DEFAULT_SEGMENT_SIZE);
|
||||
if (r != 0){
|
||||
return r;
|
||||
}
|
||||
|
||||
msgq_init_publisher(q);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int MSGQPubSocket::sendMessage(Message *message){
|
||||
msgq_msg_t msg;
|
||||
msg.data = message->getData();
|
||||
msg.size = message->getSize();
|
||||
|
||||
return msgq_msg_send(&msg, q);
|
||||
}
|
||||
|
||||
int MSGQPubSocket::send(char *data, size_t size){
|
||||
msgq_msg_t msg;
|
||||
msg.data = data;
|
||||
msg.size = size;
|
||||
|
||||
return msgq_msg_send(&msg, q);
|
||||
}
|
||||
|
||||
bool MSGQPubSocket::all_readers_updated() {
|
||||
return msgq_all_readers_updated(q);
|
||||
}
|
||||
|
||||
MSGQPubSocket::~MSGQPubSocket(){
|
||||
if (q != NULL){
|
||||
msgq_close_queue(q);
|
||||
delete q;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void MSGQPoller::registerSocket(SubSocket * socket){
|
||||
assert(num_polls + 1 < MAX_POLLERS);
|
||||
polls[num_polls].q = (msgq_queue_t*)socket->getRawSocket();
|
||||
|
||||
sockets.push_back(socket);
|
||||
num_polls++;
|
||||
}
|
||||
|
||||
std::vector<SubSocket*> MSGQPoller::poll(int timeout){
|
||||
std::vector<SubSocket*> r;
|
||||
|
||||
msgq_poll(polls, num_polls, timeout);
|
||||
for (size_t i = 0; i < num_polls; i++){
|
||||
if (polls[i].revents){
|
||||
r.push_back(sockets[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
67
cereal/messaging/impl_msgq.h
Normal file
67
cereal/messaging/impl_msgq.h
Normal file
@@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cereal/messaging/messaging.h"
|
||||
#include "cereal/messaging/msgq.h"
|
||||
|
||||
#define MAX_POLLERS 128
|
||||
|
||||
class MSGQContext : public Context {
|
||||
private:
|
||||
void * context = NULL;
|
||||
public:
|
||||
MSGQContext();
|
||||
void * getRawContext() {return context;}
|
||||
~MSGQContext();
|
||||
};
|
||||
|
||||
class MSGQMessage : public Message {
|
||||
private:
|
||||
char * data;
|
||||
size_t size;
|
||||
public:
|
||||
void init(size_t size);
|
||||
void init(char *data, size_t size);
|
||||
void takeOwnership(char *data, size_t size);
|
||||
size_t getSize(){return size;}
|
||||
char * getData(){return data;}
|
||||
void close();
|
||||
~MSGQMessage();
|
||||
};
|
||||
|
||||
class MSGQSubSocket : public SubSocket {
|
||||
private:
|
||||
msgq_queue_t * q = NULL;
|
||||
int timeout;
|
||||
public:
|
||||
int connect(Context *context, std::string endpoint, std::string address, bool conflate=false, bool check_endpoint=true);
|
||||
void setTimeout(int timeout);
|
||||
void * getRawSocket() {return (void*)q;}
|
||||
Message *receive(bool non_blocking=false);
|
||||
~MSGQSubSocket();
|
||||
};
|
||||
|
||||
class MSGQPubSocket : public PubSocket {
|
||||
private:
|
||||
msgq_queue_t * q = NULL;
|
||||
public:
|
||||
int connect(Context *context, std::string endpoint, bool check_endpoint=true);
|
||||
int sendMessage(Message *message);
|
||||
int send(char *data, size_t size);
|
||||
bool all_readers_updated();
|
||||
~MSGQPubSocket();
|
||||
};
|
||||
|
||||
class MSGQPoller : public Poller {
|
||||
private:
|
||||
std::vector<SubSocket*> sockets;
|
||||
msgq_pollitem_t polls[MAX_POLLERS];
|
||||
size_t num_polls = 0;
|
||||
|
||||
public:
|
||||
void registerSocket(SubSocket *socket);
|
||||
std::vector<SubSocket*> poll(int timeout);
|
||||
~MSGQPoller(){}
|
||||
};
|
||||
162
cereal/messaging/impl_zmq.cc
Normal file
162
cereal/messaging/impl_zmq.cc
Normal file
@@ -0,0 +1,162 @@
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <cerrno>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "cereal/services.h"
|
||||
#include "cereal/messaging/impl_zmq.h"
|
||||
|
||||
static int get_port(std::string endpoint) {
|
||||
return services.at(endpoint).port;
|
||||
}
|
||||
|
||||
ZMQContext::ZMQContext() {
|
||||
context = zmq_ctx_new();
|
||||
}
|
||||
|
||||
ZMQContext::~ZMQContext() {
|
||||
zmq_ctx_term(context);
|
||||
}
|
||||
|
||||
void ZMQMessage::init(size_t sz) {
|
||||
size = sz;
|
||||
data = new char[size];
|
||||
}
|
||||
|
||||
void ZMQMessage::init(char * d, size_t sz) {
|
||||
size = sz;
|
||||
data = new char[size];
|
||||
memcpy(data, d, size);
|
||||
}
|
||||
|
||||
void ZMQMessage::close() {
|
||||
if (size > 0){
|
||||
delete[] data;
|
||||
}
|
||||
size = 0;
|
||||
}
|
||||
|
||||
ZMQMessage::~ZMQMessage() {
|
||||
this->close();
|
||||
}
|
||||
|
||||
|
||||
int ZMQSubSocket::connect(Context *context, std::string endpoint, std::string address, bool conflate, bool check_endpoint){
|
||||
sock = zmq_socket(context->getRawContext(), ZMQ_SUB);
|
||||
if (sock == NULL){
|
||||
return -1;
|
||||
}
|
||||
|
||||
zmq_setsockopt(sock, ZMQ_SUBSCRIBE, "", 0);
|
||||
|
||||
if (conflate){
|
||||
int arg = 1;
|
||||
zmq_setsockopt(sock, ZMQ_CONFLATE, &arg, sizeof(int));
|
||||
}
|
||||
|
||||
int reconnect_ivl = 500;
|
||||
zmq_setsockopt(sock, ZMQ_RECONNECT_IVL_MAX, &reconnect_ivl, sizeof(reconnect_ivl));
|
||||
|
||||
full_endpoint = "tcp://" + address + ":";
|
||||
if (check_endpoint){
|
||||
full_endpoint += std::to_string(get_port(endpoint));
|
||||
} else {
|
||||
full_endpoint += endpoint;
|
||||
}
|
||||
|
||||
return zmq_connect(sock, full_endpoint.c_str());
|
||||
}
|
||||
|
||||
|
||||
Message * ZMQSubSocket::receive(bool non_blocking){
|
||||
zmq_msg_t msg;
|
||||
assert(zmq_msg_init(&msg) == 0);
|
||||
|
||||
int flags = non_blocking ? ZMQ_DONTWAIT : 0;
|
||||
int rc = zmq_msg_recv(&msg, sock, flags);
|
||||
Message *r = NULL;
|
||||
|
||||
if (rc >= 0){
|
||||
// Make a copy to ensure the data is aligned
|
||||
r = new ZMQMessage;
|
||||
r->init((char*)zmq_msg_data(&msg), zmq_msg_size(&msg));
|
||||
}
|
||||
|
||||
zmq_msg_close(&msg);
|
||||
return r;
|
||||
}
|
||||
|
||||
void ZMQSubSocket::setTimeout(int timeout){
|
||||
zmq_setsockopt(sock, ZMQ_RCVTIMEO, &timeout, sizeof(int));
|
||||
}
|
||||
|
||||
ZMQSubSocket::~ZMQSubSocket(){
|
||||
zmq_close(sock);
|
||||
}
|
||||
|
||||
int ZMQPubSocket::connect(Context *context, std::string endpoint, bool check_endpoint){
|
||||
sock = zmq_socket(context->getRawContext(), ZMQ_PUB);
|
||||
if (sock == NULL){
|
||||
return -1;
|
||||
}
|
||||
|
||||
full_endpoint = "tcp://*:";
|
||||
if (check_endpoint){
|
||||
full_endpoint += std::to_string(get_port(endpoint));
|
||||
} else {
|
||||
full_endpoint += endpoint;
|
||||
}
|
||||
|
||||
// ZMQ pub sockets cannot be shared between processes, so we need to ensure pid stays the same
|
||||
pid = getpid();
|
||||
|
||||
return zmq_bind(sock, full_endpoint.c_str());
|
||||
}
|
||||
|
||||
int ZMQPubSocket::sendMessage(Message *message) {
|
||||
assert(pid == getpid());
|
||||
return zmq_send(sock, message->getData(), message->getSize(), ZMQ_DONTWAIT);
|
||||
}
|
||||
|
||||
int ZMQPubSocket::send(char *data, size_t size) {
|
||||
assert(pid == getpid());
|
||||
return zmq_send(sock, data, size, ZMQ_DONTWAIT);
|
||||
}
|
||||
|
||||
bool ZMQPubSocket::all_readers_updated() {
|
||||
assert(false); // TODO not implemented
|
||||
return false;
|
||||
}
|
||||
|
||||
ZMQPubSocket::~ZMQPubSocket(){
|
||||
zmq_close(sock);
|
||||
}
|
||||
|
||||
|
||||
void ZMQPoller::registerSocket(SubSocket * socket){
|
||||
assert(num_polls + 1 < MAX_POLLERS);
|
||||
polls[num_polls].socket = socket->getRawSocket();
|
||||
polls[num_polls].events = ZMQ_POLLIN;
|
||||
|
||||
sockets.push_back(socket);
|
||||
num_polls++;
|
||||
}
|
||||
|
||||
std::vector<SubSocket*> ZMQPoller::poll(int timeout){
|
||||
std::vector<SubSocket*> r;
|
||||
|
||||
int rc = zmq_poll(polls, num_polls, timeout);
|
||||
if (rc < 0){
|
||||
return r;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_polls; i++){
|
||||
if (polls[i].revents){
|
||||
r.push_back(sockets[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
68
cereal/messaging/impl_zmq.h
Normal file
68
cereal/messaging/impl_zmq.h
Normal file
@@ -0,0 +1,68 @@
|
||||
#pragma once
|
||||
|
||||
#include <zmq.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cereal/messaging/messaging.h"
|
||||
|
||||
#define MAX_POLLERS 128
|
||||
|
||||
class ZMQContext : public Context {
|
||||
private:
|
||||
void * context = NULL;
|
||||
public:
|
||||
ZMQContext();
|
||||
void * getRawContext() {return context;}
|
||||
~ZMQContext();
|
||||
};
|
||||
|
||||
class ZMQMessage : public Message {
|
||||
private:
|
||||
char * data;
|
||||
size_t size;
|
||||
public:
|
||||
void init(size_t size);
|
||||
void init(char *data, size_t size);
|
||||
size_t getSize(){return size;}
|
||||
char * getData(){return data;}
|
||||
void close();
|
||||
~ZMQMessage();
|
||||
};
|
||||
|
||||
class ZMQSubSocket : public SubSocket {
|
||||
private:
|
||||
void * sock;
|
||||
std::string full_endpoint;
|
||||
public:
|
||||
int connect(Context *context, std::string endpoint, std::string address, bool conflate=false, bool check_endpoint=true);
|
||||
void setTimeout(int timeout);
|
||||
void * getRawSocket() {return sock;}
|
||||
Message *receive(bool non_blocking=false);
|
||||
~ZMQSubSocket();
|
||||
};
|
||||
|
||||
class ZMQPubSocket : public PubSocket {
|
||||
private:
|
||||
void * sock;
|
||||
std::string full_endpoint;
|
||||
int pid = -1;
|
||||
public:
|
||||
int connect(Context *context, std::string endpoint, bool check_endpoint=true);
|
||||
int sendMessage(Message *message);
|
||||
int send(char *data, size_t size);
|
||||
bool all_readers_updated();
|
||||
~ZMQPubSocket();
|
||||
};
|
||||
|
||||
class ZMQPoller : public Poller {
|
||||
private:
|
||||
std::vector<SubSocket*> sockets;
|
||||
zmq_pollitem_t polls[MAX_POLLERS];
|
||||
size_t num_polls = 0;
|
||||
|
||||
public:
|
||||
void registerSocket(SubSocket *socket);
|
||||
std::vector<SubSocket*> poll(int timeout);
|
||||
~ZMQPoller(){}
|
||||
};
|
||||
120
cereal/messaging/messaging.cc
Normal file
120
cereal/messaging/messaging.cc
Normal file
@@ -0,0 +1,120 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "cereal/messaging/messaging.h"
|
||||
#include "cereal/messaging/impl_zmq.h"
|
||||
#include "cereal/messaging/impl_msgq.h"
|
||||
#include "cereal/messaging/impl_fake.h"
|
||||
|
||||
#ifdef __APPLE__
|
||||
const bool MUST_USE_ZMQ = true;
|
||||
#else
|
||||
const bool MUST_USE_ZMQ = false;
|
||||
#endif
|
||||
|
||||
bool messaging_use_zmq(){
|
||||
if (std::getenv("ZMQ") || MUST_USE_ZMQ) {
|
||||
if (std::getenv("OPENPILOT_PREFIX")) {
|
||||
std::cerr << "OPENPILOT_PREFIX not supported with ZMQ backend\n";
|
||||
assert(false);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool messaging_use_fake(){
|
||||
char* fake_enabled = std::getenv("CEREAL_FAKE");
|
||||
return fake_enabled != NULL;
|
||||
}
|
||||
|
||||
Context * Context::create(){
|
||||
Context * c;
|
||||
if (messaging_use_zmq()){
|
||||
c = new ZMQContext();
|
||||
} else {
|
||||
c = new MSGQContext();
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
SubSocket * SubSocket::create(){
|
||||
SubSocket * s;
|
||||
if (messaging_use_fake()) {
|
||||
if (messaging_use_zmq()) {
|
||||
s = new FakeSubSocket<ZMQSubSocket>();
|
||||
} else {
|
||||
s = new FakeSubSocket<MSGQSubSocket>();
|
||||
}
|
||||
} else {
|
||||
if (messaging_use_zmq()){
|
||||
s = new ZMQSubSocket();
|
||||
} else {
|
||||
s = new MSGQSubSocket();
|
||||
}
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
SubSocket * SubSocket::create(Context * context, std::string endpoint, std::string address, bool conflate, bool check_endpoint){
|
||||
SubSocket *s = SubSocket::create();
|
||||
int r = s->connect(context, endpoint, address, conflate, check_endpoint);
|
||||
|
||||
if (r == 0) {
|
||||
return s;
|
||||
} else {
|
||||
std::cerr << "Error, failed to connect SubSocket to " << endpoint << ": " << strerror(errno) << std::endl;
|
||||
|
||||
delete s;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
PubSocket * PubSocket::create(){
|
||||
PubSocket * s;
|
||||
if (messaging_use_zmq()){
|
||||
s = new ZMQPubSocket();
|
||||
} else {
|
||||
s = new MSGQPubSocket();
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
PubSocket * PubSocket::create(Context * context, std::string endpoint, bool check_endpoint){
|
||||
PubSocket *s = PubSocket::create();
|
||||
int r = s->connect(context, endpoint, check_endpoint);
|
||||
|
||||
if (r == 0) {
|
||||
return s;
|
||||
} else {
|
||||
std::cerr << "Error, failed to bind PubSocket to " << endpoint << ": " << strerror(errno) << std::endl;
|
||||
|
||||
delete s;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Poller * Poller::create(){
|
||||
Poller * p;
|
||||
if (messaging_use_fake()) {
|
||||
p = new FakePoller();
|
||||
} else {
|
||||
if (messaging_use_zmq()){
|
||||
p = new ZMQPoller();
|
||||
} else {
|
||||
p = new MSGQPoller();
|
||||
}
|
||||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
Poller * Poller::create(std::vector<SubSocket*> sockets){
|
||||
Poller * p = Poller::create();
|
||||
|
||||
for (auto s : sockets){
|
||||
p->registerSocket(s);
|
||||
}
|
||||
return p;
|
||||
}
|
||||
162
cereal/messaging/messaging.h
Normal file
162
cereal/messaging/messaging.h
Normal file
@@ -0,0 +1,162 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <time.h>
|
||||
|
||||
#include <capnp/serialize.h>
|
||||
|
||||
#include "cereal/gen/cpp/log.capnp.h"
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define CLOCK_BOOTTIME CLOCK_MONOTONIC
|
||||
#endif
|
||||
|
||||
#define MSG_MULTIPLE_PUBLISHERS 100
|
||||
|
||||
bool messaging_use_zmq();
|
||||
|
||||
class Context {
|
||||
public:
|
||||
virtual void * getRawContext() = 0;
|
||||
static Context * create();
|
||||
virtual ~Context(){}
|
||||
};
|
||||
|
||||
class Message {
|
||||
public:
|
||||
virtual void init(size_t size) = 0;
|
||||
virtual void init(char * data, size_t size) = 0;
|
||||
virtual void close() = 0;
|
||||
virtual size_t getSize() = 0;
|
||||
virtual char * getData() = 0;
|
||||
virtual ~Message(){}
|
||||
};
|
||||
|
||||
|
||||
class SubSocket {
|
||||
public:
|
||||
virtual int connect(Context *context, std::string endpoint, std::string address, bool conflate=false, bool check_endpoint=true) = 0;
|
||||
virtual void setTimeout(int timeout) = 0;
|
||||
virtual Message *receive(bool non_blocking=false) = 0;
|
||||
virtual void * getRawSocket() = 0;
|
||||
static SubSocket * create();
|
||||
static SubSocket * create(Context * context, std::string endpoint, std::string address="127.0.0.1", bool conflate=false, bool check_endpoint=true);
|
||||
virtual ~SubSocket(){}
|
||||
};
|
||||
|
||||
class PubSocket {
|
||||
public:
|
||||
virtual int connect(Context *context, std::string endpoint, bool check_endpoint=true) = 0;
|
||||
virtual int sendMessage(Message *message) = 0;
|
||||
virtual int send(char *data, size_t size) = 0;
|
||||
virtual bool all_readers_updated() = 0;
|
||||
static PubSocket * create();
|
||||
static PubSocket * create(Context * context, std::string endpoint, bool check_endpoint=true);
|
||||
static PubSocket * create(Context * context, std::string endpoint, int port, bool check_endpoint=true);
|
||||
virtual ~PubSocket(){}
|
||||
};
|
||||
|
||||
class Poller {
|
||||
public:
|
||||
virtual void registerSocket(SubSocket *socket) = 0;
|
||||
virtual std::vector<SubSocket*> poll(int timeout) = 0;
|
||||
static Poller * create();
|
||||
static Poller * create(std::vector<SubSocket*> sockets);
|
||||
virtual ~Poller(){}
|
||||
};
|
||||
|
||||
class SubMaster {
|
||||
public:
|
||||
SubMaster(const std::vector<const char *> &service_list, const std::vector<const char *> &poll = {},
|
||||
const char *address = nullptr, const std::vector<const char *> &ignore_alive = {});
|
||||
void update(int timeout = 1000);
|
||||
void update_msgs(uint64_t current_time, const std::vector<std::pair<std::string, cereal::Event::Reader>> &messages);
|
||||
inline bool allAlive(const std::vector<const char *> &service_list = {}) { return all_(service_list, false, true); }
|
||||
inline bool allValid(const std::vector<const char *> &service_list = {}) { return all_(service_list, true, false); }
|
||||
inline bool allAliveAndValid(const std::vector<const char *> &service_list = {}) { return all_(service_list, true, true); }
|
||||
void drain();
|
||||
~SubMaster();
|
||||
|
||||
uint64_t frame = 0;
|
||||
bool updated(const char *name) const;
|
||||
bool alive(const char *name) const;
|
||||
bool valid(const char *name) const;
|
||||
uint64_t rcv_frame(const char *name) const;
|
||||
uint64_t rcv_time(const char *name) const;
|
||||
cereal::Event::Reader &operator[](const char *name) const;
|
||||
|
||||
private:
|
||||
bool all_(const std::vector<const char *> &service_list, bool valid, bool alive);
|
||||
Poller *poller_ = nullptr;
|
||||
struct SubMessage;
|
||||
std::map<SubSocket *, SubMessage *> messages_;
|
||||
std::map<std::string, SubMessage *> services_;
|
||||
};
|
||||
|
||||
class MessageBuilder : public capnp::MallocMessageBuilder {
|
||||
public:
|
||||
MessageBuilder() = default;
|
||||
|
||||
cereal::Event::Builder initEvent(bool valid = true) {
|
||||
cereal::Event::Builder event = initRoot<cereal::Event>();
|
||||
struct timespec t;
|
||||
clock_gettime(CLOCK_BOOTTIME, &t);
|
||||
uint64_t current_time = t.tv_sec * 1000000000ULL + t.tv_nsec;
|
||||
event.setLogMonoTime(current_time);
|
||||
event.setValid(valid);
|
||||
return event;
|
||||
}
|
||||
|
||||
kj::ArrayPtr<capnp::byte> toBytes() {
|
||||
heapArray_ = capnp::messageToFlatArray(*this);
|
||||
return heapArray_.asBytes();
|
||||
}
|
||||
|
||||
size_t getSerializedSize() {
|
||||
return capnp::computeSerializedSizeInWords(*this) * sizeof(capnp::word);
|
||||
}
|
||||
|
||||
int serializeToBuffer(unsigned char *buffer, size_t buffer_size) {
|
||||
size_t serialized_size = getSerializedSize();
|
||||
if (serialized_size > buffer_size) { return -1; }
|
||||
kj::ArrayOutputStream out(kj::ArrayPtr<capnp::byte>(buffer, buffer_size));
|
||||
capnp::writeMessage(out, *this);
|
||||
return serialized_size;
|
||||
}
|
||||
|
||||
private:
|
||||
kj::Array<capnp::word> heapArray_;
|
||||
};
|
||||
|
||||
class PubMaster {
|
||||
public:
|
||||
PubMaster(const std::vector<const char *> &service_list);
|
||||
inline int send(const char *name, capnp::byte *data, size_t size) { return sockets_.at(name)->send((char *)data, size); }
|
||||
int send(const char *name, MessageBuilder &msg);
|
||||
~PubMaster();
|
||||
|
||||
private:
|
||||
std::map<std::string, PubSocket *> sockets_;
|
||||
};
|
||||
|
||||
class AlignedBuffer {
|
||||
public:
|
||||
kj::ArrayPtr<const capnp::word> align(const char *data, const size_t size) {
|
||||
words_size = size / sizeof(capnp::word) + 1;
|
||||
if (aligned_buf.size() < words_size) {
|
||||
aligned_buf = kj::heapArray<capnp::word>(words_size < 512 ? 512 : words_size);
|
||||
}
|
||||
memcpy(aligned_buf.begin(), data, size);
|
||||
return aligned_buf.slice(0, words_size);
|
||||
}
|
||||
inline kj::ArrayPtr<const capnp::word> align(Message *m) {
|
||||
return align(m->getData(), m->getSize());
|
||||
}
|
||||
private:
|
||||
kj::Array<capnp::word> aligned_buf;
|
||||
size_t words_size;
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
468
cereal/messaging/msgq.cc
Normal file
468
cereal/messaging/msgq.cc
Normal file
@@ -0,0 +1,468 @@
|
||||
#include <iostream>
|
||||
#include <cassert>
|
||||
#include <cerrno>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
#include <chrono>
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <csignal>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
|
||||
#include <poll.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "cereal/messaging/msgq.h"
|
||||
|
||||
void sigusr2_handler(int signal) {
|
||||
assert(signal == SIGUSR2);
|
||||
}
|
||||
|
||||
uint64_t msgq_get_uid(void){
|
||||
std::random_device rd("/dev/urandom");
|
||||
std::uniform_int_distribution<uint64_t> distribution(0, std::numeric_limits<uint32_t>::max());
|
||||
|
||||
#ifdef __APPLE__
|
||||
// TODO: this doesn't work
|
||||
uint64_t uid = distribution(rd) << 32 | getpid();
|
||||
#else
|
||||
uint64_t uid = distribution(rd) << 32 | syscall(SYS_gettid);
|
||||
#endif
|
||||
|
||||
return uid;
|
||||
}
|
||||
|
||||
int msgq_msg_init_size(msgq_msg_t * msg, size_t size){
|
||||
msg->size = size;
|
||||
msg->data = new(std::nothrow) char[size];
|
||||
|
||||
return (msg->data == NULL) ? -1 : 0;
|
||||
}
|
||||
|
||||
|
||||
int msgq_msg_init_data(msgq_msg_t * msg, char * data, size_t size) {
|
||||
int r = msgq_msg_init_size(msg, size);
|
||||
|
||||
if (r == 0)
|
||||
memcpy(msg->data, data, size);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
int msgq_msg_close(msgq_msg_t * msg){
|
||||
if (msg->size > 0)
|
||||
delete[] msg->data;
|
||||
|
||||
msg->size = 0;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void msgq_reset_reader(msgq_queue_t * q){
|
||||
int id = q->reader_id;
|
||||
q->read_valids[id]->store(true);
|
||||
q->read_pointers[id]->store(*q->write_pointer);
|
||||
}
|
||||
|
||||
void msgq_wait_for_subscriber(msgq_queue_t *q){
|
||||
while (*q->num_readers == 0){
|
||||
// wait for subscriber
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){
|
||||
assert(size < 0xFFFFFFFF); // Buffer must be smaller than 2^32 bytes
|
||||
std::signal(SIGUSR2, sigusr2_handler);
|
||||
|
||||
std::string full_path = "/dev/shm/";
|
||||
const char* prefix = std::getenv("OPENPILOT_PREFIX");
|
||||
if (prefix) {
|
||||
full_path += std::string(prefix) + "/";
|
||||
}
|
||||
full_path += path;
|
||||
|
||||
auto fd = open(full_path.c_str(), O_RDWR | O_CREAT, 0664);
|
||||
if (fd < 0) {
|
||||
std::cout << "Warning, could not open: " << full_path << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
int rc = ftruncate(fd, size + sizeof(msgq_header_t));
|
||||
if (rc < 0){
|
||||
close(fd);
|
||||
return -1;
|
||||
}
|
||||
char * mem = (char*)mmap(NULL, size + sizeof(msgq_header_t), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
|
||||
close(fd);
|
||||
|
||||
if (mem == NULL){
|
||||
return -1;
|
||||
}
|
||||
q->mmap_p = mem;
|
||||
|
||||
msgq_header_t *header = (msgq_header_t *)mem;
|
||||
|
||||
// Setup pointers to header segment
|
||||
q->num_readers = reinterpret_cast<std::atomic<uint64_t>*>(&header->num_readers);
|
||||
q->write_pointer = reinterpret_cast<std::atomic<uint64_t>*>(&header->write_pointer);
|
||||
q->write_uid = reinterpret_cast<std::atomic<uint64_t>*>(&header->write_uid);
|
||||
|
||||
for (size_t i = 0; i < NUM_READERS; i++){
|
||||
q->read_pointers[i] = reinterpret_cast<std::atomic<uint64_t>*>(&header->read_pointers[i]);
|
||||
q->read_valids[i] = reinterpret_cast<std::atomic<uint64_t>*>(&header->read_valids[i]);
|
||||
q->read_uids[i] = reinterpret_cast<std::atomic<uint64_t>*>(&header->read_uids[i]);
|
||||
}
|
||||
|
||||
q->data = mem + sizeof(msgq_header_t);
|
||||
q->size = size;
|
||||
q->reader_id = -1;
|
||||
|
||||
q->endpoint = path;
|
||||
q->read_conflate = false;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void msgq_close_queue(msgq_queue_t *q){
|
||||
if (q->mmap_p != NULL){
|
||||
munmap(q->mmap_p, q->size + sizeof(msgq_header_t));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void msgq_init_publisher(msgq_queue_t * q) {
|
||||
//std::cout << "Starting publisher" << std::endl;
|
||||
uint64_t uid = msgq_get_uid();
|
||||
|
||||
*q->write_uid = uid;
|
||||
*q->num_readers = 0;
|
||||
|
||||
for (size_t i = 0; i < NUM_READERS; i++){
|
||||
*q->read_valids[i] = false;
|
||||
*q->read_uids[i] = 0;
|
||||
}
|
||||
|
||||
q->write_uid_local = uid;
|
||||
}
|
||||
|
||||
static void thread_signal(uint32_t tid) {
|
||||
#ifndef SYS_tkill
|
||||
// TODO: this won't work for multithreaded programs
|
||||
kill(tid, SIGUSR2);
|
||||
#else
|
||||
syscall(SYS_tkill, tid, SIGUSR2);
|
||||
#endif
|
||||
}
|
||||
|
||||
void msgq_init_subscriber(msgq_queue_t * q) {
|
||||
assert(q != NULL);
|
||||
assert(q->num_readers != NULL);
|
||||
|
||||
uint64_t uid = msgq_get_uid();
|
||||
|
||||
// Get reader id
|
||||
while (true){
|
||||
uint64_t cur_num_readers = *q->num_readers;
|
||||
uint64_t new_num_readers = cur_num_readers + 1;
|
||||
|
||||
// No more slots available. Reset all subscribers to kick out inactive ones
|
||||
if (new_num_readers > NUM_READERS){
|
||||
//std::cout << "Warning, evicting all subscribers!" << std::endl;
|
||||
*q->num_readers = 0;
|
||||
|
||||
for (size_t i = 0; i < NUM_READERS; i++){
|
||||
*q->read_valids[i] = false;
|
||||
|
||||
uint64_t old_uid = *q->read_uids[i];
|
||||
*q->read_uids[i] = 0;
|
||||
|
||||
// Wake up reader in case they are in a poll
|
||||
thread_signal(old_uid & 0xFFFFFFFF);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// Use atomic compare and swap to handle race condition
|
||||
// where two subscribers start at the same time
|
||||
if (std::atomic_compare_exchange_strong(q->num_readers,
|
||||
&cur_num_readers,
|
||||
new_num_readers)){
|
||||
q->reader_id = cur_num_readers;
|
||||
q->read_uid_local = uid;
|
||||
|
||||
// We start with read_valid = false,
|
||||
// on the first read the read pointer will be synchronized with the write pointer
|
||||
*q->read_valids[cur_num_readers] = false;
|
||||
*q->read_pointers[cur_num_readers] = 0;
|
||||
*q->read_uids[cur_num_readers] = uid;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
//std::cout << "New subscriber id: " << q->reader_id << " uid: " << q->read_uid_local << " " << q->endpoint << std::endl;
|
||||
msgq_reset_reader(q);
|
||||
}
|
||||
|
||||
int msgq_msg_send(msgq_msg_t * msg, msgq_queue_t *q){
|
||||
// Die if we are no longer the active publisher
|
||||
if (q->write_uid_local != *q->write_uid){
|
||||
std::cout << "Killing old publisher: " << q->endpoint << std::endl;
|
||||
errno = EADDRINUSE;
|
||||
return -1;
|
||||
}
|
||||
|
||||
uint64_t total_msg_size = ALIGN(msg->size + sizeof(int64_t));
|
||||
|
||||
// We need to fit at least three messages in the queue,
|
||||
// then we can always safely access the last message
|
||||
assert(3 * total_msg_size <= q->size);
|
||||
|
||||
uint64_t num_readers = *q->num_readers;
|
||||
|
||||
uint32_t write_cycles, write_pointer;
|
||||
UNPACK64(write_cycles, write_pointer, *q->write_pointer);
|
||||
|
||||
char *p = q->data + write_pointer; // add base offset
|
||||
|
||||
// Check remaining space
|
||||
// Always leave space for a wraparound tag for the next message, including alignment
|
||||
int64_t remaining_space = q->size - write_pointer - total_msg_size - sizeof(int64_t);
|
||||
if (remaining_space <= 0){
|
||||
// Write -1 size tag indicating wraparound
|
||||
*(int64_t*)p = -1;
|
||||
|
||||
// Invalidate all readers that are beyond the write pointer
|
||||
// TODO: should we handle the case where a new reader shows up while this is running?
|
||||
for (uint64_t i = 0; i < num_readers; i++){
|
||||
uint64_t read_pointer = *q->read_pointers[i];
|
||||
uint64_t read_cycles = read_pointer >> 32;
|
||||
read_pointer &= 0xFFFFFFFF;
|
||||
|
||||
if ((read_pointer > write_pointer) && (read_cycles != write_cycles)) {
|
||||
*q->read_valids[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Update global and local copies of write pointer and write_cycles
|
||||
write_pointer = 0;
|
||||
write_cycles = write_cycles + 1;
|
||||
PACK64(*q->write_pointer, write_cycles, write_pointer);
|
||||
|
||||
// Set actual pointer to the beginning of the data segment
|
||||
p = q->data;
|
||||
}
|
||||
|
||||
// Invalidate readers that are in the area that will be written
|
||||
uint64_t start = write_pointer;
|
||||
uint64_t end = ALIGN(start + sizeof(int64_t) + msg->size);
|
||||
|
||||
for (uint64_t i = 0; i < num_readers; i++){
|
||||
uint32_t read_cycles, read_pointer;
|
||||
UNPACK64(read_cycles, read_pointer, *q->read_pointers[i]);
|
||||
|
||||
if ((read_pointer >= start) && (read_pointer < end) && (read_cycles != write_cycles)) {
|
||||
*q->read_valids[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Write size tag
|
||||
std::atomic<int64_t> *size_p = reinterpret_cast<std::atomic<int64_t>*>(p);
|
||||
*size_p = msg->size;
|
||||
|
||||
// Copy data
|
||||
memcpy(p + sizeof(int64_t), msg->data, msg->size);
|
||||
__sync_synchronize();
|
||||
|
||||
// Update write pointer
|
||||
uint32_t new_ptr = ALIGN(write_pointer + msg->size + sizeof(int64_t));
|
||||
PACK64(*q->write_pointer, write_cycles, new_ptr);
|
||||
|
||||
// Notify readers
|
||||
for (uint64_t i = 0; i < num_readers; i++){
|
||||
uint64_t reader_uid = *q->read_uids[i];
|
||||
thread_signal(reader_uid & 0xFFFFFFFF);
|
||||
}
|
||||
|
||||
return msg->size;
|
||||
}
|
||||
|
||||
|
||||
int msgq_msg_ready(msgq_queue_t * q){
|
||||
start:
|
||||
int id = q->reader_id;
|
||||
assert(id >= 0); // Make sure subscriber is initialized
|
||||
|
||||
if (q->read_uid_local != *q->read_uids[id]){
|
||||
//std::cout << q->endpoint << ": Reader was evicted, reconnecting" << std::endl;
|
||||
msgq_init_subscriber(q);
|
||||
goto start;
|
||||
}
|
||||
|
||||
// Check valid
|
||||
if (!*q->read_valids[id]){
|
||||
msgq_reset_reader(q);
|
||||
goto start;
|
||||
}
|
||||
|
||||
uint32_t read_cycles, read_pointer;
|
||||
UNPACK64(read_cycles, read_pointer, *q->read_pointers[id]);
|
||||
UNUSED(read_cycles);
|
||||
|
||||
uint32_t write_cycles, write_pointer;
|
||||
UNPACK64(write_cycles, write_pointer, *q->write_pointer);
|
||||
UNUSED(write_cycles);
|
||||
|
||||
// Check if new message is available
|
||||
return (read_pointer != write_pointer);
|
||||
}
|
||||
|
||||
int msgq_msg_recv(msgq_msg_t * msg, msgq_queue_t * q){
|
||||
start:
|
||||
int id = q->reader_id;
|
||||
assert(id >= 0); // Make sure subscriber is initialized
|
||||
|
||||
if (q->read_uid_local != *q->read_uids[id]){
|
||||
//std::cout << q->endpoint << ": Reader was evicted, reconnecting" << std::endl;
|
||||
msgq_init_subscriber(q);
|
||||
goto start;
|
||||
}
|
||||
|
||||
// Check valid
|
||||
if (!*q->read_valids[id]){
|
||||
msgq_reset_reader(q);
|
||||
goto start;
|
||||
}
|
||||
|
||||
uint32_t read_cycles, read_pointer;
|
||||
UNPACK64(read_cycles, read_pointer, *q->read_pointers[id]);
|
||||
|
||||
uint32_t write_cycles, write_pointer;
|
||||
UNPACK64(write_cycles, write_pointer, *q->write_pointer);
|
||||
UNUSED(write_cycles);
|
||||
|
||||
char * p = q->data + read_pointer;
|
||||
|
||||
// Check if new message is available
|
||||
if (read_pointer == write_pointer) {
|
||||
msg->size = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Read potential message size
|
||||
std::atomic<int64_t> *size_p = reinterpret_cast<std::atomic<int64_t>*>(p);
|
||||
std::int64_t size = *size_p;
|
||||
|
||||
// Check if the size that was read is valid
|
||||
if (!*q->read_valids[id]){
|
||||
msgq_reset_reader(q);
|
||||
goto start;
|
||||
}
|
||||
|
||||
// If size is -1 the buffer was full, and we need to wrap around
|
||||
if (size == -1){
|
||||
read_cycles++;
|
||||
PACK64(*q->read_pointers[id], read_cycles, 0);
|
||||
goto start;
|
||||
}
|
||||
|
||||
// crashing is better than passing garbage data to the consumer
|
||||
// the size will have weird value if it was overwritten by data accidentally
|
||||
assert((uint64_t)size < q->size);
|
||||
assert(size > 0);
|
||||
|
||||
uint32_t new_read_pointer = ALIGN(read_pointer + sizeof(std::int64_t) + size);
|
||||
|
||||
// If conflate is true, check if this is the latest message, else start over
|
||||
if (q->read_conflate){
|
||||
if (new_read_pointer != write_pointer){
|
||||
// Update read pointer
|
||||
PACK64(*q->read_pointers[id], read_cycles, new_read_pointer);
|
||||
goto start;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy message
|
||||
if (msgq_msg_init_size(msg, size) < 0)
|
||||
return -1;
|
||||
|
||||
__sync_synchronize();
|
||||
memcpy(msg->data, p + sizeof(int64_t), size);
|
||||
__sync_synchronize();
|
||||
|
||||
// Update read pointer
|
||||
PACK64(*q->read_pointers[id], read_cycles, new_read_pointer);
|
||||
|
||||
// Check if the actual data that was copied is valid
|
||||
if (!*q->read_valids[id]){
|
||||
msgq_msg_close(msg);
|
||||
msgq_reset_reader(q);
|
||||
goto start;
|
||||
}
|
||||
|
||||
|
||||
return msg->size;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int msgq_poll(msgq_pollitem_t * items, size_t nitems, int timeout){
|
||||
int num = 0;
|
||||
|
||||
// Check if messages ready
|
||||
for (size_t i = 0; i < nitems; i++) {
|
||||
items[i].revents = msgq_msg_ready(items[i].q);
|
||||
if (items[i].revents) num++;
|
||||
}
|
||||
|
||||
int ms = (timeout == -1) ? 100 : timeout;
|
||||
struct timespec ts;
|
||||
ts.tv_sec = ms / 1000;
|
||||
ts.tv_nsec = (ms % 1000) * 1000 * 1000;
|
||||
|
||||
|
||||
while (num == 0) {
|
||||
int ret;
|
||||
|
||||
ret = nanosleep(&ts, &ts);
|
||||
|
||||
// Check if messages ready
|
||||
for (size_t i = 0; i < nitems; i++) {
|
||||
if (items[i].revents == 0 && msgq_msg_ready(items[i].q)){
|
||||
num += 1;
|
||||
items[i].revents = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// exit if we had a timeout and the sleep finished
|
||||
if (timeout != -1 && ret == 0){
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return num;
|
||||
}
|
||||
|
||||
bool msgq_all_readers_updated(msgq_queue_t *q) {
|
||||
uint64_t num_readers = *q->num_readers;
|
||||
for (uint64_t i = 0; i < num_readers; i++) {
|
||||
if (*q->read_valids[i] && *q->write_pointer != *q->read_pointers[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return num_readers > 0;
|
||||
}
|
||||
70
cereal/messaging/msgq.h
Normal file
70
cereal/messaging/msgq.h
Normal file
@@ -0,0 +1,70 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <atomic>
|
||||
|
||||
#define DEFAULT_SEGMENT_SIZE (10 * 1024 * 1024)
|
||||
#define NUM_READERS 12
|
||||
#define ALIGN(n) ((n + (8 - 1)) & -8)
|
||||
|
||||
#define UNUSED(x) (void)x
|
||||
#define UNPACK64(higher, lower, input) do {uint64_t tmp = input; higher = tmp >> 32; lower = tmp & 0xFFFFFFFF;} while (0)
|
||||
#define PACK64(output, higher, lower) output = ((uint64_t)higher << 32) | ((uint64_t)lower & 0xFFFFFFFF)
|
||||
|
||||
struct msgq_header_t {
|
||||
uint64_t num_readers;
|
||||
uint64_t write_pointer;
|
||||
uint64_t write_uid;
|
||||
uint64_t read_pointers[NUM_READERS];
|
||||
uint64_t read_valids[NUM_READERS];
|
||||
uint64_t read_uids[NUM_READERS];
|
||||
};
|
||||
|
||||
struct msgq_queue_t {
|
||||
std::atomic<uint64_t> *num_readers;
|
||||
std::atomic<uint64_t> *write_pointer;
|
||||
std::atomic<uint64_t> *write_uid;
|
||||
std::atomic<uint64_t> *read_pointers[NUM_READERS];
|
||||
std::atomic<uint64_t> *read_valids[NUM_READERS];
|
||||
std::atomic<uint64_t> *read_uids[NUM_READERS];
|
||||
char * mmap_p;
|
||||
char * data;
|
||||
size_t size;
|
||||
int reader_id;
|
||||
uint64_t read_uid_local;
|
||||
uint64_t write_uid_local;
|
||||
|
||||
bool read_conflate;
|
||||
std::string endpoint;
|
||||
};
|
||||
|
||||
struct msgq_msg_t {
|
||||
size_t size;
|
||||
char * data;
|
||||
};
|
||||
|
||||
struct msgq_pollitem_t {
|
||||
msgq_queue_t *q;
|
||||
int revents;
|
||||
};
|
||||
|
||||
void msgq_wait_for_subscriber(msgq_queue_t *q);
|
||||
void msgq_reset_reader(msgq_queue_t *q);
|
||||
|
||||
int msgq_msg_init_size(msgq_msg_t *msg, size_t size);
|
||||
int msgq_msg_init_data(msgq_msg_t *msg, char * data, size_t size);
|
||||
int msgq_msg_close(msgq_msg_t *msg);
|
||||
|
||||
int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size);
|
||||
void msgq_close_queue(msgq_queue_t *q);
|
||||
void msgq_init_publisher(msgq_queue_t * q);
|
||||
void msgq_init_subscriber(msgq_queue_t * q);
|
||||
|
||||
int msgq_msg_send(msgq_msg_t *msg, msgq_queue_t *q);
|
||||
int msgq_msg_recv(msgq_msg_t *msg, msgq_queue_t *q);
|
||||
int msgq_msg_ready(msgq_queue_t * q);
|
||||
int msgq_poll(msgq_pollitem_t * items, size_t nitems, int timeout);
|
||||
|
||||
bool msgq_all_readers_updated(msgq_queue_t *q);
|
||||
54
cereal/messaging/msgq.md
Normal file
54
cereal/messaging/msgq.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# MSGQ: A lock free single producer multi consumer message queue
|
||||
|
||||
## What is MSGQ?
|
||||
MSGQ is a system to pass messages from a single producer to multiple consumers. All the consumers need to be able to receive all the messages. It is designed to be a high performance replacement for ZMQ-like SUB/PUB patterns. It uses a ring buffer in shared memory to efficiently read and write data. Each read requires a copy. Writing can be done without a copy, as long as the size of the data is known in advance.
|
||||
|
||||
## Storage
|
||||
The storage for the queue consists of an area of metadata, and the actual buffer. The metadata contains:
|
||||
|
||||
1. A counter to the number of readers that are active
|
||||
2. A pointer to the head of the queue for writing. From now on referred to as *write pointer*
|
||||
3. A cycle counter for the writer. This counter is incremented when the writer wraps around
|
||||
4. N pointers, pointing to the current read position for all the readers. From now on referred to as *read pointer*
|
||||
5. N counters, counting the number of cycles for all the readers
|
||||
6. N booleans, indicating validity for all the readers. From now on referred to as *validity flag*
|
||||
|
||||
The counter and the pointer are both 32 bit values, packed into 64 bit so they can be read and written atomically.
|
||||
|
||||
The data buffer is a ring buffer. All messages are prefixed by an 8 byte size field, followed by the data. A size of -1 indicates a wrap-around, and means the next message is stored at the beginning of the buffer.
|
||||
|
||||
|
||||
## Writing
|
||||
Writing involves the following steps:
|
||||
|
||||
1. Check if the area that is to be written overlaps with any of the read pointers, mark those readers as invalid by clearing the validity flag.
|
||||
2. Write the message
|
||||
3. Increase the write pointer by the size of the message
|
||||
|
||||
In case there is not enough space at the end of the buffer, a special empty message with a prefix of -1 is written. The cycle counter is incremented by one. In this case step 1 will check there are no read pointers pointing to the remainder of the buffer. Then another write cycle will start with the actual message.
|
||||
|
||||
There always needs to be 8 bytes of empty space at the end of the buffer. By doing this there is always space to write the -1.
|
||||
|
||||
## Reset reader
|
||||
When the reader is lagging too much behind the read pointer becomes invalid and no longer points to the beginning of a valid message. To reset a reader to the current write pointer, the following steps are performed:
|
||||
|
||||
1. Set valid flag
|
||||
2. Set read cycle counter to that of the writer
|
||||
3. Set read pointer to write pointer
|
||||
|
||||
## Reading
|
||||
Reading involves the following steps:
|
||||
|
||||
1. Read the size field at the current read pointer
|
||||
2. Read the validity flag
|
||||
3. Copy the data out of the buffer
|
||||
4. Increase the read pointer by the size of the message
|
||||
5. Check the validity flag again
|
||||
|
||||
Before starting the copy, the valid flag is checked. This is to prevent a race condition where the size prefix was invalid, and the read could read outside of the buffer. Make sure that step 1 and 2 are not reordered by your compiler or CPU.
|
||||
|
||||
If a writer overwrites the data while it's being copied out, the data will be invalid. Therefore the validity flag is also checked after reading it. The order of step 4 and 5 does not matter.
|
||||
|
||||
If at steps 2 or 5 the validity flag is not set, the reader is reset. Any data that was already read is discarded. After the reader is reset, the reading starts from the beginning.
|
||||
|
||||
If a message with size -1 is encountered, step 3 and 4 are replaced by increasing the cycle counter and setting the read pointer to the beginning of the buffer. After that another read is performed.
|
||||
394
cereal/messaging/msgq_tests.cc
Normal file
394
cereal/messaging/msgq_tests.cc
Normal file
@@ -0,0 +1,394 @@
|
||||
#include "catch2/catch.hpp"
|
||||
#include "cereal/messaging/msgq.h"
|
||||
|
||||
TEST_CASE("ALIGN"){
|
||||
REQUIRE(ALIGN(0) == 0);
|
||||
REQUIRE(ALIGN(1) == 8);
|
||||
REQUIRE(ALIGN(7) == 8);
|
||||
REQUIRE(ALIGN(8) == 8);
|
||||
REQUIRE(ALIGN(99999) == 100000);
|
||||
}
|
||||
|
||||
TEST_CASE("msgq_msg_init_size"){
|
||||
const size_t msg_size = 30;
|
||||
msgq_msg_t msg;
|
||||
|
||||
msgq_msg_init_size(&msg, msg_size);
|
||||
REQUIRE(msg.size == msg_size);
|
||||
|
||||
msgq_msg_close(&msg);
|
||||
}
|
||||
|
||||
TEST_CASE("msgq_msg_init_data"){
|
||||
const size_t msg_size = 30;
|
||||
char * data = new char[msg_size];
|
||||
|
||||
for (size_t i = 0; i < msg_size; i++){
|
||||
data[i] = i;
|
||||
}
|
||||
|
||||
msgq_msg_t msg;
|
||||
msgq_msg_init_data(&msg, data, msg_size);
|
||||
|
||||
REQUIRE(msg.size == msg_size);
|
||||
REQUIRE(memcmp(msg.data, data, msg_size) == 0);
|
||||
|
||||
delete[] data;
|
||||
msgq_msg_close(&msg);
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("msgq_init_subscriber"){
|
||||
remove("/dev/shm/test_queue");
|
||||
msgq_queue_t q;
|
||||
msgq_new_queue(&q, "test_queue", 1024);
|
||||
REQUIRE(*q.num_readers == 0);
|
||||
|
||||
q.reader_id = 1;
|
||||
*q.read_valids[0] = false;
|
||||
*q.read_pointers[0] = ((uint64_t)1 << 32);
|
||||
|
||||
*q.write_pointer = 255;
|
||||
|
||||
msgq_init_subscriber(&q);
|
||||
REQUIRE(q.read_conflate == false);
|
||||
REQUIRE(*q.read_valids[0] == true);
|
||||
REQUIRE((*q.read_pointers[0] >> 32) == 0);
|
||||
REQUIRE((*q.read_pointers[0] & 0xFFFFFFFF) == 255);
|
||||
}
|
||||
|
||||
TEST_CASE("msgq_msg_send first message"){
|
||||
remove("/dev/shm/test_queue");
|
||||
msgq_queue_t q;
|
||||
msgq_new_queue(&q, "test_queue", 1024);
|
||||
msgq_init_publisher(&q);
|
||||
|
||||
REQUIRE(*q.write_pointer == 0);
|
||||
|
||||
size_t msg_size = 128;
|
||||
|
||||
SECTION("Aligned message size"){
|
||||
}
|
||||
SECTION("Unaligned message size"){
|
||||
msg_size--;
|
||||
}
|
||||
|
||||
char * data = new char[msg_size];
|
||||
|
||||
for (size_t i = 0; i < msg_size; i++){
|
||||
data[i] = i;
|
||||
}
|
||||
|
||||
msgq_msg_t msg;
|
||||
msgq_msg_init_data(&msg, data, msg_size);
|
||||
|
||||
|
||||
msgq_msg_send(&msg, &q);
|
||||
REQUIRE(*(int64_t*)q.data == msg_size); // Check size tag
|
||||
REQUIRE(*q.write_pointer == 128 + sizeof(int64_t));
|
||||
REQUIRE(memcmp(q.data + sizeof(int64_t), data, msg_size) == 0);
|
||||
|
||||
delete[] data;
|
||||
msgq_msg_close(&msg);
|
||||
}
|
||||
|
||||
TEST_CASE("msgq_msg_send test wraparound"){
|
||||
remove("/dev/shm/test_queue");
|
||||
msgq_queue_t q;
|
||||
msgq_new_queue(&q, "test_queue", 1024);
|
||||
msgq_init_publisher(&q);
|
||||
|
||||
REQUIRE((*q.write_pointer & 0xFFFFFFFF) == 0);
|
||||
REQUIRE((*q.write_pointer >> 32) == 0);
|
||||
|
||||
const size_t msg_size = 120;
|
||||
msgq_msg_t msg;
|
||||
msgq_msg_init_size(&msg, msg_size);
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
msgq_msg_send(&msg, &q);
|
||||
}
|
||||
// Check 8th message was written at the beginning
|
||||
REQUIRE((*q.write_pointer & 0xFFFFFFFF) == msg_size + sizeof(int64_t));
|
||||
|
||||
// Check cycle count
|
||||
REQUIRE((*q.write_pointer >> 32) == 1);
|
||||
|
||||
// Check wraparound tag
|
||||
char * tag_location = q.data;
|
||||
tag_location += 7 * (msg_size + sizeof(int64_t));
|
||||
REQUIRE(*(int64_t*)tag_location == -1);
|
||||
|
||||
msgq_msg_close(&msg);
|
||||
}
|
||||
|
||||
TEST_CASE("msgq_msg_recv test wraparound"){
|
||||
remove("/dev/shm/test_queue");
|
||||
msgq_queue_t q_pub, q_sub;
|
||||
msgq_new_queue(&q_pub, "test_queue", 1024);
|
||||
msgq_new_queue(&q_sub, "test_queue", 1024);
|
||||
|
||||
msgq_init_publisher(&q_pub);
|
||||
msgq_init_subscriber(&q_sub);
|
||||
|
||||
REQUIRE((*q_pub.write_pointer >> 32) == 0);
|
||||
REQUIRE((*q_sub.read_pointers[0] >> 32) == 0);
|
||||
|
||||
const size_t msg_size = 120;
|
||||
msgq_msg_t msg1;
|
||||
msgq_msg_init_size(&msg1, msg_size);
|
||||
|
||||
|
||||
SECTION("Check cycle counter after reset") {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
msgq_msg_send(&msg1, &q_pub);
|
||||
}
|
||||
|
||||
msgq_msg_t msg2;
|
||||
msgq_msg_recv(&msg2, &q_sub);
|
||||
REQUIRE(msg2.size == 0); // Reader had to reset
|
||||
msgq_msg_close(&msg2);
|
||||
}
|
||||
SECTION("Check cycle counter while keeping up with writer") {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
msgq_msg_send(&msg1, &q_pub);
|
||||
|
||||
msgq_msg_t msg2;
|
||||
msgq_msg_recv(&msg2, &q_sub);
|
||||
REQUIRE(msg2.size > 0);
|
||||
msgq_msg_close(&msg2);
|
||||
}
|
||||
}
|
||||
|
||||
REQUIRE((*q_sub.read_pointers[0] >> 32) == 1);
|
||||
msgq_msg_close(&msg1);
|
||||
}
|
||||
|
||||
TEST_CASE("msgq_msg_send test invalidation"){
|
||||
remove("/dev/shm/test_queue");
|
||||
msgq_queue_t q_pub, q_sub;
|
||||
msgq_new_queue(&q_pub, "test_queue", 1024);
|
||||
msgq_new_queue(&q_sub, "test_queue", 1024);
|
||||
|
||||
msgq_init_publisher(&q_pub);
|
||||
msgq_init_subscriber(&q_sub);
|
||||
*q_sub.write_pointer = (uint64_t)1 << 32;
|
||||
|
||||
REQUIRE(*q_sub.read_valids[0] == true);
|
||||
|
||||
SECTION("read pointer in tag"){
|
||||
*q_sub.read_pointers[0] = 0;
|
||||
}
|
||||
SECTION("read pointer in data section"){
|
||||
*q_sub.read_pointers[0] = 64;
|
||||
}
|
||||
SECTION("read pointer in wraparound section"){
|
||||
*q_pub.write_pointer = ((uint64_t)1 << 32) | 1000; // Writer is one cycle ahead
|
||||
*q_sub.read_pointers[0] = 1020;
|
||||
}
|
||||
|
||||
msgq_msg_t msg;
|
||||
msgq_msg_init_size(&msg, 128);
|
||||
msgq_msg_send(&msg, &q_pub);
|
||||
|
||||
REQUIRE(*q_sub.read_valids[0] == false);
|
||||
|
||||
msgq_msg_close(&msg);
|
||||
}
|
||||
|
||||
TEST_CASE("msgq_init_subscriber init 2 subscribers"){
|
||||
remove("/dev/shm/test_queue");
|
||||
msgq_queue_t q1, q2;
|
||||
msgq_new_queue(&q1, "test_queue", 1024);
|
||||
msgq_new_queue(&q2, "test_queue", 1024);
|
||||
|
||||
*q1.num_readers = 0;
|
||||
|
||||
REQUIRE(*q1.num_readers == 0);
|
||||
REQUIRE(*q2.num_readers == 0);
|
||||
|
||||
msgq_init_subscriber(&q1);
|
||||
REQUIRE(*q1.num_readers == 1);
|
||||
REQUIRE(*q2.num_readers == 1);
|
||||
REQUIRE(q1.reader_id == 0);
|
||||
|
||||
msgq_init_subscriber(&q2);
|
||||
REQUIRE(*q1.num_readers == 2);
|
||||
REQUIRE(*q2.num_readers == 2);
|
||||
REQUIRE(q2.reader_id == 1);
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE("Write 1 msg, read 1 msg", "[integration]"){
|
||||
remove("/dev/shm/test_queue");
|
||||
const size_t msg_size = 128;
|
||||
msgq_queue_t writer, reader;
|
||||
|
||||
msgq_new_queue(&writer, "test_queue", 1024);
|
||||
msgq_new_queue(&reader, "test_queue", 1024);
|
||||
|
||||
msgq_init_publisher(&writer);
|
||||
msgq_init_subscriber(&reader);
|
||||
|
||||
// Build 128 byte message
|
||||
msgq_msg_t outgoing_msg;
|
||||
msgq_msg_init_size(&outgoing_msg, msg_size);
|
||||
|
||||
for (size_t i = 0; i < msg_size; i++){
|
||||
outgoing_msg.data[i] = i;
|
||||
}
|
||||
|
||||
REQUIRE(msgq_msg_send(&outgoing_msg, &writer) == msg_size);
|
||||
|
||||
msgq_msg_t incoming_msg1;
|
||||
REQUIRE(msgq_msg_recv(&incoming_msg1, &reader) == msg_size);
|
||||
REQUIRE(memcmp(incoming_msg1.data, outgoing_msg.data, msg_size) == 0);
|
||||
|
||||
// Verify that there are no more messages
|
||||
msgq_msg_t incoming_msg2;
|
||||
REQUIRE(msgq_msg_recv(&incoming_msg2, &reader) == 0);
|
||||
|
||||
msgq_msg_close(&outgoing_msg);
|
||||
msgq_msg_close(&incoming_msg1);
|
||||
msgq_msg_close(&incoming_msg2);
|
||||
}
|
||||
|
||||
TEST_CASE("Write 2 msg, read 2 msg - conflate = false", "[integration]"){
|
||||
remove("/dev/shm/test_queue");
|
||||
const size_t msg_size = 128;
|
||||
msgq_queue_t writer, reader;
|
||||
|
||||
msgq_new_queue(&writer, "test_queue", 1024);
|
||||
msgq_new_queue(&reader, "test_queue", 1024);
|
||||
|
||||
msgq_init_publisher(&writer);
|
||||
msgq_init_subscriber(&reader);
|
||||
|
||||
// Build 128 byte message
|
||||
msgq_msg_t outgoing_msg;
|
||||
msgq_msg_init_size(&outgoing_msg, msg_size);
|
||||
|
||||
for (size_t i = 0; i < msg_size; i++){
|
||||
outgoing_msg.data[i] = i;
|
||||
}
|
||||
|
||||
REQUIRE(msgq_msg_send(&outgoing_msg, &writer) == msg_size);
|
||||
REQUIRE(msgq_msg_send(&outgoing_msg, &writer) == msg_size);
|
||||
|
||||
msgq_msg_t incoming_msg1;
|
||||
REQUIRE(msgq_msg_recv(&incoming_msg1, &reader) == msg_size);
|
||||
REQUIRE(memcmp(incoming_msg1.data, outgoing_msg.data, msg_size) == 0);
|
||||
|
||||
msgq_msg_t incoming_msg2;
|
||||
REQUIRE(msgq_msg_recv(&incoming_msg2, &reader) == msg_size);
|
||||
REQUIRE(memcmp(incoming_msg2.data, outgoing_msg.data, msg_size) == 0);
|
||||
|
||||
msgq_msg_close(&outgoing_msg);
|
||||
msgq_msg_close(&incoming_msg1);
|
||||
msgq_msg_close(&incoming_msg2);
|
||||
}
|
||||
|
||||
TEST_CASE("Write 2 msg, read 2 msg - conflate = true", "[integration]"){
|
||||
remove("/dev/shm/test_queue");
|
||||
const size_t msg_size = 128;
|
||||
msgq_queue_t writer, reader;
|
||||
|
||||
msgq_new_queue(&writer, "test_queue", 1024);
|
||||
msgq_new_queue(&reader, "test_queue", 1024);
|
||||
|
||||
msgq_init_publisher(&writer);
|
||||
msgq_init_subscriber(&reader);
|
||||
reader.read_conflate = true;
|
||||
|
||||
// Build 128 byte message
|
||||
msgq_msg_t outgoing_msg;
|
||||
msgq_msg_init_size(&outgoing_msg, msg_size);
|
||||
|
||||
for (size_t i = 0; i < msg_size; i++){
|
||||
outgoing_msg.data[i] = i;
|
||||
}
|
||||
|
||||
REQUIRE(msgq_msg_send(&outgoing_msg, &writer) == msg_size);
|
||||
REQUIRE(msgq_msg_send(&outgoing_msg, &writer) == msg_size);
|
||||
|
||||
msgq_msg_t incoming_msg1;
|
||||
REQUIRE(msgq_msg_recv(&incoming_msg1, &reader) == msg_size);
|
||||
REQUIRE(memcmp(incoming_msg1.data, outgoing_msg.data, msg_size) == 0);
|
||||
|
||||
// Verify that there are no more messages
|
||||
msgq_msg_t incoming_msg2;
|
||||
REQUIRE(msgq_msg_recv(&incoming_msg2, &reader) == 0);
|
||||
|
||||
msgq_msg_close(&outgoing_msg);
|
||||
msgq_msg_close(&incoming_msg1);
|
||||
msgq_msg_close(&incoming_msg2);
|
||||
}
|
||||
|
||||
TEST_CASE("1 publisher, 1 slow subscriber", "[integration]"){
|
||||
remove("/dev/shm/test_queue");
|
||||
msgq_queue_t writer, reader;
|
||||
|
||||
msgq_new_queue(&writer, "test_queue", 1024);
|
||||
msgq_new_queue(&reader, "test_queue", 1024);
|
||||
|
||||
msgq_init_publisher(&writer);
|
||||
msgq_init_subscriber(&reader);
|
||||
|
||||
int n_received = 0;
|
||||
int n_skipped = 0;
|
||||
|
||||
for (uint64_t i = 0; i < 1e5; i++) {
|
||||
msgq_msg_t outgoing_msg;
|
||||
msgq_msg_init_data(&outgoing_msg, (char*)&i, sizeof(uint64_t));
|
||||
msgq_msg_send(&outgoing_msg, &writer);
|
||||
msgq_msg_close(&outgoing_msg);
|
||||
|
||||
if (i % 10 == 0){
|
||||
msgq_msg_t msg1;
|
||||
msgq_msg_recv(&msg1, &reader);
|
||||
|
||||
if (msg1.size == 0){
|
||||
n_skipped++;
|
||||
} else {
|
||||
n_received++;
|
||||
}
|
||||
msgq_msg_close(&msg1);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: verify these numbers by hand
|
||||
REQUIRE(n_received == 8572);
|
||||
REQUIRE(n_skipped == 1428);
|
||||
}
|
||||
|
||||
TEST_CASE("1 publisher, 2 subscribers", "[integration]"){
|
||||
remove("/dev/shm/test_queue");
|
||||
msgq_queue_t writer, reader1, reader2;
|
||||
|
||||
msgq_new_queue(&writer, "test_queue", 1024);
|
||||
msgq_new_queue(&reader1, "test_queue", 1024);
|
||||
msgq_new_queue(&reader2, "test_queue", 1024);
|
||||
|
||||
msgq_init_publisher(&writer);
|
||||
msgq_init_subscriber(&reader1);
|
||||
msgq_init_subscriber(&reader2);
|
||||
|
||||
for (uint64_t i = 0; i < 1024 * 3; i++) {
|
||||
msgq_msg_t outgoing_msg;
|
||||
msgq_msg_init_data(&outgoing_msg, (char*)&i, sizeof(uint64_t));
|
||||
msgq_msg_send(&outgoing_msg, &writer);
|
||||
|
||||
msgq_msg_t msg1, msg2;
|
||||
msgq_msg_recv(&msg1, &reader1);
|
||||
msgq_msg_recv(&msg2, &reader2);
|
||||
|
||||
REQUIRE(msg1.size == sizeof(uint64_t));
|
||||
REQUIRE(msg2.size == sizeof(uint64_t));
|
||||
REQUIRE(*(uint64_t*)msg1.data == i);
|
||||
REQUIRE(*(uint64_t*)msg2.data == i);
|
||||
|
||||
msgq_msg_close(&outgoing_msg);
|
||||
msgq_msg_close(&msg1);
|
||||
msgq_msg_close(&msg2);
|
||||
}
|
||||
}
|
||||
210
cereal/messaging/socketmaster.cc
Normal file
210
cereal/messaging/socketmaster.cc
Normal file
@@ -0,0 +1,210 @@
|
||||
#include <time.h>
|
||||
#include <assert.h>
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
|
||||
#include "cereal/services.h"
|
||||
#include "cereal/messaging/messaging.h"
|
||||
|
||||
const bool SIMULATION = (getenv("SIMULATION") != nullptr) && (std::string(getenv("SIMULATION")) == "1");
|
||||
|
||||
static inline uint64_t nanos_since_boot() {
|
||||
struct timespec t;
|
||||
clock_gettime(CLOCK_BOOTTIME, &t);
|
||||
return t.tv_sec * 1000000000ULL + t.tv_nsec;
|
||||
}
|
||||
|
||||
static inline bool inList(const std::vector<const char *> &list, const char *value) {
|
||||
for (auto &v : list) {
|
||||
if (strcmp(value, v) == 0) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
class MessageContext {
|
||||
public:
|
||||
MessageContext() : ctx_(nullptr) {}
|
||||
~MessageContext() { delete ctx_; }
|
||||
inline Context *context() {
|
||||
std::call_once(init_flag, [=]() { ctx_ = Context::create(); });
|
||||
return ctx_;
|
||||
}
|
||||
private:
|
||||
Context *ctx_;
|
||||
std::once_flag init_flag;
|
||||
};
|
||||
|
||||
MessageContext message_context;
|
||||
|
||||
struct SubMaster::SubMessage {
|
||||
std::string name;
|
||||
SubSocket *socket = nullptr;
|
||||
int freq = 0;
|
||||
bool updated = false, alive = false, valid = true, ignore_alive;
|
||||
uint64_t rcv_time = 0, rcv_frame = 0;
|
||||
void *allocated_msg_reader = nullptr;
|
||||
bool is_polled = false;
|
||||
capnp::FlatArrayMessageReader *msg_reader = nullptr;
|
||||
AlignedBuffer aligned_buf;
|
||||
cereal::Event::Reader event;
|
||||
};
|
||||
|
||||
SubMaster::SubMaster(const std::vector<const char *> &service_list, const std::vector<const char *> &poll,
|
||||
const char *address, const std::vector<const char *> &ignore_alive) {
|
||||
poller_ = Poller::create();
|
||||
for (auto name : service_list) {
|
||||
assert(services.count(std::string(name)) > 0);
|
||||
|
||||
service serv = services.at(std::string(name));
|
||||
SubSocket *socket = SubSocket::create(message_context.context(), name, address ? address : "127.0.0.1", true);
|
||||
assert(socket != 0);
|
||||
bool is_polled = inList(poll, name) || poll.empty();
|
||||
if (is_polled) poller_->registerSocket(socket);
|
||||
SubMessage *m = new SubMessage{
|
||||
.name = name,
|
||||
.socket = socket,
|
||||
.freq = serv.frequency,
|
||||
.ignore_alive = inList(ignore_alive, name),
|
||||
.allocated_msg_reader = malloc(sizeof(capnp::FlatArrayMessageReader)),
|
||||
.is_polled = is_polled};
|
||||
m->msg_reader = new (m->allocated_msg_reader) capnp::FlatArrayMessageReader({});
|
||||
messages_[socket] = m;
|
||||
services_[name] = m;
|
||||
}
|
||||
}
|
||||
|
||||
void SubMaster::update(int timeout) {
|
||||
for (auto &kv : messages_) kv.second->updated = false;
|
||||
|
||||
auto sockets = poller_->poll(timeout);
|
||||
|
||||
// add non-polled sockets for non-blocking receive
|
||||
for (auto &kv : messages_) {
|
||||
SubMessage *m = kv.second;
|
||||
SubSocket *s = kv.first;
|
||||
if (!m->is_polled) sockets.push_back(s);
|
||||
}
|
||||
|
||||
uint64_t current_time = nanos_since_boot();
|
||||
|
||||
std::vector<std::pair<std::string, cereal::Event::Reader>> messages;
|
||||
|
||||
for (auto s : sockets) {
|
||||
Message *msg = s->receive(true);
|
||||
if (msg == nullptr) continue;
|
||||
|
||||
SubMessage *m = messages_.at(s);
|
||||
|
||||
m->msg_reader->~FlatArrayMessageReader();
|
||||
capnp::ReaderOptions options;
|
||||
options.traversalLimitInWords = kj::maxValue; // Don't limit
|
||||
m->msg_reader = new (m->allocated_msg_reader) capnp::FlatArrayMessageReader(m->aligned_buf.align(msg), options);
|
||||
delete msg;
|
||||
messages.push_back({m->name, m->msg_reader->getRoot<cereal::Event>()});
|
||||
}
|
||||
|
||||
update_msgs(current_time, messages);
|
||||
}
|
||||
|
||||
void SubMaster::update_msgs(uint64_t current_time, const std::vector<std::pair<std::string, cereal::Event::Reader>> &messages){
|
||||
if (++frame == UINT64_MAX) frame = 1;
|
||||
|
||||
for (auto &kv : messages) {
|
||||
auto m_find = services_.find(kv.first);
|
||||
if (m_find == services_.end()){
|
||||
continue;
|
||||
}
|
||||
SubMessage *m = m_find->second;
|
||||
m->event = kv.second;
|
||||
m->updated = true;
|
||||
m->rcv_time = current_time;
|
||||
m->rcv_frame = frame;
|
||||
m->valid = m->event.getValid();
|
||||
if (SIMULATION) m->alive = true;
|
||||
}
|
||||
|
||||
if (!SIMULATION) {
|
||||
for (auto &kv : messages_) {
|
||||
SubMessage *m = kv.second;
|
||||
m->alive = (m->freq <= (1e-5) || ((current_time - m->rcv_time) * (1e-9)) < (10.0 / m->freq));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SubMaster::all_(const std::vector<const char *> &service_list, bool valid, bool alive) {
|
||||
int found = 0;
|
||||
for (auto &kv : messages_) {
|
||||
SubMessage *m = kv.second;
|
||||
if (service_list.size() == 0 || inList(service_list, m->name.c_str())) {
|
||||
found += (!valid || m->valid) && (!alive || (m->alive || m->ignore_alive));
|
||||
}
|
||||
}
|
||||
return service_list.size() == 0 ? found == messages_.size() : found == service_list.size();
|
||||
}
|
||||
|
||||
void SubMaster::drain() {
|
||||
while (true) {
|
||||
auto polls = poller_->poll(0);
|
||||
if (polls.size() == 0)
|
||||
break;
|
||||
|
||||
for (auto sock : polls) {
|
||||
Message *msg = sock->receive(true);
|
||||
delete msg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SubMaster::updated(const char *name) const {
|
||||
return services_.at(name)->updated;
|
||||
}
|
||||
|
||||
bool SubMaster::alive(const char *name) const {
|
||||
return services_.at(name)->alive;
|
||||
}
|
||||
|
||||
bool SubMaster::valid(const char *name) const {
|
||||
return services_.at(name)->valid;
|
||||
}
|
||||
|
||||
uint64_t SubMaster::rcv_frame(const char *name) const {
|
||||
return services_.at(name)->rcv_frame;
|
||||
}
|
||||
|
||||
uint64_t SubMaster::rcv_time(const char *name) const {
|
||||
return services_.at(name)->rcv_time;
|
||||
}
|
||||
|
||||
cereal::Event::Reader &SubMaster::operator[](const char *name) const {
|
||||
return services_.at(name)->event;
|
||||
}
|
||||
|
||||
SubMaster::~SubMaster() {
|
||||
delete poller_;
|
||||
for (auto &kv : messages_) {
|
||||
SubMessage *m = kv.second;
|
||||
m->msg_reader->~FlatArrayMessageReader();
|
||||
free(m->allocated_msg_reader);
|
||||
delete m->socket;
|
||||
delete m;
|
||||
}
|
||||
}
|
||||
|
||||
PubMaster::PubMaster(const std::vector<const char *> &service_list) {
|
||||
for (auto name : service_list) {
|
||||
assert(services.count(name) > 0);
|
||||
PubSocket *socket = PubSocket::create(message_context.context(), name);
|
||||
assert(socket);
|
||||
sockets_[name] = socket;
|
||||
}
|
||||
}
|
||||
|
||||
int PubMaster::send(const char *name, MessageBuilder &msg) {
|
||||
auto bytes = msg.toBytes();
|
||||
return send(name, bytes.begin(), bytes.size());
|
||||
}
|
||||
|
||||
PubMaster::~PubMaster() {
|
||||
for (auto s : sockets_) delete s.second;
|
||||
}
|
||||
14
cereal/messaging/stress.py
Normal file
14
cereal/messaging/stress.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from messaging_pyx import Context, SubSocket, PubSocket
|
||||
|
||||
if __name__ == "__main__":
|
||||
c = Context()
|
||||
pub_sock = PubSocket()
|
||||
pub_sock.connect(c, "controlsState")
|
||||
|
||||
for i in range(int(1e10)):
|
||||
print(i)
|
||||
sub_sock = SubSocket()
|
||||
sub_sock.connect(c, "controlsState")
|
||||
|
||||
pub_sock.send(b'a')
|
||||
print(sub_sock.receive())
|
||||
2
cereal/messaging/test_runner.cc
Normal file
2
cereal/messaging/test_runner.cc
Normal file
@@ -0,0 +1,2 @@
|
||||
#define CATCH_CONFIG_MAIN
|
||||
#include "catch2/catch.hpp"
|
||||
0
cereal/messaging/tests/__init__.py
Normal file
0
cereal/messaging/tests/__init__.py
Normal file
193
cereal/messaging/tests/test_fake.py
Normal file
193
cereal/messaging/tests/test_fake.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import os
|
||||
import unittest
|
||||
import multiprocessing
|
||||
import platform
|
||||
from parameterized import parameterized_class
|
||||
from typing import Optional
|
||||
|
||||
import cereal.messaging as messaging
|
||||
|
||||
WAIT_TIMEOUT = 5
|
||||
|
||||
|
||||
@unittest.skipIf(platform.system() == "Darwin", "Events not supported on macOS")
|
||||
class TestEvents(unittest.TestCase):
|
||||
|
||||
def test_mutation(self):
|
||||
handle = messaging.fake_event_handle("carState")
|
||||
event = handle.recv_called_event
|
||||
|
||||
self.assertFalse(event.peek())
|
||||
event.set()
|
||||
self.assertTrue(event.peek())
|
||||
event.clear()
|
||||
self.assertFalse(event.peek())
|
||||
|
||||
del event
|
||||
|
||||
def test_wait(self):
|
||||
handle = messaging.fake_event_handle("carState")
|
||||
event = handle.recv_called_event
|
||||
|
||||
event.set()
|
||||
try:
|
||||
event.wait(WAIT_TIMEOUT)
|
||||
self.assertTrue(event.peek())
|
||||
except RuntimeError:
|
||||
self.fail("event.wait() timed out")
|
||||
|
||||
def test_wait_multiprocess(self):
|
||||
handle = messaging.fake_event_handle("carState")
|
||||
event = handle.recv_called_event
|
||||
|
||||
def set_event_run():
|
||||
event.set()
|
||||
|
||||
try:
|
||||
p = multiprocessing.Process(target=set_event_run)
|
||||
p.start()
|
||||
event.wait(WAIT_TIMEOUT)
|
||||
self.assertTrue(event.peek())
|
||||
except RuntimeError:
|
||||
self.fail("event.wait() timed out")
|
||||
|
||||
p.kill()
|
||||
|
||||
def test_wait_zero_timeout(self):
|
||||
handle = messaging.fake_event_handle("carState")
|
||||
event = handle.recv_called_event
|
||||
|
||||
try:
|
||||
event.wait(0)
|
||||
self.fail("event.wait() did not time out")
|
||||
except RuntimeError:
|
||||
self.assertFalse(event.peek())
|
||||
|
||||
|
||||
@unittest.skipIf(platform.system() == "Darwin", "FakeSockets not supported on macOS")
|
||||
@unittest.skipIf("ZMQ" in os.environ, "FakeSockets not supported on ZMQ")
|
||||
@parameterized_class([{"prefix": None}, {"prefix": "test"}])
|
||||
class TestFakeSockets(unittest.TestCase):
|
||||
prefix: Optional[str] = None
|
||||
|
||||
def setUp(self):
|
||||
messaging.toggle_fake_events(True)
|
||||
if self.prefix is not None:
|
||||
messaging.set_fake_prefix(self.prefix)
|
||||
else:
|
||||
messaging.delete_fake_prefix()
|
||||
|
||||
def tearDown(self):
|
||||
messaging.toggle_fake_events(False)
|
||||
messaging.delete_fake_prefix()
|
||||
|
||||
def test_event_handle_init(self):
|
||||
handle = messaging.fake_event_handle("controlsState", override=True)
|
||||
|
||||
self.assertFalse(handle.enabled)
|
||||
self.assertGreaterEqual(handle.recv_called_event.fd, 0)
|
||||
self.assertGreaterEqual(handle.recv_ready_event.fd, 0)
|
||||
|
||||
def test_non_managed_socket_state(self):
|
||||
# non managed socket should have zero state
|
||||
_ = messaging.pub_sock("ubloxGnss")
|
||||
|
||||
handle = messaging.fake_event_handle("ubloxGnss", override=False)
|
||||
|
||||
self.assertFalse(handle.enabled)
|
||||
self.assertEqual(handle.recv_called_event.fd, 0)
|
||||
self.assertEqual(handle.recv_ready_event.fd, 0)
|
||||
|
||||
def test_managed_socket_state(self):
|
||||
# managed socket should not change anything about the state
|
||||
handle = messaging.fake_event_handle("ubloxGnss")
|
||||
handle.enabled = True
|
||||
|
||||
expected_enabled = handle.enabled
|
||||
expected_recv_called_fd = handle.recv_called_event.fd
|
||||
expected_recv_ready_fd = handle.recv_ready_event.fd
|
||||
|
||||
_ = messaging.pub_sock("ubloxGnss")
|
||||
|
||||
self.assertEqual(handle.enabled, expected_enabled)
|
||||
self.assertEqual(handle.recv_called_event.fd, expected_recv_called_fd)
|
||||
self.assertEqual(handle.recv_ready_event.fd, expected_recv_ready_fd)
|
||||
|
||||
def test_sockets_enable_disable(self):
|
||||
carState_handle = messaging.fake_event_handle("ubloxGnss", enable=True)
|
||||
recv_called = carState_handle.recv_called_event
|
||||
recv_ready = carState_handle.recv_ready_event
|
||||
|
||||
pub_sock = messaging.pub_sock("ubloxGnss")
|
||||
sub_sock = messaging.sub_sock("ubloxGnss")
|
||||
|
||||
try:
|
||||
carState_handle.enabled = True
|
||||
recv_ready.set()
|
||||
pub_sock.send(b"test")
|
||||
_ = sub_sock.receive()
|
||||
self.assertTrue(recv_called.peek())
|
||||
recv_called.clear()
|
||||
|
||||
carState_handle.enabled = False
|
||||
recv_ready.set()
|
||||
pub_sock.send(b"test")
|
||||
_ = sub_sock.receive()
|
||||
self.assertFalse(recv_called.peek())
|
||||
except RuntimeError:
|
||||
self.fail("event.wait() timed out")
|
||||
|
||||
def test_synced_pub_sub(self):
|
||||
def daemon_repub_process_run():
|
||||
pub_sock = messaging.pub_sock("ubloxGnss")
|
||||
sub_sock = messaging.sub_sock("carState")
|
||||
|
||||
frame = -1
|
||||
while True:
|
||||
frame += 1
|
||||
msg = sub_sock.receive(non_blocking=True)
|
||||
if msg is None:
|
||||
print("none received")
|
||||
continue
|
||||
|
||||
bts = frame.to_bytes(8, 'little')
|
||||
pub_sock.send(bts)
|
||||
|
||||
carState_handle = messaging.fake_event_handle("carState", enable=True)
|
||||
recv_called = carState_handle.recv_called_event
|
||||
recv_ready = carState_handle.recv_ready_event
|
||||
|
||||
p = multiprocessing.Process(target=daemon_repub_process_run)
|
||||
p.start()
|
||||
|
||||
pub_sock = messaging.pub_sock("carState")
|
||||
sub_sock = messaging.sub_sock("ubloxGnss")
|
||||
|
||||
try:
|
||||
for i in range(10):
|
||||
recv_called.wait(WAIT_TIMEOUT)
|
||||
recv_called.clear()
|
||||
|
||||
if i == 0:
|
||||
sub_sock.receive(non_blocking=True)
|
||||
|
||||
bts = i.to_bytes(8, 'little')
|
||||
pub_sock.send(bts)
|
||||
|
||||
recv_ready.set()
|
||||
recv_called.wait(WAIT_TIMEOUT)
|
||||
|
||||
msg = sub_sock.receive(non_blocking=True)
|
||||
self.assertIsNotNone(msg)
|
||||
self.assertEqual(len(msg), 8)
|
||||
|
||||
frame = int.from_bytes(msg, 'little')
|
||||
self.assertEqual(frame, i)
|
||||
except RuntimeError:
|
||||
self.fail("event.wait() timed out")
|
||||
finally:
|
||||
p.kill()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
242
cereal/messaging/tests/test_messaging.py
Normal file
242
cereal/messaging/tests/test_messaging.py
Normal file
@@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import capnp
|
||||
import multiprocessing
|
||||
import numbers
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from parameterized import parameterized
|
||||
|
||||
from cereal import log, car
|
||||
import cereal.messaging as messaging
|
||||
from cereal.services import SERVICE_LIST
|
||||
|
||||
events = [evt for evt in log.Event.schema.union_fields if evt in SERVICE_LIST.keys()]
|
||||
|
||||
def random_sock():
|
||||
return random.choice(events)
|
||||
|
||||
def random_socks(num_socks=10):
|
||||
return list({random_sock() for _ in range(num_socks)})
|
||||
|
||||
def random_bytes(length=1000):
|
||||
return bytes([random.randrange(0xFF) for _ in range(length)])
|
||||
|
||||
def zmq_sleep(t=1):
|
||||
if "ZMQ" in os.environ:
|
||||
time.sleep(t)
|
||||
|
||||
def zmq_expected_failure(func):
|
||||
if "ZMQ" in os.environ:
|
||||
return unittest.expectedFailure(func)
|
||||
else:
|
||||
return func
|
||||
|
||||
# TODO: this should take any capnp struct and returrn a msg with random populated data
|
||||
def random_carstate():
|
||||
fields = ["vEgo", "aEgo", "gas", "steeringAngleDeg"]
|
||||
msg = messaging.new_message("carState")
|
||||
cs = msg.carState
|
||||
for f in fields:
|
||||
setattr(cs, f, random.random() * 10)
|
||||
return msg
|
||||
|
||||
# TODO: this should compare any capnp structs
|
||||
def assert_carstate(cs1, cs2):
|
||||
for f in car.CarState.schema.non_union_fields:
|
||||
# TODO: check all types
|
||||
val1, val2 = getattr(cs1, f), getattr(cs2, f)
|
||||
if isinstance(val1, numbers.Number):
|
||||
assert val1 == val2, f"{f}: sent '{val1}' vs recvd '{val2}'"
|
||||
|
||||
def delayed_send(delay, sock, dat):
|
||||
def send_func():
|
||||
sock.send(dat)
|
||||
threading.Timer(delay, send_func).start()
|
||||
|
||||
class TestPubSubSockets(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# ZMQ pub socket takes too long to die
|
||||
# sleep to prevent multiple publishers error between tests
|
||||
zmq_sleep()
|
||||
|
||||
def test_pub_sub(self):
|
||||
sock = random_sock()
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
sub_sock = messaging.sub_sock(sock, conflate=False, timeout=None)
|
||||
zmq_sleep(3)
|
||||
|
||||
for _ in range(1000):
|
||||
msg = random_bytes()
|
||||
pub_sock.send(msg)
|
||||
recvd = sub_sock.receive()
|
||||
self.assertEqual(msg, recvd)
|
||||
|
||||
def test_conflate(self):
|
||||
sock = random_sock()
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
for conflate in [True, False]:
|
||||
for _ in range(10):
|
||||
num_msgs = random.randint(3, 10)
|
||||
sub_sock = messaging.sub_sock(sock, conflate=conflate, timeout=None)
|
||||
zmq_sleep()
|
||||
|
||||
sent_msgs = []
|
||||
for __ in range(num_msgs):
|
||||
msg = random_bytes()
|
||||
pub_sock.send(msg)
|
||||
sent_msgs.append(msg)
|
||||
time.sleep(0.1)
|
||||
recvd_msgs = messaging.drain_sock_raw(sub_sock)
|
||||
if conflate:
|
||||
self.assertEqual(len(recvd_msgs), 1)
|
||||
else:
|
||||
# TODO: compare actual data
|
||||
self.assertEqual(len(recvd_msgs), len(sent_msgs))
|
||||
|
||||
def test_receive_timeout(self):
|
||||
sock = random_sock()
|
||||
for _ in range(10):
|
||||
timeout = random.randrange(200)
|
||||
sub_sock = messaging.sub_sock(sock, timeout=timeout)
|
||||
zmq_sleep()
|
||||
|
||||
start_time = time.monotonic()
|
||||
recvd = sub_sock.receive()
|
||||
self.assertLess(time.monotonic() - start_time, 0.2)
|
||||
assert recvd is None
|
||||
|
||||
class TestMessaging(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# ZMQ pub socket takes too long to die
|
||||
# sleep to prevent multiple publishers error between tests
|
||||
zmq_sleep()
|
||||
|
||||
@parameterized.expand(events)
|
||||
def test_new_message(self, evt):
|
||||
try:
|
||||
msg = messaging.new_message(evt)
|
||||
except capnp.lib.capnp.KjException:
|
||||
msg = messaging.new_message(evt, random.randrange(200))
|
||||
self.assertLess(time.monotonic() - msg.logMonoTime, 0.1)
|
||||
self.assertFalse(msg.valid)
|
||||
self.assertEqual(evt, msg.which())
|
||||
|
||||
@parameterized.expand(events)
|
||||
def test_pub_sock(self, evt):
|
||||
messaging.pub_sock(evt)
|
||||
|
||||
@parameterized.expand(events)
|
||||
def test_sub_sock(self, evt):
|
||||
messaging.sub_sock(evt)
|
||||
|
||||
@parameterized.expand([
|
||||
(messaging.drain_sock, capnp._DynamicStructReader),
|
||||
(messaging.drain_sock_raw, bytes),
|
||||
])
|
||||
def test_drain_sock(self, func, expected_type):
|
||||
sock = "carState"
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
sub_sock = messaging.sub_sock(sock, timeout=1000)
|
||||
zmq_sleep()
|
||||
|
||||
# no wait and no msgs in queue
|
||||
msgs = func(sub_sock)
|
||||
self.assertIsInstance(msgs, list)
|
||||
self.assertEqual(len(msgs), 0)
|
||||
|
||||
# no wait but msgs are queued up
|
||||
num_msgs = random.randrange(3, 10)
|
||||
for _ in range(num_msgs):
|
||||
pub_sock.send(messaging.new_message(sock).to_bytes())
|
||||
time.sleep(0.1)
|
||||
msgs = func(sub_sock)
|
||||
self.assertIsInstance(msgs, list)
|
||||
self.assertTrue(all(isinstance(msg, expected_type) for msg in msgs))
|
||||
self.assertEqual(len(msgs), num_msgs)
|
||||
|
||||
def test_recv_sock(self):
|
||||
sock = "carState"
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
sub_sock = messaging.sub_sock(sock, timeout=100)
|
||||
zmq_sleep()
|
||||
|
||||
# no wait and no msg in queue, socket should timeout
|
||||
recvd = messaging.recv_sock(sub_sock)
|
||||
self.assertTrue(recvd is None)
|
||||
|
||||
# no wait and one msg in queue
|
||||
msg = random_carstate()
|
||||
pub_sock.send(msg.to_bytes())
|
||||
time.sleep(0.01)
|
||||
recvd = messaging.recv_sock(sub_sock)
|
||||
self.assertIsInstance(recvd, capnp._DynamicStructReader)
|
||||
# https://github.com/python/mypy/issues/13038
|
||||
assert_carstate(msg.carState, recvd.carState) # type: ignore[union-attr]
|
||||
|
||||
def test_recv_one(self):
|
||||
sock = "carState"
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
sub_sock = messaging.sub_sock(sock, timeout=1000)
|
||||
zmq_sleep()
|
||||
|
||||
# no msg in queue, socket should timeout
|
||||
recvd = messaging.recv_one(sub_sock)
|
||||
self.assertTrue(recvd is None)
|
||||
|
||||
# one msg in queue
|
||||
msg = random_carstate()
|
||||
pub_sock.send(msg.to_bytes())
|
||||
recvd = messaging.recv_one(sub_sock)
|
||||
self.assertIsInstance(recvd, capnp._DynamicStructReader)
|
||||
assert_carstate(msg.carState, recvd.carState) # type: ignore[union-attr]
|
||||
|
||||
@zmq_expected_failure
|
||||
def test_recv_one_or_none(self):
|
||||
sock = "carState"
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
sub_sock = messaging.sub_sock(sock)
|
||||
zmq_sleep()
|
||||
|
||||
# no msg in queue, socket shouldn't block
|
||||
recvd = messaging.recv_one_or_none(sub_sock)
|
||||
self.assertTrue(recvd is None)
|
||||
|
||||
# one msg in queue
|
||||
msg = random_carstate()
|
||||
pub_sock.send(msg.to_bytes())
|
||||
recvd = messaging.recv_one_or_none(sub_sock)
|
||||
self.assertIsInstance(recvd, capnp._DynamicStructReader)
|
||||
assert_carstate(msg.carState, recvd.carState) # type: ignore[union-attr]
|
||||
|
||||
def test_recv_one_retry(self):
|
||||
sock = "carState"
|
||||
sock_timeout = 0.1
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
sub_sock = messaging.sub_sock(sock, timeout=round(sock_timeout*1000))
|
||||
zmq_sleep()
|
||||
|
||||
# this test doesn't work with ZMQ since multiprocessing interrupts it
|
||||
if "ZMQ" not in os.environ:
|
||||
# wait 15 socket timeouts and make sure it's still retrying
|
||||
p = multiprocessing.Process(target=messaging.recv_one_retry, args=(sub_sock,))
|
||||
p.start()
|
||||
time.sleep(sock_timeout*15)
|
||||
self.assertTrue(p.is_alive())
|
||||
p.terminate()
|
||||
|
||||
# wait 15 socket timeouts before sending
|
||||
msg = random_carstate()
|
||||
delayed_send(sock_timeout*15, pub_sock, msg.to_bytes())
|
||||
start_time = time.monotonic()
|
||||
recvd = messaging.recv_one_retry(sub_sock)
|
||||
self.assertGreaterEqual(time.monotonic() - start_time, sock_timeout*15)
|
||||
self.assertIsInstance(recvd, capnp._DynamicStructReader)
|
||||
assert_carstate(msg.carState, recvd.carState)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
142
cereal/messaging/tests/test_poller.py
Normal file
142
cereal/messaging/tests/test_poller.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import unittest
|
||||
import time
|
||||
import cereal.messaging as messaging
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def poller():
|
||||
context = messaging.Context()
|
||||
|
||||
p = messaging.Poller()
|
||||
|
||||
sub = messaging.SubSocket()
|
||||
sub.connect(context, 'controlsState')
|
||||
p.registerSocket(sub)
|
||||
|
||||
socks = p.poll(10000)
|
||||
r = [s.receive(non_blocking=True) for s in socks]
|
||||
|
||||
return r
|
||||
|
||||
|
||||
class TestPoller(unittest.TestCase):
|
||||
def test_poll_once(self):
|
||||
context = messaging.Context()
|
||||
|
||||
pub = messaging.PubSocket()
|
||||
pub.connect(context, 'controlsState')
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as e:
|
||||
poll = e.submit(poller)
|
||||
|
||||
time.sleep(0.1) # Slow joiner syndrome
|
||||
|
||||
# Send message
|
||||
pub.send(b"a")
|
||||
|
||||
# Wait for poll result
|
||||
result = poll.result()
|
||||
|
||||
del pub
|
||||
context.term()
|
||||
|
||||
self.assertEqual(result, [b"a"])
|
||||
|
||||
def test_poll_and_create_many_subscribers(self):
|
||||
context = messaging.Context()
|
||||
|
||||
pub = messaging.PubSocket()
|
||||
pub.connect(context, 'controlsState')
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as e:
|
||||
poll = e.submit(poller)
|
||||
|
||||
time.sleep(0.1) # Slow joiner syndrome
|
||||
c = messaging.Context()
|
||||
for _ in range(10):
|
||||
messaging.SubSocket().connect(c, 'controlsState')
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
# Send message
|
||||
pub.send(b"a")
|
||||
|
||||
# Wait for poll result
|
||||
result = poll.result()
|
||||
|
||||
del pub
|
||||
context.term()
|
||||
|
||||
self.assertEqual(result, [b"a"])
|
||||
|
||||
def test_multiple_publishers_exception(self):
|
||||
context = messaging.Context()
|
||||
|
||||
with self.assertRaises(messaging.MultiplePublishersError):
|
||||
pub1 = messaging.PubSocket()
|
||||
pub1.connect(context, 'controlsState')
|
||||
|
||||
pub2 = messaging.PubSocket()
|
||||
pub2.connect(context, 'controlsState')
|
||||
|
||||
pub1.send(b"a")
|
||||
|
||||
del pub1
|
||||
del pub2
|
||||
context.term()
|
||||
|
||||
def test_multiple_messages(self):
|
||||
context = messaging.Context()
|
||||
|
||||
pub = messaging.PubSocket()
|
||||
pub.connect(context, 'controlsState')
|
||||
|
||||
sub = messaging.SubSocket()
|
||||
sub.connect(context, 'controlsState')
|
||||
|
||||
time.sleep(0.1) # Slow joiner
|
||||
|
||||
for i in range(1, 100):
|
||||
pub.send(b'a'*i)
|
||||
|
||||
msg_seen = False
|
||||
i = 1
|
||||
while True:
|
||||
r = sub.receive(non_blocking=True)
|
||||
|
||||
if r is not None:
|
||||
self.assertEqual(b'a'*i, r)
|
||||
|
||||
msg_seen = True
|
||||
i += 1
|
||||
|
||||
if r is None and msg_seen: # ZMQ sometimes receives nothing on the first receive
|
||||
break
|
||||
|
||||
del pub
|
||||
del sub
|
||||
context.term()
|
||||
|
||||
def test_conflate(self):
|
||||
context = messaging.Context()
|
||||
|
||||
pub = messaging.PubSocket()
|
||||
pub.connect(context, 'controlsState')
|
||||
|
||||
sub = messaging.SubSocket()
|
||||
sub.connect(context, 'controlsState', conflate=True)
|
||||
|
||||
time.sleep(0.1) # Slow joiner
|
||||
pub.send(b'a')
|
||||
pub.send(b'b')
|
||||
|
||||
self.assertEqual(b'b', sub.receive())
|
||||
|
||||
del pub
|
||||
del sub
|
||||
context.term()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
163
cereal/messaging/tests/test_pub_sub_master.py
Normal file
163
cereal/messaging/tests/test_pub_sub_master.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#!/usr/bin/env python3
|
||||
import random
|
||||
import time
|
||||
from typing import Sized, cast
|
||||
import unittest
|
||||
|
||||
import cereal.messaging as messaging
|
||||
from cereal.messaging.tests.test_messaging import events, random_sock, random_socks, \
|
||||
random_bytes, random_carstate, assert_carstate, \
|
||||
zmq_sleep
|
||||
|
||||
|
||||
class TestSubMaster(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# ZMQ pub socket takes too long to die
|
||||
# sleep to prevent multiple publishers error between tests
|
||||
zmq_sleep(3)
|
||||
|
||||
def test_init(self):
|
||||
sm = messaging.SubMaster(events)
|
||||
for p in [sm.updated, sm.recv_time, sm.recv_frame, sm.alive,
|
||||
sm.sock, sm.data, sm.logMonoTime, sm.valid]:
|
||||
self.assertEqual(len(cast(Sized, p)), len(events))
|
||||
|
||||
def test_init_state(self):
|
||||
socks = random_socks()
|
||||
sm = messaging.SubMaster(socks)
|
||||
self.assertEqual(sm.frame, -1)
|
||||
self.assertFalse(any(sm.updated.values()))
|
||||
self.assertFalse(any(sm.alive.values()))
|
||||
self.assertTrue(all(t == 0. for t in sm.recv_time.values()))
|
||||
self.assertTrue(all(f == 0 for f in sm.recv_frame.values()))
|
||||
self.assertTrue(all(t == 0 for t in sm.logMonoTime.values()))
|
||||
|
||||
for p in [sm.updated, sm.recv_time, sm.recv_frame, sm.alive,
|
||||
sm.sock, sm.data, sm.logMonoTime, sm.valid]:
|
||||
self.assertEqual(len(cast(Sized, p)), len(socks))
|
||||
|
||||
def test_getitem(self):
|
||||
sock = "carState"
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
sm = messaging.SubMaster([sock,])
|
||||
zmq_sleep()
|
||||
|
||||
msg = random_carstate()
|
||||
pub_sock.send(msg.to_bytes())
|
||||
sm.update(1000)
|
||||
assert_carstate(msg.carState, sm[sock])
|
||||
|
||||
# TODO: break this test up to individually test SubMaster.update and SubMaster.update_msgs
|
||||
def test_update(self):
|
||||
sock = "carState"
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
sm = messaging.SubMaster([sock,])
|
||||
zmq_sleep()
|
||||
|
||||
for i in range(10):
|
||||
msg = messaging.new_message(sock)
|
||||
pub_sock.send(msg.to_bytes())
|
||||
sm.update(1000)
|
||||
self.assertEqual(sm.frame, i)
|
||||
self.assertTrue(all(sm.updated.values()))
|
||||
|
||||
def test_update_timeout(self):
|
||||
sock = random_sock()
|
||||
sm = messaging.SubMaster([sock,])
|
||||
for _ in range(5):
|
||||
timeout = random.randrange(1000, 5000)
|
||||
start_time = time.monotonic()
|
||||
sm.update(timeout)
|
||||
t = time.monotonic() - start_time
|
||||
self.assertGreaterEqual(t, timeout/1000.)
|
||||
self.assertLess(t, 5)
|
||||
self.assertFalse(any(sm.updated.values()))
|
||||
|
||||
def test_avg_frequency_checks(self):
|
||||
for poll in (True, False):
|
||||
sm = messaging.SubMaster(["modelV2", "carParams", "carState", "cameraOdometry", "liveCalibration"],
|
||||
poll=("modelV2" if poll else None),
|
||||
frequency=(20. if not poll else None))
|
||||
|
||||
checks = {
|
||||
"carState": (20, 20),
|
||||
"modelV2": (20, 20 if poll else 10),
|
||||
"cameraOdometry": (20, 10),
|
||||
"liveCalibration": (4, 4),
|
||||
"carParams": (None, None),
|
||||
}
|
||||
|
||||
for service, (max_freq, min_freq) in checks.items():
|
||||
if max_freq is not None:
|
||||
assert sm._check_avg_freq(service)
|
||||
assert sm.max_freq[service] == max_freq*1.2
|
||||
assert sm.min_freq[service] == min_freq*0.8
|
||||
else:
|
||||
assert not sm._check_avg_freq(service)
|
||||
|
||||
def test_alive(self):
|
||||
pass
|
||||
|
||||
def test_ignore_alive(self):
|
||||
pass
|
||||
|
||||
def test_valid(self):
|
||||
pass
|
||||
|
||||
# SubMaster should always conflate
|
||||
def test_conflate(self):
|
||||
sock = "carState"
|
||||
pub_sock = messaging.pub_sock(sock)
|
||||
sm = messaging.SubMaster([sock,])
|
||||
|
||||
n = 10
|
||||
for i in range(n+1):
|
||||
msg = messaging.new_message(sock)
|
||||
msg.carState.vEgo = i
|
||||
pub_sock.send(msg.to_bytes())
|
||||
time.sleep(0.01)
|
||||
sm.update(1000)
|
||||
self.assertEqual(sm[sock].vEgo, n)
|
||||
|
||||
|
||||
class TestPubMaster(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# ZMQ pub socket takes too long to die
|
||||
# sleep to prevent multiple publishers error between tests
|
||||
zmq_sleep(3)
|
||||
|
||||
def test_init(self):
|
||||
messaging.PubMaster(events)
|
||||
|
||||
def test_send(self):
|
||||
socks = random_socks()
|
||||
pm = messaging.PubMaster(socks)
|
||||
sub_socks = {s: messaging.sub_sock(s, conflate=True, timeout=1000) for s in socks}
|
||||
zmq_sleep()
|
||||
|
||||
# PubMaster accepts either a capnp msg builder or bytes
|
||||
for capnp in [True, False]:
|
||||
for i in range(100):
|
||||
sock = socks[i % len(socks)]
|
||||
|
||||
if capnp:
|
||||
try:
|
||||
msg = messaging.new_message(sock)
|
||||
except Exception:
|
||||
msg = messaging.new_message(sock, random.randrange(50))
|
||||
else:
|
||||
msg = random_bytes()
|
||||
|
||||
pm.send(sock, msg)
|
||||
recvd = sub_socks[sock].receive()
|
||||
|
||||
if capnp:
|
||||
msg.clear_write_flag()
|
||||
msg = msg.to_bytes()
|
||||
self.assertEqual(msg, recvd, i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
33
cereal/messaging/tests/test_services.py
Normal file
33
cereal/messaging/tests/test_services.py
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict
|
||||
import unittest
|
||||
from parameterized import parameterized
|
||||
|
||||
import cereal.services as services
|
||||
from cereal.services import SERVICE_LIST, RESERVED_PORT, STARTING_PORT
|
||||
|
||||
|
||||
class TestServices(unittest.TestCase):
|
||||
|
||||
@parameterized.expand(SERVICE_LIST.keys())
|
||||
def test_services(self, s):
|
||||
service = SERVICE_LIST[s]
|
||||
self.assertTrue(service.port != RESERVED_PORT)
|
||||
self.assertTrue(service.port >= STARTING_PORT)
|
||||
self.assertTrue(service.frequency <= 104)
|
||||
|
||||
def test_no_duplicate_port(self):
|
||||
ports: Dict[int, str] = {}
|
||||
for name, service in SERVICE_LIST.items():
|
||||
self.assertFalse(service.port in ports.keys(), f"duplicate port {service.port}")
|
||||
ports[service.port] = name
|
||||
|
||||
def test_generated_header(self):
|
||||
with tempfile.NamedTemporaryFile(suffix=".h") as f:
|
||||
ret = os.system(f"python3 {services.__file__} > {f.name} && clang++ {f.name}")
|
||||
self.assertEqual(ret, 0, "generated services header is not valid C")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user