fix(bedrock_httpx.py): move anthropic bedrock calls to httpx

Fixing https://github.com/BerriAI/litellm/issues/2921
This commit is contained in:
Krrish Dholakia 2024-05-16 21:51:55 -07:00
parent 10a672634d
commit 180bc46ca4
7 changed files with 298 additions and 56 deletions

View file

@ -21,7 +21,7 @@ class BaseLLM:
messages: list, messages: list,
print_verbose, print_verbose,
encoding, encoding,
) -> litellm.utils.ModelResponse: ) -> Union[litellm.utils.ModelResponse, litellm.utils.CustomStreamWrapper]:
""" """
Helper function to process the response across sync + async completion calls Helper function to process the response across sync + async completion calls
""" """

View file

@ -1,6 +1,6 @@
# What is this? # What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls). ## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V0 - just covers cohere command-r support ## V1 - covers cohere + anthropic claude-3 support
import os, types import os, types
import json import json
@ -29,12 +29,20 @@ from litellm.utils import (
get_secret, get_secret,
Logging, Logging,
) )
import litellm import litellm, uuid
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
cohere_message_pt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
contains_tag,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM from .base import BaseLLM
import httpx # type: ignore import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import * from litellm.types.llms.bedrock import *
@ -280,7 +288,8 @@ class BedrockLLM(BaseLLM):
messages: List, messages: List,
print_verbose, print_verbose,
encoding, encoding,
) -> ModelResponse: ) -> Union[ModelResponse, CustomStreamWrapper]:
provider = model.split(".")[0]
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -297,26 +306,147 @@ class BedrockLLM(BaseLLM):
raise BedrockError(message=response.text, status_code=422) raise BedrockError(message=response.text, status_code=422)
try: try:
if provider == "cohere":
model_response.choices[0].message.content = completion_response["text"] # type: ignore model_response.choices[0].message.content = completion_response["text"] # type: ignore
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
json_schemas: dict = {}
_is_function_call = False
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
for tool in optional_params["tools"]:
json_schemas[tool["function"]["name"]] = tool[
"function"
].get("parameters", None)
outputText = completion_response.get("content")[0].get("text", None)
if outputText is not None and contains_tag(
"invoke", outputText
): # OUTPUT PARSE FUNCTION CALL
function_name = extract_between_tags("tool_name", outputText)[0]
function_arguments_str = extract_between_tags(
"invoke", outputText
)[0].strip()
function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>"
)
function_arguments = parse_xml_params(
function_arguments_str,
json_schema=json_schemas.get(
function_name, None
), # check if we have a json schema for this function name)
)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
if (
_is_function_call == True
and stream is not None
and stream == True
):
print_verbose(
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
)
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = getattr(
model_response.choices[0], "finish_reason", "stop"
)
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(
f"type of streaming_choice: {type(streaming_choice)}"
)
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[
0
].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(
model_response.choices[0].message, "content", None
),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return litellm.CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
model_response["finish_reason"] = map_finish_reason(
completion_response.get("stop_reason", "")
)
_usage = litellm.Usage(
prompt_tokens=completion_response["usage"]["input_tokens"],
completion_tokens=completion_response["usage"]["output_tokens"],
total_tokens=completion_response["usage"]["input_tokens"]
+ completion_response["usage"]["output_tokens"],
)
setattr(model_response, "usage", _usage)
else:
outputText = completion_response["completion"]
model_response["finish_reason"] = completion_response["stop_reason"]
except Exception as e: except Exception as e:
raise BedrockError(message=response.text, status_code=422) raise BedrockError(
message="Error processing={}, Received error={}".format(
response.text, str(e)
),
status_code=422,
)
## CALCULATING USAGE - bedrock returns usage in the headers ## CALCULATING USAGE - bedrock returns usage in the headers
bedrock_input_tokens = response.headers.get(
"x-amzn-bedrock-input-token-count", None
)
bedrock_output_tokens = response.headers.get(
"x-amzn-bedrock-output-token-count", None
)
prompt_tokens = int( prompt_tokens = int(
response.headers.get( bedrock_input_tokens or litellm.token_counter(messages=messages)
"x-amzn-bedrock-input-token-count",
len(encoding.encode("".join(m.get("content", "") for m in messages))),
)
) )
completion_tokens = int( completion_tokens = int(
response.headers.get( bedrock_output_tokens
"x-amzn-bedrock-output-token-count", or litellm.token_counter(
len( text=model_response.choices[0].message.content, # type: ignore
encoding.encode( count_response_tokens=True,
model_response.choices[0].message.content, # type: ignore
disallowed_special=(),
)
),
) )
) )
@ -426,7 +556,7 @@ class BedrockLLM(BaseLLM):
model, messages, provider, custom_prompt_dict model, messages, provider, custom_prompt_dict
) )
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
json_schemas: dict = {}
if provider == "cohere": if provider == "cohere":
if model.startswith("cohere.command-r"): if model.startswith("cohere.command-r"):
## LOAD CONFIG ## LOAD CONFIG
@ -453,6 +583,56 @@ class BedrockLLM(BaseLLM):
True # cohere requires stream = True in inference params True # cohere requires stream = True in inference params
) )
data = json.dumps({"prompt": prompt, **inference_params}) data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
# Separate system prompt from rest of message
system_prompt_idx: list[int] = []
system_messages: list[str] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
system_messages.append(message["content"])
system_prompt_idx.append(idx)
if len(system_prompt_idx) > 0:
inference_params["system"] = "\n".join(system_messages)
messages = [
i for j, i in enumerate(messages) if j not in system_prompt_idx
]
# Format rest of message according to anthropic guidelines
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic_xml"
) # type: ignore
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## Handle Tool Calling
if "tools" in inference_params:
_is_function_call = True
for tool in inference_params["tools"]:
json_schemas[tool["function"]["name"]] = tool["function"].get(
"parameters", None
)
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=inference_params["tools"]
)
inference_params["system"] = (
inference_params.get("system", "\n")
+ tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
inference_params.pop("tools")
data = json.dumps({"messages": messages, **inference_params})
else:
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
else: else:
raise Exception("UNSUPPORTED PROVIDER") raise Exception("UNSUPPORTED PROVIDER")
@ -541,7 +721,7 @@ class BedrockLLM(BaseLLM):
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.text
) )
decoder = AWSEventStreamDecoder() decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper( streaming_response = CustomStreamWrapper(
@ -591,7 +771,7 @@ class BedrockLLM(BaseLLM):
logger_fn=None, logger_fn=None,
headers={}, headers={},
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse: ) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None: if client is None:
_params = {} _params = {}
if timeout is not None: if timeout is not None:
@ -650,7 +830,7 @@ class BedrockLLM(BaseLLM):
if response.status_code != 200: if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text) raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder() decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024)) completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper( streaming_response = CustomStreamWrapper(
@ -676,11 +856,70 @@ def get_response_stream_shape():
class AWSEventStreamDecoder: class AWSEventStreamDecoder:
def __init__(self) -> None: def __init__(self, model: str) -> None:
from botocore.parsers import EventStreamJSONParser from botocore.parsers import EventStreamJSONParser
self.model = model
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
is_finished = False
finish_reason = ""
if "outputText" in chunk_data:
text = chunk_data["outputText"]
# ai21 mapping
if "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None:
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
if (
len(chunk_data["outputs"]) == 1
and chunk_data["outputs"][0].get("text", None) is not None
):
text = chunk_data["outputs"][0]["text"]
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.cohere mappings ###############
# meta mapping
elif "generation" in chunk_data:
text = chunk_data["generation"] # bedrock.meta
# cohere mapping
elif "text" in chunk_data:
text = chunk_data["text"] # bedrock.cohere
# cohere mapping for finish reason
elif "finish_reason" in chunk_data:
finish_reason = chunk_data["finish_reason"]
is_finished = True
elif chunk_data.get("completionReason", None):
is_finished = True
finish_reason = chunk_data["completionReason"]
return GenericStreamingChunk(
**{
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered""" """Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer from botocore.eventstream import EventStreamBuffer
@ -693,12 +932,7 @@ class AWSEventStreamDecoder:
if message: if message:
# sse_event = ServerSentEvent(data=message, event="completion") # sse_event = ServerSentEvent(data=message, event="completion")
_data = json.loads(message) _data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk( yield self._chunk_parser(chunk_data=_data)
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
async def aiter_bytes( async def aiter_bytes(
self, iterator: AsyncIterator[bytes] self, iterator: AsyncIterator[bytes]
@ -713,12 +947,7 @@ class AWSEventStreamDecoder:
message = self._parse_message_from_event(event) message = self._parse_message_from_event(event)
if message: if message:
_data = json.loads(message) _data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk( yield self._chunk_parser(chunk_data=_data)
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
def _parse_message_from_event(self, event) -> Optional[str]: def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict() response_dict = event.to_response_dict()

View file

@ -326,7 +326,10 @@ async def acompletion(
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic" or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase" or custom_llm_provider == "predibase"
or (custom_llm_provider == "bedrock" and "cohere" in model) or (
custom_llm_provider == "bedrock"
and ("cohere" in model or "anthropic" in model)
)
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -1979,7 +1982,7 @@ def completion(
# boto3 reads keys from .env # boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if "cohere" in model: if "cohere" in model or "anthropic" in model:
response = bedrock_chat_completion.completion( response = bedrock_chat_completion.completion(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -4874,11 +4874,12 @@ async def token_counter(request: TokenCountRequest):
model_to_use = ( model_to_use = (
litellm_model_name or request.model litellm_model_name or request.model
) # use litellm model name, if it's not avalable then fallback to request.model ) # use litellm model name, if it's not avalable then fallback to request.model
total_tokens, tokenizer_used = token_counter( _tokenizer_used = litellm.utils._select_tokenizer(model=model_to_use)
tokenizer_used = str(_tokenizer_used["type"])
total_tokens = token_counter(
model=model_to_use, model=model_to_use,
text=prompt, text=prompt,
messages=messages, messages=messages,
return_tokenizer_used=True,
) )
return TokenCountResponse( return TokenCountResponse(
total_tokens=total_tokens, total_tokens=total_tokens,

View file

@ -2663,13 +2663,17 @@ def response_format_tests(response: litellm.ModelResponse):
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"model",
["bedrock/cohere.command-r-plus-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_bedrock_command_r(sync_mode): async def test_completion_bedrock_httpx_models(sync_mode, model):
litellm.set_verbose = True litellm.set_verbose = True
if sync_mode: if sync_mode:
response = completion( response = completion(
model="bedrock/cohere.command-r-plus-v1:0", model=model,
messages=[{"role": "user", "content": "Hey! how's it going?"}], messages=[{"role": "user", "content": "Hey! how's it going?"}],
) )
@ -2678,7 +2682,7 @@ async def test_completion_bedrock_command_r(sync_mode):
response_format_tests(response=response) response_format_tests(response=response)
else: else:
response = await litellm.acompletion( response = await litellm.acompletion(
model="bedrock/cohere.command-r-plus-v1:0", model=model,
messages=[{"role": "user", "content": "Hey! how's it going?"}], messages=[{"role": "user", "content": "Hey! how's it going?"}],
) )

View file

@ -1041,14 +1041,21 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"model",
[
# "bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0"
],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bedrock_cohere_command_r_streaming(sync_mode): async def test_bedrock_httpx_streaming(sync_mode, model):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
if sync_mode: if sync_mode:
final_chunk: Optional[litellm.ModelResponse] = None final_chunk: Optional[litellm.ModelResponse] = None
response: litellm.CustomStreamWrapper = completion( # type: ignore response: litellm.CustomStreamWrapper = completion( # type: ignore
model="bedrock/cohere.command-r-plus-v1:0", model=model,
messages=messages, messages=messages,
max_tokens=10, # type: ignore max_tokens=10, # type: ignore
stream=True, stream=True,
@ -1069,7 +1076,7 @@ async def test_bedrock_cohere_command_r_streaming(sync_mode):
raise Exception("Empty response received") raise Exception("Empty response received")
else: else:
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model="bedrock/cohere.command-r-plus-v1:0", model=model,
messages=messages, messages=messages,
max_tokens=100, # type: ignore max_tokens=100, # type: ignore
stream=True, stream=True,

View file

@ -4123,8 +4123,7 @@ def token_counter(
text: Optional[Union[str, List[str]]] = None, text: Optional[Union[str, List[str]]] = None,
messages: Optional[List] = None, messages: Optional[List] = None,
count_response_tokens: Optional[bool] = False, count_response_tokens: Optional[bool] = False,
return_tokenizer_used: Optional[bool] = False, ) -> int:
):
""" """
Count the number of tokens in a given text using a specified model. Count the number of tokens in a given text using a specified model.
@ -4216,10 +4215,6 @@ def token_counter(
) )
else: else:
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
_tokenizer_type = tokenizer_json["type"]
if return_tokenizer_used:
# used by litellm proxy server -> POST /utils/token_counter
return num_tokens, _tokenizer_type
return num_tokens return num_tokens
@ -10642,7 +10637,7 @@ class CustomStreamWrapper:
raise e raise e
def handle_bedrock_stream(self, chunk): def handle_bedrock_stream(self, chunk):
if "cohere" in self.model: if "cohere" in self.model or "anthropic" in self.model:
return { return {
"text": chunk["text"], "text": chunk["text"],
"is_finished": chunk["is_finished"], "is_finished": chunk["is_finished"],
@ -11513,7 +11508,10 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "replicate" or self.custom_llm_provider == "replicate"
or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "predibase" or self.custom_llm_provider == "predibase"
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model) or (
self.custom_llm_provider == "bedrock"
and ("cohere" in self.model or "anthropic" in self.model)
)
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):
async for chunk in self.completion_stream: async for chunk in self.completion_stream: