max tokens

pull/79/head
Michael Wornow 1 year ago
parent c1ac99ae34
commit edd5f538e0

@ -10,15 +10,13 @@
"request": "launch",
"module": "manifest.api.app",
"cwd": "${fileDirname}",
"preLaunchTask": "export_cuda",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"--model_type", "huggingface",
"--model_name_or_path", "/local-scratch-nvme/nigam/huggingface/pretrained/gpt-j-6B",
"--model_name_or_path", "~/desktop/lumia_model",
"--model_generation_type", "text-generation",
"--use_hf_parallelize",
"--port", "5009",
"--port", "5001",
"--is_flask_debug",
]
}

@ -137,10 +137,20 @@ class GenerationPipeline:
Returns:
generated text.
"""
# set generation params
max_new_tokens = kwargs.get("max_new_tokens", 30)
temperature = kwargs.get("temperature", 1.0)
top_k = kwargs.get("top_k", 50)
top_p = kwargs.get("top_p", 1)
repetition_penalty = kwargs.get("repetition_penalty", 1)
do_sample = kwargs.get("do_sample", False)
num_return_sequences = kwargs.get("num_return_sequences", 1)
print(f"Generating with parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, do_sample={do_sample}, num_return_sequences={num_return_sequences}")
# If text is longer than max model length, we reduce max input length to ensure
# the user indicated generation tokens is preserved.
max_input_len = (
self.max_length - kwargs.get("max_new_tokens")
self.max_length - max_new_tokens
if not self.is_encdec
else self.max_length
)
@ -152,15 +162,6 @@ class GenerationPipeline:
return_tensors="pt",
)
encoded_prompt = encoded_prompt.to(self.device)
# set generation params
max_new_tokens = kwargs.get("max_new_tokens", 30)
temperature = kwargs.get("temperature", 1.0)
top_k = kwargs.get("top_k", 50)
top_p = kwargs.get("top_p", 1)
repetition_penalty = kwargs.get("repetition_penalty", 1)
do_sample = kwargs.get("do_sample", False)
num_return_sequences = kwargs.get("num_return_sequences", 1)
print(f"Generating with parameters: max_new_tokens={max_new_tokens}, temperature={temperature}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, do_sample={do_sample}, num_return_sequences={num_return_sequences}")
output_dict = self.model.generate( # type: ignore
input_ids=encoded_prompt['input_ids'],

Loading…
Cancel
Save