langchain[minor]: Add PebbloRetrievalQA chain with Identity & Semantic Enforcement support (#20641)

- **Description:** PebbloRetrievalQA chain introduces identity
enforcement using vector-db metadata filtering
- **Dependencies:** None
- **Issue:** None
- **Documentation:** Adding documentation for PebbloRetrievalQA chain in
a separate PR(https://github.com/langchain-ai/langchain/pull/20746)
- **Unit tests:** New unit-tests added

---------

Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
pull/21709/head
Rajendra Kadam 3 weeks ago committed by GitHub
parent f2f970f93d
commit 54e003268e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,24 @@
"""
Chains module for langchain_community
This module contains the community chains.
"""
import importlib
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.chains.pebblo_retrieval.base import PebbloRetrievalQA
__all__ = ["PebbloRetrievalQA"]
_module_lookup = {
"PebbloRetrievalQA": "langchain_community.chains.pebblo_retrieval.base"
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")

@ -0,0 +1,218 @@
"""
Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answering
against a vector database.
"""
import inspect
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Extra, Field, validator
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
SUPPORTED_VECTORSTORES,
set_enforcement_filters,
)
from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
SemanticContext,
)
class PebbloRetrievalQA(Chain):
"""
Retrieval Chain with Identity & Semantic Enforcement for question-answering
against a vector database.
"""
combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents or not."""
retriever: VectorStoreRetriever = Field(exclude=True)
"""VectorStore to use for retrieval."""
auth_context_key: str = "auth_context" #: :meta private:
"""Authentication context for identity enforcement."""
semantic_context_key: str = "semantic_context" #: :meta private:
"""Semantic context for semantic enforcement."""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if accepts_run_manager:
docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key, self.auth_context_key, self.semantic_context_key]
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys += ["source_documents"]
return _output_keys
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "pebblo_retrieval_qa"
@classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
from langchain.chains.question_answering import load_qa_chain
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@validator("retriever", pre=True, always=True)
def validate_vectorstore(
cls, retriever: VectorStoreRetriever
) -> VectorStoreRetriever:
"""
Validate that the vectorstore of the retriever is supported vectorstores.
"""
if not any(
isinstance(retriever.vectorstore, supported_class)
for supported_class in SUPPORTED_VECTORSTORES
):
raise ValueError(
f"Vectorstore must be an instance of one of the supported "
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
f"Got {type(retriever.vectorstore).__name__} instead."
)
return retriever
def _get_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)
async def _aget_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)

@ -0,0 +1,265 @@
"""
Identity & Semantic Enforcement filters for PebbloRetrievalQA chain:
This module contains methods for applying Identity and Semantic Enforcement filters
in the PebbloRetrievalQA chain.
These filters are used to control the retrieval of documents based on authorization and
semantic context.
The Identity Enforcement filter ensures that only authorized identities can access
certain documents, while the Semantic Enforcement filter controls document retrieval
based on semantic context.
The methods in this module are designed to work with different types of vector stores.
"""
import logging
from typing import List, Optional, Union
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
SemanticContext,
)
from langchain_community.vectorstores import Pinecone, Qdrant
logger = logging.getLogger(__name__)
SUPPORTED_VECTORSTORES = [Pinecone, Qdrant]
def set_enforcement_filters(
retriever: VectorStoreRetriever,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
) -> None:
"""
Set identity and semantic enforcement filters in the retriever.
"""
if auth_context is not None:
_set_identity_enforcement_filter(retriever, auth_context)
if semantic_context is not None:
_set_semantic_enforcement_filter(retriever, semantic_context)
def _apply_qdrant_semantic_filter(
search_kwargs: dict, semantic_context: Optional[SemanticContext]
) -> None:
"""
Set semantic enforcement filter in search_kwargs for Qdrant vectorstore.
"""
try:
from qdrant_client.http import models as rest
except ImportError as e:
raise ValueError(
"Could not import `qdrant-client.http` python package. "
"Please install it with `pip install qdrant-client`."
) from e
# Create a semantic enforcement filter condition
semantic_filters: List[
Union[
rest.FieldCondition,
rest.IsEmptyCondition,
rest.IsNullCondition,
rest.HasIdCondition,
rest.NestedCondition,
rest.Filter,
]
] = []
if (
semantic_context is not None
and semantic_context.pebblo_semantic_topics is not None
):
semantic_topics_filter = rest.FieldCondition(
key="metadata.pebblo_semantic_topics",
match=rest.MatchAny(any=semantic_context.pebblo_semantic_topics.deny),
)
semantic_filters.append(semantic_topics_filter)
if (
semantic_context is not None
and semantic_context.pebblo_semantic_entities is not None
):
semantic_entities_filter = rest.FieldCondition(
key="metadata.pebblo_semantic_entities",
match=rest.MatchAny(any=semantic_context.pebblo_semantic_entities.deny),
)
semantic_filters.append(semantic_entities_filter)
# If 'filter' already exists in search_kwargs
if "filter" in search_kwargs:
existing_filter: rest.Filter = search_kwargs["filter"]
# Check if existing_filter is a qdrant-client filter
if isinstance(existing_filter, rest.Filter):
# If 'must_not' condition exists in the existing filter
if isinstance(existing_filter.must_not, list):
# Warn if 'pebblo_semantic_topics' or 'pebblo_semantic_entities'
# filter is overridden
new_must_not_conditions: List[
Union[
rest.FieldCondition,
rest.IsEmptyCondition,
rest.IsNullCondition,
rest.HasIdCondition,
rest.NestedCondition,
rest.Filter,
]
] = []
# Drop semantic filter conditions if already present
for condition in existing_filter.must_not:
if hasattr(condition, "key"):
if condition.key == "metadata.pebblo_semantic_topics":
continue
if condition.key == "metadata.pebblo_semantic_entities":
continue
new_must_not_conditions.append(condition)
# Add semantic enforcement filters to 'must_not' conditions
existing_filter.must_not = new_must_not_conditions
existing_filter.must_not.extend(semantic_filters)
else:
# Set 'must_not' condition with semantic enforcement filters
existing_filter.must_not = semantic_filters
else:
raise TypeError(
"Using dict as a `filter` is deprecated. "
"Please use qdrant-client filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/"
)
else:
# If 'filter' does not exist in search_kwargs, create it
search_kwargs["filter"] = rest.Filter(must_not=semantic_filters)
def _apply_qdrant_authorization_filter(
search_kwargs: dict, auth_context: Optional[AuthContext]
) -> None:
"""
Set identity enforcement filter in search_kwargs for Qdrant vectorstore.
"""
try:
from qdrant_client.http import models as rest
except ImportError as e:
raise ValueError(
"Could not import `qdrant-client.http` python package. "
"Please install it with `pip install qdrant-client`."
) from e
if auth_context is not None:
# Create a identity enforcement filter condition
identity_enforcement_filter = rest.FieldCondition(
key="metadata.authorized_identities",
match=rest.MatchAny(any=auth_context.user_auth),
)
else:
return
# If 'filter' already exists in search_kwargs
if "filter" in search_kwargs:
existing_filter: rest.Filter = search_kwargs["filter"]
# Check if existing_filter is a qdrant-client filter
if isinstance(existing_filter, rest.Filter):
# If 'must' exists in the existing filter
if existing_filter.must:
new_must_conditions: List[
Union[
rest.FieldCondition,
rest.IsEmptyCondition,
rest.IsNullCondition,
rest.HasIdCondition,
rest.NestedCondition,
rest.Filter,
]
] = []
# Drop 'authorized_identities' filter condition if already present
for condition in existing_filter.must:
if (
hasattr(condition, "key")
and condition.key == "metadata.authorized_identities"
):
continue
new_must_conditions.append(condition)
# Add identity enforcement filter to 'must' conditions
existing_filter.must = new_must_conditions
existing_filter.must.append(identity_enforcement_filter)
else:
# Set 'must' condition with identity enforcement filter
existing_filter.must = [identity_enforcement_filter]
else:
raise TypeError(
"Using dict as a `filter` is deprecated. "
"Please use qdrant-client filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/"
)
else:
# If 'filter' does not exist in search_kwargs, create it
search_kwargs["filter"] = rest.Filter(must=[identity_enforcement_filter])
def _apply_pinecone_semantic_filter(
search_kwargs: dict, semantic_context: Optional[SemanticContext]
) -> None:
"""
Set semantic enforcement filter in search_kwargs for Pinecone vectorstore.
"""
# Check if semantic_context is provided
semantic_context = semantic_context
if semantic_context is not None:
if semantic_context.pebblo_semantic_topics is not None:
# Add pebblo_semantic_topics filter to search_kwargs
search_kwargs.setdefault("filter", {})["pebblo_semantic_topics"] = {
"$nin": semantic_context.pebblo_semantic_topics.deny
}
if semantic_context.pebblo_semantic_entities is not None:
# Add pebblo_semantic_entities filter to search_kwargs
search_kwargs.setdefault("filter", {})["pebblo_semantic_entities"] = {
"$nin": semantic_context.pebblo_semantic_entities.deny
}
def _apply_pinecone_authorization_filter(
search_kwargs: dict, auth_context: Optional[AuthContext]
) -> None:
"""
Set identity enforcement filter in search_kwargs for Pinecone vectorstore.
"""
if auth_context is not None:
search_kwargs.setdefault("filter", {})["authorized_identities"] = {
"$in": auth_context.user_auth
}
def _set_identity_enforcement_filter(
retriever: VectorStoreRetriever, auth_context: Optional[AuthContext]
) -> None:
"""
Set identity enforcement filter in search_kwargs.
This method sets the identity enforcement filter in the search_kwargs
of the retriever based on the type of the vectorstore.
"""
search_kwargs = retriever.search_kwargs
if isinstance(retriever.vectorstore, Pinecone):
_apply_pinecone_authorization_filter(search_kwargs, auth_context)
elif isinstance(retriever.vectorstore, Qdrant):
_apply_qdrant_authorization_filter(search_kwargs, auth_context)
def _set_semantic_enforcement_filter(
retriever: VectorStoreRetriever, semantic_context: Optional[SemanticContext]
) -> None:
"""
Set semantic enforcement filter in search_kwargs.
This method sets the semantic enforcement filter in the search_kwargs
of the retriever based on the type of the vectorstore.
"""
search_kwargs = retriever.search_kwargs
if isinstance(retriever.vectorstore, Pinecone):
_apply_pinecone_semantic_filter(search_kwargs, semantic_context)
elif isinstance(retriever.vectorstore, Qdrant):
_apply_qdrant_semantic_filter(search_kwargs, semantic_context)

@ -0,0 +1,62 @@
"""Models for the PebbloRetrievalQA chain."""
from typing import Any, List, Optional
from langchain_core.pydantic_v1 import BaseModel
class AuthContext(BaseModel):
"""Class for an authorization context."""
name: Optional[str] = None
user_id: str
user_auth: List[str]
"""List of user authorizations, which may include their User ID and
the groups they are part of"""
class SemanticEntities(BaseModel):
"""Class for a semantic entity filter."""
deny: List[str]
class SemanticTopics(BaseModel):
"""Class for a semantic topic filter."""
deny: List[str]
class SemanticContext(BaseModel):
"""Class for a semantic context."""
pebblo_semantic_entities: Optional[SemanticEntities] = None
pebblo_semantic_topics: Optional[SemanticTopics] = None
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Validate semantic_context
if (
self.pebblo_semantic_entities is None
and self.pebblo_semantic_topics is None
):
raise ValueError(
"semantic_context must contain 'pebblo_semantic_entities' or "
"'pebblo_semantic_topics'"
)
class ChainInput(BaseModel):
"""Input for PebbloRetrievalQA chain."""
query: str
auth_context: Optional[AuthContext] = None
semantic_context: Optional[SemanticContext] = None
def dict(self, **kwargs: Any) -> dict:
base_dict = super().dict(**kwargs)
# Keep auth_context and semantic_context as it is(Pydantic models)
base_dict["auth_context"] = self.auth_context
base_dict["semantic_context"] = self.semantic_context
return base_dict

@ -0,0 +1,129 @@
"""
Unit tests for the PebbloRetrievalQA chain
"""
from typing import List
from unittest.mock import Mock
import pytest
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from langchain_community.chains import PebbloRetrievalQA
from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
ChainInput,
SemanticContext,
)
from langchain_community.vectorstores.chroma import Chroma
from langchain_community.vectorstores.pinecone import Pinecone
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeRetriever(VectorStoreRetriever):
"""
Test util that parrots the query back as documents
"""
vectorstore: VectorStore = Mock()
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return [Document(page_content=query)]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
return [Document(page_content=query)]
@pytest.fixture
def unsupported_retriever() -> FakeRetriever:
"""
Create a FakeRetriever instance
"""
retriever = FakeRetriever()
retriever.search_kwargs = {}
# Set the class of vectorstore to Chroma
retriever.vectorstore.__class__ = Chroma
return retriever
@pytest.fixture
def retriever() -> FakeRetriever:
"""
Create a FakeRetriever instance
"""
retriever = FakeRetriever()
retriever.search_kwargs = {}
# Set the class of vectorstore to Pinecone
retriever.vectorstore.__class__ = Pinecone
return retriever
@pytest.fixture
def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA:
"""
Create a PebbloRetrievalQA instance
"""
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
llm=FakeLLM(), chain_type="stuff", retriever=retriever
)
return pebblo_retrieval_qa
def test_invoke(pebblo_retrieval_qa: PebbloRetrievalQA) -> None:
"""
Test that the invoke method returns a non-None result
"""
# Create a fake auth context and semantic context
auth_context = AuthContext(
user_id="fake_user@email.com",
user_auth=["fake-group", "fake-group2"],
)
semantic_context_dict = {
"pebblo_semantic_topics": {"deny": ["harmful-advice"]},
"pebblo_semantic_entities": {"deny": ["credit-card"]},
}
semantic_context = SemanticContext(**semantic_context_dict)
question = "What is the meaning of life?"
chain_input_obj = ChainInput(
query=question, auth_context=auth_context, semantic_context=semantic_context
)
response = pebblo_retrieval_qa.invoke(chain_input_obj.dict())
assert response is not None
def test_validate_vectorstore(
retriever: FakeRetriever, unsupported_retriever: FakeRetriever
) -> None:
"""
Test vectorstore validation
"""
# No exception should be raised for supported vectorstores (Pinecone)
_ = PebbloRetrievalQA.from_chain_type(
llm=FakeLLM(),
chain_type="stuff",
retriever=retriever,
)
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
with pytest.raises(ValueError) as exc_info:
_ = PebbloRetrievalQA.from_chain_type(
llm=FakeLLM(),
chain_type="stuff",
retriever=unsupported_retriever,
)
assert (
"Vectorstore must be an instance of one of the supported vectorstores"
in str(exc_info.value)
)
Loading…
Cancel
Save