|
|
|
@ -9,6 +9,8 @@ from tenacity import (
|
|
|
|
|
import openai
|
|
|
|
|
|
|
|
|
|
MessageRole = Literal["system", "user", "assistant"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass()
|
|
|
|
|
class Message():
|
|
|
|
|
role: MessageRole
|
|
|
|
@ -64,6 +66,7 @@ def gpt_chat(
|
|
|
|
|
|
|
|
|
|
return [choice.message.content for choice in response.choices] # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelBase():
|
|
|
|
|
def __init__(self, name: str):
|
|
|
|
|
self.name = name
|
|
|
|
@ -106,33 +109,36 @@ class GPTDavinci(ModelBase):
|
|
|
|
|
return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StarChat(ModelBase):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
class HFModelBase(ModelBase):
|
|
|
|
|
"""
|
|
|
|
|
Base for huggingface chat models
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_name: str, hf_model_name: str, eos_token_id: int):
|
|
|
|
|
import torch
|
|
|
|
|
from transformers import pipeline
|
|
|
|
|
self.name = "star-chat"
|
|
|
|
|
self.name = model_name
|
|
|
|
|
self.hf_model_name = hf_model_name
|
|
|
|
|
self.eos_token_id = eos_token_id
|
|
|
|
|
self.pipe = pipeline(
|
|
|
|
|
"text-generation", model="HuggingFaceH4/starchat-beta", torch_dtype=torch.bfloat16, device_map="auto")
|
|
|
|
|
"text-generation", model=hf_model_name, torch_dtype=torch.bfloat16, device_map="auto")
|
|
|
|
|
self.is_chat = True
|
|
|
|
|
|
|
|
|
|
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
|
|
|
|
|
# NOTE: HF does not like temp of 0.0.
|
|
|
|
|
# NOTE: HF does not like temp of 0.0.
|
|
|
|
|
if temperature < 0.0001:
|
|
|
|
|
temperature = 0.0001
|
|
|
|
|
|
|
|
|
|
prompt = ""
|
|
|
|
|
for i, message in enumerate(messages):
|
|
|
|
|
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
|
|
|
|
|
if i == len(messages) - 1:
|
|
|
|
|
prompt += "<|assistant|>\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = self.prepare_prompt(messages)
|
|
|
|
|
|
|
|
|
|
outputs = self.pipe(
|
|
|
|
|
prompt,
|
|
|
|
|
max_new_tokens=max_tokens,
|
|
|
|
|
max_new_tokens=min(
|
|
|
|
|
max_tokens, self.pipe.model.config.max_position_embeddings),
|
|
|
|
|
do_sample=True,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
top_p=0.95,
|
|
|
|
|
eos_token_id=49155,
|
|
|
|
|
eos_token_id=self.eos_token_id,
|
|
|
|
|
num_return_sequences=num_comps,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -140,13 +146,57 @@ class StarChat(ModelBase):
|
|
|
|
|
assert isinstance(outs, list)
|
|
|
|
|
for i, out in enumerate(outs):
|
|
|
|
|
assert isinstance(out, str)
|
|
|
|
|
out = out.split("<|assistant|>")[1]
|
|
|
|
|
if out.endswith("<|end|>"):
|
|
|
|
|
out = out[:-len("<|end|>")]
|
|
|
|
|
|
|
|
|
|
outs[i] = out
|
|
|
|
|
outs[i] = self.extract_output(out)
|
|
|
|
|
|
|
|
|
|
if len(outs) == 1:
|
|
|
|
|
return outs[0] # type: ignore
|
|
|
|
|
else:
|
|
|
|
|
return outs # type: ignore
|
|
|
|
|
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]) -> str:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def extract_output(self, output: str) -> str:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StarChat(HFModelBase):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__("star-chat", "HuggingFaceH4/starchat-beta", 49155)
|
|
|
|
|
|
|
|
|
|
def prepare_prompt(self, messages: List[Message]) -> str:
|
|
|
|
|
prompt = ""
|
|
|
|
|
for i, message in enumerate(messages):
|
|
|
|
|
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
|
|
|
|
|
if i == len(messages) - 1:
|
|
|
|
|
prompt += "<|assistant|>\n"
|
|
|
|
|
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
def extract_output(self, output: str) -> str:
|
|
|
|
|
out = output.split("<|assistant|>")[1]
|
|
|
|
|
if out.endswith("<|end|>"):
|
|
|
|
|
out = out[:-len("<|end|>")]
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# class CodeLlama(HFModelBase):
|
|
|
|
|
# def __init__(self):
|
|
|
|
|
# super().__init__("star-chat", "HuggingFaceH4/starchat-beta", 49155)
|
|
|
|
|
|
|
|
|
|
# def prepare_prompt(self, messages: List[Message]) -> str:
|
|
|
|
|
# prompt = ""
|
|
|
|
|
# for i, message in enumerate(messages):
|
|
|
|
|
# prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
|
|
|
|
|
# if i == len(messages) - 1:
|
|
|
|
|
# prompt += "<|assistant|>\n"
|
|
|
|
|
|
|
|
|
|
# return prompt
|
|
|
|
|
|
|
|
|
|
# def extract_output(self, output: str) -> str:
|
|
|
|
|
# out = output.split("<|assistant|>")[1]
|
|
|
|
|
# if out.endswith("<|end|>"):
|
|
|
|
|
# out = out[:-len("<|end|>")]
|
|
|
|
|
|
|
|
|
|
# return out
|
|
|
|
|