LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)

* refactor: cleanup unused variables + fix pyright errors

* feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686

* fix(o1_reasoning.py): add stricter check for o-1 reasoning model

* refactor(mistral/): make it easier to see mistral transformation logic

* fix(openai.py): fix openai o-1 model param mapping

Fixes https://github.com/BerriAI/litellm/issues/5685

* feat(main.py): infer finetuned gemini model from base model

Fixes https://github.com/BerriAI/litellm/issues/5678

* docs(vertex.md): update docs to call finetuned gemini models

* feat(proxy_server.py): allow admin to hide proxy model aliases

Closes https://github.com/BerriAI/litellm/issues/5692

* docs(load_balancing.md): add docs on hiding alias models from proxy config

* fix(base.py): don't raise notimplemented error

* fix(user_api_key_auth.py): fix model max budget check

* fix(router.py): fix elif

* fix(user_api_key_auth.py): don't set team_id to empty str

* fix(team_endpoints.py): fix response type

* test(test_completion.py): handle predibase error

* test(test_proxy_server.py): fix test

* fix(o1_transformation.py): fix max_completion_token mapping

* test(test_image_generation.py): mark flaky test
This commit is contained in:
Krish Dholakia 2024-09-14 10:02:55 -07:00 committed by GitHub
parent db3af20d84
commit 60709a0753
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 1020 additions and 539 deletions

View file

@ -1,9 +1,9 @@
repos: repos:
- repo: local - repo: local
hooks: hooks:
- id: mypy - id: pyright
name: mypy name: pyright
entry: python3 -m mypy --ignore-missing-imports entry: pyright
language: system language: system
types: [python] types: [python]
files: ^litellm/ files: ^litellm/

View file

@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';
https://github.com/BerriAI/litellm https://github.com/BerriAI/litellm
## **Call 100+ LLMs using the same Input/Output Format** ## **Call 100+ LLMs using the OpenAI Input/Output Format**
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints - Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']` - [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`

View file

@ -1180,9 +1180,58 @@ response = completion(
Fine tuned models on vertex have a numerical model/endpoint id. Fine tuned models on vertex have a numerical model/endpoint id.
| Model Name | Function Call | <Tabs>
|------------------|--------------------------------------| <TabItem value="sdk" label="SDK">
| your fine tuned model | `completion(model='vertex_ai/4965075652664360960', messages)`|
```python
from litellm import completion
import os
## set ENV variables
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
os.environ["VERTEXAI_LOCATION"] = "us-central1"
response = completion(
model="vertex_ai/<your-finetuned-model>", # e.g. vertex_ai/4965075652664360960
messages=[{ "content": "Hello, how are you?","role": "user"}],
base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add Vertex Credentials to your env
```bash
!gcloud auth application-default login
```
2. Setup config.yaml
```yaml
- model_name: finetuned-gemini
litellm_params:
model: vertex_ai/<ENDPOINT_ID>
vertex_project: <PROJECT_ID>
vertex_location: <LOCATION>
model_info:
base_model: vertex_ai/gemini-1.5-pro # IMPORTANT
```
3. Test it!
```bash
curl --location 'https://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: <LITELLM_KEY>' \
--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}'
```
</TabItem>
</Tabs>
## Gemini Pro Vision ## Gemini Pro Vision
| Model Name | Function Call | | Model Name | Function Call |

View file

@ -38,9 +38,24 @@ router_settings:
## Router settings on config - routing_strategy, model_group_alias ## Router settings on config - routing_strategy, model_group_alias
Expose an 'alias' for a 'model_name' on the proxy server.
```
model_group_alias: {
"gpt-4": "gpt-3.5-turbo"
}
```
These aliases are shown on `/v1/models`, `/v1/model/info`, and `/v1/model_group/info` by default.
litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64) litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64)
### Usage
Example config with `router_settings` Example config with `router_settings`
```yaml ```yaml
model_list: model_list:
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
@ -48,19 +63,41 @@ model_list:
model: azure/<your-deployment-name> model: azure/<your-deployment-name>
api_base: <your-azure-endpoint> api_base: <your-azure-endpoint>
api_key: <your-azure-api-key> api_key: <your-azure-api-key>
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
router_settings:
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models
```
### Hide Alias Models
Use this if you want to set-up aliases for:
1. typos
2. minor model version changes
3. case sensitive changes between updates
```yaml
model_list:
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: azure/gpt-turbo-small-ca model: azure/<your-deployment-name>
api_base: https://my-endpoint-canada-berri992.openai.azure.com/ api_base: <your-azure-endpoint>
api_key: <your-azure-api-key> api_key: <your-azure-api-key>
rpm: 6
router_settings: router_settings:
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo` model_group_alias:
routing_strategy: least-busy # Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] "GPT-3.5-turbo": # alias
num_retries: 2 model: "gpt-3.5-turbo" # Actual model name in 'model_list'
timeout: 30 # 30 seconds hidden: true # Exclude from `/v1/models`, `/v1/model/info`, `/v1/model_group/info`
redis_host: <your redis host> ```
redis_password: <your redis password>
redis_port: 1992 ### Complete Spec
```python
model_group_alias: Optional[Dict[str, Union[str, RouterModelGroupAliasItem]]] = {}
class RouterModelGroupAliasItem(TypedDict):
model: str
hidden: bool # if 'True', don't return on `/v1/models`, `/v1/model/info`, `/v1/model_group/info`
``` ```

View file

@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';
https://github.com/BerriAI/litellm https://github.com/BerriAI/litellm
## **Call 100+ LLMs using the same Input/Output Format** ## **Call 100+ LLMs using the OpenAI Input/Output Format**
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints - Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']` - [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`

View file

@ -939,15 +939,18 @@ from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConf
from .llms.OpenAI.openai import ( from .llms.OpenAI.openai import (
OpenAIConfig, OpenAIConfig,
OpenAITextCompletionConfig, OpenAITextCompletionConfig,
MistralConfig,
MistralEmbeddingConfig, MistralEmbeddingConfig,
DeepInfraConfig, DeepInfraConfig,
GroqConfig, GroqConfig,
AzureAIStudioConfig, AzureAIStudioConfig,
) )
from .llms.OpenAI.o1_reasoning import ( from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.OpenAI.o1_transformation import (
OpenAIO1Config, OpenAIO1Config,
) )
from .llms.OpenAI.gpt_transformation import (
OpenAIGPTConfig,
)
from .llms.nvidia_nim import NvidiaNimConfig from .llms.nvidia_nim import NvidiaNimConfig
from .llms.cerebras.chat import CerebrasConfig from .llms.cerebras.chat import CerebrasConfig
from .llms.AI21.chat import AI21ChatConfig from .llms.AI21.chat import AI21ChatConfig

View file

@ -52,7 +52,9 @@ class SlackAlerting(CustomBatchLogger):
def __init__( def __init__(
self, self,
internal_usage_cache: Optional[DualCache] = None, internal_usage_cache: Optional[DualCache] = None,
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds) alerting_threshold: Optional[
float
] = None, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [], alerting: Optional[List] = [],
alert_types: List[AlertType] = [ alert_types: List[AlertType] = [
"llm_exceptions", "llm_exceptions",
@ -74,6 +76,8 @@ class SlackAlerting(CustomBatchLogger):
default_webhook_url: Optional[str] = None, default_webhook_url: Optional[str] = None,
**kwargs, **kwargs,
): ):
if alerting_threshold is None:
alerting_threshold = 300
self.alerting_threshold = alerting_threshold self.alerting_threshold = alerting_threshold
self.alerting = alerting self.alerting = alerting
self.alert_types = alert_types self.alert_types = alert_types
@ -99,6 +103,7 @@ class SlackAlerting(CustomBatchLogger):
): ):
if alerting is not None: if alerting is not None:
self.alerting = alerting self.alerting = alerting
asyncio.create_task(self.periodic_flush())
if alerting_threshold is not None: if alerting_threshold is not None:
self.alerting_threshold = alerting_threshold self.alerting_threshold = alerting_threshold
if alert_types is not None: if alert_types is not None:
@ -114,8 +119,6 @@ class SlackAlerting(CustomBatchLogger):
if llm_router is not None: if llm_router is not None:
self.llm_router = llm_router self.llm_router = llm_router
asyncio.create_task(self.periodic_flush())
async def deployment_in_cooldown(self): async def deployment_in_cooldown(self):
pass pass
@ -208,15 +211,20 @@ class SlackAlerting(CustomBatchLogger):
_deployment_latencies = metadata["_latency_per_deployment"] _deployment_latencies = metadata["_latency_per_deployment"]
if len(_deployment_latencies) == 0: if len(_deployment_latencies) == 0:
return None return None
_deployment_latency_map: Optional[dict] = None
try: try:
# try sorting deployments by latency # try sorting deployments by latency
_deployment_latencies = sorted( _deployment_latencies = sorted(
_deployment_latencies.items(), key=lambda x: x[1] _deployment_latencies.items(), key=lambda x: x[1]
) )
_deployment_latencies = dict(_deployment_latencies) _deployment_latency_map = dict(_deployment_latencies)
except: except Exception:
pass pass
for api_base, latency in _deployment_latencies.items():
if _deployment_latency_map is None:
return
for api_base, latency in _deployment_latency_map.items():
_message_to_send += f"\n{api_base}: {round(latency,2)}s" _message_to_send += f"\n{api_base}: {round(latency,2)}s"
_message_to_send = "```" + _message_to_send + "```" _message_to_send = "```" + _message_to_send + "```"
return _message_to_send return _message_to_send
@ -475,6 +483,7 @@ class SlackAlerting(CustomBatchLogger):
): ):
if self.alerting is None or self.alert_types is None: if self.alerting is None or self.alert_types is None:
return return
model: str = ""
if request_data is not None: if request_data is not None:
model = request_data.get("model", "") model = request_data.get("model", "")
messages = request_data.get("messages", None) messages = request_data.get("messages", None)
@ -619,6 +628,7 @@ class SlackAlerting(CustomBatchLogger):
return return
_id: Optional[str] = "default_id" # used for caching _id: Optional[str] = "default_id" # used for caching
user_info_json = user_info.model_dump(exclude_none=True) user_info_json = user_info.model_dump(exclude_none=True)
user_info_str = ""
for k, v in user_info_json.items(): for k, v in user_info_json.items():
user_info_str = "\n{}: {}\n".format(k, v) user_info_str = "\n{}: {}\n".format(k, v)
@ -1475,10 +1485,10 @@ Model Info:
if isinstance(response_obj, litellm.ModelResponse) and ( if isinstance(response_obj, litellm.ModelResponse) and (
hasattr(response_obj, "usage") hasattr(response_obj, "usage")
and response_obj.usage is not None and response_obj.usage is not None # type: ignore
and hasattr(response_obj.usage, "completion_tokens") and hasattr(response_obj.usage, "completion_tokens") # type: ignore
): ):
completion_tokens = response_obj.usage.completion_tokens completion_tokens = response_obj.usage.completion_tokens # type: ignore
if completion_tokens is not None and completion_tokens > 0: if completion_tokens is not None and completion_tokens > 0:
final_value = float( final_value = float(
response_s.total_seconds() / completion_tokens response_s.total_seconds() / completion_tokens
@ -1608,10 +1618,14 @@ Model Info:
todays_date = datetime.datetime.now().date() todays_date = datetime.datetime.now().date()
start_date = todays_date - datetime.timedelta(days=days) start_date = todays_date - datetime.timedelta(days=days)
spend_per_team, spend_per_tag = await _get_spend_report_for_time_range( _resp = await _get_spend_report_for_time_range(
start_date=start_date.strftime("%Y-%m-%d"), start_date=start_date.strftime("%Y-%m-%d"),
end_date=todays_date.strftime("%Y-%m-%d"), end_date=todays_date.strftime("%Y-%m-%d"),
) )
if _resp is None:
return
spend_per_team, spend_per_tag = _resp
_spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n" _spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n"
@ -1656,13 +1670,16 @@ Model Info:
days=last_day_of_month - 1 days=last_day_of_month - 1
) )
monthly_spend_per_team, monthly_spend_per_tag = ( _resp = await _get_spend_report_for_time_range(
await _get_spend_report_for_time_range( start_date=first_day_of_month.strftime("%Y-%m-%d"),
start_date=first_day_of_month.strftime("%Y-%m-%d"), end_date=last_day_of_month.strftime("%Y-%m-%d"),
end_date=last_day_of_month.strftime("%Y-%m-%d"),
)
) )
if _resp is None:
return
monthly_spend_per_team, monthly_spend_per_tag = _resp
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n" _spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
if monthly_spend_per_team is not None: if monthly_spend_per_team is not None:

