Throw CancellationException directly, without wrapping in RemoteException

This commit is contained in:
Mygod
2020-07-07 06:28:45 +08:00
parent a1b076cee9
commit 90f49b3159

View File

@@ -11,7 +11,10 @@ import androidx.collection.valueIterator
import eu.chainfire.librootjava.AppProcess
import eu.chainfire.librootjava.RootJava
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.consumeEach
import kotlinx.coroutines.channels.produce
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import java.io.*
@@ -20,7 +23,8 @@ import java.util.concurrent.CountDownLatch
import kotlin.system.exitProcess
class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) {
private sealed class Callback(protected val server: RootServer, protected val index: Long) {
private sealed class Callback(private val server: RootServer, private val index: Long,
protected val classLoader: ClassLoader?) {
var active = true
abstract fun cancel()
@@ -28,25 +32,45 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
abstract operator fun invoke(input: DataInputStream, result: Byte)
suspend fun sendClosed() = withContext(NonCancellable) { server.execute(CancelCommand(index)) }
class Ordinary(server: RootServer, index: Long, private val classLoader: ClassLoader?,
private val callback: CompletableDeferred<Parcelable?>) : Callback(server, index) {
private fun initException(targetClass: Class<*>, message: String): Throwable {
var targetClass = targetClass
while (true) {
try {
// try to find a message constructor
return targetClass.getDeclaredConstructor(String::class.java).newInstance(message) as Throwable
} catch (_: ReflectiveOperationException) { }
targetClass = targetClass.superclass
}
}
private fun makeRemoteException(cause: Throwable, message: String? = null) =
if (cause is CancellationException) cause else RemoteException(message).initCause(cause)
protected fun DataInputStream.readException(result: Byte) = when (result.toInt()) {
EX_GENERIC -> {
val message = readUTF()
val name = message.split(':', limit = 2)[0]
makeRemoteException(initException(try {
classLoader?.loadClass(name)
} catch (_: ClassNotFoundException) {
null
} ?: Class.forName(name), message), message)
}
EX_PARCELABLE -> makeRemoteException(readParcelable<Parcelable>(classLoader) as Throwable)
EX_SERIALIZABLE -> makeRemoteException(readSerializable(classLoader) as Throwable)
else -> throw IllegalArgumentException("Unexpected result $result")
}
class Ordinary(server: RootServer, index: Long, classLoader: ClassLoader?,
private val callback: CompletableDeferred<Parcelable?>) : Callback(server, index, classLoader) {
override fun cancel() = callback.cancel()
override fun shouldRemove(result: Byte) = true
override fun invoke(input: DataInputStream, result: Byte) {
when (result.toInt()) {
SUCCESS -> callback.complete(input.readParcelable(classLoader))
EX_GENERIC -> callback.completeExceptionally(RemoteException(input.readUTF()))
EX_PARCELABLE -> callback.completeExceptionally(RemoteException().initCause(
input.readParcelable<Parcelable>(classLoader) as Throwable))
EX_SERIALIZABLE -> callback.completeExceptionally(RemoteException().initCause(
input.readSerializable(classLoader) as Throwable))
else -> throw IllegalArgumentException("Unexpected result $result")
}
if (result.toInt() == SUCCESS) callback.complete(input.readParcelable(classLoader))
else callback.completeExceptionally(input.readException(result))
}
}
class Channel(server: RootServer, index: Long, private val classLoader: ClassLoader?,
private val channel: SendChannel<Parcelable?>) : Callback(server, index) {
class Channel(server: RootServer, index: Long, classLoader: ClassLoader?,
private val channel: SendChannel<Parcelable?>) : Callback(server, index, classLoader) {
val finish: CompletableDeferred<Unit> = CompletableDeferred()
override fun cancel() = finish.cancel()
override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS
@@ -61,13 +85,8 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U
finish.completeExceptionally(closed)
return
})
EX_GENERIC -> finish.completeExceptionally(RemoteException(input.readUTF()))
EX_PARCELABLE -> finish.completeExceptionally(RemoteException().initCause(
input.readParcelable<Parcelable>(classLoader) as Throwable))
EX_SERIALIZABLE -> finish.completeExceptionally(RemoteException().initCause(
input.readSerializable(classLoader) as Throwable))
CHANNEL_CONSUMED -> finish.complete(Unit)
else -> throw IllegalArgumentException("Unexpected result $result")
else -> finish.completeExceptionally(input.readException(result))
}
}
}