Compare commits

...
Sign in to create a new pull request.

21 commits

Author SHA1 Message Date
Krrish Dholakia
b06b0248ff fix: fix pass through testing 2024-11-23 14:11:54 +05:30
Krish Dholakia
4021206ac2
Merge branch 'main' into litellm_dev_11_22_2024 2024-11-23 13:52:17 +05:30
Krrish Dholakia
88db948d29 test: fix test 2024-11-23 12:35:55 +05:30
Krrish Dholakia
b5c5f87e25 fix: fix linting errors 2024-11-23 03:12:59 +05:30
Krrish Dholakia
ed64dd7f9d test: fix test 2024-11-23 03:06:51 +05:30
Krrish Dholakia
e2ccf6b7f4 build: add conftest to proxy_admin_ui_tests/ 2024-11-23 02:44:10 +05:30
Krrish Dholakia
f8a46b5950 fix: fix typing 2024-11-23 02:36:09 +05:30
Krrish Dholakia
31943bf2ad fix(handler.py): fix linting error 2024-11-23 02:08:50 +05:30
Krrish Dholakia
541326731f fix(litellm_logging.py): don't disable global callbacks when dynamic callbacks are set
Fixes issue where global callbacks - e.g. prometheus were overriden when langfuse was set dynamically
2024-11-23 02:00:45 +05:30
Krrish Dholakia
dfb34dfe92 fix(main.py): pass extra_headers param to openai
Fixes https://github.com/BerriAI/litellm/issues/6836
2024-11-23 01:23:03 +05:30
Krrish Dholakia
250d66b335 feat(rerank.py): add /v1/rerank if missing for cohere base url
Closes https://github.com/BerriAI/litellm/issues/6844
2024-11-23 01:07:39 +05:30
Krrish Dholakia
94fe135524 feat(cohere/embed): pass client to async embedding 2024-11-23 00:47:26 +05:30
Krrish Dholakia
1a3fb18a64 test(test_embedding.py): add test for passing extra headers via embedding 2024-11-23 00:32:40 +05:30
Krrish Dholakia
d788c3c37f docs(input.md): cleanup anthropic supported params
Closes https://github.com/BerriAI/litellm/issues/6856
2024-11-23 00:21:04 +05:30
Krrish Dholakia
4beb48829c docs(finetuned_models.md): add new guide on calling finetuned models 2024-11-23 00:08:03 +05:30
Krrish Dholakia
eb0a357eda docs: add docs on restricting key creation 2024-11-22 23:11:58 +05:30
Krrish Dholakia
463fa0c9d5 test(test_key_management.py): add unit testing for personal / team key restriction checks 2024-11-22 23:03:38 +05:30
Krrish Dholakia
1014216d73 feat(key_management_endpoints.py): add support for restricting access to /key/generate by team/proxy level role
Enables admin to restrict key creation, and assign team admins to handle distributing keys
2024-11-22 22:59:01 +05:30
Krrish Dholakia
97d8aa0b3a docs(configs.md): add disable_end_user_cost_tracking reference to docs 2024-11-22 16:43:35 +05:30
Krrish Dholakia
5a698c678a fix(utils.py): allow disabling end user cost tracking with new param
Allows proxy admin to disable cost tracking for end user - keeps prometheus metrics small
2024-11-22 16:41:58 +05:30
Krrish Dholakia
1c9a8c0b68 feat(pass_through_endpoints/): support logging anthropic/gemini pass through calls to langfuse/s3/etc. 2024-11-22 16:21:57 +05:30
35 changed files with 871 additions and 248 deletions

View file

@ -41,7 +41,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea
| Provider | temperature | max_completion_tokens | max_tokens | top_p | stream | stream_options | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers | | Provider | temperature | max_completion_tokens | max_tokens | top_p | stream | stream_options | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | | ✅ | ✅ | | | ✅ | |Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | | ✅ | ✅ | | | ✅ |
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ | |OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ | |Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |

View file

@ -0,0 +1,74 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Calling Finetuned Models
## OpenAI
| Model Name | Function Call |
|---------------------------|-----------------------------------------------------------------|
| fine tuned `gpt-4-0613` | `response = completion(model="ft:gpt-4-0613", messages=messages)` |
| fine tuned `gpt-4o-2024-05-13` | `response = completion(model="ft:gpt-4o-2024-05-13", messages=messages)` |
| fine tuned `gpt-3.5-turbo-0125` | `response = completion(model="ft:gpt-3.5-turbo-0125", messages=messages)` |
| fine tuned `gpt-3.5-turbo-1106` | `response = completion(model="ft:gpt-3.5-turbo-1106", messages=messages)` |
| fine tuned `gpt-3.5-turbo-0613` | `response = completion(model="ft:gpt-3.5-turbo-0613", messages=messages)` |
## Vertex AI
Fine tuned models on vertex have a numerical model/endpoint id.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
## set ENV variables
os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811"
os.environ["VERTEXAI_LOCATION"] = "us-central1"
response = completion(
model="vertex_ai/<your-finetuned-model>", # e.g. vertex_ai/4965075652664360960
messages=[{ "content": "Hello, how are you?","role": "user"}],
base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add Vertex Credentials to your env
```bash
!gcloud auth application-default login
```
2. Setup config.yaml
```yaml
- model_name: finetuned-gemini
litellm_params:
model: vertex_ai/<ENDPOINT_ID>
vertex_project: <PROJECT_ID>
vertex_location: <LOCATION>
model_info:
base_model: vertex_ai/gemini-1.5-pro # IMPORTANT
```
3. Test it!
```bash
curl --location 'https://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: <LITELLM_KEY>' \
--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}'
```
</TabItem>
</Tabs>

View file

@ -754,6 +754,8 @@ general_settings:
| cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. | | cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. |
| cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) | | cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) |
| cache_params.mode | string | The mode of the cache. [Further docs](./caching) | | cache_params.mode | string | The mode of the cache. [Further docs](./caching) |
| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. |
| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) |
### general_settings - Reference ### general_settings - Reference

View file

@ -217,4 +217,10 @@ litellm_settings:
max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None. max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None.
tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None. tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None.
rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None. rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None.
```
key_generation_settings: # Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation)
team_key_generation:
allowed_team_member_roles: ["admin"]
personal_key_generation: # maps to 'Default Team' on UI
allowed_user_roles: ["proxy_admin"]
```

View file

