Skip to content

Ultimate Guide: Using Trustcall for Data Extraction and Long-Term Memory Storage in LangGraph

  • 7 min read
  • by
Trustcall for Data Extraction and Long-Term Memory Storage in LangGraph

LangGraph allows you to store both short-term memory, which is accessible within the same thread, and long-term memory, which is accessible across threads, making it versatile for various use cases. However, extracting data from complex JSON schemas can be unreliable and costly. This is where Trustcall comes in—a powerful Python library that simplifies data extraction and enables efficient long-term memory storage in LangGraph. In this guide, you’ll learn how to leverage Trustcall to streamline your workflows effectively.

Benefits of using Trustcall for Data Extraction

Simple data extraction using with_structuted_output function is useful if you have a normal JSON schema and want to extract the information from it. However, for the complex JSON schema, it is sometimes expensive and unreliable.

Trustcall for data extraction diagram
Image Credit: Trustcall Github repository

On the other hand, Trustcall is built on top of LangGraph. It uses a JSON patch under the hood to fix any JSON extraction error and gives accurate answers (see in the above diagram).

It not only gives structured output but is also cheaper and more reliable for complex JSON schemas.

How to implement Trustcall for Data Extraction?

Trustcall is very easy to implement. The implementation method is quite similar to simple data extraction methods. Let’s see an example:

from trustcall import create_extractor
from langchain_openai import ChatOpenAI
from pydantic import BaseModel,Field

#The LLM model
llm=ChatOpenAI()

# Schema 
class UserProfile(BaseModel):
    """User profile schema with typed fields"""
    user_name: str = Field(description="The user's preferred name")
    interests: List[str] = Field(description="A list of the user's interests")
    
# Conversation
conversation = [HumanMessage(content="Hi, I'm Kathan."), 
                AIMessage(content="Nice to meet you, Kathan."), 
                HumanMessage(content="I really like driving around the city.")]
    
# Create the extractor
trustcall_extractor = create_extractor(
    llm,
    tools=[UserProfile],
    tool_choice="UserProfile"
)

# Instruction
system_msg = "Extract the user profile from the following conversation"

# Invoke the extractor
result = trustcall_extractor.invoke({"messages": [SystemMessage(content=system_msg)]+conversation})
print(result)
print(result['responses'])

Here in the create_extractor, tools, and tool_choice are used for declaring the schema(s) that we want to use to extract the information from the input. After that, we passed the conversation with the System message in the invoke function. And it gives the following output.

