feat(proxy_server.py): working /v1/messages endpoint

Works with claude engineer
This commit is contained in:
Krrish Dholakia 2024-07-10 18:15:38 -07:00
parent 5d6e172d5c
commit 2f8dbbeb97
9 changed files with 272 additions and 152 deletions

View file

@ -8,11 +8,12 @@ from typing import Literal, Optional
import dotenv
import httpx
from pydantic import BaseModel
import litellm
from litellm import ChatCompletionRequest, verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.llms.anthropic import AnthropicMessagesRequest
from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse
class AnthropicAdapter(CustomLogger):
@ -31,12 +32,18 @@ class AnthropicAdapter(CustomLogger):
translated_body = litellm.AnthropicConfig().translate_anthropic_to_openai(
anthropic_message_request=request_body
)
return translated_body
def translate_completion_output_params(self, response: litellm.ModelResponse):
return super().translate_completion_output_params(response)
def translate_completion_output_params(
self, response: litellm.ModelResponse
) -> Optional[AnthropicResponse]:
def translate_completion_output_params_streaming(self):
return litellm.AnthropicConfig().translate_openai_response_to_anthropic(
response=response
)
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
return super().translate_completion_output_params_streaming()

View file

@ -5,6 +5,7 @@ import traceback
from typing import Literal, Optional, Union
import dotenv
from pydantic import BaseModel
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
@ -67,13 +68,15 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
"""
pass
def translate_completion_output_params(self, response: ModelResponse):
def translate_completion_output_params(
self, response: ModelResponse
) -> Optional[BaseModel]:
"""
Translates the output params, from the OpenAI format to the custom format.
"""
pass
def translate_completion_output_params_streaming(self):
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
"""
Translates the streaming chunk, from the OpenAI format to the custom format.
"""

View file

@ -21,10 +21,15 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.types.llms.anthropic import (
AnthopicMessagesAssistantMessageParam,
AnthropicFinishReason,
AnthropicMessagesRequest,
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
AnthropicMessagesUserMessageParam,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseContentBlockToolUse,
AnthropicResponseUsageBlock,
ContentBlockDelta,
ContentBlockStart,
MessageBlockDelta,
@ -51,7 +56,7 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
ChatCompletionUserMessage,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import Choices, GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM
@ -187,6 +192,8 @@ class AnthropicConfig:
optional_params["top_p"] = value
return optional_params
### FOR [BETA] `/v1/messages` endpoint support
def translatable_anthropic_params(self) -> List:
"""
Which anthropic params, we need to translate to the openai format.
@ -300,9 +307,13 @@ class AnthropicConfig:
)
)
if assistant_message_str is not None or len(tool_calls) > 0:
assistant_message = ChatCompletionAssistantMessage(
role="assistant", content=assistant_message_str, tool_calls=tool_calls
role="assistant",
content=assistant_message_str,
)
if len(tool_calls) > 0:
assistant_message["tool_calls"] = tool_calls
new_messages.append(assistant_message)
return new_messages
@ -391,6 +402,77 @@ class AnthropicConfig:
return new_kwargs
def _translate_openai_content_to_anthropic(
self, choices: List[Choices]
) -> List[
Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse]
]:
new_content: List[
Union[
AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse
]
] = []
for choice in choices:
if (
choice.message.tool_calls is not None
and len(choice.message.tool_calls) > 0
):
for tool_call in choice.message.tool_calls:
new_content.append(
AnthropicResponseContentBlockToolUse(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name,
input=tool_call.function.arguments,
)
)
elif choice.message.content is not None:
new_content.append(
AnthropicResponseContentBlockText(
type="text", text=choice.message.content
)
)
return new_content
def _translate_openai_finish_reason_to_anthropic(
self, openai_finish_reason: str
) -> AnthropicFinishReason:
if openai_finish_reason == "stop":
return "end_turn"
elif openai_finish_reason == "length":
return "max_tokens"
elif openai_finish_reason == "tool_calls":
return "tool_use"
return "end_turn"
def translate_openai_response_to_anthropic(
self, response: litellm.ModelResponse
) -> AnthropicResponse:
## translate content block
anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore
## extract finish reason
anthropic_finish_reason = self._translate_openai_finish_reason_to_anthropic(
openai_finish_reason=response.choices[0].finish_reason # type: ignore
)
# extract usage
usage: litellm.Usage = getattr(response, "usage")
anthropic_usage = AnthropicResponseUsageBlock(
input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens
)
translated_obj = AnthropicResponse(
id=response.id,
type="message",
role="assistant",
model=response.model or "unknown-model",
stop_sequence=None,
usage=anthropic_usage,
content=anthropic_content,
stop_reason=anthropic_finish_reason,
)
return translated_obj
# makes headers for API call
def validate_environment(api_key, user_headers):
@ -454,121 +536,6 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
# def process_streaming_response(
# self,
# model: str,
# response: Union[requests.Response, httpx.Response],
# model_response: ModelResponse,
# stream: bool,
# logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
# optional_params: dict,
# api_key: str,
# data: Union[dict, str],
# messages: List,
# print_verbose,
# encoding,
# ) -> CustomStreamWrapper:
# """
# Return stream object for tool-calling + streaming
# """
# ## LOGGING
# logging_obj.post_call(
# input=messages,
# api_key=api_key,
# original_response=response.text,
# additional_args={"complete_input_dict": data},
# )
# print_verbose(f"raw model_response: {response.text}")
# ## RESPONSE OBJECT
# try:
# completion_response = response.json()
# except:
# raise AnthropicError(
# message=response.text, status_code=response.status_code
# )
# text_content = ""
# tool_calls = []
# for content in completion_response["content"]:
# if content["type"] == "text":
# text_content += content["text"]
# ## TOOL CALLING
# elif content["type"] == "tool_use":
# tool_calls.append(
# {
# "id": content["id"],
# "type": "function",
# "function": {
# "name": content["name"],
# "arguments": json.dumps(content["input"]),
# },
# }
# )
# if "error" in completion_response:
# raise AnthropicError(
# message=str(completion_response["error"]),
# status_code=response.status_code,
# )
# _message = litellm.Message(
# tool_calls=tool_calls,
# content=text_content or None,
# )
# model_response.choices[0].message = _message # type: ignore
# model_response._hidden_params["original_response"] = completion_response[
# "content"
# ] # allow user to access raw anthropic tool calling response
# model_response.choices[0].finish_reason = map_finish_reason(
# completion_response["stop_reason"]
# )
# print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# # return an iterator
# streaming_model_response = ModelResponse(stream=True)
# streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
# 0
# ].finish_reason
# # streaming_model_response.choices = [litellm.utils.StreamingChoices()]
# streaming_choice = litellm.utils.StreamingChoices()
# streaming_choice.index = model_response.choices[0].index
# _tool_calls = []
# print_verbose(
# f"type of model_response.choices[0]: {type(model_response.choices[0])}"
# )
# print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
# if isinstance(model_response.choices[0], litellm.Choices):
# if getattr(
# model_response.choices[0].message, "tool_calls", None
# ) is not None and isinstance(
# model_response.choices[0].message.tool_calls, list
# ):
# for tool_call in model_response.choices[0].message.tool_calls:
# _tool_call = {**tool_call.dict(), "index": 0}
# _tool_calls.append(_tool_call)
# delta_obj = litellm.utils.Delta(
# content=getattr(model_response.choices[0].message, "content", None),
# role=model_response.choices[0].message.role,
# tool_calls=_tool_calls,
# )
# streaming_choice.delta = delta_obj
# streaming_model_response.choices = [streaming_choice]
# completion_stream = ModelResponseIterator(
# model_response=streaming_model_response
# )
# print_verbose(
# "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
# )
# return CustomStreamWrapper(
# completion_stream=completion_stream,
# model=model,
# custom_llm_provider="cached_response",
# logging_obj=logging_obj,
# )
# else:
# raise AnthropicError(
# status_code=422,
# message="Unprocessable response object - {}".format(response.text),
# )
def process_response(
self,
model: str,

View file

@ -38,6 +38,7 @@ import dotenv
import httpx
import openai
import tiktoken
from pydantic import BaseModel
from typing_extensions import overload
import litellm
@ -3947,7 +3948,7 @@ def text_completion(
###### Adapter Completion ################
def adapter_completion(*, adapter_id: str, **kwargs) -> Any:
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters:
if item["id"] == adapter_id:

View file

@ -71,6 +71,11 @@ azure_api_key_header = APIKeyHeader(
auto_error=False,
description="Some older versions of the openai Python package will send an API-Key header with just the API key ",
)
anthropic_api_key_header = APIKeyHeader(
name="x-api-key",
auto_error=False,
description="If anthropic client used.",
)
def _get_bearer_token(
@ -87,6 +92,9 @@ async def user_api_key_auth(
request: Request,
api_key: str = fastapi.Security(api_key_header),
azure_api_key_header: str = fastapi.Security(azure_api_key_header),
anthropic_api_key_header: Optional[str] = fastapi.Security(
anthropic_api_key_header
),
) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
@ -114,6 +122,9 @@ async def user_api_key_auth(
elif isinstance(azure_api_key_header, str):
api_key = azure_api_key_header
elif isinstance(anthropic_api_key_header, str):
api_key = anthropic_api_key_header
parent_otel_span: Optional[Span] = None
if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span(

View file

@ -210,6 +210,12 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseUsageBlock,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import RouterGeneralSettings
@ -5030,6 +5036,34 @@ async def moderations(
)
#### ANTHROPIC ENDPOINTS ####
@router.post(
"/v1/messages",
tags=["[beta] Anthropic `/v1/messages`"],
dependencies=[Depends(user_api_key_auth)],
response_model=AnthropicResponse,
)
async def anthropic_response(data: AnthropicMessagesRequest):
from litellm import adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
response: Optional[BaseModel] = adapter_completion(adapter_id="anthropic", **data)
if response is None:
raise Exception("Response is None.")
elif not isinstance(response, AnthropicResponse):
raise Exception(
"Invalid model response={}. Not in 'AnthropicResponse' format".format(
response
)
)
return response
#### DEV UTILS ####
# @router.get(
@ -7546,7 +7580,7 @@ async def login(request: Request):
litellm_dashboard_ui += "/ui/"
import jwt
jwt_token = jwt.encode(
jwt_token = jwt.encode( # type: ignore
{
"user_id": user_id,
"key": key,
@ -7610,7 +7644,7 @@ async def login(request: Request):
litellm_dashboard_ui += "/ui/"
import jwt
jwt_token = jwt.encode(
jwt_token = jwt.encode( # type: ignore
{
"user_id": user_id,
"key": key,
@ -7745,7 +7779,7 @@ async def onboarding(invite_link: str):
litellm_dashboard_ui += "/ui/onboarding"
import jwt
jwt_token = jwt.encode(
jwt_token = jwt.encode( # type: ignore
{
"user_id": user_obj.user_id,
"key": key,
@ -8162,7 +8196,7 @@ async def auth_callback(request: Request):
import jwt
jwt_token = jwt.encode(
jwt_token = jwt.encode( # type: ignore
{
"user_id": user_id,
"key": key,

View file

@ -20,16 +20,51 @@ from unittest.mock import MagicMock, patch
import pytest
import litellm
from litellm import adapter_completion
from litellm import AnthropicConfig, adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter
from litellm.types.llms.anthropic import AnthropicResponse
def test_anthropic_completion():
def test_anthropic_completion_messages_translation():
messages = [{"role": "user", "content": "Hey, how's it going?"}]
translated_messages = AnthropicConfig().translate_anthropic_messages_to_openai(messages=messages) # type: ignore
assert translated_messages == [{"role": "user", "content": "Hey, how's it going?"}]
def test_anthropic_completion_input_translation():
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
translated_input = anthropic_adapter.translate_completion_input_params(kwargs=data)
assert translated_input is not None
assert translated_input["model"] == "gpt-3.5-turbo"
assert translated_input["messages"] == [
{"role": "user", "content": "Hey, how's it going?"}
]
def test_anthropic_completion_e2e():
litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = adapter_completion(
model="gpt-3.5-turbo", messages=messages, adapter_id="anthropic"
model="gpt-3.5-turbo",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
print(response)
print("Response: {}".format(response))
assert response is not None
assert isinstance(response, AnthropicResponse)
assert False

View file

@ -223,3 +223,51 @@ class MessageStartBlock(TypedDict):
type: Literal["message_start"]
message: MessageChunk
class AnthropicResponseContentBlockText(BaseModel):
type: Literal["text"]
text: str
class AnthropicResponseContentBlockToolUse(BaseModel):
type: Literal["tool_use"]
id: str
name: str
input: str
class AnthropicResponseUsageBlock(BaseModel):
input_tokens: int
output_tokens: int
AnthropicFinishReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
class AnthropicResponse(BaseModel):
id: str
"""Unique object identifier."""
type: Literal["message"]
"""For Messages, this is always "message"."""
role: Literal["assistant"]
"""Conversational role of the generated message. This will always be "assistant"."""
content: List[
Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse]
]
"""Content generated by the model."""
model: str
"""The model that handled the request."""
stop_reason: Optional[AnthropicFinishReason]
"""The reason that we stopped."""
stop_sequence: Optional[str]
"""Which custom stop sequence was generated, if any."""
usage: AnthropicResponseUsageBlock
"""Billing and rate-limit usage."""

View file

@ -166,7 +166,7 @@ class FunctionCall(OpenAIObject):
class Function(OpenAIObject):
arguments: str
name: Optional[str] = None
name: str
def __init__(
self,
@ -280,29 +280,43 @@ class ChatCompletionMessageToolCall(OpenAIObject):
setattr(self, key, value)
"""
Reference:
ChatCompletionMessage(content='This is a test', role='assistant', function_call=None, tool_calls=None))
"""
class Message(OpenAIObject):
content: Optional[str]
role: Literal["assistant"]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
function_call: Optional[FunctionCall]
def __init__(
self,
content: Optional[str] = "default",
role="assistant",
logprobs=None,
content: Optional[str] = None,
role: Literal["assistant"] = "assistant",
function_call=None,
tool_calls=None,
**params,
):
super(Message, self).__init__(**params)
self.content = content
self.role = role
if function_call is not None:
self.function_call = FunctionCall(**function_call)
if tool_calls is not None:
self.tool_calls = []
for tool_call in tool_calls:
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
if logprobs is not None:
self._logprobs = ChoiceLogprobs(**logprobs)
init_values = {
"content": content,
"role": role,
"function_call": (
FunctionCall(**function_call) if function_call is not None else None
),
"tool_calls": (
[ChatCompletionMessageToolCall(**tool_call) for tool_call in tool_calls]
if tool_calls is not None
else None
),
}
super(Message, self).__init__(
**init_values,
**params,
)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist