package be.mygod.librootkotlinx import android.content.Context import android.os.Build import android.os.Looper import android.os.Parcelable import android.os.RemoteException import android.system.Os import android.util.Base64 import android.util.Log import androidx.collection.LongSparseArray import androidx.collection.set import androidx.collection.valueIterator import eu.chainfire.librootjava.AppProcess import eu.chainfire.librootjava.Logger import eu.chainfire.librootjava.RootJava import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.consumeEach import kotlinx.coroutines.channels.produce import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import java.io.* import java.security.MessageDigest import java.util.* import java.util.concurrent.CountDownLatch import kotlin.system.exitProcess class RootServer @JvmOverloads constructor(private val warnLogger: (String) -> Unit = { Log.w(TAG, it) }) { private sealed class Callback(private val server: RootServer, private val index: Long, protected val classLoader: ClassLoader?) { var active = true abstract fun cancel() abstract fun shouldRemove(result: Byte): Boolean abstract operator fun invoke(input: DataInputStream, result: Byte) suspend fun sendClosed() = withContext(NonCancellable) { server.execute(CancelCommand(index)) } private fun initException(targetClass: Class<*>, message: String): Throwable { var targetClass = targetClass while (true) { try { // try to find a message constructor return targetClass.getDeclaredConstructor(String::class.java).newInstance(message) as Throwable } catch (_: ReflectiveOperationException) { } targetClass = targetClass.superclass } } private fun makeRemoteException(cause: Throwable, message: String? = null) = if (cause is CancellationException) cause else RemoteException(message).initCause(cause) protected fun DataInputStream.readException(result: Byte) = when (result.toInt()) { EX_GENERIC -> { val message = readUTF() val name = message.split(':', limit = 2)[0] makeRemoteException(initException(try { classLoader?.loadClass(name) } catch (_: ClassNotFoundException) { null } ?: Class.forName(name), message), message) } EX_PARCELABLE -> makeRemoteException(readParcelable(classLoader) as Throwable) EX_SERIALIZABLE -> makeRemoteException(readSerializable(classLoader) as Throwable) else -> throw IllegalArgumentException("Unexpected result $result") } class Ordinary(server: RootServer, index: Long, classLoader: ClassLoader?, private val callback: CompletableDeferred) : Callback(server, index, classLoader) { override fun cancel() = callback.cancel() override fun shouldRemove(result: Byte) = true override fun invoke(input: DataInputStream, result: Byte) { if (result.toInt() == SUCCESS) callback.complete(input.readParcelable(classLoader)) else callback.completeExceptionally(input.readException(result)) } } class Channel(server: RootServer, index: Long, classLoader: ClassLoader?, private val channel: SendChannel) : Callback(server, index, classLoader) { val finish: CompletableDeferred = CompletableDeferred() override fun cancel() = finish.cancel() override fun shouldRemove(result: Byte) = result.toInt() != SUCCESS override fun invoke(input: DataInputStream, result: Byte) { when (result.toInt()) { // the channel we are supporting should never block SUCCESS -> check(try { channel.offer(input.readParcelable(classLoader)) } catch (closed: Throwable) { active = false GlobalScope.launch(Dispatchers.Unconfined) { sendClosed() } finish.completeExceptionally(closed) return }) CHANNEL_CONSUMED -> finish.complete(Unit) else -> finish.completeExceptionally(input.readException(result)) } } } } private lateinit var process: Process /** * Thread safety: needs to be protected by mutex. */ private lateinit var output: DataOutputStream @Volatile var active = false private var counter = 0L private lateinit var callbackListenerExit: Deferred private val callbackLookup = LongSparseArray() private val mutex = Mutex() /** * If we encountered unexpected output from stderr during initialization, its content will be stored here. * * It is advised to read this after initializing the instance. */ fun readUnexpectedStderr(): String? { if (!this::process.isInitialized) return null var available = process.errorStream.available() return if (available <= 0) null else String(ByteArrayOutputStream().apply { while (available > 0) { val bytes = ByteArray(available) val len = process.errorStream.read(bytes) if (len < 0) throw EOFException() // should not happen write(bytes, 0, len) available = process.errorStream.available() } }.toByteArray()) } private fun BufferedReader.lookForToken(token: String) { while (true) { val line = readLine() ?: throw EOFException() if (line.endsWith(token)) { val extraLength = line.length - token.length if (extraLength > 0) warnLogger(line.substring(0, extraLength)) break } warnLogger(line) } } @Suppress("BlockingMethodInNonBlockingContext") private suspend fun doInit(context: Context, niceName: String) { val token2 = UUID.randomUUID().toString() val list = coroutineScope { val init = async(start = CoroutineStart.LAZY) { try { process = ProcessBuilder("su").start() val token1 = UUID.randomUUID().toString() val writer = DataOutputStream(process.outputStream.buffered()) writer.writeBytes("echo $token1\n") writer.flush() val reader = process.inputStream.bufferedReader() reader.lookForToken(token1) if (isDebugEnabled) Log.d(TAG, "Root shell initialized") reader to writer } catch (e: Exception) { throw NoShellException(e) } } val launchString = async(start = CoroutineStart.LAZY) { val appProcess = AppProcess.getAppProcess() val relocated = if (Build.VERSION.SDK_INT < 29) { val target = File(context.codeCacheDir, MessageDigest.getInstance("SHA-256").run { update(appProcess.toByteArray()) Base64.encodeToString(digest(), Base64.NO_PADDING or Base64.NO_WRAP or Base64.URL_SAFE) }) if (!target.canExecute()) { // copy file to local cache to reset xattr/SELinux context File(appProcess).copyTo(target, true) Os.chmod(target.absolutePath, 0b111_101_101) } check(target.canExecute()) target.absolutePath } else appProcess RootJava.getLaunchString(context.packageCodePath + " exec", // hack: plugging in exec RootServer::class.java.name, relocated, AppProcess.guessIfAppProcessIs64Bits(appProcess), arrayOf("$token2\n"), niceName).let { result -> if (Build.VERSION.SDK_INT < 24) result // undo the patch on newer APIs to let linker do the work else result.replaceFirst(" LD_LIBRARY_PATH=", " __SUPPRESSED_LD_LIBRARY_PATH=") } } awaitAll(init, launchString) } val (reader, writer) = list[0] as Pair writer.writeBytes(list[1] as String) writer.flush() reader.lookForToken(token2) // wait for ready signal output = writer require(!active) active = true if (isDebugEnabled) Log.d(TAG, "Root server initialized") } private fun callbackSpin() { val input = DataInputStream(process.inputStream.buffered()) while (active) { val index = try { input.readLong() } catch (_: EOFException) { break } val result = input.readByte() val callback = mutex.synchronized { callbackLookup[index]!!.also { if (it.shouldRemove(result)) { callbackLookup.remove(index) it.active = false } } } if (isDebugEnabled) Log.d(TAG, "Received callback #$index: $result") callback(input, result) } } /** * Initialize a RootServer synchronously, can throw a lot of exceptions. * * @param context Any [Context] from the app. * @param niceName Name to call the rooted Java process. */ suspend fun init(context: Context, niceName: String = "${context.packageName}:root") { val future = CompletableDeferred() callbackListenerExit = GlobalScope.async(Dispatchers.IO) { try { doInit(context, niceName) future.complete(Unit) } catch (e: Throwable) { future.completeExceptionally(e) return@async } val errorReader = async(Dispatchers.IO) { try { process.errorStream.bufferedReader().useLines { seq -> for (line in seq) warnLogger(line) } } catch (_: IOException) { } } try { callbackSpin() if (active) throw RemoteException("Root process exited unexpectedly") } catch (e: Throwable) { process.destroy() throw e } finally { if (isDebugEnabled) Log.d(TAG, "Waiting for exit") errorReader.await() process.waitFor() withContext(NonCancellable) { closeInternal(true) } } } future.await() } /** * Convenience function that initializes and also logs warnings to [Log]. */ suspend fun initAndroidLog(context: Context, niceName: String = "${context.packageName}:root") = try { init(context, niceName) } finally { readUnexpectedStderr()?.let { Log.e(TAG, it) } } /** * Caller should check for active. */ private fun sendLocked(command: Parcelable) { output.writeParcelable(command) output.flush() if (isDebugEnabled) Log.d(TAG, "Sent #$counter: $command") counter++ } suspend fun execute(command: RootCommandOneWay) = mutex.withLock { if (active) sendLocked(command) } @Throws(RemoteException::class) suspend inline fun execute(command: RootCommand) = execute(command, T::class.java.classLoader) @Throws(RemoteException::class) suspend fun execute(command: RootCommand, classLoader: ClassLoader?): T { val future = CompletableDeferred() @Suppress("UNCHECKED_CAST") val callback = Callback.Ordinary(this, counter, classLoader, future as CompletableDeferred) mutex.withLock { if (active) { callbackLookup[counter] = callback sendLocked(command) } else future.cancel() } try { return future.await() } finally { if (callback.active) callback.sendClosed() callback.active = false } } @ExperimentalCoroutinesApi @Throws(RemoteException::class) inline fun create(command: RootCommandChannel, scope: CoroutineScope) = create(command, scope, T::class.java.classLoader) @ExperimentalCoroutinesApi @Throws(RemoteException::class) fun create(command: RootCommandChannel, scope: CoroutineScope, classLoader: ClassLoader?) = scope.produce( capacity = command.capacity.also { when (it) { Channel.UNLIMITED, Channel.CONFLATED -> { } else -> throw IllegalArgumentException("Unsupported channel capacity $it") } }) { @Suppress("UNCHECKED_CAST") val callback = Callback.Channel(this@RootServer, counter, classLoader, this as SendChannel) mutex.withLock { if (active) { callbackLookup[counter] = callback sendLocked(command) } else callback.finish.cancel() } try { callback.finish.await() } finally { if (callback.active) callback.sendClosed() callback.active = false } } private suspend fun closeInternal(fromWorker: Boolean = false) = mutex.withLock { if (active) { active = false if (isDebugEnabled) Log.d(TAG, if (fromWorker) "Shutting down from worker" else "Shutting down from client") try { sendLocked(Shutdown()) output.close() process.outputStream.close() } catch (e: IOException) { Log.i(TAG, "send Shutdown failed", e) } if (isDebugEnabled) Log.d(TAG, "Client closed") } if (fromWorker) { for (callback in callbackLookup.valueIterator()) callback.cancel() callbackLookup.clear() } } /** * Shutdown the instance gracefully. */ suspend fun close() { closeInternal() callbackListenerExit.await() } companion object { /** * If set to true, debug information will be printed to logcat. */ @JvmStatic var isDebugEnabled = false set(value) { field = value Logger.setDebugLogging(value) } private const val TAG = "RootServer" private const val SUCCESS = 0 private const val EX_GENERIC = 1 private const val EX_PARCELABLE = 2 private const val EX_SERIALIZABLE = 4 private const val CHANNEL_CONSUMED = 3 private fun DataInputStream.readByteArray() = ByteArray(readInt()).also { readFully(it) } private inline fun DataInputStream.readParcelable( classLoader: ClassLoader? = T::class.java.classLoader) = readByteArray().toParcelable(classLoader) private fun DataOutputStream.writeParcelable(data: Parcelable?, parcelableFlags: Int = 0) { val bytes = data.toByteArray(parcelableFlags) writeInt(bytes.size) write(bytes) } private fun DataInputStream.readSerializable(classLoader: ClassLoader?) = object : ObjectInputStream(ByteArrayInputStream(readByteArray())) { override fun resolveClass(desc: ObjectStreamClass) = Class.forName(desc.name, false, classLoader) }.readObject() private inline fun Mutex.synchronized(crossinline block: () -> T): T = runBlocking { withLock { block() } } @JvmStatic fun main(args: Array) { Thread.setDefaultUncaughtExceptionHandler { thread, throwable -> Log.e(TAG, "Uncaught exception from $thread", throwable) exitProcess(1) } rootMain(args) exitProcess(0) // there might be other non-daemon threads } private fun DataOutputStream.pushThrowable(callback: Long, e: Throwable) { writeLong(callback) if (e is Parcelable) { writeByte(EX_PARCELABLE) writeParcelable(e) } else try { val bytes = ByteArrayOutputStream().apply { ObjectOutputStream(this).use { it.writeObject(e) } }.toByteArray() writeByte(EX_SERIALIZABLE) writeInt(bytes.size) write(bytes) } catch (_: NotSerializableException) { writeByte(EX_GENERIC) writeUTF(StringWriter().also { e.printStackTrace(PrintWriter(it)) }.toString()) } flush() } private fun DataOutputStream.pushResult(callback: Long, result: Parcelable?) { writeLong(callback) writeByte(SUCCESS) writeParcelable(result) flush() } private fun rootMain(args: Array) { require(args.isNotEmpty()) if (Build.VERSION.SDK_INT < 24) RootJava.restoreOriginalLdLibraryPath() val mainInitialized = CountDownLatch(1) val main = Thread({ @Suppress("DEPRECATION") Looper.prepareMainLooper() mainInitialized.countDown() Looper.loop() }, "main") main.start() val job = Job() val defaultWorker by lazy { mainInitialized.await() CoroutineScope(Dispatchers.Main.immediate + job) } val callbackWorker = newSingleThreadContext("callbackWorker") val cancellables = LongSparseArray<() -> Unit>() // thread safety: usage of output should be guarded by callbackWorker val output = DataOutputStream(System.out.buffered().apply { val writer = writer() writer.appendln(args[0]) // echo ready signal writer.flush() }) // thread safety: usage of input should be in main thread val input = DataInputStream(System.`in`.buffered()) var counter = 0L if (isDebugEnabled) Log.d(TAG, "Server entering main loop") loop@ while (true) { val command = try { input.readParcelable(RootServer::class.java.classLoader) } catch (_: EOFException) { break } val callback = counter if (isDebugEnabled) Log.d(TAG, "Received #$callback: $command") when (command) { is CancelCommand -> cancellables[command.index]?.invoke() is RootCommandOneWay -> defaultWorker.launch { try { command.execute() } catch (e: Throwable) { Log.e(command.javaClass.simpleName, "Unexpected exception in RootCommandOneWay", e) } } is RootCommand<*> -> { val commandJob = Job() cancellables[callback] = { commandJob.cancel() } defaultWorker.launch(commandJob) { val result = try { val result = command.execute() withContext(callbackWorker + NonCancellable) { output.pushResult(callback, result) } } catch (e: Throwable) { withContext(callbackWorker + NonCancellable) { output.pushThrowable(callback, e) } } finally { cancellables.remove(callback) } } } is RootCommandChannel<*> -> defaultWorker.launch { try { coroutineScope { command.create(this).also { cancellables[callback] = { it.cancel() } }.consumeEach { result -> withContext(callbackWorker) { output.pushResult(callback, result) } } } @Suppress("BlockingMethodInNonBlockingContext") withContext(callbackWorker + NonCancellable) { output.writeByte(CHANNEL_CONSUMED) output.writeLong(callback) output.flush() } } catch (e: Throwable) { withContext(callbackWorker + NonCancellable) { output.pushThrowable(callback, e) } } finally { cancellables.remove(callback) } } is Shutdown -> break@loop else -> throw IllegalArgumentException("Unrecognized input: $command") } counter++ } job.cancel() if (isDebugEnabled) Log.d(TAG, "Clean up initiated before exit. Jobs: ${job.children.joinToString()}") if (runBlocking { withTimeoutOrNull(5000) { job.join() } } == null) { Log.w(TAG, "Clean up timeout: ${job.children.joinToString()}") } else if (isDebugEnabled) Log.d(TAG, "Clean up finished, exiting") } } }