diff --git a/enterprise/enterprise_callbacks/generic_api_callback.py b/enterprise/enterprise_callbacks/generic_api_callback.py index 076c13d5e..cf1d22e8f 100644 --- a/enterprise/enterprise_callbacks/generic_api_callback.py +++ b/enterprise/enterprise_callbacks/generic_api_callback.py @@ -10,7 +10,6 @@ from litellm.caching import DualCache from typing import Literal, Union -dotenv.load_dotenv() # Loading env variables using dotenv import traceback @@ -19,8 +18,6 @@ import traceback import dotenv, os import requests - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import datetime, subprocess, sys import litellm, uuid diff --git a/litellm/__init__.py b/litellm/__init__.py index cdbd74cdb..6c7b26617 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -743,6 +743,7 @@ from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig +from .llms.bedrock_httpx import AmazonCohereChatConfig from .llms.bedrock import ( AmazonTitanConfig, AmazonAI21Config, diff --git a/litellm/integrations/aispend.py b/litellm/integrations/aispend.py index a893f8923..2fe8ea0df 100644 --- a/litellm/integrations/aispend.py +++ b/litellm/integrations/aispend.py @@ -1,8 +1,6 @@ #### What this does #### # On success + failure, log events to aispend.io import dotenv, os - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import datetime diff --git a/litellm/integrations/berrispend.py b/litellm/integrations/berrispend.py index 1f0ae4581..7d30b706c 100644 --- a/litellm/integrations/berrispend.py +++ b/litellm/integrations/berrispend.py @@ -3,7 +3,6 @@ import dotenv, os import requests # type: ignore -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import datetime diff --git a/litellm/integrations/clickhouse.py b/litellm/integrations/clickhouse.py index 7d1fb37d9..0c38b8626 100644 --- a/litellm/integrations/clickhouse.py +++ b/litellm/integrations/clickhouse.py @@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache from typing import Literal, Union - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback @@ -18,8 +16,6 @@ import traceback import dotenv, os import requests - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import datetime, subprocess, sys import litellm, uuid diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 8a3e0f467..d50882592 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache from typing import Literal, Union, Optional - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback diff --git a/litellm/integrations/datadog.py b/litellm/integrations/datadog.py index d969341fc..6d5e08faf 100644 --- a/litellm/integrations/datadog.py +++ b/litellm/integrations/datadog.py @@ -3,8 +3,6 @@ import dotenv, os import requests # type: ignore - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import datetime, subprocess, sys import litellm, uuid diff --git a/litellm/integrations/dynamodb.py b/litellm/integrations/dynamodb.py index b5462ee7f..21ccabe4b 100644 --- a/litellm/integrations/dynamodb.py +++ b/litellm/integrations/dynamodb.py @@ -3,8 +3,6 @@ import dotenv, os import requests # type: ignore - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import datetime, subprocess, sys import litellm, uuid diff --git a/litellm/integrations/helicone.py b/litellm/integrations/helicone.py index c8c107541..85e73258e 100644 --- a/litellm/integrations/helicone.py +++ b/litellm/integrations/helicone.py @@ -3,8 +3,6 @@ import dotenv, os import requests # type: ignore import litellm - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index ae8031bc1..0d9c0640c 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -1,8 +1,6 @@ #### What this does #### # On success, logs events to Langfuse -import dotenv, os - -dotenv.load_dotenv() # Loading env variables using dotenv +import os import copy import traceback from packaging.version import Version diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index 8a0fb3852..92e440215 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -3,8 +3,6 @@ import dotenv, os # type: ignore import requests # type: ignore from datetime import datetime - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import asyncio import types diff --git a/litellm/integrations/lunary.py b/litellm/integrations/lunary.py index 6b23f0987..2e16e44a1 100644 --- a/litellm/integrations/lunary.py +++ b/litellm/integrations/lunary.py @@ -2,13 +2,10 @@ # On success + failure, log events to lunary.ai from datetime import datetime, timezone import traceback -import dotenv import importlib import packaging -dotenv.load_dotenv() - # convert to {completion: xx, tokens: xx} def parse_usage(usage): @@ -79,14 +76,16 @@ class LunaryLogger: version = importlib.metadata.version("lunary") # if version < 0.1.43 then raise ImportError if packaging.version.Version(version) < packaging.version.Version("0.1.43"): - print( + print( # noqa "Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'" ) raise ImportError self.lunary_client = lunary except ImportError: - print("Lunary not installed. Please install it using 'pip install lunary'") + print( # noqa + "Lunary not installed. Please install it using 'pip install lunary'" + ) # noqa raise ImportError def log_event( diff --git a/litellm/integrations/openmeter.py b/litellm/integrations/openmeter.py index a454739d5..2c470d6f4 100644 --- a/litellm/integrations/openmeter.py +++ b/litellm/integrations/openmeter.py @@ -3,8 +3,6 @@ import dotenv, os, json import litellm - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback from litellm.integrations.custom_logger import CustomLogger from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 577946ce1..6fbc6ca4c 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -4,8 +4,6 @@ import dotenv, os import requests # type: ignore - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import datetime, subprocess, sys import litellm, uuid diff --git a/litellm/integrations/prometheus_services.py b/litellm/integrations/prometheus_services.py index d276bb85b..8fce8930d 100644 --- a/litellm/integrations/prometheus_services.py +++ b/litellm/integrations/prometheus_services.py @@ -5,8 +5,6 @@ import dotenv, os import requests # type: ignore - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import datetime, subprocess, sys import litellm, uuid diff --git a/litellm/integrations/prompt_layer.py b/litellm/integrations/prompt_layer.py index ce610e1ef..531ed75fe 100644 --- a/litellm/integrations/prompt_layer.py +++ b/litellm/integrations/prompt_layer.py @@ -3,8 +3,6 @@ import dotenv, os import requests # type: ignore from pydantic import BaseModel - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index d31b15840..d131e44f0 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -1,9 +1,7 @@ #### What this does #### # On success + failure, log events to Supabase -import dotenv, os - -dotenv.load_dotenv() # Loading env variables using dotenv +import os import traceback import datetime, subprocess, sys import litellm, uuid diff --git a/litellm/integrations/slack_alerting.py b/litellm/integrations/slack_alerting.py index 07c3585f0..d03922bc1 100644 --- a/litellm/integrations/slack_alerting.py +++ b/litellm/integrations/slack_alerting.py @@ -2,8 +2,6 @@ # Class for sending Slack Alerts # import dotenv, os from litellm.proxy._types import UserAPIKeyAuth - -dotenv.load_dotenv() # Loading env variables using dotenv from litellm._logging import verbose_logger, verbose_proxy_logger import litellm, threading from typing import List, Literal, Any, Union, Optional, Dict diff --git a/litellm/integrations/supabase.py b/litellm/integrations/supabase.py index 58beba8a3..4e6bf517f 100644 --- a/litellm/integrations/supabase.py +++ b/litellm/integrations/supabase.py @@ -3,8 +3,6 @@ import dotenv, os import requests # type: ignore - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback import datetime, subprocess, sys import litellm diff --git a/litellm/integrations/weights_biases.py b/litellm/integrations/weights_biases.py index 53e6070a5..a56233b22 100644 --- a/litellm/integrations/weights_biases.py +++ b/litellm/integrations/weights_biases.py @@ -21,11 +21,11 @@ try: # contains a (known) object attribute object: Literal["chat.completion", "edit", "text_completion"] - def __getitem__(self, key: K) -> V: - ... # pragma: no cover + def __getitem__(self, key: K) -> V: ... # noqa - def get(self, key: K, default: Optional[V] = None) -> Optional[V]: - ... # pragma: no cover + def get( # noqa + self, key: K, default: Optional[V] = None + ) -> Optional[V]: ... # pragma: no cover class OpenAIRequestResponseResolver: def __call__( @@ -173,12 +173,11 @@ except: #### What this does #### # On success, logs events to Langfuse -import dotenv, os +import os import requests import requests from datetime import datetime -dotenv.load_dotenv() # Loading env variables using dotenv import traceback diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 818c4ecb3..97a473a2e 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -3,7 +3,7 @@ import json from enum import Enum import requests, copy # type: ignore import time -from typing import Callable, Optional, List +from typing import Callable, Optional, List, Union from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper import litellm from .prompt_templates.factory import prompt_factory, custom_prompt @@ -151,19 +151,135 @@ class AnthropicChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() + def process_streaming_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.utils.Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> CustomStreamWrapper: + """ + Return stream object for tool-calling + streaming + """ + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise AnthropicError( + message=response.text, status_code=response.status_code + ) + text_content = "" + tool_calls = [] + for content in completion_response["content"]: + if content["type"] == "text": + text_content += content["text"] + ## TOOL CALLING + elif content["type"] == "tool_use": + tool_calls.append( + { + "id": content["id"], + "type": "function", + "function": { + "name": content["name"], + "arguments": json.dumps(content["input"]), + }, + } + ) + if "error" in completion_response: + raise AnthropicError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + _message = litellm.Message( + tool_calls=tool_calls, + content=text_content or None, + ) + model_response.choices[0].message = _message # type: ignore + model_response._hidden_params["original_response"] = completion_response[ + "content" + ] # allow user to access raw anthropic tool calling response + + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["stop_reason"] + ) + + print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") + # return an iterator + streaming_model_response = ModelResponse(stream=True) + streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore + 0 + ].finish_reason + # streaming_model_response.choices = [litellm.utils.StreamingChoices()] + streaming_choice = litellm.utils.StreamingChoices() + streaming_choice.index = model_response.choices[0].index + _tool_calls = [] + print_verbose( + f"type of model_response.choices[0]: {type(model_response.choices[0])}" + ) + print_verbose(f"type of streaming_choice: {type(streaming_choice)}") + if isinstance(model_response.choices[0], litellm.Choices): + if getattr( + model_response.choices[0].message, "tool_calls", None + ) is not None and isinstance( + model_response.choices[0].message.tool_calls, list + ): + for tool_call in model_response.choices[0].message.tool_calls: + _tool_call = {**tool_call.dict(), "index": 0} + _tool_calls.append(_tool_call) + delta_obj = litellm.utils.Delta( + content=getattr(model_response.choices[0].message, "content", None), + role=model_response.choices[0].message.role, + tool_calls=_tool_calls, + ) + streaming_choice.delta = delta_obj + streaming_model_response.choices = [streaming_choice] + completion_stream = ModelResponseIterator( + model_response=streaming_model_response + ) + print_verbose( + "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + else: + raise AnthropicError( + status_code=422, + message="Unprocessable response object - {}".format(response.text), + ) + def process_response( self, - model, - response, - model_response, - _is_function_call, - stream, - logging_obj, - api_key, - data, - messages, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.utils.Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, print_verbose, - ): + encoding, + ) -> ModelResponse: ## LOGGING logging_obj.post_call( input=messages, @@ -216,51 +332,6 @@ class AnthropicChatCompletion(BaseLLM): completion_response["stop_reason"] ) - print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}") - if _is_function_call and stream: - print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK") - # return an iterator - streaming_model_response = ModelResponse(stream=True) - streaming_model_response.choices[0].finish_reason = model_response.choices[ - 0 - ].finish_reason - # streaming_model_response.choices = [litellm.utils.StreamingChoices()] - streaming_choice = litellm.utils.StreamingChoices() - streaming_choice.index = model_response.choices[0].index - _tool_calls = [] - print_verbose( - f"type of model_response.choices[0]: {type(model_response.choices[0])}" - ) - print_verbose(f"type of streaming_choice: {type(streaming_choice)}") - if isinstance(model_response.choices[0], litellm.Choices): - if getattr( - model_response.choices[0].message, "tool_calls", None - ) is not None and isinstance( - model_response.choices[0].message.tool_calls, list - ): - for tool_call in model_response.choices[0].message.tool_calls: - _tool_call = {**tool_call.dict(), "index": 0} - _tool_calls.append(_tool_call) - delta_obj = litellm.utils.Delta( - content=getattr(model_response.choices[0].message, "content", None), - role=model_response.choices[0].message.role, - tool_calls=_tool_calls, - ) - streaming_choice.delta = delta_obj - streaming_model_response.choices = [streaming_choice] - completion_stream = ModelResponseIterator( - model_response=streaming_model_response - ) - print_verbose( - "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" - ) - return CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - ## CALCULATING USAGE prompt_tokens = completion_response["usage"]["input_tokens"] completion_tokens = completion_response["usage"]["output_tokens"] @@ -273,7 +344,7 @@ class AnthropicChatCompletion(BaseLLM): completion_tokens=completion_tokens, total_tokens=total_tokens, ) - model_response.usage = usage + setattr(model_response, "usage", usage) # type: ignore return model_response async def acompletion_stream_function( @@ -289,7 +360,7 @@ class AnthropicChatCompletion(BaseLLM): logging_obj, stream, _is_function_call, - data=None, + data: dict, optional_params=None, litellm_params=None, logger_fn=None, @@ -331,29 +402,44 @@ class AnthropicChatCompletion(BaseLLM): logging_obj, stream, _is_function_call, - data=None, - optional_params=None, + data: dict, + optional_params: dict, litellm_params=None, logger_fn=None, headers={}, - ): + ) -> Union[ModelResponse, CustomStreamWrapper]: self.async_handler = AsyncHTTPHandler( timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) response = await self.async_handler.post( api_base, headers=headers, data=json.dumps(data) ) + if stream and _is_function_call: + return self.process_streaming_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) return self.process_response( model=model, response=response, model_response=model_response, - _is_function_call=_is_function_call, stream=stream, logging_obj=logging_obj, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, ) def completion( @@ -367,7 +453,7 @@ class AnthropicChatCompletion(BaseLLM): encoding, api_key, logging_obj, - optional_params=None, + optional_params: dict, acompletion=None, litellm_params=None, logger_fn=None, @@ -526,17 +612,33 @@ class AnthropicChatCompletion(BaseLLM): raise AnthropicError( status_code=response.status_code, message=response.text ) + + if stream and _is_function_call: + return self.process_streaming_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) return self.process_response( model=model, response=response, model_response=model_response, - _is_function_call=_is_function_call, stream=stream, logging_obj=logging_obj, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, ) def embedding(self): diff --git a/litellm/llms/anthropic_text.py b/litellm/llms/anthropic_text.py index cef31c269..0093d9f35 100644 --- a/litellm/llms/anthropic_text.py +++ b/litellm/llms/anthropic_text.py @@ -100,7 +100,7 @@ class AnthropicTextCompletion(BaseLLM): def __init__(self) -> None: super().__init__() - def process_response( + def _process_response( self, model_response: ModelResponse, response, encoding, prompt: str, model: str ): ## RESPONSE OBJECT @@ -171,7 +171,7 @@ class AnthropicTextCompletion(BaseLLM): additional_args={"complete_input_dict": data}, ) - response = self.process_response( + response = self._process_response( model_response=model_response, response=response, encoding=encoding, @@ -330,7 +330,7 @@ class AnthropicTextCompletion(BaseLLM): ) print_verbose(f"raw model_response: {response.text}") - response = self.process_response( + response = self._process_response( model_response=model_response, response=response, encoding=encoding, diff --git a/litellm/llms/base.py b/litellm/llms/base.py index 62b8069f0..d940d9471 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -1,12 +1,32 @@ ## This is a template base class to be used for adding new LLM providers via API calls import litellm -import httpx -from typing import Optional +import httpx, requests +from typing import Optional, Union +from litellm.utils import Logging class BaseLLM: _client_session: Optional[httpx.Client] = None + def process_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: litellm.utils.ModelResponse, + stream: bool, + logging_obj: Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: list, + print_verbose, + encoding, + ) -> litellm.utils.ModelResponse: + """ + Helper function to process the response across sync + async completion calls + """ + return model_response + def create_client_session(self): if litellm.client_session: _client_session = litellm.client_session diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py new file mode 100644 index 000000000..1ff3767bd --- /dev/null +++ b/litellm/llms/bedrock_httpx.py @@ -0,0 +1,733 @@ +# What is this? +## Initial implementation of calling bedrock via httpx client (allows for async calls). +## V0 - just covers cohere command-r support + +import os, types +import json +from enum import Enum +import requests, copy # type: ignore +import time +from typing import ( + Callable, + Optional, + List, + Literal, + Union, + Any, + TypedDict, + Tuple, + Iterator, + AsyncIterator, +) +from litellm.utils import ( + ModelResponse, + Usage, + map_finish_reason, + CustomStreamWrapper, + Message, + Choices, + get_secret, + Logging, +) +import litellm +from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from .base import BaseLLM +import httpx # type: ignore +from .bedrock import BedrockError, convert_messages_to_prompt +from litellm.types.llms.bedrock import * + + +class AmazonCohereChatConfig: + """ + Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html + """ + + documents: Optional[List[Document]] = None + search_queries_only: Optional[bool] = None + preamble: Optional[str] = None + max_tokens: Optional[int] = None + temperature: Optional[float] = None + p: Optional[float] = None + k: Optional[float] = None + prompt_truncation: Optional[str] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + seed: Optional[int] = None + return_prompt: Optional[bool] = None + stop_sequences: Optional[List[str]] = None + raw_prompting: Optional[bool] = None + + def __init__( + self, + documents: Optional[List[Document]] = None, + search_queries_only: Optional[bool] = None, + preamble: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + p: Optional[float] = None, + k: Optional[float] = None, + prompt_truncation: Optional[str] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + return_prompt: Optional[bool] = None, + stop_sequences: Optional[str] = None, + raw_prompting: Optional[bool] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self) -> List[str]: + return [ + "max_tokens", + "stream", + "stop", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + "seed", + "stop", + ] + + def map_openai_params( + self, non_default_params: dict, optional_params: dict + ) -> dict: + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "stream": + optional_params["stream"] = value + if param == "stop": + if isinstance(value, str): + value = [value] + optional_params["stop_sequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["p"] = value + if param == "frequency_penalty": + optional_params["frequency_penalty"] = value + if param == "presence_penalty": + optional_params["presence_penalty"] = value + if "seed": + optional_params["seed"] = value + return optional_params + + +class BedrockLLM(BaseLLM): + """ + Example call + + ``` + curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \ + --header 'Content-Type: application/json' \ + --header 'Accept: application/json' \ + --user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \ + --aws-sigv4 "aws:amz:us-east-1:bedrock" \ + --data-raw '{ + "prompt": "Hi", + "temperature": 0, + "p": 0.9, + "max_tokens": 4096 + }' + ``` + """ + + def __init__(self) -> None: + super().__init__() + + def convert_messages_to_prompt( + self, model, messages, provider, custom_prompt_dict + ) -> Tuple[str, Optional[list]]: + # handle anthropic prompts and amazon titan prompts + prompt = "" + chat_history: Optional[list] = None + if provider == "anthropic" or provider == "amazon": + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="bedrock" + ) + elif provider == "mistral": + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="bedrock" + ) + elif provider == "meta": + prompt = prompt_factory( + model=model, messages=messages, custom_llm_provider="bedrock" + ) + elif provider == "cohere": + prompt, chat_history = cohere_message_pt(messages=messages) + else: + prompt = "" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += f"{message['content']}" + else: + prompt += f"{message['content']}" + else: + prompt += f"{message['content']}" + return prompt, chat_history # type: ignore + + def get_credentials( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_role_name: Optional[str] = None, + ): + """ + Return a boto3.Credentials object + """ + import boto3 + + ## CHECK IS 'os.environ/' passed in + params_to_check: List[Optional[str]] = [ + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + ] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + _v = get_secret(param) + if _v is not None and isinstance(_v, str): + params_to_check[i] = _v + # Assign updated values back to parameters + ( + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + ) = params_to_check + + ### CHECK STS ### + if aws_role_name is not None and aws_session_name is not None: + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, # [OPTIONAL] + aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + return sts_response["Credentials"] + elif aws_profile_name is not None: ### CHECK SESSION ### + # uses auth values from AWS profile usually stored in ~/.aws/credentials + client = boto3.Session(profile_name=aws_profile_name) + + return client.get_credentials() + else: + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + + return session.get_credentials() + + def process_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: ModelResponse, + stream: bool, + logging_obj: Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + + ## RESPONSE OBJECT + try: + completion_response = response.json() + except: + raise BedrockError(message=response.text, status_code=422) + + try: + model_response.choices[0].message.content = completion_response["text"] # type: ignore + except Exception as e: + raise BedrockError(message=response.text, status_code=422) + + ## CALCULATING USAGE - bedrock returns usage in the headers + prompt_tokens = int( + response.headers.get( + "x-amzn-bedrock-input-token-count", + len(encoding.encode("".join(m.get("content", "") for m in messages))), + ) + ) + completion_tokens = int( + response.headers.get( + "x-amzn-bedrock-output-token-count", + len( + encoding.encode( + model_response.choices[0].message.content, # type: ignore + disallowed_special=(), + ) + ), + ) + ) + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + + return model_response + + def completion( + self, + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + acompletion: bool, + timeout: Optional[Union[float, httpx.Timeout]], + litellm_params=None, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + try: + import boto3 + + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + ## SETUP ## + stream = optional_params.pop("stream", None) + + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + ) + + ### SET RUNTIME ENDPOINT ### + endpoint_url = "" + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + + if stream is not None and stream == True: + endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream" + else: + endpoint_url = f"{endpoint_url}/model/{model}/invoke" + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + provider = model.split(".")[0] + prompt, chat_history = self.convert_messages_to_prompt( + model, messages, provider, custom_prompt_dict + ) + inference_params = copy.deepcopy(optional_params) + + if provider == "cohere": + if model.startswith("cohere.command-r"): + ## LOAD CONFIG + config = litellm.AmazonCohereChatConfig().get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + _data = {"message": prompt, **inference_params} + if chat_history is not None: + _data["chat_history"] = chat_history + data = json.dumps(_data) + else: + ## LOAD CONFIG + config = litellm.AmazonCohereConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + if stream == True: + inference_params["stream"] = ( + True # cohere requires stream = True in inference params + ) + data = json.dumps({"prompt": prompt, **inference_params}) + else: + raise Exception("UNSUPPORTED PROVIDER") + + ## COMPLETION CALL + + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=endpoint_url, data=data, headers=headers + ) + sigv4.add_auth(request) + prepped = request.prepare() + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": prepped.url, + "headers": prepped.headers, + }, + ) + + ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + if isinstance(client, HTTPHandler): + client = None + if stream: + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=True, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=False, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + + if client is None or isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + self.client = HTTPHandler(**_params) # type: ignore + else: + self.client = client + if stream is not None and stream == True: + response = self.client.post( + url=prepped.url, + headers=prepped.headers, # type: ignore + data=data, + stream=stream, + ) + + if response.status_code != 200: + raise BedrockError( + status_code=response.status_code, message=response.text + ) + + decoder = AWSEventStreamDecoder() + + completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + return streaming_response + + response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore + + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=response.text) + + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + optional_params=optional_params, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> ModelResponse: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + self.client = AsyncHTTPHandler(**_params) # type: ignore + else: + self.client = client # type: ignore + + response = await self.client.post(api_base, headers=headers, data=data) # type: ignore + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) + + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + self.client = AsyncHTTPHandler(**_params) # type: ignore + else: + self.client = client # type: ignore + + response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore + + if response.status_code != 200: + raise BedrockError(status_code=response.status_code, message=response.text) + + decoder = AWSEventStreamDecoder() + + completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024)) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + return streaming_response + + def embedding(self, *args, **kwargs): + return super().embedding(*args, **kwargs) + + +def get_response_stream_shape(): + from botocore.model import ServiceModel + from botocore.loaders import Loader + + loader = Loader() + bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2") + bedrock_service_model = ServiceModel(bedrock_service_dict) + return bedrock_service_model.shape_for("ResponseStream") + + +class AWSEventStreamDecoder: + def __init__(self) -> None: + from botocore.parsers import EventStreamJSONParser + + self.parser = EventStreamJSONParser() + + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: + """Given an iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + # sse_event = ServerSentEvent(data=message, event="completion") + _data = json.loads(message) + streaming_chunk: GenericStreamingChunk = GenericStreamingChunk( + text=_data.get("text", ""), + is_finished=_data.get("is_finished", False), + finish_reason=_data.get("finish_reason", ""), + ) + yield streaming_chunk + + async def aiter_bytes( + self, iterator: AsyncIterator[bytes] + ) -> AsyncIterator[GenericStreamingChunk]: + """Given an async iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + async for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + _data = json.loads(message) + streaming_chunk: GenericStreamingChunk = GenericStreamingChunk( + text=_data.get("text", ""), + is_finished=_data.get("is_finished", False), + finish_reason=_data.get("finish_reason", ""), + ) + yield streaming_chunk + + def _parse_message_from_event(self, event) -> Optional[str]: + response_dict = event.to_response_dict() + parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) + if response_dict["status_code"] != 200: + raise ValueError(f"Bad response code, expected 200: {response_dict}") + + chunk = parsed_response.get("chunk") + if not chunk: + return None + + return chunk.get("bytes").decode() # type: ignore[no-any-return] diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 7c7d4938a..0adbd95bf 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -58,16 +58,25 @@ class AsyncHTTPHandler: class HTTPHandler: def __init__( - self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000 + self, + timeout: Optional[httpx.Timeout] = None, + concurrent_limit=1000, + client: Optional[httpx.Client] = None, ): - # Create a client with a connection pool - self.client = httpx.Client( - timeout=timeout, - limits=httpx.Limits( - max_connections=concurrent_limit, - max_keepalive_connections=concurrent_limit, - ), - ) + if timeout is None: + timeout = _DEFAULT_TIMEOUT + + if client is None: + # Create a client with a connection pool + self.client = httpx.Client( + timeout=timeout, + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ), + ) + else: + self.client = client def close(self): # Close the client when you're done with it @@ -82,11 +91,15 @@ class HTTPHandler: def post( self, url: str, - data: Optional[dict] = None, + data: Optional[Union[dict, str]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, + stream: bool = False, ): - response = self.client.post(url, data=data, params=params, headers=headers) + req = self.client.build_request( + "POST", url, data=data, params=params, headers=headers # type: ignore + ) + response = self.client.send(req, stream=stream) return response def __del__(self) -> None: diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index c3424d244..1e7e1d334 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM): logging_obj: litellm.utils.Logging, optional_params: dict, api_key: str, - data: dict, + data: Union[dict, str], messages: list, print_verbose, encoding, @@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM): try: completion_response = response.json() except: - raise PredibaseError( - message=response.text, status_code=response.status_code - ) + raise PredibaseError(message=response.text, status_code=422) if "error" in completion_response: raise PredibaseError( message=str(completion_response["error"]), @@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM): }, ) ## COMPLETION CALL - if acompletion is True: + if acompletion == True: ### ASYNC STREAMING if stream == True: return self.async_streaming( diff --git a/litellm/main.py b/litellm/main.py index c1d90451b..16188b253 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -76,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface from .llms.predibase import PredibaseChatCompletion +from .llms.bedrock_httpx import BedrockLLM from .llms.triton import TritonChatCompletion from .llms.prompt_templates.factory import ( prompt_factory, @@ -105,7 +106,6 @@ from litellm.utils import ( ) ####### ENVIRONMENT VARIABLES ################### -dotenv.load_dotenv() # Loading env variables using dotenv openai_chat_completions = OpenAIChatCompletion() openai_text_completions = OpenAITextCompletion() anthropic_chat_completions = AnthropicChatCompletion() @@ -115,6 +115,7 @@ azure_text_completions = AzureTextCompletion() huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() triton_chat_completions = TritonChatCompletion() +bedrock_chat_completion = BedrockLLM() ####### COMPLETION ENDPOINTS ################ @@ -257,7 +258,7 @@ async def acompletion( - If `stream` is True, the function returns an async generator that yields completion lines. """ loop = asyncio.get_event_loop() - custom_llm_provider = None + custom_llm_provider = kwargs.get("custom_llm_provider", None) # Adjusted to use explicit arguments instead of *args and **kwargs completion_kwargs = { "model": model, @@ -289,9 +290,10 @@ async def acompletion( "model_list": model_list, "acompletion": True, # assuming this is a required parameter } - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=completion_kwargs.get("base_url", None) - ) + if custom_llm_provider is None: + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=completion_kwargs.get("base_url", None) + ) try: # Use a partial function to pass your keyword arguments func = partial(completion, **completion_kwargs, **kwargs) @@ -300,9 +302,6 @@ async def acompletion( ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) - ) if ( custom_llm_provider == "openai" or custom_llm_provider == "azure" @@ -324,6 +323,7 @@ async def acompletion( or custom_llm_provider == "sagemaker" or custom_llm_provider == "anthropic" or custom_llm_provider == "predibase" + or (custom_llm_provider == "bedrock" and "cohere" in model) or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1976,41 +1976,59 @@ def completion( elif custom_llm_provider == "bedrock": # boto3 reads keys from .env custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = bedrock.completion( - model=model, - messages=messages, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and not isinstance(response, CustomStreamWrapper) - ): - # don't try to access stream object, - if "ai21" in model: - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) - else: - response = CustomStreamWrapper( - iter(response), - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) + if "cohere" in model: + response = bedrock_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + ) + else: + response = bedrock.completion( + model=model, + messages=messages, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): + # don't try to access stream object, + if "ai21" in model: + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="bedrock", + logging_obj=logging, + ) + else: + response = CustomStreamWrapper( + iter(response), + model, + custom_llm_provider="bedrock", + logging_obj=logging, + ) if optional_params.get("stream", False): ## LOGGING diff --git a/litellm/proxy/example_config_yaml/custom_auth.py b/litellm/proxy/example_config_yaml/custom_auth.py index a764a647a..6cecf466c 100644 --- a/litellm/proxy/example_config_yaml/custom_auth.py +++ b/litellm/proxy/example_config_yaml/custom_auth.py @@ -1,10 +1,7 @@ from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest from fastapi import Request -from dotenv import load_dotenv import os -load_dotenv() - async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: try: diff --git a/litellm/router.py b/litellm/router.py index 52fa8561d..4c0312500 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1507,7 +1507,6 @@ class Router: return response except Exception as e: original_exception = e - """ Retry Logic diff --git a/litellm/router_strategy/least_busy.py b/litellm/router_strategy/least_busy.py index 54d44b41d..417651fb3 100644 --- a/litellm/router_strategy/least_busy.py +++ b/litellm/router_strategy/least_busy.py @@ -8,8 +8,6 @@ import dotenv, os, requests, random # type: ignore from typing import Optional - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/router_strategy/lowest_cost.py b/litellm/router_strategy/lowest_cost.py index 279af2ae9..fde7781b9 100644 --- a/litellm/router_strategy/lowest_cost.py +++ b/litellm/router_strategy/lowest_cost.py @@ -1,12 +1,11 @@ #### What this does #### # picks based on response time (for streaming, this is time to first token) from pydantic import BaseModel, Extra, Field, root_validator -import dotenv, os, requests, random # type: ignore +import os, requests, random # type: ignore from typing import Optional, Union, List, Dict from datetime import datetime, timedelta import random -dotenv.load_dotenv() # Loading env variables using dotenv import traceback from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index afdfc1779..a7b93d344 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -5,8 +5,6 @@ import dotenv, os, requests, random # type: ignore from typing import Optional, Union, List, Dict from datetime import datetime, timedelta import random - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/router_strategy/lowest_tpm_rpm.py b/litellm/router_strategy/lowest_tpm_rpm.py index 0a7773a84..625db7048 100644 --- a/litellm/router_strategy/lowest_tpm_rpm.py +++ b/litellm/router_strategy/lowest_tpm_rpm.py @@ -4,8 +4,6 @@ import dotenv, os, requests, random from typing import Optional, Union, List, Dict from datetime import datetime - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback from litellm import token_counter from litellm.caching import DualCache diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index f7a55d970..23e55f4a3 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -5,8 +5,6 @@ import dotenv, os, requests, random from typing import Optional, Union, List, Dict import datetime as datetime_og from datetime import datetime - -dotenv.load_dotenv() # Loading env variables using dotenv import traceback, asyncio, httpx import litellm from litellm import token_counter diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index bbe81f8ad..310a5f818 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2584,6 +2584,69 @@ def test_completion_chat_sagemaker_mistral(): # test_completion_chat_sagemaker_mistral() +def response_format_tests(response: litellm.ModelResponse): + assert isinstance(response.id, str) + assert response.id != "" + + assert isinstance(response.object, str) + assert response.object != "" + + assert isinstance(response.created, int) + + assert isinstance(response.model, str) + assert response.model != "" + + assert isinstance(response.choices, list) + assert len(response.choices) == 1 + choice = response.choices[0] + assert isinstance(choice, litellm.Choices) + assert isinstance(choice.get("index"), int) + + message = choice.get("message") + assert isinstance(message, litellm.Message) + assert isinstance(message.get("role"), str) + assert message.get("role") != "" + assert isinstance(message.get("content"), str) + assert message.get("content") != "" + + assert choice.get("logprobs") is None + assert isinstance(choice.get("finish_reason"), str) + assert choice.get("finish_reason") != "" + + assert isinstance(response.usage, litellm.Usage) # type: ignore + assert isinstance(response.usage.prompt_tokens, int) # type: ignore + assert isinstance(response.usage.completion_tokens, int) # type: ignore + assert isinstance(response.usage.total_tokens, int) # type: ignore + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_completion_bedrock_command_r(sync_mode): + litellm.set_verbose = True + + if sync_mode: + response = completion( + model="bedrock/cohere.command-r-plus-v1:0", + messages=[{"role": "user", "content": "Hey! how's it going?"}], + ) + + assert isinstance(response, litellm.ModelResponse) + + response_format_tests(response=response) + else: + response = await litellm.acompletion( + model="bedrock/cohere.command-r-plus-v1:0", + messages=[{"role": "user", "content": "Hey! how's it going?"}], + ) + + assert isinstance(response, litellm.ModelResponse) + + print(f"response: {response}") + response_format_tests(response=response) + + print(f"response: {response}") + + def test_completion_bedrock_titan_null_response(): try: response = completion( diff --git a/litellm/tests/test_rules.py b/litellm/tests/test_rules.py index 7e2d7b819..0bafbf48f 100644 --- a/litellm/tests/test_rules.py +++ b/litellm/tests/test_rules.py @@ -132,12 +132,15 @@ def test_post_call_rule_streaming(): ) -def test_post_call_processing_error_async_response(): - response = asyncio.run( - acompletion( +@pytest.mark.asyncio +async def test_post_call_processing_error_async_response(): + try: + response = await acompletion( model="command-nightly", # Just used as an example messages=[{"content": "Hello, how are you?", "role": "user"}], api_base="https://openai-proxy.berriai.repl.co", # Just used as an example custom_llm_provider="openai", ) - ) + pytest.fail("This call should have failed") + except Exception as e: + pass diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 13cadf227..a948a5683 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -983,6 +983,64 @@ def test_vertex_ai_stream(): # pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize("sync_mode", [True]) +@pytest.mark.asyncio +async def test_bedrock_cohere_command_r_streaming(sync_mode): + try: + litellm.set_verbose = True + if sync_mode: + final_chunk: Optional[litellm.ModelResponse] = None + response: litellm.CustomStreamWrapper = completion( # type: ignore + model="bedrock/cohere.command-r-plus-v1:0", + messages=messages, + max_tokens=10, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + for idx, chunk in enumerate(response): + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + else: + response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore + model="bedrock/cohere.command-r-plus-v1:0", + messages=messages, + max_tokens=100, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + idx = 0 + final_chunk: Optional[litellm.ModelResponse] = None + async for chunk in response: + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + idx += 1 + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}\n\nFinalChunk: {final_chunk}") + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_bedrock_claude_3_streaming(): try: litellm.set_verbose = True diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py new file mode 100644 index 000000000..0c8259682 --- /dev/null +++ b/litellm/types/llms/bedrock.py @@ -0,0 +1,63 @@ +from typing import TypedDict, Any, Union, Optional +import json +from typing_extensions import ( + Self, + Protocol, + TypeGuard, + override, + get_origin, + runtime_checkable, + Required, +) + + +class GenericStreamingChunk(TypedDict): + text: Required[str] + is_finished: Required[bool] + finish_reason: Required[str] + + +class Document(TypedDict): + title: str + snippet: str + + +class ServerSentEvent: + def __init__( + self, + *, + event: Optional[str] = None, + data: Optional[str] = None, + id: Optional[str] = None, + retry: Optional[int] = None, + ) -> None: + if data is None: + data = "" + + self._id = id + self._data = data + self._event = event or None + self._retry = retry + + @property + def event(self) -> Optional[str]: + return self._event + + @property + def id(self) -> Optional[str]: + return self._id + + @property + def retry(self) -> Optional[int]: + return self._retry + + @property + def data(self) -> str: + return self._data + + def json(self) -> Any: + return json.loads(self.data) + + @override + def __repr__(self) -> str: + return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" diff --git a/litellm/utils.py b/litellm/utils.py index 9ba19b5e9..f77baf8bd 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -132,7 +132,6 @@ MAX_THREADS = 100 # Create a ThreadPoolExecutor executor = ThreadPoolExecutor(max_workers=MAX_THREADS) -dotenv.load_dotenv() # Loading env variables using dotenv sentry_sdk_instance = None capture_exception = None add_breadcrumb = None @@ -10474,6 +10473,12 @@ class CustomStreamWrapper: raise e def handle_bedrock_stream(self, chunk): + if "cohere" in self.model: + return { + "text": chunk["text"], + "is_finished": chunk["is_finished"], + "finish_reason": chunk["finish_reason"], + } if hasattr(chunk, "get"): chunk = chunk.get("chunk") chunk_data = json.loads(chunk.get("bytes").decode()) @@ -11322,6 +11327,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "predibase" + or (self.custom_llm_provider == "bedrock" and "cohere" in self.model) or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: