Merge pull request #2879 from BerriAI/litellm_async_anthropic_api

[Feat] Async Anthropic API 97.5% lower median latency
This commit is contained in:
Ishaan Jaff 2024-04-07 09:56:52 -07:00 committed by GitHub
commit a5aef6ec00
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 339 additions and 150 deletions

View file

@ -7,7 +7,8 @@ from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx import httpx
@ -15,6 +16,8 @@ class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: " HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: " AI_PROMPT = "\n\nAssistant: "
# constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py
class AnthropicError(Exception): class AnthropicError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -36,7 +39,9 @@ class AnthropicConfig:
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
""" """
max_tokens: Optional[int] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) max_tokens: Optional[int] = (
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
)
stop_sequences: Optional[list] = None stop_sequences: Optional[list] = None
temperature: Optional[int] = None temperature: Optional[int] = None
top_p: Optional[int] = None top_p: Optional[int] = None
@ -46,7 +51,9 @@ class AnthropicConfig:
def __init__( def __init__(
self, self,
max_tokens: Optional[int] = 4096, # You can pass in a value yourself or use the default value 4096 max_tokens: Optional[
int
] = 4096, # You can pass in a value yourself or use the default value 4096
stop_sequences: Optional[list] = None, stop_sequences: Optional[list] = None,
temperature: Optional[int] = None, temperature: Optional[int] = None,
top_p: Optional[int] = None, top_p: Optional[int] = None,
@ -95,121 +102,23 @@ def validate_environment(api_key, user_headers):
return headers return headers
def completion( class AnthropicChatCompletion(BaseLLM):
model: str, def __init__(self) -> None:
messages: list, super().__init__()
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
_is_function_call = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
except Exception as e:
raise AnthropicError(status_code=400, message=str(e))
## Load Config
config = litellm.AnthropicConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
headers["anthropic-beta"] = "tools-2024-04-04"
anthropic_tools = []
for tool in optional_params["tools"]:
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
## COMPLETION CALL
if (
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=stream,
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return response.iter_lines()
else:
response = requests.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
def process_response(
self,
model,
response,
model_response,
_is_function_call,
stream,
logging_obj,
api_key,
data,
messages,
print_verbose,
):
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -327,6 +236,272 @@ def completion(
model_response.usage = usage model_response.usage = usage
return model_response return model_response
async def acompletion_stream_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
return streamwrapper
async def acompletion_function(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
)
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
):
headers = validate_environment(api_key, headers)
_is_function_call = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
except Exception as e:
raise AnthropicError(status_code=400, message=str(e))
## Load Config
config = litellm.AnthropicConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
headers["anthropic-beta"] = "tools-2024-04-04"
anthropic_tools = []
for tool in optional_params["tools"]:
new_tool = tool["function"]
new_tool["input_schema"] = new_tool.pop("parameters") # rename key
anthropic_tools.append(new_tool)
optional_params["tools"] = anthropic_tools
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
if acompletion == True:
if (
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes async anthropic streaming POST request")
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
return self.acompletion_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
## COMPLETION CALL
if (
stream and not _is_function_call
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=stream,
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
completion_stream = response.iter_lines()
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
return streaming_response
else:
response = requests.post(
api_base, headers=headers, data=json.dumps(data)
)
if response.status_code != 200:
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
)
def embedding(self):
# logic for parsing in - calling - parsing out model embedding calls
pass
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, model_response): def __init__(self, model_response):
@ -352,8 +527,3 @@ class ModelResponseIterator:
raise StopAsyncIteration raise StopAsyncIteration
self.is_done = True self.is_done = True
return self.model_response return self.model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx import httpx
@ -162,8 +162,15 @@ def completion(
raise AnthropicError( raise AnthropicError(
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.text
) )
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
)
return stream_response
return response.iter_lines()
else: else:
response = requests.post(api_base, headers=headers, data=json.dumps(data)) response = requests.post(api_base, headers=headers, data=json.dumps(data))
if response.status_code != 200: if response.status_code != 200:

View file

@ -1,21 +1,34 @@
import httpx, asyncio import httpx, asyncio
from typing import Optional from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
class AsyncHTTPHandler: class AsyncHTTPHandler:
def __init__(self, concurrent_limit=1000): def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
):
# Create a client with a connection pool # Create a client with a connection pool
self.client = httpx.AsyncClient( self.client = httpx.AsyncClient(
timeout=timeout,
limits=httpx.Limits( limits=httpx.Limits(
max_connections=concurrent_limit, max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit, max_keepalive_connections=concurrent_limit,
) ),
) )
async def close(self): async def close(self):
# Close the client when you're done with it # Close the client when you're done with it
await self.client.aclose() await self.client.aclose()
async def __aenter__(self):
return self.client
async def __aexit__(self):
# close the client when exiting
await self.client.aclose()
async def get( async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
): ):
@ -25,12 +38,15 @@ class AsyncHTTPHandler:
async def post( async def post(
self, self,
url: str, url: str,
data: Optional[dict] = None, data: Optional[Union[dict, str]] = None, # type: ignore
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
): ):
response = await self.client.post( response = await self.client.post(
url, data=data, params=params, headers=headers url,
data=data, # type: ignore
params=params,
headers=headers,
) )
return response return response

View file

@ -39,7 +39,6 @@ from litellm.utils import (
get_optional_params_image_gen, get_optional_params_image_gen,
) )
from .llms import ( from .llms import (
anthropic,
anthropic_text, anthropic_text,
together_ai, together_ai,
ai21, ai21,
@ -68,6 +67,7 @@ from .llms import (
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
@ -99,6 +99,7 @@ from litellm.utils import (
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion() openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion() openai_text_completions = OpenAITextCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
@ -304,6 +305,7 @@ async def acompletion(
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini" or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -1180,10 +1182,11 @@ def completion(
or get_secret("ANTHROPIC_API_BASE") or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/messages" or "https://api.anthropic.com/v1/messages"
) )
response = anthropic.completion( response = anthropic_chat_completions.completion(
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
acompletion=acompletion,
custom_prompt_dict=litellm.custom_prompt_dict, custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
@ -1195,19 +1198,6 @@ def completion(
logging_obj=logging, logging_obj=logging,
headers=headers, headers=headers,
) )
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="anthropic",
logging_obj=logging,
)
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) or acompletion == True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(

View file

@ -831,22 +831,25 @@ def test_bedrock_claude_3_streaming():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_claude_3_streaming_finish_reason(): @pytest.mark.asyncio
async def test_claude_3_streaming_finish_reason():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
messages = [ messages = [
{"role": "system", "content": "Be helpful"}, {"role": "system", "content": "Be helpful"},
{"role": "user", "content": "What do you know?"}, {"role": "user", "content": "What do you know?"},
] ]
response: ModelResponse = completion( # type: ignore response: ModelResponse = await litellm.acompletion( # type: ignore
model="claude-3-opus-20240229", model="claude-3-opus-20240229",
messages=messages, messages=messages,
stream=True, stream=True,
max_tokens=10,
) )
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to-check the response
num_finish_reason = 0 num_finish_reason = 0
for idx, chunk in enumerate(response): async for chunk in response:
print(f"chunk: {chunk}")
if isinstance(chunk, ModelResponse): if isinstance(chunk, ModelResponse):
if chunk.choices[0].finish_reason is not None: if chunk.choices[0].finish_reason is not None:
num_finish_reason += 1 num_finish_reason += 1
@ -2285,7 +2288,7 @@ async def test_acompletion_claude_3_function_call_with_streaming():
elif chunk.choices[0].finish_reason is not None: # last chunk elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk) validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1 idx += 1
# raise Exception("it worked!") # raise Exception("it worked! ")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -8764,7 +8764,9 @@ class CustomStreamWrapper:
return hold, curr_chunk return hold, curr_chunk
def handle_anthropic_chunk(self, chunk): def handle_anthropic_chunk(self, chunk):
str_line = chunk.decode("utf-8") # Convert bytes to string str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
text = "" text = ""
is_finished = False is_finished = False
finish_reason = None finish_reason = None
@ -10024,6 +10026,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "custom_openai"
or self.custom_llm_provider == "text-completion-openai" or self.custom_llm_provider == "text-completion-openai"
or self.custom_llm_provider == "azure_text" or self.custom_llm_provider == "azure_text"
or self.custom_llm_provider == "anthropic"
or self.custom_llm_provider == "huggingface" or self.custom_llm_provider == "huggingface"
or self.custom_llm_provider == "ollama" or self.custom_llm_provider == "ollama"
or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "ollama_chat"