Compare commits

...
Sign in to create a new pull request.

21 commits

Author SHA1 Message Date
Krrish Dholakia
b06b0248ff fix: fix pass through testing 2024-11-23 14:11:54 +05:30
Krish Dholakia
4021206ac2
Merge branch 'main' into litellm_dev_11_22_2024 2024-11-23 13:52:17 +05:30
Krrish Dholakia
88db948d29 test: fix test 2024-11-23 12:35:55 +05:30
Krrish Dholakia
b5c5f87e25 fix: fix linting errors 2024-11-23 03:12:59 +05:30
Krrish Dholakia
ed64dd7f9d test: fix test 2024-11-23 03:06:51 +05:30
Krrish Dholakia
e2ccf6b7f4 build: add conftest to proxy_admin_ui_tests/ 2024-11-23 02:44:10 +05:30
Krrish Dholakia
f8a46b5950 fix: fix typing 2024-11-23 02:36:09 +05:30
Krrish Dholakia
31943bf2ad fix(handler.py): fix linting error 2024-11-23 02:08:50 +05:30
Krrish Dholakia
541326731f fix(litellm_logging.py): don't disable global callbacks when dynamic callbacks are set
Fixes issue where global callbacks - e.g. prometheus were overriden when langfuse was set dynamically
2024-11-23 02:00:45 +05:30
Krrish Dholakia
dfb34dfe92 fix(main.py): pass extra_headers param to openai
Fixes https://github.com/BerriAI/litellm/issues/6836
2024-11-23 01:23:03 +05:30
Krrish Dholakia
250d66b335 feat(rerank.py): add /v1/rerank if missing for cohere base url
Closes https://github.com/BerriAI/litellm/issues/6844
2024-11-23 01:07:39 +05:30
Krrish Dholakia
94fe135524 feat(cohere/embed): pass client to async embedding 2024-11-23 00:47:26 +05:30
Krrish Dholakia
1a3fb18a64 test(test_embedding.py): add test for passing extra headers via embedding 2024-11-23 00:32:40 +05:30
Krrish Dholakia
d788c3c37f docs(input.md): cleanup anthropic supported params
Closes https://github.com/BerriAI/litellm/issues/6856
2024-11-23 00:21:04 +05:30
Krrish Dholakia
4beb48829c docs(finetuned_models.md): add new guide on calling finetuned models 2024-11-23 00:08:03 +05:30
Krrish Dholakia
eb0a357eda docs: add docs on restricting key creation 2024-11-22 23:11:58 +05:30
Krrish Dholakia
463fa0c9d5 test(test_key_management.py): add unit testing for personal / team key restriction checks 2024-11-22 23:03:38 +05:30
Krrish Dholakia
1014216d73 feat(key_management_endpoints.py): add support for restricting access to /key/generate by team/proxy level role
Enables admin to restrict key creation, and assign team admins to handle distributing keys
2024-11-22 22:59:01 +05:30
Krrish Dholakia
97d8aa0b3a docs(configs.md): add disable_end_user_cost_tracking reference to docs 2024-11-22 16:43:35 +05:30
Krrish Dholakia
5a698c678a fix(utils.py): allow disabling end user cost tracking with new param
Allows proxy admin to disable cost tracking for end user - keeps prometheus metrics small
2024-11-22 16:41:58 +05:30
Krrish Dholakia
1c9a8c0b68 feat(pass_through_endpoints/): support logging anthropic/gemini pass through calls to langfuse/s3/etc. 2024-11-22 16:21:57 +05:30
35 changed files with 871 additions and 248 deletions

View file

@ -41,7 +41,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea
| Provider | temperature | max_completion_tokens | max_tokens | top_p | stream | stream_options | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | | ✅ | ✅ | | | ✅ |
|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | | ✅ | ✅ | | | ✅ |
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |

View file

@ -0,0 +1,74 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Calling Finetuned Models
## OpenAI
| Model Name | Function Call |
|---------------------------|-----------------------------------------------------------------|
| fine tuned `gpt-4-0613` | `response = completion(model="ft:gpt-4-0613", messages=messages)` |
| fine tuned `gpt-4o-2024-05-13` | `response = completion(model="ft:gpt-4o-2024-05-13", messages=messages)` |
| fine tuned `gpt-3.5-turbo-0125` | `response = completion(model="ft:gpt-3.5-turbo-0125", messages=messages)` |
| fine tuned `gpt-3.5-turbo-1106` | `response = completion(model="ft:gpt-3.5-turbo-1106", messages=messages)` |
| fine tuned `gpt-3.5-turbo-0613` | `response = completion(model="ft:gpt-3.5-turbo-0613", messages=messages)` |
## Vertex AI
Fine tuned models on vertex have a numerical model/endpoint id.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
## set ENV variables
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
os.environ["VERTEXAI_LOCATION"] = "us-central1"
response = completion(
model="vertex_ai/<your-finetuned-model>", # e.g. vertex_ai/4965075652664360960
messages=[{ "content": "Hello, how are you?","role": "user"}],
base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add Vertex Credentials to your env
```bash
!gcloud auth application-default login
```
2. Setup config.yaml
```yaml
- model_name: finetuned-gemini
litellm_params:
model: vertex_ai/<ENDPOINT_ID>
vertex_project: <PROJECT_ID>
vertex_location: <LOCATION>
model_info:
base_model: vertex_ai/gemini-1.5-pro # IMPORTANT
```
3. Test it!
```bash
curl --location 'https://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: <LITELLM_KEY>' \
--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}'
```
</TabItem>
</Tabs>

View file

@ -754,6 +754,8 @@ general_settings:
| cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. |
| cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) |
| cache_params.mode | string | The mode of the cache. [Further docs](./caching) |
| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. |
| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) |
### general_settings - Reference