@ -811,6 +811,75 @@ litellm_settings:
team_id: "core-infra" team_id: "core-infra"
``` ```
### Restricting Key Generation
Use this to control who can generate keys. Useful when letting others create keys on the UI.
```yaml
litellm_settings:
key_generation_settings:
team_key_generation:
allowed_team_member_roles: ["admin"]
personal_key_generation: # maps to 'Default Team' on UI
allowed_user_roles: ["proxy_admin"]
```
#### Spec
```python
class TeamUIKeyGenerationConfig(TypedDict):
allowed_team_member_roles: List[str]
class PersonalUIKeyGenerationConfig(TypedDict):
allowed_user_roles: List[LitellmUserRoles]
class StandardKeyGenerationConfig(TypedDict, total=False):
team_key_generation: TeamUIKeyGenerationConfig
personal_key_generation: PersonalUIKeyGenerationConfig
class LitellmUserRoles(str, enum.Enum):
"""
Admin Roles:
PROXY_ADMIN: admin over the platform
PROXY_ADMIN_VIEW_ONLY: can login, view all own keys, view all spend
ORG_ADMIN: admin over a specific organization, can create teams, users only within their organization
Internal User Roles:
INTERNAL_USER: can login, view/create/delete their own keys, view their spend
INTERNAL_USER_VIEW_ONLY: can login, view their own keys, view their own spend
Team Roles:
TEAM: used for JWT auth
Customer Roles:
CUSTOMER: External users -> these are customers
"""
# Admin Roles
PROXY_ADMIN = "proxy_admin"
PROXY_ADMIN_VIEW_ONLY = "proxy_admin_viewer"
# Organization admins
ORG_ADMIN = "org_admin"
# Internal User Roles
INTERNAL_USER = "internal_user"
INTERNAL_USER_VIEW_ONLY = "internal_user_viewer"
# Team Roles
TEAM = "team"
# Customer Roles - External users of proxy
CUSTOMER = "customer"
```
## **Next Steps - Set Budgets, Rate Limits per Virtual Key** ## **Next Steps - Set Budgets, Rate Limits per Virtual Key**
[Follow this doc to set budgets, rate limiters per virtual key with LiteLLM](users) [Follow this doc to set budgets, rate limiters per virtual key with LiteLLM](users)

View file

@ -199,6 +199,31 @@ const sidebars = {
], ],
}, },
{
type: "category",
label: "Guides",
items: [
"exception_mapping",
"completion/provider_specific_params",
"guides/finetuned_models",
"completion/audio",
"completion/vision",
"completion/json_mode",
"completion/prompt_caching",
"completion/predict_outputs",
"completion/prefix",
"completion/drop_params",
"completion/prompt_formatting",
"completion/stream",
"completion/message_trimming",
"completion/function_call",
"completion/model_alias",
"completion/batching",
"completion/mock_requests",
"completion/reliable_completions",
]
},
{ {
type: "category", type: "category",
label: "Supported Endpoints", label: "Supported Endpoints",
@ -214,25 +239,8 @@ const sidebars = {
}, },
items: [ items: [
"completion/input", "completion/input",
"completion/provider_specific_params",
"completion/json_mode",
"completion/prompt_caching",
"completion/audio",
"completion/vision",
"completion/predict_outputs",
"completion/prefix",
"completion/drop_params",
"completion/prompt_formatting",
"completion/output", "completion/output",
"completion/usage", "completion/usage",
"exception_mapping",
"completion/stream",
"completion/message_trimming",
"completion/function_call",
"completion/model_alias",
"completion/batching",
"completion/mock_requests",
"completion/reliable_completions",
], ],
}, },
"embedding/supported_embedding", "embedding/supported_embedding",

View file

@ -24,6 +24,7 @@ from litellm.proxy._types import (
KeyManagementSettings, KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams, LiteLLM_UpperboundKeyGenerateParams,
) )
from litellm.types.utils import StandardKeyGenerationConfig
import httpx import httpx
import dotenv import dotenv
from enum import Enum from enum import Enum
@ -273,6 +274,7 @@ s3_callback_params: Optional[Dict] = None
generic_logger_headers: Optional[Dict] = None generic_logger_headers: Optional[Dict] = None
default_key_generate_params: Optional[Dict] = None default_key_generate_params: Optional[Dict] = None
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
default_internal_user_params: Optional[Dict] = None default_internal_user_params: Optional[Dict] = None
default_team_settings: Optional[List] = None default_team_settings: Optional[List] = None
max_user_budget: Optional[float] = None max_user_budget: Optional[float] = None
@ -280,6 +282,7 @@ default_max_internal_user_budget: Optional[float] = None
max_internal_user_budget: Optional[float] = None max_internal_user_budget: Optional[float] = None
internal_user_budget_duration: Optional[str] = None internal_user_budget_duration: Optional[str] = None
max_end_user_budget: Optional[float] = None max_end_user_budget: Optional[float] = None
disable_end_user_cost_tracking: Optional[bool] = None
#### REQUEST PRIORITIZATION #### #### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None priority_reservation: Optional[Dict[str, float]] = None
#### RELIABILITY #### #### RELIABILITY ####

View file

@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.prometheus import * from litellm.types.integrations.prometheus import *
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
class PrometheusLogger(CustomLogger): class PrometheusLogger(CustomLogger):
@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger):
model = kwargs.get("model", "") model = kwargs.get("model", "")
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
_metadata = litellm_params.get("metadata", {}) _metadata = litellm_params.get("metadata", {})
proxy_server_request = litellm_params.get("proxy_server_request") or {} end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger):
# unpack kwargs # unpack kwargs
model = kwargs.get("model", "") model = kwargs.get("model", "")
litellm_params = kwargs.get("litellm_params", {}) or {}
standard_logging_payload: StandardLoggingPayload = kwargs.get( standard_logging_payload: StandardLoggingPayload = kwargs.get(
"standard_logging_object", {} "standard_logging_object", {}
) )
proxy_server_request = litellm_params.get("proxy_server_request") or {} litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]

View file

