fix(anthropic.py): fix tool calling + streaming issue

This commit is contained in:
Krrish Dholakia 2024-05-11 20:15:36 -07:00
parent 83beb41096
commit a456f6bf2b

View file

@ -165,6 +165,9 @@ class AnthropicChatCompletion(BaseLLM):
print_verbose, print_verbose,
encoding, encoding,
) -> CustomStreamWrapper: ) -> CustomStreamWrapper:
"""
Return stream object for tool-calling + streaming
"""
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -202,6 +205,18 @@ class AnthropicChatCompletion(BaseLLM):
message=str(completion_response["error"]), message=str(completion_response["error"]),
status_code=response.status_code, status_code=response.status_code,
) )
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator # return an iterator
@ -392,13 +407,27 @@ class AnthropicChatCompletion(BaseLLM):
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
) -> ModelResponse: ) -> Union[ModelResponse, CustomStreamWrapper]:
self.async_handler = AsyncHTTPHandler( self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
response = await self.async_handler.post( response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data)
) )
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response( return self.process_response(
model=model, model=model,
response=response, response=response,