View file

@ -0,0 +1,142 @@
"""
Support for gpt model family
"""
import types
from typing import Optional, Union
import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
class OpenAIGPTConfig:
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list:
base_params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
] # works across all models
model_specific_params = []
if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
if (
model in litellm.open_ai_chat_completion_models
) or model in litellm.open_ai_text_completion_models:
model_specific_params.append(
"user"
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
return base_params + model_specific_params
def _map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
return self._map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -17,13 +17,12 @@ from typing import Any, List, Optional, Union
import litellm import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from .openai import OpenAIConfig from .gpt_transformation import OpenAIGPTConfig
class OpenAIO1Config(OpenAIConfig): class OpenAIO1Config(OpenAIGPTConfig):
""" """
Reference: https://platform.openai.com/docs/guides/reasoning Reference: https://platform.openai.com/docs/guides/reasoning
""" """
@classmethod @classmethod
@ -50,9 +49,7 @@ class OpenAIO1Config(OpenAIConfig):
""" """
all_openai_params = litellm.OpenAIConfig().get_supported_openai_params( all_openai_params = super().get_supported_openai_params(model=model)
model="gpt-4o"
)
non_supported_params = [ non_supported_params = [
"logprobs", "logprobs",
"tools", "tools",
@ -69,13 +66,14 @@ class OpenAIO1Config(OpenAIConfig):
def map_openai_params( def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str self, non_default_params: dict, optional_params: dict, model: str
): ):
for param, value in non_default_params.items(): if "max_tokens" in non_default_params:
if param == "max_tokens": optional_params["max_completion_tokens"] = non_default_params.pop(
optional_params["max_completion_tokens"] = value "max_tokens"
return optional_params )
return super()._map_openai_params(non_default_params, optional_params, model)
def is_model_o1_reasoning_model(self, model: str) -> bool: def is_model_o1_reasoning_model(self, model: str) -> bool:
if "o1" in model: if model in litellm.open_ai_chat_completion_models and "o1" in model:
return True return True
return False return False
@ -93,7 +91,7 @@ class OpenAIO1Config(OpenAIConfig):
) )
messages[i] = new_message # Replace the old message with the new one messages[i] = new_message # Replace the old message with the new one
if isinstance(message["content"], list): if "content" in message and isinstance(message["content"], list):
new_content = [] new_content = []
for content_item in message["content"]: for content_item in message["content"]:
if content_item.get("type") == "image_url": if content_item.get("type") == "image_url":

View file