{'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_T9LalrMrd9FXIut99o8BKCBE', 'function': {'arguments': '{"user_name":"Kathan","interests":["driving","city exploration"]}', 'name': 'UserProfile'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 138, 'total_tokens': 155, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-1cc74159-eebb-4a32-a387-2a7469adc962-0', tool_calls=[{'name': 'UserProfile', 'args': {'user_name': 'Kathan', 'interests': ['driving', 'city exploration']}, 'id': 'call_T9LalrMrd9FXIut99o8BKCBE', 'type': 'tool_call'}], usage_metadata={'input_tokens': 138, 'output_tokens': 17, 'total_tokens': 155, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})],
 'responses': [UserProfile(user_name='Kathan', interests=['driving', 'city exploration'])],
 'response_metadata': [{'id': 'call_T9LalrMrd9FXIut99o8BKCBE'}],
 'attempts': 1}
 
[UserProfile(user_name='Kathan', interests=['driving', 'city exploration'])]

How to use Trustcall to Implement Long-term Memory in LangGraph

Now, let’s use the Trustcall library to implement Long-term memory in the LangGraph. Let’s build a simple chatbot that will persist in short-term and long-term memories.

First, we are going to import the necessary libraries:

from trustcall import create_extractor
from langchain_openai import ChatOpenAI
from pydantic import BaseModel,Field
from langgraph.graph import END,START
from langgraph.graph import MessagesState
from langchain_core.runnables.config import RunnableConfig
from langgraph.store.base import BaseStore
from langchain_core.messages import merge_message_runs
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.memory import InMemoryStore

define schema and prompt to use in Chatbot.

#The LLM model
llm=ChatOpenAI()

# Schema 
class Memory(BaseModel):
    content:str=Field(description="The main content of the memory. For example: User expressed interest in learning about Computer Science.")
    
trustcall_extractor= create_extractor(
    llm,
    tools=[Memory],
    tool_choice='Memory',
    enable_inserts=True #To Enable assertion in the schema
)
    

MODEL_SYSTEM_MESSAGE="""You are a healpful chatbot. You are designed to be a companion to a user.
You have a long term memory which keeps track of information you learn about the user over time.

Current Memory( may include updated memories from this conversation):
{memory}

"""

TRUSTCALL_INSTRUCTION="""Reflect on following interation.
Use the provided tools to retain any necessary memories about the user.
Use parallel tool calling to handle updates and insertions simultaneously
"""

#define short-term and long-term memories

# Store for long-term (across-thread) memory
across_thread_memory = InMemoryStore()

# Checkpointer for short-term (within-thread) memory
within_thread_memory = MemorySaver()

The following function will call the model with the memories of the user.

def call_model(state:MessagesState,config: RunnableConfig,store:BaseStore):

    #get the user Id
    user_id=config['configurable']['user_id']

    #retrieve memory from the config
    namespace=("memories",user_id)
    memories= store.search(namespace)
		
		#format the memories
    info="\\n".join(f"- {mem.value['content']}" for mem in memories)

    system_msg= MODEL_SYSTEM_MESSAGE.format(memory=info)

    response= llm.invoke([SystemMessage(content=system_msg)] + state['messages'])

    return {'messages':response}

Now, let’s create the main function that will write memories.

def write_memory(state:MessagesState, config:RunnableConfig, store: BaseStore):
    user_id=config['configurable']['user_id']

    namespace= ('memories',user_id)

    existing_items= store.search(namespace)

    tool_name= "Memory"

    existing_memories = ([(existing_item.key, tool_name, existing_item.value)
                          for existing_item in existing_items]
                          if existing_items
                          else None
                        )
    updated_messages=list(merge_message_runs(messages=[SystemMessage(content=TRUSTCALL_INSTRUCTION)] + state['messages']))

    result= trustcall_extractor.invoke({"messages":updated_messages,
                                        "existing":existing_memories})
	  
	  #Fetches the json_doc_id and use is as a `key` for memory store with namespace
	  #and dictionary values
    for r,rmeta in zip(result['responses'],result['response_metadata']):
        store.put(namespace,rmeta.get("json_doc_id",str(uuid.uuid4())),
                  r.model_dump(mode='json'))
                  
flow of created Graph

Now we are ready to build a LangGraph.

builder = StateGraph(MessagesState)

#define nodes
builder.add_node("call_model", call_model)
builder.add_node("write_memory", write_memory)

#define edges
builder.add_edge(START, "call_model")
builder.add_edge("call_model", "write_memory")
builder.add_edge("write_memory", END)

#compile the model
graph = builder.compile(checkpointer=within_thread_memory, store=across_thread_memory)

Now our Graph is ready to serve the purpose! Let’s test it out.

# We supply a thread ID for short-term (within-thread) memory
# We supply a user ID for long-term (across-thread) memory 
config = {"configurable": {"thread_id": "1", "user_id": "1"}}

# User input 
input_messages = [HumanMessage(content="Hi, my name is Kathan")]

# Run the graph
for chunk in graph.stream({"messages": input_messages}, config, stream_mode="values"):
    chunk["messages"][-1].pretty_print()
  
----OUTPUT----
================================ Human Message =================================

Hi, my name is Kathan
================================== Ai Message ==================================

Hello Kathan! It's nice to meet you. How can I assist you today?
# User input 
input_messages = [HumanMessage(content="I really like driving around Mumbai and love to watch movies.")]

# Run the graph
for chunk in graph.stream({"messages": input_messages}, config, stream_mode="values"):
    chunk["messages"][-1].pretty_print()

================================ Human Message =================================

I really like driving around Mumbai and love to watch movies.
================================== Ai Message ==================================

That's great to know, Kathan! Driving around Mumbai must be quite an adventure with all the hustle and bustle of the city. Do you have a favorite genre of movies you enjoy watching?

Let’s see what the Trustcall function stored in the memory.

# Namespace for the memory to save
user_id = "1"
namespace = ("memories", user_id)
memories = across_thread_memory.search(namespace)
for m in memories:
    print(m.dict())
    
----OUTPUT----
{'value': {'content': "User's name is Kathan"}, 'key': '2b77208c-d8da-4197-9c67-049725105e57', 'namespace': ['memories', '1'], 'created_at': '2024-12-29T09:34:21.536090+00:00', 'updated_at': '2024-12-29T09:34:21.536091+00:00'}
{'value': {'content': 'User enjoys driving around Mumbai and loves watching movies'}, 'key': 'cda932b5-c22f-4d6b-877c-2c3d23fce486', 'namespace': ['memories', '1'], 'created_at': '2024-12-29T09:34:30.044336+00:00', 'updated_at': '2024-12-29T09:34:30.044337+00:00'}

Perfect! This is the exact format that we were expecting from the schema.

Now let’s use a different thread to ask the question from the memories.

config = {"configurable": {"thread_id": "2", "user_id": "1"}}

# User input 
input_messages = [HumanMessage(content="Which theatres do you recommend for me?")]

# Run the graph
for chunk in graph.stream({"messages": input_messages}, config, stream_mode="values"):
    chunk["messages"][-1].pretty_print()
    
----OUTPUT----
================================ Human Message =================================

Which theatres do you recommend for me?
================================== Ai Message ==================================

Hi Kathan! I recommend PVR Cinemas, INOX, and Carnival Cinemas in Mumbai for a great movie-watching experience. These theaters have comfortable seating, good sound quality, and a wide selection of movies to choose from. Do you have a favorite genre of movies that you usually enjoy watching?

Excellent! Our memory is working perfectly with the TrustCall.

Conclusion

Trustcall is an invaluable tool for extracting data from inputs based on a given schema, especially complex ones. Not only is it accurate, but it’s also reliable and fast, making it an excellent choice for any LLM-related tasks. As demonstrated in this article, Trustcall plays a crucial role in efficiently creating long-term memory for LangGraph.

Also Read: Top 5 Methods to Create a State in LangGraph

Leave a Reply

Your email address will not be published. Required fields are marked *