View file

@ -217,4 +217,10 @@ litellm_settings:
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
key_generation_settings: # Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation)
team_key_generation:
allowed_team_member_roles: ["admin"]
personal_key_generation: # maps to 'Default Team' on UI
allowed_user_roles: ["proxy_admin"]
```

View file

@ -811,6 +811,75 @@ litellm_settings:
team_id: "core-infra"
```
### Restricting Key Generation
Use this to control who can generate keys. Useful when letting others create keys on the UI.
```yaml
litellm_settings:
key_generation_settings:
team_key_generation:
allowed_team_member_roles: ["admin"]
personal_key_generation: # maps to 'Default Team' on UI
allowed_user_roles: ["proxy_admin"]
```
#### Spec
```python
class TeamUIKeyGenerationConfig(TypedDict):
allowed_team_member_roles: List[str]
class PersonalUIKeyGenerationConfig(TypedDict):
allowed_user_roles: List[LitellmUserRoles]
class StandardKeyGenerationConfig(TypedDict, total=False):
team_key_generation: TeamUIKeyGenerationConfig
personal_key_generation: PersonalUIKeyGenerationConfig
class LitellmUserRoles(str, enum.Enum):
"""
Admin Roles:
PROXY_ADMIN: admin over the platform
PROXY_ADMIN_VIEW_ONLY: can login, view all own keys, view all spend
ORG_ADMIN: admin over a specific organization, can create teams, users only within their organization
Internal User Roles:
INTERNAL_USER: can login, view/create/delete their own keys, view their spend
INTERNAL_USER_VIEW_ONLY: can login, view their own keys, view their own spend
Team Roles:
TEAM: used for JWT auth
Customer Roles:
CUSTOMER: External users -> these are customers
"""
# Admin Roles
PROXY_ADMIN = "proxy_admin"
PROXY_ADMIN_VIEW_ONLY = "proxy_admin_viewer"
# Organization admins
ORG_ADMIN = "org_admin"
# Internal User Roles
INTERNAL_USER = "internal_user"
INTERNAL_USER_VIEW_ONLY = "internal_user_viewer"
# Team Roles
TEAM = "team"
# Customer Roles - External users of proxy
CUSTOMER = "customer"
```
## **Next Steps - Set Budgets, Rate Limits per Virtual Key**
[Follow this doc to set budgets, rate limiters per virtual key with LiteLLM](users)

View file

@ -199,6 +199,31 @@ const sidebars = {
],
},
{
type: "category",
label: "Guides",
items: [
"exception_mapping",
"completion/provider_specific_params",
"guides/finetuned_models",
"completion/audio",
"completion/vision",
"completion/json_mode",
"completion/prompt_caching",
"completion/predict_outputs",
"completion/prefix",
"completion/drop_params",
"completion/prompt_formatting",
"completion/stream",
"completion/message_trimming",
"completion/function_call",
"completion/model_alias",
"completion/batching",
"completion/mock_requests",
"completion/reliable_completions",
]
},
{
type: "category",
label: "Supported Endpoints",
@ -214,25 +239,8 @@ const sidebars = {
},
items: [
"completion/input",
"completion/provider_specific_params",
"completion/json_mode",
"completion/prompt_caching",
"completion/audio",
"completion/vision",
"completion/predict_outputs",
"completion/prefix",
"completion/drop_params",
"completion/prompt_formatting",
"completion/output",
"completion/usage",
"exception_mapping",
"completion/stream",
"completion/message_trimming",
"completion/function_call",
"completion/model_alias",
"completion/batching",
"completion/mock_requests",
"completion/reliable_completions",
],
},
"embedding/supported_embedding",

View file