@ -934,19 +934,10 @@ class Logging:
status="success", status="success",
) )
) )
if self.dynamic_success_callbacks is not None and isinstance( callbacks = get_combined_callback_list(
self.dynamic_success_callbacks, list dynamic_success_callbacks=self.dynamic_success_callbacks,
): global_callbacks=litellm.success_callback,
callbacks = self.dynamic_success_callbacks )
## keep the internal functions ##
for callback in litellm.success_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm.success_callback
## REDACT MESSAGES ## ## REDACT MESSAGES ##
result = redact_message_input_output_from_logging( result = redact_message_input_output_from_logging(
@ -1368,8 +1359,11 @@ class Logging:
and customLogger is not None and customLogger is not None
): # custom logger functions ): # custom logger functions
print_verbose( print_verbose(
"success callbacks: Running Custom Callback Function" "success callbacks: Running Custom Callback Function - {}".format(
callback
)
) )
customLogger.log_event( customLogger.log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
response_obj=result, response_obj=result,
@ -1466,21 +1460,10 @@ class Logging:
status="success", status="success",
) )
) )
if self.dynamic_async_success_callbacks is not None and isinstance( callbacks = get_combined_callback_list(
self.dynamic_async_success_callbacks, list dynamic_success_callbacks=self.dynamic_async_success_callbacks,
): global_callbacks=litellm._async_success_callback,
callbacks = self.dynamic_async_success_callbacks )
## keep the internal functions ##
for callback in litellm._async_success_callback:
callback_name = ""
if isinstance(callback, CustomLogger):
callback_name = callback.__class__.__name__
if callable(callback):
callback_name = callback.__name__
if "_PROXY_" in callback_name:
callbacks.append(callback)
else:
callbacks = litellm._async_success_callback
result = redact_message_input_output_from_logging( result = redact_message_input_output_from_logging(
model_call_details=( model_call_details=(
@ -1747,21 +1730,10 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
callbacks = [] # init this to empty incase it's not created callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_failure_callbacks,
if self.dynamic_failure_callbacks is not None and isinstance( global_callbacks=litellm.failure_callback,
self.dynamic_failure_callbacks, list )
):
callbacks = self.dynamic_failure_callbacks
## keep the internal functions ##
for callback in litellm.failure_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm.failure_callback
result = None # result sent to all loggers, init this to None incase it's not created result = None # result sent to all loggers, init this to None incase it's not created
@ -1944,21 +1916,10 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
callbacks = [] # init this to empty incase it's not created callbacks = get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_failure_callbacks,
if self.dynamic_async_failure_callbacks is not None and isinstance( global_callbacks=litellm._async_failure_callback,
self.dynamic_async_failure_callbacks, list )
):
callbacks = self.dynamic_async_failure_callbacks
## keep the internal functions ##
for callback in litellm._async_failure_callback:
if (
isinstance(callback, CustomLogger)
and "_PROXY_" in callback.__class__.__name__
):
callbacks.append(callback)
else:
callbacks = litellm._async_failure_callback
result = None # result sent to all loggers, init this to None incase it's not created result = None # result sent to all loggers, init this to None incase it's not created
for callback in callbacks: for callback in callbacks:
@ -2359,6 +2320,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
_in_memory_loggers.append(_mlflow_logger) _in_memory_loggers.append(_mlflow_logger)
return _mlflow_logger # type: ignore return _mlflow_logger # type: ignore
def get_custom_logger_compatible_class( def get_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal, logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Optional[CustomLogger]: ) -> Optional[CustomLogger]:
@ -2949,3 +2911,11 @@ def modify_integration(integration_name, integration_params):
if integration_name == "supabase": if integration_name == "supabase":
if "table_name" in integration_params: if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"] Supabase.supabase_table_name = integration_params["table_name"]
def get_combined_callback_list(
dynamic_success_callbacks: Optional[List], global_callbacks: List
) -> List:
if dynamic_success_callbacks is None:
return global_callbacks
return list(set(dynamic_success_callbacks + global_callbacks))

View file

@ -4,6 +4,7 @@ import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.rerank import RerankResponse from litellm.types.rerank import RerankResponse
@ -73,6 +74,7 @@ class AzureAIRerank(CohereRerank):
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False, _is_async: Optional[bool] = False,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse: ) -> RerankResponse:
if headers is None: if headers is None:

View file

@ -74,6 +74,7 @@ async def async_embedding(
}, },
) )
## COMPLETION CALL ## COMPLETION CALL
if client is None: if client is None:
client = get_async_httpx_client( client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE, llm_provider=litellm.LlmProviders.COHERE,
@ -151,6 +152,11 @@ def embedding(
api_key=api_key, api_key=api_key,
headers=headers, headers=headers,
encoding=encoding, encoding=encoding,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
) )
## LOGGING ## LOGGING

View file

