diff --git a/docs/my-website/docs/proxy/user_keys.md b/docs/my-website/docs/proxy/user_keys.md
index cc1d5fe82..3e0fec7fb 100644
--- a/docs/my-website/docs/proxy/user_keys.md
+++ b/docs/my-website/docs/proxy/user_keys.md
@@ -173,6 +173,37 @@ console.log(message);
```
+
+
+
+```python
+from openai import OpenAI
+import instructor
+from pydantic import BaseModel
+
+my_proxy_api_key = "" # e.g. sk-1234
+my_proxy_base_url = "" # e.g. http://0.0.0.0:4000
+
+# This enables response_model keyword
+# from client.chat.completions.create
+client = instructor.from_openai(OpenAI(api_key=my_proxy_api_key, base_url=my_proxy_base_url))
+
+class UserDetail(BaseModel):
+ name: str
+ age: int
+
+user = client.chat.completions.create(
+ model="gemini-pro-flash",
+ response_model=UserDetail,
+ messages=[
+ {"role": "user", "content": "Extract Jason is 25 years old"},
+ ]
+)
+
+assert isinstance(user, UserDetail)
+assert user.name == "Jason"
+assert user.age == 25
+```
diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py
index e4963a6f1..0bc65a7f1 100644
--- a/litellm/cost_calculator.py
+++ b/litellm/cost_calculator.py
@@ -4,6 +4,8 @@ import time
import traceback
from typing import List, Literal, Optional, Tuple, Union
+from pydantic import BaseModel
+
import litellm
import litellm._logging
from litellm import verbose_logger
@@ -14,6 +16,9 @@ from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_token as google_cost_per_token,
)
from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character
+from litellm.types.llms.openai import HttpxBinaryResponseContent
+from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
+
from litellm.utils import (
CallTypes,
CostPerToken,
@@ -469,7 +474,10 @@ def completion_cost(
prompt_characters = 0
completion_tokens = 0
completion_characters = 0
- if completion_response is not None:
+ if completion_response is not None and (
+ isinstance(completion_response, BaseModel)
+ or isinstance(completion_response, dict)
+ ): # tts returns a custom class
# get input/output tokens from completion_response
prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = completion_response.get("usage", {}).get(
@@ -654,6 +662,7 @@ def response_cost_calculator(
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
+ HttpxBinaryResponseContent,
],
model: str,
custom_llm_provider: Optional[str],
@@ -687,7 +696,8 @@ def response_cost_calculator(
if cache_hit is not None and cache_hit is True:
response_cost = 0.0
else:
- response_object._hidden_params["optional_params"] = optional_params
+ if isinstance(response_object, BaseModel):
+ response_object._hidden_params["optional_params"] = optional_params
if isinstance(response_object, ImageResponse):
response_cost = completion_cost(
completion_response=response_object,
@@ -697,12 +707,11 @@ def response_cost_calculator(
)
else:
if (
- model in litellm.model_cost
- and custom_pricing is not None
- and custom_llm_provider is True
+ model in litellm.model_cost or custom_pricing is True
): # override defaults if custom pricing is set
base_model = model
# base_model defaults to None if not set on model_info
+
response_cost = completion_cost(
completion_response=response_object,
call_type=call_type,
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index 381bcc1ac..f9f32552d 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -24,6 +24,8 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
)
+from litellm.types.llms.openai import HttpxBinaryResponseContent
+from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import (
CallTypes,
EmbeddingResponse,
@@ -517,33 +519,36 @@ class Logging:
self.model_call_details["cache_hit"] = cache_hit
## if model in model cost map - log the response cost
## else set cost to None
- verbose_logger.debug(f"Model={self.model};")
if (
- result is not None
- and (
+ result is not None and self.stream is not True
+ ): # handle streaming separately
+ if (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse)
or isinstance(result, TextCompletionResponse)
- )
- and self.stream != True
- ): # handle streaming separately
- self.model_call_details["response_cost"] = (
- litellm.response_cost_calculator(
- response_object=result,
- model=self.model,
- cache_hit=self.model_call_details.get("cache_hit", False),
- custom_llm_provider=self.model_call_details.get(
- "custom_llm_provider", None
- ),
- base_model=_get_base_model_from_metadata(
- model_call_details=self.model_call_details
- ),
- call_type=self.call_type,
- optional_params=self.optional_params,
+ or isinstance(result, HttpxBinaryResponseContent) # tts
+ ):
+ custom_pricing = use_custom_pricing_for_model(
+ litellm_params=self.litellm_params
+ )
+ self.model_call_details["response_cost"] = (
+ litellm.response_cost_calculator(
+ response_object=result,
+ model=self.model,
+ cache_hit=self.model_call_details.get("cache_hit", False),
+ custom_llm_provider=self.model_call_details.get(
+ "custom_llm_provider", None
+ ),
+ base_model=_get_base_model_from_metadata(
+ model_call_details=self.model_call_details
+ ),
+ call_type=self.call_type,
+ optional_params=self.optional_params,
+ custom_pricing=custom_pricing,
+ )
)
- )
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None
@@ -2012,3 +2017,17 @@ def get_custom_logger_compatible_class(
if isinstance(callback, _PROXY_DynamicRateLimitHandler):
return callback # type: ignore
return None
+
+
+def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool:
+ if litellm_params is None:
+ return False
+ metadata: Optional[dict] = litellm_params.get("metadata", {})
+ if metadata is None:
+ return False
+ model_info: Optional[dict] = metadata.get("model_info", {})
+ if model_info is not None:
+ for k, v in model_info.items():
+ if k in SPECIAL_MODEL_INFO_PARAMS:
+ return True
+ return False
diff --git a/litellm/main.py b/litellm/main.py
index 2a7759e8a..37ae125b9 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -5022,10 +5022,9 @@ def stream_chunk_builder(
for chunk in chunks:
if "usage" in chunk:
if "prompt_tokens" in chunk["usage"]:
- prompt_tokens += chunk["usage"].get("prompt_tokens", 0) or 0
+ prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
if "completion_tokens" in chunk["usage"]:
- completion_tokens += chunk["usage"].get("completion_tokens", 0) or 0
-
+ completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
try:
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index 7f4b86ec4..a7d46d506 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -1,12 +1,9 @@
model_list:
- - model_name: "*"
+ - model_name: tts
litellm_params:
model: "openai/*"
- mock_response: "Hello world!"
-
litellm_settings:
success_callback: ["langfuse"]
- failure_callback: ["langfuse"]
general_settings:
alerting: ["slack"]
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 5011b64b3..f954b9dd5 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -207,6 +207,7 @@ from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.types.llms.openai import HttpxBinaryResponseContent
+from litellm.types.router import RouterGeneralSettings
try:
from litellm._version import version
@@ -1765,7 +1766,11 @@ class ProxyConfig:
if k in available_args:
router_params[k] = v
router = litellm.Router(
- **router_params, assistants_config=assistants_config
+ **router_params,
+ assistants_config=assistants_config,
+ router_general_settings=RouterGeneralSettings(
+ async_only_mode=True # only init async clients
+ ),
) # type:ignore
return router, router.get_model_list(), general_settings
@@ -1957,7 +1962,12 @@ class ProxyConfig:
)
if len(_model_list) > 0:
verbose_proxy_logger.debug(f"_model_list: {_model_list}")
- llm_router = litellm.Router(model_list=_model_list)
+ llm_router = litellm.Router(
+ model_list=_model_list,
+ router_general_settings=RouterGeneralSettings(
+ async_only_mode=True # only init async clients
+ ),
+ )
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
else:
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
diff --git a/litellm/router.py b/litellm/router.py
index b54d70dbb..db68197a4 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -55,6 +55,10 @@ from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
+from litellm.router_utils.client_initalization_utils import (
+ set_client,
+ should_initialize_sync_client,
+)
from litellm.router_utils.handle_error import send_llm_exception_alert
from litellm.scheduler import FlowItem, Scheduler
from litellm.types.llms.openai import (
@@ -79,6 +83,7 @@ from litellm.types.router import (
ModelInfo,
RetryPolicy,
RouterErrors,
+ RouterGeneralSettings,
updateDeployment,
updateLiteLLMParams,
)
@@ -169,6 +174,7 @@ class Router:
routing_strategy_args: dict = {}, # just for latency-based routing
semaphore: Optional[asyncio.Semaphore] = None,
alerting_config: Optional[AlertingConfig] = None,
+ router_general_settings: Optional[RouterGeneralSettings] = None,
) -> None:
"""
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
@@ -246,6 +252,9 @@ class Router:
verbose_router_logger.setLevel(logging.INFO)
elif debug_level == "DEBUG":
verbose_router_logger.setLevel(logging.DEBUG)
+ self.router_general_settings: Optional[RouterGeneralSettings] = (
+ router_general_settings
+ )
self.assistants_config = assistants_config
self.deployment_names: List = (
@@ -3247,520 +3256,6 @@ class Router:
except Exception as e:
raise e
- def set_client(self, model: dict):
- """
- - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
- """
- client_ttl = self.client_ttl
- litellm_params = model.get("litellm_params", {})
- model_name = litellm_params.get("model")
- model_id = model["model_info"]["id"]
- # ### IF RPM SET - initialize a semaphore ###
- rpm = litellm_params.get("rpm", None)
- tpm = litellm_params.get("tpm", None)
- max_parallel_requests = litellm_params.get("max_parallel_requests", None)
- calculated_max_parallel_requests = calculate_max_parallel_requests(
- rpm=rpm,
- max_parallel_requests=max_parallel_requests,
- tpm=tpm,
- default_max_parallel_requests=self.default_max_parallel_requests,
- )
- if calculated_max_parallel_requests:
- semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
- cache_key = f"{model_id}_max_parallel_requests_client"
- self.cache.set_cache(
- key=cache_key,
- value=semaphore,
- local_only=True,
- )
-
- #### for OpenAI / Azure we need to initalize the Client for High Traffic ########
- custom_llm_provider = litellm_params.get("custom_llm_provider")
- custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
- default_api_base = None
- default_api_key = None
- if custom_llm_provider in litellm.openai_compatible_providers:
- _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
- model=model_name
- )
- default_api_base = api_base
- default_api_key = api_key
-
- if (
- model_name in litellm.open_ai_chat_completion_models
- or custom_llm_provider in litellm.openai_compatible_providers
- or custom_llm_provider == "azure"
- or custom_llm_provider == "azure_text"
- or custom_llm_provider == "custom_openai"
- or custom_llm_provider == "openai"
- or custom_llm_provider == "text-completion-openai"
- or "ft:gpt-3.5-turbo" in model_name
- or model_name in litellm.open_ai_embedding_models
- ):
- is_azure_ai_studio_model: bool = False
- if custom_llm_provider == "azure":
- if litellm.utils._is_non_openai_azure_model(model_name):
- is_azure_ai_studio_model = True
- custom_llm_provider = "openai"
- # remove azure prefx from model_name
- model_name = model_name.replace("azure/", "")
- # glorified / complicated reading of configs
- # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
- # we do this here because we init clients for Azure, OpenAI and we need to set the right key
- api_key = litellm_params.get("api_key") or default_api_key
- if (
- api_key
- and isinstance(api_key, str)
- and api_key.startswith("os.environ/")
- ):
- api_key_env_name = api_key.replace("os.environ/", "")
- api_key = litellm.get_secret(api_key_env_name)
- litellm_params["api_key"] = api_key
-
- api_base = litellm_params.get("api_base")
- base_url = litellm_params.get("base_url")
- api_base = (
- api_base or base_url or default_api_base
- ) # allow users to pass in `api_base` or `base_url` for azure
- if api_base and api_base.startswith("os.environ/"):
- api_base_env_name = api_base.replace("os.environ/", "")
- api_base = litellm.get_secret(api_base_env_name)
- litellm_params["api_base"] = api_base
-
- ## AZURE AI STUDIO MISTRAL CHECK ##
- """
- Make sure api base ends in /v1/
-
- if not, add it - https://github.com/BerriAI/litellm/issues/2279
- """
- if (
- is_azure_ai_studio_model is True
- and api_base is not None
- and isinstance(api_base, str)
- and not api_base.endswith("/v1/")
- ):
- # check if it ends with a trailing slash
- if api_base.endswith("/"):
- api_base += "v1/"
- elif api_base.endswith("/v1"):
- api_base += "/"
- else:
- api_base += "/v1/"
-
- api_version = litellm_params.get("api_version")
- if api_version and api_version.startswith("os.environ/"):
- api_version_env_name = api_version.replace("os.environ/", "")
- api_version = litellm.get_secret(api_version_env_name)
- litellm_params["api_version"] = api_version
-
- timeout = litellm_params.pop("timeout", None) or litellm.request_timeout
- if isinstance(timeout, str) and timeout.startswith("os.environ/"):
- timeout_env_name = timeout.replace("os.environ/", "")
- timeout = litellm.get_secret(timeout_env_name)
- litellm_params["timeout"] = timeout
-
- stream_timeout = litellm_params.pop(
- "stream_timeout", timeout
- ) # if no stream_timeout is set, default to timeout
- if isinstance(stream_timeout, str) and stream_timeout.startswith(
- "os.environ/"
- ):
- stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
- stream_timeout = litellm.get_secret(stream_timeout_env_name)
- litellm_params["stream_timeout"] = stream_timeout
-
- max_retries = litellm_params.pop(
- "max_retries", 0
- ) # router handles retry logic
- if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
- max_retries_env_name = max_retries.replace("os.environ/", "")
- max_retries = litellm.get_secret(max_retries_env_name)
- litellm_params["max_retries"] = max_retries
-
- # proxy support
- import os
-
- import httpx
-
- # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
- http_proxy = os.getenv("HTTP_PROXY", None)
- https_proxy = os.getenv("HTTPS_PROXY", None)
- no_proxy = os.getenv("NO_PROXY", None)
-
- # Create the proxies dictionary only if the environment variables are set.
- sync_proxy_mounts = None
- async_proxy_mounts = None
- if http_proxy is not None and https_proxy is not None:
- sync_proxy_mounts = {
- "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
- "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
- }
- async_proxy_mounts = {
- "http://": httpx.AsyncHTTPTransport(
- proxy=httpx.Proxy(url=http_proxy)
- ),
- "https://": httpx.AsyncHTTPTransport(
- proxy=httpx.Proxy(url=https_proxy)
- ),
- }
-
- # assume no_proxy is a list of comma separated urls
- if no_proxy is not None and isinstance(no_proxy, str):
- no_proxy_urls = no_proxy.split(",")
-
- for url in no_proxy_urls: # set no-proxy support for specific urls
- sync_proxy_mounts[url] = None # type: ignore
- async_proxy_mounts[url] = None # type: ignore
-
- organization = litellm_params.get("organization", None)
- if isinstance(organization, str) and organization.startswith("os.environ/"):
- organization_env_name = organization.replace("os.environ/", "")
- organization = litellm.get_secret(organization_env_name)
- litellm_params["organization"] = organization
-
- if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
- if api_base is None or not isinstance(api_base, str):
- filtered_litellm_params = {
- k: v
- for k, v in model["litellm_params"].items()
- if k != "api_key"
- }
- _filtered_model = {
- "model_name": model["model_name"],
- "litellm_params": filtered_litellm_params,
- }
- raise ValueError(
- f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}"
- )
- azure_ad_token = litellm_params.get("azure_ad_token")
- if azure_ad_token is not None:
- if azure_ad_token.startswith("oidc/"):
- azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
- if api_version is None:
- api_version = litellm.AZURE_DEFAULT_API_VERSION
-
- if "gateway.ai.cloudflare.com" in api_base:
- if not api_base.endswith("/"):
- api_base += "/"
- azure_model = model_name.replace("azure/", "")
- api_base += f"{azure_model}"
- cache_key = f"{model_id}_async_client"
- _client = openai.AsyncAzureOpenAI(
- api_key=api_key,
- azure_ad_token=azure_ad_token,
- base_url=api_base,
- api_version=api_version,
- timeout=timeout,
- max_retries=max_retries,
- http_client=httpx.AsyncClient(
- transport=AsyncCustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=async_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
- cache_key = f"{model_id}_client"
- _client = openai.AzureOpenAI( # type: ignore
- api_key=api_key,
- azure_ad_token=azure_ad_token,
- base_url=api_base,
- api_version=api_version,
- timeout=timeout,
- max_retries=max_retries,
- http_client=httpx.Client(
- transport=CustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=sync_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
- # streaming clients can have diff timeouts
- cache_key = f"{model_id}_stream_async_client"
- _client = openai.AsyncAzureOpenAI( # type: ignore
- api_key=api_key,
- azure_ad_token=azure_ad_token,
- base_url=api_base,
- api_version=api_version,
- timeout=stream_timeout,
- max_retries=max_retries,
- http_client=httpx.AsyncClient(
- transport=AsyncCustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=async_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
- cache_key = f"{model_id}_stream_client"
- _client = openai.AzureOpenAI( # type: ignore
- api_key=api_key,
- azure_ad_token=azure_ad_token,
- base_url=api_base,
- api_version=api_version,
- timeout=stream_timeout,
- max_retries=max_retries,
- http_client=httpx.Client(
- transport=CustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=sync_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
- else:
- _api_key = api_key
- if _api_key is not None and isinstance(_api_key, str):
- # only show first 5 chars of api_key
- _api_key = _api_key[:8] + "*" * 15
- verbose_router_logger.debug(
- f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
- )
- azure_client_params = {
- "api_key": api_key,
- "azure_endpoint": api_base,
- "api_version": api_version,
- "azure_ad_token": azure_ad_token,
- }
- from litellm.llms.azure import select_azure_base_url_or_endpoint
-
- # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
- # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
- azure_client_params = select_azure_base_url_or_endpoint(
- azure_client_params
- )
-
- cache_key = f"{model_id}_async_client"
- _client = openai.AsyncAzureOpenAI( # type: ignore
- **azure_client_params,
- timeout=timeout,
- max_retries=max_retries,
- http_client=httpx.AsyncClient(
- transport=AsyncCustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=async_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
- cache_key = f"{model_id}_client"
- _client = openai.AzureOpenAI( # type: ignore
- **azure_client_params,
- timeout=timeout,
- max_retries=max_retries,
- http_client=httpx.Client(
- transport=CustomHTTPTransport(
- verify=litellm.ssl_verify,
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- ),
- mounts=sync_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
- # streaming clients should have diff timeouts
- cache_key = f"{model_id}_stream_async_client"
- _client = openai.AsyncAzureOpenAI( # type: ignore
- **azure_client_params,
- timeout=stream_timeout,
- max_retries=max_retries,
- http_client=httpx.AsyncClient(
- transport=AsyncCustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=async_proxy_mounts,
- ),
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
- cache_key = f"{model_id}_stream_client"
- _client = openai.AzureOpenAI( # type: ignore
- **azure_client_params,
- timeout=stream_timeout,
- max_retries=max_retries,
- http_client=httpx.Client(
- transport=CustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=sync_proxy_mounts,
- ),
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
- else:
- _api_key = api_key # type: ignore
- if _api_key is not None and isinstance(_api_key, str):
- # only show first 5 chars of api_key
- _api_key = _api_key[:8] + "*" * 15
- verbose_router_logger.debug(
- f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}"
- )
- cache_key = f"{model_id}_async_client"
- _client = openai.AsyncOpenAI( # type: ignore
- api_key=api_key,
- base_url=api_base,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- http_client=httpx.AsyncClient(
- transport=AsyncCustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=async_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
- cache_key = f"{model_id}_client"
- _client = openai.OpenAI( # type: ignore
- api_key=api_key,
- base_url=api_base,
- timeout=timeout,
- max_retries=max_retries,
- organization=organization,
- http_client=httpx.Client(
- transport=CustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=sync_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
- # streaming clients should have diff timeouts
- cache_key = f"{model_id}_stream_async_client"
- _client = openai.AsyncOpenAI( # type: ignore
- api_key=api_key,
- base_url=api_base,
- timeout=stream_timeout,
- max_retries=max_retries,
- organization=organization,
- http_client=httpx.AsyncClient(
- transport=AsyncCustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=async_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
- # streaming clients should have diff timeouts
- cache_key = f"{model_id}_stream_client"
- _client = openai.OpenAI( # type: ignore
- api_key=api_key,
- base_url=api_base,
- timeout=stream_timeout,
- max_retries=max_retries,
- organization=organization,
- http_client=httpx.Client(
- transport=CustomHTTPTransport(
- limits=httpx.Limits(
- max_connections=1000, max_keepalive_connections=100
- ),
- verify=litellm.ssl_verify,
- ),
- mounts=sync_proxy_mounts,
- ), # type: ignore
- )
- self.cache.set_cache(
- key=cache_key,
- value=_client,
- ttl=client_ttl,
- local_only=True,
- ) # cache for 1 hr
-
def _generate_model_id(self, model_group: str, litellm_params: dict):
"""
Helper function to consistently generate the same id for a deployment
@@ -3904,7 +3399,9 @@ class Router:
raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients
- self.set_client(model=deployment.to_json(exclude_none=True))
+ set_client(
+ litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
+ )
# set region (if azure model) ## PREVIEW FEATURE ##
if litellm.enable_preview_features == True:
@@ -4432,7 +3929,7 @@ class Router:
"""
Re-initialize the client
"""
- self.set_client(model=deployment)
+ set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
@@ -4442,7 +3939,7 @@ class Router:
"""
Re-initialize the client
"""
- self.set_client(model=deployment)
+ set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
@@ -4453,7 +3950,7 @@ class Router:
"""
Re-initialize the client
"""
- self.set_client(model=deployment)
+ set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key)
return client
else:
@@ -4463,7 +3960,7 @@ class Router:
"""
Re-initialize the client
"""
- self.set_client(model=deployment)
+ set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key)
return client
diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py
new file mode 100644
index 000000000..0160ffda1
--- /dev/null
+++ b/litellm/router_utils/client_initalization_utils.py
@@ -0,0 +1,566 @@
+import asyncio
+import traceback
+from typing import TYPE_CHECKING, Any
+
+import openai
+
+import litellm
+from litellm._logging import verbose_router_logger
+from litellm.llms.azure import get_azure_ad_token_from_oidc
+from litellm.llms.custom_httpx.azure_dall_e_2 import (
+ AsyncCustomHTTPTransport,
+ CustomHTTPTransport,
+)
+from litellm.utils import calculate_max_parallel_requests
+
+if TYPE_CHECKING:
+ from litellm.router import Router as _Router
+
+ LitellmRouter = _Router
+else:
+ LitellmRouter = Any
+
+
+def should_initialize_sync_client(
+ litellm_router_instance: LitellmRouter,
+) -> bool:
+ """
+ Returns if Sync OpenAI, Azure Clients should be initialized.
+
+ Do not init sync clients when router.router_general_settings.async_only_mode is True
+
+ """
+ if litellm_router_instance is None:
+ return False
+
+ if litellm_router_instance.router_general_settings is not None:
+ if (
+ hasattr(litellm_router_instance, "router_general_settings")
+ and hasattr(
+ litellm_router_instance.router_general_settings, "async_only_mode"
+ )
+ and litellm_router_instance.router_general_settings.async_only_mode is True
+ ):
+ return False
+
+ return True
+
+
+def set_client(litellm_router_instance: LitellmRouter, model: dict):
+ """
+ - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
+ - Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
+ """
+ client_ttl = litellm_router_instance.client_ttl
+ litellm_params = model.get("litellm_params", {})
+ model_name = litellm_params.get("model")
+ model_id = model["model_info"]["id"]
+ # ### IF RPM SET - initialize a semaphore ###
+ rpm = litellm_params.get("rpm", None)
+ tpm = litellm_params.get("tpm", None)
+ max_parallel_requests = litellm_params.get("max_parallel_requests", None)
+ calculated_max_parallel_requests = calculate_max_parallel_requests(
+ rpm=rpm,
+ max_parallel_requests=max_parallel_requests,
+ tpm=tpm,
+ default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests,
+ )
+ if calculated_max_parallel_requests:
+ semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
+ cache_key = f"{model_id}_max_parallel_requests_client"
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=semaphore,
+ local_only=True,
+ )
+
+ #### for OpenAI / Azure we need to initalize the Client for High Traffic ########
+ custom_llm_provider = litellm_params.get("custom_llm_provider")
+ custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
+ default_api_base = None
+ default_api_key = None
+ if custom_llm_provider in litellm.openai_compatible_providers:
+ _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
+ model=model_name
+ )
+ default_api_base = api_base
+ default_api_key = api_key
+
+ if (
+ model_name in litellm.open_ai_chat_completion_models
+ or custom_llm_provider in litellm.openai_compatible_providers
+ or custom_llm_provider == "azure"
+ or custom_llm_provider == "azure_text"
+ or custom_llm_provider == "custom_openai"
+ or custom_llm_provider == "openai"
+ or custom_llm_provider == "text-completion-openai"
+ or "ft:gpt-3.5-turbo" in model_name
+ or model_name in litellm.open_ai_embedding_models
+ ):
+ is_azure_ai_studio_model: bool = False
+ if custom_llm_provider == "azure":
+ if litellm.utils._is_non_openai_azure_model(model_name):
+ is_azure_ai_studio_model = True
+ custom_llm_provider = "openai"
+ # remove azure prefx from model_name
+ model_name = model_name.replace("azure/", "")
+ # glorified / complicated reading of configs
+ # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
+ # we do this here because we init clients for Azure, OpenAI and we need to set the right key
+ api_key = litellm_params.get("api_key") or default_api_key
+ if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
+ api_key_env_name = api_key.replace("os.environ/", "")
+ api_key = litellm.get_secret(api_key_env_name)
+ litellm_params["api_key"] = api_key
+
+ api_base = litellm_params.get("api_base")
+ base_url = litellm_params.get("base_url")
+ api_base = (
+ api_base or base_url or default_api_base
+ ) # allow users to pass in `api_base` or `base_url` for azure
+ if api_base and api_base.startswith("os.environ/"):
+ api_base_env_name = api_base.replace("os.environ/", "")
+ api_base = litellm.get_secret(api_base_env_name)
+ litellm_params["api_base"] = api_base
+
+ ## AZURE AI STUDIO MISTRAL CHECK ##
+ """
+ Make sure api base ends in /v1/
+
+ if not, add it - https://github.com/BerriAI/litellm/issues/2279
+ """
+ if (
+ is_azure_ai_studio_model is True
+ and api_base is not None
+ and isinstance(api_base, str)
+ and not api_base.endswith("/v1/")
+ ):
+ # check if it ends with a trailing slash
+ if api_base.endswith("/"):
+ api_base += "v1/"
+ elif api_base.endswith("/v1"):
+ api_base += "/"
+ else:
+ api_base += "/v1/"
+
+ api_version = litellm_params.get("api_version")
+ if api_version and api_version.startswith("os.environ/"):
+ api_version_env_name = api_version.replace("os.environ/", "")
+ api_version = litellm.get_secret(api_version_env_name)
+ litellm_params["api_version"] = api_version
+
+ timeout = litellm_params.pop("timeout", None) or litellm.request_timeout
+ if isinstance(timeout, str) and timeout.startswith("os.environ/"):
+ timeout_env_name = timeout.replace("os.environ/", "")
+ timeout = litellm.get_secret(timeout_env_name)
+ litellm_params["timeout"] = timeout
+
+ stream_timeout = litellm_params.pop(
+ "stream_timeout", timeout
+ ) # if no stream_timeout is set, default to timeout
+ if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
+ stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
+ stream_timeout = litellm.get_secret(stream_timeout_env_name)
+ litellm_params["stream_timeout"] = stream_timeout
+
+ max_retries = litellm_params.pop("max_retries", 0) # router handles retry logic
+ if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
+ max_retries_env_name = max_retries.replace("os.environ/", "")
+ max_retries = litellm.get_secret(max_retries_env_name)
+ litellm_params["max_retries"] = max_retries
+
+ # proxy support
+ import os
+
+ import httpx
+
+ # Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
+ http_proxy = os.getenv("HTTP_PROXY", None)
+ https_proxy = os.getenv("HTTPS_PROXY", None)
+ no_proxy = os.getenv("NO_PROXY", None)
+
+ # Create the proxies dictionary only if the environment variables are set.
+ sync_proxy_mounts = None
+ async_proxy_mounts = None
+ if http_proxy is not None and https_proxy is not None:
+ sync_proxy_mounts = {
+ "http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
+ "https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
+ }
+ async_proxy_mounts = {
+ "http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
+ "https://": httpx.AsyncHTTPTransport(
+ proxy=httpx.Proxy(url=https_proxy)
+ ),
+ }
+
+ # assume no_proxy is a list of comma separated urls
+ if no_proxy is not None and isinstance(no_proxy, str):
+ no_proxy_urls = no_proxy.split(",")
+
+ for url in no_proxy_urls: # set no-proxy support for specific urls
+ sync_proxy_mounts[url] = None # type: ignore
+ async_proxy_mounts[url] = None # type: ignore
+
+ organization = litellm_params.get("organization", None)
+ if isinstance(organization, str) and organization.startswith("os.environ/"):
+ organization_env_name = organization.replace("os.environ/", "")
+ organization = litellm.get_secret(organization_env_name)
+ litellm_params["organization"] = organization
+
+ if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
+ if api_base is None or not isinstance(api_base, str):
+ filtered_litellm_params = {
+ k: v for k, v in model["litellm_params"].items() if k != "api_key"
+ }
+ _filtered_model = {
+ "model_name": model["model_name"],
+ "litellm_params": filtered_litellm_params,
+ }
+ raise ValueError(
+ f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}"
+ )
+ azure_ad_token = litellm_params.get("azure_ad_token")
+ if azure_ad_token is not None:
+ if azure_ad_token.startswith("oidc/"):
+ azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
+ if api_version is None:
+ api_version = litellm.AZURE_DEFAULT_API_VERSION
+
+ if "gateway.ai.cloudflare.com" in api_base:
+ if not api_base.endswith("/"):
+ api_base += "/"
+ azure_model = model_name.replace("azure/", "")
+ api_base += f"{azure_model}"
+ cache_key = f"{model_id}_async_client"
+ _client = openai.AsyncAzureOpenAI(
+ api_key=api_key,
+ azure_ad_token=azure_ad_token,
+ base_url=api_base,
+ api_version=api_version,
+ timeout=timeout,
+ max_retries=max_retries,
+ http_client=httpx.AsyncClient(
+ transport=AsyncCustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=async_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+
+ if should_initialize_sync_client(
+ litellm_router_instance=litellm_router_instance
+ ):
+ cache_key = f"{model_id}_client"
+ _client = openai.AzureOpenAI( # type: ignore
+ api_key=api_key,
+ azure_ad_token=azure_ad_token,
+ base_url=api_base,
+ api_version=api_version,
+ timeout=timeout,
+ max_retries=max_retries,
+ http_client=httpx.Client(
+ transport=CustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=sync_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+ # streaming clients can have diff timeouts
+ cache_key = f"{model_id}_stream_async_client"
+ _client = openai.AsyncAzureOpenAI( # type: ignore
+ api_key=api_key,
+ azure_ad_token=azure_ad_token,
+ base_url=api_base,
+ api_version=api_version,
+ timeout=stream_timeout,
+ max_retries=max_retries,
+ http_client=httpx.AsyncClient(
+ transport=AsyncCustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=async_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+
+ if should_initialize_sync_client(
+ litellm_router_instance=litellm_router_instance
+ ):
+ cache_key = f"{model_id}_stream_client"
+ _client = openai.AzureOpenAI( # type: ignore
+ api_key=api_key,
+ azure_ad_token=azure_ad_token,
+ base_url=api_base,
+ api_version=api_version,
+ timeout=stream_timeout,
+ max_retries=max_retries,
+ http_client=httpx.Client(
+ transport=CustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=sync_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+ else:
+ _api_key = api_key
+ if _api_key is not None and isinstance(_api_key, str):
+ # only show first 5 chars of api_key
+ _api_key = _api_key[:8] + "*" * 15
+ verbose_router_logger.debug(
+ f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
+ )
+ azure_client_params = {
+ "api_key": api_key,
+ "azure_endpoint": api_base,
+ "api_version": api_version,
+ "azure_ad_token": azure_ad_token,
+ }
+ from litellm.llms.azure import select_azure_base_url_or_endpoint
+
+ # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
+ # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
+ azure_client_params = select_azure_base_url_or_endpoint(
+ azure_client_params
+ )
+
+ cache_key = f"{model_id}_async_client"
+ _client = openai.AsyncAzureOpenAI( # type: ignore
+ **azure_client_params,
+ timeout=timeout,
+ max_retries=max_retries,
+ http_client=httpx.AsyncClient(
+ transport=AsyncCustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=async_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+ if should_initialize_sync_client(
+ litellm_router_instance=litellm_router_instance
+ ):
+ cache_key = f"{model_id}_client"
+ _client = openai.AzureOpenAI( # type: ignore
+ **azure_client_params,
+ timeout=timeout,
+ max_retries=max_retries,
+ http_client=httpx.Client(
+ transport=CustomHTTPTransport(
+ verify=litellm.ssl_verify,
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ ),
+ mounts=sync_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+
+ # streaming clients should have diff timeouts
+ cache_key = f"{model_id}_stream_async_client"
+ _client = openai.AsyncAzureOpenAI( # type: ignore
+ **azure_client_params,
+ timeout=stream_timeout,
+ max_retries=max_retries,
+ http_client=httpx.AsyncClient(
+ transport=AsyncCustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=async_proxy_mounts,
+ ),
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+
+ if should_initialize_sync_client(
+ litellm_router_instance=litellm_router_instance
+ ):
+ cache_key = f"{model_id}_stream_client"
+ _client = openai.AzureOpenAI( # type: ignore
+ **azure_client_params,
+ timeout=stream_timeout,
+ max_retries=max_retries,
+ http_client=httpx.Client(
+ transport=CustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=sync_proxy_mounts,
+ ),
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+
+ else:
+ _api_key = api_key # type: ignore
+ if _api_key is not None and isinstance(_api_key, str):
+ # only show first 5 chars of api_key
+ _api_key = _api_key[:8] + "*" * 15
+ verbose_router_logger.debug(
+ f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}"
+ )
+ cache_key = f"{model_id}_async_client"
+ _client = openai.AsyncOpenAI( # type: ignore
+ api_key=api_key,
+ base_url=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ http_client=httpx.AsyncClient(
+ transport=AsyncCustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=async_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+
+ if should_initialize_sync_client(
+ litellm_router_instance=litellm_router_instance
+ ):
+ cache_key = f"{model_id}_client"
+ _client = openai.OpenAI( # type: ignore
+ api_key=api_key,
+ base_url=api_base,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ http_client=httpx.Client(
+ transport=CustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=sync_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+
+ # streaming clients should have diff timeouts
+ cache_key = f"{model_id}_stream_async_client"
+ _client = openai.AsyncOpenAI( # type: ignore
+ api_key=api_key,
+ base_url=api_base,
+ timeout=stream_timeout,
+ max_retries=max_retries,
+ organization=organization,
+ http_client=httpx.AsyncClient(
+ transport=AsyncCustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=async_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
+
+ if should_initialize_sync_client(
+ litellm_router_instance=litellm_router_instance
+ ):
+ # streaming clients should have diff timeouts
+ cache_key = f"{model_id}_stream_client"
+ _client = openai.OpenAI( # type: ignore
+ api_key=api_key,
+ base_url=api_base,
+ timeout=stream_timeout,
+ max_retries=max_retries,
+ organization=organization,
+ http_client=httpx.Client(
+ transport=CustomHTTPTransport(
+ limits=httpx.Limits(
+ max_connections=1000, max_keepalive_connections=100
+ ),
+ verify=litellm.ssl_verify,
+ ),
+ mounts=sync_proxy_mounts,
+ ), # type: ignore
+ )
+ litellm_router_instance.cache.set_cache(
+ key=cache_key,
+ value=_client,
+ ttl=client_ttl,
+ local_only=True,
+ ) # cache for 1 hr
diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py
index f0f0cc541..13167c10f 100644
--- a/litellm/tests/test_router_init.py
+++ b/litellm/tests/test_router_init.py
@@ -1,16 +1,22 @@
# this tests if the router is initialized correctly
-import sys, os, time
-import traceback, asyncio
+import asyncio
+import os
+import sys
+import time
+import traceback
+
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor
+
+from dotenv import load_dotenv
+
import litellm
from litellm import Router
-from concurrent.futures import ThreadPoolExecutor
-from collections import defaultdict
-from dotenv import load_dotenv
load_dotenv()
@@ -24,6 +30,7 @@ load_dotenv()
def test_init_clients():
litellm.set_verbose = True
import logging
+
from litellm._logging import verbose_router_logger
verbose_router_logger.setLevel(logging.DEBUG)
@@ -489,6 +496,7 @@ def test_init_clients_azure_command_r_plus():
# For azure/command-r-plus we need to use openai.OpenAI because of how the Azure provider requires requests being sent
litellm.set_verbose = True
import logging
+
from litellm._logging import verbose_router_logger
verbose_router_logger.setLevel(logging.DEBUG)
@@ -585,3 +593,46 @@ async def test_text_completion_with_organization():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
+
+
+def test_init_clients_async_mode():
+ litellm.set_verbose = True
+ import logging
+
+ from litellm._logging import verbose_router_logger
+ from litellm.types.router import RouterGeneralSettings
+
+ verbose_router_logger.setLevel(logging.DEBUG)
+ try:
+ print("testing init 4 clients with diff timeouts")
+ model_list = [
+ {
+ "model_name": "gpt-3.5-turbo",
+ "litellm_params": {
+ "model": "azure/chatgpt-v-2",
+ "api_key": os.getenv("AZURE_API_KEY"),
+ "api_version": os.getenv("AZURE_API_VERSION"),
+ "api_base": os.getenv("AZURE_API_BASE"),
+ "timeout": 0.01,
+ "stream_timeout": 0.000_001,
+ "max_retries": 7,
+ },
+ },
+ ]
+ router = Router(
+ model_list=model_list,
+ set_verbose=True,
+ router_general_settings=RouterGeneralSettings(async_only_mode=True),
+ )
+ for elem in router.model_list:
+ model_id = elem["model_info"]["id"]
+
+ # sync clients not initialized in async_only_mode=True
+ assert router.cache.get_cache(f"{model_id}_client") is None
+ assert router.cache.get_cache(f"{model_id}_stream_client") is None
+
+ # only async clients initialized in async_only_mode=True
+ assert router.cache.get_cache(f"{model_id}_async_client") is not None
+ assert router.cache.get_cache(f"{model_id}_stream_async_client") is not None
+ except Exception as e:
+ pytest.fail(f"Error occurred: {e}")
diff --git a/litellm/tests/test_stream_chunk_builder.py b/litellm/tests/test_stream_chunk_builder.py
index 001ae07e0..342b070ae 100644
--- a/litellm/tests/test_stream_chunk_builder.py
+++ b/litellm/tests/test_stream_chunk_builder.py
@@ -1,15 +1,22 @@
-import sys, os, time
-import traceback, asyncio
+import asyncio
+import os
+import sys
+import time
+import traceback
+
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
-from litellm import completion, stream_chunk_builder
-import litellm
-import os, dotenv
-from openai import OpenAI
+import os
+
+import dotenv
import pytest
+from openai import OpenAI
+
+import litellm
+from litellm import completion, stream_chunk_builder
dotenv.load_dotenv()
@@ -147,3 +154,45 @@ def test_stream_chunk_builder_litellm_tool_call_regular_message():
# test_stream_chunk_builder_litellm_tool_call_regular_message()
+
+
+def test_stream_chunk_builder_litellm_usage_chunks():
+ """
+ Checks if stream_chunk_builder is able to correctly rebuild with given metadata from streaming chunks
+ """
+ messages = [
+ {"role": "user", "content": "Tell me the funniest joke you know."},
+ {
+ "role": "assistant",
+ "content": "Why did the chicken cross the road?\nYou will not guess this one I bet\n",
+ },
+ {"role": "user", "content": "I do not know, why?"},
+ {"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
+ {"role": "user", "content": "\nI am waiting...\n\n...\n"},
+ ]
+ # make a regular gemini call
+ response = completion(
+ model="gemini/gemini-1.5-flash",
+ messages=messages,
+ )
+
+ usage: litellm.Usage = response.usage
+
+ gemini_pt = usage.prompt_tokens
+
+ # make a streaming gemini call
+ response = completion(
+ model="gemini/gemini-1.5-flash",
+ messages=messages,
+ stream=True,
+ complete_response=True,
+ stream_options={"include_usage": True},
+ )
+
+ usage: litellm.Usage = response.usage
+
+ stream_rebuilt_pt = usage.prompt_tokens
+
+ # assert prompt tokens are the same
+
+ assert gemini_pt == stream_rebuilt_pt
diff --git a/litellm/types/router.py b/litellm/types/router.py
index 78d516d6c..9c028fa87 100644
--- a/litellm/types/router.py
+++ b/litellm/types/router.py
@@ -324,7 +324,12 @@ class DeploymentTypedDict(TypedDict):
litellm_params: LiteLLMParamsTypedDict
-SPECIAL_MODEL_INFO_PARAMS = ["input_cost_per_token", "output_cost_per_token"]
+SPECIAL_MODEL_INFO_PARAMS = [
+ "input_cost_per_token",
+ "output_cost_per_token",
+ "input_cost_per_character",
+ "output_cost_per_character",
+]
class Deployment(BaseModel):
@@ -517,3 +522,9 @@ class CustomRoutingStrategyBase:
"""
pass
+
+
+class RouterGeneralSettings(BaseModel):
+ async_only_mode: bool = Field(
+ default=False
+ ) # this will only initialize async clients. Good for memory utils
diff --git a/ui/litellm-dashboard/src/components/model_dashboard.tsx b/ui/litellm-dashboard/src/components/model_dashboard.tsx
index fb189b8c4..5456fb302 100644
--- a/ui/litellm-dashboard/src/components/model_dashboard.tsx
+++ b/ui/litellm-dashboard/src/components/model_dashboard.tsx
@@ -743,7 +743,7 @@ const ModelDashboard: React.FC = ({
}
const fetchModelMap = async () => {
- const data = await modelCostMap();
+ const data = await modelCostMap(accessToken);
console.log(`received model cost map data: ${Object.keys(data)}`);
setModelMap(data);
};
diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx
index 9060c339e..75819af58 100644
--- a/ui/litellm-dashboard/src/components/networking.tsx
+++ b/ui/litellm-dashboard/src/components/networking.tsx
@@ -12,11 +12,19 @@ export interface Model {
model_info: Object | null;
}
-export const modelCostMap = async () => {
+export const modelCostMap = async (
+ accessToken: string,
+) => {
try {
const url = proxyBaseUrl ? `${proxyBaseUrl}/get/litellm_model_cost_map` : `/get/litellm_model_cost_map`;
const response = await fetch(
- url
+ url, {
+ method: "GET",
+ headers: {
+ Authorization: `Bearer ${accessToken}`,
+ "Content-Type": "application/json",
+ },
+ }
);
const jsonData = await response.json();
console.log(`received litellm model cost data: ${jsonData}`);
@@ -693,6 +701,9 @@ export const claimOnboardingToken = async (
throw error;
}
};
+let ModelListerrorShown = false;
+let errorTimer: NodeJS.Timeout | null = null;
+
export const modelInfoCall = async (
accessToken: String,
userID: String,
@@ -714,8 +725,21 @@ export const modelInfoCall = async (
});
if (!response.ok) {
- const errorData = await response.text();
- message.error(errorData, 10);
+ let errorData = await response.text();
+ errorData += `error shown=${ModelListerrorShown}`
+ if (!ModelListerrorShown) {
+ if (errorData.includes("No model list passed")) {
+ errorData = "No Models Exist. Click Add Model to get started.";
+ }
+ message.info(errorData, 10);
+ ModelListerrorShown = true;
+
+ if (errorTimer) clearTimeout(errorTimer);
+ errorTimer = setTimeout(() => {
+ ModelListerrorShown = false;
+ }, 10000);
+ }
+
throw new Error("Network response was not ok");
}
@@ -750,7 +774,6 @@ export const modelHubCall = async (accessToken: String) => {
if (!response.ok) {
const errorData = await response.text();
- message.error(errorData, 10);
throw new Error("Network response was not ok");
}
diff --git a/ui/litellm-dashboard/src/components/usage.tsx b/ui/litellm-dashboard/src/components/usage.tsx
index 1ac91dd10..451eaefe6 100644
--- a/ui/litellm-dashboard/src/components/usage.tsx
+++ b/ui/litellm-dashboard/src/components/usage.tsx
@@ -32,7 +32,6 @@ import {
allTagNamesCall,
modelMetricsCall,
modelAvailableCall,
- modelInfoCall,
adminspendByProvider,
adminGlobalActivity,
adminGlobalActivityPerModel,