better design

main
cassanof 9 months ago
parent c3cfcb8863
commit f42e651445

@ -9,6 +9,8 @@ from tenacity import (
import openai
MessageRole = Literal["system", "user", "assistant"]
@dataclasses.dataclass()
class Message():
role: MessageRole
@ -64,6 +66,7 @@ def gpt_chat(
return [choice.message.content for choice in response.choices] # type: ignore
class ModelBase():
def __init__(self, name: str):
self.name = name
@ -106,33 +109,36 @@ class GPTDavinci(ModelBase):
return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps)
class StarChat(ModelBase):
def __init__(self):
class HFModelBase(ModelBase):
"""
Base for huggingface chat models
"""
def __init__(self, model_name: str, hf_model_name: str, eos_token_id: int):
import torch
from transformers import pipeline
self.name = "star-chat"
self.name = model_name
self.hf_model_name = hf_model_name
self.eos_token_id = eos_token_id
self.pipe = pipeline(
"text-generation", model="HuggingFaceH4/starchat-beta", torch_dtype=torch.bfloat16, device_map="auto")
"text-generation", model=hf_model_name, torch_dtype=torch.bfloat16, device_map="auto")
self.is_chat = True
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
# NOTE: HF does not like temp of 0.0.
# NOTE: HF does not like temp of 0.0.
if temperature < 0.0001:
temperature = 0.0001
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
if i == len(messages) - 1:
prompt += "<|assistant|>\n"
prompt = self.prepare_prompt(messages)
outputs = self.pipe(
prompt,
max_new_tokens=max_tokens,
max_new_tokens=min(
max_tokens, self.pipe.model.config.max_position_embeddings),
do_sample=True,
temperature=temperature,
top_p=0.95,
eos_token_id=49155,
eos_token_id=self.eos_token_id,
num_return_sequences=num_comps,
)
@ -140,13 +146,57 @@ class StarChat(ModelBase):
assert isinstance(outs, list)
for i, out in enumerate(outs):
assert isinstance(out, str)
out = out.split("<|assistant|>")[1]
if out.endswith("<|end|>"):
out = out[:-len("<|end|>")]
outs[i] = out
outs[i] = self.extract_output(out)
if len(outs) == 1:
return outs[0] # type: ignore
else:
return outs # type: ignore
def prepare_prompt(self, messages: List[Message]) -> str:
raise NotImplementedError
def extract_output(self, output: str) -> str:
raise NotImplementedError
class StarChat(HFModelBase):
def __init__(self):
super().__init__("star-chat", "HuggingFaceH4/starchat-beta", 49155)
def prepare_prompt(self, messages: List[Message]) -> str:
prompt = ""
for i, message in enumerate(messages):
prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
if i == len(messages) - 1:
prompt += "<|assistant|>\n"
return prompt
def extract_output(self, output: str) -> str:
out = output.split("<|assistant|>")[1]
if out.endswith("<|end|>"):
out = out[:-len("<|end|>")]
return out
# class CodeLlama(HFModelBase):
# def __init__(self):
# super().__init__("star-chat", "HuggingFaceH4/starchat-beta", 49155)
# def prepare_prompt(self, messages: List[Message]) -> str:
# prompt = ""
# for i, message in enumerate(messages):
# prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
# if i == len(messages) - 1:
# prompt += "<|assistant|>\n"
# return prompt
# def extract_output(self, output: str) -> str:
# out = output.split("<|assistant|>")[1]
# if out.endswith("<|end|>"):
# out = out[:-len("<|end|>")]
# return out

@ -0,0 +1,10 @@
CUDA_VISIBLE_DEVICES=$1 python main.py \
--run_name "reflexion_codellama_$1" \
--root_dir "root" \
--dataset_path ./benchmarks/humaneval-py.jsonl \
--strategy "reflexion" \
--language "py" \
--model "codellama" \
--pass_at_k "1" \
--max_iters "2" \
--verbose | tee ./logs/reflexion_codellama_$1
Loading…
Cancel
Save