Compare commits

...
Sign in to create a new pull request.

7 commits

Author SHA1 Message Date
Ishaan Jaff
97ecedf997 use 1 helper to return stream_response on passthrough 2024-11-20 15:49:33 -08:00
Ishaan Jaff
acf350a2fb fix check for streaming response 2024-11-20 15:15:21 -08:00
Ishaan Jaff
9f916636e1 fix get_response_body 2024-11-20 13:18:51 -08:00
Ishaan Jaff
bb7fe53bc5 fix anthropic_passthrough_handler 2024-11-20 13:18:35 -08:00
Ishaan Jaff
83a722a34b add AnthropicConfig 2024-11-20 12:09:32 -08:00
Ishaan Jaff
b3b1ff6882 fix AnthropicConfig test 2024-11-20 12:07:28 -08:00
Ishaan Jaff
f121b8f630 move _process_response in transformation 2024-11-20 12:02:15 -08:00
6 changed files with 330 additions and 208 deletions

View file

@ -45,9 +45,7 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import Message as LitellmMessage
from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from litellm.utils import CustomStreamWrapper, ModelResponse
from ...base import BaseLLM
from ..common_utils import AnthropicError, process_anthropic_headers
@ -201,163 +199,6 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
json_mode: bool,
) -> ModelResponse:
_hidden_params: Dict = {}
_hidden_params["additional_headers"] = process_anthropic_headers(
dict(response.headers)
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except Exception as e:
response_headers = getattr(response, "headers", None)
raise AnthropicError(
message="Unable to get json response - {}, Original Response: {}".format(
str(e), response.text
),
status_code=response.status_code,
headers=response_headers,
)
if "error" in completion_response:
response_headers = getattr(response, "headers", None)
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
headers=response_headers,
)
else:
text_content = ""
tool_calls: List[ChatCompletionToolCallChunk] = []
for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
ChatCompletionToolCallChunk(
id=content["id"],
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=content["name"],
arguments=json.dumps(content["input"]),
),
index=idx,
)
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
## HANDLE JSON MODE - anthropic returns single function call
if json_mode and len(tool_calls) == 1:
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
if json_mode_content_str is not None:
_converted_message = self._convert_tool_response_to_message(
tool_calls=tool_calls,
)
if _converted_message is not None:
completion_response["stop_reason"] = "stop"
_message = _converted_message
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
_usage = completion_response["usage"]
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
model_response.created = int(time.time())
model_response.model = model
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
prompt_tokens += cache_creation_input_tokens
if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
total_tokens = prompt_tokens + completion_tokens
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
)
setattr(model_response, "usage", usage) # type: ignore
model_response._hidden_params = _hidden_params
return model_response
@staticmethod
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[LitellmMessage]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
if (
isinstance(args, dict)
and (values := args.get("values")) is not None
):
_message = litellm.Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = litellm.Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return litellm.Message(content=json_mode_content_str)
return None
async def acompletion_stream_function(
self,
model: str,
@ -454,7 +295,7 @@ class AnthropicChatCompletion(BaseLLM):
headers=error_headers,
)
return self._process_response(
return AnthropicConfig._process_response(
model=model,
response=response,
model_response=model_response,
@ -630,7 +471,7 @@ class AnthropicChatCompletion(BaseLLM):
headers=error_headers,
)
return self._process_response(
return AnthropicConfig._process_response(
model=model,
response=response,
model_response=model_response,
@ -855,7 +696,7 @@ class ModelResponseIterator:
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
"""
if self.json_mode is True and tool_use is not None:
message = AnthropicChatCompletion._convert_tool_response_to_message(
message = AnthropicConfig._convert_tool_response_to_message(
tool_calls=[tool_use]
)
if message is not None:

View file

@ -1,7 +1,14 @@
import json
import time
import types
from typing import List, Literal, Optional, Tuple, Union
from re import A
from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx
import requests
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
@ -18,12 +25,23 @@ from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionCachedContent,
ChatCompletionSystemMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import Message as LitellmMessage
from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import (
CustomStreamWrapper,
ModelResponse,
Usage,
add_dummy_tool,
has_tool_call_blocks,
)
from litellm.utils import add_dummy_tool, has_tool_call_blocks
from ..common_utils import AnthropicError
from ..common_utils import AnthropicError, process_anthropic_headers
class AnthropicConfig:
@ -534,3 +552,162 @@ class AnthropicConfig:
if not is_vertex_request:
data["model"] = model
return data
@staticmethod
def _process_response(
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
json_mode: bool,
) -> ModelResponse:
_hidden_params: Dict = {}
_hidden_params["additional_headers"] = process_anthropic_headers(
dict(response.headers)
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except Exception as e:
response_headers = getattr(response, "headers", None)
raise AnthropicError(
message="Unable to get json response - {}, Original Response: {}".format(
str(e), response.text
),
status_code=response.status_code,
headers=response_headers,
)
if "error" in completion_response:
response_headers = getattr(response, "headers", None)
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
headers=response_headers,
)
else:
text_content = ""
tool_calls: List[ChatCompletionToolCallChunk] = []
for idx, content in enumerate(completion_response["content"]):
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
ChatCompletionToolCallChunk(
id=content["id"],
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=content["name"],
arguments=json.dumps(content["input"]),
),
index=idx,
)
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
## HANDLE JSON MODE - anthropic returns single function call
if json_mode and len(tool_calls) == 1:
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
if json_mode_content_str is not None:
_converted_message = (
AnthropicConfig._convert_tool_response_to_message(
tool_calls=tool_calls,
)
)
if _converted_message is not None:
completion_response["stop_reason"] = "stop"
_message = _converted_message
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
_usage = completion_response["usage"]
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
model_response.created = int(time.time())
model_response.model = model
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
prompt_tokens += cache_creation_input_tokens
if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=cache_read_input_tokens
)
total_tokens = prompt_tokens + completion_tokens
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_tokens_details=prompt_tokens_details,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
)
setattr(model_response, "usage", usage) # type: ignore
model_response._hidden_params = _hidden_params
return model_response
@staticmethod
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[LitellmMessage]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
if (
isinstance(args, dict)
and (values := args.get("values")) is not None
):
_message = litellm.Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = litellm.Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return litellm.Message(content=json_mode_content_str)
return None

View file

@ -45,11 +45,11 @@ router = APIRouter()
pass_through_endpoint_logging = PassThroughEndpointLogging()
def get_response_body(response: httpx.Response):
def get_response_body(response: httpx.Response) -> Optional[dict]:
try:
return response.json()
except Exception:
return response.text
return None
async def set_env_variables_in_header(custom_headers: Optional[dict]) -> Optional[dict]:
@ -303,9 +303,29 @@ def get_response_headers(headers: httpx.Headers) -> dict:
def get_endpoint_type(url: str) -> EndpointType:
if ("generateContent") in url or ("streamGenerateContent") in url:
return EndpointType.VERTEX_AI
elif ("api.anthropic.com") in url:
return EndpointType.ANTHROPIC
return EndpointType.GENERIC
async def stream_response(
response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
url: str,
) -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
async def pass_through_request( # noqa: PLR0915
request: Request,
target: str,
@ -445,19 +465,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread()
)
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
return StreamingResponse(
stream_response(
response=response,
logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
return StreamingResponse(
stream_response(),
url=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
@ -478,10 +493,9 @@ async def pass_through_request( # noqa: PLR0915
json=_parsed_body,
)
if (
response.headers.get("content-type") is not None
and response.headers["content-type"] == "text/event-stream"
):
verbose_proxy_logger.debug("response.headers= %s", response.headers)
if _is_streaming_response(response) is True:
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
@ -489,19 +503,14 @@ async def pass_through_request( # noqa: PLR0915
status_code=e.response.status_code, detail=await e.response.aread()
)
async def stream_response() -> AsyncIterable[bytes]:
async for chunk in chunk_processor(
response.aiter_bytes(),
litellm_logging_obj=logging_obj,
return StreamingResponse(
stream_response(
response=response,
logging_obj=logging_obj,
endpoint_type=endpoint_type,
start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
):
yield chunk
return StreamingResponse(
stream_response(),
url=str(url),
),
headers=get_response_headers(response.headers),
status_code=response.status_code,
)
@ -519,10 +528,12 @@ async def pass_through_request( # noqa: PLR0915
content = await response.aread()
## LOG SUCCESS
passthrough_logging_payload["response_body"] = get_response_body(response)
response_body: Optional[dict] = get_response_body(response)
passthrough_logging_payload["response_body"] = response_body
end_time = datetime.now()
await pass_through_endpoint_logging.pass_through_async_success_handler(
httpx_response=response,
response_body=response_body,
url_route=str(url),
result="",
start_time=start_time,
@ -619,6 +630,13 @@ def create_pass_through_route(
return endpoint_func
def _is_streaming_response(response: httpx.Response) -> bool:
_content_type = response.headers.get("content-type")
if _content_type is not None and "text/event-stream" in _content_type:
return True
return False
async def initialize_pass_through_endpoints(pass_through_endpoints: list):
verbose_proxy_logger.debug("initializing pass through endpoints")

View file

@ -2,12 +2,17 @@ import json
import re
import threading
from datetime import datetime
from typing import Union
from typing import Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@ -23,9 +28,13 @@ class PassThroughEndpointLogging:
"predict",
]
# Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
async def pass_through_async_success_handler(
self,
httpx_response: httpx.Response,
response_body: Optional[dict],
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
@ -45,6 +54,18 @@ class PassThroughEndpointLogging:
cache_hit=cache_hit,
**kwargs,
)
elif self.is_anthropic_route(url_route):
await self.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
else:
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
@ -76,6 +97,12 @@ class PassThroughEndpointLogging:
return True
return False
def is_anthropic_route(self, url_route: str):
for route in self.TRACKED_ANTHROPIC_ROUTES:
if route in url_route:
return True
return False
def extract_model_from_url(self, url: str) -> str:
pattern = r"/models/([^:]+)"
match = re.search(pattern, url)
@ -83,6 +110,72 @@ class PassThroughEndpointLogging:
return match.group(1)
return "unknown"
async def anthropic_passthrough_handler(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
model = response_body.get("model", "")
litellm_model_response: litellm.ModelResponse = (
AnthropicConfig._process_response(
response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
stream=False,
messages=[],
logging_obj=logging_obj,
optional_params={},
api_key="",
data={},
print_verbose=litellm.print_verbose,
encoding=None,
json_mode=False,
)
)
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=litellm_model_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
)
kwargs["standard_logging_object"] = standard_logging_object
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
pass
async def vertex_passthrough_handler(
self,
httpx_response: httpx.Response,

View file

@ -4,6 +4,7 @@ from typing import Optional, TypedDict
class EndpointType(str, Enum):
VERTEX_AI = "vertex-ai"
ANTHROPIC = "anthropic"
GENERIC = "generic"

View file

@ -712,9 +712,7 @@ def test_convert_tool_response_to_message_with_values():
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
message = AnthropicConfig._convert_tool_response_to_message(tool_calls=tool_calls)
assert message is not None
assert message.content == '{"name": "John", "age": 30}'
@ -739,9 +737,7 @@ def test_convert_tool_response_to_message_without_values():
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
message = AnthropicConfig._convert_tool_response_to_message(tool_calls=tool_calls)
assert message is not None
assert message.content == '{"name": "John", "age": 30}'
@ -760,9 +756,7 @@ def test_convert_tool_response_to_message_invalid_json():
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
message = AnthropicConfig._convert_tool_response_to_message(tool_calls=tool_calls)
assert message is not None
assert message.content == "invalid json"
@ -779,8 +773,6 @@ def test_convert_tool_response_to_message_no_arguments():
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
message = AnthropicConfig._convert_tool_response_to_message(tool_calls=tool_calls)
assert message is None