@ -6,10 +6,14 @@ LiteLLM supports the re rank API format, no paramter transformation occurs
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import httpx
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
@ -34,6 +38,23 @@ class CohereRerank(BaseLLM):
# Merge other headers, overriding any default ones except Authorization # Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers} return {**default_headers, **headers}
def ensure_rerank_endpoint(self, api_base: str) -> str:
"""
Ensures the `/v1/rerank` endpoint is appended to the given `api_base`.
If `/v1/rerank` is already present, the original URL is returned.
:param api_base: The base API URL.
:return: A URL with `/v1/rerank` appended if missing.
"""
# Parse the base URL to ensure proper structure
url = httpx.URL(api_base)
# Check if the URL already ends with `/v1/rerank`
if not url.path.endswith("/v1/rerank"):
url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank")
return str(url)
def rerank( def rerank(
self, self,
model: str, model: str,
@ -48,9 +69,10 @@ class CohereRerank(BaseLLM):
return_documents: Optional[bool] = True, return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None, max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False, # New parameter _is_async: Optional[bool] = False, # New parameter
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse: ) -> RerankResponse:
headers = self.validate_environment(api_key=api_key, headers=headers) headers = self.validate_environment(api_key=api_key, headers=headers)
api_base = self.ensure_rerank_endpoint(api_base)
request_data = RerankRequest( request_data = RerankRequest(
model=model, model=model,
query=query, query=query,
@ -76,9 +98,13 @@ class CohereRerank(BaseLLM):
if _is_async: if _is_async:
return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
client = _get_httpx_client() if client is not None and isinstance(client, HTTPHandler):
client = client
else:
client = _get_httpx_client()
response = client.post( response = client.post(
api_base, url=api_base,
headers=headers, headers=headers,
json=request_data_dict, json=request_data_dict,
) )
@ -100,10 +126,13 @@ class CohereRerank(BaseLLM):
api_key: str, api_key: str,
api_base: str, api_base: str,
headers: dict, headers: dict,
client: Optional[AsyncHTTPHandler] = None,
) -> RerankResponse: ) -> RerankResponse:
request_data_dict = request_data.dict(exclude_none=True) request_data_dict = request_data.dict(exclude_none=True)
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE) client = client or get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE
)
response = await client.post( response = await client.post(
api_base, api_base,

View file

@ -3440,6 +3440,10 @@ def embedding( # noqa: PLR0915
or litellm.openai_key or litellm.openai_key
or get_secret_str("OPENAI_API_KEY") or get_secret_str("OPENAI_API_KEY")
) )
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
api_type = "openai" api_type = "openai"
api_version = None api_version = None

View file

@ -16,7 +16,7 @@ model_list:
model: openai/fake model: openai/fake
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
router_settings: router_settings:
model_group_alias: model_group_alias:
"gpt-4-turbo": # Aliased model name "gpt-4-turbo": # Aliased model name
@ -35,4 +35,4 @@ litellm_settings:
failure_callback: ["langfuse"] failure_callback: ["langfuse"]
langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2 langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2 langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2
langfuse_host: https://us.cloud.langfuse.com langfuse_host: https://us.cloud.langfuse.com

View file

@ -2,6 +2,7 @@ import enum
import json import json
import os import os
import sys import sys
import traceback
import uuid import uuid
from dataclasses import fields from dataclasses import fields
from datetime import datetime from datetime import datetime
@ -12,7 +13,15 @@ from typing_extensions import Annotated, TypedDict
from litellm.types.integrations.slack_alerting import AlertType from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.router import RouterErrors, UpdateRouterConfig from litellm.types.router import RouterErrors, UpdateRouterConfig
from litellm.types.utils import ProviderField, StandardCallbackDynamicParams from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ProviderField,
StandardCallbackDynamicParams,
StandardPassThroughResponseObject,
TextCompletionResponse,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -882,15 +891,7 @@ class DeleteCustomerRequest(LiteLLMBase):
user_ids: List[str] user_ids: List[str]
class Member(LiteLLMBase): class MemberBase(LiteLLMBase):
role: Literal[
LitellmUserRoles.ORG_ADMIN,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
# older Member roles
"admin",
"user",
]
user_id: Optional[str] = None user_id: Optional[str] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@ -904,6 +905,21 @@ class Member(LiteLLMBase):
return values return values
class Member(MemberBase):
role: Literal[
"admin",
"user",
]
class OrgMember(MemberBase):
role: Literal[
LitellmUserRoles.ORG_ADMIN,
LitellmUserRoles.INTERNAL_USER,
LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
]
class TeamBase(LiteLLMBase): class TeamBase(LiteLLMBase):
team_alias: Optional[str] = None team_alias: Optional[str] = None
team_id: Optional[str] = None team_id: Optional[str] = None
@ -1966,6 +1982,26 @@ class MemberAddRequest(LiteLLMBase):
# Replace member_data with the single Member object # Replace member_data with the single Member object
data["member"] = member data["member"] = member
# Call the superclass __init__ method to initialize the object # Call the superclass __init__ method to initialize the object
traceback.print_stack()
super().__init__(**data)
class OrgMemberAddRequest(LiteLLMBase):
member: Union[List[OrgMember], OrgMember]
def __init__(self, **data):
member_data = data.get("member")
if isinstance(member_data, list):
# If member is a list of dictionaries, convert each dictionary to a Member object
members = [OrgMember(**item) for item in member_data]
# Replace member_data with the list of Member objects
data["member"] = members
elif isinstance(member_data, dict):
# If member is a dictionary, convert it to a single Member object
member = OrgMember(**member_data)
# Replace member_data with the single Member object
data["member"] = member
# Call the superclass __init__ method to initialize the object
super().__init__(**data) super().__init__(**data)
@ -2017,7 +2053,7 @@ class TeamMemberUpdateResponse(MemberUpdateResponse):
# Organization Member Requests # Organization Member Requests
class OrganizationMemberAddRequest(MemberAddRequest): class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str organization_id: str
max_budget_in_organization: Optional[float] = ( max_budget_in_organization: Optional[float] = (
None # Users max budget within the organization None # Users max budget within the organization
@ -2133,3 +2169,17 @@ class UserManagementEndpointParamDocStringEnums(str, enum.Enum):
spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used.""" spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used."""
team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None.""" team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None."""
duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None.""" duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None."""
PassThroughEndpointLoggingResultValues = Union[
ModelResponse,
TextCompletionResponse,
ImageResponse,
EmbeddingResponse,
StandardPassThroughResponseObject,
]
class PassThroughEndpointLoggingTypedDict(TypedDict):
result: Optional[PassThroughEndpointLoggingResultValues]
kwargs: dict

View file

@ -40,6 +40,77 @@ from litellm.proxy.utils import (
) )
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
def _is_team_key(data: GenerateKeyRequest):
return data.team_id is not None
def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth):
if (
litellm.key_generation_settings is None
or litellm.key_generation_settings.get("team_key_generation") is None
):
return True
if user_api_key_dict.team_member is None:
raise HTTPException(
status_code=400,
detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}",
)
team_member_role = user_api_key_dict.team_member.role
if (
team_member_role
not in litellm.key_generation_settings["team_key_generation"][ # type: ignore
"allowed_team_member_roles"
]
):
raise HTTPException(
status_code=400,
detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore
)
return True
def _personal_key_generation_check(user_api_key_dict: UserAPIKeyAuth):
if (
litellm.key_generation_settings is None
or litellm.key_generation_settings.get("personal_key_generation") is None
):
return True
if (
user_api_key_dict.user_role
not in litellm.key_generation_settings["personal_key_generation"][ # type: ignore
"allowed_user_roles"
]
):
raise HTTPException(
status_code=400,
detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore
)
return True
def key_generation_check(
user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest
) -> bool:
"""
Check if admin has restricted key creation to certain roles for teams or individuals
"""
if litellm.key_generation_settings is None:
return True
## check if key is for team or individual
is_team_key = _is_team_key(data=data)
if is_team_key:
return _team_key_generation_check(user_api_key_dict)
else:
return _personal_key_generation_check(user_api_key_dict=user_api_key_dict)
router = APIRouter() router = APIRouter()
@ -131,6 +202,8 @@ async def generate_key_fn( # noqa: PLR0915
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=message status_code=status.HTTP_403_FORBIDDEN, detail=message
) )
elif litellm.key_generation_settings is not None:
key_generation_check(user_api_key_dict=user_api_key_dict, data=data)
# check if user set default key/generate params on config.yaml # check if user set default key/generate params on config.yaml
if litellm.default_key_generate_params is not None: if litellm.default_key_generate_params is not None:
for elem in data: for elem in data:

View file

