mirror of https://github.com/HazyResearch/manifest
Add PostgreSQL cache (#53)
Add a Cache for PostgreSQL with GCP. Co-authored-by: Laurel Orr <lorr1@cs.stanford.edu>pull/82/head
parent
c4ad007f02
commit
e00d285e21
@ -0,0 +1,124 @@
|
||||
"""Postgres cache."""
|
||||
import logging
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
logger = logging.getLogger("postgresql")
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
from ..caches.cache import Cache
|
||||
|
||||
try:
|
||||
import sqlalchemy # type: ignore
|
||||
from google.cloud.sql.connector import Connector # type: ignore
|
||||
from sqlalchemy import Column, String # type: ignore
|
||||
from sqlalchemy.ext.declarative import declarative_base # type: ignore
|
||||
from sqlalchemy.orm import sessionmaker # type: ignore
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Request(Base): # type: ignore
|
||||
"""The request table."""
|
||||
|
||||
__tablename__ = "requests"
|
||||
key = Column(String, primary_key=True)
|
||||
response = Column(
|
||||
String
|
||||
) # FIXME: ideally should be an hstore, but I don't want to set it up on GCP
|
||||
|
||||
missing_dependencies = None
|
||||
|
||||
except ImportError as e:
|
||||
missing_dependencies = e
|
||||
|
||||
|
||||
class PostgresCache(Cache):
|
||||
"""A PostgreSQL cache for request/response pairs."""
|
||||
|
||||
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
cache_args: arguments for cache should include the following fields:
|
||||
{
|
||||
"cache_user": "",
|
||||
"cache_password": "",
|
||||
"cache_db": ""
|
||||
}
|
||||
"""
|
||||
if missing_dependencies:
|
||||
raise ValueError(
|
||||
"Missing dependencies for GCP PostgreSQL cache. "
|
||||
"Install with `pip install manifest[gcp]`",
|
||||
missing_dependencies,
|
||||
)
|
||||
|
||||
connector = Connector()
|
||||
|
||||
def getconn() -> Any:
|
||||
conn = connector.connect(
|
||||
connection_str,
|
||||
"pg8000",
|
||||
user=cache_args.pop("cache_user"),
|
||||
password=cache_args.pop("cache_password"),
|
||||
db=cache_args.pop("cache_db"),
|
||||
)
|
||||
return conn
|
||||
|
||||
engine = sqlalchemy.create_engine(
|
||||
"postgresql+pg8000://",
|
||||
creator=getconn,
|
||||
)
|
||||
engine.dialect.description_encoding = None # type: ignore
|
||||
|
||||
db_exists = len(sqlalchemy.inspect(engine).get_table_names()) > 0
|
||||
if not db_exists:
|
||||
logger.info("Creating database...")
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
self.session = sessionmaker(bind=engine)()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
self.session.close()
|
||||
|
||||
def _normalize_table_key(self, key: str, table: str) -> str:
|
||||
"""Cast key for prompt key."""
|
||||
return f"{table}:{key}"
|
||||
|
||||
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||
"""
|
||||
Get the key for a request.
|
||||
|
||||
With return None if key is not in cache.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
table: table to get key in.
|
||||
"""
|
||||
request = self.session.query(Request).filter_by(key=key).first()
|
||||
out = request.response if request else None
|
||||
return out # type: ignore
|
||||
|
||||
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||
"""
|
||||
Set the value for the key.
|
||||
|
||||
Will override old value.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
value: new value for key.
|
||||
table: table to set key in.
|
||||
"""
|
||||
request = self.session.query(Request).filter_by(key=key).first()
|
||||
if request:
|
||||
request.response = value # type: ignore
|
||||
else:
|
||||
self.session.add(Request(key=key, response=value))
|
||||
self.commit()
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit any results."""
|
||||
self.session.commit()
|
Loading…
Reference in New Issue