forked from phoenix/litellm-mirror
Compare commits
21 commits
main
...
litellm_de
Author | SHA1 | Date | |
---|---|---|---|
|
b06b0248ff | ||
|
4021206ac2 | ||
|
88db948d29 | ||
|
b5c5f87e25 | ||
|
ed64dd7f9d | ||
|
e2ccf6b7f4 | ||
|
f8a46b5950 | ||
|
31943bf2ad | ||
|
541326731f | ||
|
dfb34dfe92 | ||
|
250d66b335 | ||
|
94fe135524 | ||
|
1a3fb18a64 | ||
|
d788c3c37f | ||
|
4beb48829c | ||
|
eb0a357eda | ||
|
463fa0c9d5 | ||
|
1014216d73 | ||
|
97d8aa0b3a | ||
|
5a698c678a | ||
|
1c9a8c0b68 |
35 changed files with 871 additions and 248 deletions
|
@ -41,7 +41,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea
|
||||||
|
|
||||||
| Provider | temperature | max_completion_tokens | max_tokens | top_p | stream | stream_options | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|
| Provider | temperature | max_completion_tokens | max_tokens | top_p | stream | stream_options | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|
||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
||||||
|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ |
|
|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | | ✅ | ✅ | | | ✅ |
|
||||||
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|
||||||
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||||
|
|
74
docs/my-website/docs/guides/finetuned_models.md
Normal file
74
docs/my-website/docs/guides/finetuned_models.md
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
|
||||||
|
# Calling Finetuned Models
|
||||||
|
|
||||||
|
## OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|---------------------------|-----------------------------------------------------------------|
|
||||||
|
| fine tuned `gpt-4-0613` | `response = completion(model="ft:gpt-4-0613", messages=messages)` |
|
||||||
|
| fine tuned `gpt-4o-2024-05-13` | `response = completion(model="ft:gpt-4o-2024-05-13", messages=messages)` |
|
||||||
|
| fine tuned `gpt-3.5-turbo-0125` | `response = completion(model="ft:gpt-3.5-turbo-0125", messages=messages)` |
|
||||||
|
| fine tuned `gpt-3.5-turbo-1106` | `response = completion(model="ft:gpt-3.5-turbo-1106", messages=messages)` |
|
||||||
|
| fine tuned `gpt-3.5-turbo-0613` | `response = completion(model="ft:gpt-3.5-turbo-0613", messages=messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
## Vertex AI
|
||||||
|
|
||||||
|
Fine tuned models on vertex have a numerical model/endpoint id.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
|
||||||
|
os.environ["VERTEXAI_LOCATION"] = "us-central1"
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai/<your-finetuned-model>", # e.g. vertex_ai/4965075652664360960
|
||||||
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add Vertex Credentials to your env
|
||||||
|
|
||||||
|
```bash
|
||||||
|
!gcloud auth application-default login
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- model_name: finetuned-gemini
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/<ENDPOINT_ID>
|
||||||
|
vertex_project: <PROJECT_ID>
|
||||||
|
vertex_location: <LOCATION>
|
||||||
|
model_info:
|
||||||
|
base_model: vertex_ai/gemini-1.5-pro # IMPORTANT
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'https://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: <LITELLM_KEY>' \
|
||||||
|
--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
|
@ -754,6 +754,8 @@ general_settings:
|
||||||
| cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. |
|
| cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. |
|
||||||
| cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) |
|
| cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) |
|
||||||
| cache_params.mode | string | The mode of the cache. [Further docs](./caching) |
|
| cache_params.mode | string | The mode of the cache. [Further docs](./caching) |
|
||||||
|
| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. |
|
||||||
|
| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) |
|
||||||
|
|
||||||
### general_settings - Reference
|
### general_settings - Reference
|
||||||
|
|
||||||
|
|
|
@ -217,4 +217,10 @@ litellm_settings:
|
||||||
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
|
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
|
||||||
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
|
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
|
||||||
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
|
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
|
||||||
```
|
|
||||||
|
key_generation_settings: # Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation)
|
||||||
|
team_key_generation:
|
||||||
|
allowed_team_member_roles: ["admin"]
|
||||||
|
personal_key_generation: # maps to 'Default Team' on UI
|
||||||
|
allowed_user_roles: ["proxy_admin"]
|
||||||
|
```
|
||||||
|
|
|
@ -811,6 +811,75 @@ litellm_settings:
|
||||||
team_id: "core-infra"
|
team_id: "core-infra"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Restricting Key Generation
|
||||||
|
|
||||||
|
Use this to control who can generate keys. Useful when letting others create keys on the UI.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
key_generation_settings:
|
||||||
|
team_key_generation:
|
||||||
|
allowed_team_member_roles: ["admin"]
|
||||||
|
personal_key_generation: # maps to 'Default Team' on UI
|
||||||
|
allowed_user_roles: ["proxy_admin"]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Spec
|
||||||
|
|
||||||
|
```python
|
||||||
|
class TeamUIKeyGenerationConfig(TypedDict):
|
||||||
|
allowed_team_member_roles: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalUIKeyGenerationConfig(TypedDict):
|
||||||
|
allowed_user_roles: List[LitellmUserRoles]
|
||||||
|
|
||||||
|
|
||||||
|
class StandardKeyGenerationConfig(TypedDict, total=False):
|
||||||
|
team_key_generation: TeamUIKeyGenerationConfig
|
||||||
|
personal_key_generation: PersonalUIKeyGenerationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LitellmUserRoles(str, enum.Enum):
|
||||||
|
"""
|
||||||
|
Admin Roles:
|
||||||
|
PROXY_ADMIN: admin over the platform
|
||||||
|
PROXY_ADMIN_VIEW_ONLY: can login, view all own keys, view all spend
|
||||||
|
ORG_ADMIN: admin over a specific organization, can create teams, users only within their organization
|
||||||
|
|
||||||
|
Internal User Roles:
|
||||||
|
INTERNAL_USER: can login, view/create/delete their own keys, view their spend
|
||||||
|
INTERNAL_USER_VIEW_ONLY: can login, view their own keys, view their own spend
|
||||||
|
|
||||||
|
|
||||||
|
Team Roles:
|
||||||
|
TEAM: used for JWT auth
|
||||||
|
|
||||||
|
|
||||||
|
Customer Roles:
|
||||||
|
CUSTOMER: External users -> these are customers
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Admin Roles
|
||||||
|
PROXY_ADMIN = "proxy_admin"
|
||||||
|
PROXY_ADMIN_VIEW_ONLY = "proxy_admin_viewer"
|
||||||
|
|
||||||
|
# Organization admins
|
||||||
|
ORG_ADMIN = "org_admin"
|
||||||
|
|
||||||
|
# Internal User Roles
|
||||||
|
INTERNAL_USER = "internal_user"
|
||||||
|
INTERNAL_USER_VIEW_ONLY = "internal_user_viewer"
|
||||||
|
|
||||||
|
# Team Roles
|
||||||
|
TEAM = "team"
|
||||||
|
|
||||||
|
# Customer Roles - External users of proxy
|
||||||
|
CUSTOMER = "customer"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## **Next Steps - Set Budgets, Rate Limits per Virtual Key**
|
## **Next Steps - Set Budgets, Rate Limits per Virtual Key**
|
||||||
|
|
||||||
[Follow this doc to set budgets, rate limiters per virtual key with LiteLLM](users)
|
[Follow this doc to set budgets, rate limiters per virtual key with LiteLLM](users)
|
||||||
|
|
|
@ -199,6 +199,31 @@ const sidebars = {
|
||||||
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
type: "category",
|
||||||
|
label: "Guides",
|
||||||
|
items: [
|
||||||
|
"exception_mapping",
|
||||||
|
"completion/provider_specific_params",
|
||||||
|
"guides/finetuned_models",
|
||||||
|
"completion/audio",
|
||||||
|
"completion/vision",
|
||||||
|
"completion/json_mode",
|
||||||
|
"completion/prompt_caching",
|
||||||
|
"completion/predict_outputs",
|
||||||
|
"completion/prefix",
|
||||||
|
"completion/drop_params",
|
||||||
|
"completion/prompt_formatting",
|
||||||
|
"completion/stream",
|
||||||
|
"completion/message_trimming",
|
||||||
|
"completion/function_call",
|
||||||
|
"completion/model_alias",
|
||||||
|
"completion/batching",
|
||||||
|
"completion/mock_requests",
|
||||||
|
"completion/reliable_completions",
|
||||||
|
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Supported Endpoints",
|
label: "Supported Endpoints",
|
||||||
|
@ -214,25 +239,8 @@ const sidebars = {
|
||||||
},
|
},
|
||||||
items: [
|
items: [
|
||||||
"completion/input",
|
"completion/input",
|
||||||
"completion/provider_specific_params",
|
|
||||||
"completion/json_mode",
|
|
||||||
"completion/prompt_caching",
|
|
||||||
"completion/audio",
|
|
||||||
"completion/vision",
|
|
||||||
"completion/predict_outputs",
|
|
||||||
"completion/prefix",
|
|
||||||
"completion/drop_params",
|
|
||||||
"completion/prompt_formatting",
|
|
||||||
"completion/output",
|
"completion/output",
|
||||||
"completion/usage",
|
"completion/usage",
|
||||||
"exception_mapping",
|
|
||||||
"completion/stream",
|
|
||||||
"completion/message_trimming",
|
|
||||||
"completion/function_call",
|
|
||||||
"completion/model_alias",
|
|
||||||
"completion/batching",
|
|
||||||
"completion/mock_requests",
|
|
||||||
"completion/reliable_completions",
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"embedding/supported_embedding",
|
"embedding/supported_embedding",
|
||||||
|
|
|
@ -24,6 +24,7 @@ from litellm.proxy._types import (
|
||||||
KeyManagementSettings,
|
KeyManagementSettings,
|
||||||
LiteLLM_UpperboundKeyGenerateParams,
|
LiteLLM_UpperboundKeyGenerateParams,
|
||||||
)
|
)
|
||||||
|
from litellm.types.utils import StandardKeyGenerationConfig
|
||||||
import httpx
|
import httpx
|
||||||
import dotenv
|
import dotenv
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -273,6 +274,7 @@ s3_callback_params: Optional[Dict] = None
|
||||||
generic_logger_headers: Optional[Dict] = None
|
generic_logger_headers: Optional[Dict] = None
|
||||||
default_key_generate_params: Optional[Dict] = None
|
default_key_generate_params: Optional[Dict] = None
|
||||||
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
||||||
|
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
|
||||||
default_internal_user_params: Optional[Dict] = None
|
default_internal_user_params: Optional[Dict] = None
|
||||||
default_team_settings: Optional[List] = None
|
default_team_settings: Optional[List] = None
|
||||||
max_user_budget: Optional[float] = None
|
max_user_budget: Optional[float] = None
|
||||||
|
@ -280,6 +282,7 @@ default_max_internal_user_budget: Optional[float] = None
|
||||||
max_internal_user_budget: Optional[float] = None
|
max_internal_user_budget: Optional[float] = None
|
||||||
internal_user_budget_duration: Optional[str] = None
|
internal_user_budget_duration: Optional[str] = None
|
||||||
max_end_user_budget: Optional[float] = None
|
max_end_user_budget: Optional[float] = None
|
||||||
|
disable_end_user_cost_tracking: Optional[bool] = None
|
||||||
#### REQUEST PRIORITIZATION ####
|
#### REQUEST PRIORITIZATION ####
|
||||||
priority_reservation: Optional[Dict[str, float]] = None
|
priority_reservation: Optional[Dict[str, float]] = None
|
||||||
#### RELIABILITY ####
|
#### RELIABILITY ####
|
||||||
|
|
|
@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.types.integrations.prometheus import *
|
from litellm.types.integrations.prometheus import *
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||||
|
|
||||||
|
|
||||||
class PrometheusLogger(CustomLogger):
|
class PrometheusLogger(CustomLogger):
|
||||||
|
@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger):
|
||||||
model = kwargs.get("model", "")
|
model = kwargs.get("model", "")
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
_metadata = litellm_params.get("metadata", {})
|
_metadata = litellm_params.get("metadata", {})
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
|
||||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||||
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
||||||
|
@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger):
|
||||||
|
|
||||||
# unpack kwargs
|
# unpack kwargs
|
||||||
model = kwargs.get("model", "")
|
model = kwargs.get("model", "")
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
|
||||||
standard_logging_payload: StandardLoggingPayload = kwargs.get(
|
standard_logging_payload: StandardLoggingPayload = kwargs.get(
|
||||||
"standard_logging_object", {}
|
"standard_logging_object", {}
|
||||||
)
|
)
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
|
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
|
||||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||||
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
||||||
|
|
|
@ -934,19 +934,10 @@ class Logging:
|
||||||
status="success",
|
status="success",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.dynamic_success_callbacks is not None and isinstance(
|
callbacks = get_combined_callback_list(
|
||||||
self.dynamic_success_callbacks, list
|
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
||||||
):
|
global_callbacks=litellm.success_callback,
|
||||||
callbacks = self.dynamic_success_callbacks
|
)
|
||||||
## keep the internal functions ##
|
|
||||||
for callback in litellm.success_callback:
|
|
||||||
if (
|
|
||||||
isinstance(callback, CustomLogger)
|
|
||||||
and "_PROXY_" in callback.__class__.__name__
|
|
||||||
):
|
|
||||||
callbacks.append(callback)
|
|
||||||
else:
|
|
||||||
callbacks = litellm.success_callback
|
|
||||||
|
|
||||||
## REDACT MESSAGES ##
|
## REDACT MESSAGES ##
|
||||||
result = redact_message_input_output_from_logging(
|
result = redact_message_input_output_from_logging(
|
||||||
|
@ -1368,8 +1359,11 @@ class Logging:
|
||||||
and customLogger is not None
|
and customLogger is not None
|
||||||
): # custom logger functions
|
): # custom logger functions
|
||||||
print_verbose(
|
print_verbose(
|
||||||
"success callbacks: Running Custom Callback Function"
|
"success callbacks: Running Custom Callback Function - {}".format(
|
||||||
|
callback
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
customLogger.log_event(
|
customLogger.log_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
response_obj=result,
|
response_obj=result,
|
||||||
|
@ -1466,21 +1460,10 @@ class Logging:
|
||||||
status="success",
|
status="success",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.dynamic_async_success_callbacks is not None and isinstance(
|
callbacks = get_combined_callback_list(
|
||||||
self.dynamic_async_success_callbacks, list
|
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
||||||
):
|
global_callbacks=litellm._async_success_callback,
|
||||||
callbacks = self.dynamic_async_success_callbacks
|
)
|
||||||
## keep the internal functions ##
|
|
||||||
for callback in litellm._async_success_callback:
|
|
||||||
callback_name = ""
|
|
||||||
if isinstance(callback, CustomLogger):
|
|
||||||
callback_name = callback.__class__.__name__
|
|
||||||
if callable(callback):
|
|
||||||
callback_name = callback.__name__
|
|
||||||
if "_PROXY_" in callback_name:
|
|
||||||
callbacks.append(callback)
|
|
||||||
else:
|
|
||||||
callbacks = litellm._async_success_callback
|
|
||||||
|
|
||||||
result = redact_message_input_output_from_logging(
|
result = redact_message_input_output_from_logging(
|
||||||
model_call_details=(
|
model_call_details=(
|
||||||
|
@ -1747,21 +1730,10 @@ class Logging:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
callbacks = [] # init this to empty incase it's not created
|
callbacks = get_combined_callback_list(
|
||||||
|
dynamic_success_callbacks=self.dynamic_failure_callbacks,
|
||||||
if self.dynamic_failure_callbacks is not None and isinstance(
|
global_callbacks=litellm.failure_callback,
|
||||||
self.dynamic_failure_callbacks, list
|
)
|
||||||
):
|
|
||||||
callbacks = self.dynamic_failure_callbacks
|
|
||||||
## keep the internal functions ##
|
|
||||||
for callback in litellm.failure_callback:
|
|
||||||
if (
|
|
||||||
isinstance(callback, CustomLogger)
|
|
||||||
and "_PROXY_" in callback.__class__.__name__
|
|
||||||
):
|
|
||||||
callbacks.append(callback)
|
|
||||||
else:
|
|
||||||
callbacks = litellm.failure_callback
|
|
||||||
|
|
||||||
result = None # result sent to all loggers, init this to None incase it's not created
|
result = None # result sent to all loggers, init this to None incase it's not created
|
||||||
|
|
||||||
|
@ -1944,21 +1916,10 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
callbacks = [] # init this to empty incase it's not created
|
callbacks = get_combined_callback_list(
|
||||||
|
dynamic_success_callbacks=self.dynamic_async_failure_callbacks,
|
||||||
if self.dynamic_async_failure_callbacks is not None and isinstance(
|
global_callbacks=litellm._async_failure_callback,
|
||||||
self.dynamic_async_failure_callbacks, list
|
)
|
||||||
):
|
|
||||||
callbacks = self.dynamic_async_failure_callbacks
|
|
||||||
## keep the internal functions ##
|
|
||||||
for callback in litellm._async_failure_callback:
|
|
||||||
if (
|
|
||||||
isinstance(callback, CustomLogger)
|
|
||||||
and "_PROXY_" in callback.__class__.__name__
|
|
||||||
):
|
|
||||||
callbacks.append(callback)
|
|
||||||
else:
|
|
||||||
callbacks = litellm._async_failure_callback
|
|
||||||
|
|
||||||
result = None # result sent to all loggers, init this to None incase it's not created
|
result = None # result sent to all loggers, init this to None incase it's not created
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
|
@ -2359,6 +2320,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
_in_memory_loggers.append(_mlflow_logger)
|
_in_memory_loggers.append(_mlflow_logger)
|
||||||
return _mlflow_logger # type: ignore
|
return _mlflow_logger # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def get_custom_logger_compatible_class(
|
def get_custom_logger_compatible_class(
|
||||||
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
||||||
) -> Optional[CustomLogger]:
|
) -> Optional[CustomLogger]:
|
||||||
|
@ -2949,3 +2911,11 @@ def modify_integration(integration_name, integration_params):
|
||||||
if integration_name == "supabase":
|
if integration_name == "supabase":
|
||||||
if "table_name" in integration_params:
|
if "table_name" in integration_params:
|
||||||
Supabase.supabase_table_name = integration_params["table_name"]
|
Supabase.supabase_table_name = integration_params["table_name"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_combined_callback_list(
|
||||||
|
dynamic_success_callbacks: Optional[List], global_callbacks: List
|
||||||
|
) -> List:
|
||||||
|
if dynamic_success_callbacks is None:
|
||||||
|
return global_callbacks
|
||||||
|
return list(set(dynamic_success_callbacks + global_callbacks))
|
||||||
|
|
|
@ -4,6 +4,7 @@ import httpx
|
||||||
|
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.cohere.rerank import CohereRerank
|
from litellm.llms.cohere.rerank import CohereRerank
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.rerank import RerankResponse
|
from litellm.types.rerank import RerankResponse
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,6 +74,7 @@ class AzureAIRerank(CohereRerank):
|
||||||
return_documents: Optional[bool] = True,
|
return_documents: Optional[bool] = True,
|
||||||
max_chunks_per_doc: Optional[int] = None,
|
max_chunks_per_doc: Optional[int] = None,
|
||||||
_is_async: Optional[bool] = False,
|
_is_async: Optional[bool] = False,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
|
|
||||||
if headers is None:
|
if headers is None:
|
||||||
|
|
|
@ -74,6 +74,7 @@ async def async_embedding(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
client = get_async_httpx_client(
|
client = get_async_httpx_client(
|
||||||
llm_provider=litellm.LlmProviders.COHERE,
|
llm_provider=litellm.LlmProviders.COHERE,
|
||||||
|
@ -151,6 +152,11 @@ def embedding(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -6,10 +6,14 @@ LiteLLM supports the re rank API format, no paramter transformation occurs
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.base import BaseLLM
|
from litellm.llms.base import BaseLLM
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
|
@ -34,6 +38,23 @@ class CohereRerank(BaseLLM):
|
||||||
# Merge other headers, overriding any default ones except Authorization
|
# Merge other headers, overriding any default ones except Authorization
|
||||||
return {**default_headers, **headers}
|
return {**default_headers, **headers}
|
||||||
|
|
||||||
|
def ensure_rerank_endpoint(self, api_base: str) -> str:
|
||||||
|
"""
|
||||||
|
Ensures the `/v1/rerank` endpoint is appended to the given `api_base`.
|
||||||
|
If `/v1/rerank` is already present, the original URL is returned.
|
||||||
|
|
||||||
|
:param api_base: The base API URL.
|
||||||
|
:return: A URL with `/v1/rerank` appended if missing.
|
||||||
|
"""
|
||||||
|
# Parse the base URL to ensure proper structure
|
||||||
|
url = httpx.URL(api_base)
|
||||||
|
|
||||||
|
# Check if the URL already ends with `/v1/rerank`
|
||||||
|
if not url.path.endswith("/v1/rerank"):
|
||||||
|
url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank")
|
||||||
|
|
||||||
|
return str(url)
|
||||||
|
|
||||||
def rerank(
|
def rerank(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -48,9 +69,10 @@ class CohereRerank(BaseLLM):
|
||||||
return_documents: Optional[bool] = True,
|
return_documents: Optional[bool] = True,
|
||||||
max_chunks_per_doc: Optional[int] = None,
|
max_chunks_per_doc: Optional[int] = None,
|
||||||
_is_async: Optional[bool] = False, # New parameter
|
_is_async: Optional[bool] = False, # New parameter
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
headers = self.validate_environment(api_key=api_key, headers=headers)
|
headers = self.validate_environment(api_key=api_key, headers=headers)
|
||||||
|
api_base = self.ensure_rerank_endpoint(api_base)
|
||||||
request_data = RerankRequest(
|
request_data = RerankRequest(
|
||||||
model=model,
|
model=model,
|
||||||
query=query,
|
query=query,
|
||||||
|
@ -76,9 +98,13 @@ class CohereRerank(BaseLLM):
|
||||||
if _is_async:
|
if _is_async:
|
||||||
return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
|
return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
|
||||||
|
|
||||||
client = _get_httpx_client()
|
if client is not None and isinstance(client, HTTPHandler):
|
||||||
|
client = client
|
||||||
|
else:
|
||||||
|
client = _get_httpx_client()
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
api_base,
|
url=api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=request_data_dict,
|
json=request_data_dict,
|
||||||
)
|
)
|
||||||
|
@ -100,10 +126,13 @@ class CohereRerank(BaseLLM):
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
request_data_dict = request_data.dict(exclude_none=True)
|
request_data_dict = request_data.dict(exclude_none=True)
|
||||||
|
|
||||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE)
|
client = client or get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.COHERE
|
||||||
|
)
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
api_base,
|
api_base,
|
||||||
|
|
|
@ -3440,6 +3440,10 @@ def embedding( # noqa: PLR0915
|
||||||
or litellm.openai_key
|
or litellm.openai_key
|
||||||
or get_secret_str("OPENAI_API_KEY")
|
or get_secret_str("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
|
||||||
api_type = "openai"
|
api_type = "openai"
|
||||||
api_version = None
|
api_version = None
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ model_list:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
model_group_alias:
|
model_group_alias:
|
||||||
"gpt-4-turbo": # Aliased model name
|
"gpt-4-turbo": # Aliased model name
|
||||||
|
@ -35,4 +35,4 @@ litellm_settings:
|
||||||
failure_callback: ["langfuse"]
|
failure_callback: ["langfuse"]
|
||||||
langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2
|
langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2
|
||||||
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2
|
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2
|
||||||
langfuse_host: https://us.cloud.langfuse.com
|
langfuse_host: https://us.cloud.langfuse.com
|
||||||
|
|
|
@ -2,6 +2,7 @@ import enum
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -12,7 +13,15 @@ from typing_extensions import Annotated, TypedDict
|
||||||
|
|
||||||
from litellm.types.integrations.slack_alerting import AlertType
|
from litellm.types.integrations.slack_alerting import AlertType
|
||||||
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
||||||
from litellm.types.utils import ProviderField, StandardCallbackDynamicParams
|
from litellm.types.utils import (
|
||||||
|
EmbeddingResponse,
|
||||||
|
ImageResponse,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderField,
|
||||||
|
StandardCallbackDynamicParams,
|
||||||
|
StandardPassThroughResponseObject,
|
||||||
|
TextCompletionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -882,15 +891,7 @@ class DeleteCustomerRequest(LiteLLMBase):
|
||||||
user_ids: List[str]
|
user_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
class Member(LiteLLMBase):
|
class MemberBase(LiteLLMBase):
|
||||||
role: Literal[
|
|
||||||
LitellmUserRoles.ORG_ADMIN,
|
|
||||||
LitellmUserRoles.INTERNAL_USER,
|
|
||||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
|
||||||
# older Member roles
|
|
||||||
"admin",
|
|
||||||
"user",
|
|
||||||
]
|
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
|
|
||||||
|
@ -904,6 +905,21 @@ class Member(LiteLLMBase):
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class Member(MemberBase):
|
||||||
|
role: Literal[
|
||||||
|
"admin",
|
||||||
|
"user",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OrgMember(MemberBase):
|
||||||
|
role: Literal[
|
||||||
|
LitellmUserRoles.ORG_ADMIN,
|
||||||
|
LitellmUserRoles.INTERNAL_USER,
|
||||||
|
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TeamBase(LiteLLMBase):
|
class TeamBase(LiteLLMBase):
|
||||||
team_alias: Optional[str] = None
|
team_alias: Optional[str] = None
|
||||||
team_id: Optional[str] = None
|
team_id: Optional[str] = None
|
||||||
|
@ -1966,6 +1982,26 @@ class MemberAddRequest(LiteLLMBase):
|
||||||
# Replace member_data with the single Member object
|
# Replace member_data with the single Member object
|
||||||
data["member"] = member
|
data["member"] = member
|
||||||
# Call the superclass __init__ method to initialize the object
|
# Call the superclass __init__ method to initialize the object
|
||||||
|
traceback.print_stack()
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
|
||||||
|
class OrgMemberAddRequest(LiteLLMBase):
|
||||||
|
member: Union[List[OrgMember], OrgMember]
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
member_data = data.get("member")
|
||||||
|
if isinstance(member_data, list):
|
||||||
|
# If member is a list of dictionaries, convert each dictionary to a Member object
|
||||||
|
members = [OrgMember(**item) for item in member_data]
|
||||||
|
# Replace member_data with the list of Member objects
|
||||||
|
data["member"] = members
|
||||||
|
elif isinstance(member_data, dict):
|
||||||
|
# If member is a dictionary, convert it to a single Member object
|
||||||
|
member = OrgMember(**member_data)
|
||||||
|
# Replace member_data with the single Member object
|
||||||
|
data["member"] = member
|
||||||
|
# Call the superclass __init__ method to initialize the object
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2017,7 +2053,7 @@ class TeamMemberUpdateResponse(MemberUpdateResponse):
|
||||||
|
|
||||||
|
|
||||||
# Organization Member Requests
|
# Organization Member Requests
|
||||||
class OrganizationMemberAddRequest(MemberAddRequest):
|
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
||||||
organization_id: str
|
organization_id: str
|
||||||
max_budget_in_organization: Optional[float] = (
|
max_budget_in_organization: Optional[float] = (
|
||||||
None # Users max budget within the organization
|
None # Users max budget within the organization
|
||||||
|
@ -2133,3 +2169,17 @@ class UserManagementEndpointParamDocStringEnums(str, enum.Enum):
|
||||||
spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used."""
|
spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used."""
|
||||||
team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None."""
|
team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None."""
|
||||||
duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None."""
|
duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None."""
|
||||||
|
|
||||||
|
|
||||||
|
PassThroughEndpointLoggingResultValues = Union[
|
||||||
|
ModelResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
ImageResponse,
|
||||||
|
EmbeddingResponse,
|
||||||
|
StandardPassThroughResponseObject,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PassThroughEndpointLoggingTypedDict(TypedDict):
|
||||||
|
result: Optional[PassThroughEndpointLoggingResultValues]
|
||||||
|
kwargs: dict
|
||||||
|
|
|
@ -40,6 +40,77 @@ from litellm.proxy.utils import (
|
||||||
)
|
)
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
|
||||||
|
|
||||||
|
def _is_team_key(data: GenerateKeyRequest):
|
||||||
|
return data.team_id is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth):
|
||||||
|
if (
|
||||||
|
litellm.key_generation_settings is None
|
||||||
|
or litellm.key_generation_settings.get("team_key_generation") is None
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if user_api_key_dict.team_member is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}",
|
||||||
|
)
|
||||||
|
|
||||||
|
team_member_role = user_api_key_dict.team_member.role
|
||||||
|
if (
|
||||||
|
team_member_role
|
||||||
|
not in litellm.key_generation_settings["team_key_generation"][ # type: ignore
|
||||||
|
"allowed_team_member_roles"
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _personal_key_generation_check(user_api_key_dict: UserAPIKeyAuth):
|
||||||
|
|
||||||
|
if (
|
||||||
|
litellm.key_generation_settings is None
|
||||||
|
or litellm.key_generation_settings.get("personal_key_generation") is None
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role
|
||||||
|
not in litellm.key_generation_settings["personal_key_generation"][ # type: ignore
|
||||||
|
"allowed_user_roles"
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def key_generation_check(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if admin has restricted key creation to certain roles for teams or individuals
|
||||||
|
"""
|
||||||
|
if litellm.key_generation_settings is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
## check if key is for team or individual
|
||||||
|
is_team_key = _is_team_key(data=data)
|
||||||
|
|
||||||
|
if is_team_key:
|
||||||
|
return _team_key_generation_check(user_api_key_dict)
|
||||||
|
else:
|
||||||
|
return _personal_key_generation_check(user_api_key_dict=user_api_key_dict)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,6 +202,8 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
||||||
)
|
)
|
||||||
|
elif litellm.key_generation_settings is not None:
|
||||||
|
key_generation_check(user_api_key_dict=user_api_key_dict, data=data)
|
||||||
# check if user set default key/generate params on config.yaml
|
# check if user set default key/generate params on config.yaml
|
||||||
if litellm.default_key_generate_params is not None:
|
if litellm.default_key_generate_params is not None:
|
||||||
for elem in data:
|
for elem in data:
|
||||||
|
|
|
@ -352,7 +352,7 @@ async def organization_member_add(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
members: List[Member]
|
members: List[OrgMember]
|
||||||
if isinstance(data.member, List):
|
if isinstance(data.member, List):
|
||||||
members = data.member
|
members = data.member
|
||||||
else:
|
else:
|
||||||
|
@ -397,7 +397,7 @@ async def organization_member_add(
|
||||||
|
|
||||||
|
|
||||||
async def add_member_to_organization(
|
async def add_member_to_organization(
|
||||||
member: Member,
|
member: OrgMember,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]:
|
) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]:
|
||||||
|
|
|
@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat.handler import (
|
||||||
ModelResponseIterator as AnthropicModelResponseIterator,
|
ModelResponseIterator as AnthropicModelResponseIterator,
|
||||||
)
|
)
|
||||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||||
|
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..success_handler import PassThroughEndpointLogging
|
from ..success_handler import PassThroughEndpointLogging
|
||||||
|
@ -26,7 +27,7 @@ else:
|
||||||
class AnthropicPassthroughLoggingHandler:
|
class AnthropicPassthroughLoggingHandler:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def anthropic_passthrough_handler(
|
def anthropic_passthrough_handler(
|
||||||
httpx_response: httpx.Response,
|
httpx_response: httpx.Response,
|
||||||
response_body: dict,
|
response_body: dict,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
@ -36,7 +37,7 @@ class AnthropicPassthroughLoggingHandler:
|
||||||
end_time: datetime,
|
end_time: datetime,
|
||||||
cache_hit: bool,
|
cache_hit: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> PassThroughEndpointLoggingTypedDict:
|
||||||
"""
|
"""
|
||||||
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
|
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
|
||||||
"""
|
"""
|
||||||
|
@ -67,15 +68,10 @@ class AnthropicPassthroughLoggingHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
await logging_obj.async_success_handler(
|
return {
|
||||||
result=litellm_model_response,
|
"result": litellm_model_response,
|
||||||
start_time=start_time,
|
"kwargs": kwargs,
|
||||||
end_time=end_time,
|
}
|
||||||
cache_hit=cache_hit,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_anthropic_response_logging_payload(
|
def _create_anthropic_response_logging_payload(
|
||||||
|
@ -123,7 +119,7 @@ class AnthropicPassthroughLoggingHandler:
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _handle_logging_anthropic_collected_chunks(
|
def _handle_logging_anthropic_collected_chunks(
|
||||||
litellm_logging_obj: LiteLLMLoggingObj,
|
litellm_logging_obj: LiteLLMLoggingObj,
|
||||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||||
url_route: str,
|
url_route: str,
|
||||||
|
@ -132,7 +128,7 @@ class AnthropicPassthroughLoggingHandler:
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
all_chunks: List[str],
|
all_chunks: List[str],
|
||||||
end_time: datetime,
|
end_time: datetime,
|
||||||
):
|
) -> PassThroughEndpointLoggingTypedDict:
|
||||||
"""
|
"""
|
||||||
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
|
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
|
||||||
|
|
||||||
|
@ -152,7 +148,10 @@ class AnthropicPassthroughLoggingHandler:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
|
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
|
||||||
)
|
)
|
||||||
return
|
return {
|
||||||
|
"result": None,
|
||||||
|
"kwargs": {},
|
||||||
|
}
|
||||||
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
||||||
litellm_model_response=complete_streaming_response,
|
litellm_model_response=complete_streaming_response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -161,13 +160,11 @@ class AnthropicPassthroughLoggingHandler:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
)
|
)
|
||||||
await litellm_logging_obj.async_success_handler(
|
|
||||||
result=complete_streaming_response,
|
return {
|
||||||
start_time=start_time,
|
"result": complete_streaming_response,
|
||||||
end_time=end_time,
|
"kwargs": kwargs,
|
||||||
cache_hit=False,
|
}
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_complete_streaming_response(
|
def _build_complete_streaming_response(
|
||||||
|
|
|
@ -14,6 +14,7 @@ from litellm.litellm_core_utils.litellm_logging import (
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
ModelResponseIterator as VertexModelResponseIterator,
|
ModelResponseIterator as VertexModelResponseIterator,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..success_handler import PassThroughEndpointLogging
|
from ..success_handler import PassThroughEndpointLogging
|
||||||
|
@ -25,7 +26,7 @@ else:
|
||||||
|
|
||||||
class VertexPassthroughLoggingHandler:
|
class VertexPassthroughLoggingHandler:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def vertex_passthrough_handler(
|
def vertex_passthrough_handler(
|
||||||
httpx_response: httpx.Response,
|
httpx_response: httpx.Response,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
url_route: str,
|
url_route: str,
|
||||||
|
@ -34,7 +35,7 @@ class VertexPassthroughLoggingHandler:
|
||||||
end_time: datetime,
|
end_time: datetime,
|
||||||
cache_hit: bool,
|
cache_hit: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> PassThroughEndpointLoggingTypedDict:
|
||||||
if "generateContent" in url_route:
|
if "generateContent" in url_route:
|
||||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||||
|
|
||||||
|
@ -65,13 +66,11 @@ class VertexPassthroughLoggingHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
await logging_obj.async_success_handler(
|
return {
|
||||||
result=litellm_model_response,
|
"result": litellm_model_response,
|
||||||
start_time=start_time,
|
"kwargs": kwargs,
|
||||||
end_time=end_time,
|
}
|
||||||
cache_hit=cache_hit,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif "predict" in url_route:
|
elif "predict" in url_route:
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
|
||||||
VertexImageGeneration,
|
VertexImageGeneration,
|
||||||
|
@ -112,16 +111,18 @@ class VertexPassthroughLoggingHandler:
|
||||||
logging_obj.model = model
|
logging_obj.model = model
|
||||||
logging_obj.model_call_details["model"] = logging_obj.model
|
logging_obj.model_call_details["model"] = logging_obj.model
|
||||||
|
|
||||||
await logging_obj.async_success_handler(
|
return {
|
||||||
result=litellm_prediction_response,
|
"result": litellm_prediction_response,
|
||||||
start_time=start_time,
|
"kwargs": kwargs,
|
||||||
end_time=end_time,
|
}
|
||||||
cache_hit=cache_hit,
|
else:
|
||||||
**kwargs,
|
return {
|
||||||
)
|
"result": None,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _handle_logging_vertex_collected_chunks(
|
def _handle_logging_vertex_collected_chunks(
|
||||||
litellm_logging_obj: LiteLLMLoggingObj,
|
litellm_logging_obj: LiteLLMLoggingObj,
|
||||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||||
url_route: str,
|
url_route: str,
|
||||||
|
@ -130,7 +131,7 @@ class VertexPassthroughLoggingHandler:
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
all_chunks: List[str],
|
all_chunks: List[str],
|
||||||
end_time: datetime,
|
end_time: datetime,
|
||||||
):
|
) -> PassThroughEndpointLoggingTypedDict:
|
||||||
"""
|
"""
|
||||||
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
|
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
|
||||||
|
|
||||||
|
@ -152,7 +153,11 @@ class VertexPassthroughLoggingHandler:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
|
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
|
||||||
)
|
)
|
||||||
return
|
return {
|
||||||
|
"result": None,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
||||||
litellm_model_response=complete_streaming_response,
|
litellm_model_response=complete_streaming_response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -161,13 +166,11 @@ class VertexPassthroughLoggingHandler:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
)
|
)
|
||||||
await litellm_logging_obj.async_success_handler(
|
|
||||||
result=complete_streaming_response,
|
return {
|
||||||
start_time=start_time,
|
"result": complete_streaming_response,
|
||||||
end_time=end_time,
|
"kwargs": kwargs,
|
||||||
cache_hit=False,
|
}
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_complete_streaming_response(
|
def _build_complete_streaming_response(
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import AsyncIterable, Dict, List, Optional, Union
|
from typing import AsyncIterable, Dict, List, Optional, Union
|
||||||
|
@ -15,7 +16,12 @@ from litellm.llms.anthropic.chat.handler import (
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
ModelResponseIterator as VertexAIIterator,
|
ModelResponseIterator as VertexAIIterator,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericStreamingChunk
|
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
||||||
|
from litellm.types.utils import (
|
||||||
|
GenericStreamingChunk,
|
||||||
|
ModelResponse,
|
||||||
|
StandardPassThroughResponseObject,
|
||||||
|
)
|
||||||
|
|
||||||
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||||
AnthropicPassthroughLoggingHandler,
|
AnthropicPassthroughLoggingHandler,
|
||||||
|
@ -87,8 +93,12 @@ class PassThroughStreamingHandler:
|
||||||
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
|
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
|
||||||
raw_bytes
|
raw_bytes
|
||||||
)
|
)
|
||||||
|
standard_logging_response_object: Optional[
|
||||||
|
PassThroughEndpointLoggingResultValues
|
||||||
|
] = None
|
||||||
|
kwargs: dict = {}
|
||||||
if endpoint_type == EndpointType.ANTHROPIC:
|
if endpoint_type == EndpointType.ANTHROPIC:
|
||||||
await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
|
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
|
||||||
litellm_logging_obj=litellm_logging_obj,
|
litellm_logging_obj=litellm_logging_obj,
|
||||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||||
url_route=url_route,
|
url_route=url_route,
|
||||||
|
@ -98,20 +108,48 @@ class PassThroughStreamingHandler:
|
||||||
all_chunks=all_chunks,
|
all_chunks=all_chunks,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
standard_logging_response_object = anthropic_passthrough_logging_handler_result[
|
||||||
|
"result"
|
||||||
|
]
|
||||||
|
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||||
elif endpoint_type == EndpointType.VERTEX_AI:
|
elif endpoint_type == EndpointType.VERTEX_AI:
|
||||||
await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
|
vertex_passthrough_logging_handler_result = (
|
||||||
litellm_logging_obj=litellm_logging_obj,
|
VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
|
||||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
litellm_logging_obj=litellm_logging_obj,
|
||||||
url_route=url_route,
|
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||||
request_body=request_body,
|
url_route=url_route,
|
||||||
endpoint_type=endpoint_type,
|
request_body=request_body,
|
||||||
start_time=start_time,
|
endpoint_type=endpoint_type,
|
||||||
all_chunks=all_chunks,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
all_chunks=all_chunks,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif endpoint_type == EndpointType.GENERIC:
|
standard_logging_response_object = vertex_passthrough_logging_handler_result[
|
||||||
# No logging is supported for generic streaming endpoints
|
"result"
|
||||||
pass
|
]
|
||||||
|
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||||
|
|
||||||
|
if standard_logging_response_object is None:
|
||||||
|
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||||
|
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
|
||||||
|
)
|
||||||
|
threading.Thread(
|
||||||
|
target=litellm_logging_obj.success_handler,
|
||||||
|
args=(
|
||||||
|
standard_logging_response_object,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
).start()
|
||||||
|
await litellm_logging_obj.async_success_handler(
|
||||||
|
result=standard_logging_response_object,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
|
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
|
||||||
|
@ -130,4 +168,4 @@ class PassThroughStreamingHandler:
|
||||||
# Split by newlines and filter out empty lines
|
# Split by newlines and filter out empty lines
|
||||||
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
|
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
|
||||||
|
|
||||||
return lines
|
return lines
|
|
@ -15,6 +15,7 @@ from litellm.litellm_core_utils.litellm_logging import (
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.types.utils import StandardPassThroughResponseObject
|
from litellm.types.utils import StandardPassThroughResponseObject
|
||||||
|
|
||||||
|
@ -49,53 +50,69 @@ class PassThroughEndpointLogging:
|
||||||
cache_hit: bool,
|
cache_hit: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
standard_logging_response_object: Optional[
|
||||||
|
PassThroughEndpointLoggingResultValues
|
||||||
|
] = None
|
||||||
if self.is_vertex_route(url_route):
|
if self.is_vertex_route(url_route):
|
||||||
await VertexPassthroughLoggingHandler.vertex_passthrough_handler(
|
vertex_passthrough_logging_handler_result = (
|
||||||
httpx_response=httpx_response,
|
VertexPassthroughLoggingHandler.vertex_passthrough_handler(
|
||||||
logging_obj=logging_obj,
|
httpx_response=httpx_response,
|
||||||
url_route=url_route,
|
logging_obj=logging_obj,
|
||||||
result=result,
|
url_route=url_route,
|
||||||
start_time=start_time,
|
result=result,
|
||||||
end_time=end_time,
|
start_time=start_time,
|
||||||
cache_hit=cache_hit,
|
end_time=end_time,
|
||||||
**kwargs,
|
cache_hit=cache_hit,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
standard_logging_response_object = (
|
||||||
|
vertex_passthrough_logging_handler_result["result"]
|
||||||
|
)
|
||||||
|
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||||
elif self.is_anthropic_route(url_route):
|
elif self.is_anthropic_route(url_route):
|
||||||
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
anthropic_passthrough_logging_handler_result = (
|
||||||
httpx_response=httpx_response,
|
AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
||||||
response_body=response_body or {},
|
httpx_response=httpx_response,
|
||||||
logging_obj=logging_obj,
|
response_body=response_body or {},
|
||||||
url_route=url_route,
|
logging_obj=logging_obj,
|
||||||
result=result,
|
url_route=url_route,
|
||||||
start_time=start_time,
|
result=result,
|
||||||
end_time=end_time,
|
start_time=start_time,
|
||||||
cache_hit=cache_hit,
|
end_time=end_time,
|
||||||
**kwargs,
|
cache_hit=cache_hit,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
standard_logging_response_object = (
|
||||||
|
anthropic_passthrough_logging_handler_result["result"]
|
||||||
|
)
|
||||||
|
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||||
|
if standard_logging_response_object is None:
|
||||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||||
response=httpx_response.text
|
response=httpx_response.text
|
||||||
)
|
)
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=logging_obj.success_handler,
|
target=logging_obj.success_handler,
|
||||||
args=(
|
args=(
|
||||||
standard_logging_response_object,
|
standard_logging_response_object,
|
||||||
start_time,
|
start_time,
|
||||||
end_time,
|
end_time,
|
||||||
cache_hit,
|
cache_hit,
|
||||||
),
|
),
|
||||||
).start()
|
).start()
|
||||||
await logging_obj.async_success_handler(
|
await logging_obj.async_success_handler(
|
||||||
result=(
|
result=(
|
||||||
json.dumps(result)
|
json.dumps(result)
|
||||||
if isinstance(result, dict)
|
if isinstance(result, dict)
|
||||||
else standard_logging_response_object
|
else standard_logging_response_object
|
||||||
),
|
),
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
cache_hit=False,
|
cache_hit=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_vertex_route(self, url_route: str):
|
def is_vertex_route(self, url_route: str):
|
||||||
for route in self.TRACKED_VERTEX_ROUTES:
|
for route in self.TRACKED_VERTEX_ROUTES:
|
||||||
|
|
|
@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import (
|
||||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||||
from litellm.types.router import RouterGeneralSettings
|
from litellm.types.router import RouterGeneralSettings
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
|
@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback(
|
||||||
)
|
)
|
||||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
|
||||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||||
user_id = metadata.get("user_api_key_user_id", None)
|
user_id = metadata.get("user_api_key_user_id", None)
|
||||||
team_id = metadata.get("user_api_key_team_id", None)
|
team_id = metadata.get("user_api_key_team_id", None)
|
||||||
|
|
|
@ -337,14 +337,14 @@ class ProxyLogging:
|
||||||
alert_to_webhook_url=self.alert_to_webhook_url,
|
alert_to_webhook_url=self.alert_to_webhook_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if self.alerting is not None and "slack" in self.alerting:
|
||||||
self.alerting is not None
|
|
||||||
and "slack" in self.alerting
|
|
||||||
and "daily_reports" in self.alert_types
|
|
||||||
):
|
|
||||||
# NOTE: ENSURE we only add callbacks when alerting is on
|
# NOTE: ENSURE we only add callbacks when alerting is on
|
||||||
# We should NOT add callbacks when alerting is off
|
# We should NOT add callbacks when alerting is off
|
||||||
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
if "daily_reports" in self.alert_types:
|
||||||
|
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
||||||
|
litellm.success_callback.append(
|
||||||
|
self.slack_alerting_instance.response_taking_too_long_callback
|
||||||
|
)
|
||||||
|
|
||||||
if redis_cache is not None:
|
if redis_cache is not None:
|
||||||
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
|
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
|
||||||
|
@ -354,9 +354,6 @@ class ProxyLogging:
|
||||||
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
|
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
|
||||||
litellm.callbacks.append(self.cache_control_check) # type: ignore
|
litellm.callbacks.append(self.cache_control_check) # type: ignore
|
||||||
litellm.callbacks.append(self.service_logging_obj) # type: ignore
|
litellm.callbacks.append(self.service_logging_obj) # type: ignore
|
||||||
litellm.success_callback.append(
|
|
||||||
self.slack_alerting_instance.response_taking_too_long_callback
|
|
||||||
)
|
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if isinstance(callback, str):
|
if isinstance(callback, str):
|
||||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||||
|
|
|
@ -91,6 +91,7 @@ def rerank(
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
metadata = kwargs.get("metadata", {})
|
||||||
user = kwargs.get("user", None)
|
user = kwargs.get("user", None)
|
||||||
|
client = kwargs.get("client", None)
|
||||||
try:
|
try:
|
||||||
_is_async = kwargs.pop("arerank", False) is True
|
_is_async = kwargs.pop("arerank", False) is True
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
@ -150,7 +151,7 @@ def rerank(
|
||||||
or optional_params.api_base
|
or optional_params.api_base
|
||||||
or litellm.api_base
|
or litellm.api_base
|
||||||
or get_secret("COHERE_API_BASE") # type: ignore
|
or get_secret("COHERE_API_BASE") # type: ignore
|
||||||
or "https://api.cohere.com/v1/rerank"
|
or "https://api.cohere.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
if api_base is None:
|
if api_base is None:
|
||||||
|
@ -173,6 +174,7 @@ def rerank(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
litellm_logging_obj=litellm_logging_obj,
|
litellm_logging_obj=litellm_logging_obj,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
elif _custom_llm_provider == "azure_ai":
|
elif _custom_llm_provider == "azure_ai":
|
||||||
api_base = (
|
api_base = (
|
||||||
|
|
|
@ -1602,3 +1602,16 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
|
||||||
langsmith_api_key: Optional[str]
|
langsmith_api_key: Optional[str]
|
||||||
langsmith_project: Optional[str]
|
langsmith_project: Optional[str]
|
||||||
langsmith_base_url: Optional[str]
|
langsmith_base_url: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TeamUIKeyGenerationConfig(TypedDict):
|
||||||
|
allowed_team_member_roles: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalUIKeyGenerationConfig(TypedDict):
|
||||||
|
allowed_user_roles: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class StandardKeyGenerationConfig(TypedDict, total=False):
|
||||||
|
team_key_generation: TeamUIKeyGenerationConfig
|
||||||
|
personal_key_generation: PersonalUIKeyGenerationConfig
|
||||||
|
|
|
@ -6170,3 +6170,13 @@ class ProviderConfigManager:
|
||||||
return litellm.GroqChatConfig()
|
return litellm.GroqChatConfig()
|
||||||
|
|
||||||
return OpenAIGPTConfig()
|
return OpenAIGPTConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Used for enforcing `disable_end_user_cost_tracking` param.
|
||||||
|
"""
|
||||||
|
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||||
|
if litellm.disable_end_user_cost_tracking:
|
||||||
|
return None
|
||||||
|
return proxy_server_request.get("body", {}).get("user", None)
|
||||||
|
|
|
@ -1080,3 +1080,34 @@ def test_cohere_img_embeddings(input, input_type):
|
||||||
assert response.usage.prompt_tokens_details.image_tokens > 0
|
assert response.usage.prompt_tokens_details.image_tokens > 0
|
||||||
else:
|
else:
|
||||||
assert response.usage.prompt_tokens_details.text_tokens > 0
|
assert response.usage.prompt_tokens_details.text_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_with_extra_headers(sync_mode):
|
||||||
|
|
||||||
|
input = ["hello world"]
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
client = HTTPHandler()
|
||||||
|
else:
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "cohere/embed-english-v3.0",
|
||||||
|
"input": input,
|
||||||
|
"extra_headers": {"my-test-param": "hello-world"},
|
||||||
|
"client": client,
|
||||||
|
}
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
if sync_mode:
|
||||||
|
embedding(**data)
|
||||||
|
else:
|
||||||
|
await litellm.aembedding(**data)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
assert "my-test-param" in mock_post.call_args.kwargs["headers"]
|
||||||
|
|
|
@ -215,7 +215,10 @@ async def test_rerank_custom_api_base():
|
||||||
args_to_api = kwargs["json"]
|
args_to_api = kwargs["json"]
|
||||||
print("Arguments passed to API=", args_to_api)
|
print("Arguments passed to API=", args_to_api)
|
||||||
print("url = ", _url)
|
print("url = ", _url)
|
||||||
assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/"
|
assert (
|
||||||
|
_url[0]
|
||||||
|
== "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
|
||||||
|
)
|
||||||
assert args_to_api == expected_payload
|
assert args_to_api == expected_payload
|
||||||
assert response.id is not None
|
assert response.id is not None
|
||||||
assert response.results is not None
|
assert response.results is not None
|
||||||
|
@ -258,3 +261,32 @@ async def test_rerank_custom_callbacks():
|
||||||
assert custom_logger.kwargs.get("response_cost") > 0.0
|
assert custom_logger.kwargs.get("response_cost") > 0.0
|
||||||
assert custom_logger.response_obj is not None
|
assert custom_logger.response_obj is not None
|
||||||
assert custom_logger.response_obj.results is not None
|
assert custom_logger.response_obj.results is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_base_url_cohere():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
litellm.api_base = "http://localhost:4000"
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
text = "Hello there!"
|
||||||
|
list_texts = ["Hello there!", "How are you?", "How do you do?"]
|
||||||
|
|
||||||
|
rerank_model = "rerank-multilingual-v3.0"
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
litellm.rerank(
|
||||||
|
model=rerank_model,
|
||||||
|
query=text,
|
||||||
|
documents=list_texts,
|
||||||
|
custom_llm_provider="cohere",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
print("mock_post.call_args", mock_post.call_args)
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"]
|
||||||
|
|
|
@ -1012,3 +1012,23 @@ def test_models_by_provider():
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
assert provider in models_by_provider.keys()
|
assert provider in models_by_provider.keys()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"litellm_params, disable_end_user_cost_tracking, expected_end_user_id",
|
||||||
|
[
|
||||||
|
({}, False, None),
|
||||||
|
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
|
||||||
|
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_end_user_id_for_cost_tracking(
|
||||||
|
litellm_params, disable_end_user_cost_tracking, expected_end_user_id
|
||||||
|
):
|
||||||
|
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||||
|
|
||||||
|
litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking
|
||||||
|
assert (
|
||||||
|
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
|
||||||
|
== expected_end_user_id
|
||||||
|
)
|
||||||
|
|
|
@ -216,3 +216,78 @@ async def test_init_custom_logger_compatible_class_as_callback():
|
||||||
await use_callback_in_llm_call(callback, used_in="success_callback")
|
await use_callback_in_llm_call(callback, used_in="success_callback")
|
||||||
|
|
||||||
reset_env_vars()
|
reset_env_vars()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_logging_global_callback():
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.types.utils import ModelResponse, Choices, Message, Usage
|
||||||
|
|
||||||
|
cl = CustomLogger()
|
||||||
|
|
||||||
|
litellm_logging = LiteLLMLoggingObj(
|
||||||
|
model="claude-3-opus-20240229",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
call_type="completion",
|
||||||
|
start_time=datetime.now(),
|
||||||
|
litellm_call_id="123",
|
||||||
|
function_id="456",
|
||||||
|
kwargs={
|
||||||
|
"langfuse_public_key": "my-mock-public-key",
|
||||||
|
"langfuse_secret_key": "my-mock-secret-key",
|
||||||
|
},
|
||||||
|
dynamic_success_callbacks=["langfuse"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(cl, "log_success_event") as mock_log_success_event:
|
||||||
|
cl.log_success_event = mock_log_success_event
|
||||||
|
litellm.success_callback = [cl]
|
||||||
|
|
||||||
|
try:
|
||||||
|
litellm_logging.success_handler(
|
||||||
|
result=ModelResponse(
|
||||||
|
id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70",
|
||||||
|
created=1732306261,
|
||||||
|
model="claude-3-opus-20240229",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
choices=[
|
||||||
|
Choices(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=Message(
|
||||||
|
content="hello",
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=None,
|
||||||
|
function_call=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=Usage(
|
||||||
|
completion_tokens=20,
|
||||||
|
prompt_tokens=10,
|
||||||
|
total_tokens=30,
|
||||||
|
completion_tokens_details=None,
|
||||||
|
prompt_tokens_details=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
start_time=datetime.now(),
|
||||||
|
end_time=datetime.now(),
|
||||||
|
cache_hit=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
mock_log_success_event.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_combined_callback_list():
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list
|
||||||
|
|
||||||
|
assert "langfuse" in get_combined_callback_list(
|
||||||
|
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
|
||||||
|
)
|
||||||
|
assert "lago" in get_combined_callback_list(
|
||||||
|
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
|
||||||
|
)
|
||||||
|
|
|
@ -73,7 +73,7 @@ async def test_anthropic_passthrough_handler(
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
|
|
||||||
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
result = AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
||||||
httpx_response=mock_httpx_response,
|
httpx_response=mock_httpx_response,
|
||||||
response_body=mock_response,
|
response_body=mock_response,
|
||||||
logging_obj=mock_logging_obj,
|
logging_obj=mock_logging_obj,
|
||||||
|
@ -84,30 +84,7 @@ async def test_anthropic_passthrough_handler(
|
||||||
cache_hit=False,
|
cache_hit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that async_success_handler was called
|
assert isinstance(result["result"], litellm.ModelResponse)
|
||||||
assert mock_logging_obj.async_success_handler.called
|
|
||||||
|
|
||||||
call_args = mock_logging_obj.async_success_handler.call_args
|
|
||||||
call_kwargs = call_args.kwargs
|
|
||||||
print("call_kwargs", call_kwargs)
|
|
||||||
|
|
||||||
# Assert required fields are present in call_kwargs
|
|
||||||
assert "result" in call_kwargs
|
|
||||||
assert "start_time" in call_kwargs
|
|
||||||
assert "end_time" in call_kwargs
|
|
||||||
assert "cache_hit" in call_kwargs
|
|
||||||
assert "response_cost" in call_kwargs
|
|
||||||
assert "model" in call_kwargs
|
|
||||||
assert "standard_logging_object" in call_kwargs
|
|
||||||
|
|
||||||
# Assert specific values and types
|
|
||||||
assert isinstance(call_kwargs["result"], litellm.ModelResponse)
|
|
||||||
assert isinstance(call_kwargs["start_time"], datetime)
|
|
||||||
assert isinstance(call_kwargs["end_time"], datetime)
|
|
||||||
assert isinstance(call_kwargs["cache_hit"], bool)
|
|
||||||
assert isinstance(call_kwargs["response_cost"], float)
|
|
||||||
assert call_kwargs["model"] == "claude-3-opus-20240229"
|
|
||||||
assert isinstance(call_kwargs["standard_logging_object"], dict)
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_anthropic_response_logging_payload(mock_logging_obj):
|
def test_create_anthropic_response_logging_payload(mock_logging_obj):
|
||||||
|
|
|
@ -64,6 +64,7 @@ async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route):
|
||||||
litellm_logging_obj = MagicMock()
|
litellm_logging_obj = MagicMock()
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
passthrough_success_handler_obj = MagicMock()
|
passthrough_success_handler_obj = MagicMock()
|
||||||
|
litellm_logging_obj.async_success_handler = AsyncMock()
|
||||||
|
|
||||||
# Capture yielded chunks and perform detailed assertions
|
# Capture yielded chunks and perform detailed assertions
|
||||||
received_chunks = []
|
received_chunks = []
|
||||||
|
|
54
tests/proxy_admin_ui_tests/conftest.py
Normal file
54
tests/proxy_admin_ui_tests/conftest.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# conftest.py
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def setup_and_teardown():
|
||||||
|
"""
|
||||||
|
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||||
|
"""
|
||||||
|
curr_dir = os.getcwd() # Get the current working directory
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the project directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
importlib.reload(litellm)
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
print(litellm)
|
||||||
|
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Teardown code (executes after the yield point)
|
||||||
|
loop.close() # Close the loop created earlier
|
||||||
|
asyncio.set_event_loop(None) # Remove the reference to the loop
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||||
|
custom_logger_tests = [
|
||||||
|
item for item in items if "custom_logger" in item.parent.name
|
||||||
|
]
|
||||||
|
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
|
||||||
|
|
||||||
|
# Sort tests based on their names
|
||||||
|
custom_logger_tests.sort(key=lambda x: x.name)
|
||||||
|
other_tests.sort(key=lambda x: x.name)
|
||||||
|
|
||||||
|
# Reorder the items list
|
||||||
|
items[:] = custom_logger_tests + other_tests
|
|
@ -542,3 +542,65 @@ async def test_list_teams(prisma_client):
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
await prisma_client.delete_data(team_id_list=[team_id], table_name="team")
|
await prisma_client.delete_data(team_id_list=[team_id], table_name="team")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_team_key():
|
||||||
|
from litellm.proxy.management_endpoints.key_management_endpoints import _is_team_key
|
||||||
|
|
||||||
|
assert _is_team_key(GenerateKeyRequest(team_id="test_team_id"))
|
||||||
|
assert not _is_team_key(GenerateKeyRequest(user_id="test_user_id"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_team_key_generation_check():
|
||||||
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
|
_team_key_generation_check,
|
||||||
|
)
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
litellm.key_generation_settings = {
|
||||||
|
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert _team_key_generation_check(
|
||||||
|
UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
|
api_key="sk-1234",
|
||||||
|
team_member=Member(role="admin", user_id="test_user_id"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
_team_key_generation_check(
|
||||||
|
UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="test_user_id",
|
||||||
|
team_member=Member(role="user", user_id="test_user_id"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_personal_key_generation_check():
|
||||||
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
|
_personal_key_generation_check,
|
||||||
|
)
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
litellm.key_generation_settings = {
|
||||||
|
"personal_key_generation": {"allowed_user_roles": ["proxy_admin"]}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert _personal_key_generation_check(
|
||||||
|
UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
_personal_key_generation_check(
|
||||||
|
UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="admin",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -160,7 +160,7 @@ async def test_create_new_user_in_organization(prisma_client, user_role):
|
||||||
response = await organization_member_add(
|
response = await organization_member_add(
|
||||||
data=OrganizationMemberAddRequest(
|
data=OrganizationMemberAddRequest(
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
member=Member(role=user_role, user_id=created_user_id),
|
member=OrgMember(role=user_role, user_id=created_user_id),
|
||||||
),
|
),
|
||||||
http_request=None,
|
http_request=None,
|
||||||
)
|
)
|
||||||
|
@ -220,7 +220,7 @@ async def test_org_admin_create_team_permissions(prisma_client):
|
||||||
response = await organization_member_add(
|
response = await organization_member_add(
|
||||||
data=OrganizationMemberAddRequest(
|
data=OrganizationMemberAddRequest(
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||||
),
|
),
|
||||||
http_request=None,
|
http_request=None,
|
||||||
)
|
)
|
||||||
|
@ -292,7 +292,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
|
||||||
response = await organization_member_add(
|
response = await organization_member_add(
|
||||||
data=OrganizationMemberAddRequest(
|
data=OrganizationMemberAddRequest(
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||||
),
|
),
|
||||||
http_request=None,
|
http_request=None,
|
||||||
)
|
)
|
||||||
|
@ -323,7 +323,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
|
||||||
response = await organization_member_add(
|
response = await organization_member_add(
|
||||||
data=OrganizationMemberAddRequest(
|
data=OrganizationMemberAddRequest(
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
member=Member(
|
member=OrgMember(
|
||||||
role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org
|
role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
@ -375,7 +375,7 @@ async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client):
|
||||||
response = await organization_member_add(
|
response = await organization_member_add(
|
||||||
data=OrganizationMemberAddRequest(
|
data=OrganizationMemberAddRequest(
|
||||||
organization_id=org1_id,
|
organization_id=org1_id,
|
||||||
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||||
),
|
),
|
||||||
http_request=None,
|
http_request=None,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue