Unify file-based responses handling

This commit is contained in:
2020-07-03 08:40:44 +02:00
parent 699136d640
commit 26771e5e42
3 changed files with 20 additions and 32 deletions

View File

@@ -3,11 +3,10 @@ package com.bartlomiejpluta.ttsserver.core.lua.lib
import cafe.adriel.androidaudioconverter.model.AudioFormat import cafe.adriel.androidaudioconverter.model.AudioFormat
import com.bartlomiejpluta.ttsserver.core.tts.engine.TTSEngine import com.bartlomiejpluta.ttsserver.core.tts.engine.TTSEngine
import org.luaj.vm2.LuaNil import org.luaj.vm2.LuaNil
import org.luaj.vm2.LuaString
import org.luaj.vm2.LuaValue import org.luaj.vm2.LuaValue
import org.luaj.vm2.lib.ThreeArgFunction import org.luaj.vm2.lib.ThreeArgFunction
import org.luaj.vm2.lib.TwoArgFunction import org.luaj.vm2.lib.TwoArgFunction
import java.lang.IllegalArgumentException import org.luaj.vm2.lib.jse.CoerceJavaToLua
import java.util.* import java.util.*
class TTSLibrary(private val ttsEngine: TTSEngine) : TwoArgFunction() { class TTSLibrary(private val ttsEngine: TTSEngine) : TwoArgFunction() {
@@ -42,7 +41,7 @@ class TTSLibrary(private val ttsEngine: TTSEngine) : TwoArgFunction() {
val file = ttsEngine.createTTSFile(text.checkjstring(), lang, audioFormat) val file = ttsEngine.createTTSFile(text.checkjstring(), lang, audioFormat)
return LuaValue.valueOf(file.absolutePath) return CoerceJavaToLua.coerce(file)
} }
} }
} }

View File

@@ -10,14 +10,11 @@ import cafe.adriel.androidaudioconverter.model.AudioFormat
import com.bartlomiejpluta.ttsserver.core.tts.exception.TTSException import com.bartlomiejpluta.ttsserver.core.tts.exception.TTSException
import com.bartlomiejpluta.ttsserver.core.tts.listener.GongListener import com.bartlomiejpluta.ttsserver.core.tts.listener.GongListener
import com.bartlomiejpluta.ttsserver.core.tts.listener.TTSProcessListener import com.bartlomiejpluta.ttsserver.core.tts.listener.TTSProcessListener
import com.bartlomiejpluta.ttsserver.core.tts.model.TTSStream
import com.bartlomiejpluta.ttsserver.core.tts.status.TTSStatus import com.bartlomiejpluta.ttsserver.core.tts.status.TTSStatus
import com.bartlomiejpluta.ttsserver.core.tts.status.TTSStatusHolder import com.bartlomiejpluta.ttsserver.core.tts.status.TTSStatusHolder
import com.bartlomiejpluta.ttsserver.core.util.AudioConverter import com.bartlomiejpluta.ttsserver.core.util.AudioConverter
import com.bartlomiejpluta.ttsserver.ui.preference.key.PreferenceKey import com.bartlomiejpluta.ttsserver.ui.preference.key.PreferenceKey
import java.io.BufferedInputStream
import java.io.File import java.io.File
import java.io.FileInputStream
import java.security.MessageDigest import java.security.MessageDigest
import java.util.* import java.util.*
@@ -33,7 +30,11 @@ class TTSEngine(
val status: TTSStatus val status: TTSStatus
get() = ttsStatusHolder.status get() = ttsStatusHolder.status
fun createTTSFile(text: String, language: Locale, audioFormat: AudioFormat = AudioFormat.WAV): File { fun createTTSFile(
text: String,
language: Locale,
audioFormat: AudioFormat = AudioFormat.WAV
): File {
val digest = hash(text, language) val digest = hash(text, language)
val targetFilename = "tts_$digest.${audioFormat.format}" val targetFilename = "tts_$digest.${audioFormat.format}"
val wavFilename = "tts_$digest.wav" val wavFilename = "tts_$digest.wav"
@@ -68,26 +69,6 @@ class TTSEngine(
return digest.fold("", { str, it -> str + "%02x".format(it) }) return digest.fold("", { str, it -> str + "%02x".format(it) })
} }
fun fetchTTSStream(text: String, language: Locale, audioFormat: AudioFormat = AudioFormat.WAV): TTSStream {
val file = createTempFile("tmp_tts_server", ".wav")
val uuid = UUID.randomUUID().toString()
val listener = TTSProcessListener(uuid)
tts.setOnUtteranceProgressListener(listener)
tts.language = language
tts.synthesizeToFile(text, null, file, uuid)
listener.await()
val converted = convertFile(file, audioFormat)
val stream = BufferedInputStream(FileInputStream(converted))
val length = converted.length()
file.delete()
return TTSStream(stream, length)
}
fun performTTS(text: String, language: Locale) { fun performTTS(text: String, language: Locale) {
val uuid = UUID.randomUUID().toString() val uuid = UUID.randomUUID().toString()
val listener = TTSProcessListener(uuid) val listener = TTSProcessListener(uuid)

View File

@@ -4,6 +4,7 @@ import com.bartlomiejpluta.ttsserver.core.web.dto.Request
import com.bartlomiejpluta.ttsserver.core.web.uri.UriTemplate import com.bartlomiejpluta.ttsserver.core.web.uri.UriTemplate
import fi.iki.elonen.NanoHTTPD.* import fi.iki.elonen.NanoHTTPD.*
import org.luaj.vm2.* import org.luaj.vm2.*
import org.luaj.vm2.lib.jse.CoerceLuaToJava
import java.io.BufferedInputStream import java.io.BufferedInputStream
import java.io.File import java.io.File
import java.io.FileInputStream import java.io.FileInputStream
@@ -24,10 +25,10 @@ class DefaultEndpoint(
.let { provideResponse(it) } .let { provideResponse(it) }
private fun provideResponse(response: LuaTable) = when { private fun provideResponse(response: LuaTable) = when (response.get("data")) {
response.get("data") !is LuaNil -> getTextResponse(response) is LuaString -> getTextResponse(response)
response.get("file") !is LuaNil -> getFileResponse(response) is LuaUserdata -> getFileResponse(response)
else -> throw IllegalArgumentException("Provide 'data' or 'file' in response table") else -> throw IllegalArgumentException("Supported only string and file data types")
} }
private fun getTextResponse(response: LuaTable) = newFixedLengthResponse( private fun getTextResponse(response: LuaTable) = newFixedLengthResponse(
@@ -37,9 +38,14 @@ class DefaultEndpoint(
) )
private fun getFileResponse(response: LuaTable): Response? { private fun getFileResponse(response: LuaTable): Response? {
val file = File(response.get("file").checkjstring()) val file = CoerceLuaToJava.coerce(response.get("data"), File::class.java) as File
val stream = BufferedInputStream(FileInputStream(file)) val stream = BufferedInputStream(FileInputStream(file))
val length = file.length() val length = file.length()
if(!getCached(response)) {
file.delete()
}
return newFixedLengthResponse( return newFixedLengthResponse(
getStatus(response), getStatus(response),
getMimeType(response), getMimeType(response),
@@ -60,5 +66,7 @@ class DefaultEndpoint(
private fun getData(response: LuaTable) = response.get("data").checkjstring() private fun getData(response: LuaTable) = response.get("data").checkjstring()
private fun getCached(response: LuaTable) = response.get("cached").optboolean(false)
override fun toString() = "D[${uri.template}]" override fun toString() = "D[${uri.template}]"
} }