@ -24,6 +24,7 @@ from litellm.proxy._types import (
KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams,
)
from litellm.types.utils import StandardKeyGenerationConfig
import httpx
import dotenv
from enum import Enum
@ -273,6 +274,7 @@ s3_callback_params: Optional[Dict] = None
generic_logger_headers: Optional[Dict] = None
default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
default_internal_user_params: Optional[Dict] = None
default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None
@ -280,6 +282,7 @@ default_max_internal_user_budget: Optional[float] = None
max_internal_user_budget: Optional[float] = None
internal_user_budget_duration: Optional[str] = None
max_end_user_budget: Optional[float] = None
disable_end_user_cost_tracking: Optional[bool] = None
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
#### RELIABILITY ####

View file

@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.prometheus import *
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
class PrometheusLogger(CustomLogger):
@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger):
model = kwargs.get("model", "")
litellm_params = kwargs.get("litellm_params", {}) or {}
_metadata = litellm_params.get("metadata", {})
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger):
# unpack kwargs
model = kwargs.get("model", "")
litellm_params = kwargs.get("litellm_params", {}) or {}
standard_logging_payload: StandardLoggingPayload = kwargs.get(
"standard_logging_object", {}
)
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]

View file

@ -934,19 +934,10 @@ class Logging:
status="success",
)
)
if self.dynamic_success_callbacks is not None and isinstance(
self.dynamic_success_callbacks, list
):
callbacks = self.dynamic_success_callbacks
## keep the internal functions ##
for callback in litellm.success_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm.success_callback
callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_success_callbacks,
global_callbacks=litellm.success_callback,
)
## REDACT MESSAGES ##
result = redact_message_input_output_from_logging(
@ -1368,8 +1359,11 @@ class Logging:
and customLogger is not None
): # custom logger functions
print_verbose(
"success callbacks: Running Custom Callback Function"
"success callbacks: Running Custom Callback Function - {}".format(
callback
)
)
customLogger.log_event(
kwargs=self.model_call_details,
response_obj=result,
@ -1466,21 +1460,10 @@ class Logging:
status="success",
)
)
if self.dynamic_async_success_callbacks is not None and isinstance(
self.dynamic_async_success_callbacks, list
):
callbacks = self.dynamic_async_success_callbacks
## keep the internal functions ##
for callback in litellm._async_success_callback:
callback_name = ""
if isinstance(callback, CustomLogger):
callback_name = callback.__class__.__name__
if callable(callback):
callback_name = callback.__name__
if "_PROXY_" in callback_name:
callbacks.append(callback)
else:
callbacks = litellm._async_success_callback
callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
global_callbacks=litellm._async_success_callback,
)
result = redact_message_input_output_from_logging(
model_call_details=(
@ -1747,21 +1730,10 @@ class Logging:
start_time=start_time,
end_time=end_time,
)
callbacks = [] # init this to empty incase it's not created
if self.dynamic_failure_callbacks is not None and isinstance(
self.dynamic_failure_callbacks, list
):
callbacks = self.dynamic_failure_callbacks
## keep the internal functions ##
for callback in litellm.failure_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm.failure_callback
callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_failure_callbacks,
global_callbacks=litellm.failure_callback,
)
result = None # result sent to all loggers, init this to None incase it's not created
@ -1944,21 +1916,10 @@ class Logging:
end_time=end_time,
)
callbacks = [] # init this to empty incase it's not created
if self.dynamic_async_failure_callbacks is not None and isinstance(
self.dynamic_async_failure_callbacks, list
):
callbacks = self.dynamic_async_failure_callbacks
## keep the internal functions ##
for callback in litellm._async_failure_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm._async_failure_callback
callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_failure_callbacks,
global_callbacks=litellm._async_failure_callback,
)
result = None # result sent to all loggers, init this to None incase it's not created
for callback in callbacks:
@ -2359,6 +2320,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
_in_memory_loggers.append(_mlflow_logger)
return _mlflow_logger # type: ignore
def get_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Optional[CustomLogger]:
@ -2949,3 +2911,11 @@ def modify_integration(integration_name, integration_params):
if integration_name == "supabase":
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]
def get_combined_callback_list(
dynamic_success_callbacks: Optional[List], global_callbacks: List
) -> List:
if dynamic_success_callbacks is None:
return global_callbacks
return list(set(dynamic_success_callbacks + global_callbacks))

View file

@ -4,6 +4,7 @@ import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.rerank import RerankResponse
@ -73,6 +74,7 @@ class AzureAIRerank(CohereRerank):
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
if headers is None:

View file

