Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
a5d2e48
1
Parent(s):
f5d1866
test: capture openai's param handling
Browse files
src/faster_whisper_server/routers/stt.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
| 3 |
import asyncio
|
| 4 |
from io import BytesIO
|
| 5 |
import logging
|
| 6 |
-
from typing import TYPE_CHECKING, Annotated
|
| 7 |
|
| 8 |
from fastapi import (
|
| 9 |
APIRouter,
|
|
@@ -30,6 +30,7 @@ from faster_whisper_server.config import (
|
|
| 30 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
|
| 31 |
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
|
| 32 |
from faster_whisper_server.server_models import (
|
|
|
|
| 33 |
TranscriptionJsonResponse,
|
| 34 |
TranscriptionVerboseJsonResponse,
|
| 35 |
)
|
|
@@ -165,7 +166,7 @@ def transcribe_file(
|
|
| 165 |
response_format: Annotated[ResponseFormat | None, Form()] = None,
|
| 166 |
temperature: Annotated[float, Form()] = 0.0,
|
| 167 |
timestamp_granularities: Annotated[
|
| 168 |
-
|
| 169 |
Form(alias="timestamp_granularities[]"),
|
| 170 |
] = ["segment"],
|
| 171 |
stream: Annotated[bool, Form()] = False,
|
|
|
|
| 3 |
import asyncio
|
| 4 |
from io import BytesIO
|
| 5 |
import logging
|
| 6 |
+
from typing import TYPE_CHECKING, Annotated
|
| 7 |
|
| 8 |
from fastapi import (
|
| 9 |
APIRouter,
|
|
|
|
| 30 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
|
| 31 |
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
|
| 32 |
from faster_whisper_server.server_models import (
|
| 33 |
+
TimestampGranularities,
|
| 34 |
TranscriptionJsonResponse,
|
| 35 |
TranscriptionVerboseJsonResponse,
|
| 36 |
)
|
|
|
|
| 166 |
response_format: Annotated[ResponseFormat | None, Form()] = None,
|
| 167 |
temperature: Annotated[float, Form()] = 0.0,
|
| 168 |
timestamp_granularities: Annotated[
|
| 169 |
+
TimestampGranularities,
|
| 170 |
Form(alias="timestamp_granularities[]"),
|
| 171 |
] = ["segment"],
|
| 172 |
stream: Annotated[bool, Form()] = False,
|
src/faster_whisper_server/server_models.py
CHANGED
|
@@ -107,3 +107,15 @@ class ModelObject(BaseModel):
|
|
| 107 |
]
|
| 108 |
},
|
| 109 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
]
|
| 108 |
},
|
| 109 |
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
TimestampGranularities = list[Literal["segment", "word"]]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [
|
| 116 |
+
[], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities
|
| 117 |
+
["segment"],
|
| 118 |
+
["word"],
|
| 119 |
+
["word", "segment"],
|
| 120 |
+
["segment", "word"], # same as ["word", "segment"] but order is different
|
| 121 |
+
]
|
tests/conftest.py
CHANGED
|
@@ -5,7 +5,7 @@ import os
|
|
| 5 |
from fastapi.testclient import TestClient
|
| 6 |
from faster_whisper_server.main import create_app
|
| 7 |
from httpx import ASGITransport, AsyncClient
|
| 8 |
-
from openai import OpenAI
|
| 9 |
import pytest
|
| 10 |
import pytest_asyncio
|
| 11 |
|
|
@@ -35,3 +35,10 @@ async def aclient() -> AsyncGenerator[AsyncClient, None]:
|
|
| 35 |
@pytest.fixture()
|
| 36 |
def openai_client(client: TestClient) -> OpenAI:
|
| 37 |
return OpenAI(api_key="cant-be-empty", http_client=client)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from fastapi.testclient import TestClient
|
| 6 |
from faster_whisper_server.main import create_app
|
| 7 |
from httpx import ASGITransport, AsyncClient
|
| 8 |
+
from openai import AsyncOpenAI, OpenAI
|
| 9 |
import pytest
|
| 10 |
import pytest_asyncio
|
| 11 |
|
|
|
|
| 35 |
@pytest.fixture()
|
| 36 |
def openai_client(client: TestClient) -> OpenAI:
|
| 37 |
return OpenAI(api_key="cant-be-empty", http_client=client)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@pytest.fixture()
|
| 41 |
+
def actual_openai_client() -> AsyncOpenAI:
|
| 42 |
+
return AsyncOpenAI(
|
| 43 |
+
base_url="https://api.openai.com/v1"
|
| 44 |
+
) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
|
tests/openai_timestamp_granularities_test.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501
|
| 2 |
+
|
| 3 |
+
from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
|
| 4 |
+
from openai import AsyncOpenAI, BadRequestError
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.mark.asyncio()
|
| 9 |
+
@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
|
| 10 |
+
async def test_openai_json_response_format_and_timestamp_granularities_combinations(
|
| 11 |
+
actual_openai_client: AsyncOpenAI,
|
| 12 |
+
timestamp_granularities: TimestampGranularities,
|
| 13 |
+
) -> None:
|
| 14 |
+
audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
|
| 15 |
+
|
| 16 |
+
if "word" in timestamp_granularities:
|
| 17 |
+
with pytest.raises(BadRequestError):
|
| 18 |
+
await actual_openai_client.audio.transcriptions.create(
|
| 19 |
+
file=audio_file,
|
| 20 |
+
model="whisper-1",
|
| 21 |
+
response_format="json",
|
| 22 |
+
timestamp_granularities=timestamp_granularities,
|
| 23 |
+
)
|
| 24 |
+
else:
|
| 25 |
+
await actual_openai_client.audio.transcriptions.create(
|
| 26 |
+
file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.mark.asyncio()
|
| 31 |
+
@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
|
| 32 |
+
async def test_openai_verbose_json_response_format_and_timestamp_granularities_combinations(
|
| 33 |
+
actual_openai_client: AsyncOpenAI,
|
| 34 |
+
timestamp_granularities: TimestampGranularities,
|
| 35 |
+
) -> None:
|
| 36 |
+
audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
|
| 37 |
+
|
| 38 |
+
transcription = await actual_openai_client.audio.transcriptions.create(
|
| 39 |
+
file=audio_file,
|
| 40 |
+
model="whisper-1",
|
| 41 |
+
response_format="verbose_json",
|
| 42 |
+
timestamp_granularities=timestamp_granularities,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
assert transcription.__pydantic_extra__
|
| 46 |
+
if timestamp_granularities == ["word"]:
|
| 47 |
+
# This is an exception where segments are not present
|
| 48 |
+
assert transcription.__pydantic_extra__.get("segments") is None
|
| 49 |
+
assert transcription.__pydantic_extra__.get("words") is not None
|
| 50 |
+
elif "word" in timestamp_granularities:
|
| 51 |
+
assert transcription.__pydantic_extra__.get("segments") is not None
|
| 52 |
+
assert transcription.__pydantic_extra__.get("words") is not None
|
| 53 |
+
else:
|
| 54 |
+
# Unless explicitly requested, words are not present
|
| 55 |
+
assert transcription.__pydantic_extra__.get("segments") is not None
|
| 56 |
+
assert transcription.__pydantic_extra__.get("words") is None
|