You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/cookbook/langgraph_crag.ipynb

529 lines
264 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "459d0bcf-7c60-495e-91c3-85b0b8c67552",
"metadata": {},
"outputs": [],
"source": [
"! pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python"
]
},
{
"attachments": {
"5bfa38a2-78a1-4e99-80a2-d98c8a440ea2.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA58AAAKUCAYAAACDoYwLAAAMP2lDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBCCSAgJfQmCEgJICWEFkB6EWyEJEAoMQaCiB1dVHDtYgEbuiqi2AGxI3YWwd4XRRSUdbFgV96kgK77yvfO9829//3nzH/OnDu3DADqp7hicQ6qAUCuKF8SGxLAGJucwiB1AwTggAYIgMDl5YlZ0dERANrg+e/27ib0hnbNQab1z/7/app8QR4PACQa4jR+Hi8X4kMA4JU8sSQfAKKMN5+aL5Zh2IC2BCYI8UIZzlDgShlOU+B9cp/4WDbEzQCoqHG5kgwAaG2QZxTwMqAGrQ9iJxFfKAJAnQGxb27uZD7EqRDbQB8xxDJ9ZtoPOhl/00wb0uRyM4awYi5yUwkU5olzuNP+z3L8b8vNkQ7GsIJNLVMSGiubM6zb7ezJ4TKsBnGvKC0yCmItiD8I+XJ/iFFKpjQ0QeGPGvLy2LBmQBdiJz43MBxiQ4iDRTmREUo+LV0YzIEYrhC0UJjPiYdYD+KFgrygOKXPZsnkWGUstC5dwmYp+QtciTyuLNZDaXYCS6n/OlPAUepjtKLM+CSIKRBbFAgTIyGmQeyYlx0XrvQZXZTJjhz0kUhjZflbQBwrEIUEKPSxgnRJcKzSvzQ3b3C+2OZMISdSiQ/kZ8aHKuqDNfO48vzhXLA2gYiVMKgjyBsbMTgXviAwSDF3rFsgSohT6nwQ5wfEKsbiFHFOtNIfNxPkhMh4M4hd8wrilGPxxHy4IBX6eLo4PzpekSdelMUNi1bkgy8DEYANAgEDSGFLA5NBFhC29tb3witFTzDgAgnIAALgoGQGRyTJe0TwGAeKwJ8QCUDe0LgAea8AFED+6xCrODqAdHlvgXxENngKcS4IBznwWiofJRqKlgieQEb4j+hc2Hgw3xzYZP3/nh9kvzMsyEQoGelgRIb6oCcxiBhIDCUGE21xA9wX98Yj4NEfNheciXsOzuO7P+EpoZ3wmHCD0EG4M0lYLPkpyzGgA+oHK2uR9mMtcCuo6YYH4D5QHSrjurgBcMBdYRwW7gcju0GWrcxbVhXGT9p/m8EPd0PpR3Yio+RhZH+yzc8jaXY0tyEVWa1/rI8i17SherOHen6Oz/6h+nx4Dv/ZE1uIHcTOY6exi9gxrB4wsJNYA9aCHZfhodX1RL66BqPFyvPJhjrCf8QbvLOySuY51Tj1OH1R9OULCmXvaMCeLJ4mEWZk5jNY8IsgYHBEPMcRDBcnF1cAZN8XxevrTYz8u4Hotnzn5v0BgM/JgYGBo9+5sJMA7PeAj/+R75wNE346VAG4cIQnlRQoOFx2IMC3hDp80vSBMTAHNnA+LsAdeAN/EATCQBSIB8lgIsw+E65zCZgKZoC5oASUgWVgNVgPNoGtYCfYAw6AenAMnAbnwGXQBm6Ae3D1dIEXoA+8A58RBCEhVISO6CMmiCVij7ggTMQXCUIikFgkGUlFMhARIkVmIPOQMmQFsh7ZglQj+5EjyGnkItKO3EEeIT3Ia+QTiqFqqDZqhFqhI1EmykLD0Xh0ApqBTkGL0PnoEnQtWoXuRuvQ0+hl9Abagb5A+zGAqWK6mCnmgDExNhaFpWDpmASbhZVi5VgVVos1wvt8DevAerGPOBGn4wzcAa7gUDwB5+FT8Fn4Ynw9vhOvw5vxa/gjvA//RqASDAn2BC8ChzCWkEGYSighlBO2Ew4TzsJnqYvwjkgk6hKtiR7wWUwmZhGnExcTNxD3Ek8R24mdxH4SiaRPsif5kKJIXFI+qYS0jrSbdJJ0ldRF+qCiqmKi4qISrJKiIlIpVilX2aVyQuWqyjOVz2QNsiXZixxF5pOnkZeSt5EbyVfIXeTPFE2KNcWHEk/JosylrKXUUs5S7lPeqKqqmql6qsaoClXnqK5V3ad6QfWR6kc1LTU7NbbaeDWp2hK1HWqn1O6ovaFSqVZUf2oKNZ+6hFpNPUN9SP1Ao9McaRwanzabVkGro12lvVQnq1uqs9Qnqhepl6sfVL+i3qtB1rDSYGtwNWZpVGgc0bil0a9J13TWjNLM1VysuUvzoma3FknLSitIi681X2ur1hmtTjpGN6ez6Tz6PPo2+ll6lzZR21qbo52lXaa9R7tVu09HS8dVJ1GnUKdC57hOhy6ma6XL0c3RXap7QPem7qdhRsNYwwTDFg2rHXZ12Hu94Xr+egK9Ur29ejf0Pukz9IP0s/WX69frPzDADewMYgymGmw0OGvQO1x7uPdw3vDS4QeG3zVEDe0MYw2nG241bDHsNzI2CjESG60zOmPUa6xr7G+cZbzK+IRxjwndxNdEaLLK5KTJc4YOg8XIYaxlNDP6TA1NQ02lpltMW00/m1mbJZgVm+01e2BOMWeap5uvMm8y77MwsRhjMcOixuKuJdmSaZlpucbyvOV7K2urJKsFVvVW3dZ61hzrIusa6/s2VBs/myk2VTbXbYm2TNts2w22bXaonZtdpl2F3RV71N7dXmi/wb59BGGE5wjRiKoRtxzUHFgOBQ41Do8cdR0jHIsd6x1fjrQYmTJy+cjzI785uTnlOG1zuues5RzmXOzc6Pzaxc6F51Lhcn0UdVTwqNmjGka9crV3FbhudL3tRncb47bArcntq7uHu8S91r3Hw8Ij1aPS4xZTmxnNXMy84EnwDPCc7XnM86OXu1e+1wGvv7wdvLO9d3l3j7YeLRi9bXSnj5kP12eLT4cvwzfVd7Nvh5+pH9evyu+xv7k/33+7/zOWLSuLtZv1MsApQBJwOOA924s9k30qEAsMCSwNbA3SCkoIWh/0MNgsOCO4JrgvxC1kesipUEJoeOjy0FscIw6PU83pC/MImxnWHK4WHhe+PvxxhF2EJKJxDDombMzKMfcjLSNFkfVRIIoTtTLqQbR19JToozHEmOiYipinsc6xM2LPx9HjJsXtinsXHxC/NP5egk2CNKEpUT1xfGJ14vukwKQVSR1jR46dOfZyskGyMLkhhZSSmLI9pX9c0LjV47rGu40vGX9zgvWEwgkXJxpMzJl4fJL6JO6kg6mE1KTUXalfuFHcKm5/GietMq2Px+at4b3g+/NX8XsEPoIVgmfpPukr0rszfDJWZvRk+mWWZ/YK2cL1wldZoVmbst5nR2XvyB7IScrZm6uSm5p7RKQlyhY1TzaeXDi5XWwvLhF3TPGasnpKnyRcsj0PyZuQ15CvDX/kW6Q20l+kjwp8CyoKPkxNnHqwULNQVNgyzW7aomnPioKLfpuOT+dNb5phOmPujEczWTO3zEJmpc1qmm0+e/7srjkhc3bOpczNnvt7sVPxiuK385LmNc43mj9nfucvIb/UlNBKJCW3Fngv2LQQXyhc2Lpo1KJ1i76V8ksvlTmVlZd9WcxbfOlX51/X/jqwJH1J61L3pRuXEZeJlt1c7rd85wrNFUUrOleOWVm3irGqdNXb1ZNWXyx3Ld+0hrJGuqZjbcTahnUW65at+7I+c/2NioCKvZWGlYsq32/gb7i60X9j7SajTWWbPm0Wbr69JWRLXZVVVflW4taCrU+3JW47/xvzt+rtBtvLtn/dIdrRsTN2Z3O1R3X1LsNdS2vQGmlNz+7xu9v2BO5pqHWo3bJXd2/ZPrBPuu/5/tT9Nw+EH2g6yDxYe8jyUOVh+uHSOqRuWl1ffWZ9R0NyQ/uRsCNNjd6Nh486Ht1xzPRYxXGd40tPUE7MPzFwsuhk/ynxqd7TGac7myY13Tsz9sz15pjm1rPhZy+cCz535jzr/MkLPheOXfS6eOQS81L9ZffLdS1uLYd/d/v9cKt7a90VjysNbZ5tje2j209c9bt6+lrgtXPXOdcv34i80X4z4ebtW+Nvddzm3+6+k3Pn1d2Cu5/vzblPuF/6QONB+UPDh1V/2P6xt8O94/ijwEctj+Me3+vkdb54kvfkS9f8p9Sn5c9MnlV3u3Qf6wnuaXs+7nnXC/GLz70lf2r+WfnS5uWhv/z/aukb29f1SvJq4PXiN/pvdrx1fdvUH93/8F3uu8/vSz/of9j5kfnx/KekT88+T/1C+rL2q+3Xxm/h
}
},
"cell_type": "markdown",
"id": "8889a307-fa3f-4d38-9127-d41e4686ae47",
"metadata": {},
"source": [
"# CRAG\n",
"\n",
"Corrective-RAG is a recent paper that introduces an interesting approach for active RAG. \n",
"\n",
"The framework grades retrieved documents relative to the question:\n",
"\n",
"1. Correct documents -\n",
"\n",
"* If at least one document exceeds the threshold for relevance, then it proceeds to generation\n",
"* Before generation, it performns knowledge refinement\n",
"* This paritions the document into \"knowledge strips\"\n",
"* It grades each strip, and filters our irrelevant ones \n",
"\n",
"2. Ambiguous or incorrect documents -\n",
"\n",
"* If all documents fall below the relevance threshold or if the grader is unsure, then the framework seeks an additional datasource\n",
"* It will use web search to supplement retrieval\n",
"* The diagrams in the paper also suggest that query re-writing is used here \n",
"\n",
"![Screenshot 2024-02-04 at 2.50.32 PM.png](attachment:5bfa38a2-78a1-4e99-80a2-d98c8a440ea2.png)\n",
"\n",
"Paper -\n",
"\n",
"https://arxiv.org/pdf/2401.15884.pdf\n",
"\n",
"---\n",
"\n",
"Let's implement this from scratch using [LangGraph](https://python.langchain.com/docs/langgraph).\n",
"\n",
"We can make some simplifications:\n",
"\n",
"* Let's skip the knowledge refinement phase as a first pass. This can be added back as a node, if desired. \n",
"* If *any* document is irrelevant, let's opt to supplement retrieval with web search. \n",
"* We'll use [Tavily Search](https://python.langchain.com/docs/integrations/tools/tavily_search) for web search.\n",
"* Let's use query re-writing to optimize the query for web search.\n",
"\n",
"Set the `TAVILY_API_KEY`."
]
},
{
"cell_type": "markdown",
"id": "a21f32d2-92ce-4995-b309-99347bafe3be",
"metadata": {},
"source": [
"## Retriever\n",
" \n",
"Let's index 3 blog posts."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a566a30-cf0e-4330-ad4d-9bf994bdfa86",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain_community.document_loaders import WebBaseLoader\n",
"from langchain_community.vectorstores import Chroma\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"urls = [\n",
" \"https://lilianweng.github.io/posts/2023-06-23-agent/\",\n",
" \"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/\",\n",
" \"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/\",\n",
"]\n",
"\n",
"docs = [WebBaseLoader(url).load() for url in urls]\n",
"docs_list = [item for sublist in docs for item in sublist]\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
" chunk_size=250, chunk_overlap=0\n",
")\n",
"doc_splits = text_splitter.split_documents(docs_list)\n",
"\n",
"# Add to vectorDB\n",
"vectorstore = Chroma.from_documents(\n",
" documents=doc_splits,\n",
" collection_name=\"rag-chroma\",\n",
" embedding=OpenAIEmbeddings(),\n",
")\n",
"retriever = vectorstore.as_retriever()"
]
},
{
"cell_type": "markdown",
"id": "87194a1b-535a-4593-ab95-5736fae176d1",
"metadata": {},
"source": [
"## State\n",
" \n",
"We will define a graph.\n",
"\n",
"Our state will be a `dict`.\n",
"\n",
"We can access this from any graph node as `state['keys']`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94b3945f-ef0f-458d-a443-f763903550b0",
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, TypedDict\n",
"\n",
"from langchain_core.messages import BaseMessage\n",
"\n",
"\n",
"class GraphState(TypedDict):\n",
" \"\"\"\n",
" Represents the state of an agent in the conversation.\n",
"\n",
" Attributes:\n",
" keys: A dictionary where each key is a string and the value is expected to be a list or another structure\n",
" that supports addition with `operator.add`. This could be used, for instance, to accumulate messages\n",
" or other pieces of data throughout the graph.\n",
" \"\"\"\n",
"\n",
" keys: Dict[str, any]"
]
},
{
"attachments": {
"3b65f495-5fc4-497b-83e2-73844a97f6cc.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxEAAAEXCAYAAADSoclSAAAMP2lDQ1BJQ0MgUHJvZmlsZQAASImVVwdYU8kWnluSkEBCCSAgJfQmCEgJICWEFkB6EWyEJEAoMQaCiB1dVHDtYgEbuiqi2AGxI3YWwd4XRRSUdbFgV96kgK77yvfO9829//3nzH/OnDu3DADqp7hicQ6qAUCuKF8SGxLAGJucwiB1AwTggAYIgMDl5YlZ0dERANrg+e/27ib0hnbNQab1z/7/app8QR4PACQa4jR+Hi8X4kMA4JU8sSQfAKKMN5+aL5Zh2IC2BCYI8UIZzlDgShlOU+B9cp/4WDbEzQCoqHG5kgwAaG2QZxTwMqAGrQ9iJxFfKAJAnQGxb27uZD7EqRDbQB8xxDJ9ZtoPOhl/00wb0uRyM4awYi5yUwkU5olzuNP+z3L8b8vNkQ7GsIJNLVMSGiubM6zb7ezJ4TKsBnGvKC0yCmItiD8I+XJ/iFFKpjQ0QeGPGvLy2LBmQBdiJz43MBxiQ4iDRTmREUo+LV0YzIEYrhC0UJjPiYdYD+KFgrygOKXPZsnkWGUstC5dwmYp+QtciTyuLNZDaXYCS6n/OlPAUepjtKLM+CSIKRBbFAgTIyGmQeyYlx0XrvQZXZTJjhz0kUhjZflbQBwrEIUEKPSxgnRJcKzSvzQ3b3C+2OZMISdSiQ/kZ8aHKuqDNfO48vzhXLA2gYiVMKgjyBsbMTgXviAwSDF3rFsgSohT6nwQ5wfEKsbiFHFOtNIfNxPkhMh4M4hd8wrilGPxxHy4IBX6eLo4PzpekSdelMUNi1bkgy8DEYANAgEDSGFLA5NBFhC29tb3witFTzDgAgnIAALgoGQGRyTJe0TwGAeKwJ8QCUDe0LgAea8AFED+6xCrODqAdHlvgXxENngKcS4IBznwWiofJRqKlgieQEb4j+hc2Hgw3xzYZP3/nh9kvzMsyEQoGelgRIb6oCcxiBhIDCUGE21xA9wX98Yj4NEfNheciXsOzuO7P+EpoZ3wmHCD0EG4M0lYLPkpyzGgA+oHK2uR9mMtcCuo6YYH4D5QHSrjurgBcMBdYRwW7gcju0GWrcxbVhXGT9p/m8EPd0PpR3Yio+RhZH+yzc8jaXY0tyEVWa1/rI8i17SherOHen6Oz/6h+nx4Dv/ZE1uIHcTOY6exi9gxrB4wsJNYA9aCHZfhodX1RL66BqPFyvPJhjrCf8QbvLOySuY51Tj1OH1R9OULCmXvaMCeLJ4mEWZk5jNY8IsgYHBEPMcRDBcnF1cAZN8XxevrTYz8u4Hotnzn5v0BgM/JgYGBo9+5sJMA7PeAj/+R75wNE346VAG4cIQnlRQoOFx2IMC3hDp80vSBMTAHNnA+LsAdeAN/EATCQBSIB8lgIsw+E65zCZgKZoC5oASUgWVgNVgPNoGtYCfYAw6AenAMnAbnwGXQBm6Ae3D1dIEXoA+8A58RBCEhVISO6CMmiCVij7ggTMQXCUIikFgkGUlFMhARIkVmIPOQMmQFsh7ZglQj+5EjyGnkItKO3EEeIT3Ia+QTiqFqqDZqhFqhI1EmykLD0Xh0ApqBTkGL0PnoEnQtWoXuRuvQ0+hl9Abagb5A+zGAqWK6mCnmgDExNhaFpWDpmASbhZVi5VgVVos1wvt8DevAerGPOBGn4wzcAa7gUDwB5+FT8Fn4Ynw9vhOvw5vxa/gjvA//RqASDAn2BC8ChzCWkEGYSighlBO2Ew4TzsJnqYvwjkgk6hKtiR7wWUwmZhGnExcTNxD3Ek8R24mdxH4SiaRPsif5kKJIXFI+qYS0jrSbdJJ0ldRF+qCiqmKi4qISrJKiIlIpVilX2aVyQuWqyjOVz2QNsiXZixxF5pOnkZeSt5EbyVfIXeTPFE2KNcWHEk/JosylrKXUUs5S7lPeqKqqmql6qsaoClXnqK5V3ad6QfWR6kc1LTU7NbbaeDWp2hK1HWqn1O6ovaFSqVZUf2oKNZ+6hFpNPUN9SP1Ao9McaRwanzabVkGro12lvVQnq1uqs9Qnqhepl6sfVL+i3qtB1rDSYGtwNWZpVGgc0bil0a9J13TWjNLM1VysuUvzoma3FknLSitIi681X2ur1hmtTjpGN6ez6Tz6PPo2+ll6lzZR21qbo52lXaa9R7tVu09HS8dVJ1GnUKdC57hOhy6ma6XL0c3RXap7QPem7qdhRsNYwwTDFg2rHXZ12Hu94Xr+egK9Ur29ejf0Pukz9IP0s/WX69frPzDADewMYgymGmw0OGvQO1x7uPdw3vDS4QeG3zVEDe0MYw2nG241bDHsNzI2CjESG60zOmPUa6xr7G+cZbzK+IRxjwndxNdEaLLK5KTJc4YOg8XIYaxlNDP6TA1NQ02lpltMW00/m1mbJZgVm+01e2BOMWeap5uvMm8y77MwsRhjMcOixuKuJdmSaZlpucbyvOV7K2urJKsFVvVW3dZ61hzrIusa6/s2VBs/myk2VTbXbYm2TNts2w22bXaonZtdpl2F3RV71N7dXmi/wb59BGGE5wjRiKoRtxzUHFgOBQ41Do8cdR0jHIsd6x1fjrQYmTJy+cjzI785uTnlOG1zuues5RzmXOzc6Pzaxc6F51Lhcn0UdVTwqNmjGka9crV3FbhudL3tRncb47bArcntq7uHu8S91r3Hw8Ij1aPS4xZTmxnNXMy84EnwDPCc7XnM86OXu1e+1wGvv7wdvLO9d3l3j7YeLRi9bXSnj5kP12eLT4cvwzfVd7Nvh5+pH9evyu+xv7k/33+7/zOWLSuLtZv1MsApQBJwOOA924s9k30qEAsMCSwNbA3SCkoIWh/0MNgsOCO4JrgvxC1kesipUEJoeOjy0FscIw6PU83pC/MImxnWHK4WHhe+PvxxhF2EJKJxDDombMzKMfcjLSNFkfVRIIoTtTLqQbR19JToozHEmOiYipinsc6xM2LPx9HjJsXtinsXHxC/NP5egk2CNKEpUT1xfGJ14vukwKQVSR1jR46dOfZyskGyMLkhhZSSmLI9pX9c0LjV47rGu40vGX9zgvWEwgkXJxpMzJl4fJL6JO6kg6mE1KTUXalfuFHcKm5/GietMq2Px+at4b3g+/NX8XsEPoIVgmfpPukr0rszfDJWZvRk+mWWZ/YK2cL1wldZoVmbst5nR2XvyB7IScrZm6uSm5p7RKQlyhY1TzaeXDi5XWwvLhF3TPGasnpKnyRcsj0PyZuQ15CvDX/kW6Q20l+kjwp8CyoKPkxNnHqwULNQVNgyzW7aomnPioKLfpuOT+dNb5phOmPujEczWTO3zEJmpc1qmm0+e/7srjkhc3bOpczNnvt7sVPxiuK385LmNc43mj9nfucvIb/UlNBKJCW3Fngv2LQQXyhc2Lpo1KJ1i76V8ksvlTmVlZd9WcxbfOlX51/X/jqwJH1J61L3pRuXEZeJlt1c7rd85wrNFUUrOleOWVm3irGqdNXb1ZNWXyx3Ld+0hrJGuqZjbcTahnUW65at+7I+c/2NioCKvZWGlYsq32/gb7i60X9j7SajTWWbPm0Wbr69JWRLXZVVVflW4taCrU+3JW47/xvzt+rtBtvLtn/dIdrRsTN2Z3O1R3X1LsNdS2vQGmlNz+7xu9v2BO5pqHWo3bJXd2/ZPrBPuu/5/tT9Nw+EH2g6yDxYe8jyUOVh+uHSOqRuWl1ffWZ9R0NyQ/uRsCNNjd6Nh486Ht1xzPRYxXGd40tPUE7MPzFwsuhk/ynxqd7TGac7myY13Tsz9sz15pjm1rPhZy+cCz535jzr/MkLPheOXfS6eOQS81L9ZffLdS1uLYd/d/v9cKt7a90VjysNbZ5tje2j209c9bt6+lrgtXPXOdcv34i80X4z4ebtW+Nvddzm3+6+k3Pn1d2Cu5/vzblPuF/6QONB+UPDh1V/2P6xt8O94/ijwEctj+Me3+vkdb54kvfkS9f8p9Sn5c9MnlV3u3Qf6wnuaXs+7nnXC/GLz70lf2r+WfnS5uWhv/z/aukb29f1SvJq4PXiN/pvdrx1fdvUH93/8F3uu8/vSz/of9j5kfnx/KekT88+T/1C+rL2q+3Xxm/h
}
},
"cell_type": "markdown",
"id": "f81239f2-314d-41fe-9af9-d19b5b193b53",
"metadata": {},
"source": [
"## Nodes and Edges\n",
"\n",
"Each `node` will simply modify the `state`.\n",
"\n",
"Each `edge` will choose which `node` to call next.\n",
"\n",
"It will follow the graph diagram shown above.\n",
"\n",
"![Screenshot 2024-02-04 at 1.32.52 PM.png](attachment:3b65f495-5fc4-497b-83e2-73844a97f6cc.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efd639c5-82e2-45e6-a94a-6a4039646ef5",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import operator\n",
"from typing import Annotated, Sequence, TypedDict\n",
"\n",
"from langchain import hub\n",
"from langchain.output_parsers import PydanticOutputParser\n",
"from langchain.output_parsers.openai_tools import PydanticToolsParser\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.schema import Document\n",
"from langchain_community.tools.tavily_search import TavilySearchResults\n",
"from langchain_community.vectorstores import Chroma\n",
"from langchain_core.messages import BaseMessage, FunctionMessage\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.utils.function_calling import convert_to_openai_tool\n",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
"from langgraph.prebuilt import ToolInvocation\n",
"\n",
"### Nodes ###\n",
"\n",
"\n",
"def retrieve(state):\n",
" \"\"\"\n",
" Retrieve documents\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New key added to state, documents, that contains documents.\n",
" \"\"\"\n",
" print(\"---RETRIEVE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = retriever.invoke(question)\n",
" return {\"keys\": {\"documents\": documents, \"question\": question}}\n",
"\n",
"\n",
"def generate(state):\n",
" \"\"\"\n",
" Generate answer\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New key added to state, generation, that contains generation.\n",
" \"\"\"\n",
" print(\"---GENERATE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" # Prompt\n",
" prompt = hub.pull(\"rlm/rag-prompt\")\n",
"\n",
" # LLM\n",
" llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0, streaming=True)\n",
"\n",
" # Post-processing\n",
" def format_docs(docs):\n",
" return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
"\n",
" # Chain\n",
" rag_chain = prompt | llm | StrOutputParser()\n",
"\n",
" # Run\n",
" generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n",
" return {\n",
" \"keys\": {\"documents\": documents, \"question\": question, \"generation\": generation}\n",
" }\n",
"\n",
"\n",
"def grade_documents(state):\n",
" \"\"\"\n",
" Determines whether the retrieved documents are relevant to the question.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New key added to state, filtered_documents, that contains relevant documents.\n",
" \"\"\"\n",
"\n",
" print(\"---CHECK RELEVANCE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" # Data model\n",
" class grade(BaseModel):\n",
" \"\"\"Binary score for relevance check.\"\"\"\n",
"\n",
" binary_score: str = Field(description=\"Relevance score 'yes' or 'no'\")\n",
"\n",
" # LLM\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
"\n",
" # Tool\n",
" grade_tool_oai = convert_to_openai_tool(grade)\n",
"\n",
" # LLM with tool and enforce invocation\n",
" llm_with_tool = model.bind(\n",
" tools=[convert_to_openai_tool(grade_tool_oai)],\n",
" tool_choice={\"type\": \"function\", \"function\": {\"name\": \"grade\"}},\n",
" )\n",
"\n",
" # Parser\n",
" parser_tool = PydanticToolsParser(tools=[grade])\n",
"\n",
" # Prompt\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n \n",
" Here is the retrieved document: \\n\\n {context} \\n\\n\n",
" Here is the user question: {question} \\n\n",
" If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n",
" Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\",\n",
" input_variables=[\"context\", \"question\"],\n",
" )\n",
"\n",
" # Chain\n",
" chain = prompt | llm_with_tool | parser_tool\n",
"\n",
" # Score\n",
" filtered_docs = []\n",
" search = \"No\" # Default do not opt for web search to supplement retrieval\n",
" for d in documents:\n",
" score = chain.invoke({\"question\": question, \"context\": d.page_content})\n",
" grade = score[0].binary_score\n",
" if grade == \"yes\":\n",
" print(\"---GRADE: DOCUMENT RELEVANT---\")\n",
" filtered_docs.append(d)\n",
" else:\n",
" print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n",
" search = \"Yes\" # Perform web search\n",
" continue\n",
"\n",
" return {\n",
" \"keys\": {\n",
" \"documents\": filtered_docs,\n",
" \"question\": question,\n",
" \"run_web_search\": search,\n",
" }\n",
" }\n",
"\n",
"\n",
"def transform_query(state):\n",
" \"\"\"\n",
" Transform the query to produce a better question.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New value saved to question.\n",
" \"\"\"\n",
"\n",
" print(\"---TRANSFORM QUERY---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" # Create a prompt template with format instructions and the query\n",
" prompt = PromptTemplate(\n",
" template=\"\"\"You are generating questions that is well optimized for retrieval. \\n \n",
" Look at the input and try to reason about the underlying sematic intent / meaning. \\n \n",
" Here is the initial question:\n",
" \\n ------- \\n\n",
" {question} \n",
" \\n ------- \\n\n",
" Formulate an improved question: \"\"\",\n",
" input_variables=[\"question\"],\n",
" )\n",
"\n",
" # Grader\n",
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n",
"\n",
" # Prompt\n",
" chain = prompt | model | StrOutputParser()\n",
" better_question = chain.invoke({\"question\": question})\n",
"\n",
" return {\"keys\": {\"documents\": documents, \"question\": better_question}}\n",
"\n",
"\n",
"def web_search(state):\n",
" \"\"\"\n",
" Web search using Tavily.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" state (dict): Web results appended to documents.\n",
" \"\"\"\n",
"\n",
" print(\"---WEB SEARCH---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" documents = state_dict[\"documents\"]\n",
"\n",
" tool = TavilySearchResults()\n",
" docs = tool.invoke({\"query\": question})\n",
" web_results = \"\\n\".join([d[\"content\"] for d in docs])\n",
" web_results = Document(page_content=web_results)\n",
" documents.append(web_results)\n",
"\n",
" return {\"keys\": {\"documents\": documents, \"question\": question}}\n",
"\n",
"\n",
"### Edges\n",
"\n",
"\n",
"def decide_to_generate(state):\n",
" \"\"\"\n",
" Determines whether to generate an answer, or re-generate a question.\n",
"\n",
" Args:\n",
" state (dict): The current state of the agent, including all keys.\n",
"\n",
" Returns:\n",
" dict: New key added to state, filtered_documents, that contains relevant documents.\n",
" \"\"\"\n",
"\n",
" print(\"---DECIDE TO GENERATE---\")\n",
" state_dict = state[\"keys\"]\n",
" question = state_dict[\"question\"]\n",
" filtered_documents = state_dict[\"documents\"]\n",
" search = state_dict[\"run_web_search\"]\n",
"\n",
" if search == \"Yes\":\n",
" # All documents have been filtered check_relevance\n",
" # We will re-generate a new query\n",
" print(\"---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---\")\n",
" return \"transform_query\"\n",
" else:\n",
" # We have relevant documents, so generate answer\n",
" print(\"---DECISION: GENERATE---\")\n",
" return \"generate\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dedae17a-98c6-474d-90a7-9234b7c8cea0",
"metadata": {},
"outputs": [],
"source": [
"import pprint\n",
"\n",
"from langgraph.graph import END, StateGraph\n",
"\n",
"workflow = StateGraph(GraphState)\n",
"\n",
"# Define the nodes\n",
"workflow.add_node(\"retrieve\", retrieve) # retrieve\n",
"workflow.add_node(\"grade_documents\", grade_documents) # grade documents\n",
"workflow.add_node(\"generate\", generate) # generatae\n",
"workflow.add_node(\"transform_query\", transform_query) # transform_query\n",
"workflow.add_node(\"web_search\", web_search) # web search\n",
"\n",
"# Build graph\n",
"workflow.set_entry_point(\"retrieve\")\n",
"workflow.add_edge(\"retrieve\", \"grade_documents\")\n",
"workflow.add_conditional_edges(\n",
" \"grade_documents\",\n",
" decide_to_generate,\n",
" {\n",
" \"transform_query\": \"transform_query\",\n",
" \"generate\": \"generate\",\n",
" },\n",
")\n",
"workflow.add_edge(\"transform_query\", \"web_search\")\n",
"workflow.add_edge(\"web_search\", \"generate\")\n",
"workflow.add_edge(\"generate\", END)\n",
"\n",
"# Compile\n",
"app = workflow.compile()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5b7c2fe-1fc7-4b76-bf93-ba701a40aa6b",
"metadata": {},
"outputs": [],
"source": [
"# Run\n",
"inputs = {\"keys\": {\"question\": \"Explain how the different types of agent memory work?\"}}\n",
"for output in app.stream(inputs):\n",
" for key, value in output.items():\n",
" pprint.pprint(f\"Output from node '{key}':\")\n",
" pprint.pprint(\"---\")\n",
" pprint.pprint(value[\"keys\"], indent=2, width=80, depth=None)\n",
" pprint.pprint(\"\\n---\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2bee03de-a32c-4bbe-b37a-a13bb825e4cb",
"metadata": {},
"outputs": [],
"source": [
"# Correction for question not present in context\n",
"inputs = {\"keys\": {\"question\": \"What is the approach taken in the AlphaCodium paper?\"}}\n",
"for output in app.stream(inputs):\n",
" for key, value in output.items():\n",
" pprint.pprint(f\"Output from node '{key}':\")\n",
" pprint.pprint(\"---\")\n",
" pprint.pprint(value[\"keys\"], indent=2, width=80, depth=None)\n",
" pprint.pprint(\"\\n---\\n\")"
]
},
{
"cell_type": "markdown",
"id": "a7e44593-1959-4abf-8405-5e23aa9398f5",
"metadata": {},
"source": [
"Traces -\n",
" \n",
"[Trace](https://smith.langchain.com/public/7e0b9569-abfe-4337-b34b-842b1f93df63/r) and [Trace](https://smith.langchain.com/public/b40c5813-7caf-4cc8-b279-ee66060b2040/r)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69eddb3e-57f4-4eea-8e40-4822fc50c729",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}