fix(bedrock_httpx.py): working async bedrock command r calls

This commit is contained in:
Krrish Dholakia 2024-05-11 16:45:20 -07:00
parent 59c8c0adff
commit 49ab1a1d3f
6 changed files with 374 additions and 78 deletions

View file

@ -151,19 +151,120 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_streaming_response(
self,
model: str,
response: requests.Response | httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: dict | str,
messages: List,
print_verbose,
encoding,
) -> CustomStreamWrapper:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
raise AnthropicError(
status_code=422,
message="Unprocessable response object - {}".format(response.text),
)
def process_response(
self,
model,
response,
model_response,
_is_function_call,
stream,
logging_obj,
api_key,
data,
messages,
model: str,
response: requests.Response | httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: dict | str,
messages: List,
print_verbose,
):
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
@ -216,51 +317,6 @@ class AnthropicChatCompletion(BaseLLM):
completion_response["stop_reason"]
)
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call and stream:
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
@ -273,7 +329,7 @@ class AnthropicChatCompletion(BaseLLM):
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage) # type: ignore
return model_response
async def acompletion_stream_function(
@ -289,7 +345,7 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj,
stream,
_is_function_call,
data=None,
data: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -331,12 +387,12 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
data: dict,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
):
) -> ModelResponse:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
@ -347,13 +403,14 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion(
@ -367,7 +424,7 @@ class AnthropicChatCompletion(BaseLLM):
encoding,
api_key,
logging_obj,
optional_params=None,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
@ -526,17 +583,33 @@ class AnthropicChatCompletion(BaseLLM):
raise AnthropicError(
status_code=response.status_code, message=response.text
)
return self.process_response(
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def embedding(self):

View file

@ -1,12 +1,32 @@
## This is a template base class to be used for adding new LLM providers via API calls
import litellm
import httpx
from typing import Optional
import httpx, requests
from typing import Optional, Union
from litellm.utils import Logging
class BaseLLM:
_client_session: Optional[httpx.Client] = None
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> litellm.utils.ModelResponse:
"""
Helper function to process the response across sync + async completion calls
"""
return model_response
def create_client_session(self):
if litellm.client_session:
_client_session = litellm.client_session

View file

@ -16,6 +16,7 @@ from litellm.utils import (
Message,
Choices,
get_secret,
Logging,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
@ -255,6 +256,70 @@ class BedrockLLM(BaseLLM):
return session.get_credentials()
def process_response(
self,
model: str,
response: requests.Response | httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise BedrockError(message=response.text, status_code=422)
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception as e:
raise BedrockError(message=response.text, status_code=422)
## CALCULATING USAGE - bedrock returns usage in the headers
prompt_tokens = int(
response.headers.get(
"x-amzn-bedrock-input-token-count",
len(encoding.encode("".join(m.get("content", "") for m in messages))),
)
)
completion_tokens = int(
response.headers.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response.choices[0].message.content, # type: ignore
disallowed_special=(),
)
),
)
)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def completion(
self,
model: str,
@ -268,8 +333,9 @@ class BedrockLLM(BaseLLM):
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
extra_headers: Optional[dict] = None,
client: Optional[HTTPHandler] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
try:
import boto3
@ -381,13 +447,39 @@ class BedrockLLM(BaseLLM):
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
if client is None:
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
@ -416,7 +508,62 @@ class BedrockLLM(BaseLLM):
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
return response
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs)

View file

@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: dict,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
try:
completion_response = response.json()
except:
raise PredibaseError(
message=response.text, status_code=response.status_code
)
raise PredibaseError(message=response.text, status_code=422)
if "error" in completion_response:
raise PredibaseError(
message=str(completion_response["error"]),
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
},
)
## COMPLETION CALL
if acompletion is True:
if acompletion == True:
### ASYNC STREAMING
if stream == True:
return self.async_streaming(

View file

@ -1479,6 +1479,11 @@ class Router:
return response
except Exception as e:
original_exception = e
"""
- Check if available deployments - 'get_healthy_deployments() -> List`
- if no, Check if available fallbacks - `is_fallback(model_group: str, exception) -> bool`
- if no, back-off and retry up till num_retries - `_router_should_retry -> float`
"""
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if (
isinstance(original_exception, litellm.ContextWindowExceededError)

View file

@ -2584,13 +2584,66 @@ def test_completion_chat_sagemaker_mistral():
# test_completion_chat_sagemaker_mistral()
def test_completion_bedrock_command_r():
def response_format_tests(response: litellm.ModelResponse):
assert isinstance(response.id, str)
assert response.id != ""
assert isinstance(response.object, str)
assert response.object != ""
assert isinstance(response.created, int)
assert isinstance(response.model, str)
assert response.model != ""
assert isinstance(response.choices, list)
assert len(response.choices) == 1
choice = response.choices[0]
assert isinstance(choice, litellm.Choices)
assert isinstance(choice.get("index"), int)
message = choice.get("message")
assert isinstance(message, litellm.Message)
assert isinstance(message.get("role"), str)
assert message.get("role") != ""
assert isinstance(message.get("content"), str)
assert message.get("content") != ""
assert choice.get("logprobs") is None
assert isinstance(choice.get("finish_reason"), str)
assert choice.get("finish_reason") != ""
assert isinstance(response.usage, litellm.Usage) # type: ignore
assert isinstance(response.usage.prompt_tokens, int) # type: ignore
assert isinstance(response.usage.completion_tokens, int) # type: ignore
assert isinstance(response.usage.total_tokens, int) # type: ignore
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_bedrock_command_r(sync_mode):
litellm.set_verbose = True
if sync_mode:
response = completion(
model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
response_format_tests(response=response)
else:
response = await litellm.acompletion(
model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
print(f"response: {response}")
response_format_tests(response=response)
print(f"response: {response}")