Merge branch 'main' into litellm_tts_pricing

This commit is contained in:
Krish Dholakia 2024-07-06 14:57:34 -07:00 committed by GitHub
commit 127f08ee67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 833 additions and 572 deletions

View file

@ -173,6 +173,37 @@ console.log(message);
``` ```
</TabItem>
<TabItem value="instructor" label="Instructor">
```python
from openai import OpenAI
import instructor
from pydantic import BaseModel
my_proxy_api_key = "" # e.g. sk-1234
my_proxy_base_url = "" # e.g. http://0.0.0.0:4000
# This enables response_model keyword
# from client.chat.completions.create
client = instructor.from_openai(OpenAI(api_key=my_proxy_api_key, base_url=my_proxy_base_url))
class UserDetail(BaseModel):
name: str
age: int
user = client.chat.completions.create(
model="gemini-pro-flash",
response_model=UserDetail,
messages=[
{"role": "user", "content": "Extract Jason is 25 years old"},
]
)
assert isinstance(user, UserDetail)
assert user.name == "Jason"
assert user.age == 25
```
</TabItem> </TabItem>
</Tabs> </Tabs>

View file

@ -4,6 +4,8 @@ import time
import traceback import traceback
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union
from pydantic import BaseModel
import litellm import litellm
import litellm._logging import litellm._logging
from litellm import verbose_logger from litellm import verbose_logger
@ -14,6 +16,9 @@ from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_token as google_cost_per_token, cost_per_token as google_cost_per_token,
) )
from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.utils import ( from litellm.utils import (
CallTypes, CallTypes,
CostPerToken, CostPerToken,
@ -469,7 +474,10 @@ def completion_cost(
prompt_characters = 0 prompt_characters = 0
completion_tokens = 0 completion_tokens = 0
completion_characters = 0 completion_characters = 0
if completion_response is not None: if completion_response is not None and (
isinstance(completion_response, BaseModel)
or isinstance(completion_response, dict)
): # tts returns a custom class
# get input/output tokens from completion_response # get input/output tokens from completion_response
prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = completion_response.get("usage", {}).get( completion_tokens = completion_response.get("usage", {}).get(
@ -654,6 +662,7 @@ def response_cost_calculator(
ImageResponse, ImageResponse,
TranscriptionResponse, TranscriptionResponse,
TextCompletionResponse, TextCompletionResponse,
HttpxBinaryResponseContent,
], ],
model: str, model: str,
custom_llm_provider: Optional[str], custom_llm_provider: Optional[str],
@ -687,6 +696,7 @@ def response_cost_calculator(
if cache_hit is not None and cache_hit is True: if cache_hit is not None and cache_hit is True:
response_cost = 0.0 response_cost = 0.0
else: else:
if isinstance(response_object, BaseModel):
response_object._hidden_params["optional_params"] = optional_params response_object._hidden_params["optional_params"] = optional_params
if isinstance(response_object, ImageResponse): if isinstance(response_object, ImageResponse):
response_cost = completion_cost( response_cost = completion_cost(
@ -697,12 +707,11 @@ def response_cost_calculator(
) )
else: else:
if ( if (
model in litellm.model_cost model in litellm.model_cost or custom_pricing is True
and custom_pricing is not None
and custom_llm_provider is True
): # override defaults if custom pricing is set ): # override defaults if custom pricing is set
base_model = model base_model = model
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info
response_cost = completion_cost( response_cost = completion_cost(
completion_response=response_object, completion_response=response_object,
call_type=call_type, call_type=call_type,

View file

@ -24,6 +24,8 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import ( from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging, redact_message_input_output_from_logging,
) )
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import ( from litellm.types.utils import (
CallTypes, CallTypes,
EmbeddingResponse, EmbeddingResponse,
@ -517,18 +519,20 @@ class Logging:
self.model_call_details["cache_hit"] = cache_hit self.model_call_details["cache_hit"] = cache_hit
## if model in model cost map - log the response cost ## if model in model cost map - log the response cost
## else set cost to None ## else set cost to None
verbose_logger.debug(f"Model={self.model};")
if ( if (
result is not None result is not None and self.stream is not True
and ( ): # handle streaming separately
if (
isinstance(result, ModelResponse) isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse) or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse) or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse) or isinstance(result, TranscriptionResponse)
or isinstance(result, TextCompletionResponse) or isinstance(result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts
):
custom_pricing = use_custom_pricing_for_model(
litellm_params=self.litellm_params
) )
and self.stream != True
): # handle streaming separately
self.model_call_details["response_cost"] = ( self.model_call_details["response_cost"] = (
litellm.response_cost_calculator( litellm.response_cost_calculator(
response_object=result, response_object=result,
@ -542,6 +546,7 @@ class Logging:
), ),
call_type=self.call_type, call_type=self.call_type,
optional_params=self.optional_params, optional_params=self.optional_params,
custom_pricing=custom_pricing,
) )
) )
else: # streaming chunks + image gen. else: # streaming chunks + image gen.
@ -2012,3 +2017,17 @@ def get_custom_logger_compatible_class(
if isinstance(callback, _PROXY_DynamicRateLimitHandler): if isinstance(callback, _PROXY_DynamicRateLimitHandler):
return callback # type: ignore return callback # type: ignore
return None return None
def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool:
if litellm_params is None:
return False
metadata: Optional[dict] = litellm_params.get("metadata", {})
if metadata is None:
return False
model_info: Optional[dict] = metadata.get("model_info", {})
if model_info is not None:
for k, v in model_info.items():
if k in SPECIAL_MODEL_INFO_PARAMS:
return True
return False

View file

@ -5022,10 +5022,9 @@ def stream_chunk_builder(
for chunk in chunks: for chunk in chunks:
if "usage" in chunk: if "usage" in chunk:
if "prompt_tokens" in chunk["usage"]: if "prompt_tokens" in chunk["usage"]:
prompt_tokens += chunk["usage"].get("prompt_tokens", 0) or 0 prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
if "completion_tokens" in chunk["usage"]: if "completion_tokens" in chunk["usage"]:
completion_tokens += chunk["usage"].get("completion_tokens", 0) or 0 completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
try: try:
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter( response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages model=model, messages=messages

View file

@ -1,12 +1,9 @@
model_list: model_list:
- model_name: "*" - model_name: tts
litellm_params: litellm_params:
model: "openai/*" model: "openai/*"
mock_response: "Hello world!"
litellm_settings: litellm_settings:
success_callback: ["langfuse"] success_callback: ["langfuse"]
failure_callback: ["langfuse"]
general_settings: general_settings:
alerting: ["slack"] alerting: ["slack"]

View file

@ -207,6 +207,7 @@ from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import RouterGeneralSettings
try: try:
from litellm._version import version from litellm._version import version
@ -1765,7 +1766,11 @@ class ProxyConfig:
if k in available_args: if k in available_args:
router_params[k] = v router_params[k] = v
router = litellm.Router( router = litellm.Router(
**router_params, assistants_config=assistants_config **router_params,
assistants_config=assistants_config,
router_general_settings=RouterGeneralSettings(
async_only_mode=True # only init async clients
),
) # type:ignore ) # type:ignore
return router, router.get_model_list(), general_settings return router, router.get_model_list(), general_settings
@ -1957,7 +1962,12 @@ class ProxyConfig:
) )
if len(_model_list) > 0: if len(_model_list) > 0:
verbose_proxy_logger.debug(f"_model_list: {_model_list}") verbose_proxy_logger.debug(f"_model_list: {_model_list}")
llm_router = litellm.Router(model_list=_model_list) llm_router = litellm.Router(
model_list=_model_list,
router_general_settings=RouterGeneralSettings(
async_only_mode=True # only init async clients
),
)
verbose_proxy_logger.debug(f"updated llm_router: {llm_router}") verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
else: else:
verbose_proxy_logger.debug(f"len new_models: {len(new_models)}") verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")

View file

@ -55,6 +55,10 @@ from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler
from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2
from litellm.router_utils.client_initalization_utils import (
set_client,
should_initialize_sync_client,
)
from litellm.router_utils.handle_error import send_llm_exception_alert from litellm.router_utils.handle_error import send_llm_exception_alert
from litellm.scheduler import FlowItem, Scheduler from litellm.scheduler import FlowItem, Scheduler
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
@ -79,6 +83,7 @@ from litellm.types.router import (
ModelInfo, ModelInfo,
RetryPolicy, RetryPolicy,
RouterErrors, RouterErrors,
RouterGeneralSettings,
updateDeployment, updateDeployment,
updateLiteLLMParams, updateLiteLLMParams,
) )
@ -169,6 +174,7 @@ class Router:
routing_strategy_args: dict = {}, # just for latency-based routing routing_strategy_args: dict = {}, # just for latency-based routing
semaphore: Optional[asyncio.Semaphore] = None, semaphore: Optional[asyncio.Semaphore] = None,
alerting_config: Optional[AlertingConfig] = None, alerting_config: Optional[AlertingConfig] = None,
router_general_settings: Optional[RouterGeneralSettings] = None,
) -> None: ) -> None:
""" """
Initialize the Router class with the given parameters for caching, reliability, and routing strategy. Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
@ -246,6 +252,9 @@ class Router:
verbose_router_logger.setLevel(logging.INFO) verbose_router_logger.setLevel(logging.INFO)
elif debug_level == "DEBUG": elif debug_level == "DEBUG":
verbose_router_logger.setLevel(logging.DEBUG) verbose_router_logger.setLevel(logging.DEBUG)
self.router_general_settings: Optional[RouterGeneralSettings] = (
router_general_settings
)
self.assistants_config = assistants_config self.assistants_config = assistants_config
self.deployment_names: List = ( self.deployment_names: List = (
@ -3247,520 +3256,6 @@ class Router:
except Exception as e: except Exception as e:
raise e raise e
def set_client(self, model: dict):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
"""
client_ttl = self.client_ttl
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
rpm = litellm_params.get("rpm", None)
tpm = litellm_params.get("tpm", None)
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
calculated_max_parallel_requests = calculate_max_parallel_requests(
rpm=rpm,
max_parallel_requests=max_parallel_requests,
tpm=tpm,
default_max_parallel_requests=self.default_max_parallel_requests,
)
if calculated_max_parallel_requests:
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
cache_key = f"{model_id}_max_parallel_requests_client"
self.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
default_api_base = None
default_api_key = None
if custom_llm_provider in litellm.openai_compatible_providers:
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
model=model_name
)
default_api_base = api_base
default_api_key = api_key
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models
):
is_azure_ai_studio_model: bool = False
if custom_llm_provider == "azure":
if litellm.utils._is_non_openai_azure_model(model_name):
is_azure_ai_studio_model = True
custom_llm_provider = "openai"
# remove azure prefx from model_name
model_name = model_name.replace("azure/", "")
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if (
api_key
and isinstance(api_key, str)
and api_key.startswith("os.environ/")
):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name)
litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ##
"""
Make sure api base ends in /v1/
if not, add it - https://github.com/BerriAI/litellm/issues/2279
"""
if (
is_azure_ai_studio_model is True
and api_base is not None
and isinstance(api_base, str)
and not api_base.endswith("/v1/")
):
# check if it ends with a trailing slash
if api_base.endswith("/"):
api_base += "v1/"
elif api_base.endswith("/v1"):
api_base += "/"
else:
api_base += "/v1/"
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name)
litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None) or litellm.request_timeout
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name)
litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith(
"os.environ/"
):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop(
"max_retries", 0
) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
# proxy support
import os
import httpx
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
http_proxy = os.getenv("HTTP_PROXY", None)
https_proxy = os.getenv("HTTPS_PROXY", None)
no_proxy = os.getenv("NO_PROXY", None)
# Create the proxies dictionary only if the environment variables are set.
sync_proxy_mounts = None
async_proxy_mounts = None
if http_proxy is not None and https_proxy is not None:
sync_proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
"https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
}
async_proxy_mounts = {
"http://": httpx.AsyncHTTPTransport(
proxy=httpx.Proxy(url=http_proxy)
),
"https://": httpx.AsyncHTTPTransport(
proxy=httpx.Proxy(url=https_proxy)
),
}
# assume no_proxy is a list of comma separated urls
if no_proxy is not None and isinstance(no_proxy, str):
no_proxy_urls = no_proxy.split(",")
for url in no_proxy_urls: # set no-proxy support for specific urls
sync_proxy_mounts[url] = None # type: ignore
async_proxy_mounts[url] = None # type: ignore
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
if api_base is None or not isinstance(api_base, str):
filtered_litellm_params = {
k: v
for k, v in model["litellm_params"].items()
if k != "api_key"
}
_filtered_model = {
"model_name": model["model_name"],
"litellm_params": filtered_litellm_params,
}
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}"
)
azure_ad_token = litellm_params.get("azure_ad_token")
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None:
api_version = litellm.AZURE_DEFAULT_API_VERSION
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
api_base += "/"
azure_model = model_name.replace("azure/", "")
api_base += f"{azure_model}"
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients can have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
_api_key = api_key
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_router_logger.debug(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
)
azure_client_params = {
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
}
from litellm.llms.azure import select_azure_base_url_or_endpoint
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
verify=litellm.ssl_verify,
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
),
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
),
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
_api_key = api_key # type: ignore
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_router_logger.debug(
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}"
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
cache_key = f"{model_id}_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
self.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
def _generate_model_id(self, model_group: str, litellm_params: dict): def _generate_model_id(self, model_group: str, litellm_params: dict):
""" """
Helper function to consistently generate the same id for a deployment Helper function to consistently generate the same id for a deployment
@ -3904,7 +3399,9 @@ class Router:
raise Exception(f"Unsupported provider - {custom_llm_provider}") raise Exception(f"Unsupported provider - {custom_llm_provider}")
# init OpenAI, Azure clients # init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True)) set_client(
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
)
# set region (if azure model) ## PREVIEW FEATURE ## # set region (if azure model) ## PREVIEW FEATURE ##
if litellm.enable_preview_features == True: if litellm.enable_preview_features == True:
@ -4432,7 +3929,7 @@ class Router:
""" """
Re-initialize the client Re-initialize the client
""" """
self.set_client(model=deployment) set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(key=cache_key, local_only=True)
return client return client
else: else:
@ -4442,7 +3939,7 @@ class Router:
""" """
Re-initialize the client Re-initialize the client
""" """
self.set_client(model=deployment) set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(key=cache_key, local_only=True)
return client return client
else: else:
@ -4453,7 +3950,7 @@ class Router:
""" """
Re-initialize the client Re-initialize the client
""" """
self.set_client(model=deployment) set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key) client = self.cache.get_cache(key=cache_key)
return client return client
else: else:
@ -4463,7 +3960,7 @@ class Router:
""" """
Re-initialize the client Re-initialize the client
""" """
self.set_client(model=deployment) set_client(litellm_router_instance=self, model=deployment)
client = self.cache.get_cache(key=cache_key) client = self.cache.get_cache(key=cache_key)
return client return client

View file

@ -0,0 +1,566 @@
import asyncio
import traceback
from typing import TYPE_CHECKING, Any
import openai
import litellm
from litellm._logging import verbose_router_logger
from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.llms.custom_httpx.azure_dall_e_2 import (
AsyncCustomHTTPTransport,
CustomHTTPTransport,
)
from litellm.utils import calculate_max_parallel_requests
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
def should_initialize_sync_client(
litellm_router_instance: LitellmRouter,
) -> bool:
"""
Returns if Sync OpenAI, Azure Clients should be initialized.
Do not init sync clients when router.router_general_settings.async_only_mode is True
"""
if litellm_router_instance is None:
return False
if litellm_router_instance.router_general_settings is not None:
if (
hasattr(litellm_router_instance, "router_general_settings")
and hasattr(
litellm_router_instance.router_general_settings, "async_only_mode"
)
and litellm_router_instance.router_general_settings.async_only_mode is True
):
return False
return True
def set_client(litellm_router_instance: LitellmRouter, model: dict):
"""
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
"""
client_ttl = litellm_router_instance.client_ttl
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
rpm = litellm_params.get("rpm", None)
tpm = litellm_params.get("tpm", None)
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
calculated_max_parallel_requests = calculate_max_parallel_requests(
rpm=rpm,
max_parallel_requests=max_parallel_requests,
tpm=tpm,
default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests,
)
if calculated_max_parallel_requests:
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
cache_key = f"{model_id}_max_parallel_requests_client"
litellm_router_instance.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
default_api_base = None
default_api_key = None
if custom_llm_provider in litellm.openai_compatible_providers:
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(
model=model_name
)
default_api_base = api_base
default_api_key = api_key
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models
):
is_azure_ai_studio_model: bool = False
if custom_llm_provider == "azure":
if litellm.utils._is_non_openai_azure_model(model_name):
is_azure_ai_studio_model = True
custom_llm_provider = "openai"
# remove azure prefx from model_name
model_name = model_name.replace("azure/", "")
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name)
litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ##
"""
Make sure api base ends in /v1/
if not, add it - https://github.com/BerriAI/litellm/issues/2279
"""
if (
is_azure_ai_studio_model is True
and api_base is not None
and isinstance(api_base, str)
and not api_base.endswith("/v1/")
):
# check if it ends with a trailing slash
if api_base.endswith("/"):
api_base += "v1/"
elif api_base.endswith("/v1"):
api_base += "/"
else:
api_base += "/v1/"
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name)
litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None) or litellm.request_timeout
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name)
litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 0) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
# proxy support
import os
import httpx
# Check if the HTTP_PROXY and HTTPS_PROXY environment variables are set and use them accordingly.
http_proxy = os.getenv("HTTP_PROXY", None)
https_proxy = os.getenv("HTTPS_PROXY", None)
no_proxy = os.getenv("NO_PROXY", None)
# Create the proxies dictionary only if the environment variables are set.
sync_proxy_mounts = None
async_proxy_mounts = None
if http_proxy is not None and https_proxy is not None:
sync_proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
"https://": httpx.HTTPTransport(proxy=httpx.Proxy(url=https_proxy)),
}
async_proxy_mounts = {
"http://": httpx.AsyncHTTPTransport(proxy=httpx.Proxy(url=http_proxy)),
"https://": httpx.AsyncHTTPTransport(
proxy=httpx.Proxy(url=https_proxy)
),
}
# assume no_proxy is a list of comma separated urls
if no_proxy is not None and isinstance(no_proxy, str):
no_proxy_urls = no_proxy.split(",")
for url in no_proxy_urls: # set no-proxy support for specific urls
sync_proxy_mounts[url] = None # type: ignore
async_proxy_mounts[url] = None # type: ignore
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
if api_base is None or not isinstance(api_base, str):
filtered_litellm_params = {
k: v for k, v in model["litellm_params"].items() if k != "api_key"
}
_filtered_model = {
"model_name": model["model_name"],
"litellm_params": filtered_litellm_params,
}
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}"
)
azure_ad_token = litellm_params.get("azure_ad_token")
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
if api_version is None:
api_version = litellm.AZURE_DEFAULT_API_VERSION
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
api_base += "/"
azure_model = model_name.replace("azure/", "")
api_base += f"{azure_model}"
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients can have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
_api_key = api_key
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_router_logger.debug(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
)
azure_client_params = {
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
}
from litellm.llms.azure import select_azure_base_url_or_endpoint
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
verify=litellm.ssl_verify,
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
),
mounts=sync_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
),
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
),
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
_api_key = api_key # type: ignore
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_router_logger.debug(
f"Initializing OpenAI Client for {model_name}, Api Base:{str(api_base)}, Api Key:{_api_key}"
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_client"
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr

View file

@ -1,16 +1,22 @@
# this tests if the router is initialized correctly # this tests if the router is initialized correctly
import sys, os, time import asyncio
import traceback, asyncio import os
import sys
import time
import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
import litellm import litellm
from litellm import Router from litellm import Router
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -24,6 +30,7 @@ load_dotenv()
def test_init_clients(): def test_init_clients():
litellm.set_verbose = True litellm.set_verbose = True
import logging import logging
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
verbose_router_logger.setLevel(logging.DEBUG) verbose_router_logger.setLevel(logging.DEBUG)
@ -489,6 +496,7 @@ def test_init_clients_azure_command_r_plus():
# For azure/command-r-plus we need to use openai.OpenAI because of how the Azure provider requires requests being sent # For azure/command-r-plus we need to use openai.OpenAI because of how the Azure provider requires requests being sent
litellm.set_verbose = True litellm.set_verbose = True
import logging import logging
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
verbose_router_logger.setLevel(logging.DEBUG) verbose_router_logger.setLevel(logging.DEBUG)
@ -585,3 +593,46 @@ async def test_text_completion_with_organization():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_init_clients_async_mode():
litellm.set_verbose = True
import logging
from litellm._logging import verbose_router_logger
from litellm.types.router import RouterGeneralSettings
verbose_router_logger.setLevel(logging.DEBUG)
try:
print("testing init 4 clients with diff timeouts")
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
"timeout": 0.01,
"stream_timeout": 0.000_001,
"max_retries": 7,
},
},
]
router = Router(
model_list=model_list,
set_verbose=True,
router_general_settings=RouterGeneralSettings(async_only_mode=True),
)
for elem in router.model_list:
model_id = elem["model_info"]["id"]
# sync clients not initialized in async_only_mode=True
assert router.cache.get_cache(f"{model_id}_client") is None
assert router.cache.get_cache(f"{model_id}_stream_client") is None
# only async clients initialized in async_only_mode=True
assert router.cache.get_cache(f"{model_id}_async_client") is not None
assert router.cache.get_cache(f"{model_id}_stream_async_client") is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -1,15 +1,22 @@
import sys, os, time import asyncio
import traceback, asyncio import os
import sys
import time
import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm import completion, stream_chunk_builder import os
import litellm
import os, dotenv import dotenv
from openai import OpenAI
import pytest import pytest
from openai import OpenAI
import litellm
from litellm import completion, stream_chunk_builder
dotenv.load_dotenv() dotenv.load_dotenv()
@ -147,3 +154,45 @@ def test_stream_chunk_builder_litellm_tool_call_regular_message():
# test_stream_chunk_builder_litellm_tool_call_regular_message() # test_stream_chunk_builder_litellm_tool_call_regular_message()
def test_stream_chunk_builder_litellm_usage_chunks():
"""
Checks if stream_chunk_builder is able to correctly rebuild with given metadata from streaming chunks
"""
messages = [
{"role": "user", "content": "Tell me the funniest joke you know."},
{
"role": "assistant",
"content": "Why did the chicken cross the road?\nYou will not guess this one I bet\n",
},
{"role": "user", "content": "I do not know, why?"},
{"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
{"role": "user", "content": "\nI am waiting...\n\n...\n"},
]
# make a regular gemini call
response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
)
usage: litellm.Usage = response.usage
gemini_pt = usage.prompt_tokens
# make a streaming gemini call
response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
stream=True,
complete_response=True,
stream_options={"include_usage": True},
)
usage: litellm.Usage = response.usage
stream_rebuilt_pt = usage.prompt_tokens
# assert prompt tokens are the same
assert gemini_pt == stream_rebuilt_pt

View file

@ -324,7 +324,12 @@ class DeploymentTypedDict(TypedDict):
litellm_params: LiteLLMParamsTypedDict litellm_params: LiteLLMParamsTypedDict
SPECIAL_MODEL_INFO_PARAMS = ["input_cost_per_token", "output_cost_per_token"] SPECIAL_MODEL_INFO_PARAMS = [
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_character",
"output_cost_per_character",
]
class Deployment(BaseModel): class Deployment(BaseModel):
@ -517,3 +522,9 @@ class CustomRoutingStrategyBase:
""" """
pass pass
class RouterGeneralSettings(BaseModel):
async_only_mode: bool = Field(
default=False
) # this will only initialize async clients. Good for memory utils

View file

@ -743,7 +743,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
} }
const fetchModelMap = async () => { const fetchModelMap = async () => {
const data = await modelCostMap(); const data = await modelCostMap(accessToken);
console.log(`received model cost map data: ${Object.keys(data)}`); console.log(`received model cost map data: ${Object.keys(data)}`);
setModelMap(data); setModelMap(data);
}; };

View file

@ -12,11 +12,19 @@ export interface Model {
model_info: Object | null; model_info: Object | null;
} }
export const modelCostMap = async () => { export const modelCostMap = async (
accessToken: string,
) => {
try { try {
const url = proxyBaseUrl ? `${proxyBaseUrl}/get/litellm_model_cost_map` : `/get/litellm_model_cost_map`; const url = proxyBaseUrl ? `${proxyBaseUrl}/get/litellm_model_cost_map` : `/get/litellm_model_cost_map`;
const response = await fetch( const response = await fetch(
url url, {
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
}
); );
const jsonData = await response.json(); const jsonData = await response.json();
console.log(`received litellm model cost data: ${jsonData}`); console.log(`received litellm model cost data: ${jsonData}`);
@ -693,6 +701,9 @@ export const claimOnboardingToken = async (
throw error; throw error;
} }
}; };
let ModelListerrorShown = false;
let errorTimer: NodeJS.Timeout | null = null;
export const modelInfoCall = async ( export const modelInfoCall = async (
accessToken: String, accessToken: String,
userID: String, userID: String,
@ -714,8 +725,21 @@ export const modelInfoCall = async (
}); });
if (!response.ok) { if (!response.ok) {
const errorData = await response.text(); let errorData = await response.text();
message.error(errorData, 10); errorData += `error shown=${ModelListerrorShown}`
if (!ModelListerrorShown) {
if (errorData.includes("No model list passed")) {
errorData = "No Models Exist. Click Add Model to get started.";
}
message.info(errorData, 10);
ModelListerrorShown = true;
if (errorTimer) clearTimeout(errorTimer);
errorTimer = setTimeout(() => {
ModelListerrorShown = false;
}, 10000);
}
throw new Error("Network response was not ok"); throw new Error("Network response was not ok");
} }
@ -750,7 +774,6 @@ export const modelHubCall = async (accessToken: String) => {
if (!response.ok) { if (!response.ok) {
const errorData = await response.text(); const errorData = await response.text();
message.error(errorData, 10);
throw new Error("Network response was not ok"); throw new Error("Network response was not ok");
} }

View file

@ -32,7 +32,6 @@ import {
allTagNamesCall, allTagNamesCall,
modelMetricsCall, modelMetricsCall,
modelAvailableCall, modelAvailableCall,
modelInfoCall,
adminspendByProvider, adminspendByProvider,
adminGlobalActivity, adminGlobalActivity,
adminGlobalActivityPerModel, adminGlobalActivityPerModel,