@ -352,7 +352,7 @@ async def organization_member_add(
}, },
) )
members: List[Member] members: List[OrgMember]
if isinstance(data.member, List): if isinstance(data.member, List):
members = data.member members = data.member
else: else:
@ -397,7 +397,7 @@ async def organization_member_add(
async def add_member_to_organization( async def add_member_to_organization(
member: Member, member: OrgMember,
organization_id: str, organization_id: str,
prisma_client: PrismaClient, prisma_client: PrismaClient,
) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]: ) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]:

View file

@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator, ModelResponseIterator as AnthropicModelResponseIterator,
) )
from litellm.llms.anthropic.chat.transformation import AnthropicConfig from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging from ..success_handler import PassThroughEndpointLogging
@ -26,7 +27,7 @@ else:
class AnthropicPassthroughLoggingHandler: class AnthropicPassthroughLoggingHandler:
@staticmethod @staticmethod
async def anthropic_passthrough_handler( def anthropic_passthrough_handler(
httpx_response: httpx.Response, httpx_response: httpx.Response,
response_body: dict, response_body: dict,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
@ -36,7 +37,7 @@ class AnthropicPassthroughLoggingHandler:
end_time: datetime, end_time: datetime,
cache_hit: bool, cache_hit: bool,
**kwargs, **kwargs,
): ) -> PassThroughEndpointLoggingTypedDict:
""" """
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
""" """
@ -67,15 +68,10 @@ class AnthropicPassthroughLoggingHandler:
logging_obj=logging_obj, logging_obj=logging_obj,
) )
await logging_obj.async_success_handler( return {
result=litellm_model_response, "result": litellm_model_response,
start_time=start_time, "kwargs": kwargs,
end_time=end_time, }
cache_hit=cache_hit,
**kwargs,
)
pass
@staticmethod @staticmethod
def _create_anthropic_response_logging_payload( def _create_anthropic_response_logging_payload(
@ -123,7 +119,7 @@ class AnthropicPassthroughLoggingHandler:
return kwargs return kwargs
@staticmethod @staticmethod
async def _handle_logging_anthropic_collected_chunks( def _handle_logging_anthropic_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj, litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging, passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str, url_route: str,
@ -132,7 +128,7 @@ class AnthropicPassthroughLoggingHandler:
start_time: datetime, start_time: datetime,
all_chunks: List[str], all_chunks: List[str],
end_time: datetime, end_time: datetime,
): ) -> PassThroughEndpointLoggingTypedDict:
""" """
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
@ -152,7 +148,10 @@ class AnthropicPassthroughLoggingHandler:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..." "Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
) )
return return {
"result": None,
"kwargs": {},
}
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=complete_streaming_response, litellm_model_response=complete_streaming_response,
model=model, model=model,
@ -161,13 +160,11 @@ class AnthropicPassthroughLoggingHandler:
end_time=end_time, end_time=end_time,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
) )
await litellm_logging_obj.async_success_handler(
result=complete_streaming_response, return {
start_time=start_time, "result": complete_streaming_response,
end_time=end_time, "kwargs": kwargs,
cache_hit=False, }
**kwargs,
)
@staticmethod @staticmethod
def _build_complete_streaming_response( def _build_complete_streaming_response(

View file

@ -14,6 +14,7 @@ from litellm.litellm_core_utils.litellm_logging import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexModelResponseIterator, ModelResponseIterator as VertexModelResponseIterator,
) )
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging from ..success_handler import PassThroughEndpointLogging
@ -25,7 +26,7 @@ else:
class VertexPassthroughLoggingHandler: class VertexPassthroughLoggingHandler:
@staticmethod @staticmethod
async def vertex_passthrough_handler( def vertex_passthrough_handler(
httpx_response: httpx.Response, httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
url_route: str, url_route: str,
@ -34,7 +35,7 @@ class VertexPassthroughLoggingHandler:
end_time: datetime, end_time: datetime,
cache_hit: bool, cache_hit: bool,
**kwargs, **kwargs,
): ) -> PassThroughEndpointLoggingTypedDict:
if "generateContent" in url_route: if "generateContent" in url_route:
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
@ -65,13 +66,11 @@ class VertexPassthroughLoggingHandler:
logging_obj=logging_obj, logging_obj=logging_obj,
) )
await logging_obj.async_success_handler( return {
result=litellm_model_response, "result": litellm_model_response,
start_time=start_time, "kwargs": kwargs,
end_time=end_time, }
cache_hit=cache_hit,
**kwargs,
)
elif "predict" in url_route: elif "predict" in url_route:
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
VertexImageGeneration, VertexImageGeneration,
@ -112,16 +111,18 @@ class VertexPassthroughLoggingHandler:
logging_obj.model = model logging_obj.model = model
logging_obj.model_call_details["model"] = logging_obj.model logging_obj.model_call_details["model"] = logging_obj.model
await logging_obj.async_success_handler( return {
result=litellm_prediction_response, "result": litellm_prediction_response,
start_time=start_time, "kwargs": kwargs,
end_time=end_time, }
cache_hit=cache_hit, else:
**kwargs, return {
) "result": None,
"kwargs": kwargs,
}
@staticmethod @staticmethod
async def _handle_logging_vertex_collected_chunks( def _handle_logging_vertex_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj, litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging, passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str, url_route: str,
@ -130,7 +131,7 @@ class VertexPassthroughLoggingHandler:
start_time: datetime, start_time: datetime,
all_chunks: List[str], all_chunks: List[str],
end_time: datetime, end_time: datetime,
): ) -> PassThroughEndpointLoggingTypedDict:
""" """
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
@ -152,7 +153,11 @@ class VertexPassthroughLoggingHandler:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
) )
return return {
"result": None,
"kwargs": kwargs,
}
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
litellm_model_response=complete_streaming_response, litellm_model_response=complete_streaming_response,
model=model, model=model,
@ -161,13 +166,11 @@ class VertexPassthroughLoggingHandler:
end_time=end_time, end_time=end_time,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
) )
await litellm_logging_obj.async_success_handler(
result=complete_streaming_response, return {
start_time=start_time, "result": complete_streaming_response,
end_time=end_time, "kwargs": kwargs,
cache_hit=False, }
**kwargs,
)
@staticmethod @staticmethod
def _build_complete_streaming_response( def _build_complete_streaming_response(

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
import json import json
import threading
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import AsyncIterable, Dict, List, Optional, Union from typing import AsyncIterable, Dict, List, Optional, Union
@ -15,7 +16,12 @@ from litellm.llms.anthropic.chat.handler import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexAIIterator, ModelResponseIterator as VertexAIIterator,
) )
from litellm.types.utils import GenericStreamingChunk from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.types.utils import (
GenericStreamingChunk,
ModelResponse,
StandardPassThroughResponseObject,
)
from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler, AnthropicPassthroughLoggingHandler,
@ -87,8 +93,12 @@ class PassThroughStreamingHandler:
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
raw_bytes raw_bytes
) )
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
kwargs: dict = {}
if endpoint_type == EndpointType.ANTHROPIC: if endpoint_type == EndpointType.ANTHROPIC:
await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj, litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj, passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route, url_route=url_route,
@ -98,20 +108,48 @@ class PassThroughStreamingHandler:
all_chunks=all_chunks, all_chunks=all_chunks,
end_time=end_time, end_time=end_time,
) )
standard_logging_response_object = anthropic_passthrough_logging_handler_result[
"result"
]
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.VERTEX_AI: elif endpoint_type == EndpointType.VERTEX_AI:
await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( vertex_passthrough_logging_handler_result = (
litellm_logging_obj=litellm_logging_obj, VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
passthrough_success_handler_obj=passthrough_success_handler_obj, litellm_logging_obj=litellm_logging_obj,
url_route=url_route, passthrough_success_handler_obj=passthrough_success_handler_obj,
request_body=request_body, url_route=url_route,
endpoint_type=endpoint_type, request_body=request_body,
start_time=start_time, endpoint_type=endpoint_type,
all_chunks=all_chunks, start_time=start_time,
end_time=end_time, all_chunks=all_chunks,
end_time=end_time,
)
) )
elif endpoint_type == EndpointType.GENERIC: standard_logging_response_object = vertex_passthrough_logging_handler_result[
# No logging is supported for generic streaming endpoints "result"
pass ]
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
)
threading.Thread(
target=litellm_logging_obj.success_handler,
args=(
standard_logging_response_object,
start_time,
end_time,
False,
),
).start()
await litellm_logging_obj.async_success_handler(
result=standard_logging_response_object,
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
@staticmethod @staticmethod
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]: def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
@ -130,4 +168,4 @@ class PassThroughStreamingHandler:
# Split by newlines and filter out empty lines # Split by newlines and filter out empty lines
lines = [line.strip() for line in combined_str.split("\n") if line.strip()] lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
return lines return lines

