community[minor]: Revamp PGVector Filtering (#18992)

This PR makes the following updates in the pgvector database:

1. Use JSONB field for metadata instead of JSON
2. Update operator syntax to include required `$` prefix before the
operators (otherwise there will be name collisions with fields)
3. The change is non-breaking, old functionality is still the default,
but it will emit a deprecation warning
4. Previous functionality has bugs associated with comparisons due to
casting to text (so lexical ordering is used incorrectly for numeric
fields)
5. Adds an a GIN index on the JSONB field for more efficient querying
pull/19084/head^2
Eugene Yurtsev 3 months ago committed by GitHub
parent e276817e1d
commit 6cdca4355d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -2,6 +2,7 @@ from __future__ import annotations
import contextlib import contextlib
import enum import enum
import json
import logging import logging
import uuid import uuid
from typing import ( from typing import (
@ -18,8 +19,9 @@ from typing import (
import numpy as np import numpy as np
import sqlalchemy import sqlalchemy
from sqlalchemy import delete from langchain_core._api import warn_deprecated
from sqlalchemy.dialects.postgresql import JSON, UUID from sqlalchemy import SQLColumnExpression, delete, func
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
from sqlalchemy.orm import Session, relationship from sqlalchemy.orm import Session, relationship
try: try:
@ -61,8 +63,39 @@ class BaseModel(Base):
_classes: Any = None _classes: Any = None
COMPARISONS_TO_NATIVE = {
"$eq": "==",
"$ne": "!=",
"$lt": "<",
"$lte": "<=",
"$gt": ">",
"$gte": ">=",
}
SPECIAL_CASED_OPERATORS = {
"$in",
"$nin",
"$between",
}
TEXT_OPERATORS = {
"$like",
"$ilike",
}
LOGICAL_OPERATORS = {"$and", "$or"}
SUPPORTED_OPERATORS = (
set(COMPARISONS_TO_NATIVE)
.union(TEXT_OPERATORS)
.union(LOGICAL_OPERATORS)
.union(SPECIAL_CASED_OPERATORS)
)
def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: def _get_embedding_collection_store(
vector_dimension: Optional[int] = None, *, use_jsonb: bool = True
) -> Any:
global _classes global _classes
if _classes is not None: if _classes is not None:
return _classes return _classes
@ -111,26 +144,60 @@ def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> A
created = True created = True
return collection, created return collection, created
class EmbeddingStore(BaseModel): if use_jsonb:
"""Embedding store.""" # TODO(PRIOR TO LANDING): Create a gin index on the cmetadata field
class EmbeddingStore(BaseModel):
"""Embedding store."""
__tablename__ = "langchain_pg_embedding" __tablename__ = "langchain_pg_embedding"
collection_id = sqlalchemy.Column( collection_id = sqlalchemy.Column(
UUID(as_uuid=True), UUID(as_uuid=True),
sqlalchemy.ForeignKey( sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid", f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE", ondelete="CASCADE",
), ),
) )
collection = relationship(CollectionStore, back_populates="embeddings") collection = relationship(CollectionStore, back_populates="embeddings")
embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSONB, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
__table_args__ = (
sqlalchemy.Index(
"ix_cmetadata_gin",
"cmetadata",
postgresql_using="gin",
postgresql_ops={"cmetadata": "jsonb_path_ops"},
),
)
else:
# For backwards comaptibilty with older versions of pgvector
# This should be removed in the future (remove during migration)
class EmbeddingStore(BaseModel): # type: ignore[no-redef]
"""Embedding store."""
__tablename__ = "langchain_pg_embedding"
collection_id = sqlalchemy.Column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship(CollectionStore, back_populates="embeddings")
embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True) document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True) cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id # custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
_classes = (EmbeddingStore, CollectionStore) _classes = (EmbeddingStore, CollectionStore)
@ -163,6 +230,11 @@ class PGVector(VectorStore):
pre_delete_collection: If True, will delete the collection if it exists. pre_delete_collection: If True, will delete the collection if it exists.
(default: False). Useful for testing. (default: False). Useful for testing.
engine_args: SQLAlchemy's create engine arguments. engine_args: SQLAlchemy's create engine arguments.
use_jsonb: Use JSONB instead of JSON for metadata. (default: True)
Strongly discouraged from using JSON as it's not as efficient
for querying.
It's provided here for backwards compatibility with older versions,
and will be removed in the future.
Example: Example:
.. code-block:: python .. code-block:: python
@ -178,9 +250,8 @@ class PGVector(VectorStore):
documents=docs, documents=docs,
collection_name=COLLECTION_NAME, collection_name=COLLECTION_NAME,
connection_string=CONNECTION_STRING, connection_string=CONNECTION_STRING,
use_jsonb=True,
) )
""" """
def __init__( def __init__(
@ -197,7 +268,9 @@ class PGVector(VectorStore):
*, *,
connection: Optional[sqlalchemy.engine.Connection] = None, connection: Optional[sqlalchemy.engine.Connection] = None,
engine_args: Optional[dict[str, Any]] = None, engine_args: Optional[dict[str, Any]] = None,
use_jsonb: bool = False,
) -> None: ) -> None:
"""Initialize the PGVector store."""
self.connection_string = connection_string self.connection_string = connection_string
self.embedding_function = embedding_function self.embedding_function = embedding_function
self._embedding_length = embedding_length self._embedding_length = embedding_length
@ -209,6 +282,29 @@ class PGVector(VectorStore):
self.override_relevance_score_fn = relevance_score_fn self.override_relevance_score_fn = relevance_score_fn
self.engine_args = engine_args or {} self.engine_args = engine_args or {}
self._bind = connection if connection else self._create_engine() self._bind = connection if connection else self._create_engine()
self.use_jsonb = use_jsonb
if not use_jsonb:
# Replace with a deprecation warning.
warn_deprecated(
"0.0.29",
pending=True,
message=(
"Please use JSONB instead of JSON for metadata. "
"This change will allow for more efficient querying that "
"involves filtering based on metadata."
"Please note that filtering operators have been changed "
"when using JSOB metadata to be prefixed with a $ sign "
"to avoid name collisions with columns. "
"If you're using an existing database, you will need to create a"
"db migration for your metadata column to be JSONB and update your "
"queries to use the new operators. "
),
alternative=(
"Instantiate with use_jsonb=True to use JSONB instead "
"of JSON for metadata."
),
)
self.__post_init__() self.__post_init__()
def __post_init__( def __post_init__(
@ -218,7 +314,7 @@ class PGVector(VectorStore):
self.create_vector_extension() self.create_vector_extension()
EmbeddingStore, CollectionStore = _get_embedding_collection_store( EmbeddingStore, CollectionStore = _get_embedding_collection_store(
self._embedding_length self._embedding_length, use_jsonb=self.use_jsonb
) )
self.CollectionStore = CollectionStore self.CollectionStore = CollectionStore
self.EmbeddingStore = EmbeddingStore self.EmbeddingStore = EmbeddingStore
@ -336,6 +432,8 @@ class PGVector(VectorStore):
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
connection_string: Optional[str] = None, connection_string: Optional[str] = None,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
*,
use_jsonb: bool = False,
**kwargs: Any, **kwargs: Any,
) -> PGVector: ) -> PGVector:
if ids is None: if ids is None:
@ -352,6 +450,7 @@ class PGVector(VectorStore):
embedding_function=embedding, embedding_function=embedding,
distance_strategy=distance_strategy, distance_strategy=distance_strategy,
pre_delete_collection=pre_delete_collection, pre_delete_collection=pre_delete_collection,
use_jsonb=use_jsonb,
**kwargs, **kwargs,
) )
@ -508,7 +607,117 @@ class PGVector(VectorStore):
] ]
return docs return docs
def _create_filter_clause(self, key, value): # type: ignore[no-untyped-def] def _handle_field_filter(
self,
field: str,
value: Any,
) -> SQLColumnExpression:
"""Create a filter for a specific field.
Args:
field: name of field
value: value to filter
If provided as is then this will be an equality filter
If provided as a dictionary then this will be a filter, the key
will be the operator and the value will be the value to filter by
Returns:
sqlalchemy expression
"""
if not isinstance(field, str):
raise ValueError(
f"field should be a string but got: {type(field)} with value: {field}"
)
if field.startswith("$"):
raise ValueError(
f"Invalid filter condition. Expected a field but got an operator: "
f"{field}"
)
# Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
if not field.isidentifier():
raise ValueError(
f"Invalid field name: {field}. Expected a valid identifier."
)
if isinstance(value, dict):
# This is a filter specification
if len(value) != 1:
raise ValueError(
"Invalid filter condition. Expected a value which "
"is a dictionary with a single key that corresponds to an operator "
f"but got a dictionary with {len(value)} keys. The first few "
f"keys are: {list(value.keys())[:3]}"
)
operator, filter_value = list(value.items())[0]
# Verify that that operator is an operator
if operator not in SUPPORTED_OPERATORS:
raise ValueError(
f"Invalid operator: {operator}. "
f"Expected one of {SUPPORTED_OPERATORS}"
)
else: # Then we assume an equality operator
operator = "$eq"
filter_value = value
if operator in COMPARISONS_TO_NATIVE:
# Then we implement an equality filter
# native is trusted input
native = COMPARISONS_TO_NATIVE[operator]
return func.jsonb_path_match(
self.EmbeddingStore.cmetadata,
f"$.{field} {native} $value",
json.dumps({"value": filter_value}),
)
elif operator == "$between":
# Use AND with two comparisons
low, high = filter_value
lower_bound = func.jsonb_path_match(
self.EmbeddingStore.cmetadata,
f"$.{field} >= $value",
json.dumps({"value": low}),
)
upper_bound = func.jsonb_path_match(
self.EmbeddingStore.cmetadata,
f"$.{field} <= $value",
json.dumps({"value": high}),
)
return sqlalchemy.and_(lower_bound, upper_bound)
elif operator in {"$in", "$nin", "$like", "$ilike"}:
# We'll do force coercion to text
if operator in {"$in", "$nin"}:
for val in filter_value:
if not isinstance(val, (str, int, float)):
raise NotImplementedError(
f"Unsupported type: {type(val)} for value: {val}"
)
queried_field = self.EmbeddingStore.cmetadata[field].astext
if operator in {"$in"}:
return queried_field.in_([str(val) for val in filter_value])
elif operator in {"$nin"}:
return queried_field.nin_([str(val) for val in filter_value])
elif operator in {"$like"}:
return queried_field.like(filter_value)
elif operator in {"$ilike"}:
return queried_field.ilike(filter_value)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyped-def]
"""Deprecated functionality.
This is for backwards compatibility with the JSON based schema for metadata.
It uses incorrect operator syntax (operators are not prefixed with $).
This implementation is not efficient, and has bugs associated with
the way that it handles numeric filter clauses.
"""
IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne" IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and" EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"
@ -568,6 +777,117 @@ class PGVector(VectorStore):
return filter_by_metadata return filter_by_metadata
def _create_filter_clause_json_deprecated(
self, filter: Any
) -> List[SQLColumnExpression]:
"""Convert filters from IR to SQL clauses.
**DEPRECATED** This functionality will be deprecated in the future.
It implements translation of filters for a schema that uses JSON
for metadata rather than the JSONB field which is more efficient
for querying.
"""
filter_clauses = []
for key, value in filter.items():
if isinstance(value, dict):
filter_by_metadata = self._create_filter_clause_deprecated(key, value)
if filter_by_metadata is not None:
filter_clauses.append(filter_by_metadata)
else:
filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(
value
)
filter_clauses.append(filter_by_metadata)
return filter_clauses
def _create_filter_clause(self, filters: Any) -> Any:
"""Convert LangChain IR filter representation to matching SQLAlchemy clauses.
At the top level, we still don't know if we're working with a field
or an operator for the keys. After we've determined that we can
call the appropriate logic to handle filter creation.
Args:
filters: Dictionary of filters to apply to the query.
Returns:
SQLAlchemy clause to apply to the query.
"""
if isinstance(filters, dict):
if len(filters) == 1:
# The only operators allowed at the top level are $AND and $OR
# First check if an operator or a field
key, value = list(filters.items())[0]
if key.startswith("$"):
# Then it's an operator
if key.lower() not in ["$and", "$or"]:
raise ValueError(
f"Invalid filter condition. Expected $and or $or "
f"but got: {key}"
)
else:
# Then it's a field
return self._handle_field_filter(key, filters[key])
# Here we handle the $and and $or operators
if not isinstance(value, list):
raise ValueError(
f"Expected a list, but got {type(value)} for value: {value}"
)
if key.lower() == "$and":
and_ = [self._create_filter_clause(el) for el in value]
if len(and_) > 1:
return sqlalchemy.and_(*and_)
elif len(and_) == 1:
return and_[0]
else:
raise ValueError(
"Invalid filter condition. Expected a dictionary "
"but got an empty dictionary"
)
elif key.lower() == "$or":
or_ = [self._create_filter_clause(el) for el in value]
if len(or_) > 1:
return sqlalchemy.or_(*or_)
elif len(or_) == 1:
return or_[0]
else:
raise ValueError(
"Invalid filter condition. Expected a dictionary "
"but got an empty dictionary"
)
else:
raise ValueError(
f"Invalid filter condition. Expected $and or $or "
f"but got: {key}"
)
elif len(filters) > 1:
# Then all keys have to be fields (they cannot be operators)
for key in filters.keys():
if key.startswith("$"):
raise ValueError(
f"Invalid filter condition. Expected a field but got: {key}"
)
# These should all be fields and combined using an $and operator
and_ = [self._handle_field_filter(k, v) for k, v in filters.items()]
if len(and_) > 1:
return sqlalchemy.and_(*and_)
elif len(and_) == 1:
return and_[0]
else:
raise ValueError(
"Invalid filter condition. Expected a dictionary "
"but got an empty dictionary"
)
else:
raise ValueError("Got an empty dictionary for filters.")
else:
raise ValueError(
f"Invalid type: Expected a dictionary but got type: {type(filters)}"
)
def __query_collection( def __query_collection(
self, self,
embedding: List[float], embedding: List[float],
@ -580,24 +900,16 @@ class PGVector(VectorStore):
if not collection: if not collection:
raise ValueError("Collection not found") raise ValueError("Collection not found")
filter_by = self.EmbeddingStore.collection_id == collection.uuid filter_by = [self.EmbeddingStore.collection_id == collection.uuid]
if filter:
if filter is not None: if self.use_jsonb:
filter_clauses = [] filter_clauses = self._create_filter_clause(filter)
if filter_clauses is not None:
for key, value in filter.items(): filter_by.append(filter_clauses)
if isinstance(value, dict): else:
filter_by_metadata = self._create_filter_clause(key, value) # Old way of doing things
filter_clauses = self._create_filter_clause_json_deprecated(filter)
if filter_by_metadata is not None: filter_by.extend(filter_clauses)
filter_clauses.append(filter_by_metadata)
else:
filter_by_metadata = self.EmbeddingStore.cmetadata[
key
].astext == str(value)
filter_clauses.append(filter_by_metadata)
filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
_type = self.EmbeddingStore _type = self.EmbeddingStore
@ -606,7 +918,7 @@ class PGVector(VectorStore):
self.EmbeddingStore, self.EmbeddingStore,
self.distance_strategy(embedding).label("distance"), # type: ignore self.distance_strategy(embedding).label("distance"), # type: ignore
) )
.filter(filter_by) .filter(*filter_by)
.order_by(sqlalchemy.asc("distance")) .order_by(sqlalchemy.asc("distance"))
.join( .join(
self.CollectionStore, self.CollectionStore,
@ -615,6 +927,7 @@ class PGVector(VectorStore):
.limit(k) .limit(k)
.all() .all()
) )
return results return results
def similarity_search_by_vector( def similarity_search_by_vector(
@ -649,6 +962,8 @@ class PGVector(VectorStore):
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
*,
use_jsonb: bool = False,
**kwargs: Any, **kwargs: Any,
) -> PGVector: ) -> PGVector:
""" """
@ -668,6 +983,7 @@ class PGVector(VectorStore):
collection_name=collection_name, collection_name=collection_name,
distance_strategy=distance_strategy, distance_strategy=distance_strategy,
pre_delete_collection=pre_delete_collection, pre_delete_collection=pre_delete_collection,
use_jsonb=use_jsonb,
**kwargs, **kwargs,
) )
@ -769,6 +1085,8 @@ class PGVector(VectorStore):
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
*,
use_jsonb: bool = False,
**kwargs: Any, **kwargs: Any,
) -> PGVector: ) -> PGVector:
""" """
@ -792,6 +1110,7 @@ class PGVector(VectorStore):
metadatas=metadatas, metadatas=metadatas,
ids=ids, ids=ids,
collection_name=collection_name, collection_name=collection_name,
use_jsonb=use_jsonb,
**kwargs, **kwargs,
) )

