You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
29 lines
896 B
Python
29 lines
896 B
Python
from typing import Union, Literal
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain import OpenAI
|
|
from langchain.schema import (
|
|
HumanMessage
|
|
)
|
|
|
|
class AnyOpenAILLM:
|
|
def __init__(self, *args, **kwargs):
|
|
# Determine model type from the kwargs
|
|
model_name = kwargs.get('model_name', 'gpt-3.5-turbo')
|
|
if model_name.split('-')[0] == 'text':
|
|
self.model = OpenAI(*args, **kwargs)
|
|
self.model_type = 'completion'
|
|
else:
|
|
self.model = ChatOpenAI(*args, **kwargs)
|
|
self.model_type = 'chat'
|
|
|
|
def __call__(self, prompt: str):
|
|
if self.model_type == 'completion':
|
|
return self.model(prompt)
|
|
else:
|
|
return self.model(
|
|
[
|
|
HumanMessage(
|
|
content=prompt,
|
|
)
|
|
]
|
|
).content |