Lift double routing detection out of Routing

This commit is contained in:
Mygod
2019-03-09 01:46:01 -05:00
parent b9292b8be4
commit aedba90196
5 changed files with 24 additions and 31 deletions

View File

@@ -3,7 +3,8 @@ package be.mygod.vpnhotspot
import be.mygod.vpnhotspot.App.Companion.app import be.mygod.vpnhotspot.App.Companion.app
import be.mygod.vpnhotspot.net.Routing import be.mygod.vpnhotspot.net.Routing
import be.mygod.vpnhotspot.net.wifi.WifiDoubleLock import be.mygod.vpnhotspot.net.wifi.WifiDoubleLock
import be.mygod.vpnhotspot.util.Event0 import be.mygod.vpnhotspot.util.putIfAbsentCompat
import be.mygod.vpnhotspot.util.removeCompat
import be.mygod.vpnhotspot.widget.SmartSnackbar import be.mygod.vpnhotspot.widget.SmartSnackbar
import timber.log.Timber import timber.log.Timber
@@ -18,11 +19,10 @@ abstract class RoutingManager(private val caller: Any, val downstream: String, p
} }
set(value) = app.pref.edit().putString(KEY_MASQUERADE_MODE, value.name).apply() set(value) = app.pref.edit().putString(KEY_MASQUERADE_MODE, value.name).apply()
private val onPreCleanRoutings = Event0() private val active = mutableMapOf<String, RoutingManager>()
private val onRoutingsCleaned = Event0()
fun clean() { fun clean() {
onPreCleanRoutings() for (manager in active.values) manager.routing?.stop()
val cleaned = try { val cleaned = try {
Routing.clean() Routing.clean()
true true
@@ -31,7 +31,7 @@ abstract class RoutingManager(private val caller: Any, val downstream: String, p
SmartSnackbar.make(e).show() SmartSnackbar.make(e).show()
false false
} }
if (cleaned) onRoutingsCleaned() if (cleaned) for (manager in active.values) manager.initRouting()
} }
} }
@@ -47,18 +47,16 @@ abstract class RoutingManager(private val caller: Any, val downstream: String, p
} }
} }
var started = false val started get() = active[downstream] === this
private var routing: Routing? = null private var routing: Routing? = null
init { init {
if (isWifi) WifiDoubleLock.acquire(this) if (isWifi) WifiDoubleLock.acquire(this)
} }
fun start(): Boolean { fun start() = when (active.putIfAbsentCompat(downstream, this)) {
check(!started) null -> initRouting()
started = true this -> true // already started
onPreCleanRoutings[this] = { routing?.stop() } else -> throw IllegalStateException("Double routing detected from $caller")
onRoutingsCleaned[this] = { initRouting() }
return initRouting()
} }
private fun initRouting() = try { private fun initRouting() = try {
@@ -81,11 +79,7 @@ abstract class RoutingManager(private val caller: Any, val downstream: String, p
protected abstract fun Routing.configure() protected abstract fun Routing.configure()
fun stop() { fun stop() {
if (!started) return if (active.removeCompat(downstream, this)) routing?.revert()
routing?.revert()
onPreCleanRoutings -= this
onRoutingsCleaned -= this
started = false
} }
fun destroy() { fun destroy() {

View File

@@ -50,7 +50,7 @@ class TetheringService : IpNeighbourMonitoringService() {
val toRemove = downstreams.toMutableMap() // make a copy val toRemove = downstreams.toMutableMap() // make a copy
for (iface in TetheringManager.getTetheredIfaces(extras)) { for (iface in TetheringManager.getTetheredIfaces(extras)) {
val downstream = toRemove.remove(iface) ?: continue val downstream = toRemove.remove(iface) ?: continue
if (downstream.monitor && !downstream.started) downstream.start() if (downstream.monitor) downstream.start()
} }
for ((iface, downstream) in toRemove) { for ((iface, downstream) in toRemove) {
if (downstream.monitor) downstream.stop() else downstreams.remove(iface)?.destroy() if (downstream.monitor) downstream.stop() else downstreams.remove(iface)?.destroy()

View File

@@ -42,11 +42,6 @@ class Routing(private val caller: Any, private val downstream: String) : IpNeigh
*/ */
val IPTABLES = if (Build.VERSION.SDK_INT >= 26) "iptables -w 1" else "iptables -w" val IPTABLES = if (Build.VERSION.SDK_INT >= 26) "iptables -w 1" else "iptables -w"
/**
* For debugging: check that we do not start a Routing for the same interface twice.
*/
private var downstreams = mutableSetOf<String>()
fun clean() { fun clean() {
TrafficRecorder.clean() TrafficRecorder.clean()
RootSession.use { RootSession.use {
@@ -60,7 +55,6 @@ class Routing(private val caller: Any, private val downstream: String) : IpNeigh
it.execQuiet("while ip rule del priority $RULE_PRIORITY_UPSTREAM; do done") it.execQuiet("while ip rule del priority $RULE_PRIORITY_UPSTREAM; do done")
it.execQuiet("while ip rule del priority $RULE_PRIORITY_UPSTREAM_FALLBACK; do done") it.execQuiet("while ip rule del priority $RULE_PRIORITY_UPSTREAM_FALLBACK; do done")
} }
downstreams.clear()
} }
private fun RootSession.Transaction.iptables(command: String, revert: String) { private fun RootSession.Transaction.iptables(command: String, revert: String) {
@@ -89,10 +83,6 @@ class Routing(private val caller: Any, private val downstream: String) : IpNeigh
override val message: String get() = app.getString(R.string.exception_interface_not_found) override val message: String get() = app.getString(R.string.exception_interface_not_found)
} }
init {
check(downstreams.add(downstream)) { "Double routing detected from $caller" }
}
private val hostAddress = try { private val hostAddress = try {
NetworkInterface.getByName(downstream)!!.interfaceAddresses!!.asSequence().single { it.address is Inet4Address } NetworkInterface.getByName(downstream)!!.interfaceAddresses!!.asSequence().single { it.address is Inet4Address }
} catch (e: Exception) { } catch (e: Exception) {
@@ -314,6 +304,5 @@ class Routing(private val caller: Any, private val downstream: String) : IpNeigh
fallbackUpstream.subrouting?.transaction?.revert() fallbackUpstream.subrouting?.transaction?.revert()
upstream.subrouting?.transaction?.revert() upstream.subrouting?.transaction?.revert()
transaction.revert() transaction.revert()
check(downstreams.remove(downstream)) { "Double reverting detected from $caller" }
} }
} }

View File

@@ -9,6 +9,7 @@ import be.mygod.vpnhotspot.room.TrafficRecord
import be.mygod.vpnhotspot.util.Event2 import be.mygod.vpnhotspot.util.Event2
import be.mygod.vpnhotspot.util.RootSession import be.mygod.vpnhotspot.util.RootSession
import be.mygod.vpnhotspot.util.parseNumericAddress import be.mygod.vpnhotspot.util.parseNumericAddress
import be.mygod.vpnhotspot.util.putIfAbsentCompat
import be.mygod.vpnhotspot.widget.SmartSnackbar import be.mygod.vpnhotspot.widget.SmartSnackbar
import timber.log.Timber import timber.log.Timber
import java.net.InetAddress import java.net.InetAddress
@@ -28,7 +29,7 @@ object TrafficRecorder {
AppDatabase.instance.trafficRecordDao.insert(record) AppDatabase.instance.trafficRecordDao.insert(record)
synchronized(this) { synchronized(this) {
DebugHelper.log(TAG, "Registering $ip%$downstream") DebugHelper.log(TAG, "Registering $ip%$downstream")
check(records.put(Pair(ip, downstream), record) == null) check(records.putIfAbsentCompat(Pair(ip, downstream), record) == null)
scheduleUpdateLocked() scheduleUpdateLocked()
} }
} }

View File

@@ -91,5 +91,14 @@ fun Context.stopAndUnbind(connection: ServiceConnection) {
unbindService(connection) unbindService(connection)
} }
fun <K, V> HashMap<K, V>.computeIfAbsentCompat(key: K, value: () -> V) = if (Build.VERSION.SDK_INT >= 26) fun <K, V> MutableMap<K, V>.computeIfAbsentCompat(key: K, value: () -> V) = if (Build.VERSION.SDK_INT >= 26)
computeIfAbsent(key) { value() } else this[key] ?: value().also { put(key, it) } computeIfAbsent(key) { value() } else this[key] ?: value().also { put(key, it) }
fun <K, V> MutableMap<K, V>.putIfAbsentCompat(key: K, value: V) = if (Build.VERSION.SDK_INT >= 24)
putIfAbsent(key, value) else this[key] ?: put(key, value)
fun <K, V> MutableMap<K, V>.removeCompat(key: K, value: V) = if (Build.VERSION.SDK_INT >= 24) remove(key, value) else {
val curValue = get(key)
if (curValue === value && (curValue != null || containsKey(key))) {
remove(key)
true
} else false
}