@ -7,6 +7,7 @@ from tenacity import (
wait_random_exponential , # type: ignore
)
import openai
from transformers import AutoModelForCausalLM , AutoTokenizer
MessageRole = Literal [ " system " , " user " , " assistant " ]
@ -114,14 +115,10 @@ class HFModelBase(ModelBase):
Base for huggingface chat models
"""
def __init__ ( self , model_name : str , hf_model_name : str , eos_token_id : int ) :
import torch
from transformers import pipeline
def __init__ ( self , model_name : str , model , tokenizer ) :
self . name = model_name
self . hf_model_name = hf_model_name
self . eos_token_id = eos_token_id
self . pipe = pipeline (
" text-generation " , model = hf_model_name , torch_dtype = torch . bfloat16 , device_map = " auto " )
self . model = model
self . tokenizer = tokenizer
self . is_chat = True
def generate_chat ( self , messages : List [ Message ] , max_tokens : int = 1024 , temperature : float = 0.2 , num_comps : int = 1 ) - > Union [ List [ str ] , str ] :
@ -131,18 +128,18 @@ class HFModelBase(ModelBase):
prompt = self . prepare_prompt ( messages )
outputs = self . pip e(
outputs = self . model. generat e(
prompt ,
max_new_tokens = min (
max_tokens , self . pipe. model. config . max_position_embeddings ) ,
max_tokens , self . model. config . max_position_embeddings ) ,
do_sample = True ,
temperature = temperature ,
top_p = 0.95 ,
eos_token_id = self . eos_token_id,
eos_token_id = self . tokenizer. eos_token_id,
num_return_sequences = num_comps ,
)
outs = [ output [ ' generated_text ' ] for output in outputs ] # type: ignore
outs = self . tokenizer . batch_decode ( outputs , skip_special_tokens = True )
assert isinstance ( outs , list )
for i , out in enumerate ( outs ) :
assert isinstance ( out , str )
@ -162,7 +159,16 @@ class HFModelBase(ModelBase):
class StarChat ( HFModelBase ) :
def __init__ ( self ) :
super ( ) . __init__ ( " star-chat " , " HuggingFaceH4/starchat-beta " , 49155 )
import torch
model = AutoModelForCausalLM . from_pretrained (
" HuggingFaceH4/starchat-beta " ,
torch_dtype = torch . bfloat16 ,
device_map = " auto " ,
)
tokenizer = AutoTokenizer . from_pretrained (
" HuggingFaceH4/starchat-beta " ,
)
super ( ) . __init__ ( " star-chat " , model , tokenizer )
def prepare_prompt ( self , messages : List [ Message ] ) - > str :
prompt = " "
@ -181,22 +187,36 @@ class StarChat(HFModelBase):
return out
# class CodeLlama(HFModelBase) :
# def __init__(self):
# super().__init__("star-chat", "HuggingFaceH4/starchat-beta", 49155)
class CodeLlama ( HFModelBase ) :
B_INST , E_INST = " [INST] " , " [/INST] "
B_SYS , E_SYS = " <<SYS>> \n " , " \n <</SYS>> \n \n "
# def prepare_prompt(self, messages: List[Message]) -> str:
# prompt = ""
# for i, message in enumerate(messages):
# prompt += f"<|{message.role}|>\n{message.content}\n<|end|>\n"
# if i == len(messages) - 1:
# prompt += "<|assistant|>\n"
DEFAULT_SYSTEM_PROMPT = """ \
You are a helpful , respectful and honest assistant . Always answer as helpfully as possible , while being safe . Your answers should not include any harmful , unethical , racist , sexist , toxic , dangerous , or illegal content . Please ensure that your responses are socially unbiased and positive in nature .
# return prompt
If a question does not make any sense , or is not factually coherent , explain why instead of answering something not correct . If you don ' t know the answer to a question, please don ' t share false information . """
# def extract_output(self, output: str) -> str:
# out = output.split("<|assistant|>")[1]
# if out.endswith("<|end|>"):
# out = out[:-len("<|end|>")]
def __init__ ( self ) :
super ( ) . __init__ ( " code-llama " , " codellama/CodeLlama-34b-Instruct-hf " , 2 )
self . tokenizer = AutoTokenizer . from_pretrained (
self . hf_model_name ,
add_eos_token = True ,
add_bos_token = True ,
padding_side = ' left '
)
# return out
def prepare_prompt ( self , messages : List [ Message ] ) - > str :
prompt = " "
for i , message in enumerate ( messages ) :
prompt + = f " <| { message . role } |> \n { message . content } \n <|end|> \n "
if i == len ( messages ) - 1 :
prompt + = " <|assistant|> \n "
return prompt
def extract_output ( self , output : str ) - > str :
out = output . split ( " <|assistant|> " ) [ 1 ]
if out . endswith ( " <|end|> " ) :
out = out [ : - len ( " <|end|> " ) ]
return out