@ -0,0 +1,222 @@
"""Module contains test cases for testing filtering of documents in vector stores.
"""
from langchain_core.documents import Document
metadatas = [
{
"name": "adam",
"date": "2021-01-01",
"count": 1,
"is_active": True,
"tags": ["a", "b"],
"location": [1.0, 2.0],
"info": {"address": "123 main st", "phone": "123-456-7890"},
"id": 1,
"height": 10.0, # Float column
"happiness": 0.9, # Float column
"sadness": 0.1, # Float column
},
{
"name": "bob",
"date": "2021-01-02",
"count": 2,
"is_active": False,
"tags": ["b", "c"],
"location": [2.0, 3.0],
"info": {"address": "456 main st", "phone": "123-456-7890"},
"id": 2,
"height": 5.7, # Float column
"happiness": 0.8, # Float column
"sadness": 0.1, # Float column
},
{
"name": "jane",
"date": "2021-01-01",
"count": 3,
"is_active": True,
"tags": ["b", "d"],
"location": [3.0, 4.0],
"info": {"address": "789 main st", "phone": "123-456-7890"},
"id": 3,
"height": 2.4, # Float column
"happiness": None,
# Sadness missing intentionally
},
]
texts = ["id {id}".format(id=metadata["id"]) for metadata in metadatas]
DOCUMENTS = [
Document(page_content=text, metadata=metadata)
for text, metadata in zip(texts, metadatas)
]
TYPE_1_FILTERING_TEST_CASES = [
# These tests only involve equality checks
(
{"id": 1},
[1],
),
# String field
(
# check name
{"name": "adam"},
[1],
),
# Boolean fields
(
{"is_active": True},
[1, 3],
),
(
{"is_active": False},
[2],
),
# And semantics for top level filtering
(
{"id": 1, "is_active": True},
[1],
),
(
{"id": 1, "is_active": False},
[],
),
]
TYPE_2_FILTERING_TEST_CASES = [
# These involve equality checks and other operators
# like $ne, $gt, $gte, $lt, $lte, $not
(
{"id": 1},
[1],
),
(
{"id": {"$ne": 1}},
[2, 3],
),
(
{"id": {"$gt": 1}},
[2, 3],
),
(
{"id": {"$gte": 1}},
[1, 2, 3],
),
(
{"id": {"$lt": 1}},
[],
),
(
{"id": {"$lte": 1}},
[1],
),
# Repeat all the same tests with name (string column)
(
{"name": "adam"},
[1],
),
(
{"name": "bob"},
[2],
),
(
{"name": {"$eq": "adam"}},
[1],
),
(
{"name": {"$ne": "adam"}},
[2, 3],
),
# And also gt, gte, lt, lte relying on lexicographical ordering
(
{"name": {"$gt": "jane"}},
[],
),
(
{"name": {"$gte": "jane"}},
[3],
),
(
{"name": {"$lt": "jane"}},
[1, 2],
),
(
{"name": {"$lte": "jane"}},
[1, 2, 3],
),
(
{"is_active": {"$eq": True}},
[1, 3],
),
(
{"is_active": {"$ne": True}},
[2],
),
# Test float column.
(
{"height": {"$gt": 5.0}},
[1, 2],
),
(
{"height": {"$gte": 5.0}},
[1, 2],
),
(
{"height": {"$lt": 5.0}},
[3],
),
(
{"height": {"$lte": 5.8}},
[2, 3],
),
]
TYPE_3_FILTERING_TEST_CASES = [
# These involve usage of AND and OR operators
(
{"$or": [{"id": 1}, {"id": 2}]},
[1, 2],
),
(
{"$or": [{"id": 1}, {"name": "bob"}]},
[1, 2],
),
(
{"$and": [{"id": 1}, {"id": 2}]},
[],
),
(
{"$or": [{"id": 1}, {"id": 2}, {"id": 3}]},
[1, 2, 3],
),
]
TYPE_4_FILTERING_TEST_CASES = [
# These involve special operators like $in, $nin, $between
# Test between
(
{"id": {"$between": (1, 2)}},
[1, 2],
),
(
{"id": {"$between": (1, 1)}},
[1],
),
(
{"name": {"$in": ["adam", "bob"]}},
[1, 2],
),
]
TYPE_5_FILTERING_TEST_CASES = [
# These involve special operators like $like, $ilike that
# may be specified to certain databases.
(
{"name": {"$like": "a%"}},
[1],
),
(
{"name": {"$like": "%a%"}}, # adam and jane
[1, 3],
),
]

