diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 818c4ecb3a..f3e2e2d700 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -151,19 +151,120 @@ class AnthropicChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() + def process_streaming_response( + self, + model: str, + response: requests.Response | httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.utils.Logging, + optional_params: dict, + api_key: str, + data: dict | str, + messages: List, + print_verbose, + encoding, + ) -> CustomStreamWrapper: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise AnthropicError( + message=response.text, status_code=response.status_code + ) + text_content = "" + tool_calls = [] + for content in completion_response["content"]: + if content["type"] == "text": + text_content += content["text"] + ## TOOL CALLING + elif content["type"] == "tool_use": + tool_calls.append( + { + "id": content["id"], + "type": "function", + "function": { + "name": content["name"], + "arguments": json.dumps(content["input"]), + }, + } + ) + if "error" in completion_response: + raise AnthropicError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + + print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") + # return an iterator + streaming_model_response = ModelResponse(stream=True) + streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore + 0 + ].finish_reason + # streaming_model_response.choices = [litellm.utils.StreamingChoices()] + streaming_choice = litellm.utils.StreamingChoices() + streaming_choice.index = model_response.choices[0].index + _tool_calls = [] + print_verbose( + f"type of model_response.choices[0]: {type(model_response.choices[0])}" + ) + print_verbose(f"type of streaming_choice: {type(streaming_choice)}") + if isinstance(model_response.choices[0], litellm.Choices): + if getattr( + model_response.choices[0].message, "tool_calls", None + ) is not None and isinstance( + model_response.choices[0].message.tool_calls, list + ): + for tool_call in model_response.choices[0].message.tool_calls: + _tool_call = {**tool_call.dict(), "index": 0} + _tool_calls.append(_tool_call) + delta_obj = litellm.utils.Delta( + content=getattr(model_response.choices[0].message, "content", None), + role=model_response.choices[0].message.role, + tool_calls=_tool_calls, + ) + streaming_choice.delta = delta_obj + streaming_model_response.choices = [streaming_choice] + completion_stream = ModelResponseIterator( + model_response=streaming_model_response + ) + print_verbose( + "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + else: + raise AnthropicError( + status_code=422, + message="Unprocessable response object - {}".format(response.text), + ) + def process_response( self, - model, - response, - model_response, - _is_function_call, - stream, - logging_obj, - api_key, - data, - messages, + model: str, + response: requests.Response | httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.utils.Logging, + optional_params: dict, + api_key: str, + data: dict | str, + messages: List, print_verbose, - ): + encoding, + ) -> ModelResponse: ## LOGGING logging_obj.post_call( input=messages, @@ -216,51 +317,6 @@ class AnthropicChatCompletion(BaseLLM): completion_response["stop_reason"] ) - print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}") - if _is_function_call and stream: - print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") - # return an iterator - streaming_model_response = ModelResponse(stream=True) - streaming_model_response.choices[0].finish_reason = model_response.choices[ - 0 - ].finish_reason - # streaming_model_response.choices = [litellm.utils.StreamingChoices()] - streaming_choice = litellm.utils.StreamingChoices() - streaming_choice.index = model_response.choices[0].index - _tool_calls = [] - print_verbose( - f"type of model_response.choices[0]: {type(model_response.choices[0])}" - ) - print_verbose(f"type of streaming_choice: {type(streaming_choice)}") - if isinstance(model_response.choices[0], litellm.Choices): - if getattr( - model_response.choices[0].message, "tool_calls", None - ) is not None and isinstance( - model_response.choices[0].message.tool_calls, list - ): - for tool_call in model_response.choices[0].message.tool_calls: - _tool_call = {**tool_call.dict(), "index": 0} - _tool_calls.append(_tool_call) - delta_obj = litellm.utils.Delta( - content=getattr(model_response.choices[0].message, "content", None), - role=model_response.choices[0].message.role, - tool_calls=_tool_calls, - ) - streaming_choice.delta = delta_obj - streaming_model_response.choices = [streaming_choice] - completion_stream = ModelResponseIterator( - model_response=streaming_model_response - ) - print_verbose( - "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" - ) - return CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - ## CALCULATING USAGE prompt_tokens = completion_response["usage"]["input_tokens"] completion_tokens = completion_response["usage"]["output_tokens"] @@ -273,7 +329,7 @@ class AnthropicChatCompletion(BaseLLM): completion_tokens=completion_tokens, total_tokens=total_tokens, ) - model_response.usage = usage + setattr(model_response, "usage", usage) # type: ignore return model_response async def acompletion_stream_function( @@ -289,7 +345,7 @@ class AnthropicChatCompletion(BaseLLM): logging_obj, stream, _is_function_call, - data=None, + data: dict, optional_params=None, litellm_params=None, logger_fn=None, @@ -331,12 +387,12 @@ class AnthropicChatCompletion(BaseLLM): logging_obj, stream, _is_function_call, - data=None, - optional_params=None, + data: dict, + optional_params: dict, litellm_params=None, logger_fn=None, headers={}, - ): + ) -> ModelResponse: self.async_handler = AsyncHTTPHandler( timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) @@ -347,13 +403,14 @@ class AnthropicChatCompletion(BaseLLM): model=model, response=response, model_response=model_response, - _is_function_call=_is_function_call, stream=stream, logging_obj=logging_obj, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, ) def completion( @@ -367,7 +424,7 @@ class AnthropicChatCompletion(BaseLLM): encoding, api_key, logging_obj, - optional_params=None, + optional_params: dict, acompletion=None, litellm_params=None, logger_fn=None, @@ -526,17 +583,33 @@ class AnthropicChatCompletion(BaseLLM): raise AnthropicError( status_code=response.status_code, message=response.text ) + + if stream and _is_function_call: + return self.process_streaming_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) return self.process_response( model=model, response=response, model_response=model_response, - _is_function_call=_is_function_call, stream=stream, logging_obj=logging_obj, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, ) def embedding(self): diff --git a/litellm/llms/base.py b/litellm/llms/base.py index 62b8069f06..d940d94714 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -1,12 +1,32 @@ ## This is a template base class to be used for adding new LLM providers via API calls import litellm -import httpx -from typing import Optional +import httpx, requests +from typing import Optional, Union +from litellm.utils import Logging class BaseLLM: _client_session: Optional[httpx.Client] = None + def process_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: litellm.utils.ModelResponse, + stream: bool, + logging_obj: Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: list, + print_verbose, + encoding, + ) -> litellm.utils.ModelResponse: + """ + Helper function to process the response across sync + async completion calls + """ + return model_response + def create_client_session(self): if litellm.client_session: _client_session = litellm.client_session diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index d3062b5ed8..2c0e41b1d2 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -16,6 +16,7 @@ from litellm.utils import ( Message, Choices, get_secret, + Logging, ) import litellm from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt @@ -255,6 +256,70 @@ class BedrockLLM(BaseLLM): return session.get_credentials() + def process_response( + self, + model: str, + response: requests.Response | httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise BedrockError(message=response.text, status_code=422) + + try: + model_response.choices[0].message.content = completion_response["text"] # type: ignore + except Exception as e: + raise BedrockError(message=response.text, status_code=422) + + ## CALCULATING USAGE - bedrock returns usage in the headers + prompt_tokens = int( + response.headers.get( + "x-amzn-bedrock-input-token-count", + len(encoding.encode("".join(m.get("content", "") for m in messages))), + ) + ) + completion_tokens = int( + response.headers.get( + "x-amzn-bedrock-output-token-count", + len( + encoding.encode( + model_response.choices[0].message.content, # type: ignore + disallowed_special=(), + ) + ), + ) + ) + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + + return model_response + def completion( self, model: str, @@ -268,8 +333,9 @@ class BedrockLLM(BaseLLM): timeout: Optional[Union[float, httpx.Timeout]], litellm_params=None, logger_fn=None, + acompletion: bool = False, extra_headers: Optional[dict] = None, - client: Optional[HTTPHandler] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: try: import boto3 @@ -381,13 +447,39 @@ class BedrockLLM(BaseLLM): ## COMPLETION CALL headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} request = AWSRequest( method="POST", url=endpoint_url, data=data, headers=headers ) sigv4.add_auth(request) prepped = request.prepare() - if client is None: + ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + if isinstance(client, HTTPHandler): + client = None + + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=False, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + + if client is None or isinstance(client, AsyncHTTPHandler): _params = {} if timeout is not None: if isinstance(timeout, float) or isinstance(timeout, int): @@ -416,7 +508,62 @@ class BedrockLLM(BaseLLM): error_code = err.response.status_code raise BedrockError(status_code=error_code, message=response.text) - return response + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + optional_params=optional_params, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> ModelResponse: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + self.client = AsyncHTTPHandler(**_params) # type: ignore + else: + self.client = client # type: ignore + + response = await self.client.post(api_base, headers=headers, data=data) # type: ignore + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) def embedding(self, *args, **kwargs): return super().embedding(*args, **kwargs) diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index c3424d244b..1e7e1d3348 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM): logging_obj: litellm.utils.Logging, optional_params: dict, api_key: str, - data: dict, + data: Union[dict, str], messages: list, print_verbose, encoding, @@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM): try: completion_response = response.json() except: - raise PredibaseError( - message=response.text, status_code=response.status_code - ) + raise PredibaseError(message=response.text, status_code=422) if "error" in completion_response: raise PredibaseError( message=str(completion_response["error"]), @@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM): }, ) ## COMPLETION CALL - if acompletion is True: + if acompletion == True: ### ASYNC STREAMING if stream == True: return self.async_streaming( diff --git a/litellm/router.py b/litellm/router.py index f0d94908e9..33dc5c13ce 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1479,6 +1479,11 @@ class Router: return response except Exception as e: original_exception = e + """ + - Check if available deployments - 'get_healthy_deployments() -> List` + - if no, Check if available fallbacks - `is_fallback(model_group: str, exception) -> bool` + - if no, back-off and retry up till num_retries - `_router_should_retry -> float` + """ ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error if ( isinstance(original_exception, litellm.ContextWindowExceededError) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 0cf6dda835..1d245cd272 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2584,12 +2584,65 @@ def test_completion_chat_sagemaker_mistral(): # test_completion_chat_sagemaker_mistral() -def test_completion_bedrock_command_r(): +def response_format_tests(response: litellm.ModelResponse): + assert isinstance(response.id, str) + assert response.id != "" + + assert isinstance(response.object, str) + assert response.object != "" + + assert isinstance(response.created, int) + + assert isinstance(response.model, str) + assert response.model != "" + + assert isinstance(response.choices, list) + assert len(response.choices) == 1 + choice = response.choices[0] + assert isinstance(choice, litellm.Choices) + assert isinstance(choice.get("index"), int) + + message = choice.get("message") + assert isinstance(message, litellm.Message) + assert isinstance(message.get("role"), str) + assert message.get("role") != "" + assert isinstance(message.get("content"), str) + assert message.get("content") != "" + + assert choice.get("logprobs") is None + assert isinstance(choice.get("finish_reason"), str) + assert choice.get("finish_reason") != "" + + assert isinstance(response.usage, litellm.Usage) # type: ignore + assert isinstance(response.usage.prompt_tokens, int) # type: ignore + assert isinstance(response.usage.completion_tokens, int) # type: ignore + assert isinstance(response.usage.total_tokens, int) # type: ignore + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_bedrock_command_r(sync_mode): litellm.set_verbose = True - response = completion( - model="bedrock/cohere.command-r-plus-v1:0", - messages=[{"role": "user", "content": "Hey! how's it going?"}], - ) + + if sync_mode: + response = completion( + model="bedrock/cohere.command-r-plus-v1:0", + messages=[{"role": "user", "content": "Hey! how's it going?"}], + ) + + assert isinstance(response, litellm.ModelResponse) + + response_format_tests(response=response) + else: + response = await litellm.acompletion( + model="bedrock/cohere.command-r-plus-v1:0", + messages=[{"role": "user", "content": "Hey! how's it going?"}], + ) + + assert isinstance(response, litellm.ModelResponse) + + print(f"response: {response}") + response_format_tests(response=response) print(f"response: {response}")