Add PostgreSQL cache (#53)

Add a Cache for PostgreSQL with GCP. 

Co-authored-by: Laurel Orr <lorr1@cs.stanford.edu>
pull/82/head
Sabri Eyuboglu 1 year ago committed by GitHub
parent c4ad007f02
commit e00d285e21

@ -0,0 +1,124 @@
"""Postgres cache."""
import logging
from typing import Any, Dict, Union
logger = logging.getLogger("postgresql")
logger.setLevel(logging.WARNING)
from ..caches.cache import Cache
try:
import sqlalchemy # type: ignore
from google.cloud.sql.connector import Connector # type: ignore
from sqlalchemy import Column, String # type: ignore
from sqlalchemy.ext.declarative import declarative_base # type: ignore
from sqlalchemy.orm import sessionmaker # type: ignore
Base = declarative_base()
class Request(Base): # type: ignore
"""The request table."""
__tablename__ = "requests"
key = Column(String, primary_key=True)
response = Column(
String
) # FIXME: ideally should be an hstore, but I don't want to set it up on GCP
missing_dependencies = None
except ImportError as e:
missing_dependencies = e
class PostgresCache(Cache):
"""A PostgreSQL cache for request/response pairs."""
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
"""
Connect to client.
Args:
connection_str: connection string.
cache_args: arguments for cache should include the following fields:
{
"cache_user": "",
"cache_password": "",
"cache_db": ""
}
"""
if missing_dependencies:
raise ValueError(
"Missing dependencies for GCP PostgreSQL cache. "
"Install with `pip install manifest[gcp]`",
missing_dependencies,
)
connector = Connector()
def getconn() -> Any:
conn = connector.connect(
connection_str,
"pg8000",
user=cache_args.pop("cache_user"),
password=cache_args.pop("cache_password"),
db=cache_args.pop("cache_db"),
)
return conn
engine = sqlalchemy.create_engine(
"postgresql+pg8000://",
creator=getconn,
)
engine.dialect.description_encoding = None # type: ignore
db_exists = len(sqlalchemy.inspect(engine).get_table_names()) > 0
if not db_exists:
logger.info("Creating database...")
Base.metadata.create_all(engine)
self.session = sessionmaker(bind=engine)()
def close(self) -> None:
"""Close the client."""
self.session.close()
def _normalize_table_key(self, key: str, table: str) -> str:
"""Cast key for prompt key."""
return f"{table}:{key}"
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
"""
Get the key for a request.
With return None if key is not in cache.
Args:
key: key for cache.
table: table to get key in.
"""
request = self.session.query(Request).filter_by(key=key).first()
out = request.response if request else None
return out # type: ignore
def set_key(self, key: str, value: str, table: str = "default") -> None:
"""
Set the value for the key.
Will override old value.
Args:
key: key for cache.
value: new value for key.
table: table to set key in.
"""
request = self.session.query(Request).filter_by(key=key).first()
if request:
request.response = value # type: ignore
else:
self.session.add(Request(key=key, response=value))
self.commit()
def commit(self) -> None:
"""Commit any results."""
self.session.commit()

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast
import numpy as np
from manifest.caches.noop import NoopCache
from manifest.caches.postgres import PostgresCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.ai21 import AI21Client
@ -55,6 +56,7 @@ CACHE_CONSTRUCTORS = {
"redis": RedisCache,
"sqlite": SQLiteCache,
"noop": NoopCache,
"postgres": PostgresCache,
}

@ -56,6 +56,11 @@ EXTRAS = {
"diffusers": [
"pillow>=9.0.0",
],
"gcp": [
"pg8000",
"cloud-sql-python-connector[pg8000]>=1.0.0",
"sqlalchemy",
],
"dev": [
"autopep8>=1.6.0",
"black>=22.3.0",

@ -31,6 +31,19 @@ def redis_cache() -> Generator[str, None, None]:
pass
@pytest.fixture
def postgres_cache(monkeypatch: pytest.MonkeyPatch) -> Generator[str, None, None]:
"""Postgres cache."""
import sqlalchemy # type: ignore
# Replace the sqlalchemy.create_engine function with a function that returns an
# in-memory SQLite engine
url = sqlalchemy.engine.url.URL.create("sqlite", database=":memory:")
engine = sqlalchemy.create_engine(url)
monkeypatch.setattr(sqlalchemy, "create_engine", lambda *args, **kwargs: engine)
return engine # type: ignore
@pytest.fixture
def session_cache(tmpdir: str) -> Generator[Path, None, None]:
"""Session cache dir."""

@ -8,32 +8,53 @@ from sqlitedict import SqliteDict
from manifest.caches.cache import Cache
from manifest.caches.noop import NoopCache
from manifest.caches.postgres import PostgresCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
def _get_postgres_cache(**kwargs) -> Cache: # type: ignore
"""Get postgres cache."""
return PostgresCache(
"postgres",
cache_args={"cache_user": "", "cache_password": "", "cache_db": ""},
**kwargs,
)
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
def test_init(sqlite_cache: str, redis_cache: str, cache_type: str) -> None:
@pytest.mark.usefixtures("postgres_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"])
def test_init(
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str
) -> None:
"""Test cache initialization."""
if cache_type == "sqlite":
sql_cache_obj = SQLiteCache(sqlite_cache)
assert isinstance(sql_cache_obj.cache, SqliteDict)
else:
elif cache_type == "redis":
redis_cache_obj = RedisCache(redis_cache)
assert isinstance(redis_cache_obj.redis, Redis)
elif cache_type == "postgres":
postgres_cache_obj = _get_postgres_cache()
isinstance(postgres_cache_obj, PostgresCache)
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
def test_key_get_and_set(sqlite_cache: str, redis_cache: str, cache_type: str) -> None:
@pytest.mark.usefixtures("postgres_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "postgres", "redis"])
def test_key_get_and_set(
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str
) -> None:
"""Test cache key get and set."""
if cache_type == "sqlite":
cache = cast(Cache, SQLiteCache(sqlite_cache))
else:
elif cache_type == "redis":
cache = cast(Cache, RedisCache(redis_cache))
elif cache_type == "postgres":
cache = cast(Cache, _get_postgres_cache())
cache.set_key("test", "valueA")
cache.set_key("testA", "valueB")
@ -50,13 +71,19 @@ def test_key_get_and_set(sqlite_cache: str, redis_cache: str, cache_type: str) -
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
def test_get(sqlite_cache: str, redis_cache: str, cache_type: str) -> None:
@pytest.mark.usefixtures("postgres_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis", "postgres"])
def test_get(
sqlite_cache: str, redis_cache: str, postgres_cache: str, cache_type: str
) -> None:
"""Test cache save prompt."""
if cache_type == "sqlite":
cache = cast(Cache, SQLiteCache(sqlite_cache))
else:
elif cache_type == "redis":
cache = cast(Cache, RedisCache(redis_cache))
elif cache_type == "postgres":
cache = cast(Cache, _get_postgres_cache())
test_request = {"test": "hello", "testA": "world"}
compute = lambda: {"choices": [{"text": "hello"}]}
@ -82,8 +109,11 @@ def test_get(sqlite_cache: str, redis_cache: str, cache_type: str) -> None:
# Test array
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, client_name="diffuser")
else:
elif cache_type == "redis":
cache = RedisCache(redis_cache, client_name="diffuser")
elif cache_type == "postgres":
cache = _get_postgres_cache(client_name="diffuser")
response = cache.get(test_request, overwrite_cache=False, compute=compute_arr)
assert np.allclose(response.get_response(), arr)
assert not response.is_cached()

Loading…
Cancel
Save