feat: #8 Allow user to specify a different LLM instead of OpenAI

pull/11/head
namuan 1 year ago
parent 825353b9a2
commit b72cd5006e

2
poetry.lock generated

@ -2887,4 +2887,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
[metadata]
lock-version = "2.0"
python-versions = ">=3.9.0, <4.0"
content-hash = "34a3fad8c1a4f517b1a0b69126ef05892054d27a1ee8600d0c57da19f61130fa"
content-hash = "3bc8834a38ff7a2334dc44c7b811dfa36b18aeba49dc8afe86147606e13c4af9"

@ -42,6 +42,7 @@ python-dotenv = "^0.21.0"
panel = "^0.14.2"
slug = "^2.0"
sentence-transformers = "^2.2.2"
transformers = "^4.26.0"
[tool.poetry.group.dev.dependencies]
autoflake = "*"

@ -39,6 +39,13 @@ def parse_args() -> Namespace:
default="openai",
help="Embedding to use",
)
parser.add_argument(
"-l",
"--llm",
choices=["openai", "huggingface"],
default="openai",
help="LLM to use",
)
parser.add_argument(
"-v",

@ -6,11 +6,14 @@ from typing import Any
import faiss # type: ignore
import openai
import torch
from langchain import OpenAI, VectorDBQA
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores.faiss import FAISS
@ -18,6 +21,7 @@ from py_executable_checklist.workflow import WorkflowBase, run_command
from pypdf import PdfReader
from rich import print
from slug import slug # type: ignore
from transformers import pipeline # type: ignore
from doc_search import retry
@ -144,7 +148,7 @@ class CombineAllText(WorkflowBase):
for file in self.pages_text_path.glob("*.txt"):
text += file.read_text()
text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=0)
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_text(text)
return {
@ -232,6 +236,7 @@ class AskQuestion(WorkflowBase):
input_question: str
search_index: Any
llm: str
def prompt_from_question(self) -> PromptTemplate:
template = """
@ -251,9 +256,21 @@ ${question}
return PromptTemplate(input_variables=["context", "question"], template=template)
def llm_provider(self) -> BaseLLM:
if self.llm == "huggingface":
pipe = pipeline(
"text2text-generation",
model="pszemraj/long-t5-tglobal-base-16384-book-summary",
device=0 if torch.cuda.is_available() else -1,
)
return HuggingFacePipeline(pipeline=pipe)
else:
return OpenAI()
def execute(self) -> dict:
llm = self.llm_provider()
prompt = self.prompt_from_question()
qa = VectorDBQA.from_llm(llm=OpenAI(), prompt=prompt, vectorstore=self.search_index)
qa = VectorDBQA.from_llm(llm=llm, prompt=prompt, vectorstore=self.search_index)
output = self.send_prompt(qa, self.input_question)
return {"output": output}

Loading…
Cancel
Save