mongodb[patch]: Make ObjectId JSON-serializable on generation (#21394)

pull/21678/head
Jib 3 weeks ago committed by GitHub
parent 12b599c47f
commit a97473c846
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -16,6 +16,7 @@ from typing import (
)
import numpy as np
from bson import ObjectId, json_util
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
@ -210,9 +211,23 @@ class MongoDBAtlasVectorSearch(VectorStore):
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
docs = []
def _make_serializable(obj: Dict[str, Any]) -> None:
for k, v in obj.items():
if isinstance(v, dict):
_make_serializable(v)
elif isinstance(v, list) and v and isinstance(v[0], ObjectId):
obj[k] = [json_util.default(item) for item in v]
elif isinstance(v, ObjectId):
obj[k] = json_util.default(v)
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
# Make every ObjectId found JSON-Serializable
# following format used in bson.json_util.loads
# e.g. loads('{"_id": {"$oid": "664..."}}') == {'_id': ObjectId('664..')} # noqa: E501
_make_serializable(res)
docs.append((Document(page_content=text, metadata=res), score))
return docs

@ -1,6 +1,8 @@
from json import dumps, loads
from typing import Any, Optional
import pytest
from bson import ObjectId, json_util
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from pymongo.collection import Collection
@ -75,6 +77,11 @@ class TestMongoDBAtlasVectorSearch:
output = vectorstore.similarity_search("", k=1)
assert output[0].page_content == page_content
assert output[0].metadata.get("c") == metadata
# Validate the ObjectId provided is json serializable
assert loads(dumps(output[0].page_content)) == output[0].page_content
assert loads(dumps(output[0].metadata)) == output[0].metadata
json_metadata = dumps(output[0].metadata) # normal json.dumps
assert isinstance(json_util.loads(json_metadata)["_id"], ObjectId)
def test_from_documents(
self, embedding_openai: Embeddings, collection: MockCollection

@ -1,9 +1,9 @@
from __future__ import annotations
import uuid
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Optional, cast
from bson import ObjectId
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@ -162,7 +162,7 @@ class MockCollection(Collection):
def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
mongodb_inserts = [
{"_id": str(uuid.uuid4()), "score": 1, **insert} for insert in to_insert
{"_id": ObjectId(), "score": 1, **insert} for insert in to_insert
]
self._data.extend(mongodb_inserts)
return self._insert_result or InsertManyResult(

Loading…
Cancel
Save