feat: toma diffusers support (#48)

pull/82/head
Laurel Orr 1 year ago committed by GitHub
parent 56eae406ce
commit 876d27bd2d

@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -28,30 +28,14 @@
"manifest = Manifest(\n",
" client_name=\"chatgpt\",\n",
" client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sure! Pickling is a way to save things, like food or toys, so that they can be used later. Imagine you have a toy that you really like, but you have to go to school and can't play with it. You can put the toy in a special jar and close the lid tight to keep it safe until you get home. That's kind of like pickling. You're taking something that you want to save, and putting it in a special container so it won't go bad or get lost. Just like the toy in the jar, pickled food can last a long time without going bad.\n"
]
}
],
"source": [
"print(manifest.run(\"Can you explain the pickling process to a four-year old?\"))\n"
")\n",
"print(manifest.run(\"Describe in a single, short sentence what is the best sandwhich in the world. Be short and concise.\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "bootleg",
"display_name": "mlcore",
"language": "python",
"name": "python3"
},
@ -65,12 +49,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.10.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
"hash": "1ea9cc00d433352044b557b1784ac6e58df03de4b7bb312554014351989eb135"
}
}
},

File diff suppressed because one or more lines are too long

@ -1,6 +1,6 @@
"""Huggingface model."""
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from diffusers import StableDiffusionPipeline
@ -14,14 +14,15 @@ class DiffuserModel(Model):
def __init__(
self,
model_name_or_path: str,
model_config: str = None,
cache_dir: str = None,
model_config: Optional[str] = None,
cache_dir: Optional[str] = None,
device: int = 0,
use_accelerate: bool = False,
use_parallelize: bool = False,
use_bitsandbytes: bool = False,
use_deepspeed: bool = False,
perc_max_gpu_mem_red: float = 1.0,
use_fp16: bool = True,
use_fp16: bool = False,
):
"""
Initialize model.
@ -36,12 +37,14 @@ class DiffuserModel(Model):
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
use_bitsandbytes: use HF bits and bytes
use_deepspeed: use deepspeed
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
"""
if use_accelerate or use_parallelize or use_bitsandbytes:
if use_accelerate or use_parallelize or use_bitsandbytes or use_deepspeed:
raise ValueError(
"Cannot use accelerate or parallelize or bitsandbytes with diffusers"
"Cannot use accelerate or parallelize or "
"bitsandbytes or deepspeeed with diffusers"
)
# Check if providing path
self.model_path = model_name_or_path

@ -60,7 +60,6 @@ class Model(ABC):
"""
raise NotImplementedError()
@abstractmethod
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
"""
Compute embedding for prompts.

@ -11,9 +11,17 @@ RESPONSE_CONSTRUCTORS = {
"logits_key": "logprobs",
"item_key": "array",
},
"tomadiffuser": {
"generation_key": "choices",
"logits_key": "logprobs",
"item_key": "array",
},
}
CACHE_CONSTRUCTOR = {"diffuser": ArraySerializer}
CACHE_CONSTRUCTOR = {
"diffuser": ArraySerializer,
"tomadiffuser": ArraySerializer,
}
class Cache(ABC):

@ -22,6 +22,7 @@ class DiffuserClient(Client):
"n": ("num_images_per_prompt", 1),
"guidance_scale": ("guidance_scale", 7.5),
"eta": ("eta", 0.0),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = DiffusionRequest
@ -42,6 +43,14 @@ class DiffuserClient(Client):
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
self.model_params = self.get_model_params()
def to_numpy(self, image: np.ndarray) -> np.ndarray:
"""Convert a numpy image to a PIL image.
Adapted from https://github.com/huggingface/diffusers/blob/src/diffusers/pipelines/pipeline_utils.py#L808 # noqa: E501
"""
image = (image * 255).round().astype("uint8")
return image
def close(self) -> None:
"""Close the client."""
pass
@ -88,5 +97,5 @@ class DiffuserClient(Client):
"""
# Convert array to np.array
for choice in response["choices"]:
choice["array"] = np.array(choice["array"])
choice["array"] = self.to_numpy(np.array(choice["array"]))
return response

@ -14,7 +14,6 @@ logger = logging.getLogger(__name__)
# Engines are dynamically instantiated from API
# but a few example engines are listed below.
TOMA_ENGINES = {
# "StableDiffusion",
"Together-gpt-JT-6B-v1",
}
@ -24,7 +23,7 @@ class TOMAClient(Client):
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "gpt-j-6b"),
"engine": ("model", "Together-gpt-JT-6B-v1"),
"temperature": ("temperature", 0.1),
"max_tokens": ("max_tokens", 32),
# n is depricated with new API but will come back online soon

@ -0,0 +1,73 @@
"""TOMA client."""
import base64
import io
import logging
from typing import Any, Dict
import numpy as np
from PIL import Image
from manifest.clients.toma import TOMAClient
from manifest.request import DiffusionRequest
logger = logging.getLogger(__name__)
# Engines are dynamically instantiated from API
# but a few example engines are listed below.
TOMA_ENGINES = {
"StableDiffusion",
}
class TOMADiffuserClient(TOMAClient):
"""TOMADiffuser client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "StableDiffusion"),
"num_inference_steps": ("steps", 50),
"height": ("height", 512),
"width": ("width", 512),
"n": ("n", 1),
"guidance_scale": ("guidance_scale", 7.5),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = DiffusionRequest # type: ignore
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "tomadiffuser", "engine": getattr(self, "engine")}
def format_response(self, response: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
Return:
response as dict
"""
return {
"model": getattr(self, "engine"),
"choices": [
{
"array": np.array(
Image.open(
io.BytesIO(
base64.decodebytes(bytes(item["image_base64"], "utf-8"))
)
)
),
}
for item in response["output"]["choices"]
],
}

@ -15,6 +15,7 @@ from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient
from manifest.clients.toma import TOMAClient
from manifest.clients.toma_diffuser import TOMADiffuserClient
from manifest.response import Response
from manifest.session import Session
@ -30,6 +31,7 @@ CLIENT_CONSTRUCTORS = {
"diffuser": DiffuserClient,
"dummy": DummyClient,
"toma": TOMAClient,
"tomadiffuser": TOMADiffuserClient,
}
CACHE_CONSTRUCTORS = {

Loading…
Cancel
Save