mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(bedrock_httpx.py): working async bedrock command r calls
This commit is contained in:
parent
59c8c0adff
commit
49ab1a1d3f
6 changed files with 374 additions and 78 deletions
|
@ -151,19 +151,120 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
response: requests.Response | httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: dict | str,
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> CustomStreamWrapper:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise AnthropicError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for content in completion_response["content"]:
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content["name"],
|
||||
"arguments": json.dumps(content["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise AnthropicError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||
0
|
||||
].finish_reason
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
_tool_calls = []
|
||||
print_verbose(
|
||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||
)
|
||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||
if isinstance(model_response.choices[0], litellm.Choices):
|
||||
if getattr(
|
||||
model_response.choices[0].message, "tool_calls", None
|
||||
) is not None and isinstance(
|
||||
model_response.choices[0].message.tool_calls, list
|
||||
):
|
||||
for tool_call in model_response.choices[0].message.tool_calls:
|
||||
_tool_call = {**tool_call.dict(), "index": 0}
|
||||
_tool_calls.append(_tool_call)
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(model_response.choices[0].message, "content", None),
|
||||
role=model_response.choices[0].message.role,
|
||||
tool_calls=_tool_calls,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
raise AnthropicError(
|
||||
status_code=422,
|
||||
message="Unprocessable response object - {}".format(response.text),
|
||||
)
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model,
|
||||
response,
|
||||
model_response,
|
||||
_is_function_call,
|
||||
stream,
|
||||
logging_obj,
|
||||
api_key,
|
||||
data,
|
||||
messages,
|
||||
model: str,
|
||||
response: requests.Response | httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: dict | str,
|
||||
messages: List,
|
||||
print_verbose,
|
||||
):
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
@ -216,51 +317,6 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
||||
if _is_function_call and stream:
|
||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
||||
0
|
||||
].finish_reason
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
_tool_calls = []
|
||||
print_verbose(
|
||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||
)
|
||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||
if isinstance(model_response.choices[0], litellm.Choices):
|
||||
if getattr(
|
||||
model_response.choices[0].message, "tool_calls", None
|
||||
) is not None and isinstance(
|
||||
model_response.choices[0].message.tool_calls, list
|
||||
):
|
||||
for tool_call in model_response.choices[0].message.tool_calls:
|
||||
_tool_call = {**tool_call.dict(), "index": 0}
|
||||
_tool_calls.append(_tool_call)
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(model_response.choices[0].message, "content", None),
|
||||
role=model_response.choices[0].message.role,
|
||||
tool_calls=_tool_calls,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
|
@ -273,7 +329,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
return model_response
|
||||
|
||||
async def acompletion_stream_function(
|
||||
|
@ -289,7 +345,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
data: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -331,12 +387,12 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
optional_params=None,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
) -> ModelResponse:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
@ -347,13 +403,14 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion(
|
||||
|
@ -367,7 +424,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -526,17 +583,33 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
return self.process_response(
|
||||
|
||||
if stream and _is_function_call:
|
||||
return self.process_streaming_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
|
|
|
@ -1,12 +1,32 @@
|
|||
## This is a template base class to be used for adding new LLM providers via API calls
|
||||
import litellm
|
||||
import httpx
|
||||
from typing import Optional
|
||||
import httpx, requests
|
||||
from typing import Optional, Union
|
||||
from litellm.utils import Logging
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: litellm.utils.ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> litellm.utils.ModelResponse:
|
||||
"""
|
||||
Helper function to process the response across sync + async completion calls
|
||||
"""
|
||||
return model_response
|
||||
|
||||
def create_client_session(self):
|
||||
if litellm.client_session:
|
||||
_client_session = litellm.client_session
|
||||
|
|
|
@ -16,6 +16,7 @@ from litellm.utils import (
|
|||
Message,
|
||||
Choices,
|
||||
get_secret,
|
||||
Logging,
|
||||
)
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
||||
|
@ -255,6 +256,70 @@ class BedrockLLM(BaseLLM):
|
|||
|
||||
return session.get_credentials()
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: requests.Response | httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise BedrockError(message=response.text, status_code=422)
|
||||
|
||||
try:
|
||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||
except Exception as e:
|
||||
raise BedrockError(message=response.text, status_code=422)
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
prompt_tokens = int(
|
||||
response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count",
|
||||
len(encoding.encode("".join(m.get("content", "") for m in messages))),
|
||||
)
|
||||
)
|
||||
completion_tokens = int(
|
||||
response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count",
|
||||
len(
|
||||
encoding.encode(
|
||||
model_response.choices[0].message.content, # type: ignore
|
||||
disallowed_special=(),
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -268,8 +333,9 @@ class BedrockLLM(BaseLLM):
|
|||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
try:
|
||||
import boto3
|
||||
|
@ -381,13 +447,39 @@ class BedrockLLM(BaseLLM):
|
|||
|
||||
## COMPLETION CALL
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
prepped = request.prepare()
|
||||
|
||||
if client is None:
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=False,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
|
@ -416,7 +508,62 @@ class BedrockLLM(BaseLLM):
|
|||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=response.text)
|
||||
|
||||
return response
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client # type: ignore
|
||||
|
||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def embedding(self, *args, **kwargs):
|
||||
return super().embedding(*args, **kwargs)
|
||||
|
|
|
@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
|
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise PredibaseError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
raise PredibaseError(message=response.text, status_code=422)
|
||||
if "error" in completion_response:
|
||||
raise PredibaseError(
|
||||
message=str(completion_response["error"]),
|
||||
|
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if acompletion is True:
|
||||
if acompletion == True:
|
||||
### ASYNC STREAMING
|
||||
if stream == True:
|
||||
return self.async_streaming(
|
||||
|
|
|
@ -1479,6 +1479,11 @@ class Router:
|
|||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
"""
|
||||
- Check if available deployments - 'get_healthy_deployments() -> List`
|
||||
- if no, Check if available fallbacks - `is_fallback(model_group: str, exception) -> bool`
|
||||
- if no, back-off and retry up till num_retries - `_router_should_retry -> float`
|
||||
"""
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||
if (
|
||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
||||
|
|
|
@ -2584,13 +2584,66 @@ def test_completion_chat_sagemaker_mistral():
|
|||
# test_completion_chat_sagemaker_mistral()
|
||||
|
||||
|
||||
def test_completion_bedrock_command_r():
|
||||
def response_format_tests(response: litellm.ModelResponse):
|
||||
assert isinstance(response.id, str)
|
||||
assert response.id != ""
|
||||
|
||||
assert isinstance(response.object, str)
|
||||
assert response.object != ""
|
||||
|
||||
assert isinstance(response.created, int)
|
||||
|
||||
assert isinstance(response.model, str)
|
||||
assert response.model != ""
|
||||
|
||||
assert isinstance(response.choices, list)
|
||||
assert len(response.choices) == 1
|
||||
choice = response.choices[0]
|
||||
assert isinstance(choice, litellm.Choices)
|
||||
assert isinstance(choice.get("index"), int)
|
||||
|
||||
message = choice.get("message")
|
||||
assert isinstance(message, litellm.Message)
|
||||
assert isinstance(message.get("role"), str)
|
||||
assert message.get("role") != ""
|
||||
assert isinstance(message.get("content"), str)
|
||||
assert message.get("content") != ""
|
||||
|
||||
assert choice.get("logprobs") is None
|
||||
assert isinstance(choice.get("finish_reason"), str)
|
||||
assert choice.get("finish_reason") != ""
|
||||
|
||||
assert isinstance(response.usage, litellm.Usage) # type: ignore
|
||||
assert isinstance(response.usage.prompt_tokens, int) # type: ignore
|
||||
assert isinstance(response.usage.completion_tokens, int) # type: ignore
|
||||
assert isinstance(response.usage.total_tokens, int) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_bedrock_command_r(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
||||
response_format_tests(response=response)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
||||
print(f"response: {response}")
|
||||
response_format_tests(response=response)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue