diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 47e485ecb..3eca11bef 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -8,7 +8,7 @@ from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamW import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler - +from .base import BaseLLM import httpx @@ -19,9 +19,6 @@ class AnthropicConstants(Enum): # constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py -async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) - - class AnthropicError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -105,384 +102,402 @@ def validate_environment(api_key, user_headers): return headers -def process_response( - model, - response, - model_response, - _is_function_call, - stream, - logging_obj, - api_key, - data, - messages, - print_verbose, -): - ## LOGGING - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT - try: - completion_response = response.json() - except: - raise AnthropicError(message=response.text, status_code=response.status_code) - if "error" in completion_response: - raise AnthropicError( - message=str(completion_response["error"]), - status_code=response.status_code, - ) - elif len(completion_response["content"]) == 0: - raise AnthropicError( - message="No content in response", - status_code=response.status_code, - ) - else: - text_content = "" - tool_calls = [] - for content in completion_response["content"]: - if content["type"] == "text": - text_content += content["text"] - ## TOOL CALLING - elif content["type"] == "tool_use": - tool_calls.append( - { - "id": content["id"], - "type": "function", - "function": { - "name": content["name"], - "arguments": json.dumps(content["input"]), - }, - } - ) - - _message = litellm.Message( - tool_calls=tool_calls, - content=text_content or None, - ) - model_response.choices[0].message = _message # type: ignore - model_response._hidden_params["original_response"] = completion_response[ - "content" - ] # allow user to access raw anthropic tool calling response - - model_response.choices[0].finish_reason = map_finish_reason( - completion_response["stop_reason"] +class AnthropicChatCompletion(BaseLLM): + def __init__(self) -> None: + super().__init__() + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) - print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}") - if _is_function_call and stream: - print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") - # return an iterator - streaming_model_response = ModelResponse(stream=True) - streaming_model_response.choices[0].finish_reason = model_response.choices[ - 0 - ].finish_reason - # streaming_model_response.choices = [litellm.utils.StreamingChoices()] - streaming_choice = litellm.utils.StreamingChoices() - streaming_choice.index = model_response.choices[0].index - _tool_calls = [] - print_verbose( - f"type of model_response.choices[0]: {type(model_response.choices[0])}" + def process_response( + self, + model, + response, + model_response, + _is_function_call, + stream, + logging_obj, + api_key, + data, + messages, + print_verbose, + ): + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, ) - print_verbose(f"type of streaming_choice: {type(streaming_choice)}") - if isinstance(model_response.choices[0], litellm.Choices): - if getattr( - model_response.choices[0].message, "tool_calls", None - ) is not None and isinstance( - model_response.choices[0].message.tool_calls, list - ): - for tool_call in model_response.choices[0].message.tool_calls: - _tool_call = {**tool_call.dict(), "index": 0} - _tool_calls.append(_tool_call) - delta_obj = litellm.utils.Delta( - content=getattr(model_response.choices[0].message, "content", None), - role=model_response.choices[0].message.role, - tool_calls=_tool_calls, - ) - streaming_choice.delta = delta_obj - streaming_model_response.choices = [streaming_choice] - completion_stream = ModelResponseIterator( - model_response=streaming_model_response - ) - print_verbose( - "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" - ) - return CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - - ## CALCULATING USAGE - prompt_tokens = completion_response["usage"]["input_tokens"] - completion_tokens = completion_response["usage"]["output_tokens"] - total_tokens = prompt_tokens + completion_tokens - - model_response["created"] = int(time.time()) - model_response["model"] = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) - model_response.usage = usage - return model_response - - -async def acompletion_stream_function( - model: str, - messages: list, - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - stream, - _is_function_call, - data=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - headers={}, -): - response = await async_handler.post( - api_base, headers=headers, data=json.dumps(data) - ) - - if response.status_code != 200: - raise AnthropicError(status_code=response.status_code, message=response.text) - - completion_stream = response.aiter_lines() - - streamwrapper = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="anthropic", - logging_obj=logging_obj, - ) - return streamwrapper - - -async def acompletion_function( - model: str, - messages: list, - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - stream, - _is_function_call, - data=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - headers={}, -): - response = await async_handler.post( - api_base, headers=headers, data=json.dumps(data) - ) - return process_response( - model=model, - response=response, - model_response=model_response, - _is_function_call=_is_function_call, - stream=stream, - logging_obj=logging_obj, - api_key=api_key, - data=data, - messages=messages, - print_verbose=print_verbose, - ) - - -def completion( - model: str, - messages: list, - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - optional_params=None, - acompletion=None, - litellm_params=None, - logger_fn=None, - headers={}, -): - headers = validate_environment(api_key, headers) - _is_function_call = False - messages = copy.deepcopy(messages) - optional_params = copy.deepcopy(optional_params) - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages, - ) - else: - # Separate system prompt from rest of message - system_prompt_indices = [] - system_prompt = "" - for idx, message in enumerate(messages): - if message["role"] == "system": - system_prompt += message["content"] - system_prompt_indices.append(idx) - if len(system_prompt_indices) > 0: - for idx in reversed(system_prompt_indices): - messages.pop(idx) - if len(system_prompt) > 0: - optional_params["system"] = system_prompt - # Format rest of message according to anthropic guidelines + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT try: - messages = prompt_factory( - model=model, messages=messages, custom_llm_provider="anthropic" + completion_response = response.json() + except: + raise AnthropicError( + message=response.text, status_code=response.status_code ) - except Exception as e: - raise AnthropicError(status_code=400, message=str(e)) - - ## Load Config - config = litellm.AnthropicConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## Handle Tool Calling - if "tools" in optional_params: - _is_function_call = True - headers["anthropic-beta"] = "tools-2024-04-04" - - anthropic_tools = [] - for tool in optional_params["tools"]: - new_tool = tool["function"] - new_tool["input_schema"] = new_tool.pop("parameters") # rename key - anthropic_tools.append(new_tool) - - optional_params["tools"] = anthropic_tools - - stream = optional_params.pop("stream", None) - - data = { - "model": model, - "messages": messages, - **optional_params, - } - - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "api_base": api_base, - "headers": headers, - }, - ) - print_verbose(f"_is_function_call: {_is_function_call}") - if acompletion == True: - if ( - stream and not _is_function_call - ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) - print_verbose("makes async anthropic streaming POST request") - data["stream"] = stream - return acompletion_stream_function( - model=model, - messages=messages, - data=data, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - stream=stream, - _is_function_call=_is_function_call, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, + if "error" in completion_response: + raise AnthropicError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + elif len(completion_response["content"]) == 0: + raise AnthropicError( + message="No content in response", + status_code=response.status_code, ) else: - return acompletion_function( - model=model, + text_content = "" + tool_calls = [] + for content in completion_response["content"]: + if content["type"] == "text": + text_content += content["text"] + ## TOOL CALLING + elif content["type"] == "tool_use": + tool_calls.append( + { + "id": content["id"], + "type": "function", + "function": { + "name": content["name"], + "arguments": json.dumps(content["input"]), + }, + } + ) + + _message = litellm.Message( + tool_calls=tool_calls, + content=text_content or None, + ) + model_response.choices[0].message = _message # type: ignore + model_response._hidden_params["original_response"] = completion_response[ + "content" + ] # allow user to access raw anthropic tool calling response + + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["stop_reason"] + ) + + print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}") + if _is_function_call and stream: + print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") + # return an iterator + streaming_model_response = ModelResponse(stream=True) + streaming_model_response.choices[0].finish_reason = model_response.choices[ + 0 + ].finish_reason + # streaming_model_response.choices = [litellm.utils.StreamingChoices()] + streaming_choice = litellm.utils.StreamingChoices() + streaming_choice.index = model_response.choices[0].index + _tool_calls = [] + print_verbose( + f"type of model_response.choices[0]: {type(model_response.choices[0])}" + ) + print_verbose(f"type of streaming_choice: {type(streaming_choice)}") + if isinstance(model_response.choices[0], litellm.Choices): + if getattr( + model_response.choices[0].message, "tool_calls", None + ) is not None and isinstance( + model_response.choices[0].message.tool_calls, list + ): + for tool_call in model_response.choices[0].message.tool_calls: + _tool_call = {**tool_call.dict(), "index": 0} + _tool_calls.append(_tool_call) + delta_obj = litellm.utils.Delta( + content=getattr(model_response.choices[0].message, "content", None), + role=model_response.choices[0].message.role, + tool_calls=_tool_calls, + ) + streaming_choice.delta = delta_obj + streaming_model_response.choices = [streaming_choice] + completion_stream = ModelResponseIterator( + model_response=streaming_model_response + ) + print_verbose( + "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + + ## CALCULATING USAGE + prompt_tokens = completion_response["usage"]["input_tokens"] + completion_tokens = completion_response["usage"]["output_tokens"] + total_tokens = prompt_tokens + completion_tokens + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + model_response.usage = usage + return model_response + + async def acompletion_stream_function( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + _is_function_call, + data=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ): + response = await self.async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + + completion_stream = response.aiter_lines() + + streamwrapper = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="anthropic", + logging_obj=logging_obj, + ) + return streamwrapper + + async def acompletion_function( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + _is_function_call, + data=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ): + response = await self.async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + return self.process_response( + model=model, + response=response, + model_response=model_response, + _is_function_call=_is_function_call, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + ) + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + ): + headers = validate_environment(api_key, headers) + _is_function_call = False + messages = copy.deepcopy(messages) + optional_params = copy.deepcopy(optional_params) + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], messages=messages, - data=data, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - stream=stream, - _is_function_call=_is_function_call, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, ) - else: - ## COMPLETION CALL - if ( - stream and not _is_function_call - ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) - print_verbose("makes anthropic streaming POST request") - data["stream"] = stream - response = requests.post( - api_base, - headers=headers, - data=json.dumps(data), - stream=stream, - ) - - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - - completion_stream = response.iter_lines() - streaming_response = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="anthropic", - logging_obj=logging_obj, - ) - return streaming_response - else: - response = requests.post(api_base, headers=headers, data=json.dumps(data)) - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text + # Separate system prompt from rest of message + system_prompt_indices = [] + system_prompt = "" + for idx, message in enumerate(messages): + if message["role"] == "system": + system_prompt += message["content"] + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + if len(system_prompt) > 0: + optional_params["system"] = system_prompt + # Format rest of message according to anthropic guidelines + try: + messages = prompt_factory( + model=model, messages=messages, custom_llm_provider="anthropic" ) - return process_response( - model=model, - response=response, - model_response=model_response, - _is_function_call=_is_function_call, - stream=stream, - logging_obj=logging_obj, - api_key=api_key, - data=data, - messages=messages, - print_verbose=print_verbose, - ) + except Exception as e: + raise AnthropicError(status_code=400, message=str(e)) + + ## Load Config + config = litellm.AnthropicConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## Handle Tool Calling + if "tools" in optional_params: + _is_function_call = True + headers["anthropic-beta"] = "tools-2024-04-04" + + anthropic_tools = [] + for tool in optional_params["tools"]: + new_tool = tool["function"] + new_tool["input_schema"] = new_tool.pop("parameters") # rename key + anthropic_tools.append(new_tool) + + optional_params["tools"] = anthropic_tools + + stream = optional_params.pop("stream", None) + + data = { + "model": model, + "messages": messages, + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + print_verbose(f"_is_function_call: {_is_function_call}") + if acompletion == True: + if ( + stream and not _is_function_call + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose("makes async anthropic streaming POST request") + data["stream"] = stream + return self.acompletion_stream_function( + model=model, + messages=messages, + data=data, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + _is_function_call=_is_function_call, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) + else: + return self.acompletion_function( + model=model, + messages=messages, + data=data, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + _is_function_call=_is_function_call, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) + else: + ## COMPLETION CALL + if ( + stream and not _is_function_call + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose("makes anthropic streaming POST request") + data["stream"] = stream + response = requests.post( + api_base, + headers=headers, + data=json.dumps(data), + stream=stream, + ) + + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + + completion_stream = response.iter_lines() + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="anthropic", + logging_obj=logging_obj, + ) + return streaming_response + + else: + response = requests.post( + api_base, headers=headers, data=json.dumps(data) + ) + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + return self.process_response( + model=model, + response=response, + model_response=model_response, + _is_function_call=_is_function_call, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + ) + + def embedding(self): + # logic for parsing in - calling - parsing out model embedding calls + pass class ModelResponseIterator: @@ -509,8 +524,3 @@ class ModelResponseIterator: raise StopAsyncIteration self.is_done = True return self.model_response - - -def embedding(): - # logic for parsing in - calling - parsing out model embedding calls - pass diff --git a/litellm/main.py b/litellm/main.py index 75382b8ce..a387c9147 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -39,7 +39,6 @@ from litellm.utils import ( get_optional_params_image_gen, ) from .llms import ( - anthropic, anthropic_text, together_ai, ai21, @@ -68,6 +67,7 @@ from .llms import ( from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion from .llms.azure_text import AzureTextCompletion +from .llms.anthropic import AnthropicChatCompletion from .llms.huggingface_restapi import Huggingface from .llms.prompt_templates.factory import ( prompt_factory, @@ -99,6 +99,7 @@ from litellm.utils import ( dotenv.load_dotenv() # Loading env variables using dotenv openai_chat_completions = OpenAIChatCompletion() openai_text_completions = OpenAITextCompletion() +anthropic_chat_completions = AnthropicChatCompletion() azure_chat_completions = AzureChatCompletion() azure_text_completions = AzureTextCompletion() huggingface = Huggingface() @@ -1181,7 +1182,7 @@ def completion( or get_secret("ANTHROPIC_API_BASE") or "https://api.anthropic.com/v1/messages" ) - response = anthropic.completion( + response = anthropic_chat_completions.completion( model=model, messages=messages, api_base=api_base,