Allow ordinary execute to be cancelled

This commit is contained in:
Mygod
2020-07-03 07:56:35 +08:00
parent 798275e9c9
commit 69750f6609
2 changed files with 32 additions and 21 deletions

View File

@@ -20,13 +20,16 @@ import java.util.concurrent.CountDownLatch
import kotlin.system.exitProcess import kotlin.system.exitProcess
class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) { class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) {
private sealed class Callback { private sealed class Callback(protected val server: RootServer, protected val index: Long) {
var active = true
abstract fun cancel() abstract fun cancel()
abstract fun shouldRemove(result: Byte): Boolean abstract fun shouldRemove(result: Byte): Boolean
abstract operator fun invoke(input: DataInputStream, result: Byte) abstract operator fun invoke(input: DataInputStream, result: Byte)
suspend fun sendClosed() = withContext(NonCancellable) { server.execute(CancelCommand(index)) }
class Ordinary(private val classLoader: ClassLoader?, class Ordinary(server: RootServer, index: Long, private val classLoader: ClassLoader?,
private val callback: CompletableDeferred<Parcelable?>) : Callback() { private val callback: CompletableDeferred<Parcelable?>) : Callback(server, index) {
override fun cancel() = callback.cancel() override fun cancel() = callback.cancel()
override fun shouldRemove(result: Byte) = true override fun shouldRemove(result: Byte) = true
override fun invoke(input: DataInputStream, result: Byte) { override fun invoke(input: DataInputStream, result: Byte) {
@@ -42,11 +45,8 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
} }
} }
class Channel(private val classLoader: ClassLoader?, class Channel(server: RootServer, index: Long, private val classLoader: ClassLoader?,
private val channel: SendChannel<Parcelable?>, private val channel: SendChannel<Parcelable?>) : Callback(server, index) {
private val server: RootServer,
private val index: Long) : Callback() {
var active = true
val finish: CompletableDeferred<Unit> = CompletableDeferred() val finish: CompletableDeferred<Unit> = CompletableDeferred()
override fun cancel() = finish.cancel() override fun cancel() = finish.cancel()
override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS
@@ -70,8 +70,6 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
else -> throw IllegalArgumentException("Unexpected result $result") else -> throw IllegalArgumentException("Unexpected result $result")
} }
} }
suspend fun sendClosed() = server.execute(ChannelClosed(index))
} }
} }
@@ -165,7 +163,12 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
} }
val result = input.readByte() val result = input.readByte()
val callback = mutex.synchronized { val callback = mutex.synchronized {
callbackLookup[index]!!.also { if (it.shouldRemove(result)) callbackLookup.remove(index) } callbackLookup[index]!!.also {
if (it.shouldRemove(result)) {
callbackLookup.remove(index)
it.active = false
}
}
} }
if (DEBUG) Log.d(TAG, "Received callback #$index: $result") if (DEBUG) Log.d(TAG, "Received callback #$index: $result")
callback(input, result) callback(input, result)
@@ -228,14 +231,20 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
@Throws(RemoteException::class) @Throws(RemoteException::class)
suspend fun <T : Parcelable?> execute(command: RootCommand<T>, classLoader: ClassLoader?): T { suspend fun <T : Parcelable?> execute(command: RootCommand<T>, classLoader: ClassLoader?): T {
val future = CompletableDeferred<T>() val future = CompletableDeferred<T>()
@Suppress("UNCHECKED_CAST")
val callback = Callback.Ordinary(this, counter, classLoader, future as CompletableDeferred<Parcelable?>)
mutex.withLock { mutex.withLock {
if (active) { if (active) {
@Suppress("UNCHECKED_CAST") callbackLookup[counter] = callback
callbackLookup[counter] = Callback.Ordinary(classLoader, future as CompletableDeferred<Parcelable?>)
sendLocked(command) sendLocked(command)
} else future.cancel() } else future.cancel()
} }
try {
return future.await() return future.await()
} finally {
if (callback.active) callback.sendClosed()
callback.active = false
}
} }
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi
@@ -253,7 +262,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
} }
}) { }) {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
val callback = Callback.Channel(classLoader, this as SendChannel<Parcelable?>, this@RootServer, counter) val callback = Callback.Channel(this@RootServer, counter, classLoader, this as SendChannel<Parcelable?>)
mutex.withLock { mutex.withLock {
if (active) { if (active) {
callbackLookup[counter] = callback callbackLookup[counter] = callback
@@ -263,7 +272,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
try { try {
callback.finish.await() callback.finish.await()
} finally { } finally {
if (callback.active) withContext(NonCancellable) { callback.sendClosed() } if (callback.active) callback.sendClosed()
callback.active = false callback.active = false
} }
} }
@@ -381,7 +390,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
CoroutineScope(Dispatchers.Main.immediate + job) CoroutineScope(Dispatchers.Main.immediate + job)
} }
val callbackWorker = newSingleThreadContext("callbackWorker") val callbackWorker = newSingleThreadContext("callbackWorker")
val channels = LongSparseArray<ReceiveChannel<Parcelable?>>() val cancellables = LongSparseArray<() -> Unit>()
// thread safety: usage of output should be guarded by callbackWorker // thread safety: usage of output should be guarded by callbackWorker
val output = DataOutputStream(System.out.buffered().apply { val output = DataOutputStream(System.out.buffered().apply {
@@ -402,7 +411,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
val callback = counter val callback = counter
if (DEBUG) Log.d(TAG, "Received #$callback: $command") if (DEBUG) Log.d(TAG, "Received #$callback: $command")
when (command) { when (command) {
is ChannelClosed -> channels[command.index]?.cancel() is CancelCommand -> cancellables[command.index]?.invoke()
is RootCommandOneWay -> defaultWorker.launch { is RootCommandOneWay -> defaultWorker.launch {
try { try {
command.execute() command.execute()
@@ -422,7 +431,9 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
is RootCommandChannel<*> -> defaultWorker.launch { is RootCommandChannel<*> -> defaultWorker.launch {
val result = try { val result = try {
coroutineScope { coroutineScope {
command.create(this).also { channels[callback] = it }.consumeEach { result -> command.create(this).also {
cancellables[callback] = { it.cancel() }
}.consumeEach { result ->
withContext(callbackWorker) { output.pushResult(callback, result) } withContext(callbackWorker) { output.pushResult(callback, result) }
} }
}; };
@@ -434,7 +445,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
} catch (e: Throwable) { } catch (e: Throwable) {
{ output.pushThrowable(callback, e) } { output.pushThrowable(callback, e) }
} finally { } finally {
channels.remove(callback) cancellables.remove(callback)
} }
withContext(callbackWorker) { result() } withContext(callbackWorker) { result() }
} }

View File

@@ -39,7 +39,7 @@ interface RootCommandChannel<T : Parcelable?> : Parcelable {
} }
@Parcelize @Parcelize
internal class ChannelClosed(val index: Long) : RootCommandOneWay { internal class CancelCommand(val index: Long) : RootCommandOneWay {
override suspend fun execute() = error("Internal implementation") override suspend fun execute() = error("Internal implementation")
} }