From 69750f6609a7a32a74f2c1f53b0664988efdb708 Mon Sep 17 00:00:00 2001 From: Mygod Date: Fri, 3 Jul 2020 07:56:35 +0800 Subject: [PATCH] Allow ordinary execute to be cancelled --- .../be/mygod/librootkotlinx/RootServer.kt | 51 +++++++++++-------- .../be/mygod/librootkotlinx/ServerCommands.kt | 2 +- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt b/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt index ae465a41..ee0e6fa4 100644 --- a/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt +++ b/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt @@ -20,13 +20,16 @@ import java.util.concurrent.CountDownLatch import kotlin.system.exitProcess class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) { - private sealed class Callback { + private sealed class Callback(protected val server: RootServer, protected val index: Long) { + var active = true + abstract fun cancel() abstract fun shouldRemove(result: Byte): Boolean abstract operator fun invoke(input: DataInputStream, result: Byte) + suspend fun sendClosed() = withContext(NonCancellable) { server.execute(CancelCommand(index)) } - class Ordinary(private val classLoader: ClassLoader?, - private val callback: CompletableDeferred) : Callback() { + class Ordinary(server: RootServer, index: Long, private val classLoader: ClassLoader?, + private val callback: CompletableDeferred) : Callback(server, index) { override fun cancel() = callback.cancel() override fun shouldRemove(result: Byte) = true override fun invoke(input: DataInputStream, result: Byte) { @@ -42,11 +45,8 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U } } - class Channel(private val classLoader: ClassLoader?, - private val channel: SendChannel, - private val server: RootServer, - private val index: Long) : Callback() { - var active = true + class Channel(server: RootServer, index: Long, private val classLoader: ClassLoader?, + private val channel: SendChannel) : Callback(server, index) { val finish: CompletableDeferred = CompletableDeferred() override fun cancel() = finish.cancel() override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS @@ -70,8 +70,6 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U else -> throw IllegalArgumentException("Unexpected result $result") } } - - suspend fun sendClosed() = server.execute(ChannelClosed(index)) } } @@ -165,7 +163,12 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U } val result = input.readByte() val callback = mutex.synchronized { - callbackLookup[index]!!.also { if (it.shouldRemove(result)) callbackLookup.remove(index) } + callbackLookup[index]!!.also { + if (it.shouldRemove(result)) { + callbackLookup.remove(index) + it.active = false + } + } } if (DEBUG) Log.d(TAG, "Received callback #$index: $result") callback(input, result) @@ -228,14 +231,20 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U @Throws(RemoteException::class) suspend fun execute(command: RootCommand, classLoader: ClassLoader?): T { val future = CompletableDeferred() + @Suppress("UNCHECKED_CAST") + val callback = Callback.Ordinary(this, counter, classLoader, future as CompletableDeferred) mutex.withLock { if (active) { - @Suppress("UNCHECKED_CAST") - callbackLookup[counter] = Callback.Ordinary(classLoader, future as CompletableDeferred) + callbackLookup[counter] = callback sendLocked(command) } else future.cancel() } - return future.await() + try { + return future.await() + } finally { + if (callback.active) callback.sendClosed() + callback.active = false + } } @ExperimentalCoroutinesApi @@ -253,7 +262,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U } }) { @Suppress("UNCHECKED_CAST") - val callback = Callback.Channel(classLoader, this as SendChannel, this@RootServer, counter) + val callback = Callback.Channel(this@RootServer, counter, classLoader, this as SendChannel) mutex.withLock { if (active) { callbackLookup[counter] = callback @@ -263,7 +272,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U try { callback.finish.await() } finally { - if (callback.active) withContext(NonCancellable) { callback.sendClosed() } + if (callback.active) callback.sendClosed() callback.active = false } } @@ -381,7 +390,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U CoroutineScope(Dispatchers.Main.immediate + job) } val callbackWorker = newSingleThreadContext("callbackWorker") - val channels = LongSparseArray>() + val cancellables = LongSparseArray<() -> Unit>() // thread safety: usage of output should be guarded by callbackWorker val output = DataOutputStream(System.out.buffered().apply { @@ -402,7 +411,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U val callback = counter if (DEBUG) Log.d(TAG, "Received #$callback: $command") when (command) { - is ChannelClosed -> channels[command.index]?.cancel() + is CancelCommand -> cancellables[command.index]?.invoke() is RootCommandOneWay -> defaultWorker.launch { try { command.execute() @@ -422,7 +431,9 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U is RootCommandChannel<*> -> defaultWorker.launch { val result = try { coroutineScope { - command.create(this).also { channels[callback] = it }.consumeEach { result -> + command.create(this).also { + cancellables[callback] = { it.cancel() } + }.consumeEach { result -> withContext(callbackWorker) { output.pushResult(callback, result) } } }; @@ -434,7 +445,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U } catch (e: Throwable) { { output.pushThrowable(callback, e) } } finally { - channels.remove(callback) + cancellables.remove(callback) } withContext(callbackWorker) { result() } } diff --git a/mobile/src/main/java/be/mygod/librootkotlinx/ServerCommands.kt b/mobile/src/main/java/be/mygod/librootkotlinx/ServerCommands.kt index 14b1797e..7fbc318d 100644 --- a/mobile/src/main/java/be/mygod/librootkotlinx/ServerCommands.kt +++ b/mobile/src/main/java/be/mygod/librootkotlinx/ServerCommands.kt @@ -39,7 +39,7 @@ interface RootCommandChannel : Parcelable { } @Parcelize -internal class ChannelClosed(val index: Long) : RootCommandOneWay { +internal class CancelCommand(val index: Long) : RootCommandOneWay { override suspend fun execute() = error("Internal implementation") }