You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

88 lines
2.4 KiB

"""Model class."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, Union
import numpy as np
class Model(ABC):
"""Model class."""
def __init__(
model_name_or_path: str,
cache_dir: str,
device: int,
use_accelerate: bool,
use_parallelize: bool,
use_bitsandbytes: bool,
use_deepspeed: bool,
perc_max_gpu_mem_red: float,
use_fp16: bool,
Initialize model.
All arguments will be passed in the request from Manifest.
model_name_or_path: model name string.
cache_dir: cache directory for model.
device: device to use for model.
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
use_bitsandbytes: use HF bits and bytes
use_deepspeed: use deepspeed
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
raise NotImplementedError()
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
raise NotImplementedError()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float]]:
Generate the prompt from model.
Outputs must be generated text and score, not including prompt.
prompt: promt to generate from.
list of generated text (list of length 1 for 1 generation).
raise NotImplementedError()
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
Compute embedding for prompts.
prompt: promt to generate from.
raise NotImplementedError()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[float]:
Score a sequence of choices.
prompt (:obj:`str` or :obj:`List[str]`):
The prompt to score the choices against.
Additional keyword arguments passed along to the :obj:`__call__` method.
raise NotImplementedError()