diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index b7b078b9b..d836ed8db 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -7,7 +7,8 @@ from typing import Callable, Optional, List from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper import litellm from .prompt_templates.factory import prompt_factory, custom_prompt - +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from .base import BaseLLM import httpx @@ -15,6 +16,8 @@ class AnthropicConstants(Enum): HUMAN_PROMPT = "\n\nHuman: " AI_PROMPT = "\n\nAssistant: " + # constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py + class AnthropicError(Exception): def __init__(self, status_code, message): @@ -36,7 +39,9 @@ class AnthropicConfig: to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} """ - max_tokens: Optional[int] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) + max_tokens: Optional[int] = ( + 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) + ) stop_sequences: Optional[list] = None temperature: Optional[int] = None top_p: Optional[int] = None @@ -46,7 +51,9 @@ class AnthropicConfig: def __init__( self, - max_tokens: Optional[int] = 4096, # You can pass in a value yourself or use the default value 4096 + max_tokens: Optional[ + int + ] = 4096, # You can pass in a value yourself or use the default value 4096 stop_sequences: Optional[list] = None, temperature: Optional[int] = None, top_p: Optional[int] = None, @@ -95,121 +102,23 @@ def validate_environment(api_key, user_headers): return headers -def completion( - model: str, - messages: list, - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - optional_params=None, - litellm_params=None, - logger_fn=None, - headers={}, -): - headers = validate_environment(api_key, headers) - _is_function_call = False - messages = copy.deepcopy(messages) - optional_params = copy.deepcopy(optional_params) - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages, - ) - else: - # Separate system prompt from rest of message - system_prompt_indices = [] - system_prompt = "" - for idx, message in enumerate(messages): - if message["role"] == "system": - system_prompt += message["content"] - system_prompt_indices.append(idx) - if len(system_prompt_indices) > 0: - for idx in reversed(system_prompt_indices): - messages.pop(idx) - if len(system_prompt) > 0: - optional_params["system"] = system_prompt - # Format rest of message according to anthropic guidelines - try: - messages = prompt_factory( - model=model, messages=messages, custom_llm_provider="anthropic" - ) - except Exception as e: - raise AnthropicError(status_code=400, message=str(e)) - - ## Load Config - config = litellm.AnthropicConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## Handle Tool Calling - if "tools" in optional_params: - _is_function_call = True - headers["anthropic-beta"] = "tools-2024-04-04" - - anthropic_tools = [] - for tool in optional_params["tools"]: - new_tool = tool["function"] - new_tool["input_schema"] = new_tool.pop("parameters") # rename key - anthropic_tools.append(new_tool) - - optional_params["tools"] = anthropic_tools - - stream = optional_params.pop("stream", None) - - data = { - "model": model, - "messages": messages, - **optional_params, - } - - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "api_base": api_base, - "headers": headers, - }, - ) - print_verbose(f"_is_function_call: {_is_function_call}") - ## COMPLETION CALL - if ( - stream and not _is_function_call - ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) - print_verbose("makes anthropic streaming POST request") - data["stream"] = stream - response = requests.post( - api_base, - headers=headers, - data=json.dumps(data), - stream=stream, - ) - - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) - - return response.iter_lines() - else: - response = requests.post(api_base, headers=headers, data=json.dumps(data)) - if response.status_code != 200: - raise AnthropicError( - status_code=response.status_code, message=response.text - ) +class AnthropicChatCompletion(BaseLLM): + def __init__(self) -> None: + super().__init__() + def process_response( + self, + model, + response, + model_response, + _is_function_call, + stream, + logging_obj, + api_key, + data, + messages, + print_verbose, + ): ## LOGGING logging_obj.post_call( input=messages, @@ -327,6 +236,272 @@ def completion( model_response.usage = usage return model_response + async def acompletion_stream_function( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + _is_function_call, + data=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ): + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + response = await self.async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + + completion_stream = response.aiter_lines() + + streamwrapper = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="anthropic", + logging_obj=logging_obj, + ) + return streamwrapper + + async def acompletion_function( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + _is_function_call, + data=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ): + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + response = await self.async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + return self.process_response( + model=model, + response=response, + model_response=model_response, + _is_function_call=_is_function_call, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + ) + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + ): + headers = validate_environment(api_key, headers) + _is_function_call = False + messages = copy.deepcopy(messages) + optional_params = copy.deepcopy(optional_params) + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + # Separate system prompt from rest of message + system_prompt_indices = [] + system_prompt = "" + for idx, message in enumerate(messages): + if message["role"] == "system": + system_prompt += message["content"] + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + if len(system_prompt) > 0: + optional_params["system"] = system_prompt + # Format rest of message according to anthropic guidelines + try: + messages = prompt_factory( + model=model, messages=messages, custom_llm_provider="anthropic" + ) + except Exception as e: + raise AnthropicError(status_code=400, message=str(e)) + + ## Load Config + config = litellm.AnthropicConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## Handle Tool Calling + if "tools" in optional_params: + _is_function_call = True + headers["anthropic-beta"] = "tools-2024-04-04" + + anthropic_tools = [] + for tool in optional_params["tools"]: + new_tool = tool["function"] + new_tool["input_schema"] = new_tool.pop("parameters") # rename key + anthropic_tools.append(new_tool) + + optional_params["tools"] = anthropic_tools + + stream = optional_params.pop("stream", None) + + data = { + "model": model, + "messages": messages, + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + print_verbose(f"_is_function_call: {_is_function_call}") + if acompletion == True: + if ( + stream and not _is_function_call + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose("makes async anthropic streaming POST request") + data["stream"] = stream + return self.acompletion_stream_function( + model=model, + messages=messages, + data=data, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + _is_function_call=_is_function_call, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) + else: + return self.acompletion_function( + model=model, + messages=messages, + data=data, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + _is_function_call=_is_function_call, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + ) + else: + ## COMPLETION CALL + if ( + stream and not _is_function_call + ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) + print_verbose("makes anthropic streaming POST request") + data["stream"] = stream + response = requests.post( + api_base, + headers=headers, + data=json.dumps(data), + stream=stream, + ) + + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + + completion_stream = response.iter_lines() + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="anthropic", + logging_obj=logging_obj, + ) + return streaming_response + + else: + response = requests.post( + api_base, headers=headers, data=json.dumps(data) + ) + if response.status_code != 200: + raise AnthropicError( + status_code=response.status_code, message=response.text + ) + return self.process_response( + model=model, + response=response, + model_response=model_response, + _is_function_call=_is_function_call, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + ) + + def embedding(self): + # logic for parsing in - calling - parsing out model embedding calls + pass + class ModelResponseIterator: def __init__(self, model_response): @@ -352,8 +527,3 @@ class ModelResponseIterator: raise StopAsyncIteration self.is_done = True return self.model_response - - -def embedding(): - # logic for parsing in - calling - parsing out model embedding calls - pass diff --git a/litellm/llms/anthropic_text.py b/litellm/llms/anthropic_text.py index bccc8c769..c9a9adfc2 100644 --- a/litellm/llms/anthropic_text.py +++ b/litellm/llms/anthropic_text.py @@ -4,7 +4,7 @@ from enum import Enum import requests import time from typing import Callable, Optional -from litellm.utils import ModelResponse, Usage +from litellm.utils import ModelResponse, Usage, CustomStreamWrapper import litellm from .prompt_templates.factory import prompt_factory, custom_prompt import httpx @@ -162,8 +162,15 @@ def completion( raise AnthropicError( status_code=response.status_code, message=response.text ) + completion_stream = response.iter_lines() + stream_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="anthropic", + logging_obj=logging_obj, + ) + return stream_response - return response.iter_lines() else: response = requests.post(api_base, headers=headers, data=json.dumps(data)) if response.status_code != 200: diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 10314d831..67e6c80da 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -1,21 +1,34 @@ import httpx, asyncio -from typing import Optional +from typing import Optional, Union, Mapping, Any + +# https://www.python-httpx.org/advanced/timeouts +_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) class AsyncHTTPHandler: - def __init__(self, concurrent_limit=1000): + def __init__( + self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000 + ): # Create a client with a connection pool self.client = httpx.AsyncClient( + timeout=timeout, limits=httpx.Limits( max_connections=concurrent_limit, max_keepalive_connections=concurrent_limit, - ) + ), ) async def close(self): # Close the client when you're done with it await self.client.aclose() + async def __aenter__(self): + return self.client + + async def __aexit__(self): + # close the client when exiting + await self.client.aclose() + async def get( self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None ): @@ -25,12 +38,15 @@ class AsyncHTTPHandler: async def post( self, url: str, - data: Optional[dict] = None, + data: Optional[Union[dict, str]] = None, # type: ignore params: Optional[dict] = None, headers: Optional[dict] = None, ): response = await self.client.post( - url, data=data, params=params, headers=headers + url, + data=data, # type: ignore + params=params, + headers=headers, ) return response diff --git a/litellm/main.py b/litellm/main.py index f07def97c..1ee16f36f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -39,7 +39,6 @@ from litellm.utils import ( get_optional_params_image_gen, ) from .llms import ( - anthropic, anthropic_text, together_ai, ai21, @@ -68,6 +67,7 @@ from .llms import ( from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion from .llms.azure_text import AzureTextCompletion +from .llms.anthropic import AnthropicChatCompletion from .llms.huggingface_restapi import Huggingface from .llms.prompt_templates.factory import ( prompt_factory, @@ -99,6 +99,7 @@ from litellm.utils import ( dotenv.load_dotenv() # Loading env variables using dotenv openai_chat_completions = OpenAIChatCompletion() openai_text_completions = OpenAITextCompletion() +anthropic_chat_completions = AnthropicChatCompletion() azure_chat_completions = AzureChatCompletion() azure_text_completions = AzureTextCompletion() huggingface = Huggingface() @@ -304,6 +305,7 @@ async def acompletion( or custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" + or custom_llm_provider == "anthropic" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1180,10 +1182,11 @@ def completion( or get_secret("ANTHROPIC_API_BASE") or "https://api.anthropic.com/v1/messages" ) - response = anthropic.completion( + response = anthropic_chat_completions.completion( model=model, messages=messages, api_base=api_base, + acompletion=acompletion, custom_prompt_dict=litellm.custom_prompt_dict, model_response=model_response, print_verbose=print_verbose, @@ -1195,19 +1198,6 @@ def completion( logging_obj=logging, headers=headers, ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and not isinstance(response, CustomStreamWrapper) - ): - # don't try to access stream object, - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="anthropic", - logging_obj=logging, - ) - if optional_params.get("stream", False) or acompletion == True: ## LOGGING logging.post_call( diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 7451f94ab..988a1a8e9 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -831,22 +831,25 @@ def test_bedrock_claude_3_streaming(): pytest.fail(f"Error occurred: {e}") -def test_claude_3_streaming_finish_reason(): +@pytest.mark.asyncio +async def test_claude_3_streaming_finish_reason(): try: litellm.set_verbose = True messages = [ {"role": "system", "content": "Be helpful"}, {"role": "user", "content": "What do you know?"}, ] - response: ModelResponse = completion( # type: ignore + response: ModelResponse = await litellm.acompletion( # type: ignore model="claude-3-opus-20240229", messages=messages, stream=True, + max_tokens=10, ) complete_response = "" - # Add any assertions here to check the response + # Add any assertions here to-check the response num_finish_reason = 0 - for idx, chunk in enumerate(response): + async for chunk in response: + print(f"chunk: {chunk}") if isinstance(chunk, ModelResponse): if chunk.choices[0].finish_reason is not None: num_finish_reason += 1 @@ -2285,7 +2288,7 @@ async def test_acompletion_claude_3_function_call_with_streaming(): elif chunk.choices[0].finish_reason is not None: # last chunk validate_final_streaming_function_calling_chunk(chunk=chunk) idx += 1 - # raise Exception("it worked!") + # raise Exception("it worked! ") except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 1f6258c53..6a58d56db 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8764,7 +8764,9 @@ class CustomStreamWrapper: return hold, curr_chunk def handle_anthropic_chunk(self, chunk): - str_line = chunk.decode("utf-8") # Convert bytes to string + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string text = "" is_finished = False finish_reason = None @@ -10024,6 +10026,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "text-completion-openai" or self.custom_llm_provider == "azure_text" + or self.custom_llm_provider == "anthropic" or self.custom_llm_provider == "huggingface" or self.custom_llm_provider == "ollama" or self.custom_llm_provider == "ollama_chat"