View file

@ -15,6 +15,7 @@ from litellm.litellm_core_utils.litellm_logging import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.utils import StandardPassThroughResponseObject from litellm.types.utils import StandardPassThroughResponseObject
@ -49,53 +50,69 @@ class PassThroughEndpointLogging:
cache_hit: bool, cache_hit: bool,
**kwargs, **kwargs,
): ):
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
if self.is_vertex_route(url_route): if self.is_vertex_route(url_route):
await VertexPassthroughLoggingHandler.vertex_passthrough_handler( vertex_passthrough_logging_handler_result = (
httpx_response=httpx_response, VertexPassthroughLoggingHandler.vertex_passthrough_handler(
logging_obj=logging_obj, httpx_response=httpx_response,
url_route=url_route, logging_obj=logging_obj,
result=result, url_route=url_route,
start_time=start_time, result=result,
end_time=end_time, start_time=start_time,
cache_hit=cache_hit, end_time=end_time,
**kwargs, cache_hit=cache_hit,
**kwargs,
)
) )
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif self.is_anthropic_route(url_route): elif self.is_anthropic_route(url_route):
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( anthropic_passthrough_logging_handler_result = (
httpx_response=httpx_response, AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
response_body=response_body or {}, httpx_response=httpx_response,
logging_obj=logging_obj, response_body=response_body or {},
url_route=url_route, logging_obj=logging_obj,
result=result, url_route=url_route,
start_time=start_time, result=result,
end_time=end_time, start_time=start_time,
cache_hit=cache_hit, end_time=end_time,
**kwargs, cache_hit=cache_hit,
**kwargs,
)
) )
else:
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject( standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text response=httpx_response.text
) )
threading.Thread( threading.Thread(
target=logging_obj.success_handler, target=logging_obj.success_handler,
args=( args=(
standard_logging_response_object, standard_logging_response_object,
start_time, start_time,
end_time, end_time,
cache_hit, cache_hit,
), ),
).start() ).start()
await logging_obj.async_success_handler( await logging_obj.async_success_handler(
result=( result=(
json.dumps(result) json.dumps(result)
if isinstance(result, dict) if isinstance(result, dict)
else standard_logging_response_object else standard_logging_response_object
), ),
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
cache_hit=False, cache_hit=False,
**kwargs, **kwargs,
) )
def is_vertex_route(self, url_route: str): def is_vertex_route(self, url_route: str):
for route in self.TRACKED_VERTEX_ROUTES: for route in self.TRACKED_VERTEX_ROUTES:

View file

@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import (
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import RouterGeneralSettings from litellm.types.router import RouterGeneralSettings
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
try: try:
from litellm._version import version from litellm._version import version
@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback(
) )
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {} end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
end_user_id = proxy_server_request.get("body", {}).get("user", None)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None) user_id = metadata.get("user_api_key_user_id", None)
team_id = metadata.get("user_api_key_team_id", None) team_id = metadata.get("user_api_key_team_id", None)

View file

@ -337,14 +337,14 @@ class ProxyLogging:
alert_to_webhook_url=self.alert_to_webhook_url, alert_to_webhook_url=self.alert_to_webhook_url,
) )
if ( if self.alerting is not None and "slack" in self.alerting:
self.alerting is not None
and "slack" in self.alerting
and "daily_reports" in self.alert_types
):
# NOTE: ENSURE we only add callbacks when alerting is on # NOTE: ENSURE we only add callbacks when alerting is on
# We should NOT add callbacks when alerting is off # We should NOT add callbacks when alerting is off
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore if "daily_reports" in self.alert_types:
litellm.callbacks.append(self.slack_alerting_instance) # type: ignore
litellm.success_callback.append(
self.slack_alerting_instance.response_taking_too_long_callback
)
if redis_cache is not None: if redis_cache is not None:
self.internal_usage_cache.dual_cache.redis_cache = redis_cache self.internal_usage_cache.dual_cache.redis_cache = redis_cache
@ -354,9 +354,6 @@ class ProxyLogging:
litellm.callbacks.append(self.max_budget_limiter) # type: ignore litellm.callbacks.append(self.max_budget_limiter) # type: ignore
litellm.callbacks.append(self.cache_control_check) # type: ignore litellm.callbacks.append(self.cache_control_check) # type: ignore
litellm.callbacks.append(self.service_logging_obj) # type: ignore litellm.callbacks.append(self.service_logging_obj) # type: ignore
litellm.success_callback.append(
self.slack_alerting_instance.response_taking_too_long_callback
)
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, str): if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore

