From 9992beaff9205825993ca65f589e9661bcadd939 Mon Sep 17 00:00:00 2001 From: roiperlman Date: Thu, 9 May 2024 03:53:13 +0300 Subject: [PATCH] community: Add arguments to whisper parser (#20378) **Description:** Added a few additional arguments to the whisper parser, which can be consumed by the underlying API. The prompt is especially important to fine-tune transcriptions. --------- Co-authored-by: Roi Perlman Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur --- .../document_loaders/parsers/audio.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/parsers/audio.py b/libs/community/langchain_community/document_loaders/parsers/audio.py index 31d542e87b..c8fa4c3ed3 100644 --- a/libs/community/langchain_community/document_loaders/parsers/audio.py +++ b/libs/community/langchain_community/document_loaders/parsers/audio.py @@ -1,7 +1,7 @@ import logging import os import time -from typing import Dict, Iterator, Optional, Tuple +from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Union from langchain_core.documents import Document @@ -31,12 +31,32 @@ class OpenAIWhisperParser(BaseBlobParser): *, chunk_duration_threshold: float = 0.1, base_url: Optional[str] = None, + language: Union[str, None] = None, + prompt: Union[str, None] = None, + response_format: Union[ + Literal["json", "text", "srt", "verbose_json", "vtt"], None + ] = None, + temperature: Union[float, None] = None, ): self.api_key = api_key self.chunk_duration_threshold = chunk_duration_threshold self.base_url = ( base_url if base_url is not None else os.environ.get("OPENAI_API_BASE") ) + self.language = language + self.prompt = prompt + self.response_format = response_format + self.temperature = temperature + + @property + def _create_params(self) -> Dict[str, Any]: + params = { + "language": self.language, + "prompt": self.prompt, + "response_format": self.response_format, + "temperature": self.temperature, + } + return {k: v for k, v in params.items() if v is not None} def lazy_parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob.""" @@ -95,7 +115,7 @@ class OpenAIWhisperParser(BaseBlobParser): try: if is_openai_v1(): transcript = client.audio.transcriptions.create( - model="whisper-1", file=file_obj + model="whisper-1", file=file_obj, **self._create_params ) else: transcript = openai.Audio.transcribe("whisper-1", file_obj)