Merge branch 'main' into litellm_anthropic_tool_calling_streaming_fix

This commit is contained in:
Krish Dholakia 2024-07-03 20:43:51 -07:00 committed by GitHub
commit 06c6c65d2a
24 changed files with 868 additions and 508 deletions

View file

@ -446,6 +446,20 @@ class AnthropicChatCompletion(BaseLLM):
headers={},
):
data["stream"] = True
# async_handler = AsyncHTTPHandler(
# timeout=httpx.Timeout(timeout=600.0, connect=20.0)
# )
# response = await async_handler.post(
# api_base, headers=headers, json=data, stream=True
# )
# if response.status_code != 200:
# raise AnthropicError(
# status_code=response.status_code, message=response.text
# )
# completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=None,
@ -485,6 +499,7 @@ class AnthropicChatCompletion(BaseLLM):
headers={},
) -> Union[ModelResponse, CustomStreamWrapper]:
async_handler = _get_async_httpx_client()
try:
response = await async_handler.post(api_base, headers=headers, json=data)
except Exception as e:
@ -496,6 +511,7 @@ class AnthropicChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
)
raise e
return self.process_response(
model=model,
response=response,
@ -585,16 +601,13 @@ class AnthropicChatCompletion(BaseLLM):
optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None)
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
data = {
"model": model,
"messages": messages,
**optional_params,
}
if is_vertex_request is False:
data["model"] = model
## LOGGING
logging_obj.pre_call(
input=messages,
@ -680,27 +693,10 @@ class AnthropicChatCompletion(BaseLLM):
return streaming_response
else:
try:
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
except Exception as e:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"complete_input_dict": data},
)
raise e
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
raise AnthropicError(
status_code=response.status_code, message=response.text
)