Allow ordinary execute to be cancelled
This commit is contained in:
@@ -20,13 +20,16 @@ import java.util.concurrent.CountDownLatch
|
||||
import kotlin.system.exitProcess
|
||||
|
||||
class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) {
|
||||
private sealed class Callback {
|
||||
private sealed class Callback(protected val server: RootServer, protected val index: Long) {
|
||||
var active = true
|
||||
|
||||
abstract fun cancel()
|
||||
abstract fun shouldRemove(result: Byte): Boolean
|
||||
abstract operator fun invoke(input: DataInputStream, result: Byte)
|
||||
suspend fun sendClosed() = withContext(NonCancellable) { server.execute(CancelCommand(index)) }
|
||||
|
||||
class Ordinary(private val classLoader: ClassLoader?,
|
||||
private val callback: CompletableDeferred<Parcelable?>) : Callback() {
|
||||
class Ordinary(server: RootServer, index: Long, private val classLoader: ClassLoader?,
|
||||
private val callback: CompletableDeferred<Parcelable?>) : Callback(server, index) {
|
||||
override fun cancel() = callback.cancel()
|
||||
override fun shouldRemove(result: Byte) = true
|
||||
override fun invoke(input: DataInputStream, result: Byte) {
|
||||
@@ -42,11 +45,8 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
}
|
||||
}
|
||||
|
||||
class Channel(private val classLoader: ClassLoader?,
|
||||
private val channel: SendChannel<Parcelable?>,
|
||||
private val server: RootServer,
|
||||
private val index: Long) : Callback() {
|
||||
var active = true
|
||||
class Channel(server: RootServer, index: Long, private val classLoader: ClassLoader?,
|
||||
private val channel: SendChannel<Parcelable?>) : Callback(server, index) {
|
||||
val finish: CompletableDeferred<Unit> = CompletableDeferred()
|
||||
override fun cancel() = finish.cancel()
|
||||
override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS
|
||||
@@ -70,8 +70,6 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
else -> throw IllegalArgumentException("Unexpected result $result")
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun sendClosed() = server.execute(ChannelClosed(index))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +163,12 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
}
|
||||
val result = input.readByte()
|
||||
val callback = mutex.synchronized {
|
||||
callbackLookup[index]!!.also { if (it.shouldRemove(result)) callbackLookup.remove(index) }
|
||||
callbackLookup[index]!!.also {
|
||||
if (it.shouldRemove(result)) {
|
||||
callbackLookup.remove(index)
|
||||
it.active = false
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DEBUG) Log.d(TAG, "Received callback #$index: $result")
|
||||
callback(input, result)
|
||||
@@ -228,14 +231,20 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
@Throws(RemoteException::class)
|
||||
suspend fun <T : Parcelable?> execute(command: RootCommand<T>, classLoader: ClassLoader?): T {
|
||||
val future = CompletableDeferred<T>()
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val callback = Callback.Ordinary(this, counter, classLoader, future as CompletableDeferred<Parcelable?>)
|
||||
mutex.withLock {
|
||||
if (active) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
callbackLookup[counter] = Callback.Ordinary(classLoader, future as CompletableDeferred<Parcelable?>)
|
||||
callbackLookup[counter] = callback
|
||||
sendLocked(command)
|
||||
} else future.cancel()
|
||||
}
|
||||
try {
|
||||
return future.await()
|
||||
} finally {
|
||||
if (callback.active) callback.sendClosed()
|
||||
callback.active = false
|
||||
}
|
||||
}
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
@@ -253,7 +262,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
}
|
||||
}) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val callback = Callback.Channel(classLoader, this as SendChannel<Parcelable?>, this@RootServer, counter)
|
||||
val callback = Callback.Channel(this@RootServer, counter, classLoader, this as SendChannel<Parcelable?>)
|
||||
mutex.withLock {
|
||||
if (active) {
|
||||
callbackLookup[counter] = callback
|
||||
@@ -263,7 +272,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
try {
|
||||
callback.finish.await()
|
||||
} finally {
|
||||
if (callback.active) withContext(NonCancellable) { callback.sendClosed() }
|
||||
if (callback.active) callback.sendClosed()
|
||||
callback.active = false
|
||||
}
|
||||
}
|
||||
@@ -381,7 +390,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
CoroutineScope(Dispatchers.Main.immediate + job)
|
||||
}
|
||||
val callbackWorker = newSingleThreadContext("callbackWorker")
|
||||
val channels = LongSparseArray<ReceiveChannel<Parcelable?>>()
|
||||
val cancellables = LongSparseArray<() -> Unit>()
|
||||
|
||||
// thread safety: usage of output should be guarded by callbackWorker
|
||||
val output = DataOutputStream(System.out.buffered().apply {
|
||||
@@ -402,7 +411,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
val callback = counter
|
||||
if (DEBUG) Log.d(TAG, "Received #$callback: $command")
|
||||
when (command) {
|
||||
is ChannelClosed -> channels[command.index]?.cancel()
|
||||
is CancelCommand -> cancellables[command.index]?.invoke()
|
||||
is RootCommandOneWay -> defaultWorker.launch {
|
||||
try {
|
||||
command.execute()
|
||||
@@ -422,7 +431,9 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
is RootCommandChannel<*> -> defaultWorker.launch {
|
||||
val result = try {
|
||||
coroutineScope {
|
||||
command.create(this).also { channels[callback] = it }.consumeEach { result ->
|
||||
command.create(this).also {
|
||||
cancellables[callback] = { it.cancel() }
|
||||
}.consumeEach { result ->
|
||||
withContext(callbackWorker) { output.pushResult(callback, result) }
|
||||
}
|
||||
};
|
||||
@@ -434,7 +445,7 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
|
||||
} catch (e: Throwable) {
|
||||
{ output.pushThrowable(callback, e) }
|
||||
} finally {
|
||||
channels.remove(callback)
|
||||
cancellables.remove(callback)
|
||||
}
|
||||
withContext(callbackWorker) { result() }
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ interface RootCommandChannel<T : Parcelable?> : Parcelable {
|
||||
}
|
||||
|
||||
@Parcelize
|
||||
internal class ChannelClosed(val index: Long) : RootCommandOneWay {
|
||||
internal class CancelCommand(val index: Long) : RootCommandOneWay {
|
||||
override suspend fun execute() = error("Internal implementation")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user