diff --git a/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt b/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt index 6ea67d6a..e06cbae0 100644 --- a/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt +++ b/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt @@ -11,7 +11,10 @@ import androidx.collection.valueIterator import eu.chainfire.librootjava.AppProcess import eu.chainfire.librootjava.RootJava import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.consumeEach +import kotlinx.coroutines.channels.produce import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import java.io.* @@ -20,7 +23,8 @@ import java.util.concurrent.CountDownLatch import kotlin.system.exitProcess class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) { - private sealed class Callback(protected val server: RootServer, protected val index: Long) { + private sealed class Callback(private val server: RootServer, private val index: Long, + protected val classLoader: ClassLoader?) { var active = true abstract fun cancel() @@ -28,25 +32,45 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U abstract operator fun invoke(input: DataInputStream, result: Byte) suspend fun sendClosed() = withContext(NonCancellable) { server.execute(CancelCommand(index)) } - class Ordinary(server: RootServer, index: Long, private val classLoader: ClassLoader?, - private val callback: CompletableDeferred) : Callback(server, index) { + private fun initException(targetClass: Class<*>, message: String): Throwable { + var targetClass = targetClass + while (true) { + try { + // try to find a message constructor + return targetClass.getDeclaredConstructor(String::class.java).newInstance(message) as Throwable + } catch (_: ReflectiveOperationException) { } + targetClass = targetClass.superclass + } + } + private fun makeRemoteException(cause: Throwable, message: String? = null) = + if (cause is CancellationException) cause else RemoteException(message).initCause(cause) + protected fun DataInputStream.readException(result: Byte) = when (result.toInt()) { + EX_GENERIC -> { + val message = readUTF() + val name = message.split(':', limit = 2)[0] + makeRemoteException(initException(try { + classLoader?.loadClass(name) + } catch (_: ClassNotFoundException) { + null + } ?: Class.forName(name), message), message) + } + EX_PARCELABLE -> makeRemoteException(readParcelable(classLoader) as Throwable) + EX_SERIALIZABLE -> makeRemoteException(readSerializable(classLoader) as Throwable) + else -> throw IllegalArgumentException("Unexpected result $result") + } + + class Ordinary(server: RootServer, index: Long, classLoader: ClassLoader?, + private val callback: CompletableDeferred) : Callback(server, index, classLoader) { override fun cancel() = callback.cancel() override fun shouldRemove(result: Byte) = true override fun invoke(input: DataInputStream, result: Byte) { - when (result.toInt()) { - SUCCESS -> callback.complete(input.readParcelable(classLoader)) - EX_GENERIC -> callback.completeExceptionally(RemoteException(input.readUTF())) - EX_PARCELABLE -> callback.completeExceptionally(RemoteException().initCause( - input.readParcelable(classLoader) as Throwable)) - EX_SERIALIZABLE -> callback.completeExceptionally(RemoteException().initCause( - input.readSerializable(classLoader) as Throwable)) - else -> throw IllegalArgumentException("Unexpected result $result") - } + if (result.toInt() == SUCCESS) callback.complete(input.readParcelable(classLoader)) + else callback.completeExceptionally(input.readException(result)) } } - class Channel(server: RootServer, index: Long, private val classLoader: ClassLoader?, - private val channel: SendChannel) : Callback(server, index) { + class Channel(server: RootServer, index: Long, classLoader: ClassLoader?, + private val channel: SendChannel) : Callback(server, index, classLoader) { val finish: CompletableDeferred = CompletableDeferred() override fun cancel() = finish.cancel() override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS @@ -61,13 +85,8 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U finish.completeExceptionally(closed) return }) - EX_GENERIC -> finish.completeExceptionally(RemoteException(input.readUTF())) - EX_PARCELABLE -> finish.completeExceptionally(RemoteException().initCause( - input.readParcelable(classLoader) as Throwable)) - EX_SERIALIZABLE -> finish.completeExceptionally(RemoteException().initCause( - input.readSerializable(classLoader) as Throwable)) CHANNEL_CONSUMED -> finish.complete(Unit) - else -> throw IllegalArgumentException("Unexpected result $result") + else -> finish.completeExceptionally(input.readException(result)) } } }