You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
manifest/manifest/api/app.py

279 lines
8.1 KiB
Python

"""Flask app."""
import argparse
import io
import json
import logging
import os
import socket
from typing import Dict
import pkg_resources
from flask import Flask, Response, request
from manifest.api.models.diffuser import DiffuserModel
from manifest.api.models.huggingface import (
MODEL_GENTYPE_REGISTRY,
CrossModalEncoderModel,
TextGenerationModel,
)
from manifest.api.response import ModelResponse
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = logging.getLogger(__name__)
app = Flask(__name__) # define app using Flask
# Will be global
model = None
model_type = None
PORT = int(os.environ.get("FLASK_PORT", 5000))
MODEL_CONSTRUCTORS = {
"huggingface": TextGenerationModel,
"huggingface_crossmodal": CrossModalEncoderModel,
"diffuser": DiffuserModel,
}
def parse_args() -> argparse.Namespace:
"""Generate args."""
parser = argparse.ArgumentParser(description="Model args")
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type used for finding constructor.",
choices=MODEL_CONSTRUCTORS.keys(),
)
parser.add_argument(
"--model_generation_type",
default=None,
type=str,
help="Model generation type.",
choices=MODEL_GENTYPE_REGISTRY.keys(),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
help="Name of model or path to model. Used in initialize of model class.",
)
parser.add_argument(
"--cache_dir", default=None, type=str, help="Cache directory for models."
)
parser.add_argument(
"--device", type=int, default=0, help="Model device. -1 for CPU."
)
parser.add_argument(
"--fp16", action="store_true", help="Force use fp16 for model params."
)
parser.add_argument(
"--percent_max_gpu_mem_reduction",
type=float,
default=0.85,
help="Used with accelerate multigpu. Scales down max memory.",
)
parser.add_argument(
"--use_bitsandbytes",
action="store_true",
help=("Use bits and bytes. " "This will override --device parameter."),
)
parser.add_argument(
"--use_accelerate_multigpu",
action="store_true",
help=(
"Use accelerate for multi gpu inference. "
"This will override --device parameter."
),
)
parser.add_argument(
"--use_hf_parallelize",
action="store_true",
help=(
"Use HF parallelize for multi gpu inference. "
"This will override --device parameter."
),
)
parser.add_argument(
"--use_deepspeed",
action="store_true",
help=("Use deepspeed. This will override --device parameter."),
)
args = parser.parse_args()
return args
def is_port_in_use(port: int) -> bool:
"""Check if port is in use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
def main() -> None:
"""Run main."""
kwargs = parse_args()
if is_port_in_use(PORT):
raise ValueError(f"Port {PORT} is already in use.")
global model_type
model_type = kwargs.model_type
model_gen_type = kwargs.model_generation_type
model_name_or_path = kwargs.model_name_or_path
if not model_name_or_path:
raise ValueError("Must provide model_name_or_path.")
if kwargs.use_accelerate_multigpu:
logger.info("Using accelerate. Overridding --device argument.")
if (
kwargs.percent_max_gpu_mem_reduction <= 0
or kwargs.percent_max_gpu_mem_reduction > 1
):
raise ValueError("percent_max_gpu_mem_reduction must be in (0, 1].")
if (
sum(
[
kwargs.use_accelerate_multigpu,
kwargs.use_hf_parallelize,
kwargs.use_bitsandbytes,
kwargs.use_deepspeed,
]
)
> 1
):
raise ValueError(
"Only one of use_accelerate_multigpu, use_hf_parallelize, "
"use_bitsandbytes, and use_deepspeed can be set."
)
# Global model
global model
model = MODEL_CONSTRUCTORS[model_type](
model_name_or_path,
model_type=model_gen_type,
cache_dir=kwargs.cache_dir,
device=kwargs.device,
use_accelerate=kwargs.use_accelerate_multigpu,
use_parallelize=kwargs.use_hf_parallelize,
use_bitsandbytes=kwargs.use_bitsandbytes,
use_deepspeed=kwargs.use_deepspeed,
perc_max_gpu_mem_red=kwargs.percent_max_gpu_mem_reduction,
use_fp16=kwargs.fp16,
)
app.run(host="0.0.0.0", port=PORT)
@app.route("/completions", methods=["POST"])
def completions() -> Response:
"""Get completions for generation."""
prompt = request.json["prompt"]
del request.json["prompt"]
generation_args = request.json
if not isinstance(prompt, (str, list)):
raise ValueError("Prompt must be a str or list of str")
try:
result_gens = []
for generations in model.generate(prompt, **generation_args):
result_gens.append(generations)
if model_type == "diffuser":
# Assign None logprob as it's not supported in diffusers
results = [
{"array": r[0], "logprob": None, "token_logprobs": None}
for r in result_gens
]
res_type = "image_generation"
else:
results = [
{"text": r[0], "logprob": r[1], "token_logprobs": r[2]}
for r in result_gens
]
res_type = "text_completion"
# transform the result into the openai format
return Response(
json.dumps(ModelResponse(results, response_type=res_type).__dict__()),
status=200,
)
except Exception as e:
logger.error(e)
return Response(
json.dumps({"message": str(e)}),
status=400,
)
@app.route("/embed", methods=["POST"])
def embed() -> Dict:
"""Get embed for generation."""
modality = request.json["modality"]
if modality == "text":
prompts = request.json["prompts"]
elif modality == "image":
import base64
from PIL import Image
prompts = [
Image.open(io.BytesIO(base64.b64decode(data)))
for data in request.json["prompts"]
]
else:
raise ValueError("modality must be text or image")
results = []
embeddings = model.embed(prompts)
for embedding in embeddings:
results.append(embedding.tolist())
# transform the result into the openai format
# return Response(results, response_type="text_completion").__dict__()
return {"result": results}
@app.route("/score_sequence", methods=["POST"])
def score_sequence() -> Response:
"""Get logprob of prompt."""
prompt = request.json["prompt"]
del request.json["prompt"]
generation_args = request.json
if not isinstance(prompt, (str, list)):
raise ValueError("Prompt must be a str or list of str")
try:
score_list = model.score_sequence(prompt, **generation_args)
results = [
{
"text": prompt if isinstance(prompt, str) else prompt[i],
"logprob": r[0],
"token_logprobs": r[1],
}
for i, r in enumerate(score_list)
]
# transform the result into the openai format
return Response(
json.dumps(
ModelResponse(results, response_type="prompt_logit_score").__dict__()
),
status=200,
)
except Exception as e:
logger.error(e)
return Response(
json.dumps({"message": str(e)}),
status=400,
)
@app.route("/params", methods=["POST"])
def params() -> Dict:
"""Get model params."""
return model.get_init_params()
@app.route("/")
def index() -> str:
"""Get index completion."""
fn = pkg_resources.resource_filename("metaseq", "service/index.html")
with open(fn) as f:
return f.read()
if __name__ == "__main__":
main()