diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index b7b078b9b..db41ae6e3 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -7,6 +7,9 @@ from typing import Callable, Optional, List from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper import litellm from .prompt_templates.factory import prompt_factory, custom_prompt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + +async_handler = AsyncHTTPHandler() import httpx @@ -36,7 +39,9 @@ class AnthropicConfig: to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} """ - max_tokens: Optional[int] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) + max_tokens: Optional[int] = ( + 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) + ) stop_sequences: Optional[list] = None temperature: Optional[int] = None top_p: Optional[int] = None @@ -46,7 +51,9 @@ class AnthropicConfig: def __init__( self, - max_tokens: Optional[int] = 4096, # You can pass in a value yourself or use the default value 4096 + max_tokens: Optional[ + int + ] = 4096, # You can pass in a value yourself or use the default value 4096 stop_sequences: Optional[list] = None, temperature: Optional[int] = None, top_p: Optional[int] = None, @@ -95,6 +102,169 @@ def validate_environment(api_key, user_headers): return headers +def process_response( + model, + response, + model_response, + _is_function_call, + stream, + logging_obj, + api_key, + data, + messages, + print_verbose, +): + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise AnthropicError(message=response.text, status_code=response.status_code) + if "error" in completion_response: + raise AnthropicError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + elif len(completion_response["content"]) == 0: + raise AnthropicError( + message="No content in response", + status_code=response.status_code, + ) + else: + text_content = "" + tool_calls = [] + for content in completion_response["content"]: + if content["type"] == "text": + text_content += content["text"] + ## TOOL CALLING + elif content["type"] == "tool_use": + tool_calls.append( + { + "id": content["id"], + "type": "function", + "function": { + "name": content["name"], + "arguments": json.dumps(content["input"]), + }, + } + ) + + _message = litellm.Message( + tool_calls=tool_calls, + content=text_content or None, + ) + model_response.choices[0].message = _message # type: ignore + model_response._hidden_params["original_response"] = completion_response[ + "content" + ] # allow user to access raw anthropic tool calling response + + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["stop_reason"] + ) + + print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}") + if _is_function_call and stream: + print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") + # return an iterator + streaming_model_response = ModelResponse(stream=True) + streaming_model_response.choices[0].finish_reason = model_response.choices[ + 0 + ].finish_reason + # streaming_model_response.choices = [litellm.utils.StreamingChoices()] + streaming_choice = litellm.utils.StreamingChoices() + streaming_choice.index = model_response.choices[0].index + _tool_calls = [] + print_verbose( + f"type of model_response.choices[0]: {type(model_response.choices[0])}" + ) + print_verbose(f"type of streaming_choice: {type(streaming_choice)}") + if isinstance(model_response.choices[0], litellm.Choices): + if getattr( + model_response.choices[0].message, "tool_calls", None + ) is not None and isinstance( + model_response.choices[0].message.tool_calls, list + ): + for tool_call in model_response.choices[0].message.tool_calls: + _tool_call = {**tool_call.dict(), "index": 0} + _tool_calls.append(_tool_call) + delta_obj = litellm.utils.Delta( + content=getattr(model_response.choices[0].message, "content", None), + role=model_response.choices[0].message.role, + tool_calls=_tool_calls, + ) + streaming_choice.delta = delta_obj + streaming_model_response.choices = [streaming_choice] + completion_stream = ModelResponseIterator( + model_response=streaming_model_response + ) + print_verbose( + "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + + ## CALCULATING USAGE + prompt_tokens = completion_response["usage"]["input_tokens"] + completion_tokens = completion_response["usage"]["output_tokens"] + total_tokens = prompt_tokens + completion_tokens + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + model_response.usage = usage + return model_response + + +async def acompletion_function( + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + _is_function_call, + data=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, +): + response = await async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + return process_response( + model=model, + response=response, + model_response=model_response, + _is_function_call=_is_function_call, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + ) + + def completion( model: str, messages: list, @@ -106,6 +276,7 @@ def completion( api_key, logging_obj, optional_params=None, + acompletion=None, litellm_params=None, logger_fn=None, headers={}, @@ -184,148 +355,66 @@ def completion( }, ) print_verbose(f"_is_function_call: {_is_function_call}") - ## COMPLETION CALL - if ( - stream and not _is_function_call - ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) - print_verbose("makes anthropic streaming POST request") - data["stream"] = stream - response = requests.post( - api_base, - headers=headers, - data=json.dumps(data), - stream=stream, - ) - - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - - return response.iter_lines() - else: - response = requests.post(api_base, headers=headers, data=json.dumps(data)) - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - - ## LOGGING - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT - try: - completion_response = response.json() - except: - raise AnthropicError( - message=response.text, status_code=response.status_code - ) - if "error" in completion_response: - raise AnthropicError( - message=str(completion_response["error"]), - status_code=response.status_code, - ) - elif len(completion_response["content"]) == 0: - raise AnthropicError( - message="No content in response", - status_code=response.status_code, - ) + if acompletion == True: + if optional_params.get("stream", False): + pass else: - text_content = "" - tool_calls = [] - for content in completion_response["content"]: - if content["type"] == "text": - text_content += content["text"] - ## TOOL CALLING - elif content["type"] == "tool_use": - tool_calls.append( - { - "id": content["id"], - "type": "function", - "function": { - "name": content["name"], - "arguments": json.dumps(content["input"]), - }, - } - ) - - _message = litellm.Message( - tool_calls=tool_calls, - content=text_content or None, + return acompletion_function( + model=model, + messages=messages, + data=data, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + _is_function_call=_is_function_call, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, ) - model_response.choices[0].message = _message # type: ignore - model_response._hidden_params["original_response"] = completion_response[ - "content" - ] # allow user to access raw anthropic tool calling response - - model_response.choices[0].finish_reason = map_finish_reason( - completion_response["stop_reason"] + else: + ## COMPLETION CALL + if ( + stream and not _is_function_call + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose("makes anthropic streaming POST request") + data["stream"] = stream + response = requests.post( + api_base, + headers=headers, + data=json.dumps(data), + stream=stream, ) - print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}") - if _is_function_call and stream: - print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") - # return an iterator - streaming_model_response = ModelResponse(stream=True) - streaming_model_response.choices[0].finish_reason = model_response.choices[ - 0 - ].finish_reason - # streaming_model_response.choices = [litellm.utils.StreamingChoices()] - streaming_choice = litellm.utils.StreamingChoices() - streaming_choice.index = model_response.choices[0].index - _tool_calls = [] - print_verbose( - f"type of model_response.choices[0]: {type(model_response.choices[0])}" - ) - print_verbose(f"type of streaming_choice: {type(streaming_choice)}") - if isinstance(model_response.choices[0], litellm.Choices): - if getattr( - model_response.choices[0].message, "tool_calls", None - ) is not None and isinstance( - model_response.choices[0].message.tool_calls, list - ): - for tool_call in model_response.choices[0].message.tool_calls: - _tool_call = {**tool_call.dict(), "index": 0} - _tool_calls.append(_tool_call) - delta_obj = litellm.utils.Delta( - content=getattr(model_response.choices[0].message, "content", None), - role=model_response.choices[0].message.role, - tool_calls=_tool_calls, - ) - streaming_choice.delta = delta_obj - streaming_model_response.choices = [streaming_choice] - completion_stream = ModelResponseIterator( - model_response=streaming_model_response - ) - print_verbose( - "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" - ) - return CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text ) - ## CALCULATING USAGE - prompt_tokens = completion_response["usage"]["input_tokens"] - completion_tokens = completion_response["usage"]["output_tokens"] - total_tokens = prompt_tokens + completion_tokens - - model_response["created"] = int(time.time()) - model_response["model"] = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) - model_response.usage = usage - return model_response + return response.iter_lines() + else: + response = requests.post(api_base, headers=headers, data=json.dumps(data)) + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + return process_response( + model=model, + response=response, + model_response=model_response, + _is_function_call=_is_function_call, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + ) class ModelResponseIterator: diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 10314d831..51723a2f9 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -1,5 +1,5 @@ import httpx, asyncio -from typing import Optional +from typing import Optional, Union class AsyncHTTPHandler: @@ -25,7 +25,7 @@ class AsyncHTTPHandler: async def post( self, url: str, - data: Optional[dict] = None, + data: Optional[Union[dict, str]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, ): diff --git a/litellm/main.py b/litellm/main.py index 5a9eb6e45..b7e5a3ba9 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -304,6 +304,7 @@ async def acompletion( or custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" + or custom_llm_provider == "anthropic" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1184,6 +1185,7 @@ def completion( model=model, messages=messages, api_base=api_base, + acompletion=acompletion, custom_prompt_dict=litellm.custom_prompt_dict, model_response=model_response, print_verbose=print_verbose,