mirror of https://github.com/HazyResearch/manifest
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
3.6 KiB
Python
125 lines
3.6 KiB
Python
"""Postgres cache."""
|
|
import logging
|
|
from typing import Any, Dict, Union
|
|
|
|
logger = logging.getLogger("postgresql")
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
from ..caches.cache import Cache
|
|
|
|
try:
|
|
import sqlalchemy # type: ignore
|
|
from google.cloud.sql.connector import Connector # type: ignore
|
|
from sqlalchemy import Column, String # type: ignore
|
|
from sqlalchemy.ext.declarative import declarative_base # type: ignore
|
|
from sqlalchemy.orm import sessionmaker # type: ignore
|
|
|
|
Base = declarative_base()
|
|
|
|
class Request(Base): # type: ignore
|
|
"""The request table."""
|
|
|
|
__tablename__ = "requests"
|
|
key = Column(String, primary_key=True)
|
|
response = Column(
|
|
String
|
|
) # FIXME: ideally should be an hstore, but I don't want to set it up on GCP
|
|
|
|
missing_dependencies = None
|
|
|
|
except ImportError as e:
|
|
missing_dependencies = e
|
|
|
|
|
|
class PostgresCache(Cache):
|
|
"""A PostgreSQL cache for request/response pairs."""
|
|
|
|
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
|
"""
|
|
Connect to client.
|
|
|
|
Args:
|
|
connection_str: connection string.
|
|
cache_args: arguments for cache should include the following fields:
|
|
{
|
|
"cache_user": "",
|
|
"cache_password": "",
|
|
"cache_db": ""
|
|
}
|
|
"""
|
|
if missing_dependencies:
|
|
raise ValueError(
|
|
"Missing dependencies for GCP PostgreSQL cache. "
|
|
"Install with `pip install manifest[gcp]`",
|
|
missing_dependencies,
|
|
)
|
|
|
|
connector = Connector()
|
|
|
|
def getconn() -> Any:
|
|
conn = connector.connect(
|
|
connection_str,
|
|
"pg8000",
|
|
user=cache_args.pop("cache_user"),
|
|
password=cache_args.pop("cache_password"),
|
|
db=cache_args.pop("cache_db"),
|
|
)
|
|
return conn
|
|
|
|
engine = sqlalchemy.create_engine(
|
|
"postgresql+pg8000://",
|
|
creator=getconn,
|
|
)
|
|
engine.dialect.description_encoding = None # type: ignore
|
|
|
|
db_exists = len(sqlalchemy.inspect(engine).get_table_names()) > 0
|
|
if not db_exists:
|
|
logger.info("Creating database...")
|
|
Base.metadata.create_all(engine)
|
|
|
|
self.session = sessionmaker(bind=engine)()
|
|
|
|
def close(self) -> None:
|
|
"""Close the client."""
|
|
self.session.close()
|
|
|
|
def _normalize_table_key(self, key: str, table: str) -> str:
|
|
"""Cast key for prompt key."""
|
|
return f"{table}:{key}"
|
|
|
|
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
|
"""
|
|
Get the key for a request.
|
|
|
|
With return None if key is not in cache.
|
|
|
|
Args:
|
|
key: key for cache.
|
|
table: table to get key in.
|
|
"""
|
|
request = self.session.query(Request).filter_by(key=key).first()
|
|
out = request.response if request else None
|
|
return out # type: ignore
|
|
|
|
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
|
"""
|
|
Set the value for the key.
|
|
|
|
Will override old value.
|
|
|
|
Args:
|
|
key: key for cache.
|
|
value: new value for key.
|
|
table: table to set key in.
|
|
"""
|
|
request = self.session.query(Request).filter_by(key=key).first()
|
|
if request:
|
|
request.response = value # type: ignore
|
|
else:
|
|
self.session.add(Request(key=key, response=value))
|
|
self.commit()
|
|
|
|
def commit(self) -> None:
|
|
"""Commit any results."""
|
|
self.session.commit()
|