View file

@ -91,6 +91,7 @@ def rerank(
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {}) metadata = kwargs.get("metadata", {})
user = kwargs.get("user", None) user = kwargs.get("user", None)
client = kwargs.get("client", None)
try: try:
_is_async = kwargs.pop("arerank", False) is True _is_async = kwargs.pop("arerank", False) is True
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
@ -150,7 +151,7 @@ def rerank(
or optional_params.api_base or optional_params.api_base
or litellm.api_base or litellm.api_base
or get_secret("COHERE_API_BASE") # type: ignore or get_secret("COHERE_API_BASE") # type: ignore
or "https://api.cohere.com/v1/rerank" or "https://api.cohere.com"
) )
if api_base is None: if api_base is None:
@ -173,6 +174,7 @@ def rerank(
_is_async=_is_async, _is_async=_is_async,
headers=headers, headers=headers,
litellm_logging_obj=litellm_logging_obj, litellm_logging_obj=litellm_logging_obj,
client=client,
) )
elif _custom_llm_provider == "azure_ai": elif _custom_llm_provider == "azure_ai":
api_base = ( api_base = (

View file

@ -1602,3 +1602,16 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
langsmith_api_key: Optional[str] langsmith_api_key: Optional[str]
langsmith_project: Optional[str] langsmith_project: Optional[str]
langsmith_base_url: Optional[str] langsmith_base_url: Optional[str]
class TeamUIKeyGenerationConfig(TypedDict):
allowed_team_member_roles: List[str]
class PersonalUIKeyGenerationConfig(TypedDict):
allowed_user_roles: List[str]
class StandardKeyGenerationConfig(TypedDict, total=False):
team_key_generation: TeamUIKeyGenerationConfig
personal_key_generation: PersonalUIKeyGenerationConfig

View file

@ -6170,3 +6170,13 @@ class ProviderConfigManager:
return litellm.GroqChatConfig() return litellm.GroqChatConfig()
return OpenAIGPTConfig() return OpenAIGPTConfig()
def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
"""
Used for enforcing `disable_end_user_cost_tracking` param.
"""
proxy_server_request = litellm_params.get("proxy_server_request") or {}
if litellm.disable_end_user_cost_tracking:
return None
return proxy_server_request.get("body", {}).get("user", None)

View file

@ -1080,3 +1080,34 @@ def test_cohere_img_embeddings(input, input_type):
assert response.usage.prompt_tokens_details.image_tokens > 0 assert response.usage.prompt_tokens_details.image_tokens > 0
else: else:
assert response.usage.prompt_tokens_details.text_tokens > 0 assert response.usage.prompt_tokens_details.text_tokens > 0
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_embedding_with_extra_headers(sync_mode):
input = ["hello world"]
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
if sync_mode:
client = HTTPHandler()
else:
client = AsyncHTTPHandler()
data = {
"model": "cohere/embed-english-v3.0",
"input": input,
"extra_headers": {"my-test-param": "hello-world"},
"client": client,
}
with patch.object(client, "post") as mock_post:
try:
if sync_mode:
embedding(**data)
else:
await litellm.aembedding(**data)
except Exception as e:
print(e)
mock_post.assert_called_once()
assert "my-test-param" in mock_post.call_args.kwargs["headers"]

View file

@ -215,7 +215,10 @@ async def test_rerank_custom_api_base():
args_to_api = kwargs["json"] args_to_api = kwargs["json"]
print("Arguments passed to API=", args_to_api) print("Arguments passed to API=", args_to_api)
print("url = ", _url) print("url = ", _url)
assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/" assert (
_url[0]
== "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank"
)
assert args_to_api == expected_payload assert args_to_api == expected_payload
assert response.id is not None assert response.id is not None
assert response.results is not None assert response.results is not None
@ -258,3 +261,32 @@ async def test_rerank_custom_callbacks():
assert custom_logger.kwargs.get("response_cost") > 0.0 assert custom_logger.kwargs.get("response_cost") > 0.0
assert custom_logger.response_obj is not None assert custom_logger.response_obj is not None
assert custom_logger.response_obj.results is not None assert custom_logger.response_obj.results is not None
def test_complete_base_url_cohere():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
litellm.api_base = "http://localhost:4000"
litellm.set_verbose = True
text = "Hello there!"
list_texts = ["Hello there!", "How are you?", "How do you do?"]
rerank_model = "rerank-multilingual-v3.0"
with patch.object(client, "post") as mock_post:
try:
litellm.rerank(
model=rerank_model,
query=text,
documents=list_texts,
custom_llm_provider="cohere",
client=client,
)
except Exception as e:
print(e)
print("mock_post.call_args", mock_post.call_args)
mock_post.assert_called_once()
assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"]

View file

@ -1012,3 +1012,23 @@ def test_models_by_provider():
for provider in providers: for provider in providers:
assert provider in models_by_provider.keys() assert provider in models_by_provider.keys()
@pytest.mark.parametrize(
"litellm_params, disable_end_user_cost_tracking, expected_end_user_id",
[
({}, False, None),
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
],
)
def test_get_end_user_id_for_cost_tracking(
litellm_params, disable_end_user_cost_tracking, expected_end_user_id
):
from litellm.utils import get_end_user_id_for_cost_tracking
litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking
assert (
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
== expected_end_user_id
)

View file

@ -216,3 +216,78 @@ async def test_init_custom_logger_compatible_class_as_callback():
await use_callback_in_llm_call(callback, used_in="success_callback") await use_callback_in_llm_call(callback, used_in="success_callback")
reset_env_vars() reset_env_vars()
def test_dynamic_logging_global_callback():
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import ModelResponse, Choices, Message, Usage
cl = CustomLogger()
litellm_logging = LiteLLMLoggingObj(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="completion",
start_time=datetime.now(),
litellm_call_id="123",
function_id="456",
kwargs={
"langfuse_public_key": "my-mock-public-key",
"langfuse_secret_key": "my-mock-secret-key",
},
dynamic_success_callbacks=["langfuse"],
)
with patch.object(cl, "log_success_event") as mock_log_success_event:
cl.log_success_event = mock_log_success_event
litellm.success_callback = [cl]
try:
litellm_logging.success_handler(
result=ModelResponse(
id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70",
created=1732306261,
model="claude-3-opus-20240229",
object="chat.completion",
system_fingerprint=None,
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="hello",
role="assistant",
tool_calls=None,
function_call=None,
),
)
],
usage=Usage(
completion_tokens=20,
prompt_tokens=10,
total_tokens=30,
completion_tokens_details=None,
prompt_tokens_details=None,
),
),
start_time=datetime.now(),
end_time=datetime.now(),
cache_hit=False,
)
except Exception as e:
print(f"Error: {e}")
mock_log_success_event.assert_called_once()
def test_get_combined_callback_list():
from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list
assert "langfuse" in get_combined_callback_list(
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
)
assert "lago" in get_combined_callback_list(
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
)

