Allow ordinary execute to be cancelled

This commit is contained in:
Mygod
2020-07-03 07:56:35 +08:00
parent 798275e9c9
commit 69750f6609
2 changed files with 32 additions and 21 deletions

View File

@@ -20,13 +20,16 @@ import java.util.concurrent.CountDownLatch
import kotlin.system.exitProcess
class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) {
private sealed class Callback {
private sealed class Callback(protected val server: RootServer, protected val index: Long) {
var active = true
abstract fun cancel()
abstract fun shouldRemove(result: Byte): Boolean
abstract operator fun invoke(input: DataInputStream, result: Byte)
suspend fun sendClosed() = withContext(NonCancellable) { server.execute(CancelCommand(index)) }
class Ordinary(private val classLoader: ClassLoader?,
private val callback: CompletableDeferred<Parcelable?>) : Callback() {
class Ordinary(server: RootServer, index: Long, private val classLoader: ClassLoader?,
private val callback: CompletableDeferred<Parcelable?>) : Callback(server, index) {
override fun cancel() = callback.cancel()
override fun shouldRemove(result: Byte) = true
override fun invoke(input: DataInputStream, result: Byte) {
@@ -42,11 +45,8 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
}
}
class Channel(private val classLoader: ClassLoader?,
private val channel: SendChannel<Parcelable?>,
private val server: RootServer,
private val index: Long) : Callback() {
var active = true
class Channel(server: RootServer, index: Long, private val classLoader: ClassLoader?,
private val channel: SendChannel<Parcelable?>) : Callback(server, index) {
val finish: CompletableDeferred<Unit> = CompletableDeferred()
override fun cancel() = finish.cancel()
override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS
@@ -70,8 +70,6 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
else -> throw IllegalArgumentException("Unexpected result $result")
}
}
suspend fun sendClosed() = server.execute(ChannelClosed(index))
}
}
@@ -165,7 +163,12 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
}
val result = input.readByte()
val callback = mutex.synchronized {
callbackLookup[index]!!.also { if (it.shouldRemove(result)) callbackLookup.remove(index) }
callbackLookup[index]!!.also {
if (it.shouldRemove(result)) {
callbackLookup.remove(index)
it.active = false
}
}
}
if (DEBUG) Log.d(TAG, "Received callback #$index: $result")
callback(input, result)
@@ -228,14 +231,20 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
@Throws(RemoteException::class)
suspend fun <T : Parcelable?> execute(command: RootCommand<T>, classLoader: ClassLoader?): T {
val future = CompletableDeferred<T>()
@Suppress("UNCHECKED_CAST")
val callback = Callback.Ordinary(this, counter, classLoader, future as CompletableDeferred<Parcelable?>)
mutex.withLock {
if (active) {
@Suppress("UNCHECKED_CAST")
callbackLookup[counter] = Callback.Ordinary(classLoader, future as CompletableDeferred<Parcelable?>)
callbackLookup[counter] = callback
sendLocked(command)
} else future.cancel()
}
try {
return future.await()
} finally {
if (callback.active) callback.sendClosed()
callback.active = false
}
}
@ExperimentalCoroutinesApi
@@ -253,7 +262,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
}
}) {
@Suppress("UNCHECKED_CAST")
val callback = Callback.Channel(classLoader, this as SendChannel<Parcelable?>, this@RootServer, counter)
val callback = Callback.Channel(this@RootServer, counter, classLoader, this as SendChannel<Parcelable?>)
mutex.withLock {
if (active) {
callbackLookup[counter] = callback
@@ -263,7 +272,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
try {
callback.finish.await()
} finally {
if (callback.active) withContext(NonCancellable) { callback.sendClosed() }
if (callback.active) callback.sendClosed()
callback.active = false
}
}
@@ -381,7 +390,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
CoroutineScope(Dispatchers.Main.immediate + job)
}
val callbackWorker = newSingleThreadContext("callbackWorker")
val channels = LongSparseArray<ReceiveChannel<Parcelable?>>()
val cancellables = LongSparseArray<() -> Unit>()
// thread safety: usage of output should be guarded by callbackWorker
val output = DataOutputStream(System.out.buffered().apply {
@@ -402,7 +411,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
val callback = counter
if (DEBUG) Log.d(TAG, "Received #$callback: $command")
when (command) {
is ChannelClosed -> channels[command.index]?.cancel()
is CancelCommand -> cancellables[command.index]?.invoke()
is RootCommandOneWay -> defaultWorker.launch {
try {
command.execute()
@@ -422,7 +431,9 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
is RootCommandChannel<*> -> defaultWorker.launch {
val result = try {
coroutineScope {
command.create(this).also { channels[callback] = it }.consumeEach { result ->
command.create(this).also {
cancellables[callback] = { it.cancel() }
}.consumeEach { result ->
withContext(callbackWorker) { output.pushResult(callback, result) }
}
};
@@ -434,7 +445,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
} catch (e: Throwable) {
{ output.pushThrowable(callback, e) }
} finally {
channels.remove(callback)
cancellables.remove(callback)
}
withContext(callbackWorker) { result() }
}

View File

@@ -39,7 +39,7 @@ interface RootCommandChannel<T : Parcelable?> : Parcelable {
}
@Parcelize
internal class ChannelClosed(val index: Long) : RootCommandOneWay {
internal class CancelCommand(val index: Long) : RootCommandOneWay {
override suspend fun execute() = error("Internal implementation")
}