Refactor TTSServer

This commit is contained in:
2020-05-06 22:15:49 +02:00
parent 0d75b9d9c3
commit 2286843356
2 changed files with 78 additions and 66 deletions

View File

@@ -0,0 +1,67 @@
package io.bartek.tts
import android.content.Context
import android.speech.tts.TextToSpeech
import android.speech.tts.UtteranceProgressListener
import java.io.BufferedInputStream
import java.io.FileInputStream
import java.io.InputStream
import java.lang.RuntimeException
import java.util.*
data class SpeechData(val stream: InputStream, val size: Long)
class TTS(context: Context, initListener: TextToSpeech.OnInitListener) {
private val tts = TextToSpeech(context, initListener)
fun performTTS(text: String, language: Locale): SpeechData {
val file = createTempFile("tmp_tts_server", ".wav")
val uuid = UUID.randomUUID().toString()
val lock = Lock()
tts.setOnUtteranceProgressListener(TTSProcessListener(uuid, lock))
synchronized(lock) {
tts.language = language
tts.synthesizeToFile(text, null, file, uuid)
lock.wait()
}
if (!lock.success) {
throw RuntimeException("TTS failed")
}
val stream = BufferedInputStream(FileInputStream(file))
val length = file.length()
file.delete()
return SpeechData(stream, length)
}
}
private data class Lock(var success: Boolean = false) : Object()
private class TTSProcessListener(private val uuid: String, private val lock: Lock) :
UtteranceProgressListener() {
override fun onDone(utteranceId: String?) {
if (utteranceId == uuid) {
synchronized(lock) {
lock.success = true
lock.notifyAll()
}
}
}
override fun onError(utteranceId: String?) {
if (utteranceId == uuid) {
synchronized(lock) {
lock.success = false
lock.notifyAll()
}
}
}
override fun onStart(utteranceId: String?) {}
}

View File

@@ -2,57 +2,24 @@ package io.bartek.web
import android.content.Context
import android.speech.tts.TextToSpeech
import android.speech.tts.UtteranceProgressListener
import android.widget.Toast
import fi.iki.elonen.NanoHTTPD
import fi.iki.elonen.NanoHTTPD.Response.Status.*
import io.bartek.R
import io.bartek.tts.TTS
import org.json.JSONObject
import java.io.BufferedInputStream
import java.io.FileInputStream
import java.io.InputStream
import java.lang.RuntimeException
import java.util.*
import kotlin.collections.HashMap
data class TTSRequestData(val text: String, val language: Locale)
private data class TTSRequestData(val text: String, val language: Locale)
data class Lock(var success: Boolean = false) : Object()
data class SpeechData(val stream: InputStream, val size: Long)
class TTSProcessListener(private val uuid: String, private val lock: Lock) :
UtteranceProgressListener() {
override fun onDone(utteranceId: String?) {
if (utteranceId == uuid) {
synchronized(lock) {
lock.success = true
lock.notifyAll()
}
}
}
override fun onError(utteranceId: String?) {
if (utteranceId == uuid) {
synchronized(lock) {
lock.success = false
lock.notifyAll()
}
}
}
override fun onStart(utteranceId: String?) {}
}
class TTSServer(port: Int, private val context: Context) : NanoHTTPD(port),
TextToSpeech.OnInitListener {
private val tts = TextToSpeech(context, this)
private val tts = TTS(context, this)
override fun serve(session: IHTTPSession?): Response {
try {
val (text, language) = getRequestData(validateRequest(session))
val (stream, size) = performTTS(text, language)
return newFixedLengthResponse(OK, "audio/x-wav", stream, size)
return tryToServe(session)
} catch (e: ResponseException) {
throw e
} catch (e: Exception) {
@@ -60,6 +27,12 @@ class TTSServer(port: Int, private val context: Context) : NanoHTTPD(port),
}
}
private fun tryToServe(session: IHTTPSession?): Response {
val (text, language) = getRequestData(validateRequest(session))
val (stream, size) = tts.performTTS(text, language)
return newFixedLengthResponse(OK, "audio/x-wav", stream, size)
}
private fun validateRequest(session: IHTTPSession?): IHTTPSession {
if (session == null) {
throw ResponseException(BAD_REQUEST, "")
@@ -81,40 +54,12 @@ class TTSServer(port: Int, private val context: Context) : NanoHTTPD(port),
return session
}
private fun performTTS(text: String, language: Locale): SpeechData {
val file = createTempFile("tmp_tts_server", ".wav")
val uuid = UUID.randomUUID().toString()
val lock = Lock()
tts.setOnUtteranceProgressListener(TTSProcessListener(uuid, lock))
synchronized(lock) {
tts.language = language
tts.synthesizeToFile(text, null, file, uuid)
lock.wait()
}
if (!lock.success) {
throw RuntimeException("TTS failed")
}
val stream = BufferedInputStream(FileInputStream(file))
val length = file.length()
file.delete()
return SpeechData(stream, length)
}
private fun getRequestData(session: IHTTPSession): TTSRequestData {
val map = mutableMapOf<String, String>()
session.parseBody(map)
val json = JSONObject(map["postData"] ?: "{}")
val language = Locale(json.optString("language", "en_US"))
val text = json.optString("text") ?: throw ResponseException(
BAD_REQUEST,
"The missing 'text' field is required."
)
val text = json.optString("text") ?: throw ResponseException(BAD_REQUEST, "")
return TTSRequestData(text, language)
}