|
|
|
@ -1,6 +1,6 @@
|
|
|
|
|
import asyncio
|
|
|
|
|
import json
|
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
|
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
|
|
|
|
|
|
|
|
|
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
@ -12,7 +12,7 @@ from langchain_core.prompts import (
|
|
|
|
|
HumanMessagePromptTemplate,
|
|
|
|
|
PromptTemplate,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
|
|
|
|
|
|
|
|
|
examples = [
|
|
|
|
|
{
|
|
|
|
@ -122,10 +122,34 @@ default_prompt = ChatPromptTemplate.from_messages(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_additional_info(input_type: str) -> str:
|
|
|
|
|
# Check if the input_type is one of the allowed values
|
|
|
|
|
if input_type not in ["node", "relationship", "property"]:
|
|
|
|
|
raise ValueError("input_type must be 'node', 'relationship', or 'property'")
|
|
|
|
|
|
|
|
|
|
# Perform actions based on the input_type
|
|
|
|
|
if input_type == "node":
|
|
|
|
|
return (
|
|
|
|
|
"Ensure you use basic or elementary types for node labels.\n"
|
|
|
|
|
"For example, when you identify an entity representing a person, "
|
|
|
|
|
"always label it as **'Person'**. Avoid using more specific terms "
|
|
|
|
|
"like 'Mathematician' or 'Scientist'"
|
|
|
|
|
)
|
|
|
|
|
elif input_type == "relationship":
|
|
|
|
|
return (
|
|
|
|
|
"Instead of using specific and momentary types such as "
|
|
|
|
|
"'BECAME_PROFESSOR', use more general and timeless relationship types like "
|
|
|
|
|
"'PROFESSOR'. However, do not sacrifice any accuracy for generality"
|
|
|
|
|
)
|
|
|
|
|
elif input_type == "property":
|
|
|
|
|
return ""
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def optional_enum_field(
|
|
|
|
|
enum_values: Optional[List[str]] = None,
|
|
|
|
|
description: str = "",
|
|
|
|
|
is_rel: bool = False,
|
|
|
|
|
input_type: str = "node",
|
|
|
|
|
**field_kwargs: Any,
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""Utility function to conditionally create a field with an enum constraint."""
|
|
|
|
@ -137,18 +161,7 @@ def optional_enum_field(
|
|
|
|
|
**field_kwargs,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
node_info = (
|
|
|
|
|
"Ensure you use basic or elementary types for node labels.\n"
|
|
|
|
|
"For example, when you identify an entity representing a person, "
|
|
|
|
|
"always label it as **'Person'**. Avoid using more specific terms "
|
|
|
|
|
"like 'Mathematician' or 'Scientist'"
|
|
|
|
|
)
|
|
|
|
|
rel_info = (
|
|
|
|
|
"Instead of using specific and momentary types such as "
|
|
|
|
|
"'BECAME_PROFESSOR', use more general and timeless relationship types like "
|
|
|
|
|
"'PROFESSOR'. However, do not sacrifice any accuracy for generality"
|
|
|
|
|
)
|
|
|
|
|
additional_info = rel_info if is_rel else node_info
|
|
|
|
|
additional_info = _get_additional_info(input_type)
|
|
|
|
|
return Field(..., description=description + additional_info, **field_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -255,20 +268,52 @@ For the following text, extract entities and relations as in the provided exampl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_simple_model(
|
|
|
|
|
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
|
|
|
|
|
node_labels: Optional[List[str]] = None,
|
|
|
|
|
rel_types: Optional[List[str]] = None,
|
|
|
|
|
node_properties: Union[bool, List[str]] = False,
|
|
|
|
|
) -> Type[_Graph]:
|
|
|
|
|
"""
|
|
|
|
|
Simple model allows to limit node and/or relationship types.
|
|
|
|
|
Doesn't have any node or relationship properties.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
class SimpleNode(BaseModel):
|
|
|
|
|
"""Represents a node in a graph with associated properties."""
|
|
|
|
|
node_fields: Dict[str, Tuple[Any, Any]] = {
|
|
|
|
|
"id": (
|
|
|
|
|
str,
|
|
|
|
|
Field(..., description="Name or human-readable unique identifier."),
|
|
|
|
|
),
|
|
|
|
|
"type": (
|
|
|
|
|
str,
|
|
|
|
|
optional_enum_field(
|
|
|
|
|
node_labels,
|
|
|
|
|
description="The type or label of the node.",
|
|
|
|
|
input_type="node",
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
if node_properties:
|
|
|
|
|
if isinstance(node_properties, list) and "id" in node_properties:
|
|
|
|
|
raise ValueError("The node property 'id' is reserved and cannot be used.")
|
|
|
|
|
# Map True to empty array
|
|
|
|
|
node_properties_mapped: List[str] = (
|
|
|
|
|
[] if node_properties is True else node_properties
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
id: str = Field(description="Name or human-readable unique identifier.")
|
|
|
|
|
type: str = optional_enum_field(
|
|
|
|
|
node_labels, description="The type or label of the node."
|
|
|
|
|
class Property(BaseModel):
|
|
|
|
|
"""A single property consisting of key and value"""
|
|
|
|
|
|
|
|
|
|
key: str = optional_enum_field(
|
|
|
|
|
node_properties_mapped,
|
|
|
|
|
description="Property key.",
|
|
|
|
|
input_type="property",
|
|
|
|
|
)
|
|
|
|
|
value: str = Field(..., description="value")
|
|
|
|
|
|
|
|
|
|
node_fields["properties"] = (
|
|
|
|
|
Optional[List[Property]],
|
|
|
|
|
Field(None, description="List of node properties"),
|
|
|
|
|
)
|
|
|
|
|
SimpleNode = create_model("SimpleNode", **node_fields) # type: ignore
|
|
|
|
|
|
|
|
|
|
class SimpleRelationship(BaseModel):
|
|
|
|
|
"""Represents a directed relationship between two nodes in a graph."""
|
|
|
|
@ -277,22 +322,28 @@ def create_simple_model(
|
|
|
|
|
description="Name or human-readable unique identifier of source node"
|
|
|
|
|
)
|
|
|
|
|
source_node_type: str = optional_enum_field(
|
|
|
|
|
node_labels, description="The type or label of the source node."
|
|
|
|
|
node_labels,
|
|
|
|
|
description="The type or label of the source node.",
|
|
|
|
|
input_type="node",
|
|
|
|
|
)
|
|
|
|
|
target_node_id: str = Field(
|
|
|
|
|
description="Name or human-readable unique identifier of target node"
|
|
|
|
|
)
|
|
|
|
|
target_node_type: str = optional_enum_field(
|
|
|
|
|
node_labels, description="The type or label of the target node."
|
|
|
|
|
node_labels,
|
|
|
|
|
description="The type or label of the target node.",
|
|
|
|
|
input_type="node",
|
|
|
|
|
)
|
|
|
|
|
type: str = optional_enum_field(
|
|
|
|
|
rel_types, description="The type of the relationship.", is_rel=True
|
|
|
|
|
rel_types,
|
|
|
|
|
description="The type of the relationship.",
|
|
|
|
|
input_type="relationship",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
class DynamicGraph(_Graph):
|
|
|
|
|
"""Represents a graph document consisting of nodes and relationships."""
|
|
|
|
|
|
|
|
|
|
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes")
|
|
|
|
|
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") # type: ignore
|
|
|
|
|
relationships: Optional[List[SimpleRelationship]] = Field(
|
|
|
|
|
description="List of relationships"
|
|
|
|
|
)
|
|
|
|
@ -302,7 +353,11 @@ def create_simple_model(
|
|
|
|
|
|
|
|
|
|
def map_to_base_node(node: Any) -> Node:
|
|
|
|
|
"""Map the SimpleNode to the base Node."""
|
|
|
|
|
return Node(id=node.id, type=node.type)
|
|
|
|
|
properties = {}
|
|
|
|
|
if hasattr(node, "properties") and node.properties:
|
|
|
|
|
for p in node.properties:
|
|
|
|
|
properties[format_property_key(p.key)] = p.value
|
|
|
|
|
return Node(id=node.id, type=node.type, properties=properties)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def map_to_base_relationship(rel: Any) -> Relationship:
|
|
|
|
@ -378,6 +433,7 @@ def _format_nodes(nodes: List[Node]) -> List[Node]:
|
|
|
|
|
Node(
|
|
|
|
|
id=el.id.title() if isinstance(el.id, str) else el.id,
|
|
|
|
|
type=el.type.capitalize(),
|
|
|
|
|
properties=el.properties,
|
|
|
|
|
)
|
|
|
|
|
for el in nodes
|
|
|
|
|
]
|
|
|
|
@ -394,6 +450,15 @@ def _format_relationships(rels: List[Relationship]) -> List[Relationship]:
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_property_key(s: str) -> str:
|
|
|
|
|
words = s.split()
|
|
|
|
|
if not words:
|
|
|
|
|
return s
|
|
|
|
|
first_word = words[0].lower()
|
|
|
|
|
capitalized_words = [word.capitalize() for word in words[1:]]
|
|
|
|
|
return "".join([first_word] + capitalized_words)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_to_graph_document(
|
|
|
|
|
raw_schema: Dict[Any, Any],
|
|
|
|
|
) -> Tuple[List[Node], List[Relationship]]:
|
|
|
|
@ -474,6 +539,7 @@ class LLMGraphTransformer:
|
|
|
|
|
allowed_relationships: List[str] = [],
|
|
|
|
|
prompt: Optional[ChatPromptTemplate] = None,
|
|
|
|
|
strict_mode: bool = True,
|
|
|
|
|
node_properties: Union[bool, List[str]] = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
self.allowed_nodes = allowed_nodes
|
|
|
|
|
self.allowed_relationships = allowed_relationships
|
|
|
|
@ -485,6 +551,12 @@ class LLMGraphTransformer:
|
|
|
|
|
except NotImplementedError:
|
|
|
|
|
self._function_call = False
|
|
|
|
|
if not self._function_call:
|
|
|
|
|
if node_properties:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The 'node_properties' parameter cannot be used "
|
|
|
|
|
"in combination with a LLM that doesn't support "
|
|
|
|
|
"native function calling."
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
import json_repair
|
|
|
|
|
|
|
|
|
@ -500,7 +572,9 @@ class LLMGraphTransformer:
|
|
|
|
|
self.chain = prompt | llm
|
|
|
|
|
else:
|
|
|
|
|
# Define chain
|
|
|
|
|
schema = create_simple_model(allowed_nodes, allowed_relationships)
|
|
|
|
|
schema = create_simple_model(
|
|
|
|
|
allowed_nodes, allowed_relationships, node_properties
|
|
|
|
|
)
|
|
|
|
|
structured_llm = llm.with_structured_output(schema, include_raw=True)
|
|
|
|
|
prompt = prompt or default_prompt
|
|
|
|
|
self.chain = prompt | structured_llm
|
|
|
|
|