fix: pass do sample generation (#118)

main
Laurel Orr 4 months ago committed by GitHub
parent 637fb147f7
commit c84b2fd10f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -10,7 +10,7 @@ format:
black manifest/ tests/ web_app/
check:
isort -c -v manifest/ tests/ web_app/
isort -c manifest/ tests/ web_app/
black manifest/ tests/ web_app/ --check
flake8 manifest/ tests/ web_app/
mypy manifest/ tests/ web_app/

@ -75,7 +75,7 @@ class DiffuserModel(Model):
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
) -> List[Tuple[Any, float, List[str], List[float]]]:
"""
Generate the prompt from model.

@ -132,7 +132,7 @@ class GenerationPipeline:
def __call__(
self, text: Union[str, List[str]], **kwargs: Any
) -> List[Dict[str, Union[str, List[float]]]]:
) -> List[Dict[str, Union[str, List[float], List[str]]]]:
"""Generate from text.
Args:
@ -162,6 +162,7 @@ class GenerationPipeline:
top_p=kwargs.get("top_p"),
repetition_penalty=kwargs.get("repetition_penalty"),
num_return_sequences=kwargs.get("num_return_sequences"),
do_sample=kwargs.get("do_sample"),
)
kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None}
output_dict = self.model.generate( # type: ignore
@ -587,7 +588,7 @@ class TextGenerationModel(HuggingFaceModel):
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
) -> List[Tuple[Any, float, List[str], List[float]]]:
"""
Generate the prompt from model.
@ -616,7 +617,7 @@ class TextGenerationModel(HuggingFaceModel):
(
cast(str, r["generated_text"]),
sum(cast(List[float], r["logprobs"])),
cast(List[int], r["tokens"]),
cast(List[str], r["tokens"]),
cast(List[float], r["logprobs"]),
)
for r in result

@ -45,7 +45,7 @@ class Model:
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
) -> List[Tuple[Any, float, List[str], List[float]]]:
"""
Generate the prompt from model.

@ -66,7 +66,7 @@ class SentenceTransformerModel(Model):
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
) -> List[Tuple[Any, float, List[str], List[float]]]:
"""
Generate the prompt from model.

Loading…
Cancel
Save