diff --git a/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt b/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt index e2172e82..eaff68cc 100644 --- a/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt +++ b/mobile/src/main/java/be/mygod/librootkotlinx/RootServer.kt @@ -35,7 +35,9 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U SUCCESS -> callback.complete(input.readParcelable(classLoader)) EX_GENERIC -> callback.completeExceptionally(RemoteException(input.readUTF())) EX_PARCELABLE -> callback.completeExceptionally(RemoteException().initCause( - input.readParcelable(classLoader) as Throwable?)) + input.readParcelable(classLoader) as Throwable)) + EX_SERIALIZABLE -> callback.completeExceptionally(RemoteException().initCause( + input.readSerializable(classLoader) as Throwable)) else -> throw IllegalArgumentException("Unexpected result $result") } } @@ -62,7 +64,9 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U }) EX_GENERIC -> finish.completeExceptionally(RemoteException(input.readUTF())) EX_PARCELABLE -> finish.completeExceptionally(RemoteException().initCause( - input.readParcelable(classLoader) as Throwable?)) + input.readParcelable(classLoader) as Throwable)) + EX_SERIALIZABLE -> finish.completeExceptionally(RemoteException().initCause( + input.readSerializable(classLoader) as Throwable)) CHANNEL_CONSUMED -> finish.complete(Unit) else -> throw IllegalArgumentException("Unexpected result $result") } @@ -294,17 +298,24 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U private const val SUCCESS = 0 private const val EX_GENERIC = 1 private const val EX_PARCELABLE = 2 + private const val EX_SERIALIZABLE = 4 private const val CHANNEL_CONSUMED = 3 + private fun DataInputStream.readByteArray() = ByteArray(readInt()).also { readFully(it) } + private inline fun DataInputStream.readParcelable( - classLoader: ClassLoader? = T::class.java.classLoader - ) = ByteArray(readInt()).also { readFully(it) }.toParcelable(classLoader) + classLoader: ClassLoader? = T::class.java.classLoader) = readByteArray().toParcelable(classLoader) private fun DataOutputStream.writeParcelable(data: Parcelable?, parcelableFlags: Int = 0) { val bytes = data.toByteArray(parcelableFlags) writeInt(bytes.size) write(bytes) } + private fun DataInputStream.readSerializable(classLoader: ClassLoader?) = + object : ObjectInputStream(ByteArrayInputStream(readByteArray())) { + override fun resolveClass(desc: ObjectStreamClass) = Class.forName(desc.name, false, classLoader) + }.readObject() + private inline fun Mutex.synchronized(crossinline block: () -> T): T = runBlocking { withLock { block() } } @@ -324,7 +335,14 @@ class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> U if (e is Parcelable) { writeByte(EX_PARCELABLE) writeParcelable(e) - } else { + } else try { + val bytes = ByteArrayOutputStream().apply { + ObjectOutputStream(this).use { it.writeObject(e) } + }.toByteArray() + writeByte(EX_SERIALIZABLE) + writeInt(bytes.size) + write(bytes) + } catch (_: NotSerializableException) { writeByte(EX_GENERIC) writeUTF(StringWriter().also { e.printStackTrace(PrintWriter(it))