fix(bedrock_httpx.py): working async bedrock command r calls

This commit is contained in:
Krrish Dholakia 2024-05-11 16:45:20 -07:00
parent 59c8c0adff
commit 49ab1a1d3f
6 changed files with 374 additions and 78 deletions

View file

@ -151,19 +151,120 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def process_streaming_response(
self,
model: str,
response: requests.Response | httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: dict | str,
messages: List,
print_verbose,
encoding,
) -> CustomStreamWrapper:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
raise AnthropicError(
status_code=422,
message="Unprocessable response object - {}".format(response.text),
)
def process_response( def process_response(
self, self,
model, model: str,
response, response: requests.Response | httpx.Response,
model_response, model_response: ModelResponse,
_is_function_call, stream: bool,
stream, logging_obj: litellm.utils.Logging,
logging_obj, optional_params: dict,
api_key, api_key: str,
data, data: dict | str,
messages, messages: List,
print_verbose, print_verbose,
): encoding,
) -> ModelResponse:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -216,51 +317,6 @@ class AnthropicChatCompletion(BaseLLM):
completion_response["stop_reason"] completion_response["stop_reason"]
) )
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call and stream:
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"] prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"] completion_tokens = completion_response["usage"]["output_tokens"]
@ -273,7 +329,7 @@ class AnthropicChatCompletion(BaseLLM):
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
) )
model_response.usage = usage setattr(model_response, "usage", usage) # type: ignore
return model_response return model_response
async def acompletion_stream_function( async def acompletion_stream_function(
@ -289,7 +345,7 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj, logging_obj,
stream, stream,
_is_function_call, _is_function_call,
data=None, data: dict,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -331,12 +387,12 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj, logging_obj,
stream, stream,
_is_function_call, _is_function_call,
data=None, data: dict,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
): ) -> ModelResponse:
self.async_handler = AsyncHTTPHandler( self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
@ -347,13 +403,14 @@ class AnthropicChatCompletion(BaseLLM):
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
_is_function_call=_is_function_call,
stream=stream, stream=stream,
logging_obj=logging_obj, logging_obj=logging_obj,
api_key=api_key, api_key=api_key,
data=data, data=data,
messages=messages, messages=messages,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
) )
def completion( def completion(
@ -367,7 +424,7 @@ class AnthropicChatCompletion(BaseLLM):
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
acompletion=None, acompletion=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -526,17 +583,33 @@ class AnthropicChatCompletion(BaseLLM):
raise AnthropicError( raise AnthropicError(
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.text
) )
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response( return self.process_response(
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
_is_function_call=_is_function_call,
stream=stream, stream=stream,
logging_obj=logging_obj, logging_obj=logging_obj,
api_key=api_key, api_key=api_key,
data=data, data=data,
messages=messages, messages=messages,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
) )
def embedding(self): def embedding(self):

View file

@ -1,12 +1,32 @@
## This is a template base class to be used for adding new LLM providers via API calls ## This is a template base class to be used for adding new LLM providers via API calls
import litellm import litellm
import httpx import httpx, requests
from typing import Optional from typing import Optional, Union
from litellm.utils import Logging
class BaseLLM: class BaseLLM:
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> litellm.utils.ModelResponse:
"""
Helper function to process the response across sync + async completion calls
"""
return model_response
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session

View file

@ -16,6 +16,7 @@ from litellm.utils import (
Message, Message,
Choices, Choices,
get_secret, get_secret,
Logging,
) )
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
@ -255,6 +256,70 @@ class BedrockLLM(BaseLLM):
return session.get_credentials() return session.get_credentials()
def process_response(
self,
model: str,
response: requests.Response | httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise BedrockError(message=response.text, status_code=422)
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception as e:
raise BedrockError(message=response.text, status_code=422)
## CALCULATING USAGE - bedrock returns usage in the headers
prompt_tokens = int(
response.headers.get(
"x-amzn-bedrock-input-token-count",
len(encoding.encode("".join(m.get("content", "") for m in messages))),
)
)
completion_tokens = int(
response.headers.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response.choices[0].message.content, # type: ignore
disallowed_special=(),
)
),
)
)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def completion( def completion(
self, self,
model: str, model: str,
@ -268,8 +333,9 @@ class BedrockLLM(BaseLLM):
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool = False,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
client: Optional[HTTPHandler] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
try: try:
import boto3 import boto3
@ -381,13 +447,39 @@ class BedrockLLM(BaseLLM):
## COMPLETION CALL ## COMPLETION CALL
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest( request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers method="POST", url=endpoint_url, data=data, headers=headers
) )
sigv4.add_auth(request) sigv4.add_auth(request)
prepped = request.prepare() prepped = request.prepare()
if client is None: ### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {} _params = {}
if timeout is not None: if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
@ -416,7 +508,62 @@ class BedrockLLM(BaseLLM):
error_code = err.response.status_code error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text) raise BedrockError(status_code=error_code, message=response.text)
return response return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def embedding(self, *args, **kwargs): def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs) return super().embedding(*args, **kwargs)

View file

@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
logging_obj: litellm.utils.Logging, logging_obj: litellm.utils.Logging,
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: dict, data: Union[dict, str],
messages: list, messages: list,
print_verbose, print_verbose,
encoding, encoding,
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
try: try:
completion_response = response.json() completion_response = response.json()
except: except:
raise PredibaseError( raise PredibaseError(message=response.text, status_code=422)
message=response.text, status_code=response.status_code
)
if "error" in completion_response: if "error" in completion_response:
raise PredibaseError( raise PredibaseError(
message=str(completion_response["error"]), message=str(completion_response["error"]),
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
}, },
) )
## COMPLETION CALL ## COMPLETION CALL
if acompletion is True: if acompletion == True:
### ASYNC STREAMING ### ASYNC STREAMING
if stream == True: if stream == True:
return self.async_streaming( return self.async_streaming(

View file

@ -1479,6 +1479,11 @@ class Router:
return response return response
except Exception as e: except Exception as e:
original_exception = e original_exception = e
"""
- Check if available deployments - 'get_healthy_deployments() -> List`
- if no, Check if available fallbacks - `is_fallback(model_group: str, exception) -> bool`
- if no, back-off and retry up till num_retries - `_router_should_retry -> float`
"""
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
if ( if (
isinstance(original_exception, litellm.ContextWindowExceededError) isinstance(original_exception, litellm.ContextWindowExceededError)

View file

@ -2584,12 +2584,65 @@ def test_completion_chat_sagemaker_mistral():
# test_completion_chat_sagemaker_mistral() # test_completion_chat_sagemaker_mistral()
def test_completion_bedrock_command_r(): def response_format_tests(response: litellm.ModelResponse):
assert isinstance(response.id, str)
assert response.id != ""
assert isinstance(response.object, str)
assert response.object != ""
assert isinstance(response.created, int)
assert isinstance(response.model, str)
assert response.model != ""
assert isinstance(response.choices, list)
assert len(response.choices) == 1
choice = response.choices[0]
assert isinstance(choice, litellm.Choices)
assert isinstance(choice.get("index"), int)
message = choice.get("message")
assert isinstance(message, litellm.Message)
assert isinstance(message.get("role"), str)
assert message.get("role") != ""
assert isinstance(message.get("content"), str)
assert message.get("content") != ""
assert choice.get("logprobs") is None
assert isinstance(choice.get("finish_reason"), str)
assert choice.get("finish_reason") != ""
assert isinstance(response.usage, litellm.Usage) # type: ignore
assert isinstance(response.usage.prompt_tokens, int) # type: ignore
assert isinstance(response.usage.completion_tokens, int) # type: ignore
assert isinstance(response.usage.total_tokens, int) # type: ignore
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_bedrock_command_r(sync_mode):
litellm.set_verbose = True litellm.set_verbose = True
response = completion(
model="bedrock/cohere.command-r-plus-v1:0", if sync_mode:
messages=[{"role": "user", "content": "Hey! how's it going?"}], response = completion(
) model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
response_format_tests(response=response)
else:
response = await litellm.acompletion(
model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
print(f"response: {response}")
response_format_tests(response=response)
print(f"response: {response}") print(f"response: {response}")