LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)

* refactor: cleanup unused variables + fix pyright errors

* feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686

* fix(o1_reasoning.py): add stricter check for o-1 reasoning model

* refactor(mistral/): make it easier to see mistral transformation logic

* fix(openai.py): fix openai o-1 model param mapping

Fixes https://github.com/BerriAI/litellm/issues/5685

* feat(main.py): infer finetuned gemini model from base model

Fixes https://github.com/BerriAI/litellm/issues/5678

* docs(vertex.md): update docs to call finetuned gemini models

* feat(proxy_server.py): allow admin to hide proxy model aliases

Closes https://github.com/BerriAI/litellm/issues/5692

* docs(load_balancing.md): add docs on hiding alias models from proxy config

* fix(base.py): don't raise notimplemented error

* fix(user_api_key_auth.py): fix model max budget check

* fix(router.py): fix elif

* fix(user_api_key_auth.py): don't set team_id to empty str

* fix(team_endpoints.py): fix response type

* test(test_completion.py): handle predibase error

* test(test_proxy_server.py): fix test

* fix(o1_transformation.py): fix max_completion_token mapping

* test(test_image_generation.py): mark flaky test
This commit is contained in:
Krish Dholakia 2024-09-14 10:02:55 -07:00 committed by GitHub
parent db3af20d84
commit 60709a0753
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 1020 additions and 539 deletions

View file

@ -1,9 +1,9 @@
repos:
- repo: local
hooks:
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
- id: pyright
name: pyright
entry: pyright
language: system
types: [python]
files: ^litellm/

View file

@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';
https://github.com/BerriAI/litellm
## **Call 100+ LLMs using the same Input/Output Format**
## **Call 100+ LLMs using the OpenAI Input/Output Format**
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`

View file

@ -1180,9 +1180,58 @@ response = completion(
Fine tuned models on vertex have a numerical model/endpoint id.
| Model Name | Function Call |
|------------------|--------------------------------------|
| your fine tuned model | `completion(model='vertex_ai/4965075652664360960', messages)`|
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
## set ENV variables
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
os.environ["VERTEXAI_LOCATION"] = "us-central1"
response = completion(
model="vertex_ai/<your-finetuned-model>", # e.g. vertex_ai/4965075652664360960
messages=[{ "content": "Hello, how are you?","role": "user"}],
base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add Vertex Credentials to your env
```bash
!gcloud auth application-default login
```
2. Setup config.yaml
```yaml
- model_name: finetuned-gemini
litellm_params:
model: vertex_ai/<ENDPOINT_ID>
vertex_project: <PROJECT_ID>
vertex_location: <LOCATION>
model_info:
base_model: vertex_ai/gemini-1.5-pro # IMPORTANT
```
3. Test it!
```bash
curl --location 'https://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: <LITELLM_KEY>' \
--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}'
```
</TabItem>
</Tabs>
## Gemini Pro Vision
| Model Name | Function Call |

View file

@ -38,9 +38,24 @@ router_settings:
## Router settings on config - routing_strategy, model_group_alias
Expose an 'alias' for a 'model_name' on the proxy server.
```
model_group_alias: {
"gpt-4": "gpt-3.5-turbo"
}
```
These aliases are shown on `/v1/models`, `/v1/model/info`, and `/v1/model_group/info` by default.
litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64)
### Usage
Example config with `router_settings`
```yaml
model_list:
- model_name: gpt-3.5-turbo
@ -48,19 +63,41 @@ model_list:
model: azure/<your-deployment-name>
api_base: <your-azure-endpoint>
api_key: <your-azure-api-key>
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
router_settings:
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models
```
### Hide Alias Models
Use this if you want to set-up aliases for:
1. typos
2. minor model version changes
3. case sensitive changes between updates
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-turbo-small-ca
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
model: azure/<your-deployment-name>
api_base: <your-azure-endpoint>
api_key: <your-azure-api-key>
rpm: 6
router_settings:
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
routing_strategy: least-busy # Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"]
num_retries: 2
timeout: 30 # 30 seconds
redis_host: <your redis host>
redis_password: <your redis password>
redis_port: 1992
model_group_alias:
"GPT-3.5-turbo": # alias
model: "gpt-3.5-turbo" # Actual model name in 'model_list'
hidden: true # Exclude from `/v1/models`, `/v1/model/info`, `/v1/model_group/info`
```
### Complete Spec
```python
model_group_alias: Optional[Dict[str, Union[str, RouterModelGroupAliasItem]]] = {}
class RouterModelGroupAliasItem(TypedDict):
model: str
hidden: bool # if 'True', don't return on `/v1/models`, `/v1/model/info`, `/v1/model_group/info`
```

View file

@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';
https://github.com/BerriAI/litellm
## **Call 100+ LLMs using the same Input/Output Format**
## **Call 100+ LLMs using the OpenAI Input/Output Format**
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`

View file

@ -939,15 +939,18 @@ from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConf
from .llms.OpenAI.openai import (
OpenAIConfig,
OpenAITextCompletionConfig,
MistralConfig,
MistralEmbeddingConfig,
DeepInfraConfig,
GroqConfig,
AzureAIStudioConfig,
)
from .llms.OpenAI.o1_reasoning import (
from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.OpenAI.o1_transformation import (
OpenAIO1Config,
)
from .llms.OpenAI.gpt_transformation import (
OpenAIGPTConfig,
)
from .llms.nvidia_nim import NvidiaNimConfig
from .llms.cerebras.chat import CerebrasConfig
from .llms.AI21.chat import AI21ChatConfig

View file

@ -52,7 +52,9 @@ class SlackAlerting(CustomBatchLogger):
def __init__(
self,
internal_usage_cache: Optional[DualCache] = None,
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
alerting_threshold: Optional[
float
] = None, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [],
alert_types: List[AlertType] = [
"llm_exceptions",
@ -74,6 +76,8 @@ class SlackAlerting(CustomBatchLogger):
default_webhook_url: Optional[str] = None,
**kwargs,
):
if alerting_threshold is None:
alerting_threshold = 300
self.alerting_threshold = alerting_threshold
self.alerting = alerting
self.alert_types = alert_types
@ -99,6 +103,7 @@ class SlackAlerting(CustomBatchLogger):
):
if alerting is not None:
self.alerting = alerting
asyncio.create_task(self.periodic_flush())
if alerting_threshold is not None:
self.alerting_threshold = alerting_threshold
if alert_types is not None:
@ -114,8 +119,6 @@ class SlackAlerting(CustomBatchLogger):
if llm_router is not None:
self.llm_router = llm_router
asyncio.create_task(self.periodic_flush())
async def deployment_in_cooldown(self):
pass
@ -208,15 +211,20 @@ class SlackAlerting(CustomBatchLogger):
_deployment_latencies = metadata["_latency_per_deployment"]
if len(_deployment_latencies) == 0:
return None
_deployment_latency_map: Optional[dict] = None
try:
# try sorting deployments by latency
_deployment_latencies = sorted(
_deployment_latencies.items(), key=lambda x: x[1]
)
_deployment_latencies = dict(_deployment_latencies)
except:
_deployment_latency_map = dict(_deployment_latencies)
except Exception:
pass
for api_base, latency in _deployment_latencies.items():
if _deployment_latency_map is None:
return
for api_base, latency in _deployment_latency_map.items():
_message_to_send += f"\n{api_base}: {round(latency,2)}s"
_message_to_send = "```" + _message_to_send + "```"
return _message_to_send
@ -475,6 +483,7 @@ class SlackAlerting(CustomBatchLogger):
):
if self.alerting is None or self.alert_types is None:
return
model: str = ""
if request_data is not None:
model = request_data.get("model", "")
messages = request_data.get("messages", None)
@ -619,6 +628,7 @@ class SlackAlerting(CustomBatchLogger):
return
_id: Optional[str] = "default_id" # used for caching
user_info_json = user_info.model_dump(exclude_none=True)
user_info_str = ""
for k, v in user_info_json.items():
user_info_str = "\n{}: {}\n".format(k, v)
@ -1475,10 +1485,10 @@ Model Info:
if isinstance(response_obj, litellm.ModelResponse) and (
hasattr(response_obj, "usage")
and response_obj.usage is not None
and hasattr(response_obj.usage, "completion_tokens")
and response_obj.usage is not None # type: ignore
and hasattr(response_obj.usage, "completion_tokens") # type: ignore
):
completion_tokens = response_obj.usage.completion_tokens
completion_tokens = response_obj.usage.completion_tokens # type: ignore
if completion_tokens is not None and completion_tokens > 0:
final_value = float(
response_s.total_seconds() / completion_tokens
@ -1608,10 +1618,14 @@ Model Info:
todays_date = datetime.datetime.now().date()
start_date = todays_date - datetime.timedelta(days=days)
spend_per_team, spend_per_tag = await _get_spend_report_for_time_range(
_resp = await _get_spend_report_for_time_range(
start_date=start_date.strftime("%Y-%m-%d"),
end_date=todays_date.strftime("%Y-%m-%d"),
)
if _resp is None:
return
spend_per_team, spend_per_tag = _resp
_spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n"
@ -1656,13 +1670,16 @@ Model Info:
days=last_day_of_month - 1
)
monthly_spend_per_team, monthly_spend_per_tag = (
await _get_spend_report_for_time_range(
start_date=first_day_of_month.strftime("%Y-%m-%d"),
end_date=last_day_of_month.strftime("%Y-%m-%d"),
)
_resp = await _get_spend_report_for_time_range(
start_date=first_day_of_month.strftime("%Y-%m-%d"),
end_date=last_day_of_month.strftime("%Y-%m-%d"),
)
if _resp is None:
return
monthly_spend_per_team, monthly_spend_per_tag = _resp
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
if monthly_spend_per_team is not None:

View file

@ -0,0 +1,142 @@
"""
Support for gpt model family
"""
import types
from typing import Optional, Union
import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
class OpenAIGPTConfig:
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> list:
base_params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
] # works across all models
model_specific_params = []
if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
if (
model in litellm.open_ai_chat_completion_models
) or model in litellm.open_ai_text_completion_models:
model_specific_params.append(
"user"
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
return base_params + model_specific_params
def _map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
return self._map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)

View file

@ -17,13 +17,12 @@ from typing import Any, List, Optional, Union
import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from .openai import OpenAIConfig
from .gpt_transformation import OpenAIGPTConfig
class OpenAIO1Config(OpenAIConfig):
class OpenAIO1Config(OpenAIGPTConfig):
"""
Reference: https://platform.openai.com/docs/guides/reasoning
"""
@classmethod
@ -50,9 +49,7 @@ class OpenAIO1Config(OpenAIConfig):
"""
all_openai_params = litellm.OpenAIConfig().get_supported_openai_params(
model="gpt-4o"
)
all_openai_params = super().get_supported_openai_params(model=model)
non_supported_params = [
"logprobs",
"tools",
@ -69,13 +66,14 @@ class OpenAIO1Config(OpenAIConfig):
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_completion_tokens"] = value
return optional_params
if "max_tokens" in non_default_params:
optional_params["max_completion_tokens"] = non_default_params.pop(
"max_tokens"
)
return super()._map_openai_params(non_default_params, optional_params, model)
def is_model_o1_reasoning_model(self, model: str) -> bool:
if "o1" in model:
if model in litellm.open_ai_chat_completion_models and "o1" in model:
return True
return False
@ -93,7 +91,7 @@ class OpenAIO1Config(OpenAIConfig):
)
messages[i] = new_message # Replace the old message with the new one
if isinstance(message["content"], list):
if "content" in message and isinstance(message["content"], list):
new_content = []
for content_item in message["content"]:
if content_item.get("type") == "image_url":

View file

@ -60,122 +60,6 @@ class OpenAIError(Exception):
) # Call the base class constructor with the parameters it needs
class MistralConfig:
"""
Reference: https://docs.mistral.ai/api/
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
"""
temperature: Optional[int] = None
top_p: Optional[int] = None
max_tokens: Optional[int] = None
tools: Optional[list] = None
tool_choice: Optional[Literal["auto", "any", "none"]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None
stop: Optional[Union[str, list]] = None
def __init__(
self,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
max_tokens: Optional[int] = None,
tools: Optional[list] = None,
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
stop: Optional[Union[str, list]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
"seed",
"stop",
"response_format",
]
def _map_tool_choice(self, tool_choice: str) -> str:
if tool_choice == "auto" or tool_choice == "none":
return tool_choice
elif tool_choice == "required":
return "any"
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
if param == "response_format":
optional_params["response_format"] = value
return optional_params
class MistralEmbeddingConfig:
"""
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
@ -526,44 +410,19 @@ class OpenAIConfig:
}
def get_supported_openai_params(self, model: str) -> list:
base_params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
] # works across all models
model_specific_params = []
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
return litellm.OpenAIO1Config().get_supported_openai_params(model=model)
if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
else:
return litellm.OpenAIGPTConfig().get_supported_openai_params(model=model)
if (
model in litellm.open_ai_chat_completion_models
) or model in litellm.open_ai_text_completion_models:
model_specific_params.append(
"user"
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
return base_params + model_specific_params
def _map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
) -> dict:
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
def map_openai_params(
self, non_default_params: dict, optional_params: dict, model: str
@ -575,11 +434,11 @@ class OpenAIConfig:
optional_params=optional_params,
model=model,
)
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
return litellm.OpenAIGPTConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
)
class OpenAITextCompletionConfig:
@ -816,18 +675,18 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e:
raise e
def completion(
def completion( # type: ignore
self,
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
optional_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
logging_obj=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
@ -858,14 +717,14 @@ class OpenAIChatCompletion(BaseLLM):
# process all OpenAI compatible provider logic here
if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string
messages = prompt_factory(
messages = prompt_factory( # type: ignore
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if custom_llm_provider == "perplexity" and messages is not None:
# check if messages.name is passed + supported, if not supported remove
messages = prompt_factory(
messages = prompt_factory( # type: ignore
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
@ -933,7 +792,7 @@ class OpenAIChatCompletion(BaseLLM):
status_code=422, message="max retries must be an int"
)
openai_client = self._get_openai_client(
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
api_base=api_base,
@ -1068,7 +927,7 @@ class OpenAIChatCompletion(BaseLLM):
2
): # if call fails due to alternating messages, retry with reformatted message
try:
openai_aclient = self._get_openai_client(
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True,
api_key=api_key,
api_base=api_base,
@ -1156,7 +1015,7 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=None,
headers=None,
):
openai_client = self._get_openai_client(
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
api_base=api_base,
@ -1210,7 +1069,7 @@ class OpenAIChatCompletion(BaseLLM):
response = None
for _ in range(2):
try:
openai_aclient = self._get_openai_client(
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True,
api_key=api_key,
api_base=api_base,
@ -1282,7 +1141,7 @@ class OpenAIChatCompletion(BaseLLM):
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=500,
message=f"{str(e)}\n\nOriginal Response: {response.text}",
message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore
headers=error_headers,
)
else:
@ -1294,7 +1153,7 @@ class OpenAIChatCompletion(BaseLLM):
)
elif hasattr(e, "status_code"):
raise OpenAIError(
status_code=e.status_code,
status_code=getattr(e, "status_code", 500),
message=str(e),
headers=error_headers,
)
@ -1361,7 +1220,7 @@ class OpenAIChatCompletion(BaseLLM):
):
response = None
try:
openai_aclient = self._get_openai_client(
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True,
api_key=api_key,
api_base=api_base,
@ -1410,16 +1269,16 @@ class OpenAIChatCompletion(BaseLLM):
status_code=status_code, message=str(e), headers=error_headers
)
def embedding(
def embedding( # type: ignore
self,
model: str,
input: list,
timeout: float,
logging_obj,
model_response: litellm.utils.EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
optional_params=None,
client=None,
aembedding=None,
):
@ -1452,7 +1311,7 @@ class OpenAIChatCompletion(BaseLLM):
)
return response
openai_client = self._get_openai_client(
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
api_base=api_base,
@ -1496,11 +1355,11 @@ class OpenAIChatCompletion(BaseLLM):
data: dict,
model_response: ModelResponse,
timeout: float,
logging_obj: Any,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
):
response = None
try:
@ -1538,15 +1397,16 @@ class OpenAIChatCompletion(BaseLLM):
model: Optional[str],
prompt: str,
timeout: float,
optional_params: dict,
logging_obj: Any,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
logging_obj=None,
optional_params=None,
client=None,
aimg_generation=None,
):
exception_mapping_worked = False
data = {}
try:
model = model
data = {"model": model, "prompt": prompt, **optional_params}
@ -1611,7 +1471,9 @@ class OpenAIChatCompletion(BaseLLM):
original_response=str(e),
)
if hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e))
raise OpenAIError(
status_code=getattr(e, "status_code", 500), message=str(e)
)
else:
raise OpenAIError(status_code=500, message=str(e))
@ -1661,7 +1523,7 @@ class OpenAIChatCompletion(BaseLLM):
input=input,
**optional_params,
)
return response
return response # type: ignore
async def async_audio_speech(
self,
@ -1784,11 +1646,8 @@ class OpenAIChatCompletion(BaseLLM):
class OpenAITextCompletion(BaseLLM):
_client_session: httpx.Client
def __init__(self) -> None:
super().__init__()
self._client_session = self.create_client_session()
def validate_environment(self, api_key):
headers = {
@ -1806,10 +1665,10 @@ class OpenAITextCompletion(BaseLLM):
messages: list,
timeout: float,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
optional_params=None,
litellm_params=None,
logger_fn=None,
client=None,
@ -1921,7 +1780,7 @@ class OpenAITextCompletion(BaseLLM):
api_key: str,
model: str,
timeout: float,
max_retries=None,
max_retries: int,
organization: Optional[str] = None,
client=None,
):
@ -2017,9 +1876,9 @@ class OpenAITextCompletion(BaseLLM):
model_response: ModelResponse,
model: str,
timeout: float,
max_retries: int,
api_base: Optional[str] = None,
client=None,
max_retries=None,
organization=None,
):
if client is None:

12
litellm/llms/README.md Normal file
View file

@ -0,0 +1,12 @@
## File Structure
### August 27th, 2024
To make it easy to see how calls are transformed for each model/provider:
we are working on moving all supported litellm providers to a folder structure, where folder name is the supported litellm provider name.
Each folder will contain a `*_transformation.py` file, which has all the request/response transformation logic, making it easy to see how calls are modified.
E.g. `cohere/`, `bedrock/`.

View file

@ -66,22 +66,24 @@ class BaseLLM:
return _aclient_session
def __exit__(self):
if hasattr(self, "_client_session"):
if hasattr(self, "_client_session") and self._client_session is not None:
self._client_session.close()
async def __aexit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, "_aclient_session"):
await self._aclient_session.aclose()
await self._aclient_session.aclose() # type: ignore
def validate_environment(self): # set up the environment required to run the model
pass
def validate_environment(
self, *args, **kwargs
) -> Optional[Any]: # set up the environment required to run the model
return None
def completion(
self, *args, **kwargs
): # logic for parsing in - calling - parsing out model completion calls
pass
) -> Any: # logic for parsing in - calling - parsing out model completion calls
return None
def embedding(
self, *args, **kwargs
): # logic for parsing in - calling - parsing out model embedding calls
pass
) -> Any: # logic for parsing in - calling - parsing out model embedding calls
return None

View file

@ -0,0 +1,5 @@
"""
Calls handled in openai/
as mistral is an openai-compatible endpoint.
"""

View file

@ -0,0 +1,5 @@
"""
Calls handled in openai/
as mistral is an openai-compatible endpoint.
"""

View file

@ -0,0 +1,126 @@
"""
Transformation logic from OpenAI /v1/chat/completion format to Mistral's /chat/completion format.
Why separate file? Make it easy to see how transformation works
Docs - https://docs.mistral.ai/api/
"""
import types
from typing import List, Literal, Optional, Union
class MistralConfig:
"""
Reference: https://docs.mistral.ai/api/
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
"""
temperature: Optional[int] = None
top_p: Optional[int] = None
max_tokens: Optional[int] = None
tools: Optional[list] = None
tool_choice: Optional[Literal["auto", "any", "none"]] = None
random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None
stop: Optional[Union[str, list]] = None
def __init__(
self,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
max_tokens: Optional[int] = None,
tools: Optional[list] = None,
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None,
stop: Optional[Union[str, list]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
"seed",
"stop",
"response_format",
]
def _map_tool_choice(self, tool_choice: str) -> str:
if tool_choice == "auto" or tool_choice == "none":
return tool_choice
elif tool_choice == "required":
return "any"
else: # openai 'tool_choice' object param not supported by Mistral API
return "any"
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value
)
if param == "seed":
optional_params["extra_body"] = {"random_seed": value}
if param == "response_format":
optional_params["response_format"] = value
return optional_params

View file

@ -737,6 +737,7 @@ def completion(
preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None)
supports_system_message = kwargs.get("supports_system_message", None)
base_model = kwargs.get("base_model", None)
### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False)
@ -782,11 +783,9 @@ def completion(
"top_logprobs",
"extra_headers",
]
litellm_params = (
all_litellm_params # use the external var., used in creating cache key as well.
)
default_params = openai_params + litellm_params
default_params = openai_params + all_litellm_params
litellm_params = {} # used to prevent unbound var errors
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
@ -973,6 +972,7 @@ def completion(
text_completion=kwargs.get("text_completion"),
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"),
base_model=base_model,
)
logging.update_environment_variables(
model=model,
@ -2123,7 +2123,10 @@ def completion(
timeout=timeout,
client=client,
)
elif "gemini" in model:
elif "gemini" in model or (
litellm_params.get("base_model") is not None
and "gemini" in litellm_params["base_model"]
):
model_response = vertex_chat_completion.completion( # type: ignore
model=model,
messages=messages,
@ -2820,7 +2823,7 @@ def completion_with_retries(*args, **kwargs):
)
num_retries = kwargs.pop("num_retries", 3)
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(
@ -4997,7 +5000,9 @@ def speech(
async def ahealth_check(
model_params: dict,
mode: Optional[
Literal["completion", "embedding", "image_generation", "chat", "batch"]
Literal[
"completion", "embedding", "image_generation", "chat", "batch", "rerank"
]
] = None,
prompt: Optional[str] = None,
input: Optional[List] = None,
@ -5113,6 +5118,12 @@ async def ahealth_check(
model_params["prompt"] = prompt
await litellm.aimage_generation(**model_params)
response = {}
elif mode == "rerank":
model_params.pop("messages", None)
model_params["query"] = prompt
model_params["documents"] = ["my sample text"]
await litellm.arerank(**model_params)
response = {}
elif "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_model_from_llm_provider,

View file

@ -1,9 +1,7 @@
model_list:
- model_name: "gpt-4o"
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-4o
model: gpt-3.5-turbo
litellm_settings:
cache: true
cache_params:
type: local
router_settings:
model_group_alias: {"gpt-4": {"model": "gpt-3.5-turbo", "hidden": false}}

View file

@ -99,15 +99,15 @@ class LitellmUserRoles(str, enum.Enum):
return ui_labels.get(self.value, "")
class LitellmTableNames(str, enum.Enum):
class LitellmTableNames(enum.Enum):
"""
Enum for Table Names used by LiteLLM
"""
TEAM_TABLE_NAME: str = "LiteLLM_TeamTable"
USER_TABLE_NAME: str = "LiteLLM_UserTable"
KEY_TABLE_NAME: str = "LiteLLM_VerificationToken"
PROXY_MODEL_TABLE_NAME: str = "LiteLLM_ModelTable"
TEAM_TABLE_NAME = "LiteLLM_TeamTable"
USER_TABLE_NAME = "LiteLLM_UserTable"
KEY_TABLE_NAME = "LiteLLM_VerificationToken"
PROXY_MODEL_TABLE_NAME = "LiteLLM_ModelTable"
AlertType = Literal[
@ -140,7 +140,7 @@ class LiteLLMBase(BaseModel):
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump(**kwargs) # noqa
except Exception as e:
@ -170,7 +170,7 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
class LiteLLMRoutes(enum.Enum):
openai_route_names: List = [
openai_route_names = [
"chat_completion",
"completion",
"embeddings",
@ -179,7 +179,7 @@ class LiteLLMRoutes(enum.Enum):
"moderations",
"model_list", # OpenAI /v1/models route
]
openai_routes: List = [
openai_routes = [
# chat completions
"/engines/{model}/chat/completions",
"/openai/deployments/{model}/chat/completions",
@ -247,18 +247,18 @@ class LiteLLMRoutes(enum.Enum):
"/v1/rerank",
]
mapped_pass_through_routes: List = [
mapped_pass_through_routes = [
"/bedrock",
"/vertex-ai",
"/gemini",
"/langfuse",
]
anthropic_routes: List = [
anthropic_routes = [
"/v1/messages",
]
info_routes: List = [
info_routes = [
"/key/info",
"/team/info",
"/team/list",
@ -271,9 +271,9 @@ class LiteLLMRoutes(enum.Enum):
]
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
master_key_only_routes: List = ["/global/spend/reset", "/key/list"]
master_key_only_routes = ["/global/spend/reset", "/key/list"]
sso_only_routes: List = [
sso_only_routes = [
"/key/generate",
"/key/update",
"/key/delete",
@ -282,7 +282,7 @@ class LiteLLMRoutes(enum.Enum):
"/sso/get/logout_url",
]
management_routes: List = [ # key
management_routes = [ # key
"/key/generate",
"/key/update",
"/key/delete",
@ -307,7 +307,7 @@ class LiteLLMRoutes(enum.Enum):
"/model/info",
]
spend_tracking_routes: List = [
spend_tracking_routes = [
# spend
"/spend/keys",
"/spend/users",
@ -316,7 +316,7 @@ class LiteLLMRoutes(enum.Enum):
"/spend/logs",
]
global_spend_tracking_routes: List = [
global_spend_tracking_routes = [
# global spend
"/global/spend/logs",
"/global/spend",
@ -328,7 +328,7 @@ class LiteLLMRoutes(enum.Enum):
"/global/spend/report",
]
public_routes: List = [
public_routes = [
"/routes",
"/",
"/health/liveliness",
@ -339,7 +339,7 @@ class LiteLLMRoutes(enum.Enum):
"/metrics",
]
internal_user_routes: List = (
internal_user_routes = (
[
"/key/generate",
"/key/update",
@ -357,7 +357,7 @@ class LiteLLMRoutes(enum.Enum):
+ sso_only_routes
)
self_managed_routes: List = [
self_managed_routes = [
"/team/member_add",
"/team/member_delete",
] # routes that manage their own allowed/disallowed logic
@ -581,7 +581,9 @@ class ModelParams(LiteLLMBase):
@classmethod
def set_model_info(cls, values):
if values.get("model_info") is None:
values.update({"model_info": ModelInfo()})
values.update(
{"model_info": ModelInfo(id=None, mode="chat", base_model=None)}
)
return values
@ -627,7 +629,7 @@ class GenerateKeyRequest(_GenerateKeyRequest):
class GenerateKeyResponse(_GenerateKeyRequest):
key: str
key: str # type: ignore
key_name: Optional[str] = None
expires: Optional[datetime]
user_id: Optional[str] = None
@ -659,7 +661,7 @@ class GenerateKeyResponse(_GenerateKeyRequest):
class UpdateKeyRequest(GenerateKeyRequest):
# Note: the defaults of all Params here MUST BE NONE
# else they will get overwritten
key: str
key: str # type: ignore
duration: Optional[str] = None
spend: Optional[float] = None
metadata: Optional[dict] = None
@ -976,6 +978,7 @@ class TeamCallbackMetadata(LiteLLMBase):
class LiteLLM_TeamTable(TeamBase):
team_id: str # type: ignore
spend: Optional[float] = None
max_parallel_requests: Optional[int] = None
budget_duration: Optional[str] = None
@ -1061,7 +1064,7 @@ class LiteLLM_OrganizationTable(LiteLLMBase):
class NewOrganizationResponse(LiteLLM_OrganizationTable):
organization_id: str
organization_id: str # type: ignore
created_at: datetime
updated_at: datetime
@ -1388,16 +1391,7 @@ class UserAPIKeyAuth(
"""
api_key: Optional[str] = None
user_role: Optional[
Literal[
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
LitellmUserRoles.TEAM,
LitellmUserRoles.CUSTOMER,
]
] = None
user_role: Optional[LitellmUserRoles] = None
allowed_model_region: Optional[Literal["eu"]] = None
parent_otel_span: Optional[Span] = None
rpm_limit_per_model: Optional[Dict[str, int]] = None
@ -1716,9 +1710,9 @@ class SpendLogsPayload(TypedDict):
total_tokens: int
prompt_tokens: int
completion_tokens: int
startTime: datetime
endTime: datetime
completionStartTime: Optional[datetime]
startTime: Union[datetime, str]
endTime: Union[datetime, str]
completionStartTime: Optional[Union[datetime, str]]
model: str
model_id: Optional[str]
model_group: Optional[str]
@ -1891,6 +1885,6 @@ class TeamAddMemberResponse(LiteLLM_TeamTable):
class TeamInfoResponseObject(TypedDict):
team_id: str
team_info: TeamBase
team_info: LiteLLM_TeamTable
keys: List
team_memberships: List[LiteLLM_TeamMembership]

View file

@ -109,11 +109,8 @@ async def user_api_key_auth(
),
) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
allowed_routes_check,
common_checks,
custom_db_client,
general_settings,
get_actual_routes,
jwt_handler,
litellm_proxy_admin_name,
llm_model_list,
@ -125,6 +122,8 @@ async def user_api_key_auth(
user_custom_auth,
)
parent_otel_span: Optional[Span] = None
try:
route: str = get_request_route(request=request)
# get the request body
@ -137,6 +136,7 @@ async def user_api_key_auth(
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
"pass_through_endpoints", None
)
passed_in_key: Optional[str] = None
if isinstance(api_key, str):
passed_in_key = api_key
api_key = _get_bearer_token(api_key=api_key)
@ -161,7 +161,6 @@ async def user_api_key_auth(
custom_litellm_key_header_name=custom_litellm_key_header_name,
)
parent_otel_span: Optional[Span] = None
if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span(
name="Received Proxy Server Request",
@ -189,7 +188,7 @@ async def user_api_key_auth(
######## Route Checks Before Reading DB / Cache for "token" ################
if (
route in LiteLLMRoutes.public_routes.value
route in LiteLLMRoutes.public_routes.value # type: ignore
or route_in_additonal_public_routes(current_route=route)
):
# check if public endpoint
@ -410,7 +409,7 @@ async def user_api_key_auth(
#### ELSE ####
## CHECK PASS-THROUGH ENDPOINTS ##
is_mapped_pass_through_route: bool = False
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value:
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore
if route.startswith(mapped_route):
is_mapped_pass_through_route = True
if is_mapped_pass_through_route:
@ -444,9 +443,9 @@ async def user_api_key_auth(
header_key = headers.get("litellm_user_api_key", "")
if (
isinstance(request.headers, dict)
and request.headers.get(key=header_key) is not None
and request.headers.get(key=header_key) is not None # type: ignore
):
api_key = request.headers.get(key=header_key)
api_key = request.headers.get(key=header_key) # type: ignore
if master_key is None:
if isinstance(api_key, str):
@ -606,7 +605,7 @@ async def user_api_key_auth(
## IF it's not a master key
## Route should not be in master_key_only_routes
if route in LiteLLMRoutes.master_key_only_routes.value:
if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore
raise Exception(
f"Tried to access route={route}, which is only for MASTER KEY"
)
@ -669,8 +668,9 @@ async def user_api_key_auth(
"allowed_model_region"
)
user_obj: Optional[LiteLLM_UserTable] = None
valid_token_dict: dict = {}
if valid_token is not None:
user_obj: Optional[LiteLLM_UserTable] = None
# Got Valid Token from Cache, DB
# Run checks for
# 1. If token can call model
@ -686,6 +686,7 @@ async def user_api_key_auth(
# Check 1. If token can call model
_model_alias_map = {}
model: Optional[str] = None
if (
hasattr(valid_token, "team_model_aliases")
and valid_token.team_model_aliases is not None
@ -698,6 +699,7 @@ async def user_api_key_auth(
_model_alias_map = {**valid_token.aliases}
litellm.model_alias_map = _model_alias_map
config = valid_token.config
if config != {}:
model_list = config.get("model_list", [])
llm_model_list = model_list
@ -887,7 +889,10 @@ async def user_api_key_auth(
and max_budget_per_model.get(current_model, None) is not None
):
if (
model_spend[0]["model"] == current_model
"model" in model_spend[0]
and model_spend[0].get("model") == current_model
and "_sum" in model_spend[0]
and "spend" in model_spend[0]["_sum"]
and model_spend[0]["_sum"]["spend"]
>= max_budget_per_model[current_model]
):
@ -927,16 +932,19 @@ async def user_api_key_auth(
)
# Check 8: Additional Common Checks across jwt + key auth
_team_obj = LiteLLM_TeamTable(
team_id=valid_token.team_id,
max_budget=valid_token.team_max_budget,
spend=valid_token.team_spend,
tpm_limit=valid_token.team_tpm_limit,
rpm_limit=valid_token.team_rpm_limit,
blocked=valid_token.team_blocked,
models=valid_token.team_models,
metadata=valid_token.team_metadata,
)
if valid_token.team_id is not None:
_team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
team_id=valid_token.team_id,
max_budget=valid_token.team_max_budget,
spend=valid_token.team_spend,
tpm_limit=valid_token.team_tpm_limit,
rpm_limit=valid_token.team_rpm_limit,
blocked=valid_token.team_blocked,
models=valid_token.team_models,
metadata=valid_token.team_metadata,
)
else:
_team_obj = None
user_api_key_cache.set_cache(
key=valid_token.team_id, value=_team_obj
@ -1045,7 +1053,7 @@ async def user_api_key_auth(
"/global/predict/spend/logs",
"/global/activity",
"/health/services",
] + LiteLLMRoutes.info_routes.value
] + LiteLLMRoutes.info_routes.value # type: ignore
# check if the current route startswith any of the allowed routes
if (
route is not None
@ -1106,7 +1114,7 @@ async def user_api_key_auth(
# Log this exception to OTEL
if open_telemetry_logger is not None:
await open_telemetry_logger.async_post_call_failure_hook(
await open_telemetry_logger.async_post_call_failure_hook( # type: ignore
original_exception=e,
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
)

View file

@ -4,10 +4,11 @@ import json
import traceback
import uuid
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from typing import List, Optional, Union
import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
@ -28,6 +29,7 @@ from litellm.proxy._types import (
ProxyErrorTypes,
ProxyException,
TeamAddMemberResponse,
TeamBase,
TeamInfoResponseObject,
TeamMemberAddRequest,
TeamMemberDeleteRequest,
@ -36,6 +38,7 @@ from litellm.proxy._types import (
UpdateTeamRequest,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_checks import get_team_object
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
from litellm.proxy.management_helpers.utils import (
add_new_member,
@ -240,7 +243,7 @@ async def new_team(
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
complete_team_data.budget_reset_at = reset_at
team_row = await prisma_client.insert_data(
team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore
data=complete_team_data.json(exclude_none=True), table_name="team"
)
@ -462,9 +465,6 @@ async def team_member_add(
```
"""
from litellm.proxy.proxy_server import (
_duration_in_seconds,
create_audit_log_for_update,
get_team_object,
litellm_proxy_admin_name,
prisma_client,
proxy_logging_obj,
@ -932,10 +932,13 @@ async def delete_team(
if litellm.store_audit_logs is True:
# make an audit log for each team deleted
for team_id in data.team_ids:
team_row = await prisma_client.get_data( # type: ignore
team_row: Optional[LiteLLM_TeamTable] = await prisma_client.get_data( # type: ignore
team_id=team_id, table_name="team", query_type="find_unique"
)
if team_row is None:
continue
_team_row = team_row.json(exclude_none=True)
asyncio.create_task(
@ -1027,8 +1030,10 @@ async def team_info(
),
)
team_info = await prisma_client.get_data(
team_id=team_id, table_name="team", query_type="find_unique"
team_info: Optional[Union[LiteLLM_TeamTable, dict]] = (
await prisma_client.get_data(
team_id=team_id, table_name="team", query_type="find_unique"
)
)
if team_info is None:
raise HTTPException(
@ -1044,6 +1049,9 @@ async def team_info(
expires=datetime.now(),
)
if keys is None:
keys = []
if team_info is None:
## make sure we still return a total spend ##
spend = 0
@ -1055,7 +1063,7 @@ async def team_info(
for key in keys:
try:
key = key.model_dump() # noqa
except:
except Exception:
# if using pydantic v1
key = key.dict()
key.pop("token", None)
@ -1070,9 +1078,16 @@ async def team_info(
for tm in team_memberships:
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
if isinstance(team_info, dict):
_team_info = LiteLLM_TeamTable(**team_info)
elif isinstance(team_info, BaseModel):
_team_info = LiteLLM_TeamTable(**team_info.model_dump())
else:
_team_info = LiteLLM_TeamTable()
response_object = TeamInfoResponseObject(
team_id=team_id,
team_info=team_info,
team_info=_team_info,
keys=keys,
team_memberships=returned_tm,
)

View file

@ -125,16 +125,7 @@ from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
router as analytics_router,
)
from litellm.proxy.auth.auth_checks import (
allowed_routes_check,
common_checks,
get_actual_routes,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
log_to_opentelemetry,
)
from litellm.proxy.auth.auth_checks import log_to_opentelemetry
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck
@ -260,6 +251,7 @@ from litellm.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
)
from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import get_secret
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
@ -484,7 +476,7 @@ general_settings: dict = {}
callback_settings: dict = {}
log_file = "api_log.json"
worker_config = None
master_key = None
master_key: Optional[str] = None
otel_logging = False
prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None
@ -874,7 +866,9 @@ def error_tracking():
def _set_spend_logs_payload(
payload: dict, prisma_client: PrismaClient, spend_logs_url: Optional[str] = None
payload: Union[dict, SpendLogsPayload],
prisma_client: PrismaClient,
spend_logs_url: Optional[str] = None,
):
if prisma_client is not None and spend_logs_url is not None:
if isinstance(payload["startTime"], datetime):
@ -1341,6 +1335,9 @@ async def _run_background_health_check():
# make 1 deep copy of llm_model_list -> use this for all background health checks
_llm_model_list = copy.deepcopy(llm_model_list)
if _llm_model_list is None:
return
while True:
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
model_list=_llm_model_list, details=health_check_details
@ -1352,7 +1349,10 @@ async def _run_background_health_check():
health_check_results["healthy_count"] = len(healthy_endpoints)
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
await asyncio.sleep(health_check_interval)
if health_check_interval is not None and isinstance(
health_check_interval, float
):
await asyncio.sleep(health_check_interval)
class ProxyConfig:
@ -1467,7 +1467,7 @@ class ProxyConfig:
break
for k, v in team_config.items():
if isinstance(v, str) and v.startswith("os.environ/"):
team_config[k] = litellm.get_secret(v)
team_config[k] = get_secret(v)
return team_config
def _init_cache(
@ -1513,6 +1513,9 @@ class ProxyConfig:
config = get_file_contents_from_s3(
bucket_name=bucket_name, object_key=object_key
)
if config is None:
raise Exception("Unable to load config from given source.")
else:
# default to file
config = await self.get_config(config_file_path=config_file_path)
@ -1528,9 +1531,7 @@ class ProxyConfig:
environment_variables = config.get("environment_variables", None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = str(
litellm.get_secret(secret_name=key, default_value=value)
)
os.environ[key] = str(get_secret(secret_name=key, default_value=value))
# check if litellm_license in general_settings
if "LITELLM_LICENSE" in environment_variables:
@ -1566,8 +1567,8 @@ class ProxyConfig:
if (
cache_type == "redis" or cache_type == "redis-semantic"
) and len(cache_params.keys()) == 0:
cache_host = litellm.get_secret("REDIS_HOST", None)
cache_port = litellm.get_secret("REDIS_PORT", None)
cache_host = get_secret("REDIS_HOST", None)
cache_port = get_secret("REDIS_PORT", None)
cache_password = None
cache_params.update(
{
@ -1577,8 +1578,8 @@ class ProxyConfig:
}
)
if litellm.get_secret("REDIS_PASSWORD", None) is not None:
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
if get_secret("REDIS_PASSWORD", None) is not None:
cache_password = get_secret("REDIS_PASSWORD", None)
cache_params.update(
{
"password": cache_password,
@ -1617,7 +1618,7 @@ class ProxyConfig:
# users can pass os.environ/ variables on the proxy - we should read them from the env
for key, value in cache_params.items():
if type(value) is str and value.startswith("os.environ/"):
cache_params[key] = litellm.get_secret(value)
cache_params[key] = get_secret(value)
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
self._init_cache(cache_params=cache_params)
@ -1738,7 +1739,7 @@ class ProxyConfig:
if value is not None and isinstance(value, dict):
for _k, _v in value.items():
if isinstance(_v, str) and _v.startswith("os.environ/"):
value[_k] = litellm.get_secret(_v)
value[_k] = get_secret(_v)
litellm.upperbound_key_generate_params = (
LiteLLM_UpperboundKeyGenerateParams(**value)
)
@ -1812,15 +1813,15 @@ class ProxyConfig:
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url)
database_url = get_secret(database_url)
verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url)
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
"master_key", get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
master_key = get_secret(master_key) # type: ignore
if not isinstance(master_key, str):
raise Exception(
"Master key must be a string. Current type - {}".format(
@ -1861,33 +1862,6 @@ class ProxyConfig:
await initialize_pass_through_endpoints(
pass_through_endpoints=general_settings["pass_through_endpoints"]
)
## dynamodb
database_type = general_settings.get("database_type", None)
if database_type is not None and (
database_type == "dynamo_db" or database_type == "dynamodb"
):
database_args = general_settings.get("database_args", None)
### LOAD FROM os.environ/ ###
for k, v in database_args.items():
if isinstance(v, str) and v.startswith("os.environ/"):
database_args[k] = litellm.get_secret(v)
if isinstance(k, str) and k == "aws_web_identity_token":
value = database_args[k]
verbose_proxy_logger.debug(
f"Loading AWS Web Identity Token from file: {value}"
)
if os.path.exists(value):
with open(value, "r") as file:
token_content = file.read()
database_args[k] = token_content
else:
verbose_proxy_logger.info(
f"DynamoDB Loading - {value} is not a valid file path"
)
verbose_proxy_logger.debug("database_args: %s", database_args)
custom_db_client = DBClient(
custom_db_args=database_args, custom_db_type=database_type
)
## ADMIN UI ACCESS ##
ui_access_mode = general_settings.get(
"ui_access_mode", "all"
@ -1951,7 +1925,7 @@ class ProxyConfig:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
model["litellm_params"][k] = get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)
@ -2005,7 +1979,10 @@ class ProxyConfig:
) # type:ignore
# Guardrail settings
guardrails_v2 = config.get("guardrails", None)
guardrails_v2: Optional[dict] = None
if config is not None:
guardrails_v2 = config.get("guardrails", None)
if guardrails_v2:
init_guardrails_v2(
all_guardrails=guardrails_v2, config_file_path=config_file_path
@ -2074,7 +2051,7 @@ class ProxyConfig:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
model["litellm_params"][k] = get_secret(v)
## check if they have model-id's ##
model_id = model.get("model_info", {}).get("id", None)
@ -2234,7 +2211,8 @@ class ProxyConfig:
for k, v in environment_variables.items():
try:
decrypted_value = decrypt_value_helper(value=v)
os.environ[k] = decrypted_value
if decrypted_value is not None:
os.environ[k] = decrypted_value
except Exception as e:
verbose_proxy_logger.error(
"Error setting env variable: %s - %s", k, str(e)
@ -2536,7 +2514,7 @@ async def async_assistants_data_generator(
)
# chunk = chunk.model_dump_json(exclude_none=True)
async for c in chunk:
async for c in chunk: # type: ignore
c = c.model_dump_json(exclude_none=True)
try:
yield f"data: {c}\n\n"
@ -2745,17 +2723,22 @@ async def startup_event():
### LOAD MASTER KEY ###
# check if master key set in environment - load from there
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
# check if DATABASE_URL in environment - load from there
if prisma_client is None:
prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None))
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
prisma_setup(database_url=_db_url)
### LOAD CONFIG ###
worker_config = litellm.get_secret("WORKER_CONFIG")
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
verbose_proxy_logger.debug("worker_config: %s", worker_config)
# check if it's a valid file path
if os.path.isfile(worker_config):
if proxy_config.is_yaml(config_file_path=worker_config):
if worker_config is not None:
if (
isinstance(worker_config, str)
and os.path.isfile(worker_config)
and proxy_config.is_yaml(config_file_path=worker_config)
):
(
llm_router,
llm_model_list,
@ -2763,21 +2746,23 @@ async def startup_event():
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
else:
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
worker_config, str
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
elif isinstance(worker_config, dict):
await initialize(**worker_config)
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
else:
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
await initialize(**worker_config)
else:
# if not, assume it's a json string
worker_config = json.loads(worker_config)
if isinstance(worker_config, dict):
await initialize(**worker_config)
## CHECK PREMIUM USER
verbose_proxy_logger.debug(
@ -2825,7 +2810,7 @@ async def startup_event():
if general_settings.get("litellm_jwtauth", None) is not None:
for k, v in general_settings["litellm_jwtauth"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
general_settings["litellm_jwtauth"][k] = litellm.get_secret(v)
general_settings["litellm_jwtauth"][k] = get_secret(v)
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
else:
litellm_jwtauth = LiteLLM_JWTAuth()
@ -2948,8 +2933,7 @@ async def startup_event():
### ADD NEW MODELS ###
store_model_in_db = (
litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db)
or store_model_in_db
get_secret("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db
) # type: ignore
if store_model_in_db == True:
scheduler.add_job(
@ -3498,7 +3482,7 @@ async def completion(
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
)
fastapi_response.headers.update(
@ -4000,7 +3984,7 @@ async def audio_speech(
request_data=data,
)
return StreamingResponse(
generate(response), media_type="audio/mpeg", headers=custom_headers
generate(response), media_type="audio/mpeg", headers=custom_headers # type: ignore
)
except Exception as e:
@ -4288,6 +4272,7 @@ async def create_assistant(
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
"""
global proxy_logging_obj
data = {} # ensure data always dict
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
@ -7642,6 +7627,7 @@ async def model_group_info(
)
model_groups: List[ModelGroupInfo] = []
for model in all_models_str:
_model_group_info = llm_router.get_model_group_info(model_group=model)
@ -8051,7 +8037,8 @@ async def google_login(request: Request):
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -8616,6 +8603,8 @@ async def auth_callback(request: Request):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
result = None
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
@ -8662,7 +8651,8 @@ async def auth_callback(request: Request):
result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None:
# make generic sso provider
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -8766,8 +8756,8 @@ async def auth_callback(request: Request):
verbose_proxy_logger.debug("generic result: %s", result)
# User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None)
user_id = getattr(result, "id", None)
user_email: Optional[str] = getattr(result, "email", None)
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
email_domain = user_email.split("@")[1]
@ -8783,12 +8773,12 @@ async def auth_callback(request: Request):
)
# generic client id
if generic_client_id is not None:
if generic_client_id is not None and result is not None:
user_id = getattr(result, "id", None)
user_email = getattr(result, "email", None)
user_role = getattr(result, generic_user_role_attribute_name, None)
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
if user_id is None:
if user_id is None and result is not None:
_first_name = getattr(result, "first_name", "") or ""
_last_name = getattr(result, "last_name", "") or ""
user_id = _first_name + _last_name
@ -8811,54 +8801,45 @@ async def auth_callback(request: Request):
"spend": 0,
"team_id": "litellm-dashboard",
}
user_defined_values: SSOUserDefinedValues = {
"models": user_id_models,
"user_id": user_id,
"user_email": user_email,
"max_budget": max_internal_user_budget,
"user_role": None,
"budget_duration": internal_user_budget_duration,
}
user_defined_values: Optional[SSOUserDefinedValues] = None
if user_id is not None:
user_defined_values = SSOUserDefinedValues(
models=user_id_models,
user_id=user_id,
user_email=user_email,
max_budget=max_internal_user_budget,
user_role=None,
budget_duration=internal_user_budget_duration,
)
_user_id_from_sso = user_id
user_role = None
try:
user_role = None
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
verbose_proxy_logger.debug(
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
)
if user_info is not None:
user_defined_values = {
"models": getattr(user_info, "models", user_id_models),
"user_id": getattr(user_info, "user_id", user_id),
"user_email": getattr(user_info, "user_id", user_email),
"user_role": getattr(user_info, "user_role", None),
"max_budget": getattr(
user_info, "max_budget", max_internal_user_budget
),
"budget_duration": getattr(
user_info, "budget_duration", internal_user_budget_duration
),
}
user_role = getattr(user_info, "user_role", None)
if user_info is None:
## check if user-email in db ##
user_info = await prisma_client.db.litellm_usertable.find_first(
where={"user_email": user_email}
)
## check if user-email in db ##
user_info = await prisma_client.db.litellm_usertable.find_first(
where={"user_email": user_email}
)
if user_info is not None:
user_defined_values = {
"models": getattr(user_info, "models", user_id_models),
"user_id": user_id,
"user_email": getattr(user_info, "user_id", user_email),
"user_role": getattr(user_info, "user_role", None),
"max_budget": getattr(
if user_info is not None and user_id is not None:
user_defined_values = SSOUserDefinedValues(
models=getattr(user_info, "models", user_id_models),
user_id=user_id,
user_email=getattr(user_info, "user_email", user_email),
user_role=getattr(user_info, "user_role", None),
max_budget=getattr(
user_info, "max_budget", max_internal_user_budget
),
"budget_duration": getattr(
budget_duration=getattr(
user_info, "budget_duration", internal_user_budget_duration
),
}
)
user_role = getattr(user_info, "user_role", None)
# update id
@ -8886,6 +8867,11 @@ async def auth_callback(request: Request):
except Exception as e:
pass
if user_defined_values is None:
raise Exception(
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
)
is_internal_user = False
if (
user_defined_values["user_role"] is not None
@ -8960,7 +8946,8 @@ async def auth_callback(request: Request):
master_key,
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id
if user_id is not None and isinstance(user_id, str):
litellm_dashboard_ui += "?userID=" + user_id
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
redirect_response.set_cookie(key="token", value=jwt_token)
return redirect_response
@ -9023,6 +9010,7 @@ async def new_invitation(
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
} # type: ignore
)
return response
except Exception as e:
if "Foreign key constraint failed on the field" in str(e):
raise HTTPException(
@ -9031,7 +9019,7 @@ async def new_invitation(
"error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
},
)
return response
raise HTTPException(status_code=500, detail={"error": str(e)})
@router.get(
@ -9951,44 +9939,46 @@ async def get_routes():
"""
routes = []
for route in app.routes:
route_info = {
"path": getattr(route, "path", None),
"methods": getattr(route, "methods", None),
"name": getattr(route, "name", None),
"endpoint": (
getattr(route, "endpoint", None).__name__
if getattr(route, "endpoint", None)
else None
),
}
routes.append(route_info)
endpoint_route = getattr(route, "endpoint", None)
if endpoint_route is not None:
route_info = {
"path": getattr(route, "path", None),
"methods": getattr(route, "methods", None),
"name": getattr(route, "name", None),
"endpoint": (
endpoint_route.__name__
if getattr(route, "endpoint", None)
else None
),
}
routes.append(route_info)
return {"routes": routes}
#### TEST ENDPOINTS ####
@router.get(
"/token/generate",
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def token_generate():
"""
Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
"""
# Initialize AuthJWTSSO with your OpenID Provider configuration
from fastapi_sso import AuthJWTSSO
# @router.get(
# "/token/generate",
# dependencies=[Depends(user_api_key_auth)],
# include_in_schema=False,
# )
# async def token_generate():
# """
# Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
# """
# # Initialize AuthJWTSSO with your OpenID Provider configuration
# from fastapi_sso import AuthJWTSSO
auth_jwt_sso = AuthJWTSSO(
issuer=os.getenv("OPENID_BASE_URL"),
client_id=os.getenv("OPENID_CLIENT_ID"),
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
scopes=["litellm_proxy_admin"],
)
# auth_jwt_sso = AuthJWTSSO(
# issuer=os.getenv("OPENID_BASE_URL"),
# client_id=os.getenv("OPENID_CLIENT_ID"),
# client_secret=os.getenv("OPENID_CLIENT_SECRET"),
# scopes=["litellm_proxy_admin"],
# )
token = auth_jwt_sso.create_access_token()
# token = auth_jwt_sso.create_access_token()
return {"token": token}
# return {"token": token}
@router.on_event("shutdown")
@ -10013,7 +10003,8 @@ async def shutdown_event():
# flush langfuse logs on shutdow
from litellm.utils import langFuseLogger
langFuseLogger.Langfuse.flush()
if langFuseLogger is not None:
langFuseLogger.Langfuse.flush()
except:
# [DO NOT BLOCK shutdown events for this]
pass

View file

@ -92,6 +92,7 @@ from litellm.types.router import (
RetryPolicy,
RouterErrors,
RouterGeneralSettings,
RouterModelGroupAliasItem,
RouterRateLimitError,
RouterRateLimitErrorBasic,
updateDeployment,
@ -105,6 +106,7 @@ from litellm.utils import (
calculate_max_parallel_requests,
create_proxy_transport_and_mounts,
get_llm_provider,
get_secret,
get_utc_datetime,
)
@ -156,7 +158,9 @@ class Router:
fallbacks: List = [],
context_window_fallbacks: List = [],
content_policy_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
model_group_alias: Optional[
Dict[str, Union[str, RouterModelGroupAliasItem]]
] = {},
enable_pre_call_checks: bool = False,
enable_tag_filtering: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
@ -331,7 +335,8 @@ class Router:
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list # type: ignore
for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
if "model" in m["litellm_params"]:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
else:
self.model_list: List = (
[]
@ -398,7 +403,7 @@ class Router:
self.previous_models: List = (
[]
) # list to store failed calls (passed in as metadata to next call)
self.model_group_alias: dict = (
self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
model_group_alias or {}
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
@ -1179,6 +1184,7 @@ class Router:
raise e
def _image_generation(self, prompt: str, model: str, **kwargs):
model_name = ""
try:
verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
@ -1269,6 +1275,7 @@ class Router:
raise e
async def _aimage_generation(self, prompt: str, model: str, **kwargs):
model_name = ""
try:
verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
@ -1401,6 +1408,7 @@ class Router:
raise e
async def _atranscription(self, file: FileTypes, model: str, **kwargs):
model_name = model
try:
verbose_router_logger.debug(
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
@ -1781,6 +1789,7 @@ class Router:
is_async: Optional[bool] = False,
**kwargs,
):
messages = [{"role": "user", "content": prompt}]
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
@ -1789,7 +1798,6 @@ class Router:
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages = [{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(
model=model,
@ -2534,13 +2542,13 @@ class Router:
try:
# Update kwargs with the current model name or any other model-specific adjustments
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
_, custom_llm_provider, _, _ = get_llm_provider(
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model=model_name["litellm_params"]["model"]
)
new_kwargs = copy.deepcopy(kwargs)
new_kwargs.pop("custom_llm_provider", None)
return await litellm.aretrieve_batch(
custom_llm_provider=custom_llm_provider, **new_kwargs
custom_llm_provider=custom_llm_provider, **new_kwargs # type: ignore
)
except Exception as e:
receieved_exceptions.append(e)
@ -2616,13 +2624,13 @@ class Router:
for result in results:
if result is not None:
## check batch id
if final_results["first_id"] is None:
final_results["first_id"] = result.first_id
final_results["last_id"] = result.last_id
if final_results["first_id"] is None and hasattr(result, "first_id"):
final_results["first_id"] = getattr(result, "first_id")
final_results["last_id"] = getattr(result, "last_id")
final_results["data"].extend(result.data) # type: ignore
## check 'has_more'
if result.has_more is True:
if getattr(result, "has_more", False) is True:
final_results["has_more"] = True
return final_results
@ -2874,8 +2882,12 @@ class Router:
verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
original_exception = e
fallback_model_group = None
original_model_group = kwargs.get("model")
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
fallback_failure_exception_str = ""
if original_model_group is None:
raise e
try:
verbose_router_logger.debug("Trying to fallback b/w models")
if isinstance(e, litellm.ContextWindowExceededError):
@ -2972,7 +2984,7 @@ class Router:
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
)
if hasattr(original_exception, "message"):
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore
raise original_exception
response = await run_async_fallback(
@ -2996,12 +3008,12 @@ class Router:
if hasattr(original_exception, "message"):
# add the available fallbacks to the exception
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format(
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
model_group,
fallback_model_group,
)
if len(fallback_failure_exception_str) > 0:
original_exception.message += (
original_exception.message += ( # type: ignore
"\nError doing the fallback: {}".format(
fallback_failure_exception_str
)
@ -3117,9 +3129,15 @@ class Router:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt
_healthy_deployments, _ = await self._async_get_healthy_deployments(
model=kwargs.get("model"),
)
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is not None:
_healthy_deployments, _ = (
await self._async_get_healthy_deployments(
model=_model,
)
)
else:
_healthy_deployments = []
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=remaining_retries,
@ -3129,8 +3147,8 @@ class Router:
await asyncio.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
original_exception.max_retries = num_retries
original_exception.num_retries = current_attempt
setattr(original_exception, "max_retries", num_retries)
setattr(original_exception, "num_retries", current_attempt)
raise original_exception
@ -3225,8 +3243,12 @@ class Router:
return response
except Exception as e:
original_exception = e
original_model_group = kwargs.get("model")
original_model_group: Optional[str] = kwargs.get("model")
verbose_router_logger.debug(f"An exception occurs {original_exception}")
if original_model_group is None:
raise e
try:
verbose_router_logger.debug(
f"Trying to fallback b/w models. Initial model group: {model_group}"
@ -3336,10 +3358,10 @@ class Router:
return 0
response_headers: Optional[httpx.Headers] = None
if hasattr(e, "response") and hasattr(e.response, "headers"):
response_headers = e.response.headers
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
response_headers = e.response.headers # type: ignore
elif hasattr(e, "litellm_response_headers"):
response_headers = e.litellm_response_headers
response_headers = e.litellm_response_headers # type: ignore
if response_headers is not None:
timeout = litellm._calculate_retry_after(
@ -3398,9 +3420,13 @@ class Router:
except Exception as e:
current_attempt = None
original_exception = e
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
model=kwargs.get("model"),
model=_model,
)
# raises an exception if this error should not be retries
@ -3438,8 +3464,12 @@ class Router:
except Exception as e:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
_healthy_deployments, _ = self._get_healthy_deployments(
model=kwargs.get("model"),
model=_model,
)
remaining_retries = num_retries - current_attempt
_timeout = self._time_to_sleep_before_retry(
@ -4055,7 +4085,7 @@ class Router:
if isinstance(_litellm_params, dict):
for k, v in _litellm_params.items():
if isinstance(v, str) and v.startswith("os.environ/"):
_litellm_params[k] = litellm.get_secret(v)
_litellm_params[k] = get_secret(v)
_model_info: dict = model.pop("model_info", {})
@ -4392,7 +4422,6 @@ class Router:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
"""
model_group_info: Optional[ModelGroupInfo] = None
total_tpm: Optional[int] = None
@ -4557,12 +4586,23 @@ class Router:
Returns:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
- None if error constructing model group info or hidden model group
"""
## Check if model group alias
if model_group in self.model_group_alias:
item = self.model_group_alias[model_group]
if isinstance(item, str):
_router_model_group = item
elif isinstance(item, dict):
if item["hidden"] is True:
return None
else:
_router_model_group = item["model"]
else:
return None
return self._set_model_group_info(
model_group=self.model_group_alias[model_group],
model_group=_router_model_group,
user_facing_model_group_name=model_group,
)
@ -4666,7 +4706,14 @@ class Router:
Includes model_group_alias models too.
"""
return self.model_names + list(self.model_group_alias.keys())
model_list = self.get_model_list()
if model_list is None:
return []
model_names = []
for m in model_list:
model_names.append(m["model_name"])
return model_names
def get_model_list(
self, model_name: Optional[str] = None
@ -4678,9 +4725,21 @@ class Router:
returned_models: List[DeploymentTypedDict] = []
for model_alias, model_value in self.model_group_alias.items():
if isinstance(model_value, str):
_router_model_name: str = model_value
elif isinstance(model_value, dict):
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
if _model_value["hidden"] is True:
continue
else:
_router_model_name = _model_value["model"]
else:
continue
returned_models.extend(
self._get_all_deployments(
model_name=model_value, model_alias=model_alias
model_name=_router_model_name, model_alias=model_alias
)
)
@ -5078,10 +5137,11 @@ class Router:
)
if model in self.model_group_alias:
verbose_router_logger.debug(
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
)
model = self.model_group_alias[model]
_item = self.model_group_alias[model]
if isinstance(_item, str):
model = _item
else:
model = _item["model"]
if model not in self.model_names:
# check if provider/ specific wildcard routing
@ -5124,7 +5184,9 @@ class Router:
m for m in self.model_list if m["litellm_params"]["model"] == model
]
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
verbose_router_logger.debug(
f"initial list of deployments: {healthy_deployments}"
)
if len(healthy_deployments) == 0:
raise ValueError(
@ -5208,7 +5270,7 @@ class Router:
)
# check if user wants to do tag based routing
healthy_deployments = await get_deployments_for_tag(
healthy_deployments = await get_deployments_for_tag( # type: ignore
llm_router_instance=self,
request_kwargs=request_kwargs,
healthy_deployments=healthy_deployments,
@ -5241,7 +5303,7 @@ class Router:
input=input,
)
)
if (
elif (
self.routing_strategy == "cost-based-routing"
and self.lowestcost_logger is not None
):
@ -5326,6 +5388,8 @@ class Router:
############## No RPM/TPM passed, we do a random pick #################
item = random.choice(healthy_deployments)
return item or item[0]
else:
deployment = None
if deployment is None:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
@ -5515,6 +5579,9 @@ class Router:
messages=messages,
input=input,
)
else:
deployment = None
if deployment is None:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
@ -5690,6 +5757,9 @@ class Router:
def _initialize_alerting(self):
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
if self.alerting_config is None:
return
router_alerting_config: AlertingConfig = self.alerting_config
_slack_alerting_logger = SlackAlerting(
@ -5700,7 +5770,7 @@ class Router:
self.slack_alerting_logger = _slack_alerting_logger
litellm.callbacks.append(_slack_alerting_logger)
litellm.callbacks.append(_slack_alerting_logger) # type: ignore
litellm.success_callback.append(
_slack_alerting_logger.response_taking_too_long_callback
)

View file

@ -21,8 +21,8 @@ else:
async def get_deployments_for_tag(
llm_router_instance: LitellmRouter,
healthy_deployments: Union[List[Any], Dict[Any, Any]],
request_kwargs: Optional[Dict[Any, Any]] = None,
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
):
if llm_router_instance.enable_tag_filtering is not True:
return healthy_deployments

View file

@ -2828,3 +2828,44 @@ def test_gemini_function_call_parameter_in_messages_2():
]
},
]
@pytest.mark.parametrize(
"base_model, metadata",
[
(None, {"model_info": {"base_model": "vertex_ai/gemini-1.5-pro"}}),
("vertex_ai/gemini-1.5-pro", None),
],
)
def test_gemini_finetuned_endpoint(base_model, metadata):
litellm.set_verbose = True
load_vertex_ai_credentials()
from litellm.llms.custom_httpx.http_handler import HTTPHandler
# Set up the messages
messages = [
{"role": "system", "content": """Use search for most queries."""},
{"role": "user", "content": """search for weather in boston (use `search`)"""},
]
client = HTTPHandler(concurrent_limit=1)
with patch.object(client, "post", new=MagicMock()) as mock_client:
try:
response = completion(
model="vertex_ai/4965075652664360960",
messages=messages,
tool_choice="auto",
client=client,
metadata=metadata,
base_model=base_model,
)
except Exception as e:
print(e)
print(mock_client.call_args.kwargs)
mock_client.assert_called()
assert mock_client.call_args.kwargs["url"].endswith(
"endpoints/4965075652664360960:generateContent"
)

View file

@ -311,6 +311,8 @@ async def test_completion_predibase():
pass
except litellm.ServiceUnavailableError as e:
pass
except litellm.InternalServerError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -133,3 +133,21 @@ async def test_fireworks_health_check():
assert response == {}
return response
@pytest.mark.asyncio
async def test_cohere_rerank_health_check():
response = await litellm.ahealth_check(
model_params={
"model": "cohere/rerank-english-v3.0",
"query": "Hey, how's it going",
"documents": ["my sample text"],
"api_key": os.getenv("COHERE_API_KEY"),
},
mode="rerank",
prompt="Hey, how's it going",
)
assert "error" not in response
print(response)

View file

@ -50,6 +50,7 @@ def test_image_generation_openai():
], # False
) #
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_image_generation_azure(sync_mode):
try:
if sync_mode:

View file

@ -533,7 +533,7 @@ def test_call_with_user_over_budget(prisma_client):
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
pytest.fail(f"This should have failed!. They key crossed it's budget")
pytest.fail("This should have failed!. They key crossed it's budget")
asyncio.run(test())
except Exception as e:
@ -1755,7 +1755,7 @@ def test_call_with_key_over_model_budget(prisma_client):
# use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result)
pytest.fail(f"This should have failed!. They key crossed it's budget")
pytest.fail("This should have failed!. They key crossed it's budget")
asyncio.run(test())
except Exception as e:

View file

@ -589,3 +589,14 @@ def test_parse_additional_properties_json_schema(model, provider, expectedAddPro
elif provider == "openai":
schema = optional_params["response_format"]["json_schema"]["schema"]
assert ("additionalProperties" in schema) == expectedAddProp
def test_o1_model_params():
optional_params = get_optional_params(
model="o1-preview-2024-09-12",
custom_llm_provider="openai",
seed=10,
user="John",
)
assert optional_params["seed"] == 10
assert optional_params["user"] == "John"

View file

@ -762,7 +762,7 @@ async def test_team_update_redis():
) as mock_client:
await _cache_team_object(
team_id="1234",
team_table=LiteLLM_TeamTableCachedObj(),
team_table=LiteLLM_TeamTableCachedObj(team_id="1234"),
user_api_key_cache=DualCache(),
proxy_logging_obj=proxy_logging_obj,
)
@ -776,7 +776,7 @@ async def test_get_team_redis(client_no_auth):
Tests if get_team_object gets value from redis cache, if set
"""
from litellm.caching import DualCache, RedisCache
from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object
from litellm.proxy.auth.auth_checks import get_team_object
proxy_logging_obj: ProxyLogging = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj"
@ -917,7 +917,9 @@ async def test_create_team_member_add(prisma_client, new_member_method):
)
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj())
team_mock_client.update = AsyncMock(
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
)
await team_member_add(
data=team_member_add_request,
@ -1095,7 +1097,9 @@ async def test_create_team_member_add_team_admin(
)
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj())
team_mock_client.update = AsyncMock(
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
)
try:
await team_member_add(
@ -1434,8 +1438,9 @@ async def test_gemini_pass_through_endpoint():
print(resp.body)
@pytest.mark.parametrize("hidden", [True, False])
@pytest.mark.asyncio
async def test_proxy_model_group_alias_checks(prisma_client):
async def test_proxy_model_group_alias_checks(prisma_client, hidden):
"""
Check if model group alias is returned on
@ -1465,7 +1470,7 @@ async def test_proxy_model_group_alias_checks(prisma_client):
model_alias = "gpt-4"
router = litellm.Router(
model_list=_model_list,
model_group_alias={model_alias: "gpt-3.5-turbo"},
model_group_alias={model_alias: {"model": "gpt-3.5-turbo", "hidden": hidden}},
)
setattr(litellm.proxy.proxy_server, "llm_router", router)
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
@ -1477,7 +1482,10 @@ async def test_proxy_model_group_alias_checks(prisma_client):
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
assert len(resp) == 2
if hidden:
assert len(resp["data"]) == 1
else:
assert len(resp["data"]) == 2
print(resp)
resp = await model_info_v1(
@ -1489,7 +1497,10 @@ async def test_proxy_model_group_alias_checks(prisma_client):
if model_alias == item["model_name"]:
is_model_alias_in_list = True
assert is_model_alias_in_list
if hidden:
assert is_model_alias_in_list is False
else:
assert is_model_alias_in_list
resp = await model_group_info(
user_api_key_dict=UserAPIKeyAuth(models=[]),
@ -1500,4 +1511,7 @@ async def test_proxy_model_group_alias_checks(prisma_client):
if model_alias == item.model_group:
is_model_alias_in_list = True
assert is_model_alias_in_list
if hidden:
assert is_model_alias_in_list is False
else:
assert is_model_alias_in_list, f"models: {models}"

View file

@ -2481,3 +2481,31 @@ async def test_router_batch_endpoints(provider):
model="my-custom-name", custom_llm_provider=provider, limit=2
)
print("list_batches=", list_batches)
@pytest.mark.parametrize("hidden", [True, False])
def test_model_group_alias(hidden):
_model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
{"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}},
]
router = Router(
model_list=_model_list,
model_group_alias={
"gpt-4.5-turbo": {"model": "gpt-3.5-turbo", "hidden": hidden}
},
)
models = router.get_model_list()
model_names = router.get_model_names()
if hidden:
assert len(models) == len(_model_list)
assert len(model_names) == len(_model_list)
else:
assert len(models) == len(_model_list) + 1
assert len(model_names) == len(_model_list) + 1

View file

@ -582,3 +582,8 @@ class RouterRateLimitError(ValueError):
self.cooldown_list = cooldown_list
_message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}"
super().__init__(_message)
class RouterModelGroupAliasItem(TypedDict):
model: str
hidden: bool # if 'True', don't return on `.get_model_list`

View file

@ -76,6 +76,7 @@ from litellm.types.llms.openai import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionToolParam,
)
from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import (
CallTypes,
ChatCompletionDeltaToolCall,
@ -84,7 +85,6 @@ from litellm.types.utils import (
Delta,
Embedding,
EmbeddingResponse,
FileTypes,
ImageResponse,
Message,
ModelInfo,
@ -2339,6 +2339,7 @@ def get_litellm_params(
text_completion=None,
azure_ad_token_provider=None,
user_continue_message=None,
base_model=None,
):
litellm_params = {
"acompletion": acompletion,
@ -2365,6 +2366,8 @@ def get_litellm_params(
"text_completion": text_completion,
"azure_ad_token_provider": azure_ad_token_provider,
"user_continue_message": user_continue_message,
"base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
}
return litellm_params
@ -6063,11 +6066,11 @@ def _calculate_retry_after(
max_retries: int,
response_headers: Optional[httpx.Headers] = None,
min_timeout: int = 0,
):
) -> Union[float, int]:
retry_after = _get_retry_after_from_exception_header(response_headers)
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
if 0 < retry_after <= 60:
if retry_after is not None and 0 < retry_after <= 60:
return retry_after
initial_retry_delay = 0.5
@ -10962,6 +10965,22 @@ def get_logging_id(start_time, response_obj):
return None
def _get_base_model_from_litellm_call_metadata(
metadata: Optional[dict],
) -> Optional[str]:
if metadata is None:
return None
if metadata is not None:
model_info = metadata.get("model_info", {})
if model_info is not None:
base_model = model_info.get("base_model", None)
if base_model is not None:
return base_model
return None
def _get_base_model_from_metadata(model_call_details=None):
if model_call_details is None:
return None
@ -10970,13 +10989,7 @@ def _get_base_model_from_metadata(model_call_details=None):
if litellm_params is not None:
metadata = litellm_params.get("metadata", {})
if metadata is not None:
model_info = metadata.get("model_info", {})
if model_info is not None:
base_model = model_info.get("base_model", None)
if base_model is not None:
return base_model
return _get_base_model_from_litellm_call_metadata(metadata=metadata)
return None

6
pyrightconfig.json Normal file
View file

@ -0,0 +1,6 @@
{
"ignore": [],
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py"],
"reportMissingImports": false
}