Allow ordinary execute to be cancelled
This commit is contained in:
@@ -20,13 +20,16 @@ import java.util.concurrent.CountDownLatch
|
|||||||
import kotlin.system.exitProcess
|
import kotlin.system.exitProcess
|
||||||
|
|
||||||
class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) {
|
class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) {
|
||||||
private sealed class Callback {
|
private sealed class Callback(protected val server: RootServer, protected val index: Long) {
|
||||||
|
var active = true
|
||||||
|
|
||||||
abstract fun cancel()
|
abstract fun cancel()
|
||||||
abstract fun shouldRemove(result: Byte): Boolean
|
abstract fun shouldRemove(result: Byte): Boolean
|
||||||
abstract operator fun invoke(input: DataInputStream, result: Byte)
|
abstract operator fun invoke(input: DataInputStream, result: Byte)
|
||||||
|
suspend fun sendClosed() = withContext(NonCancellable) { server.execute(CancelCommand(index)) }
|
||||||
|
|
||||||
class Ordinary(private val classLoader: ClassLoader?,
|
class Ordinary(server: RootServer, index: Long, private val classLoader: ClassLoader?,
|
||||||
private val callback: CompletableDeferred<Parcelable?>) : Callback() {
|
private val callback: CompletableDeferred<Parcelable?>) : Callback(server, index) {
|
||||||
override fun cancel() = callback.cancel()
|
override fun cancel() = callback.cancel()
|
||||||
override fun shouldRemove(result: Byte) = true
|
override fun shouldRemove(result: Byte) = true
|
||||||
override fun invoke(input: DataInputStream, result: Byte) {
|
override fun invoke(input: DataInputStream, result: Byte) {
|
||||||
@@ -42,11 +45,8 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class Channel(private val classLoader: ClassLoader?,
|
class Channel(server: RootServer, index: Long, private val classLoader: ClassLoader?,
|
||||||
private val channel: SendChannel<Parcelable?>,
|
private val channel: SendChannel<Parcelable?>) : Callback(server, index) {
|
||||||
private val server: RootServer,
|
|
||||||
private val index: Long) : Callback() {
|
|
||||||
var active = true
|
|
||||||
val finish: CompletableDeferred<Unit> = CompletableDeferred()
|
val finish: CompletableDeferred<Unit> = CompletableDeferred()
|
||||||
override fun cancel() = finish.cancel()
|
override fun cancel() = finish.cancel()
|
||||||
override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS
|
override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS
|
||||||
@@ -70,8 +70,6 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
else -> throw IllegalArgumentException("Unexpected result $result")
|
else -> throw IllegalArgumentException("Unexpected result $result")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
suspend fun sendClosed() = server.execute(ChannelClosed(index))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +163,12 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
}
|
}
|
||||||
val result = input.readByte()
|
val result = input.readByte()
|
||||||
val callback = mutex.synchronized {
|
val callback = mutex.synchronized {
|
||||||
callbackLookup[index]!!.also { if (it.shouldRemove(result)) callbackLookup.remove(index) }
|
callbackLookup[index]!!.also {
|
||||||
|
if (it.shouldRemove(result)) {
|
||||||
|
callbackLookup.remove(index)
|
||||||
|
it.active = false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (DEBUG) Log.d(TAG, "Received callback #$index: $result")
|
if (DEBUG) Log.d(TAG, "Received callback #$index: $result")
|
||||||
callback(input, result)
|
callback(input, result)
|
||||||
@@ -228,14 +231,20 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
@Throws(RemoteException::class)
|
@Throws(RemoteException::class)
|
||||||
suspend fun <T : Parcelable?> execute(command: RootCommand<T>, classLoader: ClassLoader?): T {
|
suspend fun <T : Parcelable?> execute(command: RootCommand<T>, classLoader: ClassLoader?): T {
|
||||||
val future = CompletableDeferred<T>()
|
val future = CompletableDeferred<T>()
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
val callback = Callback.Ordinary(this, counter, classLoader, future as CompletableDeferred<Parcelable?>)
|
||||||
mutex.withLock {
|
mutex.withLock {
|
||||||
if (active) {
|
if (active) {
|
||||||
@Suppress("UNCHECKED_CAST")
|
callbackLookup[counter] = callback
|
||||||
callbackLookup[counter] = Callback.Ordinary(classLoader, future as CompletableDeferred<Parcelable?>)
|
|
||||||
sendLocked(command)
|
sendLocked(command)
|
||||||
} else future.cancel()
|
} else future.cancel()
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
return future.await()
|
return future.await()
|
||||||
|
} finally {
|
||||||
|
if (callback.active) callback.sendClosed()
|
||||||
|
callback.active = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
@@ -253,7 +262,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
}
|
}
|
||||||
}) {
|
}) {
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
val callback = Callback.Channel(classLoader, this as SendChannel<Parcelable?>, this@RootServer, counter)
|
val callback = Callback.Channel(this@RootServer, counter, classLoader, this as SendChannel<Parcelable?>)
|
||||||
mutex.withLock {
|
mutex.withLock {
|
||||||
if (active) {
|
if (active) {
|
||||||
callbackLookup[counter] = callback
|
callbackLookup[counter] = callback
|
||||||
@@ -263,7 +272,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
try {
|
try {
|
||||||
callback.finish.await()
|
callback.finish.await()
|
||||||
} finally {
|
} finally {
|
||||||
if (callback.active) withContext(NonCancellable) { callback.sendClosed() }
|
if (callback.active) callback.sendClosed()
|
||||||
callback.active = false
|
callback.active = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -381,7 +390,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
CoroutineScope(Dispatchers.Main.immediate + job)
|
CoroutineScope(Dispatchers.Main.immediate + job)
|
||||||
}
|
}
|
||||||
val callbackWorker = newSingleThreadContext("callbackWorker")
|
val callbackWorker = newSingleThreadContext("callbackWorker")
|
||||||
val channels = LongSparseArray<ReceiveChannel<Parcelable?>>()
|
val cancellables = LongSparseArray<() -> Unit>()
|
||||||
|
|
||||||
// thread safety: usage of output should be guarded by callbackWorker
|
// thread safety: usage of output should be guarded by callbackWorker
|
||||||
val output = DataOutputStream(System.out.buffered().apply {
|
val output = DataOutputStream(System.out.buffered().apply {
|
||||||
@@ -402,7 +411,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
val callback = counter
|
val callback = counter
|
||||||
if (DEBUG) Log.d(TAG, "Received #$callback: $command")
|
if (DEBUG) Log.d(TAG, "Received #$callback: $command")
|
||||||
when (command) {
|
when (command) {
|
||||||
is ChannelClosed -> channels[command.index]?.cancel()
|
is CancelCommand -> cancellables[command.index]?.invoke()
|
||||||
is RootCommandOneWay -> defaultWorker.launch {
|
is RootCommandOneWay -> defaultWorker.launch {
|
||||||
try {
|
try {
|
||||||
command.execute()
|
command.execute()
|
||||||
@@ -422,7 +431,9 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
is RootCommandChannel<*> -> defaultWorker.launch {
|
is RootCommandChannel<*> -> defaultWorker.launch {
|
||||||
val result = try {
|
val result = try {
|
||||||
coroutineScope {
|
coroutineScope {
|
||||||
command.create(this).also { channels[callback] = it }.consumeEach { result ->
|
command.create(this).also {
|
||||||
|
cancellables[callback] = { it.cancel() }
|
||||||
|
}.consumeEach { result ->
|
||||||
withContext(callbackWorker) { output.pushResult(callback, result) }
|
withContext(callbackWorker) { output.pushResult(callback, result) }
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -434,7 +445,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
|||||||
} catch (e: Throwable) {
|
} catch (e: Throwable) {
|
||||||
{ output.pushThrowable(callback, e) }
|
{ output.pushThrowable(callback, e) }
|
||||||
} finally {
|
} finally {
|
||||||
channels.remove(callback)
|
cancellables.remove(callback)
|
||||||
}
|
}
|
||||||
withContext(callbackWorker) { result() }
|
withContext(callbackWorker) { result() }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ interface RootCommandChannel<T : Parcelable?> : Parcelable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Parcelize
|
@Parcelize
|
||||||
internal class ChannelClosed(val index: Long) : RootCommandOneWay {
|
internal class CancelCommand(val index: Long) : RootCommandOneWay {
|
||||||
override suspend fun execute() = error("Internal implementation")
|
override suspend fun execute() = error("Internal implementation")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user