@ -1,13 +1,26 @@
"""Test PGVector functionality.""" """Test PGVector functionality."""
import os import os
from typing import List from typing import Any, Dict, Generator, List, Type, Union
import pytest
import sqlalchemy import sqlalchemy
from langchain_core.documents import Document from langchain_core.documents import Document
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from langchain_community.vectorstores.pgvector import PGVector from langchain_community.vectorstores.pgvector import (
SUPPORTED_OPERATORS,
PGVector,
)
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
DOCUMENTS,
TYPE_1_FILTERING_TEST_CASES,
TYPE_2_FILTERING_TEST_CASES,
TYPE_3_FILTERING_TEST_CASES,
TYPE_4_FILTERING_TEST_CASES,
TYPE_5_FILTERING_TEST_CASES,
)
# The connection string matches the default settings in the docker-compose file # The connection string matches the default settings in the docker-compose file
# located in the root of the repository: [root]/docker/docker-compose.yml # located in the root of the repository: [root]/docker/docker-compose.yml
@ -42,7 +55,7 @@ class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
def test_pgvector() -> None: def test_pgvector(pgvector: PGVector) -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
docsearch = PGVector.from_texts( docsearch = PGVector.from_texts(
@ -375,3 +388,255 @@ def test_pgvector_with_custom_engine_args() -> None:
) )
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
# We should reuse this test-case across other integrations
# Add database fixture using pytest
@pytest.fixture
def pgvector() -> Generator[PGVector, None, None]:
"""Create a PGVector instance."""
store = PGVector.from_documents(
documents=DOCUMENTS,
collection_name="test_collection",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
relevance_score_fn=lambda d: d * 0,
use_jsonb=True,
)
try:
yield store
# Do clean up
finally:
store.drop_tables()
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_1(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_2(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_3(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_4(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_5(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
@pytest.mark.parametrize(
"invalid_filter",
[
["hello"],
{
"id": 2,
"$name": "foo",
},
{"$or": {}},
{"$and": {}},
{"$between": {}},
{"$eq": {}},
],
)
def test_invalid_filters(pgvector: PGVector, invalid_filter: Any) -> None:
"""Verify that invalid filters raise an error."""
with pytest.raises(ValueError):
pgvector._create_filter_clause(invalid_filter)
@pytest.mark.parametrize(
"filter,compiled",
[
({"id 'evil code'": 2}, ValueError),
(
{"id": "'evil code' == 2"},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.id == $value', "
"'{\"value\": \"''evil code'' == 2\"}')"
),
),
(
{"name": 'a"b'},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.name == $value', "
'\'{"value": "a\\\\"b"}\')'
),
),
],
)
def test_evil_code(
pgvector: PGVector, filter: Any, compiled: Union[Type[Exception], str]
) -> None:
"""Test evil code."""
if isinstance(compiled, str):
clause = pgvector._create_filter_clause(filter)
compiled_stmt = str(
clause.compile(
dialect=postgresql.dialect(),
compile_kwargs={
# This substitutes the parameters with their actual values
"literal_binds": True
},
)
)
assert compiled_stmt == compiled
else:
with pytest.raises(compiled):
pgvector._create_filter_clause(filter)
@pytest.mark.parametrize(
"filter,compiled",
[
(
{"id": 2},
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
"'{\"value\": 2}')",
),
(
{"id": {"$eq": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
"'{\"value\": 2}')"
),
),
(
{"name": "foo"},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.name == $value', "
'\'{"value": "foo"}\')'
),
),
(
{"id": {"$ne": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id != $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$gt": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id > $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$gte": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id >= $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$lt": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id < $value', "
"'{\"value\": 2}')"
),
),
(
{"id": {"$lte": 2}},
(
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id <= $value', "
"'{\"value\": 2}')"
),
),
(
{"name": {"$ilike": "foo"}},
"langchain_pg_embedding.cmetadata ->> 'name' ILIKE 'foo'",
),
(
{"name": {"$like": "foo"}},
"langchain_pg_embedding.cmetadata ->> 'name' LIKE 'foo'",
),
(
{"$or": [{"id": 1}, {"id": 2}]},
# Please note that this might not be super optimized
# Another way to phrase the query is as
# langchain_pg_embedding.cmetadata @@ '($.id == 1 || $.id == 2)'
"jsonb_path_match(langchain_pg_embedding.cmetadata, '$.id == $value', "
"'{\"value\": 1}') OR jsonb_path_match(langchain_pg_embedding.cmetadata, "
"'$.id == $value', '{\"value\": 2}')",
),
],
)
def test_pgvector_query_compilation(
pgvector: PGVector, filter: Any, compiled: str
) -> None:
"""Test translation from IR to SQL"""
clause = pgvector._create_filter_clause(filter)
compiled_stmt = str(
clause.compile(
dialect=postgresql.dialect(),
compile_kwargs={
# This substitutes the parameters with their actual values
"literal_binds": True
},
)
)
assert compiled_stmt == compiled
def test_validate_operators() -> None:
"""Verify that all operators have been categorized."""
assert sorted(SUPPORTED_OPERATORS) == [
"$and",
"$between",
"$eq",
"$gt",
"$gte",
"$ilike",
"$in",
"$like",
"$lt",
"$lte",
"$ne",
"$nin",
"$or",
]

Loading…
Cancel
Save