@ -60,122 +60,6 @@ class OpenAIError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class MistralConfig:
"""
Reference: https://docs.mistral.ai/api/
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
"""
temperature: Optional[int] = None
top_p: Optional[int] = None
max_tokens: Optional[int] = None
tools: Optional[list] = None
tool_choice: Optional[Literal["auto", "any", "none"]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None
stop: Optional[Union[str, list]] = None
def __init__(
self,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
max_tokens: Optional[int] = None,
tools: Optional[list] = None,
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
stop: Optional[Union[str, list]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
"seed",
"stop",
"response_format",
]
def _map_tool_choice(self, tool_choice: str) -> str:
if tool_choice == "auto" or tool_choice == "none":
return tool_choice
elif tool_choice == "required":
return "any"
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
if param == "response_format":
optional_params["response_format"] = value
return optional_params
class MistralEmbeddingConfig: class MistralEmbeddingConfig:
""" """
Reference: https://docs.mistral.ai/api/#operation/createEmbedding Reference: https://docs.mistral.ai/api/#operation/createEmbedding
@ -526,44 +410,19 @@ class OpenAIConfig:
} }
def get_supported_openai_params(self, model: str) -> list: def get_supported_openai_params(self, model: str) -> list:
base_params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
] # works across all models
model_specific_params = []
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model): if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
return litellm.OpenAIO1Config().get_supported_openai_params(model=model) return litellm.OpenAIO1Config().get_supported_openai_params(model=model)
if ( else:
model != "gpt-3.5-turbo-16k" and model != "gpt-4" return litellm.OpenAIGPTConfig().get_supported_openai_params(model=model)
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
if ( def _map_openai_params(
model in litellm.open_ai_chat_completion_models self, non_default_params: dict, optional_params: dict, model: str
) or model in litellm.open_ai_text_completion_models: ) -> dict:
model_specific_params.append( supported_openai_params = self.get_supported_openai_params(model)
"user" for param, value in non_default_params.items():
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai if param in supported_openai_params:
return base_params + model_specific_params optional_params[param] = value
return optional_params
def map_openai_params( def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str self, non_default_params: dict, optional_params: dict, model: str
@ -575,11 +434,11 @@ class OpenAIConfig:
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
) )
supported_openai_params = self.get_supported_openai_params(model) return litellm.OpenAIGPTConfig().map_openai_params(
for param, value in non_default_params.items(): non_default_params=non_default_params,
if param in supported_openai_params: optional_params=optional_params,
optional_params[param] = value model=model,
return optional_params )
class OpenAITextCompletionConfig: class OpenAITextCompletionConfig:
@ -816,18 +675,18 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e: except Exception as e:
raise e raise e
def completion( def completion( # type: ignore
self, self,
model_response: ModelResponse, model_response: ModelResponse,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
optional_params: dict, optional_params: dict,
logging_obj: Any,
model: Optional[str] = None, model: Optional[str] = None,
messages: Optional[list] = None, messages: Optional[list] = None,
print_verbose: Optional[Callable] = None, print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
acompletion: bool = False, acompletion: bool = False,
logging_obj=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
@ -858,14 +717,14 @@ class OpenAIChatCompletion(BaseLLM):
# process all OpenAI compatible provider logic here # process all OpenAI compatible provider logic here
if custom_llm_provider == "mistral": if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string # check if message content passed in as list, and not string
messages = prompt_factory( messages = prompt_factory( # type: ignore
model=model, model=model,
messages=messages, messages=messages,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
if custom_llm_provider == "perplexity" and messages is not None: if custom_llm_provider == "perplexity" and messages is not None:
# check if messages.name is passed + supported, if not supported remove # check if messages.name is passed + supported, if not supported remove
messages = prompt_factory( messages = prompt_factory( # type: ignore
model=model, model=model,
messages=messages, messages=messages,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
@ -933,7 +792,7 @@ class OpenAIChatCompletion(BaseLLM):
status_code=422, message="max retries must be an int" status_code=422, message="max retries must be an int"
) )
openai_client = self._get_openai_client( openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False, is_async=False,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1068,7 +927,7 @@ class OpenAIChatCompletion(BaseLLM):
2 2
): # if call fails due to alternating messages, retry with reformatted message ): # if call fails due to alternating messages, retry with reformatted message
try: try:
openai_aclient = self._get_openai_client( openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True, is_async=True,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1156,7 +1015,7 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None, max_retries=None,
headers=None, headers=None,
): ):
openai_client = self._get_openai_client( openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False, is_async=False,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1210,7 +1069,7 @@ class OpenAIChatCompletion(BaseLLM):
response = None response = None
for _ in range(2): for _ in range(2):
try: try:
openai_aclient = self._get_openai_client( openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True, is_async=True,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1282,7 +1141,7 @@ class OpenAIChatCompletion(BaseLLM):
error_headers = getattr(e, "headers", None) error_headers = getattr(e, "headers", None)
raise OpenAIError( raise OpenAIError(
status_code=500, status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}", message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore
headers=error_headers, headers=error_headers,
) )
else: else:
@ -1294,7 +1153,7 @@ class OpenAIChatCompletion(BaseLLM):
) )
elif hasattr(e, "status_code"): elif hasattr(e, "status_code"):
raise OpenAIError( raise OpenAIError(
status_code=e.status_code, status_code=getattr(e, "status_code", 500),
message=str(e), message=str(e),
headers=error_headers, headers=error_headers,
) )
@ -1361,7 +1220,7 @@ class OpenAIChatCompletion(BaseLLM):
): ):
response = None response = None
try: try:
openai_aclient = self._get_openai_client( openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True, is_async=True,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1410,16 +1269,16 @@ class OpenAIChatCompletion(BaseLLM):
status_code=status_code, message=str(e), headers=error_headers status_code=status_code, message=str(e), headers=error_headers
) )
def embedding( def embedding( # type: ignore
self, self,
model: str, model: str,
input: list, input: list,
timeout: float, timeout: float,
logging_obj, logging_obj,
model_response: litellm.utils.EmbeddingResponse, model_response: litellm.utils.EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
optional_params=None,
client=None, client=None,
aembedding=None, aembedding=None,
): ):
@ -1452,7 +1311,7 @@ class OpenAIChatCompletion(BaseLLM):
) )
return response return response
openai_client = self._get_openai_client( openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False, is_async=False,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1496,11 +1355,11 @@ class OpenAIChatCompletion(BaseLLM):
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: float,
logging_obj: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
client=None, client=None,
max_retries=None, max_retries=None,
logging_obj=None,
): ):
response = None response = None
try: try:
@ -1538,15 +1397,16 @@ class OpenAIChatCompletion(BaseLLM):
model: Optional[str], model: Optional[str],
prompt: str, prompt: str,
timeout: float, timeout: float,
optional_params: dict,
logging_obj: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None, model_response: Optional[litellm.utils.ImageResponse] = None,
logging_obj=None,
optional_params=None,
client=None, client=None,
aimg_generation=None, aimg_generation=None,
): ):
exception_mapping_worked = False exception_mapping_worked = False
data = {}
try: try:
model = model model = model
data = {"model": model, "prompt": prompt, **optional_params} data = {"model": model, "prompt": prompt, **optional_params}
@ -1611,7 +1471,9 @@ class OpenAIChatCompletion(BaseLLM):
original_response=str(e), original_response=str(e),
) )
if hasattr(e, "status_code"): if hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e)) raise OpenAIError(
status_code=getattr(e, "status_code", 500), message=str(e)
)
else: else:
raise OpenAIError(status_code=500, message=str(e)) raise OpenAIError(status_code=500, message=str(e))
@ -1661,7 +1523,7 @@ class OpenAIChatCompletion(BaseLLM):
input=input, input=input,
**optional_params, **optional_params,
) )
return response return response # type: ignore
async def async_audio_speech( async def async_audio_speech(
self, self,
@ -1784,11 +1646,8 @@ class OpenAIChatCompletion(BaseLLM):
class OpenAITextCompletion(BaseLLM): class OpenAITextCompletion(BaseLLM):
_client_session: httpx.Client
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._client_session = self.create_client_session()
def validate_environment(self, api_key): def validate_environment(self, api_key):
headers = { headers = {
@ -1806,10 +1665,10 @@ class OpenAITextCompletion(BaseLLM):
messages: list, messages: list,
timeout: float, timeout: float,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
optional_params: dict,
print_verbose: Optional[Callable] = None, print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
acompletion: bool = False, acompletion: bool = False,
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
client=None, client=None,
@ -1921,7 +1780,7 @@ class OpenAITextCompletion(BaseLLM):
api_key: str, api_key: str,
model: str, model: str,
timeout: float, timeout: float,
max_retries=None, max_retries: int,
organization: Optional[str] = None, organization: Optional[str] = None,
client=None, client=None,
): ):
@ -2017,9 +1876,9 @@ class OpenAITextCompletion(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
model: str, model: str,
timeout: float, timeout: float,
max_retries: int,
api_base: Optional[str] = None, api_base: Optional[str] = None,
client=None, client=None,
max_retries=None,
organization=None, organization=None,
): ):
if client is None: if client is None:

12
litellm/llms/README.md Normal file
View file

@ -0,0 +1,12 @@
## File Structure
### August 27th, 2024
To make it easy to see how calls are transformed for each model/provider:
we are working on moving all supported litellm providers to a folder structure, where folder name is the supported litellm provider name.
Each folder will contain a `*_transformation.py` file, which has all the request/response transformation logic, making it easy to see how calls are modified.
E.g. `cohere/`, `bedrock/`.

View file

@ -66,22 +66,24 @@ class BaseLLM:
return _aclient_session return _aclient_session
def __exit__(self): def __exit__(self):
if hasattr(self, "_client_session"): if hasattr(self, "_client_session") and self._client_session is not None:
self._client_session.close() self._client_session.close()
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, "_aclient_session"): if hasattr(self, "_aclient_session"):
await self._aclient_session.aclose() await self._aclient_session.aclose() # type: ignore
def validate_environment(self): # set up the environment required to run the model def validate_environment(
pass self, *args, **kwargs
) -> Optional[Any]: # set up the environment required to run the model
return None
def completion( def completion(
self, *args, **kwargs self, *args, **kwargs
): # logic for parsing in - calling - parsing out model completion calls ) -> Any: # logic for parsing in - calling - parsing out model completion calls
pass return None
def embedding( def embedding(
self, *args, **kwargs self, *args, **kwargs
): # logic for parsing in - calling - parsing out model embedding calls ) -> Any: # logic for parsing in - calling - parsing out model embedding calls
pass return None

View file

@ -0,0 +1,5 @@
"""
Calls handled in openai/
as mistral is an openai-compatible endpoint.
"""

View file

@ -0,0 +1,5 @@
"""
Calls handled in openai/
as mistral is an openai-compatible endpoint.
"""

View file

@ -0,0 +1,126 @@
"""
Transformation logic from OpenAI /v1/chat/completion format to Mistral's /chat/completion format.
Why separate file? Make it easy to see how transformation works
Docs - https://docs.mistral.ai/api/
"""
import types
from typing import List, Literal, Optional, Union
class MistralConfig:
"""
Reference: https://docs.mistral.ai/api/
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
"""
temperature: Optional[int] = None
top_p: Optional[int] = None
max_tokens: Optional[int] = None
tools: Optional[list] = None
tool_choice: Optional[Literal["auto", "any", "none"]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None
stop: Optional[Union[str, list]] = None
def __init__(
self,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
max_tokens: Optional[int] = None,
tools: Optional[list] = None,
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
stop: Optional[Union[str, list]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
"seed",
"stop",
"response_format",
]
def _map_tool_choice(self, tool_choice: str) -> str:
if tool_choice == "auto" or tool_choice == "none":
return tool_choice
elif tool_choice == "required":
return "any"
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
if param == "response_format":
optional_params["response_format"] = value
return optional_params

View file

@ -737,6 +737,7 @@ def completion(
preset_cache_key = kwargs.get("preset_cache_key", None) preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None) hf_model_name = kwargs.get("hf_model_name", None)
supports_system_message = kwargs.get("supports_system_message", None) supports_system_message = kwargs.get("supports_system_message", None)
base_model = kwargs.get("base_model", None)
### TEXT COMPLETION CALLS ### ### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False) text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False) atext_completion = kwargs.get("atext_completion", False)
@ -782,11 +783,9 @@ def completion(
"top_logprobs", "top_logprobs",
"extra_headers", "extra_headers",
] ]
litellm_params = (
all_litellm_params # use the external var., used in creating cache key as well.
)
default_params = openai_params + litellm_params default_params = openai_params + all_litellm_params
litellm_params = {} # used to prevent unbound var errors
non_default_params = { non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider } # model-specific params - pass them straight to the model/provider
@ -973,6 +972,7 @@ def completion(
text_completion=kwargs.get("text_completion"), text_completion=kwargs.get("text_completion"),
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"), azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"), user_continue_message=kwargs.get("user_continue_message"),
base_model=base_model,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -2123,7 +2123,10 @@ def completion(
timeout=timeout, timeout=timeout,
client=client, client=client,
) )
elif "gemini" in model: elif "gemini" in model or (
litellm_params.get("base_model") is not None
and "gemini" in litellm_params["base_model"]
):
model_response = vertex_chat_completion.completion( # type: ignore model_response = vertex_chat_completion.completion( # type: ignore
model=model, model=model,
messages=messages, messages=messages,
@ -2820,7 +2823,7 @@ def completion_with_retries(*args, **kwargs):
) )
num_retries = kwargs.pop("num_retries", 3) num_retries = kwargs.pop("num_retries", 3)
retry_strategy = kwargs.pop("retry_strategy", "constant_retry") retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
original_function = kwargs.pop("original_function", completion) original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry": if retry_strategy == "constant_retry":
retryer = tenacity.Retrying( retryer = tenacity.Retrying(
@ -4997,7 +5000,9 @@ def speech(
async def ahealth_check( async def ahealth_check(
model_params: dict, model_params: dict,
mode: Optional[ mode: Optional[
Literal["completion", "embedding", "image_generation", "chat", "batch"] Literal[
"completion", "embedding", "image_generation", "chat", "batch", "rerank"
]
] = None, ] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
input: Optional[List] = None, input: Optional[List] = None,
@ -5113,6 +5118,12 @@ async def ahealth_check(
model_params["prompt"] = prompt model_params["prompt"] = prompt
await litellm.aimage_generation(**model_params) await litellm.aimage_generation(**model_params)
response = {} response = {}
elif mode == "rerank":
model_params.pop("messages", None)
model_params["query"] = prompt
model_params["documents"] = ["my sample text"]
await litellm.arerank(**model_params)
response = {}
elif "*" in model: elif "*" in model:
from litellm.litellm_core_utils.llm_request_utils import ( from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_model_from_llm_provider, pick_cheapest_model_from_llm_provider,

View file

@ -1,9 +1,7 @@
model_list: model_list:
- model_name: "gpt-4o" - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: gpt-4o model: gpt-3.5-turbo
litellm_settings: router_settings:
cache: true model_group_alias: {"gpt-4": {"model": "gpt-3.5-turbo", "hidden": false}}
cache_params:
type: local

View file

@ -99,15 +99,15 @@ class LitellmUserRoles(str, enum.Enum):
return ui_labels.get(self.value, "") return ui_labels.get(self.value, "")
class LitellmTableNames(str, enum.Enum): class LitellmTableNames(enum.Enum):
""" """
Enum for Table Names used by LiteLLM Enum for Table Names used by LiteLLM
""" """
TEAM_TABLE_NAME: str = "LiteLLM_TeamTable" TEAM_TABLE_NAME = "LiteLLM_TeamTable"
USER_TABLE_NAME: str = "LiteLLM_UserTable" USER_TABLE_NAME = "LiteLLM_UserTable"
KEY_TABLE_NAME: str = "LiteLLM_VerificationToken" KEY_TABLE_NAME = "LiteLLM_VerificationToken"
PROXY_MODEL_TABLE_NAME: str = "LiteLLM_ModelTable" PROXY_MODEL_TABLE_NAME = "LiteLLM_ModelTable"
AlertType = Literal[ AlertType = Literal[
@ -140,7 +140,7 @@ class LiteLLMBase(BaseModel):
Implements default functions, all pydantic objects should have. Implements default functions, all pydantic objects should have.
""" """
def json(self, **kwargs): def json(self, **kwargs): # type: ignore
try: try:
return self.model_dump(**kwargs) # noqa return self.model_dump(**kwargs) # noqa
except Exception as e: except Exception as e:
@ -170,7 +170,7 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
class LiteLLMRoutes(enum.Enum): class LiteLLMRoutes(enum.Enum):
openai_route_names: List = [ openai_route_names = [
"chat_completion", "chat_completion",
"completion", "completion",
"embeddings", "embeddings",
@ -179,7 +179,7 @@ class LiteLLMRoutes(enum.Enum):
"moderations", "moderations",
"model_list", # OpenAI /v1/models route "model_list", # OpenAI /v1/models route
] ]
openai_routes: List = [ openai_routes = [
# chat completions # chat completions
"/engines/{model}/chat/completions", "/engines/{model}/chat/completions",
"/openai/deployments/{model}/chat/completions", "/openai/deployments/{model}/chat/completions",
@ -247,18 +247,18 @@ class LiteLLMRoutes(enum.Enum):
"/v1/rerank", "/v1/rerank",
] ]
mapped_pass_through_routes: List = [ mapped_pass_through_routes = [
"/bedrock", "/bedrock",
"/vertex-ai", "/vertex-ai",
"/gemini", "/gemini",
"/langfuse", "/langfuse",
] ]
anthropic_routes: List = [ anthropic_routes = [
"/v1/messages", "/v1/messages",
] ]
info_routes: List = [ info_routes = [
"/key/info", "/key/info",
"/team/info", "/team/info",
"/team/list", "/team/list",
@ -271,9 +271,9 @@ class LiteLLMRoutes(enum.Enum):
] ]
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend # NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes: List = ["/global/spend/reset", "/key/list"] master_key_only_routes = ["/global/spend/reset", "/key/list"]
sso_only_routes: List = [ sso_only_routes = [
"/key/generate", "/key/generate",
"/key/update", "/key/update",
"/key/delete", "/key/delete",
@ -282,7 +282,7 @@ class LiteLLMRoutes(enum.Enum):
"/sso/get/logout_url", "/sso/get/logout_url",
] ]
management_routes: List = [ # key management_routes = [ # key
"/key/generate", "/key/generate",
"/key/update", "/key/update",
"/key/delete", "/key/delete",
@ -307,7 +307,7 @@ class LiteLLMRoutes(enum.Enum):
"/model/info", "/model/info",
] ]
spend_tracking_routes: List = [ spend_tracking_routes = [
# spend # spend
"/spend/keys", "/spend/keys",
"/spend/users", "/spend/users",
@ -316,7 +316,7 @@ class LiteLLMRoutes(enum.Enum):
"/spend/logs", "/spend/logs",
] ]
global_spend_tracking_routes: List = [ global_spend_tracking_routes = [
# global spend # global spend
"/global/spend/logs", "/global/spend/logs",
"/global/spend", "/global/spend",
@ -328,7 +328,7 @@ class LiteLLMRoutes(enum.Enum):
"/global/spend/report", "/global/spend/report",
] ]
public_routes: List = [ public_routes = [
"/routes", "/routes",
"/", "/",
"/health/liveliness", "/health/liveliness",
@ -339,7 +339,7 @@ class LiteLLMRoutes(enum.Enum):
"/metrics", "/metrics",
] ]
internal_user_routes: List = ( internal_user_routes = (
[ [
"/key/generate", "/key/generate",
"/key/update", "/key/update",
@ -357,7 +357,7 @@ class LiteLLMRoutes(enum.Enum):
+ sso_only_routes + sso_only_routes
) )
self_managed_routes: List = [ self_managed_routes = [
"/team/member_add", "/team/member_add",
"/team/member_delete", "/team/member_delete",
] # routes that manage their own allowed/disallowed logic ] # routes that manage their own allowed/disallowed logic
@ -581,7 +581,9 @@ class ModelParams(LiteLLMBase):
@classmethod @classmethod
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("model_info") is None: if values.get("model_info") is None:
values.update({"model_info": ModelInfo()}) values.update(
{"model_info": ModelInfo(id=None, mode="chat", base_model=None)}
)
return values return values
@ -627,7 +629,7 @@ class GenerateKeyRequest(_GenerateKeyRequest):
class GenerateKeyResponse(_GenerateKeyRequest): class GenerateKeyResponse(_GenerateKeyRequest):
key: str key: str # type: ignore
key_name: Optional[str] = None key_name: Optional[str] = None
expires: Optional[datetime] expires: Optional[datetime]
user_id: Optional[str] = None user_id: Optional[str] = None
@ -659,7 +661,7 @@ class GenerateKeyResponse(_GenerateKeyRequest):
class UpdateKeyRequest(GenerateKeyRequest): class UpdateKeyRequest(GenerateKeyRequest):
# Note: the defaults of all Params here MUST BE NONE # Note: the defaults of all Params here MUST BE NONE
# else they will get overwritten # else they will get overwritten
key: str key: str # type: ignore
duration: Optional[str] = None duration: Optional[str] = None
spend: Optional[float] = None spend: Optional[float] = None
metadata: Optional[dict] = None metadata: Optional[dict] = None
@ -976,6 +978,7 @@ class TeamCallbackMetadata(LiteLLMBase):
class LiteLLM_TeamTable(TeamBase): class LiteLLM_TeamTable(TeamBase):
team_id: str # type: ignore
spend: Optional[float] = None spend: Optional[float] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
@ -1061,7 +1064,7 @@ class LiteLLM_OrganizationTable(LiteLLMBase):
class NewOrganizationResponse(LiteLLM_OrganizationTable): class NewOrganizationResponse(LiteLLM_OrganizationTable):
organization_id: str organization_id: str # type: ignore
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@ -1388,16 +1391,7 @@ class UserAPIKeyAuth(
""" """
api_key: Optional[str] = None api_key: Optional[str] = None
user_role: Optional[ user_role: Optional[LitellmUserRoles] = None
Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
LitellmUserRoles.TEAM,
LitellmUserRoles.CUSTOMER,
]
] = None
allowed_model_region: Optional[Literal["eu"]] = None allowed_model_region: Optional[Literal["eu"]] = None
parent_otel_span: Optional[Span] = None parent_otel_span: Optional[Span] = None
rpm_limit_per_model: Optional[Dict[str, int]] = None rpm_limit_per_model: Optional[Dict[str, int]] = None
@ -1716,9 +1710,9 @@ class SpendLogsPayload(TypedDict):
total_tokens: int total_tokens: int
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
startTime: datetime startTime: Union[datetime, str]
endTime: datetime endTime: Union[datetime, str]
completionStartTime: Optional[datetime] completionStartTime: Optional[Union[datetime, str]]
model: str model: str
model_id: Optional[str] model_id: Optional[str]
model_group: Optional[str] model_group: Optional[str]
@ -1891,6 +1885,6 @@ class TeamAddMemberResponse(LiteLLM_TeamTable):
class TeamInfoResponseObject(TypedDict): class TeamInfoResponseObject(TypedDict):
team_id: str team_id: str
team_info: TeamBase team_info: LiteLLM_TeamTable
keys: List keys: List
team_memberships: List[LiteLLM_TeamMembership] team_memberships: List[LiteLLM_TeamMembership]

View file

@ -109,11 +109,8 @@ async def user_api_key_auth(
), ),
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
allowed_routes_check,
common_checks,
custom_db_client, custom_db_client,
general_settings, general_settings,
get_actual_routes,
jwt_handler, jwt_handler,
litellm_proxy_admin_name, litellm_proxy_admin_name,
llm_model_list, llm_model_list,
@ -125,6 +122,8 @@ async def user_api_key_auth(
user_custom_auth, user_custom_auth,
) )
parent_otel_span: Optional[Span] = None
try: try:
route: str = get_request_route(request=request) route: str = get_request_route(request=request)
# get the request body # get the request body
@ -137,6 +136,7 @@ async def user_api_key_auth(
pass_through_endpoints: Optional[List[dict]] = general_settings.get( pass_through_endpoints: Optional[List[dict]] = general_settings.get(
"pass_through_endpoints", None "pass_through_endpoints", None
) )
passed_in_key: Optional[str] = None
if isinstance(api_key, str): if isinstance(api_key, str):
passed_in_key = api_key passed_in_key = api_key
api_key = _get_bearer_token(api_key=api_key) api_key = _get_bearer_token(api_key=api_key)
@ -161,7 +161,6 @@ async def user_api_key_auth(
custom_litellm_key_header_name=custom_litellm_key_header_name, custom_litellm_key_header_name=custom_litellm_key_header_name,
) )
parent_otel_span: Optional[Span] = None
if open_telemetry_logger is not None: if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span( parent_otel_span = open_telemetry_logger.tracer.start_span(
name="Received Proxy Server Request", name="Received Proxy Server Request",
@ -189,7 +188,7 @@ async def user_api_key_auth(
######## Route Checks Before Reading DB / Cache for "token" ################ ######## Route Checks Before Reading DB / Cache for "token" ################
if ( if (
route in LiteLLMRoutes.public_routes.value route in LiteLLMRoutes.public_routes.value # type: ignore
or route_in_additonal_public_routes(current_route=route) or route_in_additonal_public_routes(current_route=route)
): ):
# check if public endpoint # check if public endpoint
@ -410,7 +409,7 @@ async def user_api_key_auth(
#### ELSE #### #### ELSE ####
## CHECK PASS-THROUGH ENDPOINTS ## ## CHECK PASS-THROUGH ENDPOINTS ##
is_mapped_pass_through_route: bool = False is_mapped_pass_through_route: bool = False
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore
if route.startswith(mapped_route): if route.startswith(mapped_route):
is_mapped_pass_through_route = True is_mapped_pass_through_route = True
if is_mapped_pass_through_route: if is_mapped_pass_through_route:
@ -444,9 +443,9 @@ async def user_api_key_auth(
header_key = headers.get("litellm_user_api_key", "") header_key = headers.get("litellm_user_api_key", "")
if ( if (
isinstance(request.headers, dict) isinstance(request.headers, dict)
and request.headers.get(key=header_key) is not None and request.headers.get(key=header_key) is not None # type: ignore
): ):
api_key = request.headers.get(key=header_key) api_key = request.headers.get(key=header_key) # type: ignore
if master_key is None: if master_key is None:
if isinstance(api_key, str): if isinstance(api_key, str):
@ -606,7 +605,7 @@ async def user_api_key_auth(
## IF it's not a master key ## IF it's not a master key
## Route should not be in master_key_only_routes ## Route should not be in master_key_only_routes
if route in LiteLLMRoutes.master_key_only_routes.value: if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore
raise Exception( raise Exception(
f"Tried to access route={route}, which is only for MASTER KEY" f"Tried to access route={route}, which is only for MASTER KEY"
) )
@ -669,8 +668,9 @@ async def user_api_key_auth(
"allowed_model_region" "allowed_model_region"
) )
user_obj: Optional[LiteLLM_UserTable] = None
valid_token_dict: dict = {}
if valid_token is not None: if valid_token is not None:
user_obj: Optional[LiteLLM_UserTable] = None
# Got Valid Token from Cache, DB # Got Valid Token from Cache, DB
# Run checks for # Run checks for
# 1. If token can call model # 1. If token can call model
@ -686,6 +686,7 @@ async def user_api_key_auth(
# Check 1. If token can call model # Check 1. If token can call model
_model_alias_map = {} _model_alias_map = {}
model: Optional[str] = None
if ( if (
hasattr(valid_token, "team_model_aliases") hasattr(valid_token, "team_model_aliases")
and valid_token.team_model_aliases is not None and valid_token.team_model_aliases is not None
@ -698,6 +699,7 @@ async def user_api_key_auth(
_model_alias_map = {**valid_token.aliases} _model_alias_map = {**valid_token.aliases}
litellm.model_alias_map = _model_alias_map litellm.model_alias_map = _model_alias_map
config = valid_token.config config = valid_token.config
if config != {}: if config != {}:
model_list = config.get("model_list", []) model_list = config.get("model_list", [])
llm_model_list = model_list llm_model_list = model_list
@ -887,7 +889,10 @@ async def user_api_key_auth(
and max_budget_per_model.get(current_model, None) is not None and max_budget_per_model.get(current_model, None) is not None
): ):
if ( if (
model_spend[0]["model"] == current_model "model" in model_spend[0]
and model_spend[0].get("model") == current_model
and "_sum" in model_spend[0]
and "spend" in model_spend[0]["_sum"]
and model_spend[0]["_sum"]["spend"] and model_spend[0]["_sum"]["spend"]
>= max_budget_per_model[current_model] >= max_budget_per_model[current_model]
): ):
@ -927,16 +932,19 @@ async def user_api_key_auth(
) )
# Check 8: Additional Common Checks across jwt + key auth # Check 8: Additional Common Checks across jwt + key auth
_team_obj = LiteLLM_TeamTable( if valid_token.team_id is not None:
team_id=valid_token.team_id, _team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
max_budget=valid_token.team_max_budget, team_id=valid_token.team_id,
spend=valid_token.team_spend, max_budget=valid_token.team_max_budget,
tpm_limit=valid_token.team_tpm_limit, spend=valid_token.team_spend,
rpm_limit=valid_token.team_rpm_limit, tpm_limit=valid_token.team_tpm_limit,
blocked=valid_token.team_blocked, rpm_limit=valid_token.team_rpm_limit,
models=valid_token.team_models, blocked=valid_token.team_blocked,
metadata=valid_token.team_metadata, models=valid_token.team_models,
) metadata=valid_token.team_metadata,
)
else:
_team_obj = None
user_api_key_cache.set_cache( user_api_key_cache.set_cache(
key=valid_token.team_id, value=_team_obj key=valid_token.team_id, value=_team_obj
@ -1045,7 +1053,7 @@ async def user_api_key_auth(
"/global/predict/spend/logs", "/global/predict/spend/logs",
"/global/activity", "/global/activity",
"/health/services", "/health/services",
] + LiteLLMRoutes.info_routes.value ] + LiteLLMRoutes.info_routes.value # type: ignore
# check if the current route startswith any of the allowed routes # check if the current route startswith any of the allowed routes
if ( if (
route is not None route is not None
@ -1106,7 +1114,7 @@ async def user_api_key_auth(
# Log this exception to OTEL # Log this exception to OTEL
if open_telemetry_logger is not None: if open_telemetry_logger is not None:
await open_telemetry_logger.async_post_call_failure_hook( await open_telemetry_logger.async_post_call_failure_hook( # type: ignore
original_exception=e, original_exception=e,
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span), user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
) )

View file

@ -4,10 +4,11 @@ import json
import traceback import traceback
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import List, Optional from typing import List, Optional, Union
import fastapi import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from pydantic import BaseModel
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
@ -28,6 +29,7 @@ from litellm.proxy._types import (
ProxyErrorTypes, ProxyErrorTypes,
ProxyException, ProxyException,
TeamAddMemberResponse, TeamAddMemberResponse,
TeamBase,
TeamInfoResponseObject, TeamInfoResponseObject,
TeamMemberAddRequest, TeamMemberAddRequest,
TeamMemberDeleteRequest, TeamMemberDeleteRequest,
@ -36,6 +38,7 @@ from litellm.proxy._types import (
UpdateTeamRequest, UpdateTeamRequest,
UserAPIKeyAuth, UserAPIKeyAuth,
) )
from litellm.proxy.auth.auth_checks import get_team_object
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
from litellm.proxy.management_helpers.utils import ( from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
@ -240,7 +243,7 @@ async def new_team(
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at complete_team_data.budget_reset_at = reset_at
team_row = await prisma_client.insert_data( team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore
data=complete_team_data.json(exclude_none=True), table_name="team" data=complete_team_data.json(exclude_none=True), table_name="team"
) )
@ -462,9 +465,6 @@ async def team_member_add(
``` ```
""" """
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
get_team_object,
litellm_proxy_admin_name, litellm_proxy_admin_name,
prisma_client, prisma_client,
proxy_logging_obj, proxy_logging_obj,
@ -932,10 +932,13 @@ async def delete_team(
if litellm.store_audit_logs is True: if litellm.store_audit_logs is True:
# make an audit log for each team deleted # make an audit log for each team deleted
for team_id in data.team_ids: for team_id in data.team_ids:
team_row = await prisma_client.get_data( # type: ignore team_row: Optional[LiteLLM_TeamTable] = await prisma_client.get_data( # type: ignore
team_id=team_id, table_name="team", query_type="find_unique" team_id=team_id, table_name="team", query_type="find_unique"
) )
if team_row is None:
continue
_team_row = team_row.json(exclude_none=True) _team_row = team_row.json(exclude_none=True)
asyncio.create_task( asyncio.create_task(
@ -1027,8 +1030,10 @@ async def team_info(
), ),
) )
team_info = await prisma_client.get_data( team_info: Optional[Union[LiteLLM_TeamTable, dict]] = (
team_id=team_id, table_name="team", query_type="find_unique" await prisma_client.get_data(
team_id=team_id, table_name="team", query_type="find_unique"
)
) )
if team_info is None: if team_info is None:
raise HTTPException( raise HTTPException(
@ -1044,6 +1049,9 @@ async def team_info(
expires=datetime.now(), expires=datetime.now(),
) )
if keys is None:
keys = []
if team_info is None: if team_info is None:
## make sure we still return a total spend ## ## make sure we still return a total spend ##
spend = 0 spend = 0
@ -1055,7 +1063,7 @@ async def team_info(
for key in keys: for key in keys:
try: try:
key = key.model_dump() # noqa key = key.model_dump() # noqa
except: except Exception:
# if using pydantic v1 # if using pydantic v1
key = key.dict() key = key.dict()
key.pop("token", None) key.pop("token", None)
@ -1070,9 +1078,16 @@ async def team_info(
for tm in team_memberships: for tm in team_memberships:
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump())) returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
if isinstance(team_info, dict):
_team_info = LiteLLM_TeamTable(**team_info)
elif isinstance(team_info, BaseModel):
_team_info = LiteLLM_TeamTable(**team_info.model_dump())
else:
_team_info = LiteLLM_TeamTable()
response_object = TeamInfoResponseObject( response_object = TeamInfoResponseObject(
team_id=team_id, team_id=team_id,
team_info=team_info, team_info=_team_info,
keys=keys, keys=keys,
team_memberships=returned_tm, team_memberships=returned_tm,
) )

View file

@ -125,16 +125,7 @@ from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import ( from litellm.proxy.analytics_endpoints.analytics_endpoints import (
router as analytics_router, router as analytics_router,
) )
from litellm.proxy.auth.auth_checks import ( from litellm.proxy.auth.auth_checks import log_to_opentelemetry
allowed_routes_check,
common_checks,
get_actual_routes,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
log_to_opentelemetry,
)
from litellm.proxy.auth.auth_utils import check_response_size_is_safe from litellm.proxy.auth.auth_utils import check_response_size_is_safe
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck from litellm.proxy.auth.litellm_license import LicenseCheck
@ -260,6 +251,7 @@ from litellm.secret_managers.aws_secret_manager import (
load_aws_secret_manager, load_aws_secret_manager,
) )
from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import get_secret
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthropicMessagesRequest, AnthropicMessagesRequest,
AnthropicResponse, AnthropicResponse,
@ -484,7 +476,7 @@ general_settings: dict = {}
callback_settings: dict = {} callback_settings: dict = {}
log_file = "api_log.json" log_file = "api_log.json"
worker_config = None worker_config = None
master_key = None master_key: Optional[str] = None
otel_logging = False otel_logging = False
prisma_client: Optional[PrismaClient] = None prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None custom_db_client: Optional[DBClient] = None
@ -874,7 +866,9 @@ def error_tracking():
def _set_spend_logs_payload( def _set_spend_logs_payload(
payload: dict, prisma_client: PrismaClient, spend_logs_url: Optional[str] = None payload: Union[dict, SpendLogsPayload],
prisma_client: PrismaClient,
spend_logs_url: Optional[str] = None,
): ):
if prisma_client is not None and spend_logs_url is not None: if prisma_client is not None and spend_logs_url is not None:
if isinstance(payload["startTime"], datetime): if isinstance(payload["startTime"], datetime):
@ -1341,6 +1335,9 @@ async def _run_background_health_check():
# make 1 deep copy of llm_model_list -> use this for all background health checks # make 1 deep copy of llm_model_list -> use this for all background health checks
_llm_model_list = copy.deepcopy(llm_model_list) _llm_model_list = copy.deepcopy(llm_model_list)
if _llm_model_list is None:
return
while True: while True:
healthy_endpoints, unhealthy_endpoints = await perform_health_check( healthy_endpoints, unhealthy_endpoints = await perform_health_check(
model_list=_llm_model_list, details=health_check_details model_list=_llm_model_list, details=health_check_details
@ -1352,7 +1349,10 @@ async def _run_background_health_check():
health_check_results["healthy_count"] = len(healthy_endpoints) health_check_results["healthy_count"] = len(healthy_endpoints)
health_check_results["unhealthy_count"] = len(unhealthy_endpoints) health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
await asyncio.sleep(health_check_interval) if health_check_interval is not None and isinstance(
health_check_interval, float
):
await asyncio.sleep(health_check_interval)
class ProxyConfig: class ProxyConfig:
@ -1467,7 +1467,7 @@ class ProxyConfig:
break break
for k, v in team_config.items(): for k, v in team_config.items():
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
team_config[k] = litellm.get_secret(v) team_config[k] = get_secret(v)
return team_config return team_config
def _init_cache( def _init_cache(
@ -1513,6 +1513,9 @@ class ProxyConfig:
config = get_file_contents_from_s3( config = get_file_contents_from_s3(
bucket_name=bucket_name, object_key=object_key bucket_name=bucket_name, object_key=object_key
) )
if config is None:
raise Exception("Unable to load config from given source.")
else: else:
# default to file # default to file
config = await self.get_config(config_file_path=config_file_path) config = await self.get_config(config_file_path=config_file_path)
@ -1528,9 +1531,7 @@ class ProxyConfig:
environment_variables = config.get("environment_variables", None) environment_variables = config.get("environment_variables", None)
if environment_variables: if environment_variables:
for key, value in environment_variables.items(): for key, value in environment_variables.items():
os.environ[key] = str( os.environ[key] = str(get_secret(secret_name=key, default_value=value))
litellm.get_secret(secret_name=key, default_value=value)
)
# check if litellm_license in general_settings # check if litellm_license in general_settings
if "LITELLM_LICENSE" in environment_variables: if "LITELLM_LICENSE" in environment_variables:
@ -1566,8 +1567,8 @@ class ProxyConfig:
if ( if (
cache_type == "redis" or cache_type == "redis-semantic" cache_type == "redis" or cache_type == "redis-semantic"
) and len(cache_params.keys()) == 0: ) and len(cache_params.keys()) == 0:
cache_host = litellm.get_secret("REDIS_HOST", None) cache_host = get_secret("REDIS_HOST", None)
cache_port = litellm.get_secret("REDIS_PORT", None) cache_port = get_secret("REDIS_PORT", None)
cache_password = None cache_password = None
cache_params.update( cache_params.update(
{ {
@ -1577,8 +1578,8 @@ class ProxyConfig:
} }
) )
if litellm.get_secret("REDIS_PASSWORD", None) is not None: if get_secret("REDIS_PASSWORD", None) is not None:
cache_password = litellm.get_secret("REDIS_PASSWORD", None) cache_password = get_secret("REDIS_PASSWORD", None)
cache_params.update( cache_params.update(
{ {
"password": cache_password, "password": cache_password,
@ -1617,7 +1618,7 @@ class ProxyConfig:
# users can pass os.environ/ variables on the proxy - we should read them from the env # users can pass os.environ/ variables on the proxy - we should read them from the env
for key, value in cache_params.items(): for key, value in cache_params.items():
if type(value) is str and value.startswith("os.environ/"): if type(value) is str and value.startswith("os.environ/"):
cache_params[key] = litellm.get_secret(value) cache_params[key] = get_secret(value)
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
self._init_cache(cache_params=cache_params) self._init_cache(cache_params=cache_params)
@ -1738,7 +1739,7 @@ class ProxyConfig:
if value is not None and isinstance(value, dict): if value is not None and isinstance(value, dict):
for _k, _v in value.items(): for _k, _v in value.items():
if isinstance(_v, str) and _v.startswith("os.environ/"): if isinstance(_v, str) and _v.startswith("os.environ/"):
value[_k] = litellm.get_secret(_v) value[_k] = get_secret(_v)
litellm.upperbound_key_generate_params = ( litellm.upperbound_key_generate_params = (
LiteLLM_UpperboundKeyGenerateParams(**value) LiteLLM_UpperboundKeyGenerateParams(**value)
) )
@ -1812,15 +1813,15 @@ class ProxyConfig:
database_url = general_settings.get("database_url", None) database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"): if database_url and database_url.startswith("os.environ/"):
verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!") verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url) database_url = get_secret(database_url)
verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url) verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url)
### MASTER KEY ### ### MASTER KEY ###
master_key = general_settings.get( master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) "master_key", get_secret("LITELLM_MASTER_KEY", None)
) )
if master_key and master_key.startswith("os.environ/"): if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key) master_key = get_secret(master_key) # type: ignore
if not isinstance(master_key, str): if not isinstance(master_key, str):
raise Exception( raise Exception(
"Master key must be a string. Current type - {}".format( "Master key must be a string. Current type - {}".format(
@ -1861,33 +1862,6 @@ class ProxyConfig:
await initialize_pass_through_endpoints( await initialize_pass_through_endpoints(
pass_through_endpoints=general_settings["pass_through_endpoints"] pass_through_endpoints=general_settings["pass_through_endpoints"]
) )
## dynamodb
database_type = general_settings.get("database_type", None)
if database_type is not None and (
database_type == "dynamo_db" or database_type == "dynamodb"
):
database_args = general_settings.get("database_args", None)
### LOAD FROM os.environ/ ###
for k, v in database_args.items():
if isinstance(v, str) and v.startswith("os.environ/"):
database_args[k] = litellm.get_secret(v)
if isinstance(k, str) and k == "aws_web_identity_token":
value = database_args[k]
verbose_proxy_logger.debug(
f"Loading AWS Web Identity Token from file: {value}"
)
if os.path.exists(value):
with open(value, "r") as file:
token_content = file.read()
database_args[k] = token_content
else:
verbose_proxy_logger.info(
f"DynamoDB Loading - {value} is not a valid file path"
)
verbose_proxy_logger.debug("database_args: %s", database_args)
custom_db_client = DBClient(
custom_db_args=database_args, custom_db_type=database_type
)
## ADMIN UI ACCESS ## ## ADMIN UI ACCESS ##
ui_access_mode = general_settings.get( ui_access_mode = general_settings.get(
"ui_access_mode", "all" "ui_access_mode", "all"
@ -1951,7 +1925,7 @@ class ProxyConfig:
### LOAD FROM os.environ/ ### ### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items(): for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v) model["litellm_params"][k] = get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
litellm_model_name = model["litellm_params"]["model"] litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None) litellm_model_api_base = model["litellm_params"].get("api_base", None)
@ -2005,7 +1979,10 @@ class ProxyConfig:
) # type:ignore ) # type:ignore
# Guardrail settings # Guardrail settings
guardrails_v2 = config.get("guardrails", None) guardrails_v2: Optional[dict] = None
if config is not None:
guardrails_v2 = config.get("guardrails", None)
if guardrails_v2: if guardrails_v2:
init_guardrails_v2( init_guardrails_v2(
all_guardrails=guardrails_v2, config_file_path=config_file_path all_guardrails=guardrails_v2, config_file_path=config_file_path
@ -2074,7 +2051,7 @@ class ProxyConfig:
### LOAD FROM os.environ/ ### ### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items(): for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v) model["litellm_params"][k] = get_secret(v)
## check if they have model-id's ## ## check if they have model-id's ##
model_id = model.get("model_info", {}).get("id", None) model_id = model.get("model_info", {}).get("id", None)
@ -2234,7 +2211,8 @@ class ProxyConfig:
for k, v in environment_variables.items(): for k, v in environment_variables.items():
try: try:
decrypted_value = decrypt_value_helper(value=v) decrypted_value = decrypt_value_helper(value=v)
os.environ[k] = decrypted_value if decrypted_value is not None:
os.environ[k] = decrypted_value
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"Error setting env variable: %s - %s", k, str(e) "Error setting env variable: %s - %s", k, str(e)
@ -2536,7 +2514,7 @@ async def async_assistants_data_generator(
) )
# chunk = chunk.model_dump_json(exclude_none=True) # chunk = chunk.model_dump_json(exclude_none=True)
async for c in chunk: async for c in chunk: # type: ignore
c = c.model_dump_json(exclude_none=True) c = c.model_dump_json(exclude_none=True)
try: try:
yield f"data: {c}\n\n" yield f"data: {c}\n\n"
@ -2745,17 +2723,22 @@ async def startup_event():
### LOAD MASTER KEY ### ### LOAD MASTER KEY ###
# check if master key set in environment - load from there # check if master key set in environment - load from there
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None) master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
# check if DATABASE_URL in environment - load from there # check if DATABASE_URL in environment - load from there
if prisma_client is None: if prisma_client is None:
prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None)) _db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
prisma_setup(database_url=_db_url)
### LOAD CONFIG ### ### LOAD CONFIG ###
worker_config = litellm.get_secret("WORKER_CONFIG") worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
verbose_proxy_logger.debug("worker_config: %s", worker_config) verbose_proxy_logger.debug("worker_config: %s", worker_config)
# check if it's a valid file path # check if it's a valid file path
if os.path.isfile(worker_config): if worker_config is not None:
if proxy_config.is_yaml(config_file_path=worker_config): if (
isinstance(worker_config, str)
and os.path.isfile(worker_config)
and proxy_config.is_yaml(config_file_path=worker_config)
):
( (
llm_router, llm_router,
llm_model_list, llm_model_list,
@ -2763,21 +2746,23 @@ async def startup_event():
) = await proxy_config.load_config( ) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config router=llm_router, config_file_path=worker_config
) )
else: elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
worker_config, str
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
elif isinstance(worker_config, dict):
await initialize(**worker_config) await initialize(**worker_config)
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None: else:
( # if not, assume it's a json string
llm_router, worker_config = json.loads(worker_config)
llm_model_list, if isinstance(worker_config, dict):
general_settings, await initialize(**worker_config)
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
else:
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
await initialize(**worker_config)
## CHECK PREMIUM USER ## CHECK PREMIUM USER
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -2825,7 +2810,7 @@ async def startup_event():
if general_settings.get("litellm_jwtauth", None) is not None: if general_settings.get("litellm_jwtauth", None) is not None:
for k, v in general_settings["litellm_jwtauth"].items(): for k, v in general_settings["litellm_jwtauth"].items():
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
general_settings["litellm_jwtauth"][k] = litellm.get_secret(v) general_settings["litellm_jwtauth"][k] = get_secret(v)
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"]) litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
else: else:
litellm_jwtauth = LiteLLM_JWTAuth() litellm_jwtauth = LiteLLM_JWTAuth()
@ -2948,8 +2933,7 @@ async def startup_event():
### ADD NEW MODELS ### ### ADD NEW MODELS ###
store_model_in_db = ( store_model_in_db = (
litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db) get_secret("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db
or store_model_in_db
) # type: ignore ) # type: ignore
if store_model_in_db == True: if store_model_in_db == True:
scheduler.add_job( scheduler.add_job(
@ -3498,7 +3482,7 @@ async def completion(
) )
### CALL HOOKS ### - modify outgoing data ### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook( response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
) )
fastapi_response.headers.update( fastapi_response.headers.update(
@ -4000,7 +3984,7 @@ async def audio_speech(
request_data=data, request_data=data,
) )
return StreamingResponse( return StreamingResponse(
generate(response), media_type="audio/mpeg", headers=custom_headers generate(response), media_type="audio/mpeg", headers=custom_headers # type: ignore
) )
except Exception as e: except Exception as e:
@ -4288,6 +4272,7 @@ async def create_assistant(
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
""" """
global proxy_logging_obj global proxy_logging_obj
data = {} # ensure data always dict
try: try:
# Use orjson to parse JSON data, orjson speeds up requests significantly # Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body() body = await request.body()
@ -7642,6 +7627,7 @@ async def model_group_info(
) )
model_groups: List[ModelGroupInfo] = [] model_groups: List[ModelGroupInfo] = []
for model in all_models_str: for model in all_models_str:
_model_group_info = llm_router.get_model_group_info(model_group=model) _model_group_info = llm_router.get_model_group_info(model_group=model)
@ -8051,7 +8037,8 @@ async def google_login(request: Request):
with microsoft_sso: with microsoft_sso:
return await microsoft_sso.get_login_redirect() return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None: elif generic_client_id is not None:
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -8616,6 +8603,8 @@ async def auth_callback(request: Request):
redirect_url += "sso/callback" redirect_url += "sso/callback"
else: else:
redirect_url += "/sso/callback" redirect_url += "/sso/callback"
result = None
if google_client_id is not None: if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO from fastapi_sso.sso.google import GoogleSSO
@ -8662,7 +8651,8 @@ async def auth_callback(request: Request):
result = await microsoft_sso.verify_and_process(request) result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None: elif generic_client_id is not None:
# make generic sso provider # make generic sso provider
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider from fastapi_sso.sso.base import DiscoveryDocument, OpenID
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -8766,8 +8756,8 @@ async def auth_callback(request: Request):
verbose_proxy_logger.debug("generic result: %s", result) verbose_proxy_logger.debug("generic result: %s", result)
# User is Authe'd in - generate key for the UI to access Proxy # User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None) user_email: Optional[str] = getattr(result, "email", None)
user_id = getattr(result, "id", None) user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None: if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
email_domain = user_email.split("@")[1] email_domain = user_email.split("@")[1]
@ -8783,12 +8773,12 @@ async def auth_callback(request: Request):
) )
# generic client id # generic client id
if generic_client_id is not None: if generic_client_id is not None and result is not None:
user_id = getattr(result, "id", None) user_id = getattr(result, "id", None)
user_email = getattr(result, "email", None) user_email = getattr(result, "email", None)
user_role = getattr(result, generic_user_role_attribute_name, None) user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
if user_id is None: if user_id is None and result is not None:
_first_name = getattr(result, "first_name", "") or "" _first_name = getattr(result, "first_name", "") or ""
_last_name = getattr(result, "last_name", "") or "" _last_name = getattr(result, "last_name", "") or ""
user_id = _first_name + _last_name user_id = _first_name + _last_name
@ -8811,54 +8801,45 @@ async def auth_callback(request: Request):
"spend": 0, "spend": 0,
"team_id": "litellm-dashboard", "team_id": "litellm-dashboard",
} }
user_defined_values: SSOUserDefinedValues = { user_defined_values: Optional[SSOUserDefinedValues] = None
"models": user_id_models, if user_id is not None:
"user_id": user_id, user_defined_values = SSOUserDefinedValues(
"user_email": user_email, models=user_id_models,
"max_budget": max_internal_user_budget, user_id=user_id,
"user_role": None, user_email=user_email,
"budget_duration": internal_user_budget_duration, max_budget=max_internal_user_budget,
} user_role=None,
budget_duration=internal_user_budget_duration,
)
_user_id_from_sso = user_id _user_id_from_sso = user_id
user_role = None
try: try:
user_role = None
if prisma_client is not None: if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user") user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}" f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
) )
if user_info is not None: if user_info is None:
user_defined_values = { ## check if user-email in db ##
"models": getattr(user_info, "models", user_id_models), user_info = await prisma_client.db.litellm_usertable.find_first(
"user_id": getattr(user_info, "user_id", user_id), where={"user_email": user_email}
"user_email": getattr(user_info, "user_id", user_email), )
"user_role": getattr(user_info, "user_role", None),
"max_budget": getattr(
user_info, "max_budget", max_internal_user_budget
),
"budget_duration": getattr(
user_info, "budget_duration", internal_user_budget_duration
),
}
user_role = getattr(user_info, "user_role", None)
## check if user-email in db ## if user_info is not None and user_id is not None:
user_info = await prisma_client.db.litellm_usertable.find_first( user_defined_values = SSOUserDefinedValues(
where={"user_email": user_email} models=getattr(user_info, "models", user_id_models),
) user_id=user_id,
if user_info is not None: user_email=getattr(user_info, "user_email", user_email),
user_defined_values = { user_role=getattr(user_info, "user_role", None),
"models": getattr(user_info, "models", user_id_models), max_budget=getattr(
"user_id": user_id,
"user_email": getattr(user_info, "user_id", user_email),
"user_role": getattr(user_info, "user_role", None),
"max_budget": getattr(
user_info, "max_budget", max_internal_user_budget user_info, "max_budget", max_internal_user_budget
), ),
"budget_duration": getattr( budget_duration=getattr(
user_info, "budget_duration", internal_user_budget_duration user_info, "budget_duration", internal_user_budget_duration
), ),
} )
user_role = getattr(user_info, "user_role", None) user_role = getattr(user_info, "user_role", None)
# update id # update id
@ -8886,6 +8867,11 @@ async def auth_callback(request: Request):
except Exception as e: except Exception as e:
pass pass
if user_defined_values is None:
raise Exception(
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
)
is_internal_user = False is_internal_user = False
if ( if (
user_defined_values["user_role"] is not None user_defined_values["user_role"] is not None
@ -8960,7 +8946,8 @@ async def auth_callback(request: Request):
master_key, master_key,
algorithm="HS256", algorithm="HS256",
) )
litellm_dashboard_ui += "?userID=" + user_id if user_id is not None and isinstance(user_id, str):
litellm_dashboard_ui += "?userID=" + user_id
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
redirect_response.set_cookie(key="token", value=jwt_token) redirect_response.set_cookie(key="token", value=jwt_token)
return redirect_response return redirect_response
@ -9023,6 +9010,7 @@ async def new_invitation(
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
} # type: ignore } # type: ignore
) )
return response
except Exception as e: except Exception as e:
if "Foreign key constraint failed on the field" in str(e): if "Foreign key constraint failed on the field" in str(e):
raise HTTPException( raise HTTPException(
@ -9031,7 +9019,7 @@ async def new_invitation(
"error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`." "error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
}, },
) )
return response raise HTTPException(status_code=500, detail={"error": str(e)})
@router.get( @router.get(
@ -9951,44 +9939,46 @@ async def get_routes():
""" """
routes = [] routes = []
for route in app.routes: for route in app.routes:
route_info = { endpoint_route = getattr(route, "endpoint", None)
"path": getattr(route, "path", None), if endpoint_route is not None:
"methods": getattr(route, "methods", None), route_info = {
"name": getattr(route, "name", None), "path": getattr(route, "path", None),
"endpoint": ( "methods": getattr(route, "methods", None),
getattr(route, "endpoint", None).__name__ "name": getattr(route, "name", None),
if getattr(route, "endpoint", None) "endpoint": (
else None endpoint_route.__name__
), if getattr(route, "endpoint", None)
} else None
routes.append(route_info) ),
}
routes.append(route_info)
return {"routes": routes} return {"routes": routes}
#### TEST ENDPOINTS #### #### TEST ENDPOINTS ####
@router.get( # @router.get(
"/token/generate", # "/token/generate",
dependencies=[Depends(user_api_key_auth)], # dependencies=[Depends(user_api_key_auth)],
include_in_schema=False, # include_in_schema=False,
) # )
async def token_generate(): # async def token_generate():
""" # """
Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc. # Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
""" # """
# Initialize AuthJWTSSO with your OpenID Provider configuration # # Initialize AuthJWTSSO with your OpenID Provider configuration
from fastapi_sso import AuthJWTSSO # from fastapi_sso import AuthJWTSSO
auth_jwt_sso = AuthJWTSSO( # auth_jwt_sso = AuthJWTSSO(
issuer=os.getenv("OPENID_BASE_URL"), # issuer=os.getenv("OPENID_BASE_URL"),
client_id=os.getenv("OPENID_CLIENT_ID"), # client_id=os.getenv("OPENID_CLIENT_ID"),
client_secret=os.getenv("OPENID_CLIENT_SECRET"), # client_secret=os.getenv("OPENID_CLIENT_SECRET"),
scopes=["litellm_proxy_admin"], # scopes=["litellm_proxy_admin"],
) # )
token = auth_jwt_sso.create_access_token() # token = auth_jwt_sso.create_access_token()
return {"token": token} # return {"token": token}
@router.on_event("shutdown") @router.on_event("shutdown")
@ -10013,7 +10003,8 @@ async def shutdown_event():
# flush langfuse logs on shutdow # flush langfuse logs on shutdow
from litellm.utils import langFuseLogger from litellm.utils import langFuseLogger
langFuseLogger.Langfuse.flush() if langFuseLogger is not None:
langFuseLogger.Langfuse.flush()
except: except:
# [DO NOT BLOCK shutdown events for this] # [DO NOT BLOCK shutdown events for this]
pass pass

View file

@ -92,6 +92,7 @@ from litellm.types.router import (
RetryPolicy, RetryPolicy,
RouterErrors, RouterErrors,
RouterGeneralSettings, RouterGeneralSettings,
RouterModelGroupAliasItem,
RouterRateLimitError, RouterRateLimitError,
RouterRateLimitErrorBasic, RouterRateLimitErrorBasic,
updateDeployment, updateDeployment,
@ -105,6 +106,7 @@ from litellm.utils import (
calculate_max_parallel_requests, calculate_max_parallel_requests,
create_proxy_transport_and_mounts, create_proxy_transport_and_mounts,
get_llm_provider, get_llm_provider,
get_secret,
get_utc_datetime, get_utc_datetime,
) )
@ -156,7 +158,9 @@ class Router:
fallbacks: List = [], fallbacks: List = [],
context_window_fallbacks: List = [], context_window_fallbacks: List = [],
content_policy_fallbacks: List = [], content_policy_fallbacks: List = [],
model_group_alias: Optional[dict] = {}, model_group_alias: Optional[
Dict[str, Union[str, RouterModelGroupAliasItem]]
] = {},
enable_pre_call_checks: bool = False, enable_pre_call_checks: bool = False,
enable_tag_filtering: bool = False, enable_tag_filtering: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request retry_after: int = 0, # min time to wait before retrying a failed request
@ -331,7 +335,8 @@ class Router:
self.set_model_list(model_list) self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list # type: ignore self.healthy_deployments: List = self.model_list # type: ignore
for m in model_list: for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0 if "model" in m["litellm_params"]:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
else: else:
self.model_list: List = ( self.model_list: List = (
[] []
@ -398,7 +403,7 @@ class Router:
self.previous_models: List = ( self.previous_models: List = (
[] []
) # list to store failed calls (passed in as metadata to next call) ) # list to store failed calls (passed in as metadata to next call)
self.model_group_alias: dict = ( self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
model_group_alias or {} model_group_alias or {}
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group ) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
@ -1179,6 +1184,7 @@ class Router:
raise e raise e
def _image_generation(self, prompt: str, model: str, **kwargs): def _image_generation(self, prompt: str, model: str, **kwargs):
model_name = ""
try: try:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
@ -1269,6 +1275,7 @@ class Router:
raise e raise e
async def _aimage_generation(self, prompt: str, model: str, **kwargs): async def _aimage_generation(self, prompt: str, model: str, **kwargs):
model_name = ""
try: try:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
@ -1401,6 +1408,7 @@ class Router:
raise e raise e
async def _atranscription(self, file: FileTypes, model: str, **kwargs): async def _atranscription(self, file: FileTypes, model: str, **kwargs):
model_name = model
try: try:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}" f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
@ -1781,6 +1789,7 @@ class Router:
is_async: Optional[bool] = False, is_async: Optional[bool] = False,
**kwargs, **kwargs,
): ):
messages = [{"role": "user", "content": prompt}]
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["prompt"] = prompt kwargs["prompt"] = prompt
@ -1789,7 +1798,6 @@ class Router:
timeout = kwargs.get("request_timeout", self.timeout) timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model}) kwargs.setdefault("metadata", {}).update({"model_group": model})
messages = [{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment( deployment = self.get_available_deployment(
model=model, model=model,
@ -2534,13 +2542,13 @@ class Router:
try: try:
# Update kwargs with the current model name or any other model-specific adjustments # Update kwargs with the current model name or any other model-specific adjustments
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ## ## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
_, custom_llm_provider, _, _ = get_llm_provider( _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model=model_name["litellm_params"]["model"] model=model_name["litellm_params"]["model"]
) )
new_kwargs = copy.deepcopy(kwargs) new_kwargs = copy.deepcopy(kwargs)
new_kwargs.pop("custom_llm_provider", None) new_kwargs.pop("custom_llm_provider", None)
return await litellm.aretrieve_batch( return await litellm.aretrieve_batch(
custom_llm_provider=custom_llm_provider, **new_kwargs custom_llm_provider=custom_llm_provider, **new_kwargs # type: ignore
) )
except Exception as e: except Exception as e:
receieved_exceptions.append(e) receieved_exceptions.append(e)
@ -2616,13 +2624,13 @@ class Router:
for result in results: for result in results:
if result is not None: if result is not None:
## check batch id ## check batch id
if final_results["first_id"] is None: if final_results["first_id"] is None and hasattr(result, "first_id"):
final_results["first_id"] = result.first_id final_results["first_id"] = getattr(result, "first_id")
final_results["last_id"] = result.last_id final_results["last_id"] = getattr(result, "last_id")
final_results["data"].extend(result.data) # type: ignore final_results["data"].extend(result.data) # type: ignore
## check 'has_more' ## check 'has_more'
if result.has_more is True: if getattr(result, "has_more", False) is True:
final_results["has_more"] = True final_results["has_more"] = True
return final_results return final_results
@ -2874,8 +2882,12 @@ class Router:
verbose_router_logger.debug(f"Traceback{traceback.format_exc()}") verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
original_exception = e original_exception = e
fallback_model_group = None fallback_model_group = None
original_model_group = kwargs.get("model") original_model_group: Optional[str] = kwargs.get("model") # type: ignore
fallback_failure_exception_str = "" fallback_failure_exception_str = ""
if original_model_group is None:
raise e
try: try:
verbose_router_logger.debug("Trying to fallback b/w models") verbose_router_logger.debug("Trying to fallback b/w models")
if isinstance(e, litellm.ContextWindowExceededError): if isinstance(e, litellm.ContextWindowExceededError):
@ -2972,7 +2984,7 @@ class Router:
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
) )
if hasattr(original_exception, "message"): if hasattr(original_exception, "message"):
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore
raise original_exception raise original_exception
response = await run_async_fallback( response = await run_async_fallback(
@ -2996,12 +3008,12 @@ class Router:
if hasattr(original_exception, "message"): if hasattr(original_exception, "message"):
# add the available fallbacks to the exception # add the available fallbacks to the exception
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
model_group, model_group,
fallback_model_group, fallback_model_group,
) )
if len(fallback_failure_exception_str) > 0: if len(fallback_failure_exception_str) > 0:
original_exception.message += ( original_exception.message += ( # type: ignore
"\nError doing the fallback: {}".format( "\nError doing the fallback: {}".format(
fallback_failure_exception_str fallback_failure_exception_str
) )
@ -3117,9 +3129,15 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
_healthy_deployments, _ = await self._async_get_healthy_deployments( _model: Optional[str] = kwargs.get("model") # type: ignore
model=kwargs.get("model"), if _model is not None:
) _healthy_deployments, _ = (
await self._async_get_healthy_deployments(
model=_model,
)
)
else:
_healthy_deployments = []
_timeout = self._time_to_sleep_before_retry( _timeout = self._time_to_sleep_before_retry(
e=original_exception, e=original_exception,
remaining_retries=remaining_retries, remaining_retries=remaining_retries,
@ -3129,8 +3147,8 @@ class Router:
await asyncio.sleep(_timeout) await asyncio.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
original_exception.max_retries = num_retries setattr(original_exception, "max_retries", num_retries)
original_exception.num_retries = current_attempt setattr(original_exception, "num_retries", current_attempt)
raise original_exception raise original_exception
@ -3225,8 +3243,12 @@ class Router:
return response return response
except Exception as e: except Exception as e:
original_exception = e original_exception = e
original_model_group = kwargs.get("model") original_model_group: Optional[str] = kwargs.get("model")
verbose_router_logger.debug(f"An exception occurs {original_exception}") verbose_router_logger.debug(f"An exception occurs {original_exception}")
if original_model_group is None:
raise e
try: try:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Trying to fallback b/w models. Initial model group: {model_group}" f"Trying to fallback b/w models. Initial model group: {model_group}"
@ -3336,10 +3358,10 @@ class Router:
return 0 return 0
response_headers: Optional[httpx.Headers] = None response_headers: Optional[httpx.Headers] = None
if hasattr(e, "response") and hasattr(e.response, "headers"): if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
response_headers = e.response.headers response_headers = e.response.headers # type: ignore
elif hasattr(e, "litellm_response_headers"): elif hasattr(e, "litellm_response_headers"):
response_headers = e.litellm_response_headers response_headers = e.litellm_response_headers # type: ignore
if response_headers is not None: if response_headers is not None:
timeout = litellm._calculate_retry_after( timeout = litellm._calculate_retry_after(
@ -3398,9 +3420,13 @@ class Router:
except Exception as e: except Exception as e:
current_attempt = None current_attempt = None
original_exception = e original_exception = e
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
_healthy_deployments, _all_deployments = self._get_healthy_deployments( _healthy_deployments, _all_deployments = self._get_healthy_deployments(
model=kwargs.get("model"), model=_model,
) )
# raises an exception if this error should not be retries # raises an exception if this error should not be retries
@ -3438,8 +3464,12 @@ class Router:
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e) kwargs = self.log_retry(kwargs=kwargs, e=e)
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
_healthy_deployments, _ = self._get_healthy_deployments( _healthy_deployments, _ = self._get_healthy_deployments(
model=kwargs.get("model"), model=_model,
) )
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
_timeout = self._time_to_sleep_before_retry( _timeout = self._time_to_sleep_before_retry(
@ -4055,7 +4085,7 @@ class Router:
if isinstance(_litellm_params, dict): if isinstance(_litellm_params, dict):
for k, v in _litellm_params.items(): for k, v in _litellm_params.items():
if isinstance(v, str) and v.startswith("os.environ/"): if isinstance(v, str) and v.startswith("os.environ/"):
_litellm_params[k] = litellm.get_secret(v) _litellm_params[k] = get_secret(v)
_model_info: dict = model.pop("model_info", {}) _model_info: dict = model.pop("model_info", {})
@ -4392,7 +4422,6 @@ class Router:
- ModelGroupInfo if able to construct a model group - ModelGroupInfo if able to construct a model group
- None if error constructing model group info - None if error constructing model group info
""" """
model_group_info: Optional[ModelGroupInfo] = None model_group_info: Optional[ModelGroupInfo] = None
total_tpm: Optional[int] = None total_tpm: Optional[int] = None
@ -4557,12 +4586,23 @@ class Router:
Returns: Returns:
- ModelGroupInfo if able to construct a model group - ModelGroupInfo if able to construct a model group
- None if error constructing model group info - None if error constructing model group info or hidden model group
""" """
## Check if model group alias ## Check if model group alias
if model_group in self.model_group_alias: if model_group in self.model_group_alias:
item = self.model_group_alias[model_group]
if isinstance(item, str):
_router_model_group = item
elif isinstance(item, dict):
if item["hidden"] is True:
return None
else:
_router_model_group = item["model"]
else:
return None
return self._set_model_group_info( return self._set_model_group_info(
model_group=self.model_group_alias[model_group], model_group=_router_model_group,
user_facing_model_group_name=model_group, user_facing_model_group_name=model_group,
) )
@ -4666,7 +4706,14 @@ class Router:
Includes model_group_alias models too. Includes model_group_alias models too.
""" """
return self.model_names + list(self.model_group_alias.keys()) model_list = self.get_model_list()
if model_list is None:
return []
model_names = []
for m in model_list:
model_names.append(m["model_name"])
return model_names
def get_model_list( def get_model_list(
self, model_name: Optional[str] = None self, model_name: Optional[str] = None
@ -4678,9 +4725,21 @@ class Router:
returned_models: List[DeploymentTypedDict] = [] returned_models: List[DeploymentTypedDict] = []
for model_alias, model_value in self.model_group_alias.items(): for model_alias, model_value in self.model_group_alias.items():
if isinstance(model_value, str):
_router_model_name: str = model_value
elif isinstance(model_value, dict):
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
if _model_value["hidden"] is True:
continue
else:
_router_model_name = _model_value["model"]
else:
continue
returned_models.extend( returned_models.extend(
self._get_all_deployments( self._get_all_deployments(
model_name=model_value, model_alias=model_alias model_name=_router_model_name, model_alias=model_alias
) )
) )
@ -5078,10 +5137,11 @@ class Router:
) )
if model in self.model_group_alias: if model in self.model_group_alias:
verbose_router_logger.debug( _item = self.model_group_alias[model]
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}" if isinstance(_item, str):
) model = _item
model = self.model_group_alias[model] else:
model = _item["model"]
if model not in self.model_names: if model not in self.model_names:
# check if provider/ specific wildcard routing # check if provider/ specific wildcard routing
@ -5124,7 +5184,9 @@ class Router:
m for m in self.model_list if m["litellm_params"]["model"] == model m for m in self.model_list if m["litellm_params"]["model"] == model
] ]
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}") verbose_router_logger.debug(
f"initial list of deployments: {healthy_deployments}"
)
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
raise ValueError( raise ValueError(
@ -5208,7 +5270,7 @@ class Router:
) )
# check if user wants to do tag based routing # check if user wants to do tag based routing
healthy_deployments = await get_deployments_for_tag( healthy_deployments = await get_deployments_for_tag( # type: ignore
llm_router_instance=self, llm_router_instance=self,
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
@ -5241,7 +5303,7 @@ class Router:
input=input, input=input,
) )
) )
if ( elif (
self.routing_strategy == "cost-based-routing" self.routing_strategy == "cost-based-routing"
and self.lowestcost_logger is not None and self.lowestcost_logger is not None
): ):
@ -5326,6 +5388,8 @@ class Router:
############## No RPM/TPM passed, we do a random pick ################# ############## No RPM/TPM passed, we do a random pick #################
item = random.choice(healthy_deployments) item = random.choice(healthy_deployments)
return item or item[0] return item or item[0]
else:
deployment = None
if deployment is None: if deployment is None:
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"
@ -5515,6 +5579,9 @@ class Router:
messages=messages, messages=messages,
input=input, input=input,
) )
else:
deployment = None
if deployment is None: if deployment is None:
verbose_router_logger.info( verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"
@ -5690,6 +5757,9 @@ class Router:
def _initialize_alerting(self): def _initialize_alerting(self):
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
if self.alerting_config is None:
return
router_alerting_config: AlertingConfig = self.alerting_config router_alerting_config: AlertingConfig = self.alerting_config
_slack_alerting_logger = SlackAlerting( _slack_alerting_logger = SlackAlerting(
@ -5700,7 +5770,7 @@ class Router:
self.slack_alerting_logger = _slack_alerting_logger self.slack_alerting_logger = _slack_alerting_logger
litellm.callbacks.append(_slack_alerting_logger) litellm.callbacks.append(_slack_alerting_logger) # type: ignore
litellm.success_callback.append( litellm.success_callback.append(
_slack_alerting_logger.response_taking_too_long_callback _slack_alerting_logger.response_taking_too_long_callback
) )

View file

@ -21,8 +21,8 @@ else:
async def get_deployments_for_tag( async def get_deployments_for_tag(
llm_router_instance: LitellmRouter, llm_router_instance: LitellmRouter,
healthy_deployments: Union[List[Any], Dict[Any, Any]],
request_kwargs: Optional[Dict[Any, Any]] = None, request_kwargs: Optional[Dict[Any, Any]] = None,
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
): ):
if llm_router_instance.enable_tag_filtering is not True: if llm_router_instance.enable_tag_filtering is not True:
return healthy_deployments return healthy_deployments

View file

@ -2828,3 +2828,44 @@ def test_gemini_function_call_parameter_in_messages_2():
] ]
}, },
] ]
@pytest.mark.parametrize(
"base_model, metadata",
[
(None, {"model_info": {"base_model": "vertex_ai/gemini-1.5-pro"}}),
("vertex_ai/gemini-1.5-pro", None),
],
)
def test_gemini_finetuned_endpoint(base_model, metadata):
litellm.set_verbose = True
load_vertex_ai_credentials()
from litellm.llms.custom_httpx.http_handler import HTTPHandler
# Set up the messages
messages = [
{"role": "system", "content": """Use search for most queries."""},
{"role": "user", "content": """search for weather in boston (use `search`)"""},
]
client = HTTPHandler(concurrent_limit=1)
with patch.object(client, "post", new=MagicMock()) as mock_client:
try:
response = completion(
model="vertex_ai/4965075652664360960",
messages=messages,
tool_choice="auto",
client=client,
metadata=metadata,
base_model=base_model,
)
except Exception as e:
print(e)
print(mock_client.call_args.kwargs)
mock_client.assert_called()
assert mock_client.call_args.kwargs["url"].endswith(
"endpoints/4965075652664360960:generateContent"
)

View file

@ -311,6 +311,8 @@ async def test_completion_predibase():
pass pass
except litellm.ServiceUnavailableError as e: except litellm.ServiceUnavailableError as e:
pass pass
except litellm.InternalServerError:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -133,3 +133,21 @@ async def test_fireworks_health_check():
assert response == {} assert response == {}
return response return response
@pytest.mark.asyncio
async def test_cohere_rerank_health_check():
response = await litellm.ahealth_check(
model_params={
"model": "cohere/rerank-english-v3.0",
"query": "Hey, how's it going",
"documents": ["my sample text"],
"api_key": os.getenv("COHERE_API_KEY"),
},
mode="rerank",
prompt="Hey, how's it going",
)
assert "error" not in response
print(response)

View file

@ -50,6 +50,7 @@ def test_image_generation_openai():
], # False ], # False
) # ) #
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_image_generation_azure(sync_mode): async def test_image_generation_azure(sync_mode):
try: try:
if sync_mode: if sync_mode:

View file

@ -533,7 +533,7 @@ def test_call_with_user_over_budget(prisma_client):
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
pytest.fail(f"This should have failed!. They key crossed it's budget") pytest.fail("This should have failed!. They key crossed it's budget")
asyncio.run(test()) asyncio.run(test())
except Exception as e: except Exception as e:
@ -1755,7 +1755,7 @@ def test_call_with_key_over_model_budget(prisma_client):
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key", result)
pytest.fail(f"This should have failed!. They key crossed it's budget") pytest.fail("This should have failed!. They key crossed it's budget")
asyncio.run(test()) asyncio.run(test())
except Exception as e: except Exception as e:

View file

@ -589,3 +589,14 @@ def test_parse_additional_properties_json_schema(model, provider, expectedAddPro
elif provider == "openai": elif provider == "openai":
schema = optional_params["response_format"]["json_schema"]["schema"] schema = optional_params["response_format"]["json_schema"]["schema"]
assert ("additionalProperties" in schema) == expectedAddProp assert ("additionalProperties" in schema) == expectedAddProp
def test_o1_model_params():
optional_params = get_optional_params(
model="o1-preview-2024-09-12",
custom_llm_provider="openai",
seed=10,
user="John",
)
assert optional_params["seed"] == 10
assert optional_params["user"] == "John"

View file

@ -762,7 +762,7 @@ async def test_team_update_redis():
) as mock_client: ) as mock_client:
await _cache_team_object( await _cache_team_object(
team_id="1234", team_id="1234",
team_table=LiteLLM_TeamTableCachedObj(), team_table=LiteLLM_TeamTableCachedObj(team_id="1234"),
user_api_key_cache=DualCache(), user_api_key_cache=DualCache(),
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
) )
@ -776,7 +776,7 @@ async def test_get_team_redis(client_no_auth):
Tests if get_team_object gets value from redis cache, if set Tests if get_team_object gets value from redis cache, if set
""" """
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache
from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object from litellm.proxy.auth.auth_checks import get_team_object
proxy_logging_obj: ProxyLogging = getattr( proxy_logging_obj: ProxyLogging = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj" litellm.proxy.proxy_server, "proxy_logging_obj"
@ -917,7 +917,9 @@ async def test_create_team_member_add(prisma_client, new_member_method):
) )
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj()) team_mock_client.update = AsyncMock(
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
)
await team_member_add( await team_member_add(
data=team_member_add_request, data=team_member_add_request,
@ -1095,7 +1097,9 @@ async def test_create_team_member_add_team_admin(
) )
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj()) team_mock_client.update = AsyncMock(
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
)
try: try:
await team_member_add( await team_member_add(
@ -1434,8 +1438,9 @@ async def test_gemini_pass_through_endpoint():
print(resp.body) print(resp.body)
@pytest.mark.parametrize("hidden", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_proxy_model_group_alias_checks(prisma_client): async def test_proxy_model_group_alias_checks(prisma_client, hidden):
""" """
Check if model group alias is returned on Check if model group alias is returned on
@ -1465,7 +1470,7 @@ async def test_proxy_model_group_alias_checks(prisma_client):
model_alias = "gpt-4" model_alias = "gpt-4"
router = litellm.Router( router = litellm.Router(
model_list=_model_list, model_list=_model_list,
model_group_alias={model_alias: "gpt-3.5-turbo"}, model_group_alias={model_alias: {"model": "gpt-3.5-turbo", "hidden": hidden}},
) )
setattr(litellm.proxy.proxy_server, "llm_router", router) setattr(litellm.proxy.proxy_server, "llm_router", router)
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list) setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
@ -1477,7 +1482,10 @@ async def test_proxy_model_group_alias_checks(prisma_client):
user_api_key_dict=UserAPIKeyAuth(models=[]), user_api_key_dict=UserAPIKeyAuth(models=[]),
) )
assert len(resp) == 2 if hidden:
assert len(resp["data"]) == 1
else:
assert len(resp["data"]) == 2
print(resp) print(resp)
resp = await model_info_v1( resp = await model_info_v1(
@ -1489,7 +1497,10 @@ async def test_proxy_model_group_alias_checks(prisma_client):
if model_alias == item["model_name"]: if model_alias == item["model_name"]:
is_model_alias_in_list = True is_model_alias_in_list = True
assert is_model_alias_in_list if hidden:
assert is_model_alias_in_list is False
else:
assert is_model_alias_in_list
resp = await model_group_info( resp = await model_group_info(
user_api_key_dict=UserAPIKeyAuth(models=[]), user_api_key_dict=UserAPIKeyAuth(models=[]),
@ -1500,4 +1511,7 @@ async def test_proxy_model_group_alias_checks(prisma_client):
if model_alias == item.model_group: if model_alias == item.model_group:
is_model_alias_in_list = True is_model_alias_in_list = True
assert is_model_alias_in_list if hidden:
assert is_model_alias_in_list is False
else:
assert is_model_alias_in_list, f"models: {models}"

View file

@ -2481,3 +2481,31 @@ async def test_router_batch_endpoints(provider):
model="my-custom-name", custom_llm_provider=provider, limit=2 model="my-custom-name", custom_llm_provider=provider, limit=2
) )
print("list_batches=", list_batches) print("list_batches=", list_batches)
@pytest.mark.parametrize("hidden", [True, False])
def test_model_group_alias(hidden):
_model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
{"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}},
]
router = Router(
model_list=_model_list,
model_group_alias={
"gpt-4.5-turbo": {"model": "gpt-3.5-turbo", "hidden": hidden}
},
)
models = router.get_model_list()
model_names = router.get_model_names()
if hidden:
assert len(models) == len(_model_list)
assert len(model_names) == len(_model_list)
else:
assert len(models) == len(_model_list) + 1
assert len(model_names) == len(_model_list) + 1

View file

@ -582,3 +582,8 @@ class RouterRateLimitError(ValueError):
self.cooldown_list = cooldown_list self.cooldown_list = cooldown_list
_message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}" _message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}"
super().__init__(_message) super().__init__(_message)
class RouterModelGroupAliasItem(TypedDict):
model: str
hidden: bool # if 'True', don't return on `.get_model_list`

View file

@ -76,6 +76,7 @@ from litellm.types.llms.openai import (
ChatCompletionNamedToolChoiceParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionToolParam, ChatCompletionToolParam,
) )
from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import ( from litellm.types.utils import (
CallTypes, CallTypes,
ChatCompletionDeltaToolCall, ChatCompletionDeltaToolCall,
@ -84,7 +85,6 @@ from litellm.types.utils import (
Delta, Delta,
Embedding, Embedding,
EmbeddingResponse, EmbeddingResponse,
FileTypes,
ImageResponse, ImageResponse,
Message, Message,
ModelInfo, ModelInfo,
@ -2339,6 +2339,7 @@ def get_litellm_params(
text_completion=None, text_completion=None,
azure_ad_token_provider=None, azure_ad_token_provider=None,
user_continue_message=None, user_continue_message=None,
base_model=None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2365,6 +2366,8 @@ def get_litellm_params(
"text_completion": text_completion, "text_completion": text_completion,
"azure_ad_token_provider": azure_ad_token_provider, "azure_ad_token_provider": azure_ad_token_provider,
"user_continue_message": user_continue_message, "user_continue_message": user_continue_message,
"base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
} }
return litellm_params return litellm_params
@ -6063,11 +6066,11 @@ def _calculate_retry_after(
max_retries: int, max_retries: int,
response_headers: Optional[httpx.Headers] = None, response_headers: Optional[httpx.Headers] = None,
min_timeout: int = 0, min_timeout: int = 0,
): ) -> Union[float, int]:
retry_after = _get_retry_after_from_exception_header(response_headers) retry_after = _get_retry_after_from_exception_header(response_headers)
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
if 0 < retry_after <= 60: if retry_after is not None and 0 < retry_after <= 60:
return retry_after return retry_after
initial_retry_delay = 0.5 initial_retry_delay = 0.5
@ -10962,6 +10965,22 @@ def get_logging_id(start_time, response_obj):
return None return None
def _get_base_model_from_litellm_call_metadata(
metadata: Optional[dict],
) -> Optional[str]:
if metadata is None:
return None
if metadata is not None:
model_info = metadata.get("model_info", {})
if model_info is not None:
base_model = model_info.get("base_model", None)
if base_model is not None:
return base_model
return None
def _get_base_model_from_metadata(model_call_details=None): def _get_base_model_from_metadata(model_call_details=None):
if model_call_details is None: if model_call_details is None:
return None return None
@ -10970,13 +10989,7 @@ def _get_base_model_from_metadata(model_call_details=None):
if litellm_params is not None: if litellm_params is not None:
metadata = litellm_params.get("metadata", {}) metadata = litellm_params.get("metadata", {})
if metadata is not None: return _get_base_model_from_litellm_call_metadata(metadata=metadata)
model_info = metadata.get("model_info", {})
if model_info is not None:
base_model = model_info.get("base_model", None)
if base_model is not None:
return base_model
return None return None

6
pyrightconfig.json Normal file
View file

@ -0,0 +1,6 @@
{
"ignore": [],
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py"],
"reportMissingImports": false
}