From 0bf75968396368f6b50672aaee5d0a212ad9b85b Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Tue, 7 May 2024 17:41:09 +0200 Subject: [PATCH] Add simple node properties to llm graph transformer (#21369) Add support for simple node properties in llm graph transformer. Linter and dynamic pydantic classes aren't friends, hence I added two ignores --- .../graph_transformers/llm.py | 128 ++++++++++++++---- 1 file changed, 101 insertions(+), 27 deletions(-) diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index 0c98517ea5..6cd97941a9 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -1,6 +1,6 @@ import asyncio import json -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_core.documents import Document @@ -12,7 +12,7 @@ from langchain_core.prompts import ( HumanMessagePromptTemplate, PromptTemplate, ) -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, Field, create_model examples = [ { @@ -122,10 +122,34 @@ default_prompt = ChatPromptTemplate.from_messages( ) +def _get_additional_info(input_type: str) -> str: + # Check if the input_type is one of the allowed values + if input_type not in ["node", "relationship", "property"]: + raise ValueError("input_type must be 'node', 'relationship', or 'property'") + + # Perform actions based on the input_type + if input_type == "node": + return ( + "Ensure you use basic or elementary types for node labels.\n" + "For example, when you identify an entity representing a person, " + "always label it as **'Person'**. Avoid using more specific terms " + "like 'Mathematician' or 'Scientist'" + ) + elif input_type == "relationship": + return ( + "Instead of using specific and momentary types such as " + "'BECAME_PROFESSOR', use more general and timeless relationship types like " + "'PROFESSOR'. However, do not sacrifice any accuracy for generality" + ) + elif input_type == "property": + return "" + return "" + + def optional_enum_field( enum_values: Optional[List[str]] = None, description: str = "", - is_rel: bool = False, + input_type: str = "node", **field_kwargs: Any, ) -> Any: """Utility function to conditionally create a field with an enum constraint.""" @@ -137,18 +161,7 @@ def optional_enum_field( **field_kwargs, ) else: - node_info = ( - "Ensure you use basic or elementary types for node labels.\n" - "For example, when you identify an entity representing a person, " - "always label it as **'Person'**. Avoid using more specific terms " - "like 'Mathematician' or 'Scientist'" - ) - rel_info = ( - "Instead of using specific and momentary types such as " - "'BECAME_PROFESSOR', use more general and timeless relationship types like " - "'PROFESSOR'. However, do not sacrifice any accuracy for generality" - ) - additional_info = rel_info if is_rel else node_info + additional_info = _get_additional_info(input_type) return Field(..., description=description + additional_info, **field_kwargs) @@ -255,20 +268,52 @@ For the following text, extract entities and relations as in the provided exampl def create_simple_model( - node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None + node_labels: Optional[List[str]] = None, + rel_types: Optional[List[str]] = None, + node_properties: Union[bool, List[str]] = False, ) -> Type[_Graph]: """ Simple model allows to limit node and/or relationship types. Doesn't have any node or relationship properties. """ - class SimpleNode(BaseModel): - """Represents a node in a graph with associated properties.""" + node_fields: Dict[str, Tuple[Any, Any]] = { + "id": ( + str, + Field(..., description="Name or human-readable unique identifier."), + ), + "type": ( + str, + optional_enum_field( + node_labels, + description="The type or label of the node.", + input_type="node", + ), + ), + } + if node_properties: + if isinstance(node_properties, list) and "id" in node_properties: + raise ValueError("The node property 'id' is reserved and cannot be used.") + # Map True to empty array + node_properties_mapped: List[str] = ( + [] if node_properties is True else node_properties + ) - id: str = Field(description="Name or human-readable unique identifier.") - type: str = optional_enum_field( - node_labels, description="The type or label of the node." + class Property(BaseModel): + """A single property consisting of key and value""" + + key: str = optional_enum_field( + node_properties_mapped, + description="Property key.", + input_type="property", + ) + value: str = Field(..., description="value") + + node_fields["properties"] = ( + Optional[List[Property]], + Field(None, description="List of node properties"), ) + SimpleNode = create_model("SimpleNode", **node_fields) # type: ignore class SimpleRelationship(BaseModel): """Represents a directed relationship between two nodes in a graph.""" @@ -277,22 +322,28 @@ def create_simple_model( description="Name or human-readable unique identifier of source node" ) source_node_type: str = optional_enum_field( - node_labels, description="The type or label of the source node." + node_labels, + description="The type or label of the source node.", + input_type="node", ) target_node_id: str = Field( description="Name or human-readable unique identifier of target node" ) target_node_type: str = optional_enum_field( - node_labels, description="The type or label of the target node." + node_labels, + description="The type or label of the target node.", + input_type="node", ) type: str = optional_enum_field( - rel_types, description="The type of the relationship.", is_rel=True + rel_types, + description="The type of the relationship.", + input_type="relationship", ) class DynamicGraph(_Graph): """Represents a graph document consisting of nodes and relationships.""" - nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") + nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") # type: ignore relationships: Optional[List[SimpleRelationship]] = Field( description="List of relationships" ) @@ -302,7 +353,11 @@ def create_simple_model( def map_to_base_node(node: Any) -> Node: """Map the SimpleNode to the base Node.""" - return Node(id=node.id, type=node.type) + properties = {} + if hasattr(node, "properties") and node.properties: + for p in node.properties: + properties[format_property_key(p.key)] = p.value + return Node(id=node.id, type=node.type, properties=properties) def map_to_base_relationship(rel: Any) -> Relationship: @@ -378,6 +433,7 @@ def _format_nodes(nodes: List[Node]) -> List[Node]: Node( id=el.id.title() if isinstance(el.id, str) else el.id, type=el.type.capitalize(), + properties=el.properties, ) for el in nodes ] @@ -394,6 +450,15 @@ def _format_relationships(rels: List[Relationship]) -> List[Relationship]: ] +def format_property_key(s: str) -> str: + words = s.split() + if not words: + return s + first_word = words[0].lower() + capitalized_words = [word.capitalize() for word in words[1:]] + return "".join([first_word] + capitalized_words) + + def _convert_to_graph_document( raw_schema: Dict[Any, Any], ) -> Tuple[List[Node], List[Relationship]]: @@ -474,6 +539,7 @@ class LLMGraphTransformer: allowed_relationships: List[str] = [], prompt: Optional[ChatPromptTemplate] = None, strict_mode: bool = True, + node_properties: Union[bool, List[str]] = False, ) -> None: self.allowed_nodes = allowed_nodes self.allowed_relationships = allowed_relationships @@ -485,6 +551,12 @@ class LLMGraphTransformer: except NotImplementedError: self._function_call = False if not self._function_call: + if node_properties: + raise ValueError( + "The 'node_properties' parameter cannot be used " + "in combination with a LLM that doesn't support " + "native function calling." + ) try: import json_repair @@ -500,7 +572,9 @@ class LLMGraphTransformer: self.chain = prompt | llm else: # Define chain - schema = create_simple_model(allowed_nodes, allowed_relationships) + schema = create_simple_model( + allowed_nodes, allowed_relationships, node_properties + ) structured_llm = llm.with_structured_output(schema, include_raw=True) prompt = prompt or default_prompt self.chain = prompt | structured_llm