View file

@ -73,7 +73,7 @@ async def test_anthropic_passthrough_handler(
start_time = datetime.now() start_time = datetime.now()
end_time = datetime.now() end_time = datetime.now()
await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( result = AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=mock_httpx_response, httpx_response=mock_httpx_response,
response_body=mock_response, response_body=mock_response,
logging_obj=mock_logging_obj, logging_obj=mock_logging_obj,
@ -84,30 +84,7 @@ async def test_anthropic_passthrough_handler(
cache_hit=False, cache_hit=False,
) )
# Assert that async_success_handler was called assert isinstance(result["result"], litellm.ModelResponse)
assert mock_logging_obj.async_success_handler.called
call_args = mock_logging_obj.async_success_handler.call_args
call_kwargs = call_args.kwargs
print("call_kwargs", call_kwargs)
# Assert required fields are present in call_kwargs
assert "result" in call_kwargs
assert "start_time" in call_kwargs
assert "end_time" in call_kwargs
assert "cache_hit" in call_kwargs
assert "response_cost" in call_kwargs
assert "model" in call_kwargs
assert "standard_logging_object" in call_kwargs
# Assert specific values and types
assert isinstance(call_kwargs["result"], litellm.ModelResponse)
assert isinstance(call_kwargs["start_time"], datetime)
assert isinstance(call_kwargs["end_time"], datetime)
assert isinstance(call_kwargs["cache_hit"], bool)
assert isinstance(call_kwargs["response_cost"], float)
assert call_kwargs["model"] == "claude-3-opus-20240229"
assert isinstance(call_kwargs["standard_logging_object"], dict)
def test_create_anthropic_response_logging_payload(mock_logging_obj): def test_create_anthropic_response_logging_payload(mock_logging_obj):

View file

@ -64,6 +64,7 @@ async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route):
litellm_logging_obj = MagicMock() litellm_logging_obj = MagicMock()
start_time = datetime.now() start_time = datetime.now()
passthrough_success_handler_obj = MagicMock() passthrough_success_handler_obj = MagicMock()
litellm_logging_obj.async_success_handler = AsyncMock()
# Capture yielded chunks and perform detailed assertions # Capture yielded chunks and perform detailed assertions
received_chunks = [] received_chunks = []

View file

@ -0,0 +1,54 @@
# conftest.py
import importlib
import os
import sys
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown():
"""
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
"""
curr_dir = os.getcwd() # Get the current working directory
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the project directory to the system path
import litellm
from litellm import Router
importlib.reload(litellm)
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(loop)
print(litellm)
# from litellm import Router, completion, aembedding, acompletion, embedding
yield
# Teardown code (executes after the yield point)
loop.close() # Close the loop created earlier
asyncio.set_event_loop(None) # Remove the reference to the loop
def pytest_collection_modifyitems(config, items):
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
custom_logger_tests = [
item for item in items if "custom_logger" in item.parent.name
]
other_tests = [item for item in items if "custom_logger" not in item.parent.name]
# Sort tests based on their names
custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name)
# Reorder the items list
items[:] = custom_logger_tests + other_tests

View file

@ -542,3 +542,65 @@ async def test_list_teams(prisma_client):
# Clean up # Clean up
await prisma_client.delete_data(team_id_list=[team_id], table_name="team") await prisma_client.delete_data(team_id_list=[team_id], table_name="team")
def test_is_team_key():
from litellm.proxy.management_endpoints.key_management_endpoints import _is_team_key
assert _is_team_key(GenerateKeyRequest(team_id="test_team_id"))
assert not _is_team_key(GenerateKeyRequest(user_id="test_user_id"))
def test_team_key_generation_check():
from litellm.proxy.management_endpoints.key_management_endpoints import (
_team_key_generation_check,
)
from fastapi import HTTPException
litellm.key_generation_settings = {
"team_key_generation": {"allowed_team_member_roles": ["admin"]}
}
assert _team_key_generation_check(
UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-1234",
team_member=Member(role="admin", user_id="test_user_id"),
)
)
with pytest.raises(HTTPException):
_team_key_generation_check(
UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-1234",
user_id="test_user_id",
team_member=Member(role="user", user_id="test_user_id"),
)
)
def test_personal_key_generation_check():
from litellm.proxy.management_endpoints.key_management_endpoints import (
_personal_key_generation_check,
)
from fastapi import HTTPException
litellm.key_generation_settings = {
"personal_key_generation": {"allowed_user_roles": ["proxy_admin"]}
}
assert _personal_key_generation_check(
UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin"
)
)
with pytest.raises(HTTPException):
_personal_key_generation_check(
UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER,
api_key="sk-1234",
user_id="admin",
)
)

View file

@ -160,7 +160,7 @@ async def test_create_new_user_in_organization(prisma_client, user_role):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org_id, organization_id=org_id,
member=Member(role=user_role, user_id=created_user_id), member=OrgMember(role=user_role, user_id=created_user_id),
), ),
http_request=None, http_request=None,
) )
@ -220,7 +220,7 @@ async def test_org_admin_create_team_permissions(prisma_client):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org_id, organization_id=org_id,
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
), ),
http_request=None, http_request=None,
) )
@ -292,7 +292,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org_id, organization_id=org_id,
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
), ),
http_request=None, http_request=None,
) )
@ -323,7 +323,7 @@ async def test_org_admin_create_user_permissions(prisma_client):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org_id, organization_id=org_id,
member=Member( member=OrgMember(
role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org
), ),
), ),
@ -375,7 +375,7 @@ async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client):
response = await organization_member_add( response = await organization_member_add(
data=OrganizationMemberAddRequest( data=OrganizationMemberAddRequest(
organization_id=org1_id, organization_id=org1_id,
member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id),
), ),
http_request=None, http_request=None,
) )