@ -74,6 +74,7 @@ async def async_embedding(
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE,
@ -151,6 +152,11 @@ def embedding(
api_key=api_key,
headers=headers,
encoding=encoding,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
## LOGGING

View file

@ -6,10 +6,14 @@ LiteLLM supports the re rank API format, no paramter transformation occurs
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
@ -34,6 +38,23 @@ class CohereRerank(BaseLLM):
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def ensure_rerank_endpoint(self, api_base: str) -> str:
"""
Ensures the `/v1/rerank` endpoint is appended to the given `api_base`.
If `/v1/rerank` is already present, the original URL is returned.
:param api_base: The base API URL.
:return: A URL with `/v1/rerank` appended if missing.
"""
# Parse the base URL to ensure proper structure
url = httpx.URL(api_base)
# Check if the URL already ends with `/v1/rerank`
if not url.path.endswith("/v1/rerank"):
url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank")
return str(url)
def rerank(
self,
model: str,
@ -48,9 +69,10 @@ class CohereRerank(BaseLLM):
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False, # New parameter
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
headers = self.validate_environment(api_key=api_key, headers=headers)
api_base = self.ensure_rerank_endpoint(api_base)
request_data = RerankRequest(
model=model,
query=query,
@ -76,9 +98,13 @@ class CohereRerank(BaseLLM):
if _is_async:
return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
client = _get_httpx_client()
if client is not None and isinstance(client, HTTPHandler):
client = client
else:
client = _get_httpx_client()
response = client.post(
api_base,
url=api_base,
headers=headers,
json=request_data_dict,
)
@ -100,10 +126,13 @@ class CohereRerank(BaseLLM):
api_key: str,
api_base: str,
headers: dict,
client: Optional[AsyncHTTPHandler] = None,
) -> RerankResponse:
request_data_dict = request_data.dict(exclude_none=True)
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE)
client = client or get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE
)
response = await client.post(
api_base,

View file

@ -3440,6 +3440,10 @@ def embedding( # noqa: PLR0915
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
api_type = "openai"
api_version = None

View file

@ -2,6 +2,7 @@ import enum
import json
import os
import sys
import traceback
import uuid
from dataclasses import fields
from datetime import datetime
@ -12,7 +13,15 @@ from typing_extensions import Annotated, TypedDict
from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.router import RouterErrors, UpdateRouterConfig
from litellm.types.utils import ProviderField, StandardCallbackDynamicParams
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ProviderField,
StandardCallbackDynamicParams,
StandardPassThroughResponseObject,
TextCompletionResponse,
)
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -882,15 +891,7 @@ class DeleteCustomerRequest(LiteLLMBase):
user_ids: List[str]
class Member(LiteLLMBase):
role: Literal[
LitellmUserRoles.ORG_ADMIN,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
# older Member roles
"admin",
"user",
]
class MemberBase(LiteLLMBase):
user_id: Optional[str] = None
user_email: Optional[str] = None
@ -904,6 +905,21 @@ class Member(LiteLLMBase):
return values
class Member(MemberBase):
role: Literal[
"admin",
"user",
]
class OrgMember(MemberBase):
role: Literal[
LitellmUserRoles.ORG_ADMIN,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
]
class TeamBase(LiteLLMBase):
team_alias: Optional[str] = None
team_id: Optional[str] = None
@ -1966,6 +1982,26 @@ class MemberAddRequest(LiteLLMBase):
# Replace member_data with the single Member object
data["member"] = member
# Call the superclass __init__ method to initialize the object
traceback.print_stack()
super().__init__(**data)
class OrgMemberAddRequest(LiteLLMBase):
member: Union[List[OrgMember], OrgMember]
def __init__(self, **data):
member_data = data.get("member")
if isinstance(member_data, list):
# If member is a list of dictionaries, convert each dictionary to a Member object
members = [OrgMember(**item) for item in member_data]
# Replace member_data with the list of Member objects
data["member"] = members
elif isinstance(member_data, dict):
# If member is a dictionary, convert it to a single Member object
member = OrgMember(**member_data)
# Replace member_data with the single Member object
data["member"] = member
# Call the superclass __init__ method to initialize the object
super().__init__(**data)
@ -2017,7 +2053,7 @@ class TeamMemberUpdateResponse(MemberUpdateResponse):
# Organization Member Requests
class OrganizationMemberAddRequest(MemberAddRequest):
class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str
max_budget_in_organization: Optional[float] = (
None # Users max budget within the organization
@ -2133,3 +2169,17 @@ class UserManagementEndpointParamDocStringEnums(str, enum.Enum):
spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used."""
team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None."""
duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None."""
PassThroughEndpointLoggingResultValues = Union[
ModelResponse,
TextCompletionResponse,
ImageResponse,
EmbeddingResponse,
StandardPassThroughResponseObject,
]
class PassThroughEndpointLoggingTypedDict(TypedDict):
result: Optional[PassThroughEndpointLoggingResultValues]
kwargs: dict

View file

@ -40,6 +40,77 @@ from litellm.proxy.utils import (
)
from litellm.secret_managers.main import get_secret
def _is_team_key(data: GenerateKeyRequest):
return data.team_id is not None
def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth):
if (
litellm.key_generation_settings is None
or litellm.key_generation_settings.get("team_key_generation") is None
):
return True
if user_api_key_dict.team_member is None:
raise HTTPException(
status_code=400,
detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}",
)
team_member_role = user_api_key_dict.team_member.role
if (
team_member_role
not in litellm.key_generation_settings["team_key_generation"][ # type: ignore
"allowed_team_member_roles"
]
):
raise HTTPException(
status_code=400,
detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore
)
return True
def _personal_key_generation_check(user_api_key_dict: UserAPIKeyAuth):
if (
litellm.key_generation_settings is None
or litellm.key_generation_settings.get("personal_key_generation") is None
):
return True
if (
user_api_key_dict.user_role
not in litellm.key_generation_settings["personal_key_generation"][ # type: ignore
"allowed_user_roles"
]
):
raise HTTPException(
status_code=400,
detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore
)
return True
def key_generation_check(
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
) -> bool:
"""
Check if admin has restricted key creation to certain roles for teams or individuals
"""
if litellm.key_generation_settings is None:
return True
## check if key is for team or individual
is_team_key = _is_team_key(data=data)
if is_team_key:
return _team_key_generation_check(user_api_key_dict)
else:
return _personal_key_generation_check(user_api_key_dict=user_api_key_dict)
router = APIRouter()
@ -131,6 +202,8 @@ async def generate_key_fn( # noqa: PLR0915
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=message
)
elif litellm.key_generation_settings is not None:
key_generation_check(user_api_key_dict=user_api_key_dict, data=data)
# check if user set default key/generate params on config.yaml
if litellm.default_key_generate_params is not None:
for elem in data:

View file

@ -352,7 +352,7 @@ async def organization_member_add(
},
)
members: List[Member]
members: List[OrgMember]
if isinstance(data.member, List):
members = data.member
else:
@ -397,7 +397,7 @@ async def organization_member_add(
async def add_member_to_organization(
member: Member,
member: OrgMember,
organization_id: str,
prisma_client: PrismaClient,
) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]:

View file

@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
@ -26,7 +27,7 @@ else:
class AnthropicPassthroughLoggingHandler:
@staticmethod
async def anthropic_passthrough_handler(
def anthropic_passthrough_handler(
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
@ -36,7 +37,7 @@ class AnthropicPassthroughLoggingHandler:
end_time: datetime,
cache_hit: bool,
**kwargs,
):
) -> PassThroughEndpointLoggingTypedDict:
"""
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
@ -67,15 +68,10 @@ class AnthropicPassthroughLoggingHandler:
logging_obj=logging_obj,
)
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
pass
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
@staticmethod
def _create_anthropic_response_logging_payload(
@ -123,7 +119,7 @@ class AnthropicPassthroughLoggingHandler:
return kwargs
@staticmethod
async def _handle_logging_anthropic_collected_chunks(
def _handle_logging_anthropic_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
@ -132,7 +128,7 @@ class AnthropicPassthroughLoggingHandler:
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
):
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
@ -152,7 +148,10 @@ class AnthropicPassthroughLoggingHandler:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
)
return
return {
"result": None,
"kwargs": {},
}
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
@ -161,13 +160,11 @@ class AnthropicPassthroughLoggingHandler:
end_time=end_time,
logging_obj=litellm_logging_obj,
)
await litellm_logging_obj.async_success_handler(
result=complete_streaming_response,
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
return {
"result": complete_streaming_response,
"kwargs": kwargs,
}
@staticmethod
def _build_complete_streaming_response(

View file

@ -14,6 +14,7 @@ from litellm.litellm_core_utils.litellm_logging import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexModelResponseIterator,
)
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
@ -25,7 +26,7 @@ else:
class VertexPassthroughLoggingHandler:
@staticmethod
async def vertex_passthrough_handler(
def vertex_passthrough_handler(
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
@ -34,7 +35,7 @@ class VertexPassthroughLoggingHandler:
end_time: datetime,
cache_hit: bool,
**kwargs,
):
) -> PassThroughEndpointLoggingTypedDict:
if "generateContent" in url_route:
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
@ -65,13 +66,11 @@ class VertexPassthroughLoggingHandler:
logging_obj=logging_obj,
)
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
elif "predict" in url_route:
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
VertexImageGeneration,
@ -112,16 +111,18 @@ class VertexPassthroughLoggingHandler:
logging_obj.model = model
logging_obj.model_call_details["model"] = logging_obj.model
await logging_obj.async_success_handler(
result=litellm_prediction_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
return {
"result": litellm_prediction_response,
"kwargs": kwargs,
}
else:
return {
"result": None,
"kwargs": kwargs,
}
@staticmethod
async def _handle_logging_vertex_collected_chunks(
def _handle_logging_vertex_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
@ -130,7 +131,7 @@ class VertexPassthroughLoggingHandler:
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
):
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
@ -152,7 +153,11 @@ class VertexPassthroughLoggingHandler:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
)
return
return {
"result": None,
"kwargs": kwargs,
}
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
litellm_model_response=complete_streaming_response,
model=model,
@ -161,13 +166,11 @@ class VertexPassthroughLoggingHandler:
end_time=end_time,
logging_obj=litellm_logging_obj,
)
await litellm_logging_obj.async_success_handler(
result=complete_streaming_response,
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
return {
"result": complete_streaming_response,
"kwargs": kwargs,
}
@staticmethod
def _build_complete_streaming_response(

View file

@ -1,5 +1,6 @@
import asyncio
import json
import threading
from datetime import datetime
from enum import Enum
from typing import AsyncIterable, Dict, List, Optional, Union
@ -15,7 +16,12 @@ from litellm.llms.anthropic.chat.handler import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexAIIterator,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.types.utils import (
GenericStreamingChunk,
ModelResponse,
StandardPassThroughResponseObject,
)
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
@ -87,8 +93,12 @@ class PassThroughStreamingHandler:
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
raw_bytes
)
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
kwargs: dict = {}
if endpoint_type == EndpointType.ANTHROPIC:
await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
@ -98,20 +108,48 @@ class PassThroughStreamingHandler:
all_chunks=all_chunks,
end_time=end_time,
)
standard_logging_response_object = anthropic_passthrough_logging_handler_result[
"result"
]
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.VERTEX_AI:
await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
)
elif endpoint_type == EndpointType.GENERIC:
# No logging is supported for generic streaming endpoints
pass
standard_logging_response_object = vertex_passthrough_logging_handler_result[
"result"
]
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
)
threading.Thread(
target=litellm_logging_obj.success_handler,
args=(
standard_logging_response_object,
start_time,
end_time,
False,
),
).start()
await litellm_logging_obj.async_success_handler(
result=standard_logging_response_object,
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
@staticmethod
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:

View file

@ -15,6 +15,7 @@ from litellm.litellm_core_utils.litellm_logging import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import StandardPassThroughResponseObject
@ -49,53 +50,69 @@ class PassThroughEndpointLogging:
cache_hit: bool,
**kwargs,
):
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
if self.is_vertex_route(url_route):
await VertexPassthroughLoggingHandler.vertex_passthrough_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler.vertex_passthrough_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif self.is_anthropic_route(url_route):
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
anthropic_passthrough_logging_handler_result = (
AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
else:
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
)
threading.Thread(
target=logging_obj.success_handler,
args=(
standard_logging_response_object,
start_time,
end_time,
cache_hit,
),
).start()
await logging_obj.async_success_handler(
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
threading.Thread(
target=logging_obj.success_handler,
args=(
standard_logging_response_object,
start_time,
end_time,
cache_hit,
),
).start()
await logging_obj.async_success_handler(
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
def is_vertex_route(self, url_route: str):
for route in self.TRACKED_VERTEX_ROUTES:

View file

@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import (
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import RouterGeneralSettings
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
try:
from litellm._version import version
@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback(
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None)
team_id = metadata.get("user_api_key_team_id", None)

View file

@ -337,14 +337,14 @@ class ProxyLogging:
alert_to_webhook_url=self.alert_to_webhook_url,
)
if (
self.alerting is not None
and "slack" in self.alerting
and "daily_reports" in self.alert_types
):
if self.alerting is not None and "slack" in self.alerting:
# NOTE: ENSURE we only add callbacks when alerting is on
# We should NOT add callbacks when alerting is off
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
if "daily_reports" in self.alert_types:
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
litellm.success_callback.append(
self.slack_alerting_instance.response_taking_too_long_callback
)
if redis_cache is not None:
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
@ -354,9 +354,6 @@ class ProxyLogging:
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
litellm.callbacks.append(self.cache_control_check) # type: ignore
litellm.callbacks.append(self.service_logging_obj) # type: ignore
litellm.success_callback.append(
self.slack_alerting_instance.response_taking_too_long_callback
)
for callback in litellm.callbacks:
if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore

View file

@ -91,6 +91,7 @@ def rerank(
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
user = kwargs.get("user", None)
client = kwargs.get("client", None)
try:
_is_async = kwargs.pop("arerank", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
@ -150,7 +151,7 @@ def rerank(
or optional_params.api_base
or litellm.api_base
or get_secret("COHERE_API_BASE") # type: ignore
or "https://api.cohere.com/v1/rerank"
or "https://api.cohere.com"
)
if api_base is None:
@ -173,6 +174,7 @@ def rerank(
_is_async=_is_async,
headers=headers,
litellm_logging_obj=litellm_logging_obj,
client=client,
)
elif _custom_llm_provider == "azure_ai":
api_base = (

View file

@ -1602,3 +1602,16 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
langsmith_api_key: Optional[str]
langsmith_project: Optional[str]
langsmith_base_url: Optional[str]
class TeamUIKeyGenerationConfig(TypedDict):
allowed_team_member_roles: List[str]
class PersonalUIKeyGenerationConfig(TypedDict):
allowed_user_roles: List[str]
class StandardKeyGenerationConfig(TypedDict, total=False):
team_key_generation: TeamUIKeyGenerationConfig
personal_key_generation: PersonalUIKeyGenerationConfig

View file

@ -6170,3 +6170,13 @@ class ProviderConfigManager:
return litellm.GroqChatConfig()
return OpenAIGPTConfig()
def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
"""
Used for enforcing `disable_end_user_cost_tracking` param.
"""
proxy_server_request = litellm_params.get("proxy_server_request") or {}
if litellm.disable_end_user_cost_tracking:
return None
return proxy_server_request.get("body", {}).get("user", None)

View file

@ -1080,3 +1080,34 @@ def test_cohere_img_embeddings(input, input_type):
assert response.usage.prompt_tokens_details.image_tokens > 0
else:
assert response.usage.prompt_tokens_details.text_tokens > 0
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_embedding_with_extra_headers(sync_mode):
input = ["hello world"]
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
if sync_mode:
client = HTTPHandler()
else:
client = AsyncHTTPHandler()
data = {
"model": "cohere/embed-english-v3.0",
"input": input,
"extra_headers": {"my-test-param": "hello-world"},
"client": client,
}
with patch.object(client, "post") as mock_post:
try:
if sync_mode:
embedding(**data)
else:
await litellm.aembedding(**data)
except Exception as e:
print(e)
mock_post.assert_called_once()
assert "my-test-param" in mock_post.call_args.kwargs["headers"]

View file

@ -215,7 +215,10 @@ async def test_rerank_custom_api_base():
args_to_api = kwargs["json"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/"
assert (
_url[0]
== "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
)
assert args_to_api == expected_payload
assert response.id is not None
assert response.results is not None
@ -258,3 +261,32 @@ async def test_rerank_custom_callbacks():
assert custom_logger.kwargs.get("response_cost") > 0.0
assert custom_logger.response_obj is not None
assert custom_logger.response_obj.results is not None
def test_complete_base_url_cohere():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
litellm.api_base = "http://localhost:4000"
litellm.set_verbose = True
text = "Hello there!"
list_texts = ["Hello there!", "How are you?", "How do you do?"]
rerank_model = "rerank-multilingual-v3.0"
with patch.object(client, "post") as mock_post:
try:
litellm.rerank(
model=rerank_model,
query=text,
documents=list_texts,
custom_llm_provider="cohere",
client=client,
)
except Exception as e:
print(e)
print("mock_post.call_args", mock_post.call_args)
mock_post.assert_called_once()
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"]

View file

@ -1012,3 +1012,23 @@ def test_models_by_provider():
for provider in providers:
assert provider in models_by_provider.keys()
@pytest.mark.parametrize(
"litellm_params, disable_end_user_cost_tracking, expected_end_user_id",
[
({}, False, None),
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
],
)
def test_get_end_user_id_for_cost_tracking(
litellm_params, disable_end_user_cost_tracking, expected_end_user_id
):
from litellm.utils import get_end_user_id_for_cost_tracking
litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking
assert (
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
== expected_end_user_id
)

View file

@ -216,3 +216,78 @@ async def test_init_custom_logger_compatible_class_as_callback():
await use_callback_in_llm_call(callback, used_in="success_callback")
reset_env_vars()
def test_dynamic_logging_global_callback():
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import ModelResponse, Choices, Message, Usage
cl = CustomLogger()
litellm_logging = LiteLLMLoggingObj(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="completion",
start_time=datetime.now(),
litellm_call_id="123",
function_id="456",
kwargs={
"langfuse_public_key": "my-mock-public-key",
"langfuse_secret_key": "my-mock-secret-key",
},
dynamic_success_callbacks=["langfuse"],
)
with patch.object(cl, "log_success_event") as mock_log_success_event:
cl.log_success_event = mock_log_success_event
litellm.success_callback = [cl]
try:
litellm_logging.success_handler(
result=ModelResponse(
id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70",
created=1732306261,
model="claude-3-opus-20240229",
object="chat.completion",
system_fingerprint=None,
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="hello",
role="assistant",
tool_calls=None,
function_call=None,
),
)
],
usage=Usage(
completion_tokens=20,
prompt_tokens=10,
total_tokens=30,
completion_tokens_details=None,
prompt_tokens_details=None,
),
),
start_time=datetime.now(),
end_time=datetime.now(),
cache_hit=False,
)
except Exception as e:
print(f"Error: {e}")
mock_log_success_event.assert_called_once()
def test_get_combined_callback_list():
from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list
assert "langfuse" in get_combined_callback_list(
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
)
assert "lago" in get_combined_callback_list(
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
)

View file

@ -73,7 +73,7 @@ async def test_anthropic_passthrough_handler(
start_time = datetime.now()
end_time = datetime.now()
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
result = AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=mock_httpx_response,
response_body=mock_response,
logging_obj=mock_logging_obj,
@ -84,30 +84,7 @@ async def test_anthropic_passthrough_handler(
cache_hit=False,
)
# Assert that async_success_handler was called
assert mock_logging_obj.async_success_handler.called
call_args = mock_logging_obj.async_success_handler.call_args
call_kwargs = call_args.kwargs
print("call_kwargs", call_kwargs)
# Assert required fields are present in call_kwargs
assert "result" in call_kwargs
assert "start_time" in call_kwargs
assert "end_time" in call_kwargs
assert "cache_hit" in call_kwargs
assert "response_cost" in call_kwargs
assert "model" in call_kwargs
assert "standard_logging_object" in call_kwargs
# Assert specific values and types
assert isinstance(call_kwargs["result"], litellm.ModelResponse)
assert isinstance(call_kwargs["start_time"], datetime)
assert isinstance(call_kwargs["end_time"], datetime)
assert isinstance(call_kwargs["cache_hit"], bool)
assert isinstance(call_kwargs["response_cost"], float)
assert call_kwargs["model"] == "claude-3-opus-20240229"
assert isinstance(call_kwargs["standard_logging_object"], dict)
assert isinstance(result["result"], litellm.ModelResponse)
def test_create_anthropic_response_logging_payload(mock_logging_obj):

View file

@ -64,6 +64,7 @@ async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route):
litellm_logging_obj = MagicMock()
start_time = datetime.now()
passthrough_success_handler_obj = MagicMock()
litellm_logging_obj.async_success_handler = AsyncMock()
# Capture yielded chunks and perform detailed assertions
received_chunks = []

View file

@ -0,0 +1,54 @@
# conftest.py
import importlib
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
import litellm
from litellm import Router
importlib.reload(litellm)
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(loop)
print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield
# Teardown code (executes after the yield point)
loop.close() # Close the loop created earlier
asyncio.set_event_loop(None) # Remove the reference to the loop
def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [
item for item in items if "custom_logger" in item.parent.name
]
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
# Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name)
# Reorder the items list
items[:] = custom_logger_tests + other_tests

View file

@ -542,3 +542,65 @@ async def test_list_teams(prisma_client):
# Clean up
await prisma_client.delete_data(team_id_list=[team_id], table_name="team")
def test_is_team_key():
from litellm.proxy.management_endpoints.key_management_endpoints import _is_team_key
assert _is_team_key(GenerateKeyRequest(team_id="test_team_id"))
assert not _is_team_key(GenerateKeyRequest(user_id="test_user_id"))
def test_team_key_generation_check():
from litellm.proxy.management_endpoints.key_management_endpoints import (
_team_key_generation_check,
)
from fastapi import HTTPException
litellm.key_generation_settings = {
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
}
assert _team_key_generation_check(
UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-1234",
team_member=Member(role="admin", user_id="test_user_id"),
)
)
with pytest.raises(HTTPException):
_team_key_generation_check(
UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-1234",
user_id="test_user_id",
team_member=Member(role="user", user_id="test_user_id"),
)
)
def test_personal_key_generation_check():
from litellm.proxy.management_endpoints.key_management_endpoints import (
_personal_key_generation_check,
)
from fastapi import HTTPException
litellm.key_generation_settings = {
"personal_key_generation": {"allowed_user_roles": ["proxy_admin"]}
}
assert _personal_key_generation_check(
UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
)
)
with pytest.raises(HTTPException):
_personal_key_generation_check(
UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-1234",
user_id="admin",
)
)

View file

@ -160,7 +160,7 @@ async def test_create_new_user_in_organization(prisma_client, user_role):
response = await organization_member_add(
data=OrganizationMemberAddRequest(
organization_id=org_id,
member=Member(role=user_role, user_id=created_user_id),
member=OrgMember(role=user_role, user_id=created_user_id),
),
http_request=None,
)
@ -220,7 +220,7 @@ async def test_org_admin_create_team_permissions(prisma_client):
response = await organization_member_add(
data=OrganizationMemberAddRequest(
organization_id=org_id,
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
),
http_request=None,
)
@ -292,7 +292,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
response = await organization_member_add(
data=OrganizationMemberAddRequest(
organization_id=org_id,
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
),
http_request=None,
)
@ -323,7 +323,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
response = await organization_member_add(
data=OrganizationMemberAddRequest(
organization_id=org_id,
member=Member(
member=OrgMember(
role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org
),
),
@ -375,7 +375,7 @@ async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client):
response = await organization_member_add(
data=OrganizationMemberAddRequest(
organization_id=org1_id,
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
),
http_request=None,
)