|
|
|
@ -137,10 +137,20 @@ class GenerationPipeline:
|
|
|
|
|
Returns:
|
|
|
|
|
generated text.
|
|
|
|
|
"""
|
|
|
|
|
# set generation params
|
|
|
|
|
max_new_tokens = kwargs.get("max_new_tokens", 30)
|
|
|
|
|
temperature = kwargs.get("temperature", 1.0)
|
|
|
|
|
top_k = kwargs.get("top_k", 50)
|
|
|
|
|
top_p = kwargs.get("top_p", 1)
|
|
|
|
|
repetition_penalty = kwargs.get("repetition_penalty", 1)
|
|
|
|
|
do_sample = kwargs.get("do_sample", False)
|
|
|
|
|
num_return_sequences = kwargs.get("num_return_sequences", 1)
|
|
|
|
|
print(f"Generating with parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, do_sample={do_sample}, num_return_sequences={num_return_sequences}")
|
|
|
|
|
|
|
|
|
|
# If text is longer than max model length, we reduce max input length to ensure
|
|
|
|
|
# the user indicated generation tokens is preserved.
|
|
|
|
|
max_input_len = (
|
|
|
|
|
self.max_length - kwargs.get("max_new_tokens")
|
|
|
|
|
self.max_length - max_new_tokens
|
|
|
|
|
if not self.is_encdec
|
|
|
|
|
else self.max_length
|
|
|
|
|
)
|
|
|
|
@ -152,15 +162,6 @@ class GenerationPipeline:
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
)
|
|
|
|
|
encoded_prompt = encoded_prompt.to(self.device)
|
|
|
|
|
# set generation params
|
|
|
|
|
max_new_tokens = kwargs.get("max_new_tokens", 30)
|
|
|
|
|
temperature = kwargs.get("temperature", 1.0)
|
|
|
|
|
top_k = kwargs.get("top_k", 50)
|
|
|
|
|
top_p = kwargs.get("top_p", 1)
|
|
|
|
|
repetition_penalty = kwargs.get("repetition_penalty", 1)
|
|
|
|
|
do_sample = kwargs.get("do_sample", False)
|
|
|
|
|
num_return_sequences = kwargs.get("num_return_sequences", 1)
|
|
|
|
|
print(f"Generating with parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, do_sample={do_sample}, num_return_sequences={num_return_sequences}")
|
|
|
|
|
|
|
|
|
|
output_dict = self.model.generate( # type: ignore
|
|
|
|
|
input_ids=encoded_prompt['input_ids'],
|
|
|
|
|