forked from phoenix/litellm-mirror
LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)
* refactor: cleanup unused variables + fix pyright errors * feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686 * fix(o1_reasoning.py): add stricter check for o-1 reasoning model * refactor(mistral/): make it easier to see mistral transformation logic * fix(openai.py): fix openai o-1 model param mapping Fixes https://github.com/BerriAI/litellm/issues/5685 * feat(main.py): infer finetuned gemini model from base model Fixes https://github.com/BerriAI/litellm/issues/5678 * docs(vertex.md): update docs to call finetuned gemini models * feat(proxy_server.py): allow admin to hide proxy model aliases Closes https://github.com/BerriAI/litellm/issues/5692 * docs(load_balancing.md): add docs on hiding alias models from proxy config * fix(base.py): don't raise notimplemented error * fix(user_api_key_auth.py): fix model max budget check * fix(router.py): fix elif * fix(user_api_key_auth.py): don't set team_id to empty str * fix(team_endpoints.py): fix response type * test(test_completion.py): handle predibase error * test(test_proxy_server.py): fix test * fix(o1_transformation.py): fix max_completion_token mapping * test(test_image_generation.py): mark flaky test
This commit is contained in:
parent
db3af20d84
commit
60709a0753
35 changed files with 1020 additions and 539 deletions
|
@ -1,9 +1,9 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: pyright
|
||||||
name: mypy
|
name: pyright
|
||||||
entry: python3 -m mypy --ignore-missing-imports
|
entry: pyright
|
||||||
language: system
|
language: system
|
||||||
types: [python]
|
types: [python]
|
||||||
files: ^litellm/
|
files: ^litellm/
|
||||||
|
|
|
@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
https://github.com/BerriAI/litellm
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
## **Call 100+ LLMs using the same Input/Output Format**
|
## **Call 100+ LLMs using the OpenAI Input/Output Format**
|
||||||
|
|
||||||
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
|
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
|
||||||
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`
|
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`
|
||||||
|
|
|
@ -1180,9 +1180,58 @@ response = completion(
|
||||||
|
|
||||||
Fine tuned models on vertex have a numerical model/endpoint id.
|
Fine tuned models on vertex have a numerical model/endpoint id.
|
||||||
|
|
||||||
| Model Name | Function Call |
|
<Tabs>
|
||||||
|------------------|--------------------------------------|
|
<TabItem value="sdk" label="SDK">
|
||||||
| your fine tuned model | `completion(model='vertex_ai/4965075652664360960', messages)`|
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
|
||||||
|
os.environ["VERTEXAI_LOCATION"] = "us-central1"
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai/<your-finetuned-model>", # e.g. vertex_ai/4965075652664360960
|
||||||
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add Vertex Credentials to your env
|
||||||
|
|
||||||
|
```bash
|
||||||
|
!gcloud auth application-default login
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- model_name: finetuned-gemini
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/<ENDPOINT_ID>
|
||||||
|
vertex_project: <PROJECT_ID>
|
||||||
|
vertex_location: <LOCATION>
|
||||||
|
model_info:
|
||||||
|
base_model: vertex_ai/gemini-1.5-pro # IMPORTANT
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'https://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: <LITELLM_KEY>' \
|
||||||
|
--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Gemini Pro Vision
|
## Gemini Pro Vision
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|
|
|
@ -38,9 +38,24 @@ router_settings:
|
||||||
|
|
||||||
## Router settings on config - routing_strategy, model_group_alias
|
## Router settings on config - routing_strategy, model_group_alias
|
||||||
|
|
||||||
|
Expose an 'alias' for a 'model_name' on the proxy server.
|
||||||
|
|
||||||
|
```
|
||||||
|
model_group_alias: {
|
||||||
|
"gpt-4": "gpt-3.5-turbo"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
These aliases are shown on `/v1/models`, `/v1/model/info`, and `/v1/model_group/info` by default.
|
||||||
|
|
||||||
litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64)
|
litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
Example config with `router_settings`
|
Example config with `router_settings`
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
|
@ -48,19 +63,41 @@ model_list:
|
||||||
model: azure/<your-deployment-name>
|
model: azure/<your-deployment-name>
|
||||||
api_base: <your-azure-endpoint>
|
api_base: <your-azure-endpoint>
|
||||||
api_key: <your-azure-api-key>
|
api_key: <your-azure-api-key>
|
||||||
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
|
|
||||||
|
router_settings:
|
||||||
|
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hide Alias Models
|
||||||
|
|
||||||
|
Use this if you want to set-up aliases for:
|
||||||
|
|
||||||
|
1. typos
|
||||||
|
2. minor model version changes
|
||||||
|
3. case sensitive changes between updates
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/gpt-turbo-small-ca
|
model: azure/<your-deployment-name>
|
||||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
api_base: <your-azure-endpoint>
|
||||||
api_key: <your-azure-api-key>
|
api_key: <your-azure-api-key>
|
||||||
rpm: 6
|
|
||||||
router_settings:
|
router_settings:
|
||||||
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
|
model_group_alias:
|
||||||
routing_strategy: least-busy # Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"]
|
"GPT-3.5-turbo": # alias
|
||||||
num_retries: 2
|
model: "gpt-3.5-turbo" # Actual model name in 'model_list'
|
||||||
timeout: 30 # 30 seconds
|
hidden: true # Exclude from `/v1/models`, `/v1/model/info`, `/v1/model_group/info`
|
||||||
redis_host: <your redis host>
|
```
|
||||||
redis_password: <your redis password>
|
|
||||||
redis_port: 1992
|
### Complete Spec
|
||||||
|
|
||||||
|
```python
|
||||||
|
model_group_alias: Optional[Dict[str, Union[str, RouterModelGroupAliasItem]]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class RouterModelGroupAliasItem(TypedDict):
|
||||||
|
model: str
|
||||||
|
hidden: bool # if 'True', don't return on `/v1/models`, `/v1/model/info`, `/v1/model_group/info`
|
||||||
```
|
```
|
|
@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
https://github.com/BerriAI/litellm
|
https://github.com/BerriAI/litellm
|
||||||
|
|
||||||
## **Call 100+ LLMs using the same Input/Output Format**
|
## **Call 100+ LLMs using the OpenAI Input/Output Format**
|
||||||
|
|
||||||
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
|
- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints
|
||||||
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`
|
- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']`
|
||||||
|
|
|
@ -939,15 +939,18 @@ from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConf
|
||||||
from .llms.OpenAI.openai import (
|
from .llms.OpenAI.openai import (
|
||||||
OpenAIConfig,
|
OpenAIConfig,
|
||||||
OpenAITextCompletionConfig,
|
OpenAITextCompletionConfig,
|
||||||
MistralConfig,
|
|
||||||
MistralEmbeddingConfig,
|
MistralEmbeddingConfig,
|
||||||
DeepInfraConfig,
|
DeepInfraConfig,
|
||||||
GroqConfig,
|
GroqConfig,
|
||||||
AzureAIStudioConfig,
|
AzureAIStudioConfig,
|
||||||
)
|
)
|
||||||
from .llms.OpenAI.o1_reasoning import (
|
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||||
|
from .llms.OpenAI.o1_transformation import (
|
||||||
OpenAIO1Config,
|
OpenAIO1Config,
|
||||||
)
|
)
|
||||||
|
from .llms.OpenAI.gpt_transformation import (
|
||||||
|
OpenAIGPTConfig,
|
||||||
|
)
|
||||||
from .llms.nvidia_nim import NvidiaNimConfig
|
from .llms.nvidia_nim import NvidiaNimConfig
|
||||||
from .llms.cerebras.chat import CerebrasConfig
|
from .llms.cerebras.chat import CerebrasConfig
|
||||||
from .llms.AI21.chat import AI21ChatConfig
|
from .llms.AI21.chat import AI21ChatConfig
|
||||||
|
|
|
@ -52,7 +52,9 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
internal_usage_cache: Optional[DualCache] = None,
|
internal_usage_cache: Optional[DualCache] = None,
|
||||||
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
|
alerting_threshold: Optional[
|
||||||
|
float
|
||||||
|
] = None, # threshold for slow / hanging llm responses (in seconds)
|
||||||
alerting: Optional[List] = [],
|
alerting: Optional[List] = [],
|
||||||
alert_types: List[AlertType] = [
|
alert_types: List[AlertType] = [
|
||||||
"llm_exceptions",
|
"llm_exceptions",
|
||||||
|
@ -74,6 +76,8 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
default_webhook_url: Optional[str] = None,
|
default_webhook_url: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if alerting_threshold is None:
|
||||||
|
alerting_threshold = 300
|
||||||
self.alerting_threshold = alerting_threshold
|
self.alerting_threshold = alerting_threshold
|
||||||
self.alerting = alerting
|
self.alerting = alerting
|
||||||
self.alert_types = alert_types
|
self.alert_types = alert_types
|
||||||
|
@ -99,6 +103,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
):
|
):
|
||||||
if alerting is not None:
|
if alerting is not None:
|
||||||
self.alerting = alerting
|
self.alerting = alerting
|
||||||
|
asyncio.create_task(self.periodic_flush())
|
||||||
if alerting_threshold is not None:
|
if alerting_threshold is not None:
|
||||||
self.alerting_threshold = alerting_threshold
|
self.alerting_threshold = alerting_threshold
|
||||||
if alert_types is not None:
|
if alert_types is not None:
|
||||||
|
@ -114,8 +119,6 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
if llm_router is not None:
|
if llm_router is not None:
|
||||||
self.llm_router = llm_router
|
self.llm_router = llm_router
|
||||||
|
|
||||||
asyncio.create_task(self.periodic_flush())
|
|
||||||
|
|
||||||
async def deployment_in_cooldown(self):
|
async def deployment_in_cooldown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -208,15 +211,20 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
_deployment_latencies = metadata["_latency_per_deployment"]
|
_deployment_latencies = metadata["_latency_per_deployment"]
|
||||||
if len(_deployment_latencies) == 0:
|
if len(_deployment_latencies) == 0:
|
||||||
return None
|
return None
|
||||||
|
_deployment_latency_map: Optional[dict] = None
|
||||||
try:
|
try:
|
||||||
# try sorting deployments by latency
|
# try sorting deployments by latency
|
||||||
_deployment_latencies = sorted(
|
_deployment_latencies = sorted(
|
||||||
_deployment_latencies.items(), key=lambda x: x[1]
|
_deployment_latencies.items(), key=lambda x: x[1]
|
||||||
)
|
)
|
||||||
_deployment_latencies = dict(_deployment_latencies)
|
_deployment_latency_map = dict(_deployment_latencies)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
for api_base, latency in _deployment_latencies.items():
|
|
||||||
|
if _deployment_latency_map is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for api_base, latency in _deployment_latency_map.items():
|
||||||
_message_to_send += f"\n{api_base}: {round(latency,2)}s"
|
_message_to_send += f"\n{api_base}: {round(latency,2)}s"
|
||||||
_message_to_send = "```" + _message_to_send + "```"
|
_message_to_send = "```" + _message_to_send + "```"
|
||||||
return _message_to_send
|
return _message_to_send
|
||||||
|
@ -475,6 +483,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
):
|
):
|
||||||
if self.alerting is None or self.alert_types is None:
|
if self.alerting is None or self.alert_types is None:
|
||||||
return
|
return
|
||||||
|
model: str = ""
|
||||||
if request_data is not None:
|
if request_data is not None:
|
||||||
model = request_data.get("model", "")
|
model = request_data.get("model", "")
|
||||||
messages = request_data.get("messages", None)
|
messages = request_data.get("messages", None)
|
||||||
|
@ -619,6 +628,7 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
return
|
return
|
||||||
_id: Optional[str] = "default_id" # used for caching
|
_id: Optional[str] = "default_id" # used for caching
|
||||||
user_info_json = user_info.model_dump(exclude_none=True)
|
user_info_json = user_info.model_dump(exclude_none=True)
|
||||||
|
user_info_str = ""
|
||||||
for k, v in user_info_json.items():
|
for k, v in user_info_json.items():
|
||||||
user_info_str = "\n{}: {}\n".format(k, v)
|
user_info_str = "\n{}: {}\n".format(k, v)
|
||||||
|
|
||||||
|
@ -1475,10 +1485,10 @@ Model Info:
|
||||||
|
|
||||||
if isinstance(response_obj, litellm.ModelResponse) and (
|
if isinstance(response_obj, litellm.ModelResponse) and (
|
||||||
hasattr(response_obj, "usage")
|
hasattr(response_obj, "usage")
|
||||||
and response_obj.usage is not None
|
and response_obj.usage is not None # type: ignore
|
||||||
and hasattr(response_obj.usage, "completion_tokens")
|
and hasattr(response_obj.usage, "completion_tokens") # type: ignore
|
||||||
):
|
):
|
||||||
completion_tokens = response_obj.usage.completion_tokens
|
completion_tokens = response_obj.usage.completion_tokens # type: ignore
|
||||||
if completion_tokens is not None and completion_tokens > 0:
|
if completion_tokens is not None and completion_tokens > 0:
|
||||||
final_value = float(
|
final_value = float(
|
||||||
response_s.total_seconds() / completion_tokens
|
response_s.total_seconds() / completion_tokens
|
||||||
|
@ -1608,10 +1618,14 @@ Model Info:
|
||||||
todays_date = datetime.datetime.now().date()
|
todays_date = datetime.datetime.now().date()
|
||||||
start_date = todays_date - datetime.timedelta(days=days)
|
start_date = todays_date - datetime.timedelta(days=days)
|
||||||
|
|
||||||
spend_per_team, spend_per_tag = await _get_spend_report_for_time_range(
|
_resp = await _get_spend_report_for_time_range(
|
||||||
start_date=start_date.strftime("%Y-%m-%d"),
|
start_date=start_date.strftime("%Y-%m-%d"),
|
||||||
end_date=todays_date.strftime("%Y-%m-%d"),
|
end_date=todays_date.strftime("%Y-%m-%d"),
|
||||||
)
|
)
|
||||||
|
if _resp is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
spend_per_team, spend_per_tag = _resp
|
||||||
|
|
||||||
_spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n"
|
_spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n"
|
||||||
|
|
||||||
|
@ -1656,13 +1670,16 @@ Model Info:
|
||||||
days=last_day_of_month - 1
|
days=last_day_of_month - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
monthly_spend_per_team, monthly_spend_per_tag = (
|
_resp = await _get_spend_report_for_time_range(
|
||||||
await _get_spend_report_for_time_range(
|
start_date=first_day_of_month.strftime("%Y-%m-%d"),
|
||||||
start_date=first_day_of_month.strftime("%Y-%m-%d"),
|
end_date=last_day_of_month.strftime("%Y-%m-%d"),
|
||||||
end_date=last_day_of_month.strftime("%Y-%m-%d"),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _resp is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
monthly_spend_per_team, monthly_spend_per_tag = _resp
|
||||||
|
|
||||||
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
|
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
|
||||||
|
|
||||||
if monthly_spend_per_team is not None:
|
if monthly_spend_per_team is not None:
|
||||||
|
|
142
litellm/llms/OpenAI/gpt_transformation.py
Normal file
142
litellm/llms/OpenAI/gpt_transformation.py
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
"""
|
||||||
|
Support for gpt model family
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIGPTConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
|
||||||
|
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
|
||||||
|
|
||||||
|
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||||
|
|
||||||
|
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||||
|
|
||||||
|
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||||
|
|
||||||
|
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||||
|
|
||||||
|
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||||
|
|
||||||
|
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||||
|
|
||||||
|
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||||
|
|
||||||
|
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
|
||||||
|
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||||
|
|
||||||
|
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
frequency_penalty: Optional[int] = None
|
||||||
|
function_call: Optional[Union[str, dict]] = None
|
||||||
|
functions: Optional[list] = None
|
||||||
|
logit_bias: Optional[dict] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
presence_penalty: Optional[int] = None
|
||||||
|
stop: Optional[Union[str, list]] = None
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
response_format: Optional[dict] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frequency_penalty: Optional[int] = None,
|
||||||
|
function_call: Optional[Union[str, dict]] = None,
|
||||||
|
functions: Optional[list] = None,
|
||||||
|
logit_bias: Optional[dict] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, list]] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
response_format: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
base_params = [
|
||||||
|
"frequency_penalty",
|
||||||
|
"logit_bias",
|
||||||
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
|
"max_tokens",
|
||||||
|
"n",
|
||||||
|
"presence_penalty",
|
||||||
|
"seed",
|
||||||
|
"stop",
|
||||||
|
"stream",
|
||||||
|
"stream_options",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"function_call",
|
||||||
|
"functions",
|
||||||
|
"max_retries",
|
||||||
|
"extra_headers",
|
||||||
|
"parallel_tool_calls",
|
||||||
|
] # works across all models
|
||||||
|
|
||||||
|
model_specific_params = []
|
||||||
|
if (
|
||||||
|
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||||
|
): # gpt-4 does not support 'response_format'
|
||||||
|
model_specific_params.append("response_format")
|
||||||
|
|
||||||
|
if (
|
||||||
|
model in litellm.open_ai_chat_completion_models
|
||||||
|
) or model in litellm.open_ai_text_completion_models:
|
||||||
|
model_specific_params.append(
|
||||||
|
"user"
|
||||||
|
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
||||||
|
return base_params + model_specific_params
|
||||||
|
|
||||||
|
def _map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
|
) -> dict:
|
||||||
|
supported_openai_params = self.get_supported_openai_params(model)
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in supported_openai_params:
|
||||||
|
optional_params[param] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
|
) -> dict:
|
||||||
|
return self._map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
)
|
|
@ -17,13 +17,12 @@ from typing import Any, List, Optional, Union
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||||
|
|
||||||
from .openai import OpenAIConfig
|
from .gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
class OpenAIO1Config(OpenAIConfig):
|
class OpenAIO1Config(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://platform.openai.com/docs/guides/reasoning
|
Reference: https://platform.openai.com/docs/guides/reasoning
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -50,9 +49,7 @@ class OpenAIO1Config(OpenAIConfig):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_openai_params = litellm.OpenAIConfig().get_supported_openai_params(
|
all_openai_params = super().get_supported_openai_params(model=model)
|
||||||
model="gpt-4o"
|
|
||||||
)
|
|
||||||
non_supported_params = [
|
non_supported_params = [
|
||||||
"logprobs",
|
"logprobs",
|
||||||
"tools",
|
"tools",
|
||||||
|
@ -69,13 +66,14 @@ class OpenAIO1Config(OpenAIConfig):
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self, non_default_params: dict, optional_params: dict, model: str
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
):
|
):
|
||||||
for param, value in non_default_params.items():
|
if "max_tokens" in non_default_params:
|
||||||
if param == "max_tokens":
|
optional_params["max_completion_tokens"] = non_default_params.pop(
|
||||||
optional_params["max_completion_tokens"] = value
|
"max_tokens"
|
||||||
return optional_params
|
)
|
||||||
|
return super()._map_openai_params(non_default_params, optional_params, model)
|
||||||
|
|
||||||
def is_model_o1_reasoning_model(self, model: str) -> bool:
|
def is_model_o1_reasoning_model(self, model: str) -> bool:
|
||||||
if "o1" in model:
|
if model in litellm.open_ai_chat_completion_models and "o1" in model:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -93,7 +91,7 @@ class OpenAIO1Config(OpenAIConfig):
|
||||||
)
|
)
|
||||||
messages[i] = new_message # Replace the old message with the new one
|
messages[i] = new_message # Replace the old message with the new one
|
||||||
|
|
||||||
if isinstance(message["content"], list):
|
if "content" in message and isinstance(message["content"], list):
|
||||||
new_content = []
|
new_content = []
|
||||||
for content_item in message["content"]:
|
for content_item in message["content"]:
|
||||||
if content_item.get("type") == "image_url":
|
if content_item.get("type") == "image_url":
|
|
@ -60,122 +60,6 @@ class OpenAIError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://docs.mistral.ai/api/
|
|
||||||
|
|
||||||
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
|
|
||||||
|
|
||||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
|
|
||||||
|
|
||||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
|
|
||||||
|
|
||||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
|
|
||||||
|
|
||||||
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
|
|
||||||
|
|
||||||
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
|
||||||
|
|
||||||
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
|
|
||||||
|
|
||||||
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
|
||||||
|
|
||||||
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
|
||||||
|
|
||||||
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: Optional[int] = None
|
|
||||||
top_p: Optional[int] = None
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
tools: Optional[list] = None
|
|
||||||
tool_choice: Optional[Literal["auto", "any", "none"]] = None
|
|
||||||
random_seed: Optional[int] = None
|
|
||||||
safe_prompt: Optional[bool] = None
|
|
||||||
response_format: Optional[dict] = None
|
|
||||||
stop: Optional[Union[str, list]] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
temperature: Optional[int] = None,
|
|
||||||
top_p: Optional[int] = None,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
tools: Optional[list] = None,
|
|
||||||
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
|
|
||||||
random_seed: Optional[int] = None,
|
|
||||||
safe_prompt: Optional[bool] = None,
|
|
||||||
response_format: Optional[dict] = None,
|
|
||||||
stop: Optional[Union[str, list]] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals().copy()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return [
|
|
||||||
"stream",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"max_tokens",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"seed",
|
|
||||||
"stop",
|
|
||||||
"response_format",
|
|
||||||
]
|
|
||||||
|
|
||||||
def _map_tool_choice(self, tool_choice: str) -> str:
|
|
||||||
if tool_choice == "auto" or tool_choice == "none":
|
|
||||||
return tool_choice
|
|
||||||
elif tool_choice == "required":
|
|
||||||
return "any"
|
|
||||||
else: # openai 'tool_choice' object param not supported by Mistral API
|
|
||||||
return "any"
|
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param == "max_tokens":
|
|
||||||
optional_params["max_tokens"] = value
|
|
||||||
if param == "tools":
|
|
||||||
optional_params["tools"] = value
|
|
||||||
if param == "stream" and value is True:
|
|
||||||
optional_params["stream"] = value
|
|
||||||
if param == "temperature":
|
|
||||||
optional_params["temperature"] = value
|
|
||||||
if param == "top_p":
|
|
||||||
optional_params["top_p"] = value
|
|
||||||
if param == "stop":
|
|
||||||
optional_params["stop"] = value
|
|
||||||
if param == "tool_choice" and isinstance(value, str):
|
|
||||||
optional_params["tool_choice"] = self._map_tool_choice(
|
|
||||||
tool_choice=value
|
|
||||||
)
|
|
||||||
if param == "seed":
|
|
||||||
optional_params["extra_body"] = {"random_seed": value}
|
|
||||||
if param == "response_format":
|
|
||||||
optional_params["response_format"] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
|
|
||||||
class MistralEmbeddingConfig:
|
class MistralEmbeddingConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
|
Reference: https://docs.mistral.ai/api/#operation/createEmbedding
|
||||||
|
@ -526,44 +410,19 @@ class OpenAIConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> list:
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
base_params = [
|
|
||||||
"frequency_penalty",
|
|
||||||
"logit_bias",
|
|
||||||
"logprobs",
|
|
||||||
"top_logprobs",
|
|
||||||
"max_tokens",
|
|
||||||
"n",
|
|
||||||
"presence_penalty",
|
|
||||||
"seed",
|
|
||||||
"stop",
|
|
||||||
"stream",
|
|
||||||
"stream_options",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"function_call",
|
|
||||||
"functions",
|
|
||||||
"max_retries",
|
|
||||||
"extra_headers",
|
|
||||||
"parallel_tool_calls",
|
|
||||||
] # works across all models
|
|
||||||
|
|
||||||
model_specific_params = []
|
|
||||||
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
|
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
|
||||||
return litellm.OpenAIO1Config().get_supported_openai_params(model=model)
|
return litellm.OpenAIO1Config().get_supported_openai_params(model=model)
|
||||||
if (
|
else:
|
||||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
return litellm.OpenAIGPTConfig().get_supported_openai_params(model=model)
|
||||||
): # gpt-4 does not support 'response_format'
|
|
||||||
model_specific_params.append("response_format")
|
|
||||||
|
|
||||||
if (
|
def _map_openai_params(
|
||||||
model in litellm.open_ai_chat_completion_models
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
) or model in litellm.open_ai_text_completion_models:
|
) -> dict:
|
||||||
model_specific_params.append(
|
supported_openai_params = self.get_supported_openai_params(model)
|
||||||
"user"
|
for param, value in non_default_params.items():
|
||||||
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
if param in supported_openai_params:
|
||||||
return base_params + model_specific_params
|
optional_params[param] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self, non_default_params: dict, optional_params: dict, model: str
|
self, non_default_params: dict, optional_params: dict, model: str
|
||||||
|
@ -575,11 +434,11 @@ class OpenAIConfig:
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
supported_openai_params = self.get_supported_openai_params(model)
|
return litellm.OpenAIGPTConfig().map_openai_params(
|
||||||
for param, value in non_default_params.items():
|
non_default_params=non_default_params,
|
||||||
if param in supported_openai_params:
|
optional_params=optional_params,
|
||||||
optional_params[param] = value
|
model=model,
|
||||||
return optional_params
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAITextCompletionConfig:
|
class OpenAITextCompletionConfig:
|
||||||
|
@ -816,18 +675,18 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def completion(
|
def completion( # type: ignore
|
||||||
self,
|
self,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
messages: Optional[list] = None,
|
messages: Optional[list] = None,
|
||||||
print_verbose: Optional[Callable] = None,
|
print_verbose: Optional[Callable] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
logging_obj=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
@ -858,14 +717,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
# process all OpenAI compatible provider logic here
|
# process all OpenAI compatible provider logic here
|
||||||
if custom_llm_provider == "mistral":
|
if custom_llm_provider == "mistral":
|
||||||
# check if message content passed in as list, and not string
|
# check if message content passed in as list, and not string
|
||||||
messages = prompt_factory(
|
messages = prompt_factory( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
if custom_llm_provider == "perplexity" and messages is not None:
|
if custom_llm_provider == "perplexity" and messages is not None:
|
||||||
# check if messages.name is passed + supported, if not supported remove
|
# check if messages.name is passed + supported, if not supported remove
|
||||||
messages = prompt_factory(
|
messages = prompt_factory( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
@ -933,7 +792,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
status_code=422, message="max retries must be an int"
|
status_code=422, message="max retries must be an int"
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_client = self._get_openai_client(
|
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||||
is_async=False,
|
is_async=False,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1068,7 +927,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
2
|
2
|
||||||
): # if call fails due to alternating messages, retry with reformatted message
|
): # if call fails due to alternating messages, retry with reformatted message
|
||||||
try:
|
try:
|
||||||
openai_aclient = self._get_openai_client(
|
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||||
is_async=True,
|
is_async=True,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1156,7 +1015,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
headers=None,
|
headers=None,
|
||||||
):
|
):
|
||||||
openai_client = self._get_openai_client(
|
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||||
is_async=False,
|
is_async=False,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1210,7 +1069,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
response = None
|
response = None
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
try:
|
try:
|
||||||
openai_aclient = self._get_openai_client(
|
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||||
is_async=True,
|
is_async=True,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1282,7 +1141,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
error_headers = getattr(e, "headers", None)
|
error_headers = getattr(e, "headers", None)
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
message=f"{str(e)}\n\nOriginal Response: {response.text}",
|
message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore
|
||||||
headers=error_headers,
|
headers=error_headers,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -1294,7 +1153,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
elif hasattr(e, "status_code"):
|
elif hasattr(e, "status_code"):
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=e.status_code,
|
status_code=getattr(e, "status_code", 500),
|
||||||
message=str(e),
|
message=str(e),
|
||||||
headers=error_headers,
|
headers=error_headers,
|
||||||
)
|
)
|
||||||
|
@ -1361,7 +1220,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
openai_aclient = self._get_openai_client(
|
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||||
is_async=True,
|
is_async=True,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1410,16 +1269,16 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
status_code=status_code, message=str(e), headers=error_headers
|
status_code=status_code, message=str(e), headers=error_headers
|
||||||
)
|
)
|
||||||
|
|
||||||
def embedding(
|
def embedding( # type: ignore
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
input: list,
|
input: list,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
model_response: litellm.utils.EmbeddingResponse,
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
optional_params=None,
|
|
||||||
client=None,
|
client=None,
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
):
|
):
|
||||||
|
@ -1452,7 +1311,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
openai_client = self._get_openai_client(
|
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||||
is_async=False,
|
is_async=False,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1496,11 +1355,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
logging_obj: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
logging_obj=None,
|
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -1538,15 +1397,16 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||||
logging_obj=None,
|
|
||||||
optional_params=None,
|
|
||||||
client=None,
|
client=None,
|
||||||
aimg_generation=None,
|
aimg_generation=None,
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
|
data = {}
|
||||||
try:
|
try:
|
||||||
model = model
|
model = model
|
||||||
data = {"model": model, "prompt": prompt, **optional_params}
|
data = {"model": model, "prompt": prompt, **optional_params}
|
||||||
|
@ -1611,7 +1471,9 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
original_response=str(e),
|
original_response=str(e),
|
||||||
)
|
)
|
||||||
if hasattr(e, "status_code"):
|
if hasattr(e, "status_code"):
|
||||||
raise OpenAIError(status_code=e.status_code, message=str(e))
|
raise OpenAIError(
|
||||||
|
status_code=getattr(e, "status_code", 500), message=str(e)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise OpenAIError(status_code=500, message=str(e))
|
raise OpenAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
@ -1661,7 +1523,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
input=input,
|
input=input,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
)
|
)
|
||||||
return response
|
return response # type: ignore
|
||||||
|
|
||||||
async def async_audio_speech(
|
async def async_audio_speech(
|
||||||
self,
|
self,
|
||||||
|
@ -1784,11 +1646,8 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
|
|
||||||
class OpenAITextCompletion(BaseLLM):
|
class OpenAITextCompletion(BaseLLM):
|
||||||
_client_session: httpx.Client
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._client_session = self.create_client_session()
|
|
||||||
|
|
||||||
def validate_environment(self, api_key):
|
def validate_environment(self, api_key):
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -1806,10 +1665,10 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
messages: list,
|
messages: list,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
optional_params: dict,
|
||||||
print_verbose: Optional[Callable] = None,
|
print_verbose: Optional[Callable] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
client=None,
|
client=None,
|
||||||
|
@ -1921,7 +1780,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
max_retries=None,
|
max_retries: int,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
|
@ -2017,9 +1876,9 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
max_retries: int,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
|
||||||
organization=None,
|
organization=None,
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
|
|
12
litellm/llms/README.md
Normal file
12
litellm/llms/README.md
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
## File Structure
|
||||||
|
|
||||||
|
### August 27th, 2024
|
||||||
|
|
||||||
|
To make it easy to see how calls are transformed for each model/provider:
|
||||||
|
|
||||||
|
we are working on moving all supported litellm providers to a folder structure, where folder name is the supported litellm provider name.
|
||||||
|
|
||||||
|
Each folder will contain a `*_transformation.py` file, which has all the request/response transformation logic, making it easy to see how calls are modified.
|
||||||
|
|
||||||
|
E.g. `cohere/`, `bedrock/`.
|
||||||
|
|
|
@ -66,22 +66,24 @@ class BaseLLM:
|
||||||
return _aclient_session
|
return _aclient_session
|
||||||
|
|
||||||
def __exit__(self):
|
def __exit__(self):
|
||||||
if hasattr(self, "_client_session"):
|
if hasattr(self, "_client_session") and self._client_session is not None:
|
||||||
self._client_session.close()
|
self._client_session.close()
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
if hasattr(self, "_aclient_session"):
|
if hasattr(self, "_aclient_session"):
|
||||||
await self._aclient_session.aclose()
|
await self._aclient_session.aclose() # type: ignore
|
||||||
|
|
||||||
def validate_environment(self): # set up the environment required to run the model
|
def validate_environment(
|
||||||
pass
|
self, *args, **kwargs
|
||||||
|
) -> Optional[Any]: # set up the environment required to run the model
|
||||||
|
return None
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
): # logic for parsing in - calling - parsing out model completion calls
|
) -> Any: # logic for parsing in - calling - parsing out model completion calls
|
||||||
pass
|
return None
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
): # logic for parsing in - calling - parsing out model embedding calls
|
) -> Any: # logic for parsing in - calling - parsing out model embedding calls
|
||||||
pass
|
return None
|
||||||
|
|
5
litellm/llms/mistral/chat.py
Normal file
5
litellm/llms/mistral/chat.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
"""
|
||||||
|
Calls handled in openai/
|
||||||
|
|
||||||
|
as mistral is an openai-compatible endpoint.
|
||||||
|
"""
|
5
litellm/llms/mistral/embedding.py
Normal file
5
litellm/llms/mistral/embedding.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
"""
|
||||||
|
Calls handled in openai/
|
||||||
|
|
||||||
|
as mistral is an openai-compatible endpoint.
|
||||||
|
"""
|
126
litellm/llms/mistral/mistral_chat_transformation.py
Normal file
126
litellm/llms/mistral/mistral_chat_transformation.py
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from OpenAI /v1/chat/completion format to Mistral's /chat/completion format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
Docs - https://docs.mistral.ai/api/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class MistralConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.mistral.ai/api/
|
||||||
|
|
||||||
|
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
|
||||||
|
|
||||||
|
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
|
||||||
|
|
||||||
|
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
|
||||||
|
|
||||||
|
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
|
||||||
|
|
||||||
|
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
|
||||||
|
|
||||||
|
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
||||||
|
|
||||||
|
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
|
||||||
|
|
||||||
|
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||||
|
|
||||||
|
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
||||||
|
|
||||||
|
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
tools: Optional[list] = None
|
||||||
|
tool_choice: Optional[Literal["auto", "any", "none"]] = None
|
||||||
|
random_seed: Optional[int] = None
|
||||||
|
safe_prompt: Optional[bool] = None
|
||||||
|
response_format: Optional[dict] = None
|
||||||
|
stop: Optional[Union[str, list]] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
|
||||||
|
random_seed: Optional[int] = None,
|
||||||
|
safe_prompt: Optional[bool] = None,
|
||||||
|
response_format: Optional[dict] = None,
|
||||||
|
stop: Optional[Union[str, list]] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"seed",
|
||||||
|
"stop",
|
||||||
|
"response_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def _map_tool_choice(self, tool_choice: str) -> str:
|
||||||
|
if tool_choice == "auto" or tool_choice == "none":
|
||||||
|
return tool_choice
|
||||||
|
elif tool_choice == "required":
|
||||||
|
return "any"
|
||||||
|
else: # openai 'tool_choice' object param not supported by Mistral API
|
||||||
|
return "any"
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "stream" and value is True:
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "tool_choice" and isinstance(value, str):
|
||||||
|
optional_params["tool_choice"] = self._map_tool_choice(
|
||||||
|
tool_choice=value
|
||||||
|
)
|
||||||
|
if param == "seed":
|
||||||
|
optional_params["extra_body"] = {"random_seed": value}
|
||||||
|
if param == "response_format":
|
||||||
|
optional_params["response_format"] = value
|
||||||
|
return optional_params
|
0
litellm/llms/mistral/mistral_embedding_transformation.py
Normal file
0
litellm/llms/mistral/mistral_embedding_transformation.py
Normal file
|
@ -737,6 +737,7 @@ def completion(
|
||||||
preset_cache_key = kwargs.get("preset_cache_key", None)
|
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||||
hf_model_name = kwargs.get("hf_model_name", None)
|
hf_model_name = kwargs.get("hf_model_name", None)
|
||||||
supports_system_message = kwargs.get("supports_system_message", None)
|
supports_system_message = kwargs.get("supports_system_message", None)
|
||||||
|
base_model = kwargs.get("base_model", None)
|
||||||
### TEXT COMPLETION CALLS ###
|
### TEXT COMPLETION CALLS ###
|
||||||
text_completion = kwargs.get("text_completion", False)
|
text_completion = kwargs.get("text_completion", False)
|
||||||
atext_completion = kwargs.get("atext_completion", False)
|
atext_completion = kwargs.get("atext_completion", False)
|
||||||
|
@ -782,11 +783,9 @@ def completion(
|
||||||
"top_logprobs",
|
"top_logprobs",
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
]
|
]
|
||||||
litellm_params = (
|
|
||||||
all_litellm_params # use the external var., used in creating cache key as well.
|
|
||||||
)
|
|
||||||
|
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + all_litellm_params
|
||||||
|
litellm_params = {} # used to prevent unbound var errors
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v for k, v in kwargs.items() if k not in default_params
|
k: v for k, v in kwargs.items() if k not in default_params
|
||||||
} # model-specific params - pass them straight to the model/provider
|
} # model-specific params - pass them straight to the model/provider
|
||||||
|
@ -973,6 +972,7 @@ def completion(
|
||||||
text_completion=kwargs.get("text_completion"),
|
text_completion=kwargs.get("text_completion"),
|
||||||
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
||||||
user_continue_message=kwargs.get("user_continue_message"),
|
user_continue_message=kwargs.get("user_continue_message"),
|
||||||
|
base_model=base_model,
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2123,7 +2123,10 @@ def completion(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
elif "gemini" in model:
|
elif "gemini" in model or (
|
||||||
|
litellm_params.get("base_model") is not None
|
||||||
|
and "gemini" in litellm_params["base_model"]
|
||||||
|
):
|
||||||
model_response = vertex_chat_completion.completion( # type: ignore
|
model_response = vertex_chat_completion.completion( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -2820,7 +2823,7 @@ def completion_with_retries(*args, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
num_retries = kwargs.pop("num_retries", 3)
|
num_retries = kwargs.pop("num_retries", 3)
|
||||||
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
|
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
|
||||||
original_function = kwargs.pop("original_function", completion)
|
original_function = kwargs.pop("original_function", completion)
|
||||||
if retry_strategy == "constant_retry":
|
if retry_strategy == "constant_retry":
|
||||||
retryer = tenacity.Retrying(
|
retryer = tenacity.Retrying(
|
||||||
|
@ -4997,7 +5000,9 @@ def speech(
|
||||||
async def ahealth_check(
|
async def ahealth_check(
|
||||||
model_params: dict,
|
model_params: dict,
|
||||||
mode: Optional[
|
mode: Optional[
|
||||||
Literal["completion", "embedding", "image_generation", "chat", "batch"]
|
Literal[
|
||||||
|
"completion", "embedding", "image_generation", "chat", "batch", "rerank"
|
||||||
|
]
|
||||||
] = None,
|
] = None,
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
input: Optional[List] = None,
|
input: Optional[List] = None,
|
||||||
|
@ -5113,6 +5118,12 @@ async def ahealth_check(
|
||||||
model_params["prompt"] = prompt
|
model_params["prompt"] = prompt
|
||||||
await litellm.aimage_generation(**model_params)
|
await litellm.aimage_generation(**model_params)
|
||||||
response = {}
|
response = {}
|
||||||
|
elif mode == "rerank":
|
||||||
|
model_params.pop("messages", None)
|
||||||
|
model_params["query"] = prompt
|
||||||
|
model_params["documents"] = ["my sample text"]
|
||||||
|
await litellm.arerank(**model_params)
|
||||||
|
response = {}
|
||||||
elif "*" in model:
|
elif "*" in model:
|
||||||
from litellm.litellm_core_utils.llm_request_utils import (
|
from litellm.litellm_core_utils.llm_request_utils import (
|
||||||
pick_cheapest_model_from_llm_provider,
|
pick_cheapest_model_from_llm_provider,
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "gpt-4o"
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-4o
|
model: gpt-3.5-turbo
|
||||||
|
|
||||||
litellm_settings:
|
router_settings:
|
||||||
cache: true
|
model_group_alias: {"gpt-4": {"model": "gpt-3.5-turbo", "hidden": false}}
|
||||||
cache_params:
|
|
||||||
type: local
|
|
|
@ -99,15 +99,15 @@ class LitellmUserRoles(str, enum.Enum):
|
||||||
return ui_labels.get(self.value, "")
|
return ui_labels.get(self.value, "")
|
||||||
|
|
||||||
|
|
||||||
class LitellmTableNames(str, enum.Enum):
|
class LitellmTableNames(enum.Enum):
|
||||||
"""
|
"""
|
||||||
Enum for Table Names used by LiteLLM
|
Enum for Table Names used by LiteLLM
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TEAM_TABLE_NAME: str = "LiteLLM_TeamTable"
|
TEAM_TABLE_NAME = "LiteLLM_TeamTable"
|
||||||
USER_TABLE_NAME: str = "LiteLLM_UserTable"
|
USER_TABLE_NAME = "LiteLLM_UserTable"
|
||||||
KEY_TABLE_NAME: str = "LiteLLM_VerificationToken"
|
KEY_TABLE_NAME = "LiteLLM_VerificationToken"
|
||||||
PROXY_MODEL_TABLE_NAME: str = "LiteLLM_ModelTable"
|
PROXY_MODEL_TABLE_NAME = "LiteLLM_ModelTable"
|
||||||
|
|
||||||
|
|
||||||
AlertType = Literal[
|
AlertType = Literal[
|
||||||
|
@ -140,7 +140,7 @@ class LiteLLMBase(BaseModel):
|
||||||
Implements default functions, all pydantic objects should have.
|
Implements default functions, all pydantic objects should have.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def json(self, **kwargs):
|
def json(self, **kwargs): # type: ignore
|
||||||
try:
|
try:
|
||||||
return self.model_dump(**kwargs) # noqa
|
return self.model_dump(**kwargs) # noqa
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -170,7 +170,7 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMRoutes(enum.Enum):
|
class LiteLLMRoutes(enum.Enum):
|
||||||
openai_route_names: List = [
|
openai_route_names = [
|
||||||
"chat_completion",
|
"chat_completion",
|
||||||
"completion",
|
"completion",
|
||||||
"embeddings",
|
"embeddings",
|
||||||
|
@ -179,7 +179,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"moderations",
|
"moderations",
|
||||||
"model_list", # OpenAI /v1/models route
|
"model_list", # OpenAI /v1/models route
|
||||||
]
|
]
|
||||||
openai_routes: List = [
|
openai_routes = [
|
||||||
# chat completions
|
# chat completions
|
||||||
"/engines/{model}/chat/completions",
|
"/engines/{model}/chat/completions",
|
||||||
"/openai/deployments/{model}/chat/completions",
|
"/openai/deployments/{model}/chat/completions",
|
||||||
|
@ -247,18 +247,18 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/v1/rerank",
|
"/v1/rerank",
|
||||||
]
|
]
|
||||||
|
|
||||||
mapped_pass_through_routes: List = [
|
mapped_pass_through_routes = [
|
||||||
"/bedrock",
|
"/bedrock",
|
||||||
"/vertex-ai",
|
"/vertex-ai",
|
||||||
"/gemini",
|
"/gemini",
|
||||||
"/langfuse",
|
"/langfuse",
|
||||||
]
|
]
|
||||||
|
|
||||||
anthropic_routes: List = [
|
anthropic_routes = [
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
]
|
]
|
||||||
|
|
||||||
info_routes: List = [
|
info_routes = [
|
||||||
"/key/info",
|
"/key/info",
|
||||||
"/team/info",
|
"/team/info",
|
||||||
"/team/list",
|
"/team/list",
|
||||||
|
@ -271,9 +271,9 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
]
|
]
|
||||||
|
|
||||||
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
# NOTE: ROUTES ONLY FOR MASTER KEY - only the Master Key should be able to Reset Spend
|
||||||
master_key_only_routes: List = ["/global/spend/reset", "/key/list"]
|
master_key_only_routes = ["/global/spend/reset", "/key/list"]
|
||||||
|
|
||||||
sso_only_routes: List = [
|
sso_only_routes = [
|
||||||
"/key/generate",
|
"/key/generate",
|
||||||
"/key/update",
|
"/key/update",
|
||||||
"/key/delete",
|
"/key/delete",
|
||||||
|
@ -282,7 +282,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/sso/get/logout_url",
|
"/sso/get/logout_url",
|
||||||
]
|
]
|
||||||
|
|
||||||
management_routes: List = [ # key
|
management_routes = [ # key
|
||||||
"/key/generate",
|
"/key/generate",
|
||||||
"/key/update",
|
"/key/update",
|
||||||
"/key/delete",
|
"/key/delete",
|
||||||
|
@ -307,7 +307,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/model/info",
|
"/model/info",
|
||||||
]
|
]
|
||||||
|
|
||||||
spend_tracking_routes: List = [
|
spend_tracking_routes = [
|
||||||
# spend
|
# spend
|
||||||
"/spend/keys",
|
"/spend/keys",
|
||||||
"/spend/users",
|
"/spend/users",
|
||||||
|
@ -316,7 +316,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/spend/logs",
|
"/spend/logs",
|
||||||
]
|
]
|
||||||
|
|
||||||
global_spend_tracking_routes: List = [
|
global_spend_tracking_routes = [
|
||||||
# global spend
|
# global spend
|
||||||
"/global/spend/logs",
|
"/global/spend/logs",
|
||||||
"/global/spend",
|
"/global/spend",
|
||||||
|
@ -328,7 +328,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/global/spend/report",
|
"/global/spend/report",
|
||||||
]
|
]
|
||||||
|
|
||||||
public_routes: List = [
|
public_routes = [
|
||||||
"/routes",
|
"/routes",
|
||||||
"/",
|
"/",
|
||||||
"/health/liveliness",
|
"/health/liveliness",
|
||||||
|
@ -339,7 +339,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/metrics",
|
"/metrics",
|
||||||
]
|
]
|
||||||
|
|
||||||
internal_user_routes: List = (
|
internal_user_routes = (
|
||||||
[
|
[
|
||||||
"/key/generate",
|
"/key/generate",
|
||||||
"/key/update",
|
"/key/update",
|
||||||
|
@ -357,7 +357,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
+ sso_only_routes
|
+ sso_only_routes
|
||||||
)
|
)
|
||||||
|
|
||||||
self_managed_routes: List = [
|
self_managed_routes = [
|
||||||
"/team/member_add",
|
"/team/member_add",
|
||||||
"/team/member_delete",
|
"/team/member_delete",
|
||||||
] # routes that manage their own allowed/disallowed logic
|
] # routes that manage their own allowed/disallowed logic
|
||||||
|
@ -581,7 +581,9 @@ class ModelParams(LiteLLMBase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
if values.get("model_info") is None:
|
if values.get("model_info") is None:
|
||||||
values.update({"model_info": ModelInfo()})
|
values.update(
|
||||||
|
{"model_info": ModelInfo(id=None, mode="chat", base_model=None)}
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
@ -627,7 +629,7 @@ class GenerateKeyRequest(_GenerateKeyRequest):
|
||||||
|
|
||||||
|
|
||||||
class GenerateKeyResponse(_GenerateKeyRequest):
|
class GenerateKeyResponse(_GenerateKeyRequest):
|
||||||
key: str
|
key: str # type: ignore
|
||||||
key_name: Optional[str] = None
|
key_name: Optional[str] = None
|
||||||
expires: Optional[datetime]
|
expires: Optional[datetime]
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
|
@ -659,7 +661,7 @@ class GenerateKeyResponse(_GenerateKeyRequest):
|
||||||
class UpdateKeyRequest(GenerateKeyRequest):
|
class UpdateKeyRequest(GenerateKeyRequest):
|
||||||
# Note: the defaults of all Params here MUST BE NONE
|
# Note: the defaults of all Params here MUST BE NONE
|
||||||
# else they will get overwritten
|
# else they will get overwritten
|
||||||
key: str
|
key: str # type: ignore
|
||||||
duration: Optional[str] = None
|
duration: Optional[str] = None
|
||||||
spend: Optional[float] = None
|
spend: Optional[float] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
@ -976,6 +978,7 @@ class TeamCallbackMetadata(LiteLLMBase):
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_TeamTable(TeamBase):
|
class LiteLLM_TeamTable(TeamBase):
|
||||||
|
team_id: str # type: ignore
|
||||||
spend: Optional[float] = None
|
spend: Optional[float] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
|
@ -1061,7 +1064,7 @@ class LiteLLM_OrganizationTable(LiteLLMBase):
|
||||||
|
|
||||||
|
|
||||||
class NewOrganizationResponse(LiteLLM_OrganizationTable):
|
class NewOrganizationResponse(LiteLLM_OrganizationTable):
|
||||||
organization_id: str
|
organization_id: str # type: ignore
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
|
@ -1388,16 +1391,7 @@ class UserAPIKeyAuth(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
user_role: Optional[
|
user_role: Optional[LitellmUserRoles] = None
|
||||||
Literal[
|
|
||||||
LitellmUserRoles.PROXY_ADMIN,
|
|
||||||
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
|
|
||||||
LitellmUserRoles.INTERNAL_USER,
|
|
||||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
|
||||||
LitellmUserRoles.TEAM,
|
|
||||||
LitellmUserRoles.CUSTOMER,
|
|
||||||
]
|
|
||||||
] = None
|
|
||||||
allowed_model_region: Optional[Literal["eu"]] = None
|
allowed_model_region: Optional[Literal["eu"]] = None
|
||||||
parent_otel_span: Optional[Span] = None
|
parent_otel_span: Optional[Span] = None
|
||||||
rpm_limit_per_model: Optional[Dict[str, int]] = None
|
rpm_limit_per_model: Optional[Dict[str, int]] = None
|
||||||
|
@ -1716,9 +1710,9 @@ class SpendLogsPayload(TypedDict):
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
startTime: datetime
|
startTime: Union[datetime, str]
|
||||||
endTime: datetime
|
endTime: Union[datetime, str]
|
||||||
completionStartTime: Optional[datetime]
|
completionStartTime: Optional[Union[datetime, str]]
|
||||||
model: str
|
model: str
|
||||||
model_id: Optional[str]
|
model_id: Optional[str]
|
||||||
model_group: Optional[str]
|
model_group: Optional[str]
|
||||||
|
@ -1891,6 +1885,6 @@ class TeamAddMemberResponse(LiteLLM_TeamTable):
|
||||||
|
|
||||||
class TeamInfoResponseObject(TypedDict):
|
class TeamInfoResponseObject(TypedDict):
|
||||||
team_id: str
|
team_id: str
|
||||||
team_info: TeamBase
|
team_info: LiteLLM_TeamTable
|
||||||
keys: List
|
keys: List
|
||||||
team_memberships: List[LiteLLM_TeamMembership]
|
team_memberships: List[LiteLLM_TeamMembership]
|
||||||
|
|
|
@ -109,11 +109,8 @@ async def user_api_key_auth(
|
||||||
),
|
),
|
||||||
) -> UserAPIKeyAuth:
|
) -> UserAPIKeyAuth:
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
allowed_routes_check,
|
|
||||||
common_checks,
|
|
||||||
custom_db_client,
|
custom_db_client,
|
||||||
general_settings,
|
general_settings,
|
||||||
get_actual_routes,
|
|
||||||
jwt_handler,
|
jwt_handler,
|
||||||
litellm_proxy_admin_name,
|
litellm_proxy_admin_name,
|
||||||
llm_model_list,
|
llm_model_list,
|
||||||
|
@ -125,6 +122,8 @@ async def user_api_key_auth(
|
||||||
user_custom_auth,
|
user_custom_auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parent_otel_span: Optional[Span] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
route: str = get_request_route(request=request)
|
route: str = get_request_route(request=request)
|
||||||
# get the request body
|
# get the request body
|
||||||
|
@ -137,6 +136,7 @@ async def user_api_key_auth(
|
||||||
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
||||||
"pass_through_endpoints", None
|
"pass_through_endpoints", None
|
||||||
)
|
)
|
||||||
|
passed_in_key: Optional[str] = None
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
passed_in_key = api_key
|
passed_in_key = api_key
|
||||||
api_key = _get_bearer_token(api_key=api_key)
|
api_key = _get_bearer_token(api_key=api_key)
|
||||||
|
@ -161,7 +161,6 @@ async def user_api_key_auth(
|
||||||
custom_litellm_key_header_name=custom_litellm_key_header_name,
|
custom_litellm_key_header_name=custom_litellm_key_header_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
parent_otel_span: Optional[Span] = None
|
|
||||||
if open_telemetry_logger is not None:
|
if open_telemetry_logger is not None:
|
||||||
parent_otel_span = open_telemetry_logger.tracer.start_span(
|
parent_otel_span = open_telemetry_logger.tracer.start_span(
|
||||||
name="Received Proxy Server Request",
|
name="Received Proxy Server Request",
|
||||||
|
@ -189,7 +188,7 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
######## Route Checks Before Reading DB / Cache for "token" ################
|
######## Route Checks Before Reading DB / Cache for "token" ################
|
||||||
if (
|
if (
|
||||||
route in LiteLLMRoutes.public_routes.value
|
route in LiteLLMRoutes.public_routes.value # type: ignore
|
||||||
or route_in_additonal_public_routes(current_route=route)
|
or route_in_additonal_public_routes(current_route=route)
|
||||||
):
|
):
|
||||||
# check if public endpoint
|
# check if public endpoint
|
||||||
|
@ -410,7 +409,7 @@ async def user_api_key_auth(
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
## CHECK PASS-THROUGH ENDPOINTS ##
|
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||||
is_mapped_pass_through_route: bool = False
|
is_mapped_pass_through_route: bool = False
|
||||||
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value:
|
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore
|
||||||
if route.startswith(mapped_route):
|
if route.startswith(mapped_route):
|
||||||
is_mapped_pass_through_route = True
|
is_mapped_pass_through_route = True
|
||||||
if is_mapped_pass_through_route:
|
if is_mapped_pass_through_route:
|
||||||
|
@ -444,9 +443,9 @@ async def user_api_key_auth(
|
||||||
header_key = headers.get("litellm_user_api_key", "")
|
header_key = headers.get("litellm_user_api_key", "")
|
||||||
if (
|
if (
|
||||||
isinstance(request.headers, dict)
|
isinstance(request.headers, dict)
|
||||||
and request.headers.get(key=header_key) is not None
|
and request.headers.get(key=header_key) is not None # type: ignore
|
||||||
):
|
):
|
||||||
api_key = request.headers.get(key=header_key)
|
api_key = request.headers.get(key=header_key) # type: ignore
|
||||||
|
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
|
@ -606,7 +605,7 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
## IF it's not a master key
|
## IF it's not a master key
|
||||||
## Route should not be in master_key_only_routes
|
## Route should not be in master_key_only_routes
|
||||||
if route in LiteLLMRoutes.master_key_only_routes.value:
|
if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Tried to access route={route}, which is only for MASTER KEY"
|
f"Tried to access route={route}, which is only for MASTER KEY"
|
||||||
)
|
)
|
||||||
|
@ -669,8 +668,9 @@ async def user_api_key_auth(
|
||||||
"allowed_model_region"
|
"allowed_model_region"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_obj: Optional[LiteLLM_UserTable] = None
|
||||||
|
valid_token_dict: dict = {}
|
||||||
if valid_token is not None:
|
if valid_token is not None:
|
||||||
user_obj: Optional[LiteLLM_UserTable] = None
|
|
||||||
# Got Valid Token from Cache, DB
|
# Got Valid Token from Cache, DB
|
||||||
# Run checks for
|
# Run checks for
|
||||||
# 1. If token can call model
|
# 1. If token can call model
|
||||||
|
@ -686,6 +686,7 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
# Check 1. If token can call model
|
# Check 1. If token can call model
|
||||||
_model_alias_map = {}
|
_model_alias_map = {}
|
||||||
|
model: Optional[str] = None
|
||||||
if (
|
if (
|
||||||
hasattr(valid_token, "team_model_aliases")
|
hasattr(valid_token, "team_model_aliases")
|
||||||
and valid_token.team_model_aliases is not None
|
and valid_token.team_model_aliases is not None
|
||||||
|
@ -698,6 +699,7 @@ async def user_api_key_auth(
|
||||||
_model_alias_map = {**valid_token.aliases}
|
_model_alias_map = {**valid_token.aliases}
|
||||||
litellm.model_alias_map = _model_alias_map
|
litellm.model_alias_map = _model_alias_map
|
||||||
config = valid_token.config
|
config = valid_token.config
|
||||||
|
|
||||||
if config != {}:
|
if config != {}:
|
||||||
model_list = config.get("model_list", [])
|
model_list = config.get("model_list", [])
|
||||||
llm_model_list = model_list
|
llm_model_list = model_list
|
||||||
|
@ -887,7 +889,10 @@ async def user_api_key_auth(
|
||||||
and max_budget_per_model.get(current_model, None) is not None
|
and max_budget_per_model.get(current_model, None) is not None
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
model_spend[0]["model"] == current_model
|
"model" in model_spend[0]
|
||||||
|
and model_spend[0].get("model") == current_model
|
||||||
|
and "_sum" in model_spend[0]
|
||||||
|
and "spend" in model_spend[0]["_sum"]
|
||||||
and model_spend[0]["_sum"]["spend"]
|
and model_spend[0]["_sum"]["spend"]
|
||||||
>= max_budget_per_model[current_model]
|
>= max_budget_per_model[current_model]
|
||||||
):
|
):
|
||||||
|
@ -927,16 +932,19 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check 8: Additional Common Checks across jwt + key auth
|
# Check 8: Additional Common Checks across jwt + key auth
|
||||||
_team_obj = LiteLLM_TeamTable(
|
if valid_token.team_id is not None:
|
||||||
team_id=valid_token.team_id,
|
_team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
|
||||||
max_budget=valid_token.team_max_budget,
|
team_id=valid_token.team_id,
|
||||||
spend=valid_token.team_spend,
|
max_budget=valid_token.team_max_budget,
|
||||||
tpm_limit=valid_token.team_tpm_limit,
|
spend=valid_token.team_spend,
|
||||||
rpm_limit=valid_token.team_rpm_limit,
|
tpm_limit=valid_token.team_tpm_limit,
|
||||||
blocked=valid_token.team_blocked,
|
rpm_limit=valid_token.team_rpm_limit,
|
||||||
models=valid_token.team_models,
|
blocked=valid_token.team_blocked,
|
||||||
metadata=valid_token.team_metadata,
|
models=valid_token.team_models,
|
||||||
)
|
metadata=valid_token.team_metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_team_obj = None
|
||||||
|
|
||||||
user_api_key_cache.set_cache(
|
user_api_key_cache.set_cache(
|
||||||
key=valid_token.team_id, value=_team_obj
|
key=valid_token.team_id, value=_team_obj
|
||||||
|
@ -1045,7 +1053,7 @@ async def user_api_key_auth(
|
||||||
"/global/predict/spend/logs",
|
"/global/predict/spend/logs",
|
||||||
"/global/activity",
|
"/global/activity",
|
||||||
"/health/services",
|
"/health/services",
|
||||||
] + LiteLLMRoutes.info_routes.value
|
] + LiteLLMRoutes.info_routes.value # type: ignore
|
||||||
# check if the current route startswith any of the allowed routes
|
# check if the current route startswith any of the allowed routes
|
||||||
if (
|
if (
|
||||||
route is not None
|
route is not None
|
||||||
|
@ -1106,7 +1114,7 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
# Log this exception to OTEL
|
# Log this exception to OTEL
|
||||||
if open_telemetry_logger is not None:
|
if open_telemetry_logger is not None:
|
||||||
await open_telemetry_logger.async_post_call_failure_hook(
|
await open_telemetry_logger.async_post_call_failure_hook( # type: ignore
|
||||||
original_exception=e,
|
original_exception=e,
|
||||||
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
|
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,10 +4,11 @@ import json
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
@ -28,6 +29,7 @@ from litellm.proxy._types import (
|
||||||
ProxyErrorTypes,
|
ProxyErrorTypes,
|
||||||
ProxyException,
|
ProxyException,
|
||||||
TeamAddMemberResponse,
|
TeamAddMemberResponse,
|
||||||
|
TeamBase,
|
||||||
TeamInfoResponseObject,
|
TeamInfoResponseObject,
|
||||||
TeamMemberAddRequest,
|
TeamMemberAddRequest,
|
||||||
TeamMemberDeleteRequest,
|
TeamMemberDeleteRequest,
|
||||||
|
@ -36,6 +38,7 @@ from litellm.proxy._types import (
|
||||||
UpdateTeamRequest,
|
UpdateTeamRequest,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.auth.auth_checks import get_team_object
|
||||||
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
|
||||||
from litellm.proxy.management_helpers.utils import (
|
from litellm.proxy.management_helpers.utils import (
|
||||||
add_new_member,
|
add_new_member,
|
||||||
|
@ -240,7 +243,7 @@ async def new_team(
|
||||||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||||
complete_team_data.budget_reset_at = reset_at
|
complete_team_data.budget_reset_at = reset_at
|
||||||
|
|
||||||
team_row = await prisma_client.insert_data(
|
team_row: LiteLLM_TeamTable = await prisma_client.insert_data( # type: ignore
|
||||||
data=complete_team_data.json(exclude_none=True), table_name="team"
|
data=complete_team_data.json(exclude_none=True), table_name="team"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -462,9 +465,6 @@ async def team_member_add(
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
_duration_in_seconds,
|
|
||||||
create_audit_log_for_update,
|
|
||||||
get_team_object,
|
|
||||||
litellm_proxy_admin_name,
|
litellm_proxy_admin_name,
|
||||||
prisma_client,
|
prisma_client,
|
||||||
proxy_logging_obj,
|
proxy_logging_obj,
|
||||||
|
@ -932,10 +932,13 @@ async def delete_team(
|
||||||
if litellm.store_audit_logs is True:
|
if litellm.store_audit_logs is True:
|
||||||
# make an audit log for each team deleted
|
# make an audit log for each team deleted
|
||||||
for team_id in data.team_ids:
|
for team_id in data.team_ids:
|
||||||
team_row = await prisma_client.get_data( # type: ignore
|
team_row: Optional[LiteLLM_TeamTable] = await prisma_client.get_data( # type: ignore
|
||||||
team_id=team_id, table_name="team", query_type="find_unique"
|
team_id=team_id, table_name="team", query_type="find_unique"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if team_row is None:
|
||||||
|
continue
|
||||||
|
|
||||||
_team_row = team_row.json(exclude_none=True)
|
_team_row = team_row.json(exclude_none=True)
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
@ -1027,8 +1030,10 @@ async def team_info(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
team_info = await prisma_client.get_data(
|
team_info: Optional[Union[LiteLLM_TeamTable, dict]] = (
|
||||||
team_id=team_id, table_name="team", query_type="find_unique"
|
await prisma_client.get_data(
|
||||||
|
team_id=team_id, table_name="team", query_type="find_unique"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if team_info is None:
|
if team_info is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -1044,6 +1049,9 @@ async def team_info(
|
||||||
expires=datetime.now(),
|
expires=datetime.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if keys is None:
|
||||||
|
keys = []
|
||||||
|
|
||||||
if team_info is None:
|
if team_info is None:
|
||||||
## make sure we still return a total spend ##
|
## make sure we still return a total spend ##
|
||||||
spend = 0
|
spend = 0
|
||||||
|
@ -1055,7 +1063,7 @@ async def team_info(
|
||||||
for key in keys:
|
for key in keys:
|
||||||
try:
|
try:
|
||||||
key = key.model_dump() # noqa
|
key = key.model_dump() # noqa
|
||||||
except:
|
except Exception:
|
||||||
# if using pydantic v1
|
# if using pydantic v1
|
||||||
key = key.dict()
|
key = key.dict()
|
||||||
key.pop("token", None)
|
key.pop("token", None)
|
||||||
|
@ -1070,9 +1078,16 @@ async def team_info(
|
||||||
for tm in team_memberships:
|
for tm in team_memberships:
|
||||||
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
|
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
|
||||||
|
|
||||||
|
if isinstance(team_info, dict):
|
||||||
|
_team_info = LiteLLM_TeamTable(**team_info)
|
||||||
|
elif isinstance(team_info, BaseModel):
|
||||||
|
_team_info = LiteLLM_TeamTable(**team_info.model_dump())
|
||||||
|
else:
|
||||||
|
_team_info = LiteLLM_TeamTable()
|
||||||
|
|
||||||
response_object = TeamInfoResponseObject(
|
response_object = TeamInfoResponseObject(
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
team_info=team_info,
|
team_info=_team_info,
|
||||||
keys=keys,
|
keys=keys,
|
||||||
team_memberships=returned_tm,
|
team_memberships=returned_tm,
|
||||||
)
|
)
|
||||||
|
|
|
@ -125,16 +125,7 @@ from litellm.proxy._types import *
|
||||||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||||
router as analytics_router,
|
router as analytics_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.auth_checks import (
|
from litellm.proxy.auth.auth_checks import log_to_opentelemetry
|
||||||
allowed_routes_check,
|
|
||||||
common_checks,
|
|
||||||
get_actual_routes,
|
|
||||||
get_end_user_object,
|
|
||||||
get_org_object,
|
|
||||||
get_team_object,
|
|
||||||
get_user_object,
|
|
||||||
log_to_opentelemetry,
|
|
||||||
)
|
|
||||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||||
|
@ -260,6 +251,7 @@ from litellm.secret_managers.aws_secret_manager import (
|
||||||
load_aws_secret_manager,
|
load_aws_secret_manager,
|
||||||
)
|
)
|
||||||
from litellm.secret_managers.google_kms import load_google_kms
|
from litellm.secret_managers.google_kms import load_google_kms
|
||||||
|
from litellm.secret_managers.main import get_secret
|
||||||
from litellm.types.llms.anthropic import (
|
from litellm.types.llms.anthropic import (
|
||||||
AnthropicMessagesRequest,
|
AnthropicMessagesRequest,
|
||||||
AnthropicResponse,
|
AnthropicResponse,
|
||||||
|
@ -484,7 +476,7 @@ general_settings: dict = {}
|
||||||
callback_settings: dict = {}
|
callback_settings: dict = {}
|
||||||
log_file = "api_log.json"
|
log_file = "api_log.json"
|
||||||
worker_config = None
|
worker_config = None
|
||||||
master_key = None
|
master_key: Optional[str] = None
|
||||||
otel_logging = False
|
otel_logging = False
|
||||||
prisma_client: Optional[PrismaClient] = None
|
prisma_client: Optional[PrismaClient] = None
|
||||||
custom_db_client: Optional[DBClient] = None
|
custom_db_client: Optional[DBClient] = None
|
||||||
|
@ -874,7 +866,9 @@ def error_tracking():
|
||||||
|
|
||||||
|
|
||||||
def _set_spend_logs_payload(
|
def _set_spend_logs_payload(
|
||||||
payload: dict, prisma_client: PrismaClient, spend_logs_url: Optional[str] = None
|
payload: Union[dict, SpendLogsPayload],
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
spend_logs_url: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if prisma_client is not None and spend_logs_url is not None:
|
if prisma_client is not None and spend_logs_url is not None:
|
||||||
if isinstance(payload["startTime"], datetime):
|
if isinstance(payload["startTime"], datetime):
|
||||||
|
@ -1341,6 +1335,9 @@ async def _run_background_health_check():
|
||||||
# make 1 deep copy of llm_model_list -> use this for all background health checks
|
# make 1 deep copy of llm_model_list -> use this for all background health checks
|
||||||
_llm_model_list = copy.deepcopy(llm_model_list)
|
_llm_model_list = copy.deepcopy(llm_model_list)
|
||||||
|
|
||||||
|
if _llm_model_list is None:
|
||||||
|
return
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
||||||
model_list=_llm_model_list, details=health_check_details
|
model_list=_llm_model_list, details=health_check_details
|
||||||
|
@ -1352,7 +1349,10 @@ async def _run_background_health_check():
|
||||||
health_check_results["healthy_count"] = len(healthy_endpoints)
|
health_check_results["healthy_count"] = len(healthy_endpoints)
|
||||||
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
|
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
|
||||||
|
|
||||||
await asyncio.sleep(health_check_interval)
|
if health_check_interval is not None and isinstance(
|
||||||
|
health_check_interval, float
|
||||||
|
):
|
||||||
|
await asyncio.sleep(health_check_interval)
|
||||||
|
|
||||||
|
|
||||||
class ProxyConfig:
|
class ProxyConfig:
|
||||||
|
@ -1467,7 +1467,7 @@ class ProxyConfig:
|
||||||
break
|
break
|
||||||
for k, v in team_config.items():
|
for k, v in team_config.items():
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
team_config[k] = litellm.get_secret(v)
|
team_config[k] = get_secret(v)
|
||||||
return team_config
|
return team_config
|
||||||
|
|
||||||
def _init_cache(
|
def _init_cache(
|
||||||
|
@ -1513,6 +1513,9 @@ class ProxyConfig:
|
||||||
config = get_file_contents_from_s3(
|
config = get_file_contents_from_s3(
|
||||||
bucket_name=bucket_name, object_key=object_key
|
bucket_name=bucket_name, object_key=object_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
raise Exception("Unable to load config from given source.")
|
||||||
else:
|
else:
|
||||||
# default to file
|
# default to file
|
||||||
config = await self.get_config(config_file_path=config_file_path)
|
config = await self.get_config(config_file_path=config_file_path)
|
||||||
|
@ -1528,9 +1531,7 @@ class ProxyConfig:
|
||||||
environment_variables = config.get("environment_variables", None)
|
environment_variables = config.get("environment_variables", None)
|
||||||
if environment_variables:
|
if environment_variables:
|
||||||
for key, value in environment_variables.items():
|
for key, value in environment_variables.items():
|
||||||
os.environ[key] = str(
|
os.environ[key] = str(get_secret(secret_name=key, default_value=value))
|
||||||
litellm.get_secret(secret_name=key, default_value=value)
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if litellm_license in general_settings
|
# check if litellm_license in general_settings
|
||||||
if "LITELLM_LICENSE" in environment_variables:
|
if "LITELLM_LICENSE" in environment_variables:
|
||||||
|
@ -1566,8 +1567,8 @@ class ProxyConfig:
|
||||||
if (
|
if (
|
||||||
cache_type == "redis" or cache_type == "redis-semantic"
|
cache_type == "redis" or cache_type == "redis-semantic"
|
||||||
) and len(cache_params.keys()) == 0:
|
) and len(cache_params.keys()) == 0:
|
||||||
cache_host = litellm.get_secret("REDIS_HOST", None)
|
cache_host = get_secret("REDIS_HOST", None)
|
||||||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
cache_port = get_secret("REDIS_PORT", None)
|
||||||
cache_password = None
|
cache_password = None
|
||||||
cache_params.update(
|
cache_params.update(
|
||||||
{
|
{
|
||||||
|
@ -1577,8 +1578,8 @@ class ProxyConfig:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if litellm.get_secret("REDIS_PASSWORD", None) is not None:
|
if get_secret("REDIS_PASSWORD", None) is not None:
|
||||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
cache_password = get_secret("REDIS_PASSWORD", None)
|
||||||
cache_params.update(
|
cache_params.update(
|
||||||
{
|
{
|
||||||
"password": cache_password,
|
"password": cache_password,
|
||||||
|
@ -1617,7 +1618,7 @@ class ProxyConfig:
|
||||||
# users can pass os.environ/ variables on the proxy - we should read them from the env
|
# users can pass os.environ/ variables on the proxy - we should read them from the env
|
||||||
for key, value in cache_params.items():
|
for key, value in cache_params.items():
|
||||||
if type(value) is str and value.startswith("os.environ/"):
|
if type(value) is str and value.startswith("os.environ/"):
|
||||||
cache_params[key] = litellm.get_secret(value)
|
cache_params[key] = get_secret(value)
|
||||||
|
|
||||||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||||
self._init_cache(cache_params=cache_params)
|
self._init_cache(cache_params=cache_params)
|
||||||
|
@ -1738,7 +1739,7 @@ class ProxyConfig:
|
||||||
if value is not None and isinstance(value, dict):
|
if value is not None and isinstance(value, dict):
|
||||||
for _k, _v in value.items():
|
for _k, _v in value.items():
|
||||||
if isinstance(_v, str) and _v.startswith("os.environ/"):
|
if isinstance(_v, str) and _v.startswith("os.environ/"):
|
||||||
value[_k] = litellm.get_secret(_v)
|
value[_k] = get_secret(_v)
|
||||||
litellm.upperbound_key_generate_params = (
|
litellm.upperbound_key_generate_params = (
|
||||||
LiteLLM_UpperboundKeyGenerateParams(**value)
|
LiteLLM_UpperboundKeyGenerateParams(**value)
|
||||||
)
|
)
|
||||||
|
@ -1812,15 +1813,15 @@ class ProxyConfig:
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
if database_url and database_url.startswith("os.environ/"):
|
if database_url and database_url.startswith("os.environ/"):
|
||||||
verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!")
|
verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!")
|
||||||
database_url = litellm.get_secret(database_url)
|
database_url = get_secret(database_url)
|
||||||
verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url)
|
verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url)
|
||||||
### MASTER KEY ###
|
### MASTER KEY ###
|
||||||
master_key = general_settings.get(
|
master_key = general_settings.get(
|
||||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
"master_key", get_secret("LITELLM_MASTER_KEY", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
master_key = litellm.get_secret(master_key)
|
master_key = get_secret(master_key) # type: ignore
|
||||||
if not isinstance(master_key, str):
|
if not isinstance(master_key, str):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Master key must be a string. Current type - {}".format(
|
"Master key must be a string. Current type - {}".format(
|
||||||
|
@ -1861,33 +1862,6 @@ class ProxyConfig:
|
||||||
await initialize_pass_through_endpoints(
|
await initialize_pass_through_endpoints(
|
||||||
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
||||||
)
|
)
|
||||||
## dynamodb
|
|
||||||
database_type = general_settings.get("database_type", None)
|
|
||||||
if database_type is not None and (
|
|
||||||
database_type == "dynamo_db" or database_type == "dynamodb"
|
|
||||||
):
|
|
||||||
database_args = general_settings.get("database_args", None)
|
|
||||||
### LOAD FROM os.environ/ ###
|
|
||||||
for k, v in database_args.items():
|
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
|
||||||
database_args[k] = litellm.get_secret(v)
|
|
||||||
if isinstance(k, str) and k == "aws_web_identity_token":
|
|
||||||
value = database_args[k]
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"Loading AWS Web Identity Token from file: {value}"
|
|
||||||
)
|
|
||||||
if os.path.exists(value):
|
|
||||||
with open(value, "r") as file:
|
|
||||||
token_content = file.read()
|
|
||||||
database_args[k] = token_content
|
|
||||||
else:
|
|
||||||
verbose_proxy_logger.info(
|
|
||||||
f"DynamoDB Loading - {value} is not a valid file path"
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug("database_args: %s", database_args)
|
|
||||||
custom_db_client = DBClient(
|
|
||||||
custom_db_args=database_args, custom_db_type=database_type
|
|
||||||
)
|
|
||||||
## ADMIN UI ACCESS ##
|
## ADMIN UI ACCESS ##
|
||||||
ui_access_mode = general_settings.get(
|
ui_access_mode = general_settings.get(
|
||||||
"ui_access_mode", "all"
|
"ui_access_mode", "all"
|
||||||
|
@ -1951,7 +1925,7 @@ class ProxyConfig:
|
||||||
### LOAD FROM os.environ/ ###
|
### LOAD FROM os.environ/ ###
|
||||||
for k, v in model["litellm_params"].items():
|
for k, v in model["litellm_params"].items():
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
model["litellm_params"][k] = litellm.get_secret(v)
|
model["litellm_params"][k] = get_secret(v)
|
||||||
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
|
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
|
||||||
litellm_model_name = model["litellm_params"]["model"]
|
litellm_model_name = model["litellm_params"]["model"]
|
||||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||||
|
@ -2005,7 +1979,10 @@ class ProxyConfig:
|
||||||
) # type:ignore
|
) # type:ignore
|
||||||
|
|
||||||
# Guardrail settings
|
# Guardrail settings
|
||||||
guardrails_v2 = config.get("guardrails", None)
|
guardrails_v2: Optional[dict] = None
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
guardrails_v2 = config.get("guardrails", None)
|
||||||
if guardrails_v2:
|
if guardrails_v2:
|
||||||
init_guardrails_v2(
|
init_guardrails_v2(
|
||||||
all_guardrails=guardrails_v2, config_file_path=config_file_path
|
all_guardrails=guardrails_v2, config_file_path=config_file_path
|
||||||
|
@ -2074,7 +2051,7 @@ class ProxyConfig:
|
||||||
### LOAD FROM os.environ/ ###
|
### LOAD FROM os.environ/ ###
|
||||||
for k, v in model["litellm_params"].items():
|
for k, v in model["litellm_params"].items():
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
model["litellm_params"][k] = litellm.get_secret(v)
|
model["litellm_params"][k] = get_secret(v)
|
||||||
|
|
||||||
## check if they have model-id's ##
|
## check if they have model-id's ##
|
||||||
model_id = model.get("model_info", {}).get("id", None)
|
model_id = model.get("model_info", {}).get("id", None)
|
||||||
|
@ -2234,7 +2211,8 @@ class ProxyConfig:
|
||||||
for k, v in environment_variables.items():
|
for k, v in environment_variables.items():
|
||||||
try:
|
try:
|
||||||
decrypted_value = decrypt_value_helper(value=v)
|
decrypted_value = decrypt_value_helper(value=v)
|
||||||
os.environ[k] = decrypted_value
|
if decrypted_value is not None:
|
||||||
|
os.environ[k] = decrypted_value
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"Error setting env variable: %s - %s", k, str(e)
|
"Error setting env variable: %s - %s", k, str(e)
|
||||||
|
@ -2536,7 +2514,7 @@ async def async_assistants_data_generator(
|
||||||
)
|
)
|
||||||
|
|
||||||
# chunk = chunk.model_dump_json(exclude_none=True)
|
# chunk = chunk.model_dump_json(exclude_none=True)
|
||||||
async for c in chunk:
|
async for c in chunk: # type: ignore
|
||||||
c = c.model_dump_json(exclude_none=True)
|
c = c.model_dump_json(exclude_none=True)
|
||||||
try:
|
try:
|
||||||
yield f"data: {c}\n\n"
|
yield f"data: {c}\n\n"
|
||||||
|
@ -2745,17 +2723,22 @@ async def startup_event():
|
||||||
|
|
||||||
### LOAD MASTER KEY ###
|
### LOAD MASTER KEY ###
|
||||||
# check if master key set in environment - load from there
|
# check if master key set in environment - load from there
|
||||||
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
|
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
|
||||||
# check if DATABASE_URL in environment - load from there
|
# check if DATABASE_URL in environment - load from there
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None))
|
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
|
||||||
|
prisma_setup(database_url=_db_url)
|
||||||
|
|
||||||
### LOAD CONFIG ###
|
### LOAD CONFIG ###
|
||||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
|
||||||
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
verbose_proxy_logger.debug("worker_config: %s", worker_config)
|
||||||
# check if it's a valid file path
|
# check if it's a valid file path
|
||||||
if os.path.isfile(worker_config):
|
if worker_config is not None:
|
||||||
if proxy_config.is_yaml(config_file_path=worker_config):
|
if (
|
||||||
|
isinstance(worker_config, str)
|
||||||
|
and os.path.isfile(worker_config)
|
||||||
|
and proxy_config.is_yaml(config_file_path=worker_config)
|
||||||
|
):
|
||||||
(
|
(
|
||||||
llm_router,
|
llm_router,
|
||||||
llm_model_list,
|
llm_model_list,
|
||||||
|
@ -2763,21 +2746,23 @@ async def startup_event():
|
||||||
) = await proxy_config.load_config(
|
) = await proxy_config.load_config(
|
||||||
router=llm_router, config_file_path=worker_config
|
router=llm_router, config_file_path=worker_config
|
||||||
)
|
)
|
||||||
else:
|
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
|
||||||
|
worker_config, str
|
||||||
|
):
|
||||||
|
(
|
||||||
|
llm_router,
|
||||||
|
llm_model_list,
|
||||||
|
general_settings,
|
||||||
|
) = await proxy_config.load_config(
|
||||||
|
router=llm_router, config_file_path=worker_config
|
||||||
|
)
|
||||||
|
elif isinstance(worker_config, dict):
|
||||||
await initialize(**worker_config)
|
await initialize(**worker_config)
|
||||||
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
else:
|
||||||
(
|
# if not, assume it's a json string
|
||||||
llm_router,
|
worker_config = json.loads(worker_config)
|
||||||
llm_model_list,
|
if isinstance(worker_config, dict):
|
||||||
general_settings,
|
await initialize(**worker_config)
|
||||||
) = await proxy_config.load_config(
|
|
||||||
router=llm_router, config_file_path=worker_config
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# if not, assume it's a json string
|
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
|
||||||
await initialize(**worker_config)
|
|
||||||
|
|
||||||
## CHECK PREMIUM USER
|
## CHECK PREMIUM USER
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -2825,7 +2810,7 @@ async def startup_event():
|
||||||
if general_settings.get("litellm_jwtauth", None) is not None:
|
if general_settings.get("litellm_jwtauth", None) is not None:
|
||||||
for k, v in general_settings["litellm_jwtauth"].items():
|
for k, v in general_settings["litellm_jwtauth"].items():
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
general_settings["litellm_jwtauth"][k] = litellm.get_secret(v)
|
general_settings["litellm_jwtauth"][k] = get_secret(v)
|
||||||
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
|
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
|
||||||
else:
|
else:
|
||||||
litellm_jwtauth = LiteLLM_JWTAuth()
|
litellm_jwtauth = LiteLLM_JWTAuth()
|
||||||
|
@ -2948,8 +2933,7 @@ async def startup_event():
|
||||||
|
|
||||||
### ADD NEW MODELS ###
|
### ADD NEW MODELS ###
|
||||||
store_model_in_db = (
|
store_model_in_db = (
|
||||||
litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db)
|
get_secret("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db
|
||||||
or store_model_in_db
|
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
if store_model_in_db == True:
|
if store_model_in_db == True:
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
|
@ -3498,7 +3482,7 @@ async def completion(
|
||||||
)
|
)
|
||||||
### CALL HOOKS ### - modify outgoing data
|
### CALL HOOKS ### - modify outgoing data
|
||||||
response = await proxy_logging_obj.post_call_success_hook(
|
response = await proxy_logging_obj.post_call_success_hook(
|
||||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
fastapi_response.headers.update(
|
||||||
|
@ -4000,7 +3984,7 @@ async def audio_speech(
|
||||||
request_data=data,
|
request_data=data,
|
||||||
)
|
)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate(response), media_type="audio/mpeg", headers=custom_headers
|
generate(response), media_type="audio/mpeg", headers=custom_headers # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -4288,6 +4272,7 @@ async def create_assistant(
|
||||||
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
|
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
|
||||||
"""
|
"""
|
||||||
global proxy_logging_obj
|
global proxy_logging_obj
|
||||||
|
data = {} # ensure data always dict
|
||||||
try:
|
try:
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -7642,6 +7627,7 @@ async def model_group_info(
|
||||||
)
|
)
|
||||||
|
|
||||||
model_groups: List[ModelGroupInfo] = []
|
model_groups: List[ModelGroupInfo] = []
|
||||||
|
|
||||||
for model in all_models_str:
|
for model in all_models_str:
|
||||||
|
|
||||||
_model_group_info = llm_router.get_model_group_info(model_group=model)
|
_model_group_info = llm_router.get_model_group_info(model_group=model)
|
||||||
|
@ -8051,7 +8037,8 @@ async def google_login(request: Request):
|
||||||
with microsoft_sso:
|
with microsoft_sso:
|
||||||
return await microsoft_sso.get_login_redirect()
|
return await microsoft_sso.get_login_redirect()
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
|
from fastapi_sso.sso.base import DiscoveryDocument
|
||||||
|
from fastapi_sso.sso.generic import create_provider
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||||
|
@ -8616,6 +8603,8 @@ async def auth_callback(request: Request):
|
||||||
redirect_url += "sso/callback"
|
redirect_url += "sso/callback"
|
||||||
else:
|
else:
|
||||||
redirect_url += "/sso/callback"
|
redirect_url += "/sso/callback"
|
||||||
|
|
||||||
|
result = None
|
||||||
if google_client_id is not None:
|
if google_client_id is not None:
|
||||||
from fastapi_sso.sso.google import GoogleSSO
|
from fastapi_sso.sso.google import GoogleSSO
|
||||||
|
|
||||||
|
@ -8662,7 +8651,8 @@ async def auth_callback(request: Request):
|
||||||
result = await microsoft_sso.verify_and_process(request)
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
# make generic sso provider
|
# make generic sso provider
|
||||||
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
|
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
|
||||||
|
from fastapi_sso.sso.generic import create_provider
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||||
|
@ -8766,8 +8756,8 @@ async def auth_callback(request: Request):
|
||||||
verbose_proxy_logger.debug("generic result: %s", result)
|
verbose_proxy_logger.debug("generic result: %s", result)
|
||||||
|
|
||||||
# User is Authe'd in - generate key for the UI to access Proxy
|
# User is Authe'd in - generate key for the UI to access Proxy
|
||||||
user_email = getattr(result, "email", None)
|
user_email: Optional[str] = getattr(result, "email", None)
|
||||||
user_id = getattr(result, "id", None)
|
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
||||||
|
|
||||||
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
|
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
|
||||||
email_domain = user_email.split("@")[1]
|
email_domain = user_email.split("@")[1]
|
||||||
|
@ -8783,12 +8773,12 @@ async def auth_callback(request: Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
# generic client id
|
# generic client id
|
||||||
if generic_client_id is not None:
|
if generic_client_id is not None and result is not None:
|
||||||
user_id = getattr(result, "id", None)
|
user_id = getattr(result, "id", None)
|
||||||
user_email = getattr(result, "email", None)
|
user_email = getattr(result, "email", None)
|
||||||
user_role = getattr(result, generic_user_role_attribute_name, None)
|
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
|
||||||
|
|
||||||
if user_id is None:
|
if user_id is None and result is not None:
|
||||||
_first_name = getattr(result, "first_name", "") or ""
|
_first_name = getattr(result, "first_name", "") or ""
|
||||||
_last_name = getattr(result, "last_name", "") or ""
|
_last_name = getattr(result, "last_name", "") or ""
|
||||||
user_id = _first_name + _last_name
|
user_id = _first_name + _last_name
|
||||||
|
@ -8811,54 +8801,45 @@ async def auth_callback(request: Request):
|
||||||
"spend": 0,
|
"spend": 0,
|
||||||
"team_id": "litellm-dashboard",
|
"team_id": "litellm-dashboard",
|
||||||
}
|
}
|
||||||
user_defined_values: SSOUserDefinedValues = {
|
user_defined_values: Optional[SSOUserDefinedValues] = None
|
||||||
"models": user_id_models,
|
if user_id is not None:
|
||||||
"user_id": user_id,
|
user_defined_values = SSOUserDefinedValues(
|
||||||
"user_email": user_email,
|
models=user_id_models,
|
||||||
"max_budget": max_internal_user_budget,
|
user_id=user_id,
|
||||||
"user_role": None,
|
user_email=user_email,
|
||||||
"budget_duration": internal_user_budget_duration,
|
max_budget=max_internal_user_budget,
|
||||||
}
|
user_role=None,
|
||||||
|
budget_duration=internal_user_budget_duration,
|
||||||
|
)
|
||||||
|
|
||||||
_user_id_from_sso = user_id
|
_user_id_from_sso = user_id
|
||||||
|
user_role = None
|
||||||
try:
|
try:
|
||||||
user_role = None
|
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
|
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
|
||||||
)
|
)
|
||||||
if user_info is not None:
|
if user_info is None:
|
||||||
user_defined_values = {
|
## check if user-email in db ##
|
||||||
"models": getattr(user_info, "models", user_id_models),
|
user_info = await prisma_client.db.litellm_usertable.find_first(
|
||||||
"user_id": getattr(user_info, "user_id", user_id),
|
where={"user_email": user_email}
|
||||||
"user_email": getattr(user_info, "user_id", user_email),
|
)
|
||||||
"user_role": getattr(user_info, "user_role", None),
|
|
||||||
"max_budget": getattr(
|
|
||||||
user_info, "max_budget", max_internal_user_budget
|
|
||||||
),
|
|
||||||
"budget_duration": getattr(
|
|
||||||
user_info, "budget_duration", internal_user_budget_duration
|
|
||||||
),
|
|
||||||
}
|
|
||||||
user_role = getattr(user_info, "user_role", None)
|
|
||||||
|
|
||||||
## check if user-email in db ##
|
if user_info is not None and user_id is not None:
|
||||||
user_info = await prisma_client.db.litellm_usertable.find_first(
|
user_defined_values = SSOUserDefinedValues(
|
||||||
where={"user_email": user_email}
|
models=getattr(user_info, "models", user_id_models),
|
||||||
)
|
user_id=user_id,
|
||||||
if user_info is not None:
|
user_email=getattr(user_info, "user_email", user_email),
|
||||||
user_defined_values = {
|
user_role=getattr(user_info, "user_role", None),
|
||||||
"models": getattr(user_info, "models", user_id_models),
|
max_budget=getattr(
|
||||||
"user_id": user_id,
|
|
||||||
"user_email": getattr(user_info, "user_id", user_email),
|
|
||||||
"user_role": getattr(user_info, "user_role", None),
|
|
||||||
"max_budget": getattr(
|
|
||||||
user_info, "max_budget", max_internal_user_budget
|
user_info, "max_budget", max_internal_user_budget
|
||||||
),
|
),
|
||||||
"budget_duration": getattr(
|
budget_duration=getattr(
|
||||||
user_info, "budget_duration", internal_user_budget_duration
|
user_info, "budget_duration", internal_user_budget_duration
|
||||||
),
|
),
|
||||||
}
|
)
|
||||||
|
|
||||||
user_role = getattr(user_info, "user_role", None)
|
user_role = getattr(user_info, "user_role", None)
|
||||||
|
|
||||||
# update id
|
# update id
|
||||||
|
@ -8886,6 +8867,11 @@ async def auth_callback(request: Request):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if user_defined_values is None:
|
||||||
|
raise Exception(
|
||||||
|
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
|
||||||
|
)
|
||||||
|
|
||||||
is_internal_user = False
|
is_internal_user = False
|
||||||
if (
|
if (
|
||||||
user_defined_values["user_role"] is not None
|
user_defined_values["user_role"] is not None
|
||||||
|
@ -8960,7 +8946,8 @@ async def auth_callback(request: Request):
|
||||||
master_key,
|
master_key,
|
||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
)
|
)
|
||||||
litellm_dashboard_ui += "?userID=" + user_id
|
if user_id is not None and isinstance(user_id, str):
|
||||||
|
litellm_dashboard_ui += "?userID=" + user_id
|
||||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||||
redirect_response.set_cookie(key="token", value=jwt_token)
|
redirect_response.set_cookie(key="token", value=jwt_token)
|
||||||
return redirect_response
|
return redirect_response
|
||||||
|
@ -9023,6 +9010,7 @@ async def new_invitation(
|
||||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
} # type: ignore
|
} # type: ignore
|
||||||
)
|
)
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "Foreign key constraint failed on the field" in str(e):
|
if "Foreign key constraint failed on the field" in str(e):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -9031,7 +9019,7 @@ async def new_invitation(
|
||||||
"error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
|
"error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return response
|
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
@ -9951,44 +9939,46 @@ async def get_routes():
|
||||||
"""
|
"""
|
||||||
routes = []
|
routes = []
|
||||||
for route in app.routes:
|
for route in app.routes:
|
||||||
route_info = {
|
endpoint_route = getattr(route, "endpoint", None)
|
||||||
"path": getattr(route, "path", None),
|
if endpoint_route is not None:
|
||||||
"methods": getattr(route, "methods", None),
|
route_info = {
|
||||||
"name": getattr(route, "name", None),
|
"path": getattr(route, "path", None),
|
||||||
"endpoint": (
|
"methods": getattr(route, "methods", None),
|
||||||
getattr(route, "endpoint", None).__name__
|
"name": getattr(route, "name", None),
|
||||||
if getattr(route, "endpoint", None)
|
"endpoint": (
|
||||||
else None
|
endpoint_route.__name__
|
||||||
),
|
if getattr(route, "endpoint", None)
|
||||||
}
|
else None
|
||||||
routes.append(route_info)
|
),
|
||||||
|
}
|
||||||
|
routes.append(route_info)
|
||||||
|
|
||||||
return {"routes": routes}
|
return {"routes": routes}
|
||||||
|
|
||||||
|
|
||||||
#### TEST ENDPOINTS ####
|
#### TEST ENDPOINTS ####
|
||||||
@router.get(
|
# @router.get(
|
||||||
"/token/generate",
|
# "/token/generate",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
# dependencies=[Depends(user_api_key_auth)],
|
||||||
include_in_schema=False,
|
# include_in_schema=False,
|
||||||
)
|
# )
|
||||||
async def token_generate():
|
# async def token_generate():
|
||||||
"""
|
# """
|
||||||
Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
|
# Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
|
||||||
"""
|
# """
|
||||||
# Initialize AuthJWTSSO with your OpenID Provider configuration
|
# # Initialize AuthJWTSSO with your OpenID Provider configuration
|
||||||
from fastapi_sso import AuthJWTSSO
|
# from fastapi_sso import AuthJWTSSO
|
||||||
|
|
||||||
auth_jwt_sso = AuthJWTSSO(
|
# auth_jwt_sso = AuthJWTSSO(
|
||||||
issuer=os.getenv("OPENID_BASE_URL"),
|
# issuer=os.getenv("OPENID_BASE_URL"),
|
||||||
client_id=os.getenv("OPENID_CLIENT_ID"),
|
# client_id=os.getenv("OPENID_CLIENT_ID"),
|
||||||
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
|
# client_secret=os.getenv("OPENID_CLIENT_SECRET"),
|
||||||
scopes=["litellm_proxy_admin"],
|
# scopes=["litellm_proxy_admin"],
|
||||||
)
|
# )
|
||||||
|
|
||||||
token = auth_jwt_sso.create_access_token()
|
# token = auth_jwt_sso.create_access_token()
|
||||||
|
|
||||||
return {"token": token}
|
# return {"token": token}
|
||||||
|
|
||||||
|
|
||||||
@router.on_event("shutdown")
|
@router.on_event("shutdown")
|
||||||
|
@ -10013,7 +10003,8 @@ async def shutdown_event():
|
||||||
# flush langfuse logs on shutdow
|
# flush langfuse logs on shutdow
|
||||||
from litellm.utils import langFuseLogger
|
from litellm.utils import langFuseLogger
|
||||||
|
|
||||||
langFuseLogger.Langfuse.flush()
|
if langFuseLogger is not None:
|
||||||
|
langFuseLogger.Langfuse.flush()
|
||||||
except:
|
except:
|
||||||
# [DO NOT BLOCK shutdown events for this]
|
# [DO NOT BLOCK shutdown events for this]
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -92,6 +92,7 @@ from litellm.types.router import (
|
||||||
RetryPolicy,
|
RetryPolicy,
|
||||||
RouterErrors,
|
RouterErrors,
|
||||||
RouterGeneralSettings,
|
RouterGeneralSettings,
|
||||||
|
RouterModelGroupAliasItem,
|
||||||
RouterRateLimitError,
|
RouterRateLimitError,
|
||||||
RouterRateLimitErrorBasic,
|
RouterRateLimitErrorBasic,
|
||||||
updateDeployment,
|
updateDeployment,
|
||||||
|
@ -105,6 +106,7 @@ from litellm.utils import (
|
||||||
calculate_max_parallel_requests,
|
calculate_max_parallel_requests,
|
||||||
create_proxy_transport_and_mounts,
|
create_proxy_transport_and_mounts,
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
|
get_secret,
|
||||||
get_utc_datetime,
|
get_utc_datetime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -156,7 +158,9 @@ class Router:
|
||||||
fallbacks: List = [],
|
fallbacks: List = [],
|
||||||
context_window_fallbacks: List = [],
|
context_window_fallbacks: List = [],
|
||||||
content_policy_fallbacks: List = [],
|
content_policy_fallbacks: List = [],
|
||||||
model_group_alias: Optional[dict] = {},
|
model_group_alias: Optional[
|
||||||
|
Dict[str, Union[str, RouterModelGroupAliasItem]]
|
||||||
|
] = {},
|
||||||
enable_pre_call_checks: bool = False,
|
enable_pre_call_checks: bool = False,
|
||||||
enable_tag_filtering: bool = False,
|
enable_tag_filtering: bool = False,
|
||||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||||
|
@ -331,7 +335,8 @@ class Router:
|
||||||
self.set_model_list(model_list)
|
self.set_model_list(model_list)
|
||||||
self.healthy_deployments: List = self.model_list # type: ignore
|
self.healthy_deployments: List = self.model_list # type: ignore
|
||||||
for m in model_list:
|
for m in model_list:
|
||||||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
if "model" in m["litellm_params"]:
|
||||||
|
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||||
else:
|
else:
|
||||||
self.model_list: List = (
|
self.model_list: List = (
|
||||||
[]
|
[]
|
||||||
|
@ -398,7 +403,7 @@ class Router:
|
||||||
self.previous_models: List = (
|
self.previous_models: List = (
|
||||||
[]
|
[]
|
||||||
) # list to store failed calls (passed in as metadata to next call)
|
) # list to store failed calls (passed in as metadata to next call)
|
||||||
self.model_group_alias: dict = (
|
self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
|
||||||
model_group_alias or {}
|
model_group_alias or {}
|
||||||
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
|
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
|
||||||
|
|
||||||
|
@ -1179,6 +1184,7 @@ class Router:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _image_generation(self, prompt: str, model: str, **kwargs):
|
def _image_generation(self, prompt: str, model: str, **kwargs):
|
||||||
|
model_name = ""
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
||||||
|
@ -1269,6 +1275,7 @@ class Router:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _aimage_generation(self, prompt: str, model: str, **kwargs):
|
async def _aimage_generation(self, prompt: str, model: str, **kwargs):
|
||||||
|
model_name = ""
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
||||||
|
@ -1401,6 +1408,7 @@ class Router:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _atranscription(self, file: FileTypes, model: str, **kwargs):
|
async def _atranscription(self, file: FileTypes, model: str, **kwargs):
|
||||||
|
model_name = model
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
|
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
|
||||||
|
@ -1781,6 +1789,7 @@ class Router:
|
||||||
is_async: Optional[bool] = False,
|
is_async: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["prompt"] = prompt
|
kwargs["prompt"] = prompt
|
||||||
|
@ -1789,7 +1798,6 @@ class Router:
|
||||||
timeout = kwargs.get("request_timeout", self.timeout)
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(
|
deployment = self.get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2534,13 +2542,13 @@ class Router:
|
||||||
try:
|
try:
|
||||||
# Update kwargs with the current model name or any other model-specific adjustments
|
# Update kwargs with the current model name or any other model-specific adjustments
|
||||||
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
|
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
||||||
model=model_name["litellm_params"]["model"]
|
model=model_name["litellm_params"]["model"]
|
||||||
)
|
)
|
||||||
new_kwargs = copy.deepcopy(kwargs)
|
new_kwargs = copy.deepcopy(kwargs)
|
||||||
new_kwargs.pop("custom_llm_provider", None)
|
new_kwargs.pop("custom_llm_provider", None)
|
||||||
return await litellm.aretrieve_batch(
|
return await litellm.aretrieve_batch(
|
||||||
custom_llm_provider=custom_llm_provider, **new_kwargs
|
custom_llm_provider=custom_llm_provider, **new_kwargs # type: ignore
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
receieved_exceptions.append(e)
|
receieved_exceptions.append(e)
|
||||||
|
@ -2616,13 +2624,13 @@ class Router:
|
||||||
for result in results:
|
for result in results:
|
||||||
if result is not None:
|
if result is not None:
|
||||||
## check batch id
|
## check batch id
|
||||||
if final_results["first_id"] is None:
|
if final_results["first_id"] is None and hasattr(result, "first_id"):
|
||||||
final_results["first_id"] = result.first_id
|
final_results["first_id"] = getattr(result, "first_id")
|
||||||
final_results["last_id"] = result.last_id
|
final_results["last_id"] = getattr(result, "last_id")
|
||||||
final_results["data"].extend(result.data) # type: ignore
|
final_results["data"].extend(result.data) # type: ignore
|
||||||
|
|
||||||
## check 'has_more'
|
## check 'has_more'
|
||||||
if result.has_more is True:
|
if getattr(result, "has_more", False) is True:
|
||||||
final_results["has_more"] = True
|
final_results["has_more"] = True
|
||||||
|
|
||||||
return final_results
|
return final_results
|
||||||
|
@ -2874,8 +2882,12 @@ class Router:
|
||||||
verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
|
verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
|
||||||
original_exception = e
|
original_exception = e
|
||||||
fallback_model_group = None
|
fallback_model_group = None
|
||||||
original_model_group = kwargs.get("model")
|
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
|
||||||
fallback_failure_exception_str = ""
|
fallback_failure_exception_str = ""
|
||||||
|
|
||||||
|
if original_model_group is None:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug("Trying to fallback b/w models")
|
verbose_router_logger.debug("Trying to fallback b/w models")
|
||||||
if isinstance(e, litellm.ContextWindowExceededError):
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
|
@ -2972,7 +2984,7 @@ class Router:
|
||||||
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||||
)
|
)
|
||||||
if hasattr(original_exception, "message"):
|
if hasattr(original_exception, "message"):
|
||||||
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
response = await run_async_fallback(
|
response = await run_async_fallback(
|
||||||
|
@ -2996,12 +3008,12 @@ class Router:
|
||||||
|
|
||||||
if hasattr(original_exception, "message"):
|
if hasattr(original_exception, "message"):
|
||||||
# add the available fallbacks to the exception
|
# add the available fallbacks to the exception
|
||||||
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format(
|
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
|
||||||
model_group,
|
model_group,
|
||||||
fallback_model_group,
|
fallback_model_group,
|
||||||
)
|
)
|
||||||
if len(fallback_failure_exception_str) > 0:
|
if len(fallback_failure_exception_str) > 0:
|
||||||
original_exception.message += (
|
original_exception.message += ( # type: ignore
|
||||||
"\nError doing the fallback: {}".format(
|
"\nError doing the fallback: {}".format(
|
||||||
fallback_failure_exception_str
|
fallback_failure_exception_str
|
||||||
)
|
)
|
||||||
|
@ -3117,9 +3129,15 @@ class Router:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
_healthy_deployments, _ = await self._async_get_healthy_deployments(
|
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||||
model=kwargs.get("model"),
|
if _model is not None:
|
||||||
)
|
_healthy_deployments, _ = (
|
||||||
|
await self._async_get_healthy_deployments(
|
||||||
|
model=_model,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_healthy_deployments = []
|
||||||
_timeout = self._time_to_sleep_before_retry(
|
_timeout = self._time_to_sleep_before_retry(
|
||||||
e=original_exception,
|
e=original_exception,
|
||||||
remaining_retries=remaining_retries,
|
remaining_retries=remaining_retries,
|
||||||
|
@ -3129,8 +3147,8 @@ class Router:
|
||||||
await asyncio.sleep(_timeout)
|
await asyncio.sleep(_timeout)
|
||||||
|
|
||||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||||
original_exception.max_retries = num_retries
|
setattr(original_exception, "max_retries", num_retries)
|
||||||
original_exception.num_retries = current_attempt
|
setattr(original_exception, "num_retries", current_attempt)
|
||||||
|
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
|
@ -3225,8 +3243,12 @@ class Router:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
original_model_group = kwargs.get("model")
|
original_model_group: Optional[str] = kwargs.get("model")
|
||||||
verbose_router_logger.debug(f"An exception occurs {original_exception}")
|
verbose_router_logger.debug(f"An exception occurs {original_exception}")
|
||||||
|
|
||||||
|
if original_model_group is None:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Trying to fallback b/w models. Initial model group: {model_group}"
|
f"Trying to fallback b/w models. Initial model group: {model_group}"
|
||||||
|
@ -3336,10 +3358,10 @@ class Router:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
response_headers: Optional[httpx.Headers] = None
|
response_headers: Optional[httpx.Headers] = None
|
||||||
if hasattr(e, "response") and hasattr(e.response, "headers"):
|
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
|
||||||
response_headers = e.response.headers
|
response_headers = e.response.headers # type: ignore
|
||||||
elif hasattr(e, "litellm_response_headers"):
|
elif hasattr(e, "litellm_response_headers"):
|
||||||
response_headers = e.litellm_response_headers
|
response_headers = e.litellm_response_headers # type: ignore
|
||||||
|
|
||||||
if response_headers is not None:
|
if response_headers is not None:
|
||||||
timeout = litellm._calculate_retry_after(
|
timeout = litellm._calculate_retry_after(
|
||||||
|
@ -3398,9 +3420,13 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_attempt = None
|
current_attempt = None
|
||||||
original_exception = e
|
original_exception = e
|
||||||
|
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||||
|
|
||||||
|
if _model is None:
|
||||||
|
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||||
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
|
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
|
||||||
model=kwargs.get("model"),
|
model=_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# raises an exception if this error should not be retries
|
# raises an exception if this error should not be retries
|
||||||
|
@ -3438,8 +3464,12 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||||
|
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||||
|
|
||||||
|
if _model is None:
|
||||||
|
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||||
_healthy_deployments, _ = self._get_healthy_deployments(
|
_healthy_deployments, _ = self._get_healthy_deployments(
|
||||||
model=kwargs.get("model"),
|
model=_model,
|
||||||
)
|
)
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
_timeout = self._time_to_sleep_before_retry(
|
_timeout = self._time_to_sleep_before_retry(
|
||||||
|
@ -4055,7 +4085,7 @@ class Router:
|
||||||
if isinstance(_litellm_params, dict):
|
if isinstance(_litellm_params, dict):
|
||||||
for k, v in _litellm_params.items():
|
for k, v in _litellm_params.items():
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
_litellm_params[k] = litellm.get_secret(v)
|
_litellm_params[k] = get_secret(v)
|
||||||
|
|
||||||
_model_info: dict = model.pop("model_info", {})
|
_model_info: dict = model.pop("model_info", {})
|
||||||
|
|
||||||
|
@ -4392,7 +4422,6 @@ class Router:
|
||||||
- ModelGroupInfo if able to construct a model group
|
- ModelGroupInfo if able to construct a model group
|
||||||
- None if error constructing model group info
|
- None if error constructing model group info
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_group_info: Optional[ModelGroupInfo] = None
|
model_group_info: Optional[ModelGroupInfo] = None
|
||||||
|
|
||||||
total_tpm: Optional[int] = None
|
total_tpm: Optional[int] = None
|
||||||
|
@ -4557,12 +4586,23 @@ class Router:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- ModelGroupInfo if able to construct a model group
|
- ModelGroupInfo if able to construct a model group
|
||||||
- None if error constructing model group info
|
- None if error constructing model group info or hidden model group
|
||||||
"""
|
"""
|
||||||
## Check if model group alias
|
## Check if model group alias
|
||||||
if model_group in self.model_group_alias:
|
if model_group in self.model_group_alias:
|
||||||
|
item = self.model_group_alias[model_group]
|
||||||
|
if isinstance(item, str):
|
||||||
|
_router_model_group = item
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
if item["hidden"] is True:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
_router_model_group = item["model"]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
return self._set_model_group_info(
|
return self._set_model_group_info(
|
||||||
model_group=self.model_group_alias[model_group],
|
model_group=_router_model_group,
|
||||||
user_facing_model_group_name=model_group,
|
user_facing_model_group_name=model_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4666,7 +4706,14 @@ class Router:
|
||||||
|
|
||||||
Includes model_group_alias models too.
|
Includes model_group_alias models too.
|
||||||
"""
|
"""
|
||||||
return self.model_names + list(self.model_group_alias.keys())
|
model_list = self.get_model_list()
|
||||||
|
if model_list is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
model_names = []
|
||||||
|
for m in model_list:
|
||||||
|
model_names.append(m["model_name"])
|
||||||
|
return model_names
|
||||||
|
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
self, model_name: Optional[str] = None
|
self, model_name: Optional[str] = None
|
||||||
|
@ -4678,9 +4725,21 @@ class Router:
|
||||||
returned_models: List[DeploymentTypedDict] = []
|
returned_models: List[DeploymentTypedDict] = []
|
||||||
|
|
||||||
for model_alias, model_value in self.model_group_alias.items():
|
for model_alias, model_value in self.model_group_alias.items():
|
||||||
|
|
||||||
|
if isinstance(model_value, str):
|
||||||
|
_router_model_name: str = model_value
|
||||||
|
elif isinstance(model_value, dict):
|
||||||
|
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
|
||||||
|
if _model_value["hidden"] is True:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
_router_model_name = _model_value["model"]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
returned_models.extend(
|
returned_models.extend(
|
||||||
self._get_all_deployments(
|
self._get_all_deployments(
|
||||||
model_name=model_value, model_alias=model_alias
|
model_name=_router_model_name, model_alias=model_alias
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5078,10 +5137,11 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
if model in self.model_group_alias:
|
if model in self.model_group_alias:
|
||||||
verbose_router_logger.debug(
|
_item = self.model_group_alias[model]
|
||||||
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
|
if isinstance(_item, str):
|
||||||
)
|
model = _item
|
||||||
model = self.model_group_alias[model]
|
else:
|
||||||
|
model = _item["model"]
|
||||||
|
|
||||||
if model not in self.model_names:
|
if model not in self.model_names:
|
||||||
# check if provider/ specific wildcard routing
|
# check if provider/ specific wildcard routing
|
||||||
|
@ -5124,7 +5184,9 @@ class Router:
|
||||||
m for m in self.model_list if m["litellm_params"]["model"] == model
|
m for m in self.model_list if m["litellm_params"]["model"] == model
|
||||||
]
|
]
|
||||||
|
|
||||||
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
verbose_router_logger.debug(
|
||||||
|
f"initial list of deployments: {healthy_deployments}"
|
||||||
|
)
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -5208,7 +5270,7 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if user wants to do tag based routing
|
# check if user wants to do tag based routing
|
||||||
healthy_deployments = await get_deployments_for_tag(
|
healthy_deployments = await get_deployments_for_tag( # type: ignore
|
||||||
llm_router_instance=self,
|
llm_router_instance=self,
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments,
|
||||||
|
@ -5241,7 +5303,7 @@ class Router:
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if (
|
elif (
|
||||||
self.routing_strategy == "cost-based-routing"
|
self.routing_strategy == "cost-based-routing"
|
||||||
and self.lowestcost_logger is not None
|
and self.lowestcost_logger is not None
|
||||||
):
|
):
|
||||||
|
@ -5326,6 +5388,8 @@ class Router:
|
||||||
############## No RPM/TPM passed, we do a random pick #################
|
############## No RPM/TPM passed, we do a random pick #################
|
||||||
item = random.choice(healthy_deployments)
|
item = random.choice(healthy_deployments)
|
||||||
return item or item[0]
|
return item or item[0]
|
||||||
|
else:
|
||||||
|
deployment = None
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, No deployment available"
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
|
@ -5515,6 +5579,9 @@ class Router:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
deployment = None
|
||||||
|
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"get_available_deployment for model: {model}, No deployment available"
|
f"get_available_deployment for model: {model}, No deployment available"
|
||||||
|
@ -5690,6 +5757,9 @@ class Router:
|
||||||
def _initialize_alerting(self):
|
def _initialize_alerting(self):
|
||||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||||
|
|
||||||
|
if self.alerting_config is None:
|
||||||
|
return
|
||||||
|
|
||||||
router_alerting_config: AlertingConfig = self.alerting_config
|
router_alerting_config: AlertingConfig = self.alerting_config
|
||||||
|
|
||||||
_slack_alerting_logger = SlackAlerting(
|
_slack_alerting_logger = SlackAlerting(
|
||||||
|
@ -5700,7 +5770,7 @@ class Router:
|
||||||
|
|
||||||
self.slack_alerting_logger = _slack_alerting_logger
|
self.slack_alerting_logger = _slack_alerting_logger
|
||||||
|
|
||||||
litellm.callbacks.append(_slack_alerting_logger)
|
litellm.callbacks.append(_slack_alerting_logger) # type: ignore
|
||||||
litellm.success_callback.append(
|
litellm.success_callback.append(
|
||||||
_slack_alerting_logger.response_taking_too_long_callback
|
_slack_alerting_logger.response_taking_too_long_callback
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,8 +21,8 @@ else:
|
||||||
|
|
||||||
async def get_deployments_for_tag(
|
async def get_deployments_for_tag(
|
||||||
llm_router_instance: LitellmRouter,
|
llm_router_instance: LitellmRouter,
|
||||||
|
healthy_deployments: Union[List[Any], Dict[Any, Any]],
|
||||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
|
||||||
):
|
):
|
||||||
if llm_router_instance.enable_tag_filtering is not True:
|
if llm_router_instance.enable_tag_filtering is not True:
|
||||||
return healthy_deployments
|
return healthy_deployments
|
||||||
|
|
|
@ -2828,3 +2828,44 @@ def test_gemini_function_call_parameter_in_messages_2():
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"base_model, metadata",
|
||||||
|
[
|
||||||
|
(None, {"model_info": {"base_model": "vertex_ai/gemini-1.5-pro"}}),
|
||||||
|
("vertex_ai/gemini-1.5-pro", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_gemini_finetuned_endpoint(base_model, metadata):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
# Set up the messages
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": """Use search for most queries."""},
|
||||||
|
{"role": "user", "content": """search for weather in boston (use `search`)"""},
|
||||||
|
]
|
||||||
|
|
||||||
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
|
||||||
|
with patch.object(client, "post", new=MagicMock()) as mock_client:
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai/4965075652664360960",
|
||||||
|
messages=messages,
|
||||||
|
tool_choice="auto",
|
||||||
|
client=client,
|
||||||
|
metadata=metadata,
|
||||||
|
base_model=base_model,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
print(mock_client.call_args.kwargs)
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
assert mock_client.call_args.kwargs["url"].endswith(
|
||||||
|
"endpoints/4965075652664360960:generateContent"
|
||||||
|
)
|
||||||
|
|
|
@ -311,6 +311,8 @@ async def test_completion_predibase():
|
||||||
pass
|
pass
|
||||||
except litellm.ServiceUnavailableError as e:
|
except litellm.ServiceUnavailableError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.InternalServerError:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -133,3 +133,21 @@ async def test_fireworks_health_check():
|
||||||
assert response == {}
|
assert response == {}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_rerank_health_check():
|
||||||
|
response = await litellm.ahealth_check(
|
||||||
|
model_params={
|
||||||
|
"model": "cohere/rerank-english-v3.0",
|
||||||
|
"query": "Hey, how's it going",
|
||||||
|
"documents": ["my sample text"],
|
||||||
|
"api_key": os.getenv("COHERE_API_KEY"),
|
||||||
|
},
|
||||||
|
mode="rerank",
|
||||||
|
prompt="Hey, how's it going",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "error" not in response
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
|
@ -50,6 +50,7 @@ def test_image_generation_openai():
|
||||||
], # False
|
], # False
|
||||||
) #
|
) #
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
async def test_image_generation_azure(sync_mode):
|
async def test_image_generation_azure(sync_mode):
|
||||||
try:
|
try:
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
|
|
|
@ -533,7 +533,7 @@ def test_call_with_user_over_budget(prisma_client):
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
pytest.fail(f"This should have failed!. They key crossed it's budget")
|
pytest.fail("This should have failed!. They key crossed it's budget")
|
||||||
|
|
||||||
asyncio.run(test())
|
asyncio.run(test())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1755,7 +1755,7 @@ def test_call_with_key_over_model_budget(prisma_client):
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
pytest.fail(f"This should have failed!. They key crossed it's budget")
|
pytest.fail("This should have failed!. They key crossed it's budget")
|
||||||
|
|
||||||
asyncio.run(test())
|
asyncio.run(test())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -589,3 +589,14 @@ def test_parse_additional_properties_json_schema(model, provider, expectedAddPro
|
||||||
elif provider == "openai":
|
elif provider == "openai":
|
||||||
schema = optional_params["response_format"]["json_schema"]["schema"]
|
schema = optional_params["response_format"]["json_schema"]["schema"]
|
||||||
assert ("additionalProperties" in schema) == expectedAddProp
|
assert ("additionalProperties" in schema) == expectedAddProp
|
||||||
|
|
||||||
|
|
||||||
|
def test_o1_model_params():
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
model="o1-preview-2024-09-12",
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
seed=10,
|
||||||
|
user="John",
|
||||||
|
)
|
||||||
|
assert optional_params["seed"] == 10
|
||||||
|
assert optional_params["user"] == "John"
|
||||||
|
|
|
@ -762,7 +762,7 @@ async def test_team_update_redis():
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
await _cache_team_object(
|
await _cache_team_object(
|
||||||
team_id="1234",
|
team_id="1234",
|
||||||
team_table=LiteLLM_TeamTableCachedObj(),
|
team_table=LiteLLM_TeamTableCachedObj(team_id="1234"),
|
||||||
user_api_key_cache=DualCache(),
|
user_api_key_cache=DualCache(),
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
@ -776,7 +776,7 @@ async def test_get_team_redis(client_no_auth):
|
||||||
Tests if get_team_object gets value from redis cache, if set
|
Tests if get_team_object gets value from redis cache, if set
|
||||||
"""
|
"""
|
||||||
from litellm.caching import DualCache, RedisCache
|
from litellm.caching import DualCache, RedisCache
|
||||||
from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object
|
from litellm.proxy.auth.auth_checks import get_team_object
|
||||||
|
|
||||||
proxy_logging_obj: ProxyLogging = getattr(
|
proxy_logging_obj: ProxyLogging = getattr(
|
||||||
litellm.proxy.proxy_server, "proxy_logging_obj"
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
||||||
|
@ -917,7 +917,9 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
||||||
)
|
)
|
||||||
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
||||||
|
|
||||||
team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj())
|
team_mock_client.update = AsyncMock(
|
||||||
|
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
|
||||||
|
)
|
||||||
|
|
||||||
await team_member_add(
|
await team_member_add(
|
||||||
data=team_member_add_request,
|
data=team_member_add_request,
|
||||||
|
@ -1095,7 +1097,9 @@ async def test_create_team_member_add_team_admin(
|
||||||
)
|
)
|
||||||
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
litellm.proxy.proxy_server.prisma_client.db.litellm_teamtable = team_mock_client
|
||||||
|
|
||||||
team_mock_client.update = AsyncMock(return_value=LiteLLM_TeamTableCachedObj())
|
team_mock_client.update = AsyncMock(
|
||||||
|
return_value=LiteLLM_TeamTableCachedObj(team_id="1234")
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await team_member_add(
|
await team_member_add(
|
||||||
|
@ -1434,8 +1438,9 @@ async def test_gemini_pass_through_endpoint():
|
||||||
print(resp.body)
|
print(resp.body)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("hidden", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_proxy_model_group_alias_checks(prisma_client):
|
async def test_proxy_model_group_alias_checks(prisma_client, hidden):
|
||||||
"""
|
"""
|
||||||
Check if model group alias is returned on
|
Check if model group alias is returned on
|
||||||
|
|
||||||
|
@ -1465,7 +1470,7 @@ async def test_proxy_model_group_alias_checks(prisma_client):
|
||||||
model_alias = "gpt-4"
|
model_alias = "gpt-4"
|
||||||
router = litellm.Router(
|
router = litellm.Router(
|
||||||
model_list=_model_list,
|
model_list=_model_list,
|
||||||
model_group_alias={model_alias: "gpt-3.5-turbo"},
|
model_group_alias={model_alias: {"model": "gpt-3.5-turbo", "hidden": hidden}},
|
||||||
)
|
)
|
||||||
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
||||||
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
||||||
|
@ -1477,7 +1482,10 @@ async def test_proxy_model_group_alias_checks(prisma_client):
|
||||||
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(resp) == 2
|
if hidden:
|
||||||
|
assert len(resp["data"]) == 1
|
||||||
|
else:
|
||||||
|
assert len(resp["data"]) == 2
|
||||||
print(resp)
|
print(resp)
|
||||||
|
|
||||||
resp = await model_info_v1(
|
resp = await model_info_v1(
|
||||||
|
@ -1489,7 +1497,10 @@ async def test_proxy_model_group_alias_checks(prisma_client):
|
||||||
if model_alias == item["model_name"]:
|
if model_alias == item["model_name"]:
|
||||||
is_model_alias_in_list = True
|
is_model_alias_in_list = True
|
||||||
|
|
||||||
assert is_model_alias_in_list
|
if hidden:
|
||||||
|
assert is_model_alias_in_list is False
|
||||||
|
else:
|
||||||
|
assert is_model_alias_in_list
|
||||||
|
|
||||||
resp = await model_group_info(
|
resp = await model_group_info(
|
||||||
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||||
|
@ -1500,4 +1511,7 @@ async def test_proxy_model_group_alias_checks(prisma_client):
|
||||||
if model_alias == item.model_group:
|
if model_alias == item.model_group:
|
||||||
is_model_alias_in_list = True
|
is_model_alias_in_list = True
|
||||||
|
|
||||||
assert is_model_alias_in_list
|
if hidden:
|
||||||
|
assert is_model_alias_in_list is False
|
||||||
|
else:
|
||||||
|
assert is_model_alias_in_list, f"models: {models}"
|
||||||
|
|
|
@ -2481,3 +2481,31 @@ async def test_router_batch_endpoints(provider):
|
||||||
model="my-custom-name", custom_llm_provider=provider, limit=2
|
model="my-custom-name", custom_llm_provider=provider, limit=2
|
||||||
)
|
)
|
||||||
print("list_batches=", list_batches)
|
print("list_batches=", list_batches)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("hidden", [True, False])
|
||||||
|
def test_model_group_alias(hidden):
|
||||||
|
_model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||||
|
},
|
||||||
|
{"model_name": "gpt-4", "litellm_params": {"model": "gpt-4"}},
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=_model_list,
|
||||||
|
model_group_alias={
|
||||||
|
"gpt-4.5-turbo": {"model": "gpt-3.5-turbo", "hidden": hidden}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
models = router.get_model_list()
|
||||||
|
|
||||||
|
model_names = router.get_model_names()
|
||||||
|
|
||||||
|
if hidden:
|
||||||
|
assert len(models) == len(_model_list)
|
||||||
|
assert len(model_names) == len(_model_list)
|
||||||
|
else:
|
||||||
|
assert len(models) == len(_model_list) + 1
|
||||||
|
assert len(model_names) == len(_model_list) + 1
|
||||||
|
|
|
@ -582,3 +582,8 @@ class RouterRateLimitError(ValueError):
|
||||||
self.cooldown_list = cooldown_list
|
self.cooldown_list = cooldown_list
|
||||||
_message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}"
|
_message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}"
|
||||||
super().__init__(_message)
|
super().__init__(_message)
|
||||||
|
|
||||||
|
|
||||||
|
class RouterModelGroupAliasItem(TypedDict):
|
||||||
|
model: str
|
||||||
|
hidden: bool # if 'True', don't return on `.get_model_list`
|
||||||
|
|
|
@ -76,6 +76,7 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionNamedToolChoiceParam,
|
ChatCompletionNamedToolChoiceParam,
|
||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
)
|
)
|
||||||
|
from litellm.types.utils import FileTypes # type: ignore
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
CallTypes,
|
CallTypes,
|
||||||
ChatCompletionDeltaToolCall,
|
ChatCompletionDeltaToolCall,
|
||||||
|
@ -84,7 +85,6 @@ from litellm.types.utils import (
|
||||||
Delta,
|
Delta,
|
||||||
Embedding,
|
Embedding,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
FileTypes,
|
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
Message,
|
Message,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
|
@ -2339,6 +2339,7 @@ def get_litellm_params(
|
||||||
text_completion=None,
|
text_completion=None,
|
||||||
azure_ad_token_provider=None,
|
azure_ad_token_provider=None,
|
||||||
user_continue_message=None,
|
user_continue_message=None,
|
||||||
|
base_model=None,
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -2365,6 +2366,8 @@ def get_litellm_params(
|
||||||
"text_completion": text_completion,
|
"text_completion": text_completion,
|
||||||
"azure_ad_token_provider": azure_ad_token_provider,
|
"azure_ad_token_provider": azure_ad_token_provider,
|
||||||
"user_continue_message": user_continue_message,
|
"user_continue_message": user_continue_message,
|
||||||
|
"base_model": base_model
|
||||||
|
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
@ -6063,11 +6066,11 @@ def _calculate_retry_after(
|
||||||
max_retries: int,
|
max_retries: int,
|
||||||
response_headers: Optional[httpx.Headers] = None,
|
response_headers: Optional[httpx.Headers] = None,
|
||||||
min_timeout: int = 0,
|
min_timeout: int = 0,
|
||||||
):
|
) -> Union[float, int]:
|
||||||
retry_after = _get_retry_after_from_exception_header(response_headers)
|
retry_after = _get_retry_after_from_exception_header(response_headers)
|
||||||
|
|
||||||
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
|
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
|
||||||
if 0 < retry_after <= 60:
|
if retry_after is not None and 0 < retry_after <= 60:
|
||||||
return retry_after
|
return retry_after
|
||||||
|
|
||||||
initial_retry_delay = 0.5
|
initial_retry_delay = 0.5
|
||||||
|
@ -10962,6 +10965,22 @@ def get_logging_id(start_time, response_obj):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_model_from_litellm_call_metadata(
|
||||||
|
metadata: Optional[dict],
|
||||||
|
) -> Optional[str]:
|
||||||
|
if metadata is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
model_info = metadata.get("model_info", {})
|
||||||
|
|
||||||
|
if model_info is not None:
|
||||||
|
base_model = model_info.get("base_model", None)
|
||||||
|
if base_model is not None:
|
||||||
|
return base_model
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _get_base_model_from_metadata(model_call_details=None):
|
def _get_base_model_from_metadata(model_call_details=None):
|
||||||
if model_call_details is None:
|
if model_call_details is None:
|
||||||
return None
|
return None
|
||||||
|
@ -10970,13 +10989,7 @@ def _get_base_model_from_metadata(model_call_details=None):
|
||||||
if litellm_params is not None:
|
if litellm_params is not None:
|
||||||
metadata = litellm_params.get("metadata", {})
|
metadata = litellm_params.get("metadata", {})
|
||||||
|
|
||||||
if metadata is not None:
|
return _get_base_model_from_litellm_call_metadata(metadata=metadata)
|
||||||
model_info = metadata.get("model_info", {})
|
|
||||||
|
|
||||||
if model_info is not None:
|
|
||||||
base_model = model_info.get("base_model", None)
|
|
||||||
if base_model is not None:
|
|
||||||
return base_model
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
6
pyrightconfig.json
Normal file
6
pyrightconfig.json
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
{
|
||||||
|
"ignore": [],
|
||||||
|
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py"],
|
||||||
|
"reportMissingImports": false
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue