forked from phoenix/litellm-mirror
LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)
* refactor: cleanup unused variables + fix pyright errors * feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686 * fix(o1_reasoning.py): add stricter check for o-1 reasoning model * refactor(mistral/): make it easier to see mistral transformation logic * fix(openai.py): fix openai o-1 model param mapping Fixes https://github.com/BerriAI/litellm/issues/5685 * feat(main.py): infer finetuned gemini model from base model Fixes https://github.com/BerriAI/litellm/issues/5678 * docs(vertex.md): update docs to call finetuned gemini models * feat(proxy_server.py): allow admin to hide proxy model aliases Closes https://github.com/BerriAI/litellm/issues/5692 * docs(load_balancing.md): add docs on hiding alias models from proxy config * fix(base.py): don't raise notimplemented error * fix(user_api_key_auth.py): fix model max budget check * fix(router.py): fix elif * fix(user_api_key_auth.py): don't set team_id to empty str * fix(team_endpoints.py): fix response type * test(test_completion.py): handle predibase error * test(test_proxy_server.py): fix test * fix(o1_transformation.py): fix max_completion_token mapping * test(test_image_generation.py): mark flaky test
This commit is contained in:
parent
db3af20d84
commit
60709a0753
35 changed files with 1020 additions and 539 deletions
|
@ -1,9 +1,9 @@
|
|||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: python3 -m mypy --ignore-missing-imports
|
||||
- id: pyright
|
||||
name: pyright
|
||||
entry: pyright
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^litellm/
|
||||
|
|
|
@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
https://github.com/BerriAI/litellm
|
||||
|
||||
## **Call 100+ LLMs using the same Input/Output Format**
|
||||
## **Call 100+ LLMs using the OpenAI Input/Output Format**
|
||||
|
||||
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
|
||||
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`
|
||||
|
|
|
@ -1180,9 +1180,58 @@ response = completion(
|
|||
|
||||
Fine tuned models on vertex have a numerical model/endpoint id.
|
||||
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| your fine tuned model | `completion(model='vertex_ai/4965075652664360960', messages)`|
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
## set ENV variables
|
||||
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
|
||||
os.environ["VERTEXAI_LOCATION"] = "us-central1"
|
||||
|
||||
response = completion(
|
||||
model="vertex_ai/<your-finetuned-model>", # e.g. vertex_ai/4965075652664360960
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add Vertex Credentials to your env
|
||||
|
||||
```bash
|
||||
!gcloud auth application-default login
|
||||
```
|
||||
|
||||
2. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
- model_name: finetuned-gemini
|
||||
litellm_params:
|
||||
model: vertex_ai/<ENDPOINT_ID>
|
||||
vertex_project: <PROJECT_ID>
|
||||
vertex_location: <LOCATION>
|
||||
model_info:
|
||||
base_model: vertex_ai/gemini-1.5-pro # IMPORTANT
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl --location 'https://0.0.0.0:4000/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: <LITELLM_KEY>' \
|
||||
--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
## Gemini Pro Vision
|
||||
| Model Name | Function Call |
|
||||
|
|
|
@ -38,9 +38,24 @@ router_settings:
|
|||
|
||||
## Router settings on config - routing_strategy, model_group_alias
|
||||
|
||||
Expose an 'alias' for a 'model_name' on the proxy server.
|
||||
|
||||
```
|
||||
model_group_alias: {
|
||||
"gpt-4": "gpt-3.5-turbo"
|
||||
}
|
||||
```
|
||||
|
||||
These aliases are shown on `/v1/models`, `/v1/model/info`, and `/v1/model_group/info` by default.
|
||||
|
||||
litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64)
|
||||
|
||||
|
||||
|
||||
### Usage
|
||||
|
||||
Example config with `router_settings`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
|
@ -48,19 +63,41 @@ model_list:
|
|||
model: azure/<your-deployment-name>
|
||||
api_base: <your-azure-endpoint>
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
|
||||
|
||||
router_settings:
|
||||
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models
|
||||
```
|
||||
|
||||
### Hide Alias Models
|
||||
|
||||
Use this if you want to set-up aliases for:
|
||||
|
||||
1. typos
|
||||
2. minor model version changes
|
||||
3. case sensitive changes between updates
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
model: azure/<your-deployment-name>
|
||||
api_base: <your-azure-endpoint>
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6
|
||||
|
||||
router_settings:
|
||||
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
|
||||
routing_strategy: least-busy # Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"]
|
||||
num_retries: 2
|
||||
timeout: 30 # 30 seconds
|
||||
redis_host: <your redis host>
|
||||
redis_password: <your redis password>
|
||||
redis_port: 1992
|
||||
model_group_alias:
|
||||
"GPT-3.5-turbo": # alias
|
||||
model: "gpt-3.5-turbo" # Actual model name in 'model_list'
|
||||
hidden: true # Exclude from `/v1/models`, `/v1/model/info`, `/v1/model_group/info`
|
||||
```
|
||||
|
||||
### Complete Spec
|
||||
|
||||
```python
|
||||
model_group_alias: Optional[Dict[str, Union[str, RouterModelGroupAliasItem]]] = {}
|
||||
|
||||
|
||||
class RouterModelGroupAliasItem(TypedDict):
|
||||
model: str
|
||||
hidden: bool # if 'True', don't return on `/v1/models`, `/v1/model/info`, `/v1/model_group/info`
|
||||
```
|
|
@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
https://github.com/BerriAI/litellm
|
||||
|
||||
## **Call 100+ LLMs using the same Input/Output Format**
|
||||
## **Call 100+ LLMs using the OpenAI Input/Output Format**
|
||||
|
||||
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
|
||||
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`
|
||||
|
|
|
@ -939,15 +939,18 @@ from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConf
|
|||
from .llms.OpenAI.openai import (
|
||||
OpenAIConfig,
|
||||
OpenAITextCompletionConfig,
|
||||
MistralConfig,
|
||||
MistralEmbeddingConfig,
|
||||
DeepInfraConfig,
|
||||
GroqConfig,
|
||||
AzureAIStudioConfig,
|
||||
)
|
||||
from .llms.OpenAI.o1_reasoning import (
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
from .llms.OpenAI.o1_transformation import (
|
||||
OpenAIO1Config,
|
||||
)
|
||||
from .llms.OpenAI.gpt_transformation import (
|
||||
OpenAIGPTConfig,
|
||||
)
|
||||
from .llms.nvidia_nim import NvidiaNimConfig
|
||||
from .llms.cerebras.chat import CerebrasConfig
|
||||
from .llms.AI21.chat import AI21ChatConfig
|
||||
|
|
|
@ -52,7 +52,9 @@ class SlackAlerting(CustomBatchLogger):
|
|||
def __init__(
|
||||
self,
|
||||
internal_usage_cache: Optional[DualCache] = None,
|
||||
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
|
||||
alerting_threshold: Optional[
|
||||
float
|
||||
] = None, # threshold for slow / hanging llm responses (in seconds)
|
||||
alerting: Optional[List] = [],
|
||||
alert_types: List[AlertType] = [
|
||||
"llm_exceptions",
|
||||
|
@ -74,6 +76,8 @@ class SlackAlerting(CustomBatchLogger):
|
|||
default_webhook_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if alerting_threshold is None:
|
||||
alerting_threshold = 300
|
||||
self.alerting_threshold = alerting_threshold
|
||||
self.alerting = alerting
|
||||
self.alert_types = alert_types
|
||||
|
@ -99,6 +103,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
):
|
||||
if alerting is not None:
|
||||
self.alerting = alerting
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
if alerting_threshold is not None:
|
||||
self.alerting_threshold = alerting_threshold
|
||||
if alert_types is not None:
|
||||
|
@ -114,8 +119,6 @@ class SlackAlerting(CustomBatchLogger):
|
|||
if llm_router is not None:
|
||||
self.llm_router = llm_router
|
||||
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
|
||||
async def deployment_in_cooldown(self):
|
||||
pass
|
||||
|
||||
|
@ -208,15 +211,20 @@ class SlackAlerting(CustomBatchLogger):
|
|||
_deployment_latencies = metadata["_latency_per_deployment"]
|
||||
if len(_deployment_latencies) == 0:
|
||||
return None
|
||||
_deployment_latency_map: Optional[dict] = None
|
||||
try:
|
||||
# try sorting deployments by latency
|
||||
_deployment_latencies = sorted(
|
||||
_deployment_latencies.items(), key=lambda x: x[1]
|
||||
)
|
||||
_deployment_latencies = dict(_deployment_latencies)
|
||||
except:
|
||||
_deployment_latency_map = dict(_deployment_latencies)
|
||||
except Exception:
|
||||
pass
|
||||
for api_base, latency in _deployment_latencies.items():
|
||||
|
||||
if _deployment_latency_map is None:
|
||||
return
|
||||
|
||||
for api_base, latency in _deployment_latency_map.items():
|
||||
_message_to_send += f"\n{api_base}: {round(latency,2)}s"
|
||||
_message_to_send = "```" + _message_to_send + "```"
|
||||
return _message_to_send
|
||||
|
@ -475,6 +483,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
):
|
||||
if self.alerting is None or self.alert_types is None:
|
||||
return
|
||||
model: str = ""
|
||||
if request_data is not None:
|
||||
model = request_data.get("model", "")
|
||||
messages = request_data.get("messages", None)
|
||||
|
@ -619,6 +628,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
return
|
||||
_id: Optional[str] = "default_id" # used for caching
|
||||
user_info_json = user_info.model_dump(exclude_none=True)
|
||||
user_info_str = ""
|
||||
for k, v in user_info_json.items():
|
||||
user_info_str = "\n{}: {}\n".format(k, v)
|
||||
|
||||
|
@ -1475,10 +1485,10 @@ Model Info:
|
|||
|
||||
if isinstance(response_obj, litellm.ModelResponse) and (
|
||||
hasattr(response_obj, "usage")
|
||||
and response_obj.usage is not None
|
||||
and hasattr(response_obj.usage, "completion_tokens")
|
||||
and response_obj.usage is not None # type: ignore
|
||||
and hasattr(response_obj.usage, "completion_tokens") # type: ignore
|
||||
):
|
||||
completion_tokens = response_obj.usage.completion_tokens
|
||||
completion_tokens = response_obj.usage.completion_tokens # type: ignore
|
||||
if completion_tokens is not None and completion_tokens > 0:
|
||||
final_value = float(
|
||||
response_s.total_seconds() / completion_tokens
|
||||
|
@ -1608,10 +1618,14 @@ Model Info:
|
|||
todays_date = datetime.datetime.now().date()
|
||||
start_date = todays_date - datetime.timedelta(days=days)
|
||||
|
||||
spend_per_team, spend_per_tag = await _get_spend_report_for_time_range(
|
||||
_resp = await _get_spend_report_for_time_range(
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=todays_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
if _resp is None:
|
||||
return
|
||||
|
||||
spend_per_team, spend_per_tag = _resp
|
||||
|
||||
_spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n"
|
||||
|
||||
|
@ -1656,13 +1670,16 @@ Model Info:
|
|||
days=last_day_of_month - 1
|
||||
)
|
||||
|
||||
monthly_spend_per_team, monthly_spend_per_tag = (
|
||||
await _get_spend_report_for_time_range(
|
||||
start_date=first_day_of_month.strftime("%Y-%m-%d"),
|
||||
end_date=last_day_of_month.strftime("%Y-%m-%d"),
|
||||
)
|
||||
_resp = await _get_spend_report_for_time_range(
|
||||
start_date=first_day_of_month.strftime("%Y-%m-%d"),
|
||||
end_date=last_day_of_month.strftime("%Y-%m-%d"),
|
||||
)
|
||||
|
||||
if _resp is None:
|
||||
return
|
||||
|
||||
monthly_spend_per_team, monthly_spend_per_tag = _resp
|
||||
|
||||
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
|
||||
|
||||
if monthly_spend_per_team is not None:
|
||||
|
|
142
litellm/llms/OpenAI/gpt_transformation.py
Normal file
142
litellm/llms/OpenAI/gpt_transformation.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
"""
|
||||
Support for gpt model family
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
|
||||
|
||||
class OpenAIGPTConfig:
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_tokens",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
] # works across all models
|
||||
|
||||
model_specific_params = []
|
||||
if (
|
||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||
): # gpt-4 does not support 'response_format'
|
||||
model_specific_params.append("response_format")
|
||||
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
) or model in litellm.open_ai_text_completion_models:
|
||||
model_specific_params.append(
|
||||
"user"
|
||||
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
||||
return base_params + model_specific_params
|
||||
|
||||
def _map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
) -> dict:
|
||||
return self._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
|
@ -17,13 +17,12 @@ from typing import Any, List, Optional, Union
|
|||
import litellm
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
|
||||
from .openai import OpenAIConfig
|
||||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class OpenAIO1Config(OpenAIConfig):
|
||||
class OpenAIO1Config(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
|
@ -50,9 +49,7 @@ class OpenAIO1Config(OpenAIConfig):
|
|||
|
||||
"""
|
||||
|
||||
all_openai_params = litellm.OpenAIConfig().get_supported_openai_params(
|
||||
model="gpt-4o"
|
||||
)
|
||||
all_openai_params = super().get_supported_openai_params(model=model)
|
||||
non_supported_params = [
|
||||
"logprobs",
|
||||
"tools",
|
||||
|
@ -69,13 +66,14 @@ class OpenAIO1Config(OpenAIConfig):
|
|||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_completion_tokens"] = value
|
||||
return optional_params
|
||||
if "max_tokens" in non_default_params:
|
||||
optional_params["max_completion_tokens"] = non_default_params.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
return super()._map_openai_params(non_default_params, optional_params, model)
|
||||
|
||||
def is_model_o1_reasoning_model(self, model: str) -> bool:
|
||||
if "o1" in model:
|
||||
if model in litellm.open_ai_chat_completion_models and "o1" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -93,7 +91,7 @@ class OpenAIO1Config(OpenAIConfig):
|
|||
)
|
||||
messages[i] = new_message # Replace the old message with the new one
|
||||
|
||||
if isinstance(message["content"], list):
|
||||
if "content" in message and isinstance(message["content"], list):
|
||||
new_content = []
|
||||
for content_item in message["content"]:
|
||||
if content_item.get("type") == "image_url":
|
|
@ -60,122 +60,6 @@ class OpenAIError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class MistralConfig:
|
||||
"""
|
||||
Reference: https://docs.mistral.ai/api/
|
||||
|
||||
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
|
||||
|
||||
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
|
||||
|
||||
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
||||
|
||||
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
|
||||
|
||||
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
|
||||
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
||||
|
||||
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
|
||||
"""
|
||||
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
tool_choice: Optional[Literal["auto", "any", "none"]] = None
|
||||
random_seed: Optional[int] = None
|
||||
safe_prompt: Optional[bool] = None
|
||||
response_format: Optional[dict] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
safe_prompt: Optional[bool] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"stop",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def _map_tool_choice(self, tool_choice: str) -> str:
|
||||
if tool_choice == "auto" or tool_choice == "none":
|
||||
return tool_choice
|
||||
elif tool_choice == "required":
|
||||
return "any"
|
||||
else: # openai 'tool_choice' object param not supported by Mistral API
|
||||
return "any"
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "tool_choice" and isinstance(value, str):
|
||||
optional_params["tool_choice"] = self._map_tool_choice(
|
||||
tool_choice=value
|
||||
)
|
||||
if param == "seed":
|
||||
optional_params["extra_body"] = {"random_seed": value}
|
||||
if param == "response_format":
|
||||
optional_params["response_format"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class MistralEmbeddingConfig:
|
||||
"""
|
||||
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
|
||||
|
@ -526,44 +410,19 @@ class OpenAIConfig:
|
|||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_tokens",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
] # works across all models
|
||||
|
||||
model_specific_params = []
|
||||
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
|
||||
return litellm.OpenAIO1Config().get_supported_openai_params(model=model)
|
||||
if (
|
||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||
): # gpt-4 does not support 'response_format'
|
||||
model_specific_params.append("response_format")
|
||||
else:
|
||||
return litellm.OpenAIGPTConfig().get_supported_openai_params(model=model)
|
||||
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
) or model in litellm.open_ai_text_completion_models:
|
||||
model_specific_params.append(
|
||||
"user"
|
||||
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
||||
return base_params + model_specific_params
|
||||
def _map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
|
@ -575,11 +434,11 @@ class OpenAIConfig:
|
|||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
return litellm.OpenAIGPTConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
class OpenAITextCompletionConfig:
|
||||
|
@ -816,18 +675,18 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def completion(
|
||||
def completion( # type: ignore
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
logging_obj=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
|
@ -858,14 +717,14 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
# process all OpenAI compatible provider logic here
|
||||
if custom_llm_provider == "mistral":
|
||||
# check if message content passed in as list, and not string
|
||||
messages = prompt_factory(
|
||||
messages = prompt_factory( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
if custom_llm_provider == "perplexity" and messages is not None:
|
||||
# check if messages.name is passed + supported, if not supported remove
|
||||
messages = prompt_factory(
|
||||
messages = prompt_factory( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
|
@ -933,7 +792,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
|
||||
openai_client = self._get_openai_client(
|
||||
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=False,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -1068,7 +927,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
2
|
||||
): # if call fails due to alternating messages, retry with reformatted message
|
||||
try:
|
||||
openai_aclient = self._get_openai_client(
|
||||
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=True,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -1156,7 +1015,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
max_retries=None,
|
||||
headers=None,
|
||||
):
|
||||
openai_client = self._get_openai_client(
|
||||
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=False,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -1210,7 +1069,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
response = None
|
||||
for _ in range(2):
|
||||
try:
|
||||
openai_aclient = self._get_openai_client(
|
||||
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=True,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -1282,7 +1141,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
error_headers = getattr(e, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=500,
|
||||
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
||||
message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore
|
||||
headers=error_headers,
|
||||
)
|
||||
else:
|
||||
|
@ -1294,7 +1153,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
)
|
||||
elif hasattr(e, "status_code"):
|
||||
raise OpenAIError(
|
||||
status_code=e.status_code,
|
||||
status_code=getattr(e, "status_code", 500),
|
||||
message=str(e),
|
||||
headers=error_headers,
|
||||
)
|
||||
|
@ -1361,7 +1220,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
):
|
||||
response = None
|
||||
try:
|
||||
openai_aclient = self._get_openai_client(
|
||||
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=True,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -1410,16 +1269,16 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
def embedding(
|
||||
def embedding( # type: ignore
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
):
|
||||
|
@ -1452,7 +1311,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
)
|
||||
return response
|
||||
|
||||
openai_client = self._get_openai_client(
|
||||
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=False,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -1496,11 +1355,11 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
logging_obj: Any,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -1538,15 +1397,16 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model: Optional[str],
|
||||
prompt: str,
|
||||
timeout: float,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aimg_generation=None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
data = {}
|
||||
try:
|
||||
model = model
|
||||
data = {"model": model, "prompt": prompt, **optional_params}
|
||||
|
@ -1611,7 +1471,9 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
original_response=str(e),
|
||||
)
|
||||
if hasattr(e, "status_code"):
|
||||
raise OpenAIError(status_code=e.status_code, message=str(e))
|
||||
raise OpenAIError(
|
||||
status_code=getattr(e, "status_code", 500), message=str(e)
|
||||
)
|
||||
else:
|
||||
raise OpenAIError(status_code=500, message=str(e))
|
||||
|
||||
|
@ -1661,7 +1523,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
input=input,
|
||||
**optional_params,
|
||||
)
|
||||
return response
|
||||
return response # type: ignore
|
||||
|
||||
async def async_audio_speech(
|
||||
self,
|
||||
|
@ -1784,11 +1646,8 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
|
||||
|
||||
class OpenAITextCompletion(BaseLLM):
|
||||
_client_session: httpx.Client
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._client_session = self.create_client_session()
|
||||
|
||||
def validate_environment(self, api_key):
|
||||
headers = {
|
||||
|
@ -1806,10 +1665,10 @@ class OpenAITextCompletion(BaseLLM):
|
|||
messages: list,
|
||||
timeout: float,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_base: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
|
@ -1921,7 +1780,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
api_key: str,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries=None,
|
||||
max_retries: int,
|
||||
organization: Optional[str] = None,
|
||||
client=None,
|
||||
):
|
||||
|
@ -2017,9 +1876,9 @@ class OpenAITextCompletion(BaseLLM):
|
|||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
organization=None,
|
||||
):
|
||||
if client is None:
|
||||
|
|
12
litellm/llms/README.md
Normal file
12
litellm/llms/README.md
Normal file
|
@ -0,0 +1,12 @@
|
|||
## File Structure
|
||||
|
||||
### August 27th, 2024
|
||||
|
||||
To make it easy to see how calls are transformed for each model/provider:
|
||||
|
||||
we are working on moving all supported litellm providers to a folder structure, where folder name is the supported litellm provider name.
|
||||
|
||||
Each folder will contain a `*_transformation.py` file, which has all the request/response transformation logic, making it easy to see how calls are modified.
|
||||
|
||||
E.g. `cohere/`, `bedrock/`.
|
||||
|
|
@ -66,22 +66,24 @@ class BaseLLM:
|
|||
return _aclient_session
|
||||
|
||||
def __exit__(self):
|
||||
if hasattr(self, "_client_session"):
|
||||
if hasattr(self, "_client_session") and self._client_session is not None:
|
||||
self._client_session.close()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if hasattr(self, "_aclient_session"):
|
||||
await self._aclient_session.aclose()
|
||||
await self._aclient_session.aclose() # type: ignore
|
||||
|
||||
def validate_environment(self): # set up the environment required to run the model
|
||||
pass
|
||||
def validate_environment(
|
||||
self, *args, **kwargs
|
||||
) -> Optional[Any]: # set up the environment required to run the model
|
||||
return None
|
||||
|
||||
def completion(
|
||||
self, *args, **kwargs
|
||||
): # logic for parsing in - calling - parsing out model completion calls
|
||||
pass
|
||||
) -> Any: # logic for parsing in - calling - parsing out model completion calls
|
||||
return None
|
||||
|
||||
def embedding(
|
||||
self, *args, **kwargs
|
||||
): # logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
) -> Any: # logic for parsing in - calling - parsing out model embedding calls
|
||||
return None
|
||||
|
|
5
litellm/llms/mistral/chat.py
Normal file
5
litellm/llms/mistral/chat.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Calls handled in openai/
|
||||
|
||||
as mistral is an openai-compatible endpoint.
|
||||
"""
|
5
litellm/llms/mistral/embedding.py
Normal file
5
litellm/llms/mistral/embedding.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Calls handled in openai/
|
||||
|
||||
as mistral is an openai-compatible endpoint.
|
||||
"""
|
126
litellm/llms/mistral/mistral_chat_transformation.py
Normal file
126
litellm/llms/mistral/mistral_chat_transformation.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
"""
|
||||
Transformation logic from OpenAI /v1/chat/completion format to Mistral's /chat/completion format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Docs - https://docs.mistral.ai/api/
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
|
||||
class MistralConfig:
|
||||
"""
|
||||
Reference: https://docs.mistral.ai/api/
|
||||
|
||||
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
|
||||
|
||||
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
|
||||
|
||||
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
||||
|
||||
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
|
||||
|
||||
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
|
||||
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
||||
|
||||
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
|
||||
"""
|
||||
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
tool_choice: Optional[Literal["auto", "any", "none"]] = None
|
||||
random_seed: Optional[int] = None
|
||||
safe_prompt: Optional[bool] = None
|
||||
response_format: Optional[dict] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
safe_prompt: Optional[bool] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"stop",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def _map_tool_choice(self, tool_choice: str) -> str:
|
||||
if tool_choice == "auto" or tool_choice == "none":
|
||||
return tool_choice
|
||||
elif tool_choice == "required":
|
||||
return "any"
|
||||
else: # openai 'tool_choice' object param not supported by Mistral API
|
||||
return "any"
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "tool_choice" and isinstance(value, str):
|
||||
optional_params["tool_choice"] = self._map_tool_choice(
|
||||
tool_choice=value
|
||||
)
|
||||
if param == "seed":
|
||||
optional_params["extra_body"] = {"random_seed": value}
|
||||
if param == "response_format":
|
||||
optional_params["response_format"] = value
|
||||
return optional_params
|
0
litellm/llms/mistral/mistral_embedding_transformation.py
Normal file
0
litellm/llms/mistral/mistral_embedding_transformation.py
Normal file
|
@ -737,6 +737,7 @@ def completion(
|
|||
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||
hf_model_name = kwargs.get("hf_model_name", None)
|
||||
supports_system_message = kwargs.get("supports_system_message", None)
|
||||
base_model = kwargs.get("base_model", None)
|
||||
### TEXT COMPLETION CALLS ###
|
||||
text_completion = kwargs.get("text_completion", False)
|
||||
atext_completion = kwargs.get("atext_completion", False)
|
||||
|
@ -782,11 +783,9 @@ def completion(
|
|||
"top_logprobs",
|
||||
"extra_headers",
|
||||
]
|
||||
litellm_params = (
|
||||
all_litellm_params # use the external var., used in creating cache key as well.
|
||||
)
|
||||
|
||||
default_params = openai_params + litellm_params
|
||||
default_params = openai_params + all_litellm_params
|
||||
litellm_params = {} # used to prevent unbound var errors
|
||||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
|
@ -973,6 +972,7 @@ def completion(
|
|||
text_completion=kwargs.get("text_completion"),
|
||||
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
||||
user_continue_message=kwargs.get("user_continue_message"),
|
||||
base_model=base_model,
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -2123,7 +2123,10 @@ def completion(
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
elif "gemini" in model:
|
||||
elif "gemini" in model or (
|
||||
litellm_params.get("base_model") is not None
|
||||
and "gemini" in litellm_params["base_model"]
|
||||
):
|
||||
model_response = vertex_chat_completion.completion( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -2820,7 +2823,7 @@ def completion_with_retries(*args, **kwargs):
|
|||
)
|
||||
|
||||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
|
||||
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "constant_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
|
@ -4997,7 +5000,9 @@ def speech(
|
|||
async def ahealth_check(
|
||||
model_params: dict,
|
||||
mode: Optional[
|
||||
Literal["completion", "embedding", "image_generation", "chat", "batch"]
|
||||
Literal[
|
||||
"completion", "embedding", "image_generation", "chat", "batch", "rerank"
|
||||
]
|
||||
] = None,
|
||||
prompt: Optional[str] = None,
|
||||
input: Optional[List] = None,
|
||||
|
@ -5113,6 +5118,12 @@ async def ahealth_check(
|
|||
model_params["prompt"] = prompt
|
||||
await litellm.aimage_generation(**model_params)
|
||||
response = {}
|
||||
elif mode == "rerank":
|
||||
model_params.pop("messages", None)
|
||||
model_params["query"] = prompt
|
||||
model_params["documents"] = ["my sample text"]
|
||||
await litellm.arerank(**model_params)
|
||||
response = {}
|
||||
elif "*" in model:
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_model_from_llm_provider,
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
model_list:
|
||||
- model_name: "gpt-4o"
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-4o
|
||||
model: gpt-3.5-turbo
|
||||
|
||||
litellm_settings:
|
||||
cache: true
|
||||
cache_params:
|
||||
type: local
|
||||
router_settings:
|
||||
model_group_alias: {"gpt-4": {"model": "gpt-3.5-turbo", "hidden": false}}
|
|
@ -99,15 +99,15 @@ class LitellmUserRoles(str, enum.Enum):
|
|||
return ui_labels.get(self.value, "")
|
||||
|
||||
|
||||
class LitellmTableNames(str, enum.Enum):
|
||||
class LitellmTableNames(enum.Enum):
|
||||
"""
|
||||
Enum for Table Names used by LiteLLM
|
||||
"""
|
||||
|
||||
TEAM_TABLE_NAME: str = "LiteLLM_TeamTable"
|
||||
USER_TABLE_NAME: str = "LiteLLM_UserTable"
|
||||
KEY_TABLE_NAME: str = "LiteLLM_VerificationToken"
|
||||
PROXY_MODEL_TABLE_NAME: str = "LiteLLM_ModelTable"
|
||||
TEAM_TABLE_NAME = "LiteLLM_TeamTable"
|
||||
USER_TABLE_NAME = "LiteLLM_UserTable"
|
||||
KEY_TABLE_NAME = "LiteLLM_VerificationToken"
|
||||
PROXY_MODEL_TABLE_NAME = "LiteLLM_ModelTable"
|
||||
|
||||
|
||||
AlertType = Literal[
|
||||
|
@ -140,7 +140,7 @@ class LiteLLMBase(BaseModel):
|
|||
Implements default functions, all pydantic objects should have.
|
||||
"""
|
||||
|
||||
def json(self, **kwargs):
|
||||
def json(self, **kwargs): # type: ignore
|
||||
try:
|
||||
return self.model_dump(**kwargs) # noqa
|
||||
except Exception as e:
|
||||
|
@ -170,7 +170,7 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
|
|||
|
||||
|
||||
class LiteLLMRoutes(enum.Enum):
|
||||
openai_route_names: List = [
|
||||
openai_route_names = [
|
||||
"chat_completion",
|
||||
"completion",
|
||||
"embeddings",
|
||||
|
@ -179,7 +179,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"moderations",
|
||||
"model_list", # OpenAI /v1/models route
|
||||
]
|
||||
openai_routes: List = [
|
||||
openai_routes = [
|
||||
# chat completions
|
||||
"/engines/{model}/chat/completions",
|
||||
"/openai/deployments/{model}/chat/completions",
|
||||
|
@ -247,18 +247,18 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/v1/rerank",
|
||||
]
|
||||
|
||||
mapped_pass_through_routes: List = [
|
||||
mapped_pass_through_routes = [
|
||||
"/bedrock",
|
||||
"/vertex-ai",
|
||||
"/gemini",
|
||||
"/langfuse",
|
||||
]
|
||||
|
||||
anthropic_routes: List = [
|
||||
anthropic_routes = [
|
||||
"/v1/messages",
|
||||
]
|
||||
|
||||
info_routes: List = [
|
||||
info_routes = [
|
||||
"/key/info",
|
||||
"/team/info",
|
||||
"/team/list",
|
||||
|
@ -271,9 +271,9 @@ class LiteLLMRoutes(enum.Enum):
|
|||
]
|
||||
|
||||
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
||||
master_key_only_routes: List = ["/global/spend/reset", "/key/list"]
|
||||
master_key_only_routes = ["/global/spend/reset", "/key/list"]
|
||||
|
||||
sso_only_routes: List = [
|
||||
sso_only_routes = [
|
||||
"/key/generate",
|
||||
"/key/update",
|
||||
"/key/delete",
|
||||
|
@ -282,7 +282,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/sso/get/logout_url",
|
||||
]
|
||||
|
||||
management_routes: List = [ # key
|
||||
management_routes = [ # key
|
||||
"/key/generate",
|
||||
"/key/update",
|
||||
"/key/delete",
|
||||
|
@ -307,7 +307,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/model/info",
|
||||
]
|
||||
|
||||
spend_tracking_routes: List = [
|
||||
spend_tracking_routes = [
|
||||
# spend
|
||||
"/spend/keys",
|
||||
"/spend/users",
|
||||
|
@ -316,7 +316,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/spend/logs",
|
||||
]
|
||||
|
||||
global_spend_tracking_routes: List = [
|
||||
global_spend_tracking_routes = [
|
||||
# global spend
|
||||
"/global/spend/logs",
|
||||
"/global/spend",
|
||||
|
@ -328,7 +328,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/global/spend/report",
|
||||
]
|
||||
|
||||
public_routes: List = [
|
||||
public_routes = [
|
||||
"/routes",
|
||||
"/",
|
||||
"/health/liveliness",
|
||||
|
@ -339,7 +339,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/metrics",
|
||||
]
|
||||
|
||||
internal_user_routes: List = (
|
||||
internal_user_routes = (
|
||||
[
|
||||
"/key/generate",
|
||||
"/key/update",
|
||||
|
@ -357,7 +357,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
+ sso_only_routes
|
||||
)
|
||||
|
||||
self_managed_routes: List = [
|
||||
self_managed_routes = [
|
||||
"/team/member_add",
|
||||
"/team/member_delete",
|
||||
] # routes that manage their own allowed/disallowed logic
|
||||
|
@ -581,7 +581,9 @@ class ModelParams(LiteLLMBase):
|
|||
@classmethod
|
||||
def set_model_info(cls, values):
|
||||
if values.get("model_info") is None:
|
||||
values.update({"model_info": ModelInfo()})
|
||||
values.update(
|
||||
{"model_info": ModelInfo(id=None, mode="chat", base_model=None)}
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
|
@ -627,7 +629,7 @@ class GenerateKeyRequest(_GenerateKeyRequest):
|
|||
|
||||
|
||||
class GenerateKeyResponse(_GenerateKeyRequest):
|
||||
key: str
|
||||
key: str # type: ignore
|
||||
key_name: Optional[str] = None
|
||||
expires: Optional[datetime]
|
||||
user_id: Optional[str] = None
|
||||
|
@ -659,7 +661,7 @@ class GenerateKeyResponse(_GenerateKeyRequest):
|
|||
class UpdateKeyRequest(GenerateKeyRequest):
|
||||
# Note: the defaults of all Params here MUST BE NONE
|
||||
# else they will get overwritten
|
||||
key: str
|
||||
key: str # type: ignore
|
||||
duration: Optional[str] = None
|
||||
spend: Optional[float] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
@ -976,6 +978,7 @@ class TeamCallbackMetadata(LiteLLMBase):
|
|||
|
||||
|
||||
class LiteLLM_TeamTable(TeamBase):
|
||||
team_id: str # type: ignore
|
||||
spend: Optional[float] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
|
@ -1061,7 +1064,7 @@ class LiteLLM_OrganizationTable(LiteLLMBase):
|
|||
|
||||
|
||||
class NewOrganizationResponse(LiteLLM_OrganizationTable):
|
||||
organization_id: str
|
||||
organization_id: str # type: ignore
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
@ -1388,16 +1391,7 @@ class UserAPIKeyAuth(
|
|||
"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
user_role: Optional[
|
||||
Literal[
|
||||
LitellmUserRoles.PROXY_ADMIN,
|
||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||
LitellmUserRoles.TEAM,
|
||||
LitellmUserRoles.CUSTOMER,
|
||||
]
|
||||
] = None
|
||||
user_role: Optional[LitellmUserRoles] = None
|
||||
allowed_model_region: Optional[Literal["eu"]] = None
|
||||
parent_otel_span: Optional[Span] = None
|
||||
rpm_limit_per_model: Optional[Dict[str, int]] = None
|
||||
|
@ -1716,9 +1710,9 @@ class SpendLogsPayload(TypedDict):
|
|||
total_tokens: int
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
startTime: datetime
|
||||
endTime: datetime
|
||||
completionStartTime: Optional[datetime]
|
||||
startTime: Union[datetime, str]
|
||||
endTime: Union[datetime, str]
|
||||
completionStartTime: Optional[Union[datetime, str]]
|
||||
model: str
|
||||
model_id: Optional[str]
|
||||
model_group: Optional[str]
|
||||
|
@ -1891,6 +1885,6 @@ class TeamAddMemberResponse(LiteLLM_TeamTable):
|
|||
|
||||
class TeamInfoResponseObject(TypedDict):
|
||||
team_id: str
|
||||
team_info: TeamBase
|
||||
team_info: LiteLLM_TeamTable
|
||||
keys: List
|
||||
team_memberships: List[LiteLLM_TeamMembership]
|
||||
|
|
|
@ -109,11 +109,8 @@ async def user_api_key_auth(
|
|||
),
|
||||
) -> UserAPIKeyAuth:
|
||||
from litellm.proxy.proxy_server import (
|
||||
allowed_routes_check,
|
||||
common_checks,
|
||||
custom_db_client,
|
||||
general_settings,
|
||||
get_actual_routes,
|
||||
jwt_handler,
|
||||
litellm_proxy_admin_name,
|
||||
llm_model_list,
|
||||
|
@ -125,6 +122,8 @@ async def user_api_key_auth(
|
|||
user_custom_auth,
|
||||
)
|
||||
|
||||
parent_otel_span: Optional[Span] = None
|
||||
|
||||
try:
|
||||
route: str = get_request_route(request=request)
|
||||
# get the request body
|
||||
|
@ -137,6 +136,7 @@ async def user_api_key_auth(
|
|||
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
||||
"pass_through_endpoints", None
|
||||
)
|
||||
passed_in_key: Optional[str] = None
|
||||
if isinstance(api_key, str):
|
||||
passed_in_key = api_key
|
||||
api_key = _get_bearer_token(api_key=api_key)
|
||||
|
@ -161,7 +161,6 @@ async def user_api_key_auth(
|
|||
custom_litellm_key_header_name=custom_litellm_key_header_name,
|
||||
)
|
||||
|
||||
parent_otel_span: Optional[Span] = None
|
||||
if open_telemetry_logger is not None:
|
||||
parent_otel_span = open_telemetry_logger.tracer.start_span(
|
||||
name="Received Proxy Server Request",
|
||||
|
@ -189,7 +188,7 @@ async def user_api_key_auth(
|
|||
|
||||
######## Route Checks Before Reading DB / Cache for "token" ################
|
||||
if (
|
||||
route in LiteLLMRoutes.public_routes.value
|
||||
route in LiteLLMRoutes.public_routes.value # type: ignore
|
||||
or route_in_additonal_public_routes(current_route=route)
|
||||
):
|
||||
# check if public endpoint
|
||||
|
@ -410,7 +409,7 @@ async def user_api_key_auth(
|
|||
#### ELSE ####
|
||||
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||
is_mapped_pass_through_route: bool = False
|
||||
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value:
|
||||
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore
|
||||
if route.startswith(mapped_route):
|
||||
is_mapped_pass_through_route = True
|
||||
if is_mapped_pass_through_route:
|
||||
|
@ -444,9 +443,9 @@ async def user_api_key_auth(
|
|||
header_key = headers.get("litellm_user_api_key", "")
|
||||
if (
|
||||
isinstance(request.headers, dict)
|
||||
and request.headers.get(key=header_key) is not None
|
||||
and request.headers.get(key=header_key) is not None # type: ignore
|
||||
):
|
||||
api_key = request.headers.get(key=header_key)
|
||||
api_key = request.headers.get(key=header_key) # type: ignore
|
||||
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
|
@ -606,7 +605,7 @@ async def user_api_key_auth(
|
|||
|
||||
## IF it's not a master key
|
||||
## Route should not be in master_key_only_routes
|
||||
if route in LiteLLMRoutes.master_key_only_routes.value:
|
||||
if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore
|
||||
raise Exception(
|
||||
f"Tried to access route={route}, which is only for MASTER KEY"
|
||||
)
|
||||
|
@ -669,8 +668,9 @@ async def user_api_key_auth(
|
|||
"allowed_model_region"
|
||||
)
|
||||
|
||||
user_obj: Optional[LiteLLM_UserTable] = None
|
||||
valid_token_dict: dict = {}
|
||||
if valid_token is not None:
|
||||
user_obj: Optional[LiteLLM_UserTable] = None
|
||||
# Got Valid Token from Cache, DB
|
||||
# Run checks for
|
||||
# 1. If token can call model
|
||||
|
@ -686,6 +686,7 @@ async def user_api_key_auth(
|
|||
|
||||
# Check 1. If token can call model
|
||||
_model_alias_map = {}
|
||||
model: Optional[str] = None
|
||||
if (
|
||||
hasattr(valid_token, "team_model_aliases")
|
||||
and valid_token.team_model_aliases is not None
|
||||
|
@ -698,6 +699,7 @@ async def user_api_key_auth(
|
|||
_model_alias_map = {**valid_token.aliases}
|
||||
litellm.model_alias_map = _model_alias_map
|
||||
config = valid_token.config
|
||||
|
||||
if config != {}:
|
||||
model_list = config.get("model_list", [])
|
||||
llm_model_list = model_list
|
||||
|
@ -887,7 +889,10 @@ async def user_api_key_auth(
|
|||
and max_budget_per_model.get(current_model, None) is not None
|
||||
):
|
||||
if (
|
||||
model_spend[0]["model"] == current_model
|
||||
"model" in model_spend[0]
|
||||
and model_spend[0].get("model") == current_model
|
||||
and "_sum" in model_spend[0]
|
||||
and "spend" in model_spend[0]["_sum"]
|
||||
and model_spend[0]["_sum"]["spend"]
|
||||
>= max_budget_per_model[current_model]
|
||||
):
|
||||
|
@ -927,16 +932,19 @@ async def user_api_key_auth(
|
|||
)
|
||||
|
||||
# Check 8: Additional Common Checks across jwt + key auth
|
||||
_team_obj = LiteLLM_TeamTable(
|
||||
team_id=valid_token.team_id,
|
||||
max_budget=valid_token.team_max_budget,
|
||||
spend=valid_token.team_spend,
|
||||
tpm_limit=valid_token.team_tpm_limit,
|
||||
rpm_limit=valid_token.team_rpm_limit,
|
||||
blocked=valid_token.team_blocked,
|
||||
models=valid_token.team_models,
|
||||
metadata=valid_token.team_metadata,
|
||||
)
|
||||
if valid_token.team_id is not None:
|
||||
_team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
|
||||
team_id=valid_token.team_id,
|
||||
max_budget=valid_token.team_max_budget,
|
||||
spend=valid_token.team_spend,
|
||||
tpm_limit=valid_token.team_tpm_limit,
|
||||
rpm_limit=valid_token.team_rpm_limit,
|
||||
blocked=valid_token.team_blocked,
|
||||
models=valid_token.team_models,
|
||||
metadata=valid_token.team_metadata,
|
||||
)
|
||||
else:
|
||||
_team_obj = None
|
||||
|
||||
user_api_key_cache.set_cache(
|
||||
key=valid_token.team_id, value=_team_obj
|
||||
|
@ -1045,7 +1053,7 @@ async def user_api_key_auth(
|
|||
"/global/predict/spend/logs",
|
||||
"/global/activity",
|
||||
"/health/services",
|
||||
] + LiteLLMRoutes.info_routes.value
|
||||
] + LiteLLMRoutes.info_routes.value # type: ignore
|
||||
# check if the current route startswith any of the allowed routes
|
||||
if (
|
||||
route is not None
|
||||
|
@ -1106,7 +1114,7 @@ async def user_api_key_auth(
|
|||
|
||||
# Log this exception to OTEL
|
||||
if open_telemetry_logger is not None:
|
||||
await open_telemetry_logger.async_post_call_failure_hook(
|
||||
await open_telemetry_logger.async_post_call_failure_hook( # type: ignore
|
||||
original_exception=e,
|
||||
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
|
||||
)
|
||||
|
|
|
@ -4,10 +4,11 @@ import json
|
|||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
@ -28,6 +29,7 @@ from litellm.proxy._types import (
|
|||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
TeamAddMemberResponse,
|
||||
TeamBase,
|
||||
TeamInfoResponseObject,
|
||||
TeamMemberAddRequest,
|
||||
TeamMemberDeleteRequest,
|
||||
|
@ -36,6 +38,7 @@ from litellm.proxy._types import (
|
|||
UpdateTeamRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
|
@ -240,7 +243,7 @@ async def new_team(
|
|||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
complete_team_data.budget_reset_at = reset_at
|
||||
|
||||
team_row = await prisma_client.insert_data(
|
||||
team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore
|
||||
data=complete_team_data.json(exclude_none=True), table_name="team"
|
||||
)
|
||||
|
||||
|
@ -462,9 +465,6 @@ async def team_member_add(
|
|||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
_duration_in_seconds,
|
||||
create_audit_log_for_update,
|
||||
get_team_object,
|
||||
litellm_proxy_admin_name,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
|
@ -932,10 +932,13 @@ async def delete_team(
|
|||
if litellm.store_audit_logs is True:
|
||||
# make an audit log for each team deleted
|
||||
for team_id in data.team_ids:
|
||||
team_row = await prisma_client.get_data( # type: ignore
|
||||
team_row: Optional[LiteLLM_TeamTable] = await prisma_client.get_data( # type: ignore
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
|
||||
if team_row is None:
|
||||
continue
|
||||
|
||||
_team_row = team_row.json(exclude_none=True)
|
||||
|
||||
asyncio.create_task(
|
||||
|
@ -1027,8 +1030,10 @@ async def team_info(
|
|||
),
|
||||
)
|
||||
|
||||
team_info = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
team_info: Optional[Union[LiteLLM_TeamTable, dict]] = (
|
||||
await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
)
|
||||
if team_info is None:
|
||||
raise HTTPException(
|
||||
|
@ -1044,6 +1049,9 @@ async def team_info(
|
|||
expires=datetime.now(),
|
||||
)
|
||||
|
||||
if keys is None:
|
||||
keys = []
|
||||
|
||||
if team_info is None:
|
||||
## make sure we still return a total spend ##
|
||||
spend = 0
|
||||
|
@ -1055,7 +1063,7 @@ async def team_info(
|
|||
for key in keys:
|
||||
try:
|
||||
key = key.model_dump() # noqa
|
||||
except:
|
||||
except Exception:
|
||||
# if using pydantic v1
|
||||
key = key.dict()
|
||||
key.pop("token", None)
|
||||
|
@ -1070,9 +1078,16 @@ async def team_info(
|
|||
for tm in team_memberships:
|
||||
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
|
||||
|
||||
if isinstance(team_info, dict):
|
||||
_team_info = LiteLLM_TeamTable(**team_info)
|
||||
elif isinstance(team_info, BaseModel):
|
||||
_team_info = LiteLLM_TeamTable(**team_info.model_dump())
|
||||
else:
|
||||
_team_info = LiteLLM_TeamTable()
|
||||
|
||||
response_object = TeamInfoResponseObject(
|
||||
team_id=team_id,
|
||||
team_info=team_info,
|
||||
team_info=_team_info,
|
||||
keys=keys,
|
||||
team_memberships=returned_tm,
|
||||
)
|
||||
|
|
|
@ -125,16 +125,7 @@ from litellm.proxy._types import *
|
|||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||
router as analytics_router,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
allowed_routes_check,
|
||||
common_checks,
|
||||
get_actual_routes,
|
||||
get_end_user_object,
|
||||
get_org_object,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
log_to_opentelemetry,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import log_to_opentelemetry
|
||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||
|
@ -260,6 +251,7 @@ from litellm.secret_managers.aws_secret_manager import (
|
|||
load_aws_secret_manager,
|
||||
)
|
||||
from litellm.secret_managers.google_kms import load_google_kms
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicResponse,
|
||||
|
@ -484,7 +476,7 @@ general_settings: dict = {}
|
|||
callback_settings: dict = {}
|
||||
log_file = "api_log.json"
|
||||
worker_config = None
|
||||
master_key = None
|
||||
master_key: Optional[str] = None
|
||||
otel_logging = False
|
||||
prisma_client: Optional[PrismaClient] = None
|
||||
custom_db_client: Optional[DBClient] = None
|
||||
|
@ -874,7 +866,9 @@ def error_tracking():
|
|||
|
||||
|
||||
def _set_spend_logs_payload(
|
||||
payload: dict, prisma_client: PrismaClient, spend_logs_url: Optional[str] = None
|
||||
payload: Union[dict, SpendLogsPayload],
|
||||
prisma_client: PrismaClient,
|
||||
spend_logs_url: Optional[str] = None,
|
||||
):
|
||||
if prisma_client is not None and spend_logs_url is not None:
|
||||
if isinstance(payload["startTime"], datetime):
|
||||
|
@ -1341,6 +1335,9 @@ async def _run_background_health_check():
|
|||
# make 1 deep copy of llm_model_list -> use this for all background health checks
|
||||
_llm_model_list = copy.deepcopy(llm_model_list)
|
||||
|
||||
if _llm_model_list is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
||||
model_list=_llm_model_list, details=health_check_details
|
||||
|
@ -1352,7 +1349,10 @@ async def _run_background_health_check():
|
|||
health_check_results["healthy_count"] = len(healthy_endpoints)
|
||||
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
|
||||
|
||||
await asyncio.sleep(health_check_interval)
|
||||
if health_check_interval is not None and isinstance(
|
||||
health_check_interval, float
|
||||
):
|
||||
await asyncio.sleep(health_check_interval)
|
||||
|
||||
|
||||
class ProxyConfig:
|
||||
|
@ -1467,7 +1467,7 @@ class ProxyConfig:
|
|||
break
|
||||
for k, v in team_config.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
team_config[k] = litellm.get_secret(v)
|
||||
team_config[k] = get_secret(v)
|
||||
return team_config
|
||||
|
||||
def _init_cache(
|
||||
|
@ -1513,6 +1513,9 @@ class ProxyConfig:
|
|||
config = get_file_contents_from_s3(
|
||||
bucket_name=bucket_name, object_key=object_key
|
||||
)
|
||||
|
||||
if config is None:
|
||||
raise Exception("Unable to load config from given source.")
|
||||
else:
|
||||
# default to file
|
||||
config = await self.get_config(config_file_path=config_file_path)
|
||||
|
@ -1528,9 +1531,7 @@ class ProxyConfig:
|
|||
environment_variables = config.get("environment_variables", None)
|
||||
if environment_variables:
|
||||
for key, value in environment_variables.items():
|
||||
os.environ[key] = str(
|
||||
litellm.get_secret(secret_name=key, default_value=value)
|
||||
)
|
||||
os.environ[key] = str(get_secret(secret_name=key, default_value=value))
|
||||
|
||||
# check if litellm_license in general_settings
|
||||
if "LITELLM_LICENSE" in environment_variables:
|
||||
|
@ -1566,8 +1567,8 @@ class ProxyConfig:
|
|||
if (
|
||||
cache_type == "redis" or cache_type == "redis-semantic"
|
||||
) and len(cache_params.keys()) == 0:
|
||||
cache_host = litellm.get_secret("REDIS_HOST", None)
|
||||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||
cache_host = get_secret("REDIS_HOST", None)
|
||||
cache_port = get_secret("REDIS_PORT", None)
|
||||
cache_password = None
|
||||
cache_params.update(
|
||||
{
|
||||
|
@ -1577,8 +1578,8 @@ class ProxyConfig:
|
|||
}
|
||||
)
|
||||
|
||||
if litellm.get_secret("REDIS_PASSWORD", None) is not None:
|
||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||
if get_secret("REDIS_PASSWORD", None) is not None:
|
||||
cache_password = get_secret("REDIS_PASSWORD", None)
|
||||
cache_params.update(
|
||||
{
|
||||
"password": cache_password,
|
||||
|
@ -1617,7 +1618,7 @@ class ProxyConfig:
|
|||
# users can pass os.environ/ variables on the proxy - we should read them from the env
|
||||
for key, value in cache_params.items():
|
||||
if type(value) is str and value.startswith("os.environ/"):
|
||||
cache_params[key] = litellm.get_secret(value)
|
||||
cache_params[key] = get_secret(value)
|
||||
|
||||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||
self._init_cache(cache_params=cache_params)
|
||||
|
@ -1738,7 +1739,7 @@ class ProxyConfig:
|
|||
if value is not None and isinstance(value, dict):
|
||||
for _k, _v in value.items():
|
||||
if isinstance(_v, str) and _v.startswith("os.environ/"):
|
||||
value[_k] = litellm.get_secret(_v)
|
||||
value[_k] = get_secret(_v)
|
||||
litellm.upperbound_key_generate_params = (
|
||||
LiteLLM_UpperboundKeyGenerateParams(**value)
|
||||
)
|
||||
|
@ -1812,15 +1813,15 @@ class ProxyConfig:
|
|||
database_url = general_settings.get("database_url", None)
|
||||
if database_url and database_url.startswith("os.environ/"):
|
||||
verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!")
|
||||
database_url = litellm.get_secret(database_url)
|
||||
database_url = get_secret(database_url)
|
||||
verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url)
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get(
|
||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
"master_key", get_secret("LITELLM_MASTER_KEY", None)
|
||||
)
|
||||
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
master_key = get_secret(master_key) # type: ignore
|
||||
if not isinstance(master_key, str):
|
||||
raise Exception(
|
||||
"Master key must be a string. Current type - {}".format(
|
||||
|
@ -1861,33 +1862,6 @@ class ProxyConfig:
|
|||
await initialize_pass_through_endpoints(
|
||||
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
||||
)
|
||||
## dynamodb
|
||||
database_type = general_settings.get("database_type", None)
|
||||
if database_type is not None and (
|
||||
database_type == "dynamo_db" or database_type == "dynamodb"
|
||||
):
|
||||
database_args = general_settings.get("database_args", None)
|
||||
### LOAD FROM os.environ/ ###
|
||||
for k, v in database_args.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
database_args[k] = litellm.get_secret(v)
|
||||
if isinstance(k, str) and k == "aws_web_identity_token":
|
||||
value = database_args[k]
|
||||
verbose_proxy_logger.debug(
|
||||
f"Loading AWS Web Identity Token from file: {value}"
|
||||
)
|
||||
if os.path.exists(value):
|
||||
with open(value, "r") as file:
|
||||
token_content = file.read()
|
||||
database_args[k] = token_content
|
||||
else:
|
||||
verbose_proxy_logger.info(
|
||||
f"DynamoDB Loading - {value} is not a valid file path"
|
||||
)
|
||||
verbose_proxy_logger.debug("database_args: %s", database_args)
|
||||
custom_db_client = DBClient(
|
||||
custom_db_args=database_args, custom_db_type=database_type
|
||||
)
|
||||
## ADMIN UI ACCESS ##
|
||||
ui_access_mode = general_settings.get(
|
||||
"ui_access_mode", "all"
|
||||
|
@ -1951,7 +1925,7 @@ class ProxyConfig:
|
|||
### LOAD FROM os.environ/ ###
|
||||
for k, v in model["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
model["litellm_params"][k] = litellm.get_secret(v)
|
||||
model["litellm_params"][k] = get_secret(v)
|
||||
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
|
||||
litellm_model_name = model["litellm_params"]["model"]
|
||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||
|
@ -2005,7 +1979,10 @@ class ProxyConfig:
|
|||
) # type:ignore
|
||||
|
||||
# Guardrail settings
|
||||
guardrails_v2 = config.get("guardrails", None)
|
||||
guardrails_v2: Optional[dict] = None
|
||||
|
||||
if config is not None:
|
||||
guardrails_v2 = config.get("guardrails", None)
|
||||
if guardrails_v2:
|
||||
init_guardrails_v2(
|
||||
all_guardrails=guardrails_v2, config_file_path=config_file_path
|
||||
|
@ -2074,7 +2051,7 @@ class ProxyConfig:
|
|||
### LOAD FROM os.environ/ ###
|
||||
for k, v in model["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
model["litellm_params"][k] = litellm.get_secret(v)
|
||||
model["litellm_params"][k] = get_secret(v)
|
||||
|
||||
## check if they have model-id's ##
|
||||
model_id = model.get("model_info", {}).get("id", None)
|
||||
|
@ -2234,7 +2211,8 @@ class ProxyConfig:
|
|||
for k, v in environment_variables.items():
|
||||
try:
|
||||
decrypted_value = decrypt_value_helper(value=v)
|
||||
os.environ[k] = decrypted_value
|
||||
if decrypted_value is not None:
|
||||
os.environ[k] = decrypted_value
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"Error setting env variable: %s - %s", k, str(e)
|
||||
|
@ -2536,7 +2514,7 @@ async def async_assistants_data_generator(
|
|||
)
|
||||
|
||||
# chunk = chunk.model_dump_json(exclude_none=True)
|
||||
async for c in chunk:
|
||||
async for c in chunk: # type: ignore
|
||||
c = c.model_dump_json(exclude_none=True)
|
||||
try:
|
||||
yield f"data: {c}\n\n"
|
||||
|
@ -2745,17 +2723,22 @@ async def startup_event():
|
|||
|
||||
### LOAD MASTER KEY ###
|
||||
# check if master key set in environment - load from there
|
||||
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
|
||||
# check if DATABASE_URL in environment - load from there
|
||||
if prisma_client is None:
|
||||
prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None))
|
||||
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
|
||||
prisma_setup(database_url=_db_url)
|
||||
|
||||
### LOAD CONFIG ###
|
||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
||||
# check if it's a valid file path
|
||||
if os.path.isfile(worker_config):
|
||||
if proxy_config.is_yaml(config_file_path=worker_config):
|
||||
if worker_config is not None:
|
||||
if (
|
||||
isinstance(worker_config, str)
|
||||
and os.path.isfile(worker_config)
|
||||
and proxy_config.is_yaml(config_file_path=worker_config)
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
|
@ -2763,21 +2746,23 @@ async def startup_event():
|
|||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
else:
|
||||
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
|
||||
worker_config, str
|
||||
):
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
elif isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=worker_config
|
||||
)
|
||||
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||
await initialize(**worker_config)
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(worker_config)
|
||||
if isinstance(worker_config, dict):
|
||||
await initialize(**worker_config)
|
||||
|
||||
## CHECK PREMIUM USER
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -2825,7 +2810,7 @@ async def startup_event():
|
|||
if general_settings.get("litellm_jwtauth", None) is not None:
|
||||
for k, v in general_settings["litellm_jwtauth"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
general_settings["litellm_jwtauth"][k] = litellm.get_secret(v)
|
||||
general_settings["litellm_jwtauth"][k] = get_secret(v)
|
||||
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
|
||||
else:
|
||||
litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
@ -2948,8 +2933,7 @@ async def startup_event():
|
|||
|
||||
### ADD NEW MODELS ###
|
||||
store_model_in_db = (
|
||||
litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db)
|
||||
or store_model_in_db
|
||||
get_secret("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db
|
||||
) # type: ignore
|
||||
if store_model_in_db == True:
|
||||
scheduler.add_job(
|
||||
|
@ -3498,7 +3482,7 @@ async def completion(
|
|||
)
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
|
||||
)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
|
@ -4000,7 +3984,7 @@ async def audio_speech(
|
|||
request_data=data,
|
||||
)
|
||||
return StreamingResponse(
|
||||
generate(response), media_type="audio/mpeg", headers=custom_headers
|
||||
generate(response), media_type="audio/mpeg", headers=custom_headers # type: ignore
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -4288,6 +4272,7 @@ async def create_assistant(
|
|||
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
data = {} # ensure data always dict
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
|
@ -7642,6 +7627,7 @@ async def model_group_info(
|
|||
)
|
||||
|
||||
model_groups: List[ModelGroupInfo] = []
|
||||
|
||||
for model in all_models_str:
|
||||
|
||||
_model_group_info = llm_router.get_model_group_info(model_group=model)
|
||||
|
@ -8051,7 +8037,8 @@ async def google_login(request: Request):
|
|||
with microsoft_sso:
|
||||
return await microsoft_sso.get_login_redirect()
|
||||
elif generic_client_id is not None:
|
||||
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
|
@ -8616,6 +8603,8 @@ async def auth_callback(request: Request):
|
|||
redirect_url += "sso/callback"
|
||||
else:
|
||||
redirect_url += "/sso/callback"
|
||||
|
||||
result = None
|
||||
if google_client_id is not None:
|
||||
from fastapi_sso.sso.google import GoogleSSO
|
||||
|
||||
|
@ -8662,7 +8651,8 @@ async def auth_callback(request: Request):
|
|||
result = await microsoft_sso.verify_and_process(request)
|
||||
elif generic_client_id is not None:
|
||||
# make generic sso provider
|
||||
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
|
||||
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
|
||||
from fastapi_sso.sso.generic import create_provider
|
||||
|
||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||
|
@ -8766,8 +8756,8 @@ async def auth_callback(request: Request):
|
|||
verbose_proxy_logger.debug("generic result: %s", result)
|
||||
|
||||
# User is Authe'd in - generate key for the UI to access Proxy
|
||||
user_email = getattr(result, "email", None)
|
||||
user_id = getattr(result, "id", None)
|
||||
user_email: Optional[str] = getattr(result, "email", None)
|
||||
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
||||
|
||||
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
|
||||
email_domain = user_email.split("@")[1]
|
||||
|
@ -8783,12 +8773,12 @@ async def auth_callback(request: Request):
|
|||
)
|
||||
|
||||
# generic client id
|
||||
if generic_client_id is not None:
|
||||
if generic_client_id is not None and result is not None:
|
||||
user_id = getattr(result, "id", None)
|
||||
user_email = getattr(result, "email", None)
|
||||
user_role = getattr(result, generic_user_role_attribute_name, None)
|
||||
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
|
||||
|
||||
if user_id is None:
|
||||
if user_id is None and result is not None:
|
||||
_first_name = getattr(result, "first_name", "") or ""
|
||||
_last_name = getattr(result, "last_name", "") or ""
|
||||
user_id = _first_name + _last_name
|
||||
|
@ -8811,54 +8801,45 @@ async def auth_callback(request: Request):
|
|||
"spend": 0,
|
||||
"team_id": "litellm-dashboard",
|
||||
}
|
||||
user_defined_values: SSOUserDefinedValues = {
|
||||
"models": user_id_models,
|
||||
"user_id": user_id,
|
||||
"user_email": user_email,
|
||||
"max_budget": max_internal_user_budget,
|
||||
"user_role": None,
|
||||
"budget_duration": internal_user_budget_duration,
|
||||
}
|
||||
user_defined_values: Optional[SSOUserDefinedValues] = None
|
||||
if user_id is not None:
|
||||
user_defined_values = SSOUserDefinedValues(
|
||||
models=user_id_models,
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
max_budget=max_internal_user_budget,
|
||||
user_role=None,
|
||||
budget_duration=internal_user_budget_duration,
|
||||
)
|
||||
|
||||
_user_id_from_sso = user_id
|
||||
user_role = None
|
||||
try:
|
||||
user_role = None
|
||||
if prisma_client is not None:
|
||||
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
|
||||
)
|
||||
if user_info is not None:
|
||||
user_defined_values = {
|
||||
"models": getattr(user_info, "models", user_id_models),
|
||||
"user_id": getattr(user_info, "user_id", user_id),
|
||||
"user_email": getattr(user_info, "user_id", user_email),
|
||||
"user_role": getattr(user_info, "user_role", None),
|
||||
"max_budget": getattr(
|
||||
user_info, "max_budget", max_internal_user_budget
|
||||
),
|
||||
"budget_duration": getattr(
|
||||
user_info, "budget_duration", internal_user_budget_duration
|
||||
),
|
||||
}
|
||||
user_role = getattr(user_info, "user_role", None)
|
||||
if user_info is None:
|
||||
## check if user-email in db ##
|
||||
user_info = await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_email": user_email}
|
||||
)
|
||||
|
||||
## check if user-email in db ##
|
||||
user_info = await prisma_client.db.litellm_usertable.find_first(
|
||||
where={"user_email": user_email}
|
||||
)
|
||||
if user_info is not None:
|
||||
user_defined_values = {
|
||||
"models": getattr(user_info, "models", user_id_models),
|
||||
"user_id": user_id,
|
||||
"user_email": getattr(user_info, "user_id", user_email),
|
||||
"user_role": getattr(user_info, "user_role", None),
|
||||
"max_budget": getattr(
|
||||
if user_info is not None and user_id is not None:
|
||||
user_defined_values = SSOUserDefinedValues(
|
||||
models=getattr(user_info, "models", user_id_models),
|
||||
user_id=user_id,
|
||||
user_email=getattr(user_info, "user_email", user_email),
|
||||
user_role=getattr(user_info, "user_role", None),
|
||||
max_budget=getattr(
|
||||
user_info, "max_budget", max_internal_user_budget
|
||||
),
|
||||
"budget_duration": getattr(
|
||||
budget_duration=getattr(
|
||||
user_info, "budget_duration", internal_user_budget_duration
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
user_role = getattr(user_info, "user_role", None)
|
||||
|
||||
# update id
|
||||
|
@ -8886,6 +8867,11 @@ async def auth_callback(request: Request):
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
if user_defined_values is None:
|
||||
raise Exception(
|
||||
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
|
||||
)
|
||||
|
||||
is_internal_user = False
|
||||
if (
|
||||
user_defined_values["user_role"] is not None
|
||||
|
@ -8960,7 +8946,8 @@ async def auth_callback(request: Request):
|
|||
master_key,
|
||||
algorithm="HS256",
|
||||
)
|
||||
litellm_dashboard_ui += "?userID=" + user_id
|
||||
if user_id is not None and isinstance(user_id, str):
|
||||
litellm_dashboard_ui += "?userID=" + user_id
|
||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
||||
return redirect_response
|
||||
|
@ -9023,6 +9010,7 @@ async def new_invitation(
|
|||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
} # type: ignore
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
if "Foreign key constraint failed on the field" in str(e):
|
||||
raise HTTPException(
|
||||
|
@ -9031,7 +9019,7 @@ async def new_invitation(
|
|||
"error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
|
||||
},
|
||||
)
|
||||
return response
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
@router.get(
|
||||
|
@ -9951,44 +9939,46 @@ async def get_routes():
|
|||
"""
|
||||
routes = []
|
||||
for route in app.routes:
|
||||
route_info = {
|
||||
"path": getattr(route, "path", None),
|
||||
"methods": getattr(route, "methods", None),
|
||||
"name": getattr(route, "name", None),
|
||||
"endpoint": (
|
||||
getattr(route, "endpoint", None).__name__
|
||||
if getattr(route, "endpoint", None)
|
||||
else None
|
||||
),
|
||||
}
|
||||
routes.append(route_info)
|
||||
endpoint_route = getattr(route, "endpoint", None)
|
||||
if endpoint_route is not None:
|
||||
route_info = {
|
||||
"path": getattr(route, "path", None),
|
||||
"methods": getattr(route, "methods", None),
|
||||
"name": getattr(route, "name", None),
|
||||
"endpoint": (
|
||||
endpoint_route.__name__
|
||||
if getattr(route, "endpoint", None)
|
||||
else None
|
||||
),
|
||||
}
|
||||
routes.append(route_info)
|
||||
|
||||
return {"routes": routes}
|
||||
|
||||
|
||||
#### TEST ENDPOINTS ####
|
||||
@router.get(
|
||||
"/token/generate",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def token_generate():
|
||||
"""
|
||||
Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
|
||||
"""
|
||||
# Initialize AuthJWTSSO with your OpenID Provider configuration
|
||||
from fastapi_sso import AuthJWTSSO
|
||||
# @router.get(
|
||||
# "/token/generate",
|
||||
# dependencies=[Depends(user_api_key_auth)],
|
||||
# include_in_schema=False,
|
||||
# )
|
||||
# async def token_generate():
|
||||
# """
|
||||
# Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
|
||||
# """
|
||||
# # Initialize AuthJWTSSO with your OpenID Provider configuration
|
||||
# from fastapi_sso import AuthJWTSSO
|
||||
|
||||
auth_jwt_sso = AuthJWTSSO(
|
||||
issuer=os.getenv("OPENID_BASE_URL"),
|
||||
client_id=os.getenv("OPENID_CLIENT_ID"),
|
||||
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
|
||||
scopes=["litellm_proxy_admin"],
|
||||
)
|
||||
# auth_jwt_sso = AuthJWTSSO(
|
||||
# issuer=os.getenv("OPENID_BASE_URL"),
|
||||
# client_id=os.getenv("OPENID_CLIENT_ID"),
|
||||
# client_secret=os.getenv("OPENID_CLIENT_SECRET"),
|
||||
# scopes=["litellm_proxy_admin"],
|
||||
# )
|
||||
|
||||
token = auth_jwt_sso.create_access_token()
|
||||
# token = auth_jwt_sso.create_access_token()
|
||||
|
||||
return {"token": token}
|
||||
# return {"token": token}
|
||||
|
||||
|
||||
@router.on_event("shutdown")
|
||||
|
@ -10013,7 +10003,8 @@ async def shutdown_event():
|
|||
# flush langfuse logs on shutdow
|
||||
from litellm.utils import langFuseLogger
|
||||
|
||||
langFuseLogger.Langfuse.flush()
|
||||
if langFuseLogger is not None:
|
||||
langFuseLogger.Langfuse.flush()
|
||||
except:
|
||||
# [DO NOT BLOCK shutdown events for this]
|
||||
pass
|
||||
|
|
|
@ -92,6 +92,7 @@ from litellm.types.router import (
|
|||
RetryPolicy,
|
||||
RouterErrors,
|
||||
RouterGeneralSettings,
|
||||
RouterModelGroupAliasItem,
|
||||
RouterRateLimitError,
|
||||
RouterRateLimitErrorBasic,
|
||||
updateDeployment,
|
||||
|
@ -105,6 +106,7 @@ from litellm.utils import (
|
|||
calculate_max_parallel_requests,
|
||||
create_proxy_transport_and_mounts,
|
||||
get_llm_provider,
|
||||
get_secret,
|
||||
get_utc_datetime,
|
||||
)
|
||||
|
||||
|
@ -156,7 +158,9 @@ class Router:
|
|||
fallbacks: List = [],
|
||||
context_window_fallbacks: List = [],
|
||||
content_policy_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
model_group_alias: Optional[
|
||||
Dict[str, Union[str, RouterModelGroupAliasItem]]
|
||||
] = {},
|
||||
enable_pre_call_checks: bool = False,
|
||||
enable_tag_filtering: bool = False,
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
|
@ -331,7 +335,8 @@ class Router:
|
|||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = self.model_list # type: ignore
|
||||
for m in model_list:
|
||||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||
if "model" in m["litellm_params"]:
|
||||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||
else:
|
||||
self.model_list: List = (
|
||||
[]
|
||||
|
@ -398,7 +403,7 @@ class Router:
|
|||
self.previous_models: List = (
|
||||
[]
|
||||
) # list to store failed calls (passed in as metadata to next call)
|
||||
self.model_group_alias: dict = (
|
||||
self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
|
||||
model_group_alias or {}
|
||||
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
|
||||
|
||||
|
@ -1179,6 +1184,7 @@ class Router:
|
|||
raise e
|
||||
|
||||
def _image_generation(self, prompt: str, model: str, **kwargs):
|
||||
model_name = ""
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
||||
|
@ -1269,6 +1275,7 @@ class Router:
|
|||
raise e
|
||||
|
||||
async def _aimage_generation(self, prompt: str, model: str, **kwargs):
|
||||
model_name = ""
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
||||
|
@ -1401,6 +1408,7 @@ class Router:
|
|||
raise e
|
||||
|
||||
async def _atranscription(self, file: FileTypes, model: str, **kwargs):
|
||||
model_name = model
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
|
||||
|
@ -1781,6 +1789,7 @@ class Router:
|
|||
is_async: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
|
@ -1789,7 +1798,6 @@ class Router:
|
|||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(
|
||||
model=model,
|
||||
|
@ -2534,13 +2542,13 @@ class Router:
|
|||
try:
|
||||
# Update kwargs with the current model name or any other model-specific adjustments
|
||||
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
||||
model=model_name["litellm_params"]["model"]
|
||||
)
|
||||
new_kwargs = copy.deepcopy(kwargs)
|
||||
new_kwargs.pop("custom_llm_provider", None)
|
||||
return await litellm.aretrieve_batch(
|
||||
custom_llm_provider=custom_llm_provider, **new_kwargs
|
||||
custom_llm_provider=custom_llm_provider, **new_kwargs # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
receieved_exceptions.append(e)
|
||||
|
@ -2616,13 +2624,13 @@ class Router:
|
|||
for result in results:
|
||||
if result is not None:
|
||||
## check batch id
|
||||
if final_results["first_id"] is None:
|
||||
final_results["first_id"] = result.first_id
|
||||
final_results["last_id"] = result.last_id
|
||||
if final_results["first_id"] is None and hasattr(result, "first_id"):
|
||||
final_results["first_id"] = getattr(result, "first_id")
|
||||
final_results["last_id"] = getattr(result, "last_id")
|
||||
final_results["data"].extend(result.data) # type: ignore
|
||||
|
||||
## check 'has_more'
|
||||
if result.has_more is True:
|
||||
if getattr(result, "has_more", False) is True:
|
||||
final_results["has_more"] = True
|
||||
|
||||
return final_results
|
||||
|
@ -2874,8 +2882,12 @@ class Router:
|
|||
verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
|
||||
original_exception = e
|
||||
fallback_model_group = None
|
||||
original_model_group = kwargs.get("model")
|
||||
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
|
||||
fallback_failure_exception_str = ""
|
||||
|
||||
if original_model_group is None:
|
||||
raise e
|
||||
|
||||
try:
|
||||
verbose_router_logger.debug("Trying to fallback b/w models")
|
||||
if isinstance(e, litellm.ContextWindowExceededError):
|
||||
|
@ -2972,7 +2984,7 @@ class Router:
|
|||
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||
)
|
||||
if hasattr(original_exception, "message"):
|
||||
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore
|
||||
raise original_exception
|
||||
|
||||
response = await run_async_fallback(
|
||||
|
@ -2996,12 +3008,12 @@ class Router:
|
|||
|
||||
if hasattr(original_exception, "message"):
|
||||
# add the available fallbacks to the exception
|
||||
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format(
|
||||
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
|
||||
model_group,
|
||||
fallback_model_group,
|
||||
)
|
||||
if len(fallback_failure_exception_str) > 0:
|
||||
original_exception.message += (
|
||||
original_exception.message += ( # type: ignore
|
||||
"\nError doing the fallback: {}".format(
|
||||
fallback_failure_exception_str
|
||||
)
|
||||
|
@ -3117,9 +3129,15 @@ class Router:
|
|||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_healthy_deployments, _ = await self._async_get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
)
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
if _model is not None:
|
||||
_healthy_deployments, _ = (
|
||||
await self._async_get_healthy_deployments(
|
||||
model=_model,
|
||||
)
|
||||
)
|
||||
else:
|
||||
_healthy_deployments = []
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=remaining_retries,
|
||||
|
@ -3129,8 +3147,8 @@ class Router:
|
|||
await asyncio.sleep(_timeout)
|
||||
|
||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
original_exception.max_retries = num_retries
|
||||
original_exception.num_retries = current_attempt
|
||||
setattr(original_exception, "max_retries", num_retries)
|
||||
setattr(original_exception, "num_retries", current_attempt)
|
||||
|
||||
raise original_exception
|
||||
|
||||
|
@ -3225,8 +3243,12 @@ class Router:
|
|||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
original_model_group = kwargs.get("model")
|
||||
original_model_group: Optional[str] = kwargs.get("model")
|
||||
verbose_router_logger.debug(f"An exception occurs {original_exception}")
|
||||
|
||||
if original_model_group is None:
|
||||
raise e
|
||||
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Trying to fallback b/w models. Initial model group: {model_group}"
|
||||
|
@ -3336,10 +3358,10 @@ class Router:
|
|||
return 0
|
||||
|
||||
response_headers: Optional[httpx.Headers] = None
|
||||
if hasattr(e, "response") and hasattr(e.response, "headers"):
|
||||
response_headers = e.response.headers
|
||||
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
|
||||
response_headers = e.response.headers # type: ignore
|
||||
elif hasattr(e, "litellm_response_headers"):
|
||||
response_headers = e.litellm_response_headers
|
||||
response_headers = e.litellm_response_headers # type: ignore
|
||||
|
||||
if response_headers is not None:
|
||||
timeout = litellm._calculate_retry_after(
|
||||
|
@ -3398,9 +3420,13 @@ class Router:
|
|||
except Exception as e:
|
||||
current_attempt = None
|
||||
original_exception = e
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
|
||||
if _model is None:
|
||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
model=_model,
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
|
@ -3438,8 +3464,12 @@ class Router:
|
|||
except Exception as e:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
|
||||
if _model is None:
|
||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||
_healthy_deployments, _ = self._get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
model=_model,
|
||||
)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
|
@ -4055,7 +4085,7 @@ class Router:
|
|||
if isinstance(_litellm_params, dict):
|
||||
for k, v in _litellm_params.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
_litellm_params[k] = litellm.get_secret(v)
|
||||
_litellm_params[k] = get_secret(v)
|
||||
|
||||
_model_info: dict = model.pop("model_info", {})
|
||||
|
||||
|
@ -4392,7 +4422,6 @@ class Router:
|
|||
- ModelGroupInfo if able to construct a model group
|
||||
- None if error constructing model group info
|
||||
"""
|
||||
|
||||
model_group_info: Optional[ModelGroupInfo] = None
|
||||
|
||||
total_tpm: Optional[int] = None
|
||||
|
@ -4557,12 +4586,23 @@ class Router:
|
|||
|
||||
Returns:
|
||||
- ModelGroupInfo if able to construct a model group
|
||||
- None if error constructing model group info
|
||||
- None if error constructing model group info or hidden model group
|
||||
"""
|
||||
## Check if model group alias
|
||||
if model_group in self.model_group_alias:
|
||||
item = self.model_group_alias[model_group]
|
||||
if isinstance(item, str):
|
||||
_router_model_group = item
|
||||
elif isinstance(item, dict):
|
||||
if item["hidden"] is True:
|
||||
return None
|
||||
else:
|
||||
_router_model_group = item["model"]
|
||||
else:
|
||||
return None
|
||||
|
||||
return self._set_model_group_info(
|
||||
model_group=self.model_group_alias[model_group],
|
||||
model_group=_router_model_group,
|
||||
user_facing_model_group_name=model_group,
|
||||
)
|
||||
|
||||
|
@ -4666,7 +4706,14 @@ class Router:
|
|||
|
||||
Includes model_group_alias models too.
|
||||
"""
|
||||
return self.model_names + list(self.model_group_alias.keys())
|
||||
model_list = self.get_model_list()
|
||||
if model_list is None:
|
||||
return []
|
||||
|
||||
model_names = []
|
||||
for m in model_list:
|
||||
model_names.append(m["model_name"])
|
||||
return model_names
|
||||
|
||||
def get_model_list(
|
||||
self, model_name: Optional[str] = None
|
||||
|
@ -4678,9 +4725,21 @@ class Router:
|
|||
returned_models: List[DeploymentTypedDict] = []
|
||||
|
||||
for model_alias, model_value in self.model_group_alias.items():
|
||||
|
||||
if isinstance(model_value, str):
|
||||
_router_model_name: str = model_value
|
||||
elif isinstance(model_value, dict):
|
||||
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
|
||||
if _model_value["hidden"] is True:
|
||||
continue
|
||||
else:
|
||||
_router_model_name = _model_value["model"]
|
||||
else:
|
||||
continue
|
||||
|
||||
returned_models.extend(
|
||||
self._get_all_deployments(
|
||||
model_name=model_value, model_alias=model_alias
|
||||
model_name=_router_model_name, model_alias=model_alias
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -5078,10 +5137,11 @@ class Router:
|
|||
)
|
||||
|
||||
if model in self.model_group_alias:
|
||||
verbose_router_logger.debug(
|
||||
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
|
||||
)
|
||||
model = self.model_group_alias[model]
|
||||
_item = self.model_group_alias[model]
|
||||
if isinstance(_item, str):
|
||||
model = _item
|
||||
else:
|
||||
model = _item["model"]
|
||||
|
||||
if model not in self.model_names:
|
||||
# check if provider/ specific wildcard routing
|
||||
|
@ -5124,7 +5184,9 @@ class Router:
|
|||
m for m in self.model_list if m["litellm_params"]["model"] == model
|
||||
]
|
||||
|
||||
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
verbose_router_logger.debug(
|
||||
f"initial list of deployments: {healthy_deployments}"
|
||||
)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError(
|
||||
|
@ -5208,7 +5270,7 @@ class Router:
|
|||
)
|
||||
|
||||
# check if user wants to do tag based routing
|
||||
healthy_deployments = await get_deployments_for_tag(
|
||||
healthy_deployments = await get_deployments_for_tag( # type: ignore
|
||||
llm_router_instance=self,
|
||||
request_kwargs=request_kwargs,
|
||||
healthy_deployments=healthy_deployments,
|
||||
|
@ -5241,7 +5303,7 @@ class Router:
|
|||
input=input,
|
||||
)
|
||||
)
|
||||
if (
|
||||
elif (
|
||||
self.routing_strategy == "cost-based-routing"
|
||||
and self.lowestcost_logger is not None
|
||||
):
|
||||
|
@ -5326,6 +5388,8 @@ class Router:
|
|||
############## No RPM/TPM passed, we do a random pick #################
|
||||
item = random.choice(healthy_deployments)
|
||||
return item or item[0]
|
||||
else:
|
||||
deployment = None
|
||||
if deployment is None:
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
|
@ -5515,6 +5579,9 @@ class Router:
|
|||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
else:
|
||||
deployment = None
|
||||
|
||||
if deployment is None:
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
|
@ -5690,6 +5757,9 @@ class Router:
|
|||
def _initialize_alerting(self):
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
|
||||
if self.alerting_config is None:
|
||||
return
|
||||
|
||||
router_alerting_config: AlertingConfig = self.alerting_config
|
||||
|
||||
_slack_alerting_logger = SlackAlerting(
|
||||
|
@ -5700,7 +5770,7 @@ class Router:
|
|||
|
||||
self.slack_alerting_logger = _slack_alerting_logger
|
||||
|
||||
litellm.callbacks.append(_slack_alerting_logger)
|
||||
litellm.callbacks.append(_slack_alerting_logger) # type: ignore
|
||||
litellm.success_callback.append(
|
||||
_slack_alerting_logger.response_taking_too_long_callback
|
||||
)
|
||||
|
|
|
@ -21,8 +21,8 @@ else:
|
|||
|
||||
async def get_deployments_for_tag(
|
||||
llm_router_instance: LitellmRouter,
|
||||
healthy_deployments: Union[List[Any], Dict[Any, Any]],
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||
):
|
||||
if llm_router_instance.enable_tag_filtering is not True:
|
||||
return healthy_deployments
|
||||
|
|
|
@ -2828,3 +2828,44 @@ def test_gemini_function_call_parameter_in_messages_2():
|
|||
]
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_model, metadata",
|
||||
[
|
||||
(None, {"model_info": {"base_model": "vertex_ai/gemini-1.5-pro"}}),
|
||||
("vertex_ai/gemini-1.5-pro", None),
|
||||
],
|
||||
)
|
||||
def test_gemini_finetuned_endpoint(base_model, metadata):
|
||||
litellm.set_verbose = True
|
||||
load_vertex_ai_credentials()
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
# Set up the messages
|
||||
messages = [
|
||||
{"role": "system", "content": """Use search for most queries."""},
|
||||
{"role": "user", "content": """search for weather in boston (use `search`)"""},
|
||||
]
|
||||
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_client:
|
||||
try:
|
||||
response = completion(
|
||||
model="vertex_ai/4965075652664360960",
|
||||
messages=messages,
|
||||
tool_choice="auto",
|
||||
client=client,
|
||||
metadata=metadata,
|
||||
base_model=base_model,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
print(mock_client.call_args.kwargs)
|
||||
|
||||
mock_client.assert_called()
|
||||
assert mock_client.call_args.kwargs["url"].endswith(
|
||||
"endpoints/4965075652664360960:generateContent"
|
||||
)
|
||||
|
|
|
@ -311,6 +311,8 @@ async def test_completion_predibase():
|
|||
pass
|
||||
except litellm.ServiceUnavailableError as e:
|
||||
pass
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -133,3 +133,21 @@ async def test_fireworks_health_check():
|
|||
assert response == {}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_rerank_health_check():
|
||||
response = await litellm.ahealth_check(
|
||||
model_params={
|
||||
"model": "cohere/rerank-english-v3.0",
|
||||
"query": "Hey, how's it going",
|
||||
"documents": ["my sample text"],
|
||||
"api_key": os.getenv("COHERE_API_KEY"),
|
||||
},
|
||||
mode="rerank",
|
||||
prompt="Hey, how's it going",
|
||||
)
|
||||
|
||||
assert "error" not in response
|
||||
|
||||
print(response)
|
||||
|
|
|
@ -50,6 +50,7 @@ def test_image_generation_openai():
|
|||
], # False
|
||||
) #
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_image_generation_azure(sync_mode):
|
||||
try:
|
||||
if sync_mode:
|
||||
|
|
|
@ -533,7 +533,7 @@ def test_call_with_user_over_budget(prisma_client):
|
|||
# use generated key to auth in
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
print("result from user auth with new key", result)
|
||||
pytest.fail(f"This should have failed!. They key crossed it's budget")
|
||||
pytest.fail("This should have failed!. They key crossed it's budget")
|
||||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
|
@ -1755,7 +1755,7 @@ def test_call_with_key_over_model_budget(prisma_client):
|
|||
# use generated key to auth in
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
print("result from user auth with new key", result)
|
||||
pytest.fail(f"This should have failed!. They key crossed it's budget")
|
||||
pytest.fail("This should have failed!. They key crossed it's budget")
|
||||
|
||||
asyncio.run(test())
|
||||
except Exception as e:
|
||||
|
|
|
@ -589,3 +589,14 @@ def test_parse_additional_properties_json_schema(model, provider, expectedAddPro
|
|||
elif provider == "openai":
|
||||
schema = optional_params["response_format"]["json_schema"]["schema"]
|
||||
assert ("additionalProperties" in schema) == expectedAddProp
|
||||
|
||||
|
||||
def test_o1_model_params():
|
||||
optional_params = get_optional_params(
|
||||
model="o1-preview-2024-09-12",
|
||||
custom_llm_provider="openai",
|
||||
seed=10,
|
||||
user="John",
|
||||
)
|
||||
assert optional_params["seed"] == 10
|
||||
assert optional_params["user"] == "John"
|
||||
|
|
|
@ -762,7 +762,7 @@ async def test_team_update_redis():
|
|||
) as mock_client:
|
||||
await _cache_team_object(
|
||||
team_id="1234",
|
||||
team_table=LiteLLM_TeamTableCachedObj(),
|
||||
team_table=LiteLLM_TeamTableCachedObj(team_id="1234"),
|
||||
user_api_key_cache=DualCache(),
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
@ -776,7 +776,7 @@ async def test_get_team_redis(client_no_auth):
|
|||
Tests if get_team_object gets value from redis cache, if set
|
||||
"""
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
|
||||
proxy_logging_obj: ProxyLogging = getattr(
|
||||
litellm.proxy.proxy_server, "proxy_logging_obj"
|
||||
|
@ -917,7 +917,9 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
|||
)
|
||||
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
||||
|
||||
team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj())
|
||||
team_mock_client.update = AsyncMock(
|
||||
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
|
||||
)
|
||||
|
||||
await team_member_add(
|
||||
data=team_member_add_request,
|
||||
|
@ -1095,7 +1097,9 @@ async def test_create_team_member_add_team_admin(
|
|||
)
|
||||
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
||||
|
||||
team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj())
|
||||
team_mock_client.update = AsyncMock(
|
||||
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
|
||||
)
|
||||
|
||||
try:
|
||||
await team_member_add(
|
||||
|
@ -1434,8 +1438,9 @@ async def test_gemini_pass_through_endpoint():
|
|||
print(resp.body)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_model_group_alias_checks(prisma_client):
|
||||
async def test_proxy_model_group_alias_checks(prisma_client, hidden):
|
||||
"""
|
||||
Check if model group alias is returned on
|
||||
|
||||
|
@ -1465,7 +1470,7 @@ async def test_proxy_model_group_alias_checks(prisma_client):
|
|||
model_alias = "gpt-4"
|
||||
router = litellm.Router(
|
||||
model_list=_model_list,
|
||||
model_group_alias={model_alias: "gpt-3.5-turbo"},
|
||||
model_group_alias={model_alias: {"model": "gpt-3.5-turbo", "hidden": hidden}},
|
||||
)
|
||||
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
||||
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
||||
|
@ -1477,7 +1482,10 @@ async def test_proxy_model_group_alias_checks(prisma_client):
|
|||
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||
)
|
||||
|
||||
assert len(resp) == 2
|
||||
if hidden:
|
||||
assert len(resp["data"]) == 1
|
||||
else:
|
||||
assert len(resp["data"]) == 2
|
||||
print(resp)
|
||||
|
||||
resp = await model_info_v1(
|
||||
|
@ -1489,7 +1497,10 @@ async def test_proxy_model_group_alias_checks(prisma_client):
|
|||
if model_alias == item["model_name"]:
|
||||
is_model_alias_in_list = True
|
||||
|
||||
assert is_model_alias_in_list
|
||||
if hidden:
|
||||
assert is_model_alias_in_list is False
|
||||
else:
|
||||
assert is_model_alias_in_list
|
||||
|
||||
resp = await model_group_info(
|
||||
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||
|
@ -1500,4 +1511,7 @@ async def test_proxy_model_group_alias_checks(prisma_client):
|
|||
if model_alias == item.model_group:
|
||||
is_model_alias_in_list = True
|
||||
|
||||
assert is_model_alias_in_list
|
||||
if hidden:
|
||||
assert is_model_alias_in_list is False
|
||||
else:
|
||||
assert is_model_alias_in_list, f"models: {models}"
|
||||
|
|
|
@ -2481,3 +2481,31 @@ async def test_router_batch_endpoints(provider):
|
|||
model="my-custom-name", custom_llm_provider=provider, limit=2
|
||||
)
|
||||
print("list_batches=", list_batches)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden", [True, False])
|
||||
def test_model_group_alias(hidden):
|
||||
_model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
},
|
||||
{"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}},
|
||||
]
|
||||
router = Router(
|
||||
model_list=_model_list,
|
||||
model_group_alias={
|
||||
"gpt-4.5-turbo": {"model": "gpt-3.5-turbo", "hidden": hidden}
|
||||
},
|
||||
)
|
||||
|
||||
models = router.get_model_list()
|
||||
|
||||
model_names = router.get_model_names()
|
||||
|
||||
if hidden:
|
||||
assert len(models) == len(_model_list)
|
||||
assert len(model_names) == len(_model_list)
|
||||
else:
|
||||
assert len(models) == len(_model_list) + 1
|
||||
assert len(model_names) == len(_model_list) + 1
|
||||
|
|
|
@ -582,3 +582,8 @@ class RouterRateLimitError(ValueError):
|
|||
self.cooldown_list = cooldown_list
|
||||
_message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}"
|
||||
super().__init__(_message)
|
||||
|
||||
|
||||
class RouterModelGroupAliasItem(TypedDict):
|
||||
model: str
|
||||
hidden: bool # if 'True', don't return on `.get_model_list`
|
||||
|
|
|
@ -76,6 +76,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.utils import FileTypes # type: ignore
|
||||
from litellm.types.utils import (
|
||||
CallTypes,
|
||||
ChatCompletionDeltaToolCall,
|
||||
|
@ -84,7 +85,6 @@ from litellm.types.utils import (
|
|||
Delta,
|
||||
Embedding,
|
||||
EmbeddingResponse,
|
||||
FileTypes,
|
||||
ImageResponse,
|
||||
Message,
|
||||
ModelInfo,
|
||||
|
@ -2339,6 +2339,7 @@ def get_litellm_params(
|
|||
text_completion=None,
|
||||
azure_ad_token_provider=None,
|
||||
user_continue_message=None,
|
||||
base_model=None,
|
||||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
|
@ -2365,6 +2366,8 @@ def get_litellm_params(
|
|||
"text_completion": text_completion,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
"user_continue_message": user_continue_message,
|
||||
"base_model": base_model
|
||||
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
|
||||
}
|
||||
|
||||
return litellm_params
|
||||
|
@ -6063,11 +6066,11 @@ def _calculate_retry_after(
|
|||
max_retries: int,
|
||||
response_headers: Optional[httpx.Headers] = None,
|
||||
min_timeout: int = 0,
|
||||
):
|
||||
) -> Union[float, int]:
|
||||
retry_after = _get_retry_after_from_exception_header(response_headers)
|
||||
|
||||
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
|
||||
if 0 < retry_after <= 60:
|
||||
if retry_after is not None and 0 < retry_after <= 60:
|
||||
return retry_after
|
||||
|
||||
initial_retry_delay = 0.5
|
||||
|
@ -10962,6 +10965,22 @@ def get_logging_id(start_time, response_obj):
|
|||
return None
|
||||
|
||||
|
||||
def _get_base_model_from_litellm_call_metadata(
|
||||
metadata: Optional[dict],
|
||||
) -> Optional[str]:
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
if metadata is not None:
|
||||
model_info = metadata.get("model_info", {})
|
||||
|
||||
if model_info is not None:
|
||||
base_model = model_info.get("base_model", None)
|
||||
if base_model is not None:
|
||||
return base_model
|
||||
return None
|
||||
|
||||
|
||||
def _get_base_model_from_metadata(model_call_details=None):
|
||||
if model_call_details is None:
|
||||
return None
|
||||
|
@ -10970,13 +10989,7 @@ def _get_base_model_from_metadata(model_call_details=None):
|
|||
if litellm_params is not None:
|
||||
metadata = litellm_params.get("metadata", {})
|
||||
|
||||
if metadata is not None:
|
||||
model_info = metadata.get("model_info", {})
|
||||
|
||||
if model_info is not None:
|
||||
base_model = model_info.get("base_model", None)
|
||||
if base_model is not None:
|
||||
return base_model
|
||||
return _get_base_model_from_litellm_call_metadata(metadata=metadata)
|
||||
return None
|
||||
|
||||
|
||||
|
|
6
pyrightconfig.json
Normal file
6
pyrightconfig.json
Normal file
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"ignore": [],
|
||||
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py"],
|
||||
"reportMissingImports": false
|
||||
}
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue