From 7e9d8b58f6e9f5c622513f22a26d5952427af8c9 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 23 Nov 2024 15:17:40 +0530 Subject: [PATCH] LiteLLM Minor Fixes & Improvements (11/23/2024) (#6870) * feat(pass_through_endpoints/): support logging anthropic/gemini pass through calls to langfuse/s3/etc. * fix(utils.py): allow disabling end user cost tracking with new param Allows proxy admin to disable cost tracking for end user - keeps prometheus metrics small * docs(configs.md): add disable_end_user_cost_tracking reference to docs * feat(key_management_endpoints.py): add support for restricting access to `/key/generate` by team/proxy level role Enables admin to restrict key creation, and assign team admins to handle distributing keys * test(test_key_management.py): add unit testing for personal / team key restriction checks * docs: add docs on restricting key creation * docs(finetuned_models.md): add new guide on calling finetuned models * docs(input.md): cleanup anthropic supported params Closes https://github.com/BerriAI/litellm/issues/6856 * test(test_embedding.py): add test for passing extra headers via embedding * feat(cohere/embed): pass client to async embedding * feat(rerank.py): add `/v1/rerank` if missing for cohere base url Closes https://github.com/BerriAI/litellm/issues/6844 * fix(main.py): pass extra_headers param to openai Fixes https://github.com/BerriAI/litellm/issues/6836 * fix(litellm_logging.py): don't disable global callbacks when dynamic callbacks are set Fixes issue where global callbacks - e.g. prometheus were overriden when langfuse was set dynamically * fix(handler.py): fix linting error * fix: fix typing * build: add conftest to proxy_admin_ui_tests/ * test: fix test * fix: fix linting errors * test: fix test * fix: fix pass through testing --- docs/my-website/docs/completion/input.md | 2 +- .../docs/guides/finetuned_models.md | 74 ++++++++++++++ docs/my-website/docs/proxy/configs.md | 2 + docs/my-website/docs/proxy/self_serve.md | 8 +- docs/my-website/docs/proxy/virtual_keys.md | 69 +++++++++++++ docs/my-website/sidebars.js | 42 ++++---- litellm/__init__.py | 3 + litellm/integrations/prometheus.py | 10 +- litellm/litellm_core_utils/litellm_logging.py | 88 ++++++----------- litellm/llms/azure_ai/rerank/handler.py | 2 + litellm/llms/cohere/embed/handler.py | 6 ++ litellm/llms/cohere/rerank.py | 37 ++++++- litellm/main.py | 4 + litellm/proxy/_new_secret_config.yaml | 4 +- litellm/proxy/_types.py | 72 +++++++++++--- .../key_management_endpoints.py | 73 ++++++++++++++ .../organization_endpoints.py | 4 +- .../anthropic_passthrough_logging_handler.py | 39 ++++---- .../vertex_passthrough_logging_handler.py | 55 ++++++----- .../streaming_handler.py | 68 ++++++++++--- .../pass_through_endpoints/success_handler.py | 97 +++++++++++-------- litellm/proxy/proxy_server.py | 4 +- litellm/proxy/utils.py | 15 ++- litellm/rerank_api/main.py | 4 +- litellm/types/utils.py | 13 +++ litellm/utils.py | 10 ++ tests/local_testing/test_embedding.py | 31 ++++++ tests/local_testing/test_rerank.py | 34 ++++++- tests/local_testing/test_utils.py | 20 ++++ .../test_unit_tests_init_callbacks.py | 75 ++++++++++++++ .../test_unit_test_anthropic_pass_through.py | 27 +----- .../test_unit_test_streaming.py | 1 + tests/proxy_admin_ui_tests/conftest.py | 54 +++++++++++ .../test_key_management.py | 62 ++++++++++++ .../test_role_based_access.py | 10 +- 35 files changed, 871 insertions(+), 248 deletions(-) create mode 100644 docs/my-website/docs/guides/finetuned_models.md create mode 100644 tests/proxy_admin_ui_tests/conftest.py diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index c563a5bf0..e55c160e0 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -41,7 +41,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea | Provider | temperature | max_completion_tokens | max_tokens | top_p | stream | stream_options | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers | |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| -|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | +|Anthropic| ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | | | | | | |✅ | ✅ | | ✅ | ✅ | | | ✅ | |OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ | |Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ | |Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | diff --git a/docs/my-website/docs/guides/finetuned_models.md b/docs/my-website/docs/guides/finetuned_models.md new file mode 100644 index 000000000..cb0d49b44 --- /dev/null +++ b/docs/my-website/docs/guides/finetuned_models.md @@ -0,0 +1,74 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + + +# Calling Finetuned Models + +## OpenAI + + +| Model Name | Function Call | +|---------------------------|-----------------------------------------------------------------| +| fine tuned `gpt-4-0613` | `response = completion(model="ft:gpt-4-0613", messages=messages)` | +| fine tuned `gpt-4o-2024-05-13` | `response = completion(model="ft:gpt-4o-2024-05-13", messages=messages)` | +| fine tuned `gpt-3.5-turbo-0125` | `response = completion(model="ft:gpt-3.5-turbo-0125", messages=messages)` | +| fine tuned `gpt-3.5-turbo-1106` | `response = completion(model="ft:gpt-3.5-turbo-1106", messages=messages)` | +| fine tuned `gpt-3.5-turbo-0613` | `response = completion(model="ft:gpt-3.5-turbo-0613", messages=messages)` | + + +## Vertex AI + +Fine tuned models on vertex have a numerical model/endpoint id. + + + + +```python +from litellm import completion +import os + +## set ENV variables +os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811" +os.environ["VERTEXAI_LOCATION"] = "us-central1" + +response = completion( + model="vertex_ai/", # e.g. vertex_ai/4965075652664360960 + messages=[{ "content": "Hello, how are you?","role": "user"}], + base_model="vertex_ai/gemini-1.5-pro" # the base model - used for routing +) +``` + + + + +1. Add Vertex Credentials to your env + +```bash +!gcloud auth application-default login +``` + +2. Setup config.yaml + +```yaml +- model_name: finetuned-gemini + litellm_params: + model: vertex_ai/ + vertex_project: + vertex_location: + model_info: + base_model: vertex_ai/gemini-1.5-pro # IMPORTANT +``` + +3. Test it! + +```bash +curl --location 'https://0.0.0.0:4000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: ' \ +--data '{"model": "finetuned-gemini" ,"messages":[{"role": "user", "content":[{"type": "text", "text": "hi"}]}]}' +``` + + + + + diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 3b6b336d6..df22a29e3 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -754,6 +754,8 @@ general_settings: | cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. | | cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) | | cache_params.mode | string | The mode of the cache. [Further docs](./caching) | +| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. | +| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) | ### general_settings - Reference diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index e04aa4b44..494d9e60d 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -217,4 +217,10 @@ litellm_settings: max_parallel_requests: 1000 # (Optional[int], optional): Max number of requests that can be made in parallel. Defaults to None. tpm_limit: 1000 #(Optional[int], optional): Tpm limit. Defaults to None. rpm_limit: 1000 #(Optional[int], optional): Rpm limit. Defaults to None. -``` \ No newline at end of file + + key_generation_settings: # Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) + team_key_generation: + allowed_team_member_roles: ["admin"] + personal_key_generation: # maps to 'Default Team' on UI + allowed_user_roles: ["proxy_admin"] +``` diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 3b9a2a03e..98b06d33b 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -811,6 +811,75 @@ litellm_settings: team_id: "core-infra" ``` +### Restricting Key Generation + +Use this to control who can generate keys. Useful when letting others create keys on the UI. + +```yaml +litellm_settings: + key_generation_settings: + team_key_generation: + allowed_team_member_roles: ["admin"] + personal_key_generation: # maps to 'Default Team' on UI + allowed_user_roles: ["proxy_admin"] +``` + +#### Spec + +```python +class TeamUIKeyGenerationConfig(TypedDict): + allowed_team_member_roles: List[str] + + +class PersonalUIKeyGenerationConfig(TypedDict): + allowed_user_roles: List[LitellmUserRoles] + + +class StandardKeyGenerationConfig(TypedDict, total=False): + team_key_generation: TeamUIKeyGenerationConfig + personal_key_generation: PersonalUIKeyGenerationConfig + + +class LitellmUserRoles(str, enum.Enum): + """ + Admin Roles: + PROXY_ADMIN: admin over the platform + PROXY_ADMIN_VIEW_ONLY: can login, view all own keys, view all spend + ORG_ADMIN: admin over a specific organization, can create teams, users only within their organization + + Internal User Roles: + INTERNAL_USER: can login, view/create/delete their own keys, view their spend + INTERNAL_USER_VIEW_ONLY: can login, view their own keys, view their own spend + + + Team Roles: + TEAM: used for JWT auth + + + Customer Roles: + CUSTOMER: External users -> these are customers + + """ + + # Admin Roles + PROXY_ADMIN = "proxy_admin" + PROXY_ADMIN_VIEW_ONLY = "proxy_admin_viewer" + + # Organization admins + ORG_ADMIN = "org_admin" + + # Internal User Roles + INTERNAL_USER = "internal_user" + INTERNAL_USER_VIEW_ONLY = "internal_user_viewer" + + # Team Roles + TEAM = "team" + + # Customer Roles - External users of proxy + CUSTOMER = "customer" +``` + + ## **Next Steps - Set Budgets, Rate Limits per Virtual Key** [Follow this doc to set budgets, rate limiters per virtual key with LiteLLM](users) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index f01402299..f2bb1c5e9 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -199,6 +199,31 @@ const sidebars = { ], }, + { + type: "category", + label: "Guides", + items: [ + "exception_mapping", + "completion/provider_specific_params", + "guides/finetuned_models", + "completion/audio", + "completion/vision", + "completion/json_mode", + "completion/prompt_caching", + "completion/predict_outputs", + "completion/prefix", + "completion/drop_params", + "completion/prompt_formatting", + "completion/stream", + "completion/message_trimming", + "completion/function_call", + "completion/model_alias", + "completion/batching", + "completion/mock_requests", + "completion/reliable_completions", + + ] + }, { type: "category", label: "Supported Endpoints", @@ -214,25 +239,8 @@ const sidebars = { }, items: [ "completion/input", - "completion/provider_specific_params", - "completion/json_mode", - "completion/prompt_caching", - "completion/audio", - "completion/vision", - "completion/predict_outputs", - "completion/prefix", - "completion/drop_params", - "completion/prompt_formatting", "completion/output", "completion/usage", - "exception_mapping", - "completion/stream", - "completion/message_trimming", - "completion/function_call", - "completion/model_alias", - "completion/batching", - "completion/mock_requests", - "completion/reliable_completions", ], }, "embedding/supported_embedding", diff --git a/litellm/__init__.py b/litellm/__init__.py index c978b24ee..65b1b3465 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -24,6 +24,7 @@ from litellm.proxy._types import ( KeyManagementSettings, LiteLLM_UpperboundKeyGenerateParams, ) +from litellm.types.utils import StandardKeyGenerationConfig import httpx import dotenv from enum import Enum @@ -273,6 +274,7 @@ s3_callback_params: Optional[Dict] = None generic_logger_headers: Optional[Dict] = None default_key_generate_params: Optional[Dict] = None upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None +key_generation_settings: Optional[StandardKeyGenerationConfig] = None default_internal_user_params: Optional[Dict] = None default_team_settings: Optional[List] = None max_user_budget: Optional[float] = None @@ -280,6 +282,7 @@ default_max_internal_user_budget: Optional[float] = None max_internal_user_budget: Optional[float] = None internal_user_budget_duration: Optional[str] = None max_end_user_budget: Optional[float] = None +disable_end_user_cost_tracking: Optional[bool] = None #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None #### RELIABILITY #### diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index bb28719a3..1460a1d7f 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -18,6 +18,7 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth from litellm.types.integrations.prometheus import * from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking class PrometheusLogger(CustomLogger): @@ -364,8 +365,7 @@ class PrometheusLogger(CustomLogger): model = kwargs.get("model", "") litellm_params = kwargs.get("litellm_params", {}) or {} _metadata = litellm_params.get("metadata", {}) - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] @@ -664,13 +664,11 @@ class PrometheusLogger(CustomLogger): # unpack kwargs model = kwargs.get("model", "") - litellm_params = kwargs.get("litellm_params", {}) or {} standard_logging_payload: StandardLoggingPayload = kwargs.get( "standard_logging_object", {} ) - proxy_server_request = litellm_params.get("proxy_server_request") or {} - - end_user_id = proxy_server_request.get("body", {}).get("user", None) + litellm_params = kwargs.get("litellm_params", {}) or {} + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) user_id = standard_logging_payload["metadata"]["user_api_key_user_id"] user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"] user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"] diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 69d6adca4..298e28974 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -934,19 +934,10 @@ class Logging: status="success", ) ) - if self.dynamic_success_callbacks is not None and isinstance( - self.dynamic_success_callbacks, list - ): - callbacks = self.dynamic_success_callbacks - ## keep the internal functions ## - for callback in litellm.success_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm.success_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_success_callbacks, + global_callbacks=litellm.success_callback, + ) ## REDACT MESSAGES ## result = redact_message_input_output_from_logging( @@ -1368,8 +1359,11 @@ class Logging: and customLogger is not None ): # custom logger functions print_verbose( - "success callbacks: Running Custom Callback Function" + "success callbacks: Running Custom Callback Function - {}".format( + callback + ) ) + customLogger.log_event( kwargs=self.model_call_details, response_obj=result, @@ -1466,21 +1460,10 @@ class Logging: status="success", ) ) - if self.dynamic_async_success_callbacks is not None and isinstance( - self.dynamic_async_success_callbacks, list - ): - callbacks = self.dynamic_async_success_callbacks - ## keep the internal functions ## - for callback in litellm._async_success_callback: - callback_name = "" - if isinstance(callback, CustomLogger): - callback_name = callback.__class__.__name__ - if callable(callback): - callback_name = callback.__name__ - if "_PROXY_" in callback_name: - callbacks.append(callback) - else: - callbacks = litellm._async_success_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_success_callbacks, + global_callbacks=litellm._async_success_callback, + ) result = redact_message_input_output_from_logging( model_call_details=( @@ -1747,21 +1730,10 @@ class Logging: start_time=start_time, end_time=end_time, ) - callbacks = [] # init this to empty incase it's not created - - if self.dynamic_failure_callbacks is not None and isinstance( - self.dynamic_failure_callbacks, list - ): - callbacks = self.dynamic_failure_callbacks - ## keep the internal functions ## - for callback in litellm.failure_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm.failure_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_failure_callbacks, + global_callbacks=litellm.failure_callback, + ) result = None # result sent to all loggers, init this to None incase it's not created @@ -1944,21 +1916,10 @@ class Logging: end_time=end_time, ) - callbacks = [] # init this to empty incase it's not created - - if self.dynamic_async_failure_callbacks is not None and isinstance( - self.dynamic_async_failure_callbacks, list - ): - callbacks = self.dynamic_async_failure_callbacks - ## keep the internal functions ## - for callback in litellm._async_failure_callback: - if ( - isinstance(callback, CustomLogger) - and "_PROXY_" in callback.__class__.__name__ - ): - callbacks.append(callback) - else: - callbacks = litellm._async_failure_callback + callbacks = get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_async_failure_callbacks, + global_callbacks=litellm._async_failure_callback, + ) result = None # result sent to all loggers, init this to None incase it's not created for callback in callbacks: @@ -2359,6 +2320,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 _in_memory_loggers.append(_mlflow_logger) return _mlflow_logger # type: ignore + def get_custom_logger_compatible_class( logging_integration: litellm._custom_logger_compatible_callbacks_literal, ) -> Optional[CustomLogger]: @@ -2949,3 +2911,11 @@ def modify_integration(integration_name, integration_params): if integration_name == "supabase": if "table_name" in integration_params: Supabase.supabase_table_name = integration_params["table_name"] + + +def get_combined_callback_list( + dynamic_success_callbacks: Optional[List], global_callbacks: List +) -> List: + if dynamic_success_callbacks is None: + return global_callbacks + return list(set(dynamic_success_callbacks + global_callbacks)) diff --git a/litellm/llms/azure_ai/rerank/handler.py b/litellm/llms/azure_ai/rerank/handler.py index a67c893f2..60edfd296 100644 --- a/litellm/llms/azure_ai/rerank/handler.py +++ b/litellm/llms/azure_ai/rerank/handler.py @@ -4,6 +4,7 @@ import httpx from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.cohere.rerank import CohereRerank +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.rerank import RerankResponse @@ -73,6 +74,7 @@ class AzureAIRerank(CohereRerank): return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, _is_async: Optional[bool] = False, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: if headers is None: diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index 5b224c375..afeba10b5 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -74,6 +74,7 @@ async def async_embedding( }, ) ## COMPLETION CALL + if client is None: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.COHERE, @@ -151,6 +152,11 @@ def embedding( api_key=api_key, headers=headers, encoding=encoding, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), ) ## LOGGING diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index 022ffc6f9..8de2dfbb4 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -6,10 +6,14 @@ LiteLLM supports the re rank API format, no paramter transformation occurs from typing import Any, Dict, List, Optional, Union +import httpx + import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, _get_httpx_client, get_async_httpx_client, ) @@ -34,6 +38,23 @@ class CohereRerank(BaseLLM): # Merge other headers, overriding any default ones except Authorization return {**default_headers, **headers} + def ensure_rerank_endpoint(self, api_base: str) -> str: + """ + Ensures the `/v1/rerank` endpoint is appended to the given `api_base`. + If `/v1/rerank` is already present, the original URL is returned. + + :param api_base: The base API URL. + :return: A URL with `/v1/rerank` appended if missing. + """ + # Parse the base URL to ensure proper structure + url = httpx.URL(api_base) + + # Check if the URL already ends with `/v1/rerank` + if not url.path.endswith("/v1/rerank"): + url = url.copy_with(path=f"{url.path.rstrip('/')}/v1/rerank") + + return str(url) + def rerank( self, model: str, @@ -48,9 +69,10 @@ class CohereRerank(BaseLLM): return_documents: Optional[bool] = True, max_chunks_per_doc: Optional[int] = None, _is_async: Optional[bool] = False, # New parameter + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: headers = self.validate_environment(api_key=api_key, headers=headers) - + api_base = self.ensure_rerank_endpoint(api_base) request_data = RerankRequest( model=model, query=query, @@ -76,9 +98,13 @@ class CohereRerank(BaseLLM): if _is_async: return self.async_rerank(request_data=request_data, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method - client = _get_httpx_client() + if client is not None and isinstance(client, HTTPHandler): + client = client + else: + client = _get_httpx_client() + response = client.post( - api_base, + url=api_base, headers=headers, json=request_data_dict, ) @@ -100,10 +126,13 @@ class CohereRerank(BaseLLM): api_key: str, api_base: str, headers: dict, + client: Optional[AsyncHTTPHandler] = None, ) -> RerankResponse: request_data_dict = request_data.dict(exclude_none=True) - client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE) + client = client or get_async_httpx_client( + llm_provider=litellm.LlmProviders.COHERE + ) response = await client.post( api_base, diff --git a/litellm/main.py b/litellm/main.py index 5d433eb36..5095ce518 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3440,6 +3440,10 @@ def embedding( # noqa: PLR0915 or litellm.openai_key or get_secret_str("OPENAI_API_KEY") ) + + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + api_type = "openai" api_version = None diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index ce9bd1d2f..7baf2224c 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -16,7 +16,7 @@ model_list: model: openai/fake api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/ - + router_settings: model_group_alias: "gpt-4-turbo": # Aliased model name @@ -35,4 +35,4 @@ litellm_settings: failure_callback: ["langfuse"] langfuse_public_key: os.environ/LANGFUSE_PROJECT2_PUBLIC # Project 2 langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2 - langfuse_host: https://us.cloud.langfuse.com \ No newline at end of file + langfuse_host: https://us.cloud.langfuse.com diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8b8dbf2e5..74e82b0ea 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2,6 +2,7 @@ import enum import json import os import sys +import traceback import uuid from dataclasses import fields from datetime import datetime @@ -12,7 +13,15 @@ from typing_extensions import Annotated, TypedDict from litellm.types.integrations.slack_alerting import AlertType from litellm.types.router import RouterErrors, UpdateRouterConfig -from litellm.types.utils import ProviderField, StandardCallbackDynamicParams +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + ProviderField, + StandardCallbackDynamicParams, + StandardPassThroughResponseObject, + TextCompletionResponse, +) if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -882,15 +891,7 @@ class DeleteCustomerRequest(LiteLLMBase): user_ids: List[str] -class Member(LiteLLMBase): - role: Literal[ - LitellmUserRoles.ORG_ADMIN, - LitellmUserRoles.INTERNAL_USER, - LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, - # older Member roles - "admin", - "user", - ] +class MemberBase(LiteLLMBase): user_id: Optional[str] = None user_email: Optional[str] = None @@ -904,6 +905,21 @@ class Member(LiteLLMBase): return values +class Member(MemberBase): + role: Literal[ + "admin", + "user", + ] + + +class OrgMember(MemberBase): + role: Literal[ + LitellmUserRoles.ORG_ADMIN, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + ] + + class TeamBase(LiteLLMBase): team_alias: Optional[str] = None team_id: Optional[str] = None @@ -1966,6 +1982,26 @@ class MemberAddRequest(LiteLLMBase): # Replace member_data with the single Member object data["member"] = member # Call the superclass __init__ method to initialize the object + traceback.print_stack() + super().__init__(**data) + + +class OrgMemberAddRequest(LiteLLMBase): + member: Union[List[OrgMember], OrgMember] + + def __init__(self, **data): + member_data = data.get("member") + if isinstance(member_data, list): + # If member is a list of dictionaries, convert each dictionary to a Member object + members = [OrgMember(**item) for item in member_data] + # Replace member_data with the list of Member objects + data["member"] = members + elif isinstance(member_data, dict): + # If member is a dictionary, convert it to a single Member object + member = OrgMember(**member_data) + # Replace member_data with the single Member object + data["member"] = member + # Call the superclass __init__ method to initialize the object super().__init__(**data) @@ -2017,7 +2053,7 @@ class TeamMemberUpdateResponse(MemberUpdateResponse): # Organization Member Requests -class OrganizationMemberAddRequest(MemberAddRequest): +class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str max_budget_in_organization: Optional[float] = ( None # Users max budget within the organization @@ -2133,3 +2169,17 @@ class UserManagementEndpointParamDocStringEnums(str, enum.Enum): spend_doc_str = """Optional[float] - Amount spent by user. Default is 0. Will be updated by proxy whenever user is used.""" team_id_doc_str = """Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None.""" duration_doc_str = """Optional[str] - Duration for the key auto-created on `/user/new`. Default is None.""" + + +PassThroughEndpointLoggingResultValues = Union[ + ModelResponse, + TextCompletionResponse, + ImageResponse, + EmbeddingResponse, + StandardPassThroughResponseObject, +] + + +class PassThroughEndpointLoggingTypedDict(TypedDict): + result: Optional[PassThroughEndpointLoggingResultValues] + kwargs: dict diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index e4493a28c..ab13616d5 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -40,6 +40,77 @@ from litellm.proxy.utils import ( ) from litellm.secret_managers.main import get_secret + +def _is_team_key(data: GenerateKeyRequest): + return data.team_id is not None + + +def _team_key_generation_check(user_api_key_dict: UserAPIKeyAuth): + if ( + litellm.key_generation_settings is None + or litellm.key_generation_settings.get("team_key_generation") is None + ): + return True + + if user_api_key_dict.team_member is None: + raise HTTPException( + status_code=400, + detail=f"User not assigned to team. Got team_member={user_api_key_dict.team_member}", + ) + + team_member_role = user_api_key_dict.team_member.role + if ( + team_member_role + not in litellm.key_generation_settings["team_key_generation"][ # type: ignore + "allowed_team_member_roles" + ] + ): + raise HTTPException( + status_code=400, + detail=f"Team member role {team_member_role} not in allowed_team_member_roles={litellm.key_generation_settings['team_key_generation']['allowed_team_member_roles']}", # type: ignore + ) + return True + + +def _personal_key_generation_check(user_api_key_dict: UserAPIKeyAuth): + + if ( + litellm.key_generation_settings is None + or litellm.key_generation_settings.get("personal_key_generation") is None + ): + return True + + if ( + user_api_key_dict.user_role + not in litellm.key_generation_settings["personal_key_generation"][ # type: ignore + "allowed_user_roles" + ] + ): + raise HTTPException( + status_code=400, + detail=f"Personal key creation has been restricted by admin. Allowed roles={litellm.key_generation_settings['personal_key_generation']['allowed_user_roles']}. Your role={user_api_key_dict.user_role}", # type: ignore + ) + return True + + +def key_generation_check( + user_api_key_dict: UserAPIKeyAuth, data: GenerateKeyRequest +) -> bool: + """ + Check if admin has restricted key creation to certain roles for teams or individuals + """ + if litellm.key_generation_settings is None: + return True + + ## check if key is for team or individual + is_team_key = _is_team_key(data=data) + + if is_team_key: + return _team_key_generation_check(user_api_key_dict) + else: + return _personal_key_generation_check(user_api_key_dict=user_api_key_dict) + + router = APIRouter() @@ -131,6 +202,8 @@ async def generate_key_fn( # noqa: PLR0915 raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=message ) + elif litellm.key_generation_settings is not None: + key_generation_check(user_api_key_dict=user_api_key_dict, data=data) # check if user set default key/generate params on config.yaml if litellm.default_key_generate_params is not None: for elem in data: diff --git a/litellm/proxy/management_endpoints/organization_endpoints.py b/litellm/proxy/management_endpoints/organization_endpoints.py index 81d135097..363384375 100644 --- a/litellm/proxy/management_endpoints/organization_endpoints.py +++ b/litellm/proxy/management_endpoints/organization_endpoints.py @@ -352,7 +352,7 @@ async def organization_member_add( }, ) - members: List[Member] + members: List[OrgMember] if isinstance(data.member, List): members = data.member else: @@ -397,7 +397,7 @@ async def organization_member_add( async def add_member_to_organization( - member: Member, + member: OrgMember, organization_id: str, prisma_client: PrismaClient, ) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]: diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index ad5a98258..d155174a7 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -14,6 +14,7 @@ from litellm.llms.anthropic.chat.handler import ( ModelResponseIterator as AnthropicModelResponseIterator, ) from litellm.llms.anthropic.chat.transformation import AnthropicConfig +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict if TYPE_CHECKING: from ..success_handler import PassThroughEndpointLogging @@ -26,7 +27,7 @@ else: class AnthropicPassthroughLoggingHandler: @staticmethod - async def anthropic_passthrough_handler( + def anthropic_passthrough_handler( httpx_response: httpx.Response, response_body: dict, logging_obj: LiteLLMLoggingObj, @@ -36,7 +37,7 @@ class AnthropicPassthroughLoggingHandler: end_time: datetime, cache_hit: bool, **kwargs, - ): + ) -> PassThroughEndpointLoggingTypedDict: """ Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled """ @@ -67,15 +68,10 @@ class AnthropicPassthroughLoggingHandler: logging_obj=logging_obj, ) - await logging_obj.async_success_handler( - result=litellm_model_response, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, - ) - - pass + return { + "result": litellm_model_response, + "kwargs": kwargs, + } @staticmethod def _create_anthropic_response_logging_payload( @@ -123,7 +119,7 @@ class AnthropicPassthroughLoggingHandler: return kwargs @staticmethod - async def _handle_logging_anthropic_collected_chunks( + def _handle_logging_anthropic_collected_chunks( litellm_logging_obj: LiteLLMLoggingObj, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, @@ -132,7 +128,7 @@ class AnthropicPassthroughLoggingHandler: start_time: datetime, all_chunks: List[str], end_time: datetime, - ): + ) -> PassThroughEndpointLoggingTypedDict: """ Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks @@ -152,7 +148,10 @@ class AnthropicPassthroughLoggingHandler: verbose_proxy_logger.error( "Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..." ) - return + return { + "result": None, + "kwargs": {}, + } kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( litellm_model_response=complete_streaming_response, model=model, @@ -161,13 +160,11 @@ class AnthropicPassthroughLoggingHandler: end_time=end_time, logging_obj=litellm_logging_obj, ) - await litellm_logging_obj.async_success_handler( - result=complete_streaming_response, - start_time=start_time, - end_time=end_time, - cache_hit=False, - **kwargs, - ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } @staticmethod def _build_complete_streaming_response( diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py index 275a0a119..2773979ad 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/vertex_passthrough_logging_handler.py @@ -14,6 +14,7 @@ from litellm.litellm_core_utils.litellm_logging import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( ModelResponseIterator as VertexModelResponseIterator, ) +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict if TYPE_CHECKING: from ..success_handler import PassThroughEndpointLogging @@ -25,7 +26,7 @@ else: class VertexPassthroughLoggingHandler: @staticmethod - async def vertex_passthrough_handler( + def vertex_passthrough_handler( httpx_response: httpx.Response, logging_obj: LiteLLMLoggingObj, url_route: str, @@ -34,7 +35,7 @@ class VertexPassthroughLoggingHandler: end_time: datetime, cache_hit: bool, **kwargs, - ): + ) -> PassThroughEndpointLoggingTypedDict: if "generateContent" in url_route: model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route) @@ -65,13 +66,11 @@ class VertexPassthroughLoggingHandler: logging_obj=logging_obj, ) - await logging_obj.async_success_handler( - result=litellm_model_response, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, - ) + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + elif "predict" in url_route: from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( VertexImageGeneration, @@ -112,16 +111,18 @@ class VertexPassthroughLoggingHandler: logging_obj.model = model logging_obj.model_call_details["model"] = logging_obj.model - await logging_obj.async_success_handler( - result=litellm_prediction_response, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, - ) + return { + "result": litellm_prediction_response, + "kwargs": kwargs, + } + else: + return { + "result": None, + "kwargs": kwargs, + } @staticmethod - async def _handle_logging_vertex_collected_chunks( + def _handle_logging_vertex_collected_chunks( litellm_logging_obj: LiteLLMLoggingObj, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, @@ -130,7 +131,7 @@ class VertexPassthroughLoggingHandler: start_time: datetime, all_chunks: List[str], end_time: datetime, - ): + ) -> PassThroughEndpointLoggingTypedDict: """ Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks @@ -152,7 +153,11 @@ class VertexPassthroughLoggingHandler: verbose_proxy_logger.error( "Unable to build complete streaming response for Vertex passthrough endpoint, not logging..." ) - return + return { + "result": None, + "kwargs": kwargs, + } + kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content( litellm_model_response=complete_streaming_response, model=model, @@ -161,13 +166,11 @@ class VertexPassthroughLoggingHandler: end_time=end_time, logging_obj=litellm_logging_obj, ) - await litellm_logging_obj.async_success_handler( - result=complete_streaming_response, - start_time=start_time, - end_time=end_time, - cache_hit=False, - **kwargs, - ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } @staticmethod def _build_complete_streaming_response( diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index 522319aaa..dc6aae3af 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -1,5 +1,6 @@ import asyncio import json +import threading from datetime import datetime from enum import Enum from typing import AsyncIterable, Dict, List, Optional, Union @@ -15,7 +16,12 @@ from litellm.llms.anthropic.chat.handler import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( ModelResponseIterator as VertexAIIterator, ) -from litellm.types.utils import GenericStreamingChunk +from litellm.proxy._types import PassThroughEndpointLoggingResultValues +from litellm.types.utils import ( + GenericStreamingChunk, + ModelResponse, + StandardPassThroughResponseObject, +) from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( AnthropicPassthroughLoggingHandler, @@ -87,8 +93,12 @@ class PassThroughStreamingHandler: all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines( raw_bytes ) + standard_logging_response_object: Optional[ + PassThroughEndpointLoggingResultValues + ] = None + kwargs: dict = {} if endpoint_type == EndpointType.ANTHROPIC: - await AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( + anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks( litellm_logging_obj=litellm_logging_obj, passthrough_success_handler_obj=passthrough_success_handler_obj, url_route=url_route, @@ -98,20 +108,48 @@ class PassThroughStreamingHandler: all_chunks=all_chunks, end_time=end_time, ) + standard_logging_response_object = anthropic_passthrough_logging_handler_result[ + "result" + ] + kwargs = anthropic_passthrough_logging_handler_result["kwargs"] elif endpoint_type == EndpointType.VERTEX_AI: - await VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body, - endpoint_type=endpoint_type, - start_time=start_time, - all_chunks=all_chunks, - end_time=end_time, + vertex_passthrough_logging_handler_result = ( + VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body, + endpoint_type=endpoint_type, + start_time=start_time, + all_chunks=all_chunks, + end_time=end_time, + ) ) - elif endpoint_type == EndpointType.GENERIC: - # No logging is supported for generic streaming endpoints - pass + standard_logging_response_object = vertex_passthrough_logging_handler_result[ + "result" + ] + kwargs = vertex_passthrough_logging_handler_result["kwargs"] + + if standard_logging_response_object is None: + standard_logging_response_object = StandardPassThroughResponseObject( + response=f"cannot parse chunks to standard response object. Chunks={all_chunks}" + ) + threading.Thread( + target=litellm_logging_obj.success_handler, + args=( + standard_logging_response_object, + start_time, + end_time, + False, + ), + ).start() + await litellm_logging_obj.async_success_handler( + result=standard_logging_response_object, + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) @staticmethod def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]: @@ -130,4 +168,4 @@ class PassThroughStreamingHandler: # Split by newlines and filter out empty lines lines = [line.strip() for line in combined_str.split("\n") if line.strip()] - return lines + return lines \ No newline at end of file diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index e22a37052..c9c7707f0 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -15,6 +15,7 @@ from litellm.litellm_core_utils.litellm_logging import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) +from litellm.proxy._types import PassThroughEndpointLoggingResultValues from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.types.utils import StandardPassThroughResponseObject @@ -49,53 +50,69 @@ class PassThroughEndpointLogging: cache_hit: bool, **kwargs, ): + standard_logging_response_object: Optional[ + PassThroughEndpointLoggingResultValues + ] = None if self.is_vertex_route(url_route): - await VertexPassthroughLoggingHandler.vertex_passthrough_handler( - httpx_response=httpx_response, - logging_obj=logging_obj, - url_route=url_route, - result=result, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, + vertex_passthrough_logging_handler_result = ( + VertexPassthroughLoggingHandler.vertex_passthrough_handler( + httpx_response=httpx_response, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) ) + standard_logging_response_object = ( + vertex_passthrough_logging_handler_result["result"] + ) + kwargs = vertex_passthrough_logging_handler_result["kwargs"] elif self.is_anthropic_route(url_route): - await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( - httpx_response=httpx_response, - response_body=response_body or {}, - logging_obj=logging_obj, - url_route=url_route, - result=result, - start_time=start_time, - end_time=end_time, - cache_hit=cache_hit, - **kwargs, + anthropic_passthrough_logging_handler_result = ( + AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( + httpx_response=httpx_response, + response_body=response_body or {}, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) ) - else: + + standard_logging_response_object = ( + anthropic_passthrough_logging_handler_result["result"] + ) + kwargs = anthropic_passthrough_logging_handler_result["kwargs"] + if standard_logging_response_object is None: standard_logging_response_object = StandardPassThroughResponseObject( response=httpx_response.text ) - threading.Thread( - target=logging_obj.success_handler, - args=( - standard_logging_response_object, - start_time, - end_time, - cache_hit, - ), - ).start() - await logging_obj.async_success_handler( - result=( - json.dumps(result) - if isinstance(result, dict) - else standard_logging_response_object - ), - start_time=start_time, - end_time=end_time, - cache_hit=False, - **kwargs, - ) + threading.Thread( + target=logging_obj.success_handler, + args=( + standard_logging_response_object, + start_time, + end_time, + cache_hit, + ), + ).start() + await logging_obj.async_success_handler( + result=( + json.dumps(result) + if isinstance(result, dict) + else standard_logging_response_object + ), + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) def is_vertex_route(self, url_route: str): for route in self.TRACKED_VERTEX_ROUTES: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9d7c120a7..70bf5b523 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -268,6 +268,7 @@ from litellm.types.llms.anthropic import ( from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.router import RouterGeneralSettings from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking try: from litellm._version import version @@ -763,8 +764,7 @@ async def _PROXY_track_cost_callback( ) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) litellm_params = kwargs.get("litellm_params", {}) or {} - proxy_server_request = litellm_params.get("proxy_server_request") or {} - end_user_id = proxy_server_request.get("body", {}).get("user", None) + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) user_id = metadata.get("user_api_key_user_id", None) team_id = metadata.get("user_api_key_team_id", None) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 74bf398e7..0f7d6c3e0 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -337,14 +337,14 @@ class ProxyLogging: alert_to_webhook_url=self.alert_to_webhook_url, ) - if ( - self.alerting is not None - and "slack" in self.alerting - and "daily_reports" in self.alert_types - ): + if self.alerting is not None and "slack" in self.alerting: # NOTE: ENSURE we only add callbacks when alerting is on # We should NOT add callbacks when alerting is off - litellm.callbacks.append(self.slack_alerting_instance) # type: ignore + if "daily_reports" in self.alert_types: + litellm.callbacks.append(self.slack_alerting_instance) # type: ignore + litellm.success_callback.append( + self.slack_alerting_instance.response_taking_too_long_callback + ) if redis_cache is not None: self.internal_usage_cache.dual_cache.redis_cache = redis_cache @@ -354,9 +354,6 @@ class ProxyLogging: litellm.callbacks.append(self.max_budget_limiter) # type: ignore litellm.callbacks.append(self.cache_control_check) # type: ignore litellm.callbacks.append(self.service_logging_obj) # type: ignore - litellm.success_callback.append( - self.slack_alerting_instance.response_taking_too_long_callback - ) for callback in litellm.callbacks: if isinstance(callback, str): callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index 9cc8a8c1d..7e6dc7503 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -91,6 +91,7 @@ def rerank( model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) user = kwargs.get("user", None) + client = kwargs.get("client", None) try: _is_async = kwargs.pop("arerank", False) is True optional_params = GenericLiteLLMParams(**kwargs) @@ -150,7 +151,7 @@ def rerank( or optional_params.api_base or litellm.api_base or get_secret("COHERE_API_BASE") # type: ignore - or "https://api.cohere.com/v1/rerank" + or "https://api.cohere.com" ) if api_base is None: @@ -173,6 +174,7 @@ def rerank( _is_async=_is_async, headers=headers, litellm_logging_obj=litellm_logging_obj, + client=client, ) elif _custom_llm_provider == "azure_ai": api_base = ( diff --git a/litellm/types/utils.py b/litellm/types/utils.py index d02129681..334894320 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1602,3 +1602,16 @@ class StandardCallbackDynamicParams(TypedDict, total=False): langsmith_api_key: Optional[str] langsmith_project: Optional[str] langsmith_base_url: Optional[str] + + +class TeamUIKeyGenerationConfig(TypedDict): + allowed_team_member_roles: List[str] + + +class PersonalUIKeyGenerationConfig(TypedDict): + allowed_user_roles: List[str] + + +class StandardKeyGenerationConfig(TypedDict, total=False): + team_key_generation: TeamUIKeyGenerationConfig + personal_key_generation: PersonalUIKeyGenerationConfig diff --git a/litellm/utils.py b/litellm/utils.py index 003971142..262af3418 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6170,3 +6170,13 @@ class ProviderConfigManager: return litellm.GroqChatConfig() return OpenAIGPTConfig() + + +def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]: + """ + Used for enforcing `disable_end_user_cost_tracking` param. + """ + proxy_server_request = litellm_params.get("proxy_server_request") or {} + if litellm.disable_end_user_cost_tracking: + return None + return proxy_server_request.get("body", {}).get("user", None) diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index d7988e690..096dfc419 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -1080,3 +1080,34 @@ def test_cohere_img_embeddings(input, input_type): assert response.usage.prompt_tokens_details.image_tokens > 0 else: assert response.usage.prompt_tokens_details.text_tokens > 0 + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_embedding_with_extra_headers(sync_mode): + + input = ["hello world"] + from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler + + if sync_mode: + client = HTTPHandler() + else: + client = AsyncHTTPHandler() + + data = { + "model": "cohere/embed-english-v3.0", + "input": input, + "extra_headers": {"my-test-param": "hello-world"}, + "client": client, + } + with patch.object(client, "post") as mock_post: + try: + if sync_mode: + embedding(**data) + else: + await litellm.aembedding(**data) + except Exception as e: + print(e) + + mock_post.assert_called_once() + assert "my-test-param" in mock_post.call_args.kwargs["headers"] diff --git a/tests/local_testing/test_rerank.py b/tests/local_testing/test_rerank.py index c5ed1efe5..5fca6f135 100644 --- a/tests/local_testing/test_rerank.py +++ b/tests/local_testing/test_rerank.py @@ -215,7 +215,10 @@ async def test_rerank_custom_api_base(): args_to_api = kwargs["json"] print("Arguments passed to API=", args_to_api) print("url = ", _url) - assert _url[0] == "https://exampleopenaiendpoint-production.up.railway.app/" + assert ( + _url[0] + == "https://exampleopenaiendpoint-production.up.railway.app/v1/rerank" + ) assert args_to_api == expected_payload assert response.id is not None assert response.results is not None @@ -258,3 +261,32 @@ async def test_rerank_custom_callbacks(): assert custom_logger.kwargs.get("response_cost") > 0.0 assert custom_logger.response_obj is not None assert custom_logger.response_obj.results is not None + + +def test_complete_base_url_cohere(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + litellm.api_base = "http://localhost:4000" + litellm.set_verbose = True + + text = "Hello there!" + list_texts = ["Hello there!", "How are you?", "How do you do?"] + + rerank_model = "rerank-multilingual-v3.0" + + with patch.object(client, "post") as mock_post: + try: + litellm.rerank( + model=rerank_model, + query=text, + documents=list_texts, + custom_llm_provider="cohere", + client=client, + ) + except Exception as e: + print(e) + + print("mock_post.call_args", mock_post.call_args) + mock_post.assert_called_once() + assert "http://localhost:4000/v1/rerank" in mock_post.call_args.kwargs["url"] diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 52946ca30..cf1db27e8 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1012,3 +1012,23 @@ def test_models_by_provider(): for provider in providers: assert provider in models_by_provider.keys() + + +@pytest.mark.parametrize( + "litellm_params, disable_end_user_cost_tracking, expected_end_user_id", + [ + ({}, False, None), + ({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"), + ({"proxy_server_request": {"body": {"user": "123"}}}, True, None), + ], +) +def test_get_end_user_id_for_cost_tracking( + litellm_params, disable_end_user_cost_tracking, expected_end_user_id +): + from litellm.utils import get_end_user_id_for_cost_tracking + + litellm.disable_end_user_cost_tracking = disable_end_user_cost_tracking + assert ( + get_end_user_id_for_cost_tracking(litellm_params=litellm_params) + == expected_end_user_id + ) diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index 38883fa38..15c2118d8 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -216,3 +216,78 @@ async def test_init_custom_logger_compatible_class_as_callback(): await use_callback_in_llm_call(callback, used_in="success_callback") reset_env_vars() + + +def test_dynamic_logging_global_callback(): + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + from litellm.integrations.custom_logger import CustomLogger + from litellm.types.utils import ModelResponse, Choices, Message, Usage + + cl = CustomLogger() + + litellm_logging = LiteLLMLoggingObj( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "hi"}], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + kwargs={ + "langfuse_public_key": "my-mock-public-key", + "langfuse_secret_key": "my-mock-secret-key", + }, + dynamic_success_callbacks=["langfuse"], + ) + + with patch.object(cl, "log_success_event") as mock_log_success_event: + cl.log_success_event = mock_log_success_event + litellm.success_callback = [cl] + + try: + litellm_logging.success_handler( + result=ModelResponse( + id="chatcmpl-5418737b-ab14-420b-b9c5-b278b6681b70", + created=1732306261, + model="claude-3-opus-20240229", + object="chat.completion", + system_fingerprint=None, + choices=[ + Choices( + finish_reason="stop", + index=0, + message=Message( + content="hello", + role="assistant", + tool_calls=None, + function_call=None, + ), + ) + ], + usage=Usage( + completion_tokens=20, + prompt_tokens=10, + total_tokens=30, + completion_tokens_details=None, + prompt_tokens_details=None, + ), + ), + start_time=datetime.now(), + end_time=datetime.now(), + cache_hit=False, + ) + except Exception as e: + print(f"Error: {e}") + + mock_log_success_event.assert_called_once() + + +def test_get_combined_callback_list(): + from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list + + assert "langfuse" in get_combined_callback_list( + dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] + ) + assert "lago" in get_combined_callback_list( + dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] + ) diff --git a/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py index afb77f718..ecd289005 100644 --- a/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py +++ b/tests/pass_through_unit_tests/test_unit_test_anthropic_pass_through.py @@ -73,7 +73,7 @@ async def test_anthropic_passthrough_handler( start_time = datetime.now() end_time = datetime.now() - await AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( + result = AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler( httpx_response=mock_httpx_response, response_body=mock_response, logging_obj=mock_logging_obj, @@ -84,30 +84,7 @@ async def test_anthropic_passthrough_handler( cache_hit=False, ) - # Assert that async_success_handler was called - assert mock_logging_obj.async_success_handler.called - - call_args = mock_logging_obj.async_success_handler.call_args - call_kwargs = call_args.kwargs - print("call_kwargs", call_kwargs) - - # Assert required fields are present in call_kwargs - assert "result" in call_kwargs - assert "start_time" in call_kwargs - assert "end_time" in call_kwargs - assert "cache_hit" in call_kwargs - assert "response_cost" in call_kwargs - assert "model" in call_kwargs - assert "standard_logging_object" in call_kwargs - - # Assert specific values and types - assert isinstance(call_kwargs["result"], litellm.ModelResponse) - assert isinstance(call_kwargs["start_time"], datetime) - assert isinstance(call_kwargs["end_time"], datetime) - assert isinstance(call_kwargs["cache_hit"], bool) - assert isinstance(call_kwargs["response_cost"], float) - assert call_kwargs["model"] == "claude-3-opus-20240229" - assert isinstance(call_kwargs["standard_logging_object"], dict) + assert isinstance(result["result"], litellm.ModelResponse) def test_create_anthropic_response_logging_payload(mock_logging_obj): diff --git a/tests/pass_through_unit_tests/test_unit_test_streaming.py b/tests/pass_through_unit_tests/test_unit_test_streaming.py index bbbc465fc..61b71b56d 100644 --- a/tests/pass_through_unit_tests/test_unit_test_streaming.py +++ b/tests/pass_through_unit_tests/test_unit_test_streaming.py @@ -64,6 +64,7 @@ async def test_chunk_processor_yields_raw_bytes(endpoint_type, url_route): litellm_logging_obj = MagicMock() start_time = datetime.now() passthrough_success_handler_obj = MagicMock() + litellm_logging_obj.async_success_handler = AsyncMock() # Capture yielded chunks and perform detailed assertions received_chunks = [] diff --git a/tests/proxy_admin_ui_tests/conftest.py b/tests/proxy_admin_ui_tests/conftest.py new file mode 100644 index 000000000..eca0bc431 --- /dev/null +++ b/tests/proxy_admin_ui_tests/conftest.py @@ -0,0 +1,54 @@ +# conftest.py + +import importlib +import os +import sys + +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + + +@pytest.fixture(scope="function", autouse=True) +def setup_and_teardown(): + """ + This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. + """ + curr_dir = os.getcwd() # Get the current working directory + sys.path.insert( + 0, os.path.abspath("../..") + ) # Adds the project directory to the system path + + import litellm + from litellm import Router + + importlib.reload(litellm) + import asyncio + + loop = asyncio.get_event_loop_policy().new_event_loop() + asyncio.set_event_loop(loop) + print(litellm) + # from litellm import Router, completion, aembedding, acompletion, embedding + yield + + # Teardown code (executes after the yield point) + loop.close() # Close the loop created earlier + asyncio.set_event_loop(None) # Remove the reference to the loop + + +def pytest_collection_modifyitems(config, items): + # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests + custom_logger_tests = [ + item for item in items if "custom_logger" in item.parent.name + ] + other_tests = [item for item in items if "custom_logger" not in item.parent.name] + + # Sort tests based on their names + custom_logger_tests.sort(key=lambda x: x.name) + other_tests.sort(key=lambda x: x.name) + + # Reorder the items list + items[:] = custom_logger_tests + other_tests diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index b039a101b..81d9fb676 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -542,3 +542,65 @@ async def test_list_teams(prisma_client): # Clean up await prisma_client.delete_data(team_id_list=[team_id], table_name="team") + + +def test_is_team_key(): + from litellm.proxy.management_endpoints.key_management_endpoints import _is_team_key + + assert _is_team_key(GenerateKeyRequest(team_id="test_team_id")) + assert not _is_team_key(GenerateKeyRequest(user_id="test_user_id")) + + +def test_team_key_generation_check(): + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _team_key_generation_check, + ) + from fastapi import HTTPException + + litellm.key_generation_settings = { + "team_key_generation": {"allowed_team_member_roles": ["admin"]} + } + + assert _team_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + team_member=Member(role="admin", user_id="test_user_id"), + ) + ) + + with pytest.raises(HTTPException): + _team_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + user_id="test_user_id", + team_member=Member(role="user", user_id="test_user_id"), + ) + ) + + +def test_personal_key_generation_check(): + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _personal_key_generation_check, + ) + from fastapi import HTTPException + + litellm.key_generation_settings = { + "personal_key_generation": {"allowed_user_roles": ["proxy_admin"]} + } + + assert _personal_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="admin" + ) + ) + + with pytest.raises(HTTPException): + _personal_key_generation_check( + UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-1234", + user_id="admin", + ) + ) diff --git a/tests/proxy_admin_ui_tests/test_role_based_access.py b/tests/proxy_admin_ui_tests/test_role_based_access.py index 609a3598d..ff73143bf 100644 --- a/tests/proxy_admin_ui_tests/test_role_based_access.py +++ b/tests/proxy_admin_ui_tests/test_role_based_access.py @@ -160,7 +160,7 @@ async def test_create_new_user_in_organization(prisma_client, user_role): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member(role=user_role, user_id=created_user_id), + member=OrgMember(role=user_role, user_id=created_user_id), ), http_request=None, ) @@ -220,7 +220,7 @@ async def test_org_admin_create_team_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), ), http_request=None, ) @@ -292,7 +292,7 @@ async def test_org_admin_create_user_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), ), http_request=None, ) @@ -323,7 +323,7 @@ async def test_org_admin_create_user_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org_id, - member=Member( + member=OrgMember( role=LitellmUserRoles.INTERNAL_USER, user_id=new_internal_user_for_org ), ), @@ -375,7 +375,7 @@ async def test_org_admin_create_user_team_wrong_org_permissions(prisma_client): response = await organization_member_add( data=OrganizationMemberAddRequest( organization_id=org1_id, - member=Member(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), + member=OrgMember(role=LitellmUserRoles.ORG_ADMIN, user_id=created_user_id), ), http_request=None, )