diff --git a/docs/my-website/docs/proxy/user_keys.md b/docs/my-website/docs/proxy/user_keys.md index cc1d5fe821..3e0fec7fbd 100644 --- a/docs/my-website/docs/proxy/user_keys.md +++ b/docs/my-website/docs/proxy/user_keys.md @@ -173,6 +173,37 @@ console.log(message); ``` + + + +```python +from openai import OpenAI +import instructor +from pydantic import BaseModel + +my_proxy_api_key = "" # e.g. sk-1234 +my_proxy_base_url = "" # e.g. http://0.0.0.0:4000 + +# This enables response_model keyword +# from client.chat.completions.create +client = instructor.from_openai(OpenAI(api_key=my_proxy_api_key, base_url=my_proxy_base_url)) + +class UserDetail(BaseModel): + name: str + age: int + +user = client.chat.completions.create( + model="gemini-pro-flash", + response_model=UserDetail, + messages=[ + {"role": "user", "content": "Extract Jason is 25 years old"}, + ] +) + +assert isinstance(user, UserDetail) +assert user.name == "Jason" +assert user.age == 25 +``` diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index e4963a6f1a..0bc65a7f1d 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -4,6 +4,8 @@ import time import traceback from typing import List, Literal, Optional, Tuple, Union +from pydantic import BaseModel + import litellm import litellm._logging from litellm import verbose_logger @@ -14,6 +16,9 @@ from litellm.litellm_core_utils.llm_cost_calc.google import ( cost_per_token as google_cost_per_token, ) from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character +from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS + from litellm.utils import ( CallTypes, CostPerToken, @@ -469,7 +474,10 @@ def completion_cost( prompt_characters = 0 completion_tokens = 0 completion_characters = 0 - if completion_response is not None: + if completion_response is not None and ( + isinstance(completion_response, BaseModel) + or isinstance(completion_response, dict) + ): # tts returns a custom class # get input/output tokens from completion_response prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) completion_tokens = completion_response.get("usage", {}).get( @@ -654,6 +662,7 @@ def response_cost_calculator( ImageResponse, TranscriptionResponse, TextCompletionResponse, + HttpxBinaryResponseContent, ], model: str, custom_llm_provider: Optional[str], @@ -687,7 +696,8 @@ def response_cost_calculator( if cache_hit is not None and cache_hit is True: response_cost = 0.0 else: - response_object._hidden_params["optional_params"] = optional_params + if isinstance(response_object, BaseModel): + response_object._hidden_params["optional_params"] = optional_params if isinstance(response_object, ImageResponse): response_cost = completion_cost( completion_response=response_object, @@ -697,12 +707,11 @@ def response_cost_calculator( ) else: if ( - model in litellm.model_cost - and custom_pricing is not None - and custom_llm_provider is True + model in litellm.model_cost or custom_pricing is True ): # override defaults if custom pricing is set base_model = model # base_model defaults to None if not set on model_info + response_cost = completion_cost( completion_response=response_object, call_type=call_type, diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 381bcc1ac9..f9f32552da 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -24,6 +24,8 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, ) +from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import ( CallTypes, EmbeddingResponse, @@ -517,33 +519,36 @@ class Logging: self.model_call_details["cache_hit"] = cache_hit ## if model in model cost map - log the response cost ## else set cost to None - verbose_logger.debug(f"Model={self.model};") if ( - result is not None - and ( + result is not None and self.stream is not True + ): # handle streaming separately + if ( isinstance(result, ModelResponse) or isinstance(result, EmbeddingResponse) or isinstance(result, ImageResponse) or isinstance(result, TranscriptionResponse) or isinstance(result, TextCompletionResponse) - ) - and self.stream != True - ): # handle streaming separately - self.model_call_details["response_cost"] = ( - litellm.response_cost_calculator( - response_object=result, - model=self.model, - cache_hit=self.model_call_details.get("cache_hit", False), - custom_llm_provider=self.model_call_details.get( - "custom_llm_provider", None - ), - base_model=_get_base_model_from_metadata( - model_call_details=self.model_call_details - ), - call_type=self.call_type, - optional_params=self.optional_params, + or isinstance(result, HttpxBinaryResponseContent) # tts + ): + custom_pricing = use_custom_pricing_for_model( + litellm_params=self.litellm_params + ) + self.model_call_details["response_cost"] = ( + litellm.response_cost_calculator( + response_object=result, + model=self.model, + cache_hit=self.model_call_details.get("cache_hit", False), + custom_llm_provider=self.model_call_details.get( + "custom_llm_provider", None + ), + base_model=_get_base_model_from_metadata( + model_call_details=self.model_call_details + ), + call_type=self.call_type, + optional_params=self.optional_params, + custom_pricing=custom_pricing, + ) ) - ) else: # streaming chunks + image gen. self.model_call_details["response_cost"] = None @@ -2012,3 +2017,17 @@ def get_custom_logger_compatible_class( if isinstance(callback, _PROXY_DynamicRateLimitHandler): return callback # type: ignore return None + + +def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool: + if litellm_params is None: + return False + metadata: Optional[dict] = litellm_params.get("metadata", {}) + if metadata is None: + return False + model_info: Optional[dict] = metadata.get("model_info", {}) + if model_info is not None: + for k, v in model_info.items(): + if k in SPECIAL_MODEL_INFO_PARAMS: + return True + return False diff --git a/litellm/main.py b/litellm/main.py index 2a7759e8a5..37ae125b99 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5022,10 +5022,9 @@ def stream_chunk_builder( for chunk in chunks: if "usage" in chunk: if "prompt_tokens" in chunk["usage"]: - prompt_tokens += chunk["usage"].get("prompt_tokens", 0) or 0 + prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0 if "completion_tokens" in chunk["usage"]: - completion_tokens += chunk["usage"].get("completion_tokens", 0) or 0 - + completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0 try: response["usage"]["prompt_tokens"] = prompt_tokens or token_counter( model=model, messages=messages diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 7f4b86ec40..a7d46d5061 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,12 +1,9 @@ model_list: - - model_name: "*" + - model_name: tts litellm_params: model: "openai/*" - mock_response: "Hello world!" - litellm_settings: success_callback: ["langfuse"] - failure_callback: ["langfuse"] general_settings: alerting: ["slack"] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5011b64b3a..f954b9dd54 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -207,6 +207,7 @@ from litellm.router import ModelInfo as RouterModelInfo from litellm.router import updateDeployment from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.router import RouterGeneralSettings try: from litellm._version import version @@ -1765,7 +1766,11 @@ class ProxyConfig: if k in available_args: router_params[k] = v router = litellm.Router( - **router_params, assistants_config=assistants_config + **router_params, + assistants_config=assistants_config, + router_general_settings=RouterGeneralSettings( + async_only_mode=True # only init async clients + ), ) # type:ignore return router, router.get_model_list(), general_settings @@ -1957,7 +1962,12 @@ class ProxyConfig: ) if len(_model_list) > 0: verbose_proxy_logger.debug(f"_model_list: {_model_list}") - llm_router = litellm.Router(model_list=_model_list) + llm_router = litellm.Router( + model_list=_model_list, + router_general_settings=RouterGeneralSettings( + async_only_mode=True # only init async clients + ), + ) verbose_proxy_logger.debug(f"updated llm_router: {llm_router}") else: verbose_proxy_logger.debug(f"len new_models: {len(new_models)}") diff --git a/litellm/router.py b/litellm/router.py index b54d70dbbf..db68197a4b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -55,6 +55,10 @@ from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 +from litellm.router_utils.client_initalization_utils import ( + set_client, + should_initialize_sync_client, +) from litellm.router_utils.handle_error import send_llm_exception_alert from litellm.scheduler import FlowItem, Scheduler from litellm.types.llms.openai import ( @@ -79,6 +83,7 @@ from litellm.types.router import ( ModelInfo, RetryPolicy, RouterErrors, + RouterGeneralSettings, updateDeployment, updateLiteLLMParams, ) @@ -169,6 +174,7 @@ class Router: routing_strategy_args: dict = {}, # just for latency-based routing semaphore: Optional[asyncio.Semaphore] = None, alerting_config: Optional[AlertingConfig] = None, + router_general_settings: Optional[RouterGeneralSettings] = None, ) -> None: """ Initialize the Router class with the given parameters for caching, reliability, and routing strategy. @@ -246,6 +252,9 @@ class Router: verbose_router_logger.setLevel(logging.INFO) elif debug_level == "DEBUG": verbose_router_logger.setLevel(logging.DEBUG) + self.router_general_settings: Optional[RouterGeneralSettings] = ( + router_general_settings + ) self.assistants_config = assistants_config self.deployment_names: List = ( @@ -3247,520 +3256,6 @@ class Router: except Exception as e: raise e - def set_client(self, model: dict): - """ - - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 - """ - client_ttl = self.client_ttl - litellm_params = model.get("litellm_params", {}) - model_name = litellm_params.get("model") - model_id = model["model_info"]["id"] - # ### IF RPM SET - initialize a semaphore ### - rpm = litellm_params.get("rpm", None) - tpm = litellm_params.get("tpm", None) - max_parallel_requests = litellm_params.get("max_parallel_requests", None) - calculated_max_parallel_requests = calculate_max_parallel_requests( - rpm=rpm, - max_parallel_requests=max_parallel_requests, - tpm=tpm, - default_max_parallel_requests=self.default_max_parallel_requests, - ) - if calculated_max_parallel_requests: - semaphore = asyncio.Semaphore(calculated_max_parallel_requests) - cache_key = f"{model_id}_max_parallel_requests_client" - self.cache.set_cache( - key=cache_key, - value=semaphore, - local_only=True, - ) - - #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## - custom_llm_provider = litellm_params.get("custom_llm_provider") - custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" - default_api_base = None - default_api_key = None - if custom_llm_provider in litellm.openai_compatible_providers: - _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( - model=model_name - ) - default_api_base = api_base - default_api_key = api_key - - if ( - model_name in litellm.open_ai_chat_completion_models - or custom_llm_provider in litellm.openai_compatible_providers - or custom_llm_provider == "azure" - or custom_llm_provider == "azure_text" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or "ft:gpt-3.5-turbo" in model_name - or model_name in litellm.open_ai_embedding_models - ): - is_azure_ai_studio_model: bool = False - if custom_llm_provider == "azure": - if litellm.utils._is_non_openai_azure_model(model_name): - is_azure_ai_studio_model = True - custom_llm_provider = "openai" - # remove azure prefx from model_name - model_name = model_name.replace("azure/", "") - # glorified / complicated reading of configs - # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env - # we do this here because we init clients for Azure, OpenAI and we need to set the right key - api_key = litellm_params.get("api_key") or default_api_key - if ( - api_key - and isinstance(api_key, str) - and api_key.startswith("os.environ/") - ): - api_key_env_name = api_key.replace("os.environ/", "") - api_key = litellm.get_secret(api_key_env_name) - litellm_params["api_key"] = api_key - - api_base = litellm_params.get("api_base") - base_url = litellm_params.get("base_url") - api_base = ( - api_base or base_url or default_api_base - ) # allow users to pass in `api_base` or `base_url` for azure - if api_base and api_base.startswith("os.environ/"): - api_base_env_name = api_base.replace("os.environ/", "") - api_base = litellm.get_secret(api_base_env_name) - litellm_params["api_base"] = api_base - - ## AZURE AI STUDIO MISTRAL CHECK ## - """ - Make sure api base ends in /v1/ - - if not, add it - https://github.com/BerriAI/litellm/issues/2279 - """ - if ( - is_azure_ai_studio_model is True - and api_base is not None - and isinstance(api_base, str) - and not api_base.endswith("/v1/") - ): - # check if it ends with a trailing slash - if api_base.endswith("/"): - api_base += "v1/" - elif api_base.endswith("/v1"): - api_base += "/" - else: - api_base += "/v1/" - - api_version = litellm_params.get("api_version") - if api_version and api_version.startswith("os.environ/"): - api_version_env_name = api_version.replace("os.environ/", "") - api_version = litellm.get_secret(api_version_env_name) - litellm_params["api_version"] = api_version - - timeout = litellm_params.pop("timeout", None) or litellm.request_timeout - if isinstance(timeout, str) and timeout.startswith("os.environ/"): - timeout_env_name = timeout.replace("os.environ/", "") - timeout = litellm.get_secret(timeout_env_name) - litellm_params["timeout"] = timeout - - stream_timeout = litellm_params.pop( - "stream_timeout", timeout - ) # if no stream_timeout is set, default to timeout - if isinstance(stream_timeout, str) and stream_timeout.startswith( - "os.environ/" - ): - stream_timeout_env_name = stream_timeout.replace("os.environ/", "") - stream_timeout = litellm.get_secret(stream_timeout_env_name) - litellm_params["stream_timeout"] = stream_timeout - - max_retries = litellm_params.pop( - "max_retries", 0 - ) # router handles retry logic - if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): - max_retries_env_name = max_retries.replace("os.environ/", "") - max_retries = litellm.get_secret(max_retries_env_name) - litellm_params["max_retries"] = max_retries - - # proxy support - import os - - import httpx - - # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. - http_proxy = os.getenv("HTTP_PROXY", None) - https_proxy = os.getenv("HTTPS_PROXY", None) - no_proxy = os.getenv("NO_PROXY", None) - - # Create the proxies dictionary only if the environment variables are set. - sync_proxy_mounts = None - async_proxy_mounts = None - if http_proxy is not None and https_proxy is not None: - sync_proxy_mounts = { - "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)), - "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)), - } - async_proxy_mounts = { - "http://": httpx.AsyncHTTPTransport( - proxy=httpx.Proxy(url=http_proxy) - ), - "https://": httpx.AsyncHTTPTransport( - proxy=httpx.Proxy(url=https_proxy) - ), - } - - # assume no_proxy is a list of comma separated urls - if no_proxy is not None and isinstance(no_proxy, str): - no_proxy_urls = no_proxy.split(",") - - for url in no_proxy_urls: # set no-proxy support for specific urls - sync_proxy_mounts[url] = None # type: ignore - async_proxy_mounts[url] = None # type: ignore - - organization = litellm_params.get("organization", None) - if isinstance(organization, str) and organization.startswith("os.environ/"): - organization_env_name = organization.replace("os.environ/", "") - organization = litellm.get_secret(organization_env_name) - litellm_params["organization"] = organization - - if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": - if api_base is None or not isinstance(api_base, str): - filtered_litellm_params = { - k: v - for k, v in model["litellm_params"].items() - if k != "api_key" - } - _filtered_model = { - "model_name": model["model_name"], - "litellm_params": filtered_litellm_params, - } - raise ValueError( - f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}" - ) - azure_ad_token = litellm_params.get("azure_ad_token") - if azure_ad_token is not None: - if azure_ad_token.startswith("oidc/"): - azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) - if api_version is None: - api_version = litellm.AZURE_DEFAULT_API_VERSION - - if "gateway.ai.cloudflare.com" in api_base: - if not api_base.endswith("/"): - api_base += "/" - azure_model = model_name.replace("azure/", "") - api_base += f"{azure_model}" - cache_key = f"{model_id}_async_client" - _client = openai.AsyncAzureOpenAI( - api_key=api_key, - azure_ad_token=azure_ad_token, - base_url=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_client" - _client = openai.AzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - base_url=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - # streaming clients can have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_stream_client" - _client = openai.AzureOpenAI( # type: ignore - api_key=api_key, - azure_ad_token=azure_ad_token, - base_url=api_base, - api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - else: - _api_key = api_key - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_router_logger.debug( - f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" - ) - azure_client_params = { - "api_key": api_key, - "azure_endpoint": api_base, - "api_version": api_version, - "azure_ad_token": azure_ad_token, - } - from litellm.llms.azure import select_azure_base_url_or_endpoint - - # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client - # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client - azure_client_params = select_azure_base_url_or_endpoint( - azure_client_params - ) - - cache_key = f"{model_id}_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - **azure_client_params, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_client" - _client = openai.AzureOpenAI( # type: ignore - **azure_client_params, - timeout=timeout, - max_retries=max_retries, - http_client=httpx.Client( - transport=CustomHTTPTransport( - verify=litellm.ssl_verify, - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncAzureOpenAI( # type: ignore - **azure_client_params, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_stream_client" - _client = openai.AzureOpenAI( # type: ignore - **azure_client_params, - timeout=stream_timeout, - max_retries=max_retries, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - else: - _api_key = api_key # type: ignore - if _api_key is not None and isinstance(_api_key, str): - # only show first 5 chars of api_key - _api_key = _api_key[:8] + "*" * 15 - verbose_router_logger.debug( - f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}" - ) - cache_key = f"{model_id}_async_client" - _client = openai.AsyncOpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - cache_key = f"{model_id}_client" - _client = openai.OpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_async_client" - _client = openai.AsyncOpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.AsyncClient( - transport=AsyncCustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=async_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - - # streaming clients should have diff timeouts - cache_key = f"{model_id}_stream_client" - _client = openai.OpenAI( # type: ignore - api_key=api_key, - base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, - organization=organization, - http_client=httpx.Client( - transport=CustomHTTPTransport( - limits=httpx.Limits( - max_connections=1000, max_keepalive_connections=100 - ), - verify=litellm.ssl_verify, - ), - mounts=sync_proxy_mounts, - ), # type: ignore - ) - self.cache.set_cache( - key=cache_key, - value=_client, - ttl=client_ttl, - local_only=True, - ) # cache for 1 hr - def _generate_model_id(self, model_group: str, litellm_params: dict): """ Helper function to consistently generate the same id for a deployment @@ -3904,7 +3399,9 @@ class Router: raise Exception(f"Unsupported provider - {custom_llm_provider}") # init OpenAI, Azure clients - self.set_client(model=deployment.to_json(exclude_none=True)) + set_client( + litellm_router_instance=self, model=deployment.to_json(exclude_none=True) + ) # set region (if azure model) ## PREVIEW FEATURE ## if litellm.enable_preview_features == True: @@ -4432,7 +3929,7 @@ class Router: """ Re-initialize the client """ - self.set_client(model=deployment) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: @@ -4442,7 +3939,7 @@ class Router: """ Re-initialize the client """ - self.set_client(model=deployment) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key, local_only=True) return client else: @@ -4453,7 +3950,7 @@ class Router: """ Re-initialize the client """ - self.set_client(model=deployment) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key) return client else: @@ -4463,7 +3960,7 @@ class Router: """ Re-initialize the client """ - self.set_client(model=deployment) + set_client(litellm_router_instance=self, model=deployment) client = self.cache.get_cache(key=cache_key) return client diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py new file mode 100644 index 0000000000..0160ffda13 --- /dev/null +++ b/litellm/router_utils/client_initalization_utils.py @@ -0,0 +1,566 @@ +import asyncio +import traceback +from typing import TYPE_CHECKING, Any + +import openai + +import litellm +from litellm._logging import verbose_router_logger +from litellm.llms.azure import get_azure_ad_token_from_oidc +from litellm.llms.custom_httpx.azure_dall_e_2 import ( + AsyncCustomHTTPTransport, + CustomHTTPTransport, +) +from litellm.utils import calculate_max_parallel_requests + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +def should_initialize_sync_client( + litellm_router_instance: LitellmRouter, +) -> bool: + """ + Returns if Sync OpenAI, Azure Clients should be initialized. + + Do not init sync clients when router.router_general_settings.async_only_mode is True + + """ + if litellm_router_instance is None: + return False + + if litellm_router_instance.router_general_settings is not None: + if ( + hasattr(litellm_router_instance, "router_general_settings") + and hasattr( + litellm_router_instance.router_general_settings, "async_only_mode" + ) + and litellm_router_instance.router_general_settings.async_only_mode is True + ): + return False + + return True + + +def set_client(litellm_router_instance: LitellmRouter, model: dict): + """ + - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 + - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994 + """ + client_ttl = litellm_router_instance.client_ttl + litellm_params = model.get("litellm_params", {}) + model_name = litellm_params.get("model") + model_id = model["model_info"]["id"] + # ### IF RPM SET - initialize a semaphore ### + rpm = litellm_params.get("rpm", None) + tpm = litellm_params.get("tpm", None) + max_parallel_requests = litellm_params.get("max_parallel_requests", None) + calculated_max_parallel_requests = calculate_max_parallel_requests( + rpm=rpm, + max_parallel_requests=max_parallel_requests, + tpm=tpm, + default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests, + ) + if calculated_max_parallel_requests: + semaphore = asyncio.Semaphore(calculated_max_parallel_requests) + cache_key = f"{model_id}_max_parallel_requests_client" + litellm_router_instance.cache.set_cache( + key=cache_key, + value=semaphore, + local_only=True, + ) + + #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## + custom_llm_provider = litellm_params.get("custom_llm_provider") + custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" + default_api_base = None + default_api_key = None + if custom_llm_provider in litellm.openai_compatible_providers: + _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider( + model=model_name + ) + default_api_base = api_base + default_api_key = api_key + + if ( + model_name in litellm.open_ai_chat_completion_models + or custom_llm_provider in litellm.openai_compatible_providers + or custom_llm_provider == "azure" + or custom_llm_provider == "azure_text" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + or "ft:gpt-3.5-turbo" in model_name + or model_name in litellm.open_ai_embedding_models + ): + is_azure_ai_studio_model: bool = False + if custom_llm_provider == "azure": + if litellm.utils._is_non_openai_azure_model(model_name): + is_azure_ai_studio_model = True + custom_llm_provider = "openai" + # remove azure prefx from model_name + model_name = model_name.replace("azure/", "") + # glorified / complicated reading of configs + # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env + # we do this here because we init clients for Azure, OpenAI and we need to set the right key + api_key = litellm_params.get("api_key") or default_api_key + if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"): + api_key_env_name = api_key.replace("os.environ/", "") + api_key = litellm.get_secret(api_key_env_name) + litellm_params["api_key"] = api_key + + api_base = litellm_params.get("api_base") + base_url = litellm_params.get("base_url") + api_base = ( + api_base or base_url or default_api_base + ) # allow users to pass in `api_base` or `base_url` for azure + if api_base and api_base.startswith("os.environ/"): + api_base_env_name = api_base.replace("os.environ/", "") + api_base = litellm.get_secret(api_base_env_name) + litellm_params["api_base"] = api_base + + ## AZURE AI STUDIO MISTRAL CHECK ## + """ + Make sure api base ends in /v1/ + + if not, add it - https://github.com/BerriAI/litellm/issues/2279 + """ + if ( + is_azure_ai_studio_model is True + and api_base is not None + and isinstance(api_base, str) + and not api_base.endswith("/v1/") + ): + # check if it ends with a trailing slash + if api_base.endswith("/"): + api_base += "v1/" + elif api_base.endswith("/v1"): + api_base += "/" + else: + api_base += "/v1/" + + api_version = litellm_params.get("api_version") + if api_version and api_version.startswith("os.environ/"): + api_version_env_name = api_version.replace("os.environ/", "") + api_version = litellm.get_secret(api_version_env_name) + litellm_params["api_version"] = api_version + + timeout = litellm_params.pop("timeout", None) or litellm.request_timeout + if isinstance(timeout, str) and timeout.startswith("os.environ/"): + timeout_env_name = timeout.replace("os.environ/", "") + timeout = litellm.get_secret(timeout_env_name) + litellm_params["timeout"] = timeout + + stream_timeout = litellm_params.pop( + "stream_timeout", timeout + ) # if no stream_timeout is set, default to timeout + if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): + stream_timeout_env_name = stream_timeout.replace("os.environ/", "") + stream_timeout = litellm.get_secret(stream_timeout_env_name) + litellm_params["stream_timeout"] = stream_timeout + + max_retries = litellm_params.pop("max_retries", 0) # router handles retry logic + if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): + max_retries_env_name = max_retries.replace("os.environ/", "") + max_retries = litellm.get_secret(max_retries_env_name) + litellm_params["max_retries"] = max_retries + + # proxy support + import os + + import httpx + + # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly. + http_proxy = os.getenv("HTTP_PROXY", None) + https_proxy = os.getenv("HTTPS_PROXY", None) + no_proxy = os.getenv("NO_PROXY", None) + + # Create the proxies dictionary only if the environment variables are set. + sync_proxy_mounts = None + async_proxy_mounts = None + if http_proxy is not None and https_proxy is not None: + sync_proxy_mounts = { + "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)), + "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)), + } + async_proxy_mounts = { + "http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)), + "https://": httpx.AsyncHTTPTransport( + proxy=httpx.Proxy(url=https_proxy) + ), + } + + # assume no_proxy is a list of comma separated urls + if no_proxy is not None and isinstance(no_proxy, str): + no_proxy_urls = no_proxy.split(",") + + for url in no_proxy_urls: # set no-proxy support for specific urls + sync_proxy_mounts[url] = None # type: ignore + async_proxy_mounts[url] = None # type: ignore + + organization = litellm_params.get("organization", None) + if isinstance(organization, str) and organization.startswith("os.environ/"): + organization_env_name = organization.replace("os.environ/", "") + organization = litellm.get_secret(organization_env_name) + litellm_params["organization"] = organization + + if custom_llm_provider == "azure" or custom_llm_provider == "azure_text": + if api_base is None or not isinstance(api_base, str): + filtered_litellm_params = { + k: v for k, v in model["litellm_params"].items() if k != "api_key" + } + _filtered_model = { + "model_name": model["model_name"], + "litellm_params": filtered_litellm_params, + } + raise ValueError( + f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}" + ) + azure_ad_token = litellm_params.get("azure_ad_token") + if azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + if api_version is None: + api_version = litellm.AZURE_DEFAULT_API_VERSION + + if "gateway.ai.cloudflare.com" in api_base: + if not api_base.endswith("/"): + api_base += "/" + azure_model = model_name.replace("azure/", "") + api_base += f"{azure_model}" + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( + api_key=api_key, + azure_ad_token=azure_ad_token, + base_url=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + base_url=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + # streaming clients can have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + api_key=api_key, + azure_ad_token=azure_ad_token, + base_url=api_base, + api_version=api_version, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + else: + _api_key = api_key + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_router_logger.debug( + f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" + ) + azure_client_params = { + "api_key": api_key, + "azure_endpoint": api_base, + "api_version": api_version, + "azure_ad_token": azure_ad_token, + } + from litellm.llms.azure import select_azure_base_url_or_endpoint + + # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client + # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params + ) + + cache_key = f"{model_id}_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + **azure_client_params, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.AzureOpenAI( # type: ignore + **azure_client_params, + timeout=timeout, + max_retries=max_retries, + http_client=httpx.Client( + transport=CustomHTTPTransport( + verify=litellm.ssl_verify, + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncAzureOpenAI( # type: ignore + **azure_client_params, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_stream_client" + _client = openai.AzureOpenAI( # type: ignore + **azure_client_params, + timeout=stream_timeout, + max_retries=max_retries, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + else: + _api_key = api_key # type: ignore + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_router_logger.debug( + f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}" + ) + cache_key = f"{model_id}_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + cache_key = f"{model_id}_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_async_client" + _client = openai.AsyncOpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.AsyncClient( + transport=AsyncCustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=async_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr + + if should_initialize_sync_client( + litellm_router_instance=litellm_router_instance + ): + # streaming clients should have diff timeouts + cache_key = f"{model_id}_stream_client" + _client = openai.OpenAI( # type: ignore + api_key=api_key, + base_url=api_base, + timeout=stream_timeout, + max_retries=max_retries, + organization=organization, + http_client=httpx.Client( + transport=CustomHTTPTransport( + limits=httpx.Limits( + max_connections=1000, max_keepalive_connections=100 + ), + verify=litellm.ssl_verify, + ), + mounts=sync_proxy_mounts, + ), # type: ignore + ) + litellm_router_instance.cache.set_cache( + key=cache_key, + value=_client, + ttl=client_ttl, + local_only=True, + ) # cache for 1 hr diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py index f0f0cc541c..13167c10f0 100644 --- a/litellm/tests/test_router_init.py +++ b/litellm/tests/test_router_init.py @@ -1,16 +1,22 @@ # this tests if the router is initialized correctly -import sys, os, time -import traceback, asyncio +import asyncio +import os +import sys +import time +import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor + +from dotenv import load_dotenv + import litellm from litellm import Router -from concurrent.futures import ThreadPoolExecutor -from collections import defaultdict -from dotenv import load_dotenv load_dotenv() @@ -24,6 +30,7 @@ load_dotenv() def test_init_clients(): litellm.set_verbose = True import logging + from litellm._logging import verbose_router_logger verbose_router_logger.setLevel(logging.DEBUG) @@ -489,6 +496,7 @@ def test_init_clients_azure_command_r_plus(): # For azure/command-r-plus we need to use openai.OpenAI because of how the Azure provider requires requests being sent litellm.set_verbose = True import logging + from litellm._logging import verbose_router_logger verbose_router_logger.setLevel(logging.DEBUG) @@ -585,3 +593,46 @@ async def test_text_completion_with_organization(): except Exception as e: pytest.fail(f"Error occurred: {e}") + + +def test_init_clients_async_mode(): + litellm.set_verbose = True + import logging + + from litellm._logging import verbose_router_logger + from litellm.types.router import RouterGeneralSettings + + verbose_router_logger.setLevel(logging.DEBUG) + try: + print("testing init 4 clients with diff timeouts") + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + "timeout": 0.01, + "stream_timeout": 0.000_001, + "max_retries": 7, + }, + }, + ] + router = Router( + model_list=model_list, + set_verbose=True, + router_general_settings=RouterGeneralSettings(async_only_mode=True), + ) + for elem in router.model_list: + model_id = elem["model_info"]["id"] + + # sync clients not initialized in async_only_mode=True + assert router.cache.get_cache(f"{model_id}_client") is None + assert router.cache.get_cache(f"{model_id}_stream_client") is None + + # only async clients initialized in async_only_mode=True + assert router.cache.get_cache(f"{model_id}_async_client") is not None + assert router.cache.get_cache(f"{model_id}_stream_async_client") is not None + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_stream_chunk_builder.py b/litellm/tests/test_stream_chunk_builder.py index 001ae07e09..342b070ae7 100644 --- a/litellm/tests/test_stream_chunk_builder.py +++ b/litellm/tests/test_stream_chunk_builder.py @@ -1,15 +1,22 @@ -import sys, os, time -import traceback, asyncio +import asyncio +import os +import sys +import time +import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from litellm import completion, stream_chunk_builder -import litellm -import os, dotenv -from openai import OpenAI +import os + +import dotenv import pytest +from openai import OpenAI + +import litellm +from litellm import completion, stream_chunk_builder dotenv.load_dotenv() @@ -147,3 +154,45 @@ def test_stream_chunk_builder_litellm_tool_call_regular_message(): # test_stream_chunk_builder_litellm_tool_call_regular_message() + + +def test_stream_chunk_builder_litellm_usage_chunks(): + """ + Checks if stream_chunk_builder is able to correctly rebuild with given metadata from streaming chunks + """ + messages = [ + {"role": "user", "content": "Tell me the funniest joke you know."}, + { + "role": "assistant", + "content": "Why did the chicken cross the road?\nYou will not guess this one I bet\n", + }, + {"role": "user", "content": "I do not know, why?"}, + {"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"}, + {"role": "user", "content": "\nI am waiting...\n\n...\n"}, + ] + # make a regular gemini call + response = completion( + model="gemini/gemini-1.5-flash", + messages=messages, + ) + + usage: litellm.Usage = response.usage + + gemini_pt = usage.prompt_tokens + + # make a streaming gemini call + response = completion( + model="gemini/gemini-1.5-flash", + messages=messages, + stream=True, + complete_response=True, + stream_options={"include_usage": True}, + ) + + usage: litellm.Usage = response.usage + + stream_rebuilt_pt = usage.prompt_tokens + + # assert prompt tokens are the same + + assert gemini_pt == stream_rebuilt_pt diff --git a/litellm/types/router.py b/litellm/types/router.py index 78d516d6c7..9c028fa87a 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -324,7 +324,12 @@ class DeploymentTypedDict(TypedDict): litellm_params: LiteLLMParamsTypedDict -SPECIAL_MODEL_INFO_PARAMS = ["input_cost_per_token", "output_cost_per_token"] +SPECIAL_MODEL_INFO_PARAMS = [ + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_character", + "output_cost_per_character", +] class Deployment(BaseModel): @@ -517,3 +522,9 @@ class CustomRoutingStrategyBase: """ pass + + +class RouterGeneralSettings(BaseModel): + async_only_mode: bool = Field( + default=False + ) # this will only initialize async clients. Good for memory utils diff --git a/ui/litellm-dashboard/src/components/model_dashboard.tsx b/ui/litellm-dashboard/src/components/model_dashboard.tsx index fb189b8c44..5456fb3029 100644 --- a/ui/litellm-dashboard/src/components/model_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/model_dashboard.tsx @@ -743,7 +743,7 @@ const ModelDashboard: React.FC = ({ } const fetchModelMap = async () => { - const data = await modelCostMap(); + const data = await modelCostMap(accessToken); console.log(`received model cost map data: ${Object.keys(data)}`); setModelMap(data); }; diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 9060c339e1..75819af58c 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -12,11 +12,19 @@ export interface Model { model_info: Object | null; } -export const modelCostMap = async () => { +export const modelCostMap = async ( + accessToken: string, +) => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/get/litellm_model_cost_map` : `/get/litellm_model_cost_map`; const response = await fetch( - url + url, { + method: "GET", + headers: { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + } ); const jsonData = await response.json(); console.log(`received litellm model cost data: ${jsonData}`); @@ -693,6 +701,9 @@ export const claimOnboardingToken = async ( throw error; } }; +let ModelListerrorShown = false; +let errorTimer: NodeJS.Timeout | null = null; + export const modelInfoCall = async ( accessToken: String, userID: String, @@ -714,8 +725,21 @@ export const modelInfoCall = async ( }); if (!response.ok) { - const errorData = await response.text(); - message.error(errorData, 10); + let errorData = await response.text(); + errorData += `error shown=${ModelListerrorShown}` + if (!ModelListerrorShown) { + if (errorData.includes("No model list passed")) { + errorData = "No Models Exist. Click Add Model to get started."; + } + message.info(errorData, 10); + ModelListerrorShown = true; + + if (errorTimer) clearTimeout(errorTimer); + errorTimer = setTimeout(() => { + ModelListerrorShown = false; + }, 10000); + } + throw new Error("Network response was not ok"); } @@ -750,7 +774,6 @@ export const modelHubCall = async (accessToken: String) => { if (!response.ok) { const errorData = await response.text(); - message.error(errorData, 10); throw new Error("Network response was not ok"); } diff --git a/ui/litellm-dashboard/src/components/usage.tsx b/ui/litellm-dashboard/src/components/usage.tsx index 1ac91dd109..451eaefe67 100644 --- a/ui/litellm-dashboard/src/components/usage.tsx +++ b/ui/litellm-dashboard/src/components/usage.tsx @@ -32,7 +32,6 @@ import { allTagNamesCall, modelMetricsCall, modelAvailableCall, - modelInfoCall, adminspendByProvider, adminGlobalActivity, adminGlobalActivityPerModel,