Anthropic Citations API Support (#8382)

* test(test_anthropic_completion.py): add test ensuring anthropic structured output response is consistent

Resolves https://github.com/BerriAI/litellm/issues/8291

* feat(anthropic.py): support citations api with new user document message format

Resolves https://github.com/BerriAI/litellm/issues/7970

* fix(anthropic/chat/transformation.py): return citations as a provider-specific-field

Resolves https://github.com/BerriAI/litellm/issues/7970

* feat(anthropic/chat/handler.py): add streaming citations support

Resolves https://github.com/BerriAI/litellm/issues/7970

* fix(handler.py): fix code qa error

* fix(handler.py): only set provider specific fields if non-empty dict

* docs(anthropic.md): add citations api to anthropic docs
This commit is contained in:
Krish Dholakia 2025-02-07 22:27:01 -08:00 committed by GitHub
parent 8a8e8f6cdd
commit 7759e86cf5
9 changed files with 308 additions and 21 deletions

View file

@ -987,6 +987,106 @@ curl http://0.0.0.0:4000/v1/chat/completions \
</TabItem> </TabItem>
</Tabs> </Tabs>
## [BETA] Citations API
Pass `citations: {"enabled": true}` to Anthropic, to get citations on your document responses.
Note: This interface is in BETA. If you have feedback on how citations should be returned, please [tell us here](https://github.com/BerriAI/litellm/issues/7970#issuecomment-2644437943)
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
resp = completion(
model="claude-3-5-sonnet-20241022",
messages=[
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "The grass is green. The sky is blue.",
},
"title": "My Document",
"context": "This is a trustworthy document.",
"citations": {"enabled": True},
},
{
"type": "text",
"text": "What color is the grass and sky?",
},
],
}
],
)
citations = resp.choices[0].message.provider_specific_fields["citations"]
assert citations is not None
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: anthropic-claude
litellm_params:
model: anthropic/claude-3-5-sonnet-20241022
api_key: os.environ/ANTHROPIC_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "anthropic-claude",
"messages": [
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "The grass is green. The sky is blue.",
},
"title": "My Document",
"context": "This is a trustworthy document.",
"citations": {"enabled": True},
},
{
"type": "text",
"text": "What color is the grass and sky?",
},
],
}
]
}'
```
</TabItem>
</Tabs>
## Usage - passing 'user_id' to Anthropic ## Usage - passing 'user_id' to Anthropic
LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param. LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param.

View file

@ -1421,6 +1421,8 @@ def anthropic_messages_pt( # noqa: PLR0915
) )
user_content.append(_content_element) user_content.append(_content_element)
elif m.get("type", "") == "document":
user_content.append(cast(AnthropicMessagesDocumentParam, m))
elif isinstance(user_message_types_block["content"], str): elif isinstance(user_message_types_block["content"], str):
_anthropic_content_text_element: AnthropicMessagesTextParam = { _anthropic_content_text_element: AnthropicMessagesTextParam = {
"type": "text", "type": "text",

View file

@ -809,7 +809,10 @@ class CustomStreamWrapper:
if self.sent_first_chunk is False: if self.sent_first_chunk is False:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
if response_obj.get("provider_specific_fields") is not None:
completion_obj["provider_specific_fields"] = response_obj[
"provider_specific_fields"
]
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)
_index: Optional[int] = completion_obj.get("index") _index: Optional[int] = completion_obj.get("index")
if _index is not None: if _index is not None:

View file

@ -4,7 +4,7 @@ Calling + translation logic for anthropic's `/v1/messages` endpoint
import copy import copy
import json import json
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
@ -506,6 +506,29 @@ class ModelResponseIterator:
return usage_block return usage_block
def _content_block_delta_helper(self, chunk: dict):
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
provider_specific_fields = {}
content_block = ContentBlockDelta(**chunk) # type: ignore
self.content_blocks.append(content_block)
if "text" in content_block["delta"]:
text = content_block["delta"]["text"]
elif "partial_json" in content_block["delta"]:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": content_block["delta"]["partial_json"],
},
"index": self.tool_index,
}
elif "citation" in content_block["delta"]:
provider_specific_fields["citation"] = content_block["delta"]["citation"]
return text, tool_use, provider_specific_fields
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try: try:
type_chunk = chunk.get("type", "") or "" type_chunk = chunk.get("type", "") or ""
@ -515,6 +538,7 @@ class ModelResponseIterator:
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields: Dict[str, Any] = {}
index = int(chunk.get("index", 0)) index = int(chunk.get("index", 0))
if type_chunk == "content_block_delta": if type_chunk == "content_block_delta":
@ -522,20 +546,9 @@ class ModelResponseIterator:
Anthropic content chunk Anthropic content chunk
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}} chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
""" """
content_block = ContentBlockDelta(**chunk) # type: ignore text, tool_use, provider_specific_fields = (
self.content_blocks.append(content_block) self._content_block_delta_helper(chunk=chunk)
if "text" in content_block["delta"]: )
text = content_block["delta"]["text"]
elif "partial_json" in content_block["delta"]:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": content_block["delta"]["partial_json"],
},
"index": self.tool_index,
}
elif type_chunk == "content_block_start": elif type_chunk == "content_block_start":
""" """
event: content_block_start event: content_block_start
@ -628,6 +641,9 @@ class ModelResponseIterator:
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage, usage=usage,
index=index, index=index,
provider_specific_fields=(
provider_specific_fields if provider_specific_fields else None
),
) )
return returned_chunk return returned_chunk

View file

@ -628,6 +628,7 @@ class AnthropicConfig(BaseConfig):
) )
else: else:
text_content = "" text_content = ""
citations: List[Any] = []
tool_calls: List[ChatCompletionToolCallChunk] = [] tool_calls: List[ChatCompletionToolCallChunk] = []
for idx, content in enumerate(completion_response["content"]): for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text": if content["type"] == "text":
@ -645,10 +646,14 @@ class AnthropicConfig(BaseConfig):
index=idx, index=idx,
) )
) )
## CITATIONS
if content.get("citations", None) is not None:
citations.append(content["citations"])
_message = litellm.Message( _message = litellm.Message(
tool_calls=tool_calls, tool_calls=tool_calls,
content=text_content or None, content=text_content or None,
provider_specific_fields={"citations": citations},
) )
## HANDLE JSON MODE - anthropic returns single function call ## HANDLE JSON MODE - anthropic returns single function call

View file

@ -92,10 +92,17 @@ class AnthropicMessagesImageParam(TypedDict, total=False):
cache_control: Optional[Union[dict, ChatCompletionCachedContent]] cache_control: Optional[Union[dict, ChatCompletionCachedContent]]
class CitationsObject(TypedDict):
enabled: bool
class AnthropicMessagesDocumentParam(TypedDict, total=False): class AnthropicMessagesDocumentParam(TypedDict, total=False):
type: Required[Literal["document"]] type: Required[Literal["document"]]
source: Required[AnthropicContentParamSource] source: Required[AnthropicContentParamSource]
cache_control: Optional[Union[dict, ChatCompletionCachedContent]] cache_control: Optional[Union[dict, ChatCompletionCachedContent]]
title: str
context: str
citations: Optional[CitationsObject]
class AnthropicMessagesToolResultContent(TypedDict): class AnthropicMessagesToolResultContent(TypedDict):
@ -173,6 +180,11 @@ class ContentTextBlockDelta(TypedDict):
text: str text: str
class ContentCitationsBlockDelta(TypedDict):
type: Literal["citations"]
citation: dict
class ContentJsonBlockDelta(TypedDict): class ContentJsonBlockDelta(TypedDict):
""" """
"delta": {"type": "input_json_delta","partial_json": "{\"location\": \"San Fra"}} "delta": {"type": "input_json_delta","partial_json": "{\"location\": \"San Fra"}}
@ -185,7 +197,9 @@ class ContentJsonBlockDelta(TypedDict):
class ContentBlockDelta(TypedDict): class ContentBlockDelta(TypedDict):
type: Literal["content_block_delta"] type: Literal["content_block_delta"]
index: int index: int
delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta] delta: Union[
ContentTextBlockDelta, ContentJsonBlockDelta, ContentCitationsBlockDelta
]
class ContentBlockStop(TypedDict): class ContentBlockStop(TypedDict):

View file

@ -382,10 +382,29 @@ class ChatCompletionAudioObject(ChatCompletionContentPartInputAudioParam):
pass pass
class DocumentObject(TypedDict):
type: Literal["text"]
media_type: str
data: str
class CitationsObject(TypedDict):
enabled: bool
class ChatCompletionDocumentObject(TypedDict):
type: Literal["document"]
source: DocumentObject
title: str
context: str
citations: Optional[CitationsObject]
OpenAIMessageContentListBlock = Union[ OpenAIMessageContentListBlock = Union[
ChatCompletionTextObject, ChatCompletionTextObject,
ChatCompletionImageObject, ChatCompletionImageObject,
ChatCompletionAudioObject, ChatCompletionAudioObject,
ChatCompletionDocumentObject,
] ]
OpenAIMessageContent = Union[ OpenAIMessageContent = Union[
@ -460,6 +479,7 @@ ValidUserMessageContentTypes = [
"text", "text",
"image_url", "image_url",
"input_audio", "input_audio",
"document",
] # used for validating user messages. Prevent users from accidentally sending anthropic messages. ] # used for validating user messages. Prevent users from accidentally sending anthropic messages.
AllMessageValues = Union[ AllMessageValues = Union[

View file

@ -551,6 +551,7 @@ class Delta(OpenAIObject):
): ):
super(Delta, self).__init__(**params) super(Delta, self).__init__(**params)
provider_specific_fields: Dict[str, Any] = {} provider_specific_fields: Dict[str, Any] = {}
if "reasoning_content" in params: if "reasoning_content" in params:
provider_specific_fields["reasoning_content"] = params["reasoning_content"] provider_specific_fields["reasoning_content"] = params["reasoning_content"]
setattr(self, "reasoning_content", params["reasoning_content"]) setattr(self, "reasoning_content", params["reasoning_content"])

View file

@ -1022,10 +1022,26 @@ def test_anthropic_json_mode_and_tool_call_response(
[ [
("stop", ["stop"], True), # basic string ("stop", ["stop"], True), # basic string
(["stop1", "stop2"], ["stop1", "stop2"], True), # list of strings (["stop1", "stop2"], ["stop1", "stop2"], True), # list of strings
(" ", None, True), # whitespace string should be dropped when drop_params is True (
(" ", [" "], False), # whitespace string should be kept when drop_params is False " ",
(["stop1", " ", "stop2"], ["stop1", "stop2"], True), # list with whitespace that should be filtered None,
(["stop1", " ", "stop2"], ["stop1", " ", "stop2"], False), # list with whitespace that should be kept True,
), # whitespace string should be dropped when drop_params is True
(
" ",
[" "],
False,
), # whitespace string should be kept when drop_params is False
(
["stop1", " ", "stop2"],
["stop1", "stop2"],
True,
), # list with whitespace that should be filtered
(
["stop1", " ", "stop2"],
["stop1", " ", "stop2"],
False,
), # list with whitespace that should be kept
(None, None, True), # None input (None, None, True), # None input
], ],
) )
@ -1035,3 +1051,113 @@ def test_map_stop_sequences(stop_input, expected_output, drop_params):
config = AnthropicConfig() config = AnthropicConfig()
result = config._map_stop_sequences(stop_input) result = config._map_stop_sequences(stop_input)
assert result == expected_output assert result == expected_output
@pytest.mark.asyncio
async def test_anthropic_structured_output():
"""
Test the _transform_response_for_structured_output
Relevant Issue: https://github.com/BerriAI/litellm/issues/8291
"""
from litellm import acompletion
args = {
"model": "claude-3-5-sonnet-20240620",
"seed": 3015206306868917280,
"stop": None,
"messages": [
{
"role": "system",
"content": 'You are a hello world agent.\nAlways respond in the following valid JSON format: {\n "response": "response",\n}\n',
},
{"role": "user", "content": "Respond with hello world"},
],
"temperature": 0,
"response_format": {"type": "json_object"},
"drop_params": True,
}
response = await acompletion(**args)
assert response is not None
print(response)
def test_anthropic_citations_api():
"""
Test the citations API
"""
from litellm import completion
resp = completion(
model="claude-3-5-sonnet-20241022",
messages=[
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "The grass is green. The sky is blue.",
},
"title": "My Document",
"context": "This is a trustworthy document.",
"citations": {"enabled": True},
},
{
"type": "text",
"text": "What color is the grass and sky?",
},
],
}
],
)
citations = resp.choices[0].message.provider_specific_fields["citations"]
assert citations is not None
def test_anthropic_citations_api_streaming():
from litellm import completion
resp = completion(
model="claude-3-5-sonnet-20241022",
messages=[
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "The grass is green. The sky is blue.",
},
"title": "My Document",
"context": "This is a trustworthy document.",
"citations": {"enabled": True},
},
{
"type": "text",
"text": "What color is the grass and sky?",
},
],
}
],
stream=True,
)
has_citations = False
for chunk in resp:
print(f"returned chunk: {chunk}")
if (
chunk.choices[0].delta.provider_specific_fields
and "citation" in chunk.choices[0].delta.provider_specific_fields
):
has_citations = True
assert has_citations