forked from phoenix/litellm-mirror
Compare commits
21 commits
main
...
litellm_de
Author | SHA1 | Date | |
---|---|---|---|
|
b06b0248ff | ||
|
4021206ac2 | ||
|
88db948d29 | ||
|
b5c5f87e25 | ||
|
ed64dd7f9d | ||
|
e2ccf6b7f4 | ||
|
f8a46b5950 | ||
|
31943bf2ad | ||
|
541326731f | ||
|
dfb34dfe92 | ||
|
250d66b335 | ||
|
94fe135524 | ||
|
1a3fb18a64 | ||
|
d788c3c37f | ||
|
4beb48829c | ||
|
eb0a357eda | ||
|
463fa0c9d5 | ||
|
1014216d73 | ||
|
97d8aa0b3a | ||
|
5a698c678a | ||
|
1c9a8c0b68 |
35 changed files with 871 additions and 248 deletions
|
@ -41,7 +41,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea
|
|||
|
||||
| Provider | temperature | max_completion_tokens | max_tokens | top_p | stream | stream_options | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|
||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
||||
|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ |
|
||||
|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | | ✅ | ✅ | | | ✅ |
|
||||
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|
||||
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
|
|
74
docs/my-website/docs/guides/finetuned_models.md
Normal file
74
docs/my-website/docs/guides/finetuned_models.md
Normal file
|
@ -0,0 +1,74 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
|
||||
# Calling Finetuned Models
|
||||
|
||||
## OpenAI
|
||||
|
||||
|
||||
| Model Name | Function Call |
|
||||
|---------------------------|-----------------------------------------------------------------|
|
||||
| fine tuned `gpt-4-0613` | `response = completion(model="ft:gpt-4-0613", messages=messages)` |
|
||||
| fine tuned `gpt-4o-2024-05-13` | `response = completion(model="ft:gpt-4o-2024-05-13", messages=messages)` |
|
||||
| fine tuned `gpt-3.5-turbo-0125` | `response = completion(model="ft:gpt-3.5-turbo-0125", messages=messages)` |
|
||||
| fine tuned `gpt-3.5-turbo-1106` | `response = completion(model="ft:gpt-3.5-turbo-1106", messages=messages)` |
|
||||
| fine tuned `gpt-3.5-turbo-0613` | `response = completion(model="ft:gpt-3.5-turbo-0613", messages=messages)` |
|
||||
|
||||
|
||||
## Vertex AI
|
||||
|
||||
Fine tuned models on vertex have a numerical model/endpoint id.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
## set ENV variables
|
||||
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
|
||||
os.environ["VERTEXAI_LOCATION"] = "us-central1"
|
||||
|
||||
response = completion(
|
||||
model="vertex_ai/<your-finetuned-model>", # e.g. vertex_ai/4965075652664360960
|
||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||
base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add Vertex Credentials to your env
|
||||
|
||||
```bash
|
||||
!gcloud auth application-default login
|
||||
```
|
||||
|
||||
2. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
- model_name: finetuned-gemini
|
||||
litellm_params:
|
||||
model: vertex_ai/<ENDPOINT_ID>
|
||||
vertex_project: <PROJECT_ID>
|
||||
vertex_location: <LOCATION>
|
||||
model_info:
|
||||
base_model: vertex_ai/gemini-1.5-pro # IMPORTANT
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl --location 'https://0.0.0.0:4000/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: <LITELLM_KEY>' \
|
||||
--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
|
@ -754,6 +754,8 @@ general_settings:
|
|||
| cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. |
|
||||
| cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) |
|
||||
| cache_params.mode | string | The mode of the cache. [Further docs](./caching) |
|
||||
| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. |
|
||||
| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) |
|
||||
|
||||
### general_settings - Reference
|
||||
|
||||
|
|
|
@ -217,4 +217,10 @@ litellm_settings:
|
|||
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
|
||||
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
|
||||
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
|
||||
```
|
||||
|
||||
key_generation_settings: # Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation)
|
||||
team_key_generation:
|
||||
allowed_team_member_roles: ["admin"]
|
||||
personal_key_generation: # maps to 'Default Team' on UI
|
||||
allowed_user_roles: ["proxy_admin"]
|
||||
```
|
||||
|
|
|
@ -811,6 +811,75 @@ litellm_settings:
|
|||
team_id: "core-infra"
|
||||
```
|
||||
|
||||
### Restricting Key Generation
|
||||
|
||||
Use this to control who can generate keys. Useful when letting others create keys on the UI.
|
||||
|
||||
```yaml
|
||||
litellm_settings:
|
||||
key_generation_settings:
|
||||
team_key_generation:
|
||||
allowed_team_member_roles: ["admin"]
|
||||
personal_key_generation: # maps to 'Default Team' on UI
|
||||
allowed_user_roles: ["proxy_admin"]
|
||||
```
|
||||
|
||||
#### Spec
|
||||
|
||||
```python
|
||||
class TeamUIKeyGenerationConfig(TypedDict):
|
||||
allowed_team_member_roles: List[str]
|
||||
|
||||
|
||||
class PersonalUIKeyGenerationConfig(TypedDict):
|
||||
allowed_user_roles: List[LitellmUserRoles]
|
||||
|
||||
|
||||
class StandardKeyGenerationConfig(TypedDict, total=False):
|
||||
team_key_generation: TeamUIKeyGenerationConfig
|
||||
personal_key_generation: PersonalUIKeyGenerationConfig
|
||||
|
||||
|
||||
class LitellmUserRoles(str, enum.Enum):
|
||||
"""
|
||||
Admin Roles:
|
||||
PROXY_ADMIN: admin over the platform
|
||||
PROXY_ADMIN_VIEW_ONLY: can login, view all own keys, view all spend
|
||||
ORG_ADMIN: admin over a specific organization, can create teams, users only within their organization
|
||||
|
||||
Internal User Roles:
|
||||
INTERNAL_USER: can login, view/create/delete their own keys, view their spend
|
||||
INTERNAL_USER_VIEW_ONLY: can login, view their own keys, view their own spend
|
||||
|
||||
|
||||
Team Roles:
|
||||
TEAM: used for JWT auth
|
||||
|
||||
|
||||
Customer Roles:
|
||||
CUSTOMER: External users -> these are customers
|
||||
|
||||
"""
|
||||
|
||||
# Admin Roles
|
||||
PROXY_ADMIN = "proxy_admin"
|
||||
PROXY_ADMIN_VIEW_ONLY = "proxy_admin_viewer"
|
||||
|
||||
# Organization admins
|
||||
ORG_ADMIN = "org_admin"
|
||||
|
||||
# Internal User Roles
|
||||
INTERNAL_USER = "internal_user"
|
||||
INTERNAL_USER_VIEW_ONLY = "internal_user_viewer"
|
||||
|
||||
# Team Roles
|
||||
TEAM = "team"
|
||||
|
||||
# Customer Roles - External users of proxy
|
||||
CUSTOMER = "customer"
|
||||
```
|
||||
|
||||
|
||||
## **Next Steps - Set Budgets, Rate Limits per Virtual Key**
|
||||
|
||||
[Follow this doc to set budgets, rate limiters per virtual key with LiteLLM](users)
|
||||
|
|
|
@ -199,6 +199,31 @@ const sidebars = {
|
|||
|
||||
],
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "Guides",
|
||||
items: [
|
||||
"exception_mapping",
|
||||
"completion/provider_specific_params",
|
||||
"guides/finetuned_models",
|
||||
"completion/audio",
|
||||
"completion/vision",
|
||||
"completion/json_mode",
|
||||
"completion/prompt_caching",
|
||||
"completion/predict_outputs",
|
||||
"completion/prefix",
|
||||
"completion/drop_params",
|
||||
"completion/prompt_formatting",
|
||||
"completion/stream",
|
||||
"completion/message_trimming",
|
||||
"completion/function_call",
|
||||
"completion/model_alias",
|
||||
"completion/batching",
|
||||
"completion/mock_requests",
|
||||
"completion/reliable_completions",
|
||||
|
||||
]
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "Supported Endpoints",
|
||||
|
@ -214,25 +239,8 @@ const sidebars = {
|
|||
},
|
||||
items: [
|
||||
"completion/input",
|
||||
"completion/provider_specific_params",
|
||||
"completion/json_mode",
|
||||
"completion/prompt_caching",
|
||||
"completion/audio",
|
||||
"completion/vision",
|
||||
"completion/predict_outputs",
|
||||
"completion/prefix",
|
||||
"completion/drop_params",
|
||||
"completion/prompt_formatting",
|
||||
"completion/output",
|
||||
"completion/usage",
|
||||
"exception_mapping",
|
||||
"completion/stream",
|
||||
"completion/message_trimming",
|
||||
"completion/function_call",
|
||||
"completion/model_alias",
|
||||
"completion/batching",
|
||||
"completion/mock_requests",
|
||||
"completion/reliable_completions",
|
||||
],
|
||||
},
|
||||
"embedding/supported_embedding",
|
||||
|
|
|
@ -24,6 +24,7 @@ from litellm.proxy._types import (
|
|||
KeyManagementSettings,
|
||||
LiteLLM_UpperboundKeyGenerateParams,
|
||||
)
|
||||
from litellm.types.utils import StandardKeyGenerationConfig
|
||||
import httpx
|
||||
import dotenv
|
||||
from enum import Enum
|
||||
|
@ -273,6 +274,7 @@ s3_callback_params: Optional[Dict] = None
|
|||
generic_logger_headers: Optional[Dict] = None
|
||||
default_key_generate_params: Optional[Dict] = None
|
||||
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
||||
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
|
||||
default_internal_user_params: Optional[Dict] = None
|
||||
default_team_settings: Optional[List] = None
|
||||
max_user_budget: Optional[float] = None
|
||||
|
@ -280,6 +282,7 @@ default_max_internal_user_budget: Optional[float] = None
|
|||
max_internal_user_budget: Optional[float] = None
|
||||
internal_user_budget_duration: Optional[str] = None
|
||||
max_end_user_budget: Optional[float] = None
|
||||
disable_end_user_cost_tracking: Optional[bool] = None
|
||||
#### REQUEST PRIORITIZATION ####
|
||||
priority_reservation: Optional[Dict[str, float]] = None
|
||||
#### RELIABILITY ####
|
||||
|
|
|
@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.integrations.prometheus import *
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
class PrometheusLogger(CustomLogger):
|
||||
|
@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger):
|
|||
model = kwargs.get("model", "")
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata = litellm_params.get("metadata", {})
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
||||
|
@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger):
|
|||
|
||||
# unpack kwargs
|
||||
model = kwargs.get("model", "")
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
standard_logging_payload: StandardLoggingPayload = kwargs.get(
|
||||
"standard_logging_object", {}
|
||||
)
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
||||
|
|
|
@ -934,19 +934,10 @@ class Logging:
|
|||
status="success",
|
||||
)
|
||||
)
|
||||
if self.dynamic_success_callbacks is not None and isinstance(
|
||||
self.dynamic_success_callbacks, list
|
||||
):
|
||||
callbacks = self.dynamic_success_callbacks
|
||||
## keep the internal functions ##
|
||||
for callback in litellm.success_callback:
|
||||
if (
|
||||
isinstance(callback, CustomLogger)
|
||||
and "_PROXY_" in callback.__class__.__name__
|
||||
):
|
||||
callbacks.append(callback)
|
||||
else:
|
||||
callbacks = litellm.success_callback
|
||||
callbacks = get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
||||
global_callbacks=litellm.success_callback,
|
||||
)
|
||||
|
||||
## REDACT MESSAGES ##
|
||||
result = redact_message_input_output_from_logging(
|
||||
|
@ -1368,8 +1359,11 @@ class Logging:
|
|||
and customLogger is not None
|
||||
): # custom logger functions
|
||||
print_verbose(
|
||||
"success callbacks: Running Custom Callback Function"
|
||||
"success callbacks: Running Custom Callback Function - {}".format(
|
||||
callback
|
||||
)
|
||||
)
|
||||
|
||||
customLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
|
@ -1466,21 +1460,10 @@ class Logging:
|
|||
status="success",
|
||||
)
|
||||
)
|
||||
if self.dynamic_async_success_callbacks is not None and isinstance(
|
||||
self.dynamic_async_success_callbacks, list
|
||||
):
|
||||
callbacks = self.dynamic_async_success_callbacks
|
||||
## keep the internal functions ##
|
||||
for callback in litellm._async_success_callback:
|
||||
callback_name = ""
|
||||
if isinstance(callback, CustomLogger):
|
||||
callback_name = callback.__class__.__name__
|
||||
if callable(callback):
|
||||
callback_name = callback.__name__
|
||||
if "_PROXY_" in callback_name:
|
||||
callbacks.append(callback)
|
||||
else:
|
||||
callbacks = litellm._async_success_callback
|
||||
callbacks = get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
||||
global_callbacks=litellm._async_success_callback,
|
||||
)
|
||||
|
||||
result = redact_message_input_output_from_logging(
|
||||
model_call_details=(
|
||||
|
@ -1747,21 +1730,10 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
callbacks = [] # init this to empty incase it's not created
|
||||
|
||||
if self.dynamic_failure_callbacks is not None and isinstance(
|
||||
self.dynamic_failure_callbacks, list
|
||||
):
|
||||
callbacks = self.dynamic_failure_callbacks
|
||||
## keep the internal functions ##
|
||||
for callback in litellm.failure_callback:
|
||||
if (
|
||||
isinstance(callback, CustomLogger)
|
||||
and "_PROXY_" in callback.__class__.__name__
|
||||
):
|
||||
callbacks.append(callback)
|
||||
else:
|
||||
callbacks = litellm.failure_callback
|
||||
callbacks = get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_failure_callbacks,
|
||||
global_callbacks=litellm.failure_callback,
|
||||
)
|
||||
|
||||
result = None # result sent to all loggers, init this to None incase it's not created
|
||||
|
||||
|
@ -1944,21 +1916,10 @@ class Logging:
|
|||
end_time=end_time,
|
||||
)
|
||||
|
||||
callbacks = [] # init this to empty incase it's not created
|
||||
|
||||
if self.dynamic_async_failure_callbacks is not None and isinstance(
|
||||
self.dynamic_async_failure_callbacks, list
|
||||
):
|
||||
callbacks = self.dynamic_async_failure_callbacks
|
||||
## keep the internal functions ##
|
||||
for callback in litellm._async_failure_callback:
|
||||
if (
|
||||
isinstance(callback, CustomLogger)
|
||||
and "_PROXY_" in callback.__class__.__name__
|
||||
):
|
||||
callbacks.append(callback)
|
||||
else:
|
||||
callbacks = litellm._async_failure_callback
|
||||
callbacks = get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_async_failure_callbacks,
|
||||
global_callbacks=litellm._async_failure_callback,
|
||||
)
|
||||
|
||||
result = None # result sent to all loggers, init this to None incase it's not created
|
||||
for callback in callbacks:
|
||||
|
@ -2359,6 +2320,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
_in_memory_loggers.append(_mlflow_logger)
|
||||
return _mlflow_logger # type: ignore
|
||||
|
||||
|
||||
def get_custom_logger_compatible_class(
|
||||
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
|
||||
) -> Optional[CustomLogger]:
|
||||
|
@ -2949,3 +2911,11 @@ def modify_integration(integration_name, integration_params):
|
|||
if integration_name == "supabase":
|
||||
if "table_name" in integration_params:
|
||||
Supabase.supabase_table_name = integration_params["table_name"]
|
||||
|
||||
|
||||
def get_combined_callback_list(
|
||||
dynamic_success_callbacks: Optional[List], global_callbacks: List
|
||||
) -> List:
|
||||
if dynamic_success_callbacks is None:
|
||||
return global_callbacks
|
||||
return list(set(dynamic_success_callbacks + global_callbacks))
|
||||
|
|
|
@ -4,6 +4,7 @@ import httpx
|
|||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.rerank import RerankResponse
|
||||
|
||||
|
||||
|
@ -73,6 +74,7 @@ class AzureAIRerank(CohereRerank):
|
|||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
if headers is None:
|
||||
|
|
|
@ -74,6 +74,7 @@ async def async_embedding(
|
|||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.COHERE,
|
||||
|
@ -151,6 +152,11 @@ def embedding(
|
|||
api_key=api_key,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
|
|
@ -6,10 +6,14 @@ LiteLLM supports the re rank API format, no paramter transformation occurs
|
|||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
|
@ -34,6 +38,23 @@ class CohereRerank(BaseLLM):
|
|||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def ensure_rerank_endpoint(self, api_base: str) -> str:
|
||||
"""
|
||||
Ensures the `/v1/rerank` endpoint is appended to the given `api_base`.
|
||||
If `/v1/rerank` is already present, the original URL is returned.
|
||||
|
||||
:param api_base: The base API URL.
|
||||
:return: A URL with `/v1/rerank` appended if missing.
|
||||
"""
|
||||
# Parse the base URL to ensure proper structure
|
||||
url = httpx.URL(api_base)
|
||||
|
||||
# Check if the URL already ends with `/v1/rerank`
|
||||
if not url.path.endswith("/v1/rerank"):
|
||||
url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank")
|
||||
|
||||
return str(url)
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -48,9 +69,10 @@ class CohereRerank(BaseLLM):
|
|||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False, # New parameter
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
headers = self.validate_environment(api_key=api_key, headers=headers)
|
||||
|
||||
api_base = self.ensure_rerank_endpoint(api_base)
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
|
@ -76,9 +98,13 @@ class CohereRerank(BaseLLM):
|
|||
if _is_async:
|
||||
return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
|
||||
|
||||
client = _get_httpx_client()
|
||||
if client is not None and isinstance(client, HTTPHandler):
|
||||
client = client
|
||||
else:
|
||||
client = _get_httpx_client()
|
||||
|
||||
response = client.post(
|
||||
api_base,
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
json=request_data_dict,
|
||||
)
|
||||
|
@ -100,10 +126,13 @@ class CohereRerank(BaseLLM):
|
|||
api_key: str,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> RerankResponse:
|
||||
request_data_dict = request_data.dict(exclude_none=True)
|
||||
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE)
|
||||
client = client or get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.COHERE
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
api_base,
|
||||
|
|
|
@ -3440,6 +3440,10 @@ def embedding( # noqa: PLR0915
|
|||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
|
||||
api_type = "openai"
|
||||
api_version = None
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ model_list:
|
|||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
|
||||
router_settings:
|
||||
model_group_alias:
|
||||
"gpt-4-turbo": # Aliased model name
|
||||
|
@ -35,4 +35,4 @@ litellm_settings:
|
|||
failure_callback: ["langfuse"]
|
||||
langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2
|
||||
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2
|
||||
langfuse_host: https://us.cloud.langfuse.com
|
||||
langfuse_host: https://us.cloud.langfuse.com
|
||||
|
|
|
@ -2,6 +2,7 @@ import enum
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import fields
|
||||
from datetime import datetime
|
||||
|
@ -12,7 +13,15 @@ from typing_extensions import Annotated, TypedDict
|
|||
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
||||
from litellm.types.utils import ProviderField, StandardCallbackDynamicParams
|
||||
from litellm.types.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
ProviderField,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardPassThroughResponseObject,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -882,15 +891,7 @@ class DeleteCustomerRequest(LiteLLMBase):
|
|||
user_ids: List[str]
|
||||
|
||||
|
||||
class Member(LiteLLMBase):
|
||||
role: Literal[
|
||||
LitellmUserRoles.ORG_ADMIN,
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||
# older Member roles
|
||||
"admin",
|
||||
"user",
|
||||
]
|
||||
class MemberBase(LiteLLMBase):
|
||||
user_id: Optional[str] = None
|
||||
user_email: Optional[str] = None
|
||||
|
||||
|
@ -904,6 +905,21 @@ class Member(LiteLLMBase):
|
|||
return values
|
||||
|
||||
|
||||
class Member(MemberBase):
|
||||
role: Literal[
|
||||
"admin",
|
||||
"user",
|
||||
]
|
||||
|
||||
|
||||
class OrgMember(MemberBase):
|
||||
role: Literal[
|
||||
LitellmUserRoles.ORG_ADMIN,
|
||||
LitellmUserRoles.INTERNAL_USER,
|
||||
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
|
||||
]
|
||||
|
||||
|
||||
class TeamBase(LiteLLMBase):
|
||||
team_alias: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
|
@ -1966,6 +1982,26 @@ class MemberAddRequest(LiteLLMBase):
|
|||
# Replace member_data with the single Member object
|
||||
data["member"] = member
|
||||
# Call the superclass __init__ method to initialize the object
|
||||
traceback.print_stack()
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class OrgMemberAddRequest(LiteLLMBase):
|
||||
member: Union[List[OrgMember], OrgMember]
|
||||
|
||||
def __init__(self, **data):
|
||||
member_data = data.get("member")
|
||||
if isinstance(member_data, list):
|
||||
# If member is a list of dictionaries, convert each dictionary to a Member object
|
||||
members = [OrgMember(**item) for item in member_data]
|
||||
# Replace member_data with the list of Member objects
|
||||
data["member"] = members
|
||||
elif isinstance(member_data, dict):
|
||||
# If member is a dictionary, convert it to a single Member object
|
||||
member = OrgMember(**member_data)
|
||||
# Replace member_data with the single Member object
|
||||
data["member"] = member
|
||||
# Call the superclass __init__ method to initialize the object
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
|
@ -2017,7 +2053,7 @@ class TeamMemberUpdateResponse(MemberUpdateResponse):
|
|||
|
||||
|
||||
# Organization Member Requests
|
||||
class OrganizationMemberAddRequest(MemberAddRequest):
|
||||
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
||||
organization_id: str
|
||||
max_budget_in_organization: Optional[float] = (
|
||||
None # Users max budget within the organization
|
||||
|
@ -2133,3 +2169,17 @@ class UserManagementEndpointParamDocStringEnums(str, enum.Enum):
|
|||
spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used."""
|
||||
team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None."""
|
||||
duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None."""
|
||||
|
||||
|
||||
PassThroughEndpointLoggingResultValues = Union[
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
ImageResponse,
|
||||
EmbeddingResponse,
|
||||
StandardPassThroughResponseObject,
|
||||
]
|
||||
|
||||
|
||||
class PassThroughEndpointLoggingTypedDict(TypedDict):
|
||||
result: Optional[PassThroughEndpointLoggingResultValues]
|
||||
kwargs: dict
|
||||
|
|
|
@ -40,6 +40,77 @@ from litellm.proxy.utils import (
|
|||
)
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
||||
def _is_team_key(data: GenerateKeyRequest):
|
||||
return data.team_id is not None
|
||||
|
||||
|
||||
def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth):
|
||||
if (
|
||||
litellm.key_generation_settings is None
|
||||
or litellm.key_generation_settings.get("team_key_generation") is None
|
||||
):
|
||||
return True
|
||||
|
||||
if user_api_key_dict.team_member is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}",
|
||||
)
|
||||
|
||||
team_member_role = user_api_key_dict.team_member.role
|
||||
if (
|
||||
team_member_role
|
||||
not in litellm.key_generation_settings["team_key_generation"][ # type: ignore
|
||||
"allowed_team_member_roles"
|
||||
]
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _personal_key_generation_check(user_api_key_dict: UserAPIKeyAuth):
|
||||
|
||||
if (
|
||||
litellm.key_generation_settings is None
|
||||
or litellm.key_generation_settings.get("personal_key_generation") is None
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role
|
||||
not in litellm.key_generation_settings["personal_key_generation"][ # type: ignore
|
||||
"allowed_user_roles"
|
||||
]
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def key_generation_check(
|
||||
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
|
||||
) -> bool:
|
||||
"""
|
||||
Check if admin has restricted key creation to certain roles for teams or individuals
|
||||
"""
|
||||
if litellm.key_generation_settings is None:
|
||||
return True
|
||||
|
||||
## check if key is for team or individual
|
||||
is_team_key = _is_team_key(data=data)
|
||||
|
||||
if is_team_key:
|
||||
return _team_key_generation_check(user_api_key_dict)
|
||||
else:
|
||||
return _personal_key_generation_check(user_api_key_dict=user_api_key_dict)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
@ -131,6 +202,8 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
||||
)
|
||||
elif litellm.key_generation_settings is not None:
|
||||
key_generation_check(user_api_key_dict=user_api_key_dict, data=data)
|
||||
# check if user set default key/generate params on config.yaml
|
||||
if litellm.default_key_generate_params is not None:
|
||||
for elem in data:
|
||||
|
|
|
@ -352,7 +352,7 @@ async def organization_member_add(
|
|||
},
|
||||
)
|
||||
|
||||
members: List[Member]
|
||||
members: List[OrgMember]
|
||||
if isinstance(data.member, List):
|
||||
members = data.member
|
||||
else:
|
||||
|
@ -397,7 +397,7 @@ async def organization_member_add(
|
|||
|
||||
|
||||
async def add_member_to_organization(
|
||||
member: Member,
|
||||
member: OrgMember,
|
||||
organization_id: str,
|
||||
prisma_client: PrismaClient,
|
||||
) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]:
|
||||
|
|
|
@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat.handler import (
|
|||
ModelResponseIterator as AnthropicModelResponseIterator,
|
||||
)
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..success_handler import PassThroughEndpointLogging
|
||||
|
@ -26,7 +27,7 @@ else:
|
|||
class AnthropicPassthroughLoggingHandler:
|
||||
|
||||
@staticmethod
|
||||
async def anthropic_passthrough_handler(
|
||||
def anthropic_passthrough_handler(
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
|
@ -36,7 +37,7 @@ class AnthropicPassthroughLoggingHandler:
|
|||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
|
||||
"""
|
||||
|
@ -67,15 +68,10 @@ class AnthropicPassthroughLoggingHandler:
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
result=litellm_model_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
pass
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _create_anthropic_response_logging_payload(
|
||||
|
@ -123,7 +119,7 @@ class AnthropicPassthroughLoggingHandler:
|
|||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
async def _handle_logging_anthropic_collected_chunks(
|
||||
def _handle_logging_anthropic_collected_chunks(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
|
@ -132,7 +128,7 @@ class AnthropicPassthroughLoggingHandler:
|
|||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
end_time: datetime,
|
||||
):
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
|
||||
|
||||
|
@ -152,7 +148,10 @@ class AnthropicPassthroughLoggingHandler:
|
|||
verbose_proxy_logger.error(
|
||||
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
|
||||
)
|
||||
return
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": {},
|
||||
}
|
||||
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
||||
litellm_model_response=complete_streaming_response,
|
||||
model=model,
|
||||
|
@ -161,13 +160,11 @@ class AnthropicPassthroughLoggingHandler:
|
|||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
await litellm_logging_obj.async_success_handler(
|
||||
result=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_streaming_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_complete_streaming_response(
|
||||
|
|
|
@ -14,6 +14,7 @@ from litellm.litellm_core_utils.litellm_logging import (
|
|||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
ModelResponseIterator as VertexModelResponseIterator,
|
||||
)
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..success_handler import PassThroughEndpointLogging
|
||||
|
@ -25,7 +26,7 @@ else:
|
|||
|
||||
class VertexPassthroughLoggingHandler:
|
||||
@staticmethod
|
||||
async def vertex_passthrough_handler(
|
||||
def vertex_passthrough_handler(
|
||||
httpx_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
|
@ -34,7 +35,7 @@ class VertexPassthroughLoggingHandler:
|
|||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
if "generateContent" in url_route:
|
||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
|
||||
|
@ -65,13 +66,11 @@ class VertexPassthroughLoggingHandler:
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
result=litellm_model_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
elif "predict" in url_route:
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
|
||||
VertexImageGeneration,
|
||||
|
@ -112,16 +111,18 @@ class VertexPassthroughLoggingHandler:
|
|||
logging_obj.model = model
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
result=litellm_prediction_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
return {
|
||||
"result": litellm_prediction_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _handle_logging_vertex_collected_chunks(
|
||||
def _handle_logging_vertex_collected_chunks(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
|
@ -130,7 +131,7 @@ class VertexPassthroughLoggingHandler:
|
|||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
end_time: datetime,
|
||||
):
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
|
||||
|
||||
|
@ -152,7 +153,11 @@ class VertexPassthroughLoggingHandler:
|
|||
verbose_proxy_logger.error(
|
||||
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
|
||||
)
|
||||
return
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
||||
litellm_model_response=complete_streaming_response,
|
||||
model=model,
|
||||
|
@ -161,13 +166,11 @@ class VertexPassthroughLoggingHandler:
|
|||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
await litellm_logging_obj.async_success_handler(
|
||||
result=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_streaming_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_complete_streaming_response(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import AsyncIterable, Dict, List, Optional, Union
|
||||
|
@ -15,7 +16,12 @@ from litellm.llms.anthropic.chat.handler import (
|
|||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
ModelResponseIterator as VertexAIIterator,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
||||
from litellm.types.utils import (
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
StandardPassThroughResponseObject,
|
||||
)
|
||||
|
||||
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||
AnthropicPassthroughLoggingHandler,
|
||||
|
@ -87,8 +93,12 @@ class PassThroughStreamingHandler:
|
|||
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
|
||||
raw_bytes
|
||||
)
|
||||
standard_logging_response_object: Optional[
|
||||
PassThroughEndpointLoggingResultValues
|
||||
] = None
|
||||
kwargs: dict = {}
|
||||
if endpoint_type == EndpointType.ANTHROPIC:
|
||||
await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
|
||||
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
|
@ -98,20 +108,48 @@ class PassThroughStreamingHandler:
|
|||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
standard_logging_response_object = anthropic_passthrough_logging_handler_result[
|
||||
"result"
|
||||
]
|
||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||
elif endpoint_type == EndpointType.VERTEX_AI:
|
||||
await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
vertex_passthrough_logging_handler_result = (
|
||||
VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
elif endpoint_type == EndpointType.GENERIC:
|
||||
# No logging is supported for generic streaming endpoints
|
||||
pass
|
||||
standard_logging_response_object = vertex_passthrough_logging_handler_result[
|
||||
"result"
|
||||
]
|
||||
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||
|
||||
if standard_logging_response_object is None:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
|
||||
)
|
||||
threading.Thread(
|
||||
target=litellm_logging_obj.success_handler,
|
||||
args=(
|
||||
standard_logging_response_object,
|
||||
start_time,
|
||||
end_time,
|
||||
False,
|
||||
),
|
||||
).start()
|
||||
await litellm_logging_obj.async_success_handler(
|
||||
result=standard_logging_response_object,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
|
||||
|
@ -130,4 +168,4 @@ class PassThroughStreamingHandler:
|
|||
# Split by newlines and filter out empty lines
|
||||
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
|
||||
|
||||
return lines
|
||||
return lines
|
|
@ -15,6 +15,7 @@ from litellm.litellm_core_utils.litellm_logging import (
|
|||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.utils import StandardPassThroughResponseObject
|
||||
|
||||
|
@ -49,53 +50,69 @@ class PassThroughEndpointLogging:
|
|||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
standard_logging_response_object: Optional[
|
||||
PassThroughEndpointLoggingResultValues
|
||||
] = None
|
||||
if self.is_vertex_route(url_route):
|
||||
await VertexPassthroughLoggingHandler.vertex_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
vertex_passthrough_logging_handler_result = (
|
||||
VertexPassthroughLoggingHandler.vertex_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
vertex_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_anthropic_route(url_route):
|
||||
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
anthropic_passthrough_logging_handler_result = (
|
||||
AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
standard_logging_response_object = (
|
||||
anthropic_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||
if standard_logging_response_object is None:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=httpx_response.text
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(
|
||||
standard_logging_response_object,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
),
|
||||
).start()
|
||||
await logging_obj.async_success_handler(
|
||||
result=(
|
||||
json.dumps(result)
|
||||
if isinstance(result, dict)
|
||||
else standard_logging_response_object
|
||||
),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(
|
||||
standard_logging_response_object,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
),
|
||||
).start()
|
||||
await logging_obj.async_success_handler(
|
||||
result=(
|
||||
json.dumps(result)
|
||||
if isinstance(result, dict)
|
||||
else standard_logging_response_object
|
||||
),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def is_vertex_route(self, url_route: str):
|
||||
for route in self.TRACKED_VERTEX_ROUTES:
|
||||
|
|
|
@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import (
|
|||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.router import RouterGeneralSettings
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback(
|
|||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
user_id = metadata.get("user_api_key_user_id", None)
|
||||
team_id = metadata.get("user_api_key_team_id", None)
|
||||
|
|
|
@ -337,14 +337,14 @@ class ProxyLogging:
|
|||
alert_to_webhook_url=self.alert_to_webhook_url,
|
||||
)
|
||||
|
||||
if (
|
||||
self.alerting is not None
|
||||
and "slack" in self.alerting
|
||||
and "daily_reports" in self.alert_types
|
||||
):
|
||||
if self.alerting is not None and "slack" in self.alerting:
|
||||
# NOTE: ENSURE we only add callbacks when alerting is on
|
||||
# We should NOT add callbacks when alerting is off
|
||||
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
||||
if "daily_reports" in self.alert_types:
|
||||
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
|
||||
litellm.success_callback.append(
|
||||
self.slack_alerting_instance.response_taking_too_long_callback
|
||||
)
|
||||
|
||||
if redis_cache is not None:
|
||||
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
|
||||
|
@ -354,9 +354,6 @@ class ProxyLogging:
|
|||
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
|
||||
litellm.callbacks.append(self.cache_control_check) # type: ignore
|
||||
litellm.callbacks.append(self.service_logging_obj) # type: ignore
|
||||
litellm.success_callback.append(
|
||||
self.slack_alerting_instance.response_taking_too_long_callback
|
||||
)
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, str):
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||
|
|
|
@ -91,6 +91,7 @@ def rerank(
|
|||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
user = kwargs.get("user", None)
|
||||
client = kwargs.get("client", None)
|
||||
try:
|
||||
_is_async = kwargs.pop("arerank", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
@ -150,7 +151,7 @@ def rerank(
|
|||
or optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret("COHERE_API_BASE") # type: ignore
|
||||
or "https://api.cohere.com/v1/rerank"
|
||||
or "https://api.cohere.com"
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
|
@ -173,6 +174,7 @@ def rerank(
|
|||
_is_async=_is_async,
|
||||
headers=headers,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
client=client,
|
||||
)
|
||||
elif _custom_llm_provider == "azure_ai":
|
||||
api_base = (
|
||||
|
|
|
@ -1602,3 +1602,16 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
|
|||
langsmith_api_key: Optional[str]
|
||||
langsmith_project: Optional[str]
|
||||
langsmith_base_url: Optional[str]
|
||||
|
||||
|
||||
class TeamUIKeyGenerationConfig(TypedDict):
|
||||
allowed_team_member_roles: List[str]
|
||||
|
||||
|
||||
class PersonalUIKeyGenerationConfig(TypedDict):
|
||||
allowed_user_roles: List[str]
|
||||
|
||||
|
||||
class StandardKeyGenerationConfig(TypedDict, total=False):
|
||||
team_key_generation: TeamUIKeyGenerationConfig
|
||||
personal_key_generation: PersonalUIKeyGenerationConfig
|
||||
|
|
|
@ -6170,3 +6170,13 @@ class ProviderConfigManager:
|
|||
return litellm.GroqChatConfig()
|
||||
|
||||
return OpenAIGPTConfig()
|
||||
|
||||
|
||||
def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
|
||||
"""
|
||||
Used for enforcing `disable_end_user_cost_tracking` param.
|
||||
"""
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
if litellm.disable_end_user_cost_tracking:
|
||||
return None
|
||||
return proxy_server_request.get("body", {}).get("user", None)
|
||||
|
|
|
@ -1080,3 +1080,34 @@ def test_cohere_img_embeddings(input, input_type):
|
|||
assert response.usage.prompt_tokens_details.image_tokens > 0
|
||||
else:
|
||||
assert response.usage.prompt_tokens_details.text_tokens > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_with_extra_headers(sync_mode):
|
||||
|
||||
input = ["hello world"]
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
|
||||
if sync_mode:
|
||||
client = HTTPHandler()
|
||||
else:
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
data = {
|
||||
"model": "cohere/embed-english-v3.0",
|
||||
"input": input,
|
||||
"extra_headers": {"my-test-param": "hello-world"},
|
||||
"client": client,
|
||||
}
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
if sync_mode:
|
||||
embedding(**data)
|
||||
else:
|
||||
await litellm.aembedding(**data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
assert "my-test-param" in mock_post.call_args.kwargs["headers"]
|
||||
|
|
|
@ -215,7 +215,10 @@ async def test_rerank_custom_api_base():
|
|||
args_to_api = kwargs["json"]
|
||||
print("Arguments passed to API=", args_to_api)
|
||||
print("url = ", _url)
|
||||
assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/"
|
||||
assert (
|
||||
_url[0]
|
||||
== "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
|
||||
)
|
||||
assert args_to_api == expected_payload
|
||||
assert response.id is not None
|
||||
assert response.results is not None
|
||||
|
@ -258,3 +261,32 @@ async def test_rerank_custom_callbacks():
|
|||
assert custom_logger.kwargs.get("response_cost") > 0.0
|
||||
assert custom_logger.response_obj is not None
|
||||
assert custom_logger.response_obj.results is not None
|
||||
|
||||
|
||||
def test_complete_base_url_cohere():
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
litellm.api_base = "http://localhost:4000"
|
||||
litellm.set_verbose = True
|
||||
|
||||
text = "Hello there!"
|
||||
list_texts = ["Hello there!", "How are you?", "How do you do?"]
|
||||
|
||||
rerank_model = "rerank-multilingual-v3.0"
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
litellm.rerank(
|
||||
model=rerank_model,
|
||||
query=text,
|
||||
documents=list_texts,
|
||||
custom_llm_provider="cohere",
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
print("mock_post.call_args", mock_post.call_args)
|
||||
mock_post.assert_called_once()
|
||||
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"]
|
||||
|
|
|
@ -1012,3 +1012,23 @@ def test_models_by_provider():
|
|||
|
||||
for provider in providers:
|
||||
assert provider in models_by_provider.keys()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"litellm_params, disable_end_user_cost_tracking, expected_end_user_id",
|
||||
[
|
||||
({}, False, None),
|
||||
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
|
||||
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
|
||||
],
|
||||
)
|
||||
def test_get_end_user_id_for_cost_tracking(
|
||||
litellm_params, disable_end_user_cost_tracking, expected_end_user_id
|
||||
):
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking
|
||||
assert (
|
||||
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
|
||||
== expected_end_user_id
|
||||
)
|
||||
|
|
|
@ -216,3 +216,78 @@ async def test_init_custom_logger_compatible_class_as_callback():
|
|||
await use_callback_in_llm_call(callback, used_in="success_callback")
|
||||
|
||||
reset_env_vars()
|
||||
|
||||
|
||||
def test_dynamic_logging_global_callback():
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import ModelResponse, Choices, Message, Usage
|
||||
|
||||
cl = CustomLogger()
|
||||
|
||||
litellm_logging = LiteLLMLoggingObj(
|
||||
model="claude-3-opus-20240229",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
call_type="completion",
|
||||
start_time=datetime.now(),
|
||||
litellm_call_id="123",
|
||||
function_id="456",
|
||||
kwargs={
|
||||
"langfuse_public_key": "my-mock-public-key",
|
||||
"langfuse_secret_key": "my-mock-secret-key",
|
||||
},
|
||||
dynamic_success_callbacks=["langfuse"],
|
||||
)
|
||||
|
||||
with patch.object(cl, "log_success_event") as mock_log_success_event:
|
||||
cl.log_success_event = mock_log_success_event
|
||||
litellm.success_callback = [cl]
|
||||
|
||||
try:
|
||||
litellm_logging.success_handler(
|
||||
result=ModelResponse(
|
||||
id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70",
|
||||
created=1732306261,
|
||||
model="claude-3-opus-20240229",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(
|
||||
content="hello",
|
||||
role="assistant",
|
||||
tool_calls=None,
|
||||
function_call=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
completion_tokens=20,
|
||||
prompt_tokens=10,
|
||||
total_tokens=30,
|
||||
completion_tokens_details=None,
|
||||
prompt_tokens_details=None,
|
||||
),
|
||||
),
|
||||
start_time=datetime.now(),
|
||||
end_time=datetime.now(),
|
||||
cache_hit=False,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
mock_log_success_event.assert_called_once()
|
||||
|
||||
|
||||
def test_get_combined_callback_list():
|
||||
from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list
|
||||
|
||||
assert "langfuse" in get_combined_callback_list(
|
||||
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
|
||||
)
|
||||
assert "lago" in get_combined_callback_list(
|
||||
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
|
||||
)
|
||||
|
|
|
@ -73,7 +73,7 @@ async def test_anthropic_passthrough_handler(
|
|||
start_time = datetime.now()
|
||||
end_time = datetime.now()
|
||||
|
||||
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
||||
result = AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
||||
httpx_response=mock_httpx_response,
|
||||
response_body=mock_response,
|
||||
logging_obj=mock_logging_obj,
|
||||
|
@ -84,30 +84,7 @@ async def test_anthropic_passthrough_handler(
|
|||
cache_hit=False,
|
||||
)
|
||||
|
||||
# Assert that async_success_handler was called
|
||||
assert mock_logging_obj.async_success_handler.called
|
||||
|
||||
call_args = mock_logging_obj.async_success_handler.call_args
|
||||
call_kwargs = call_args.kwargs
|
||||
print("call_kwargs", call_kwargs)
|
||||
|
||||
# Assert required fields are present in call_kwargs
|
||||
assert "result" in call_kwargs
|
||||
assert "start_time" in call_kwargs
|
||||
assert "end_time" in call_kwargs
|
||||
assert "cache_hit" in call_kwargs
|
||||
assert "response_cost" in call_kwargs
|
||||
assert "model" in call_kwargs
|
||||
assert "standard_logging_object" in call_kwargs
|
||||
|
||||
# Assert specific values and types
|
||||
assert isinstance(call_kwargs["result"], litellm.ModelResponse)
|
||||
assert isinstance(call_kwargs["start_time"], datetime)
|
||||
assert isinstance(call_kwargs["end_time"], datetime)
|
||||
assert isinstance(call_kwargs["cache_hit"], bool)
|
||||
assert isinstance(call_kwargs["response_cost"], float)
|
||||
assert call_kwargs["model"] == "claude-3-opus-20240229"
|
||||
assert isinstance(call_kwargs["standard_logging_object"], dict)
|
||||
assert isinstance(result["result"], litellm.ModelResponse)
|
||||
|
||||
|
||||
def test_create_anthropic_response_logging_payload(mock_logging_obj):
|
||||
|
|
|
@ -64,6 +64,7 @@ async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route):
|
|||
litellm_logging_obj = MagicMock()
|
||||
start_time = datetime.now()
|
||||
passthrough_success_handler_obj = MagicMock()
|
||||
litellm_logging_obj.async_success_handler = AsyncMock()
|
||||
|
||||
# Capture yielded chunks and perform detailed assertions
|
||||
received_chunks = []
|
||||
|
|
54
tests/proxy_admin_ui_tests/conftest.py
Normal file
54
tests/proxy_admin_ui_tests/conftest.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# conftest.py
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown():
|
||||
"""
|
||||
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||
"""
|
||||
curr_dir = os.getcwd() # Get the current working directory
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the project directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
importlib.reload(litellm)
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
print(litellm)
|
||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||
yield
|
||||
|
||||
# Teardown code (executes after the yield point)
|
||||
loop.close() # Close the loop created earlier
|
||||
asyncio.set_event_loop(None) # Remove the reference to the loop
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||
custom_logger_tests = [
|
||||
item for item in items if "custom_logger" in item.parent.name
|
||||
]
|
||||
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
|
||||
|
||||
# Sort tests based on their names
|
||||
custom_logger_tests.sort(key=lambda x: x.name)
|
||||
other_tests.sort(key=lambda x: x.name)
|
||||
|
||||
# Reorder the items list
|
||||
items[:] = custom_logger_tests + other_tests
|
|
@ -542,3 +542,65 @@ async def test_list_teams(prisma_client):
|
|||
|
||||
# Clean up
|
||||
await prisma_client.delete_data(team_id_list=[team_id], table_name="team")
|
||||
|
||||
|
||||
def test_is_team_key():
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import _is_team_key
|
||||
|
||||
assert _is_team_key(GenerateKeyRequest(team_id="test_team_id"))
|
||||
assert not _is_team_key(GenerateKeyRequest(user_id="test_user_id"))
|
||||
|
||||
|
||||
def test_team_key_generation_check():
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_team_key_generation_check,
|
||||
)
|
||||
from fastapi import HTTPException
|
||||
|
||||
litellm.key_generation_settings = {
|
||||
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
|
||||
}
|
||||
|
||||
assert _team_key_generation_check(
|
||||
UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
api_key="sk-1234",
|
||||
team_member=Member(role="admin", user_id="test_user_id"),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
_team_key_generation_check(
|
||||
UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
api_key="sk-1234",
|
||||
user_id="test_user_id",
|
||||
team_member=Member(role="user", user_id="test_user_id"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_personal_key_generation_check():
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_personal_key_generation_check,
|
||||
)
|
||||
from fastapi import HTTPException
|
||||
|
||||
litellm.key_generation_settings = {
|
||||
"personal_key_generation": {"allowed_user_roles": ["proxy_admin"]}
|
||||
}
|
||||
|
||||
assert _personal_key_generation_check(
|
||||
UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
_personal_key_generation_check(
|
||||
UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
api_key="sk-1234",
|
||||
user_id="admin",
|
||||
)
|
||||
)
|
||||
|
|
|
@ -160,7 +160,7 @@ async def test_create_new_user_in_organization(prisma_client, user_role):
|
|||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org_id,
|
||||
member=Member(role=user_role, user_id=created_user_id),
|
||||
member=OrgMember(role=user_role, user_id=created_user_id),
|
||||
),
|
||||
http_request=None,
|
||||
)
|
||||
|
@ -220,7 +220,7 @@ async def test_org_admin_create_team_permissions(prisma_client):
|
|||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org_id,
|
||||
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||
member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||
),
|
||||
http_request=None,
|
||||
)
|
||||
|
@ -292,7 +292,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
|
|||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org_id,
|
||||
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||
member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||
),
|
||||
http_request=None,
|
||||
)
|
||||
|
@ -323,7 +323,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
|
|||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org_id,
|
||||
member=Member(
|
||||
member=OrgMember(
|
||||
role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org
|
||||
),
|
||||
),
|
||||
|
@ -375,7 +375,7 @@ async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client):
|
|||
response = await organization_member_add(
|
||||
data=OrganizationMemberAddRequest(
|
||||
organization_id=org1_id,
|
||||
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||
member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
|
||||
),
|
||||
http_request=None,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue