How to count Amazon Bedrock Anthropic tokens with Langchain
Tracking Amazon Bedrock Claude token usage is simple using Langchain!
How to track tokens using Langchain
For OpenAI models, Langchain provides a native Callback handler for tracking token usage as documented here.
1from langchain.callbacks import get_openai_callback
2from langchain.llms import OpenAI
3
4llm = OpenAI(temperature=0)
5with get_openai_callback() as cb:
6 llm("What is the square root of 4?")
7
8total_tokens = cb.total_tokens
9assert total_tokens > 0
However, other model families do not have this luxury.
Counting Anthropic Bedrock Tokens
Anthropic released a Bedrock client for their models that provides a method for counting tokens. I wrote an entire article about how to use it here.
Under the hood, Langchain uses the Anthropic client within the Bedrock LLM class to provide token counting.
How to count tokens using Langchain Bedrock LLM Step-by-Step
Before starting, please review my post about using Bedrock with Langchain here. It describes how to use boto3
with langchain
to easily instantiate clients. For the sake of this tutorial, I’ll assume you already have your AWS session configured.
Installation
Firstly, install the required libraries:
1pip3 install anthropic boto3 langchain
Counting Bedrock LLM tokens
To keep it simple, we’ll begin by importing the Bedrock LLM and counting the input and output tokens.
1from langchain.llms import Bedrock
2
3import boto3
4
5client = boto3.client("bedrock-runtime")
6
7llm = Bedrock(client=client, model_id="anthropic.claude-instant-v1")
8prompt = "who are you?"
9print(llm.get_num_tokens(prompt))
10
11result = llm(prompt)
12print(result, llm.get_num_tokens(result))
Output:
14
2My name is Claude. I'm an AI assistant created by Anthropic. 16
In this case, 4 input tokens and 16 output tokens were used.
Counting Bedrock LLM tokens in chain
For my use case, I need to count the tokens used in chains. We’ll need to build a custom handler since Langchain does not provide native support for counting Bedrock tokens.
For more information on Callback handlers, please reference the Langchain documentation here.
In the code below, I extended the BaseCallbackHandler
class with AnthropicTokenCounter
. In the constructor, a Bedrock LLM client is required since it contains the token counting function.
1from langchain.callbacks.base import BaseCallbackHandler
2
3import boto3
4
5
6class AnthropicTokenCounter(BaseCallbackHandler):
7 def __init__(self, llm):
8 self.llm = llm
9 self.input_tokens = 0
10 self.output_tokens = 0
11
12 def on_llm_start(self, serialized, prompts, **kwargs):
13 for p in prompts:
14 self.input_tokens += self.llm.get_num_tokens(p)
15
16 def on_llm_end(self, response, **kwargs):
17 results = response.flatten()
18 for r in results:
19 self.output_tokens = self.llm.get_num_tokens(r.generations[0][0].text)
- The
on_llm_start
function supplies a list of prompts. For each prompt, I add the tokens to an instance property namedinput_tokens
. - The
on_llm_end
function supplies aLLMResult
.- The
LLMResult
, as I discovered in the source code, has aflatten()
method that converts aLLMResult
to a list ofLLMResult
s with 1 generation. - I add the tokens to an instance property named
output_tokens
.
- The
I used the callback handler when calling the custom chain:
1from langchain.llms import Bedrock
2from langchain.prompts import ChatPromptTemplate
3from langchain.callbacks.base import BaseCallbackHandler
4
5import boto3
6
7
8class AnthropicTokenCounter(BaseCallbackHandler):
9 def __init__(self, llm):
10 self.llm = llm
11 self.input_tokens = 0
12 self.output_tokens = 0
13
14 def on_llm_start(self, serialized, prompts, **kwargs):
15 for p in prompts:
16 self.input_tokens += self.llm.get_num_tokens(p)
17
18 def on_llm_end(self, response, **kwargs):
19 results = response.flatten()
20 for r in results:
21 self.output_tokens = self.llm.get_num_tokens(r.generations[0][0].text)
22
23
24client = boto3.client("bedrock-runtime")
25
26llm = Bedrock(client=client, model_id="anthropic.claude-instant-v1")
27
28prompt = ChatPromptTemplate.from_template("tell me a joke about {subject}")
29chain = prompt | llm
30
31token_counter = AnthropicTokenCounter(llm)
32print(chain.invoke({"subject": "ai"}, config={"callbacks": [token_counter]}))
33print(token_counter.input_tokens)
34print(token_counter.output_tokens)
Output:
1Here's one: Why can't AI assistants tell jokes? Because they lack a sense of humor!
28
320