diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index 3eb3d6df89..604ceee3e5 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -231,6 +231,14 @@ curl -X POST '/team/new' \ Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e) +### Restrict Users from creating personal keys + +This is useful if you only want users to create keys under a specific team. + +This will also prevent users from using their session tokens on the test keys chat pane. + +👉 [**See this**](./virtual_keys.md#restricting-key-generation) + ## **All Settings for Self Serve / SSO Flow** ```yaml diff --git a/litellm/integrations/langfuse/langfuse.py b/litellm/integrations/langfuse/langfuse.py index c443496942..2a459af9b9 100644 --- a/litellm/integrations/langfuse/langfuse.py +++ b/litellm/integrations/langfuse/langfuse.py @@ -15,7 +15,10 @@ from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info from litellm.llms.custom_httpx.http_handler import _get_httpx_client from litellm.secret_managers.main import str_to_bool from litellm.types.integrations.langfuse import * -from litellm.types.utils import StandardLoggingPayload +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingPromptManagementMetadata, +) if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache @@ -463,14 +466,16 @@ class LangFuseLogger: if standard_logging_object is None: end_user_id = None - prompt_management_metadata: Optional[dict] = None + prompt_management_metadata: Optional[ + StandardLoggingPromptManagementMetadata + ] = None else: end_user_id = standard_logging_object["metadata"].get( "user_api_key_end_user_id", None ) prompt_management_metadata = cast( - Optional[dict], + Optional[StandardLoggingPromptManagementMetadata], standard_logging_object["metadata"].get( "prompt_management_metadata", None ), @@ -707,7 +712,10 @@ class LangFuseLogger: if supports_prompt: generation_params = _add_prompt_to_generation_params( - generation_params=generation_params, clean_metadata=clean_metadata + generation_params=generation_params, + clean_metadata=clean_metadata, + prompt_management_metadata=prompt_management_metadata, + langfuse_client=self.Langfuse, ) if output is not None and isinstance(output, str) and level == "ERROR": generation_params["status_message"] = output @@ -753,8 +761,12 @@ class LangFuseLogger: def _add_prompt_to_generation_params( - generation_params: dict, clean_metadata: dict + generation_params: dict, + clean_metadata: dict, + prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata], + langfuse_client: Any, ) -> dict: + from langfuse import Langfuse from langfuse.model import ( ChatPromptClient, Prompt_Chat, @@ -762,8 +774,10 @@ def _add_prompt_to_generation_params( TextPromptClient, ) + langfuse_client = cast(Langfuse, langfuse_client) + user_prompt = clean_metadata.pop("prompt", None) - if user_prompt is None: + if user_prompt is None and prompt_management_metadata is None: pass elif isinstance(user_prompt, dict): if user_prompt.get("type", "") == "chat": @@ -815,6 +829,20 @@ def _add_prompt_to_generation_params( verbose_logger.error( "[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse" ) + elif ( + prompt_management_metadata is not None + and prompt_management_metadata["prompt_integration"] == "langfuse" + ): + try: + generation_params["prompt"] = langfuse_client.get_prompt( + prompt_management_metadata["prompt_id"] + ) + except Exception as e: + verbose_logger.debug( + f"[Non-blocking] Langfuse Logger: Error getting prompt client for logging: {e}" + ) + pass + else: generation_params["prompt"] = user_prompt diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index bcc59c416f..4a0c27354f 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -1,7 +1,8 @@ #### What this does #### # On success + failure, log events to Supabase -from typing import Optional +from datetime import datetime +from typing import Optional, cast import litellm from litellm._logging import print_verbose, verbose_logger @@ -32,6 +33,8 @@ class S3Logger: f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}" ) + s3_use_team_prefix = False + if litellm.s3_callback_params is not None: # read in .env variables - example os.environ/AWS_BUCKET_NAME for key, value in litellm.s3_callback_params.items(): @@ -56,7 +59,10 @@ class S3Logger: s3_config = litellm.s3_callback_params.get("s3_config") s3_path = litellm.s3_callback_params.get("s3_path") # done reading litellm.s3_callback_params - + s3_use_team_prefix = bool( + litellm.s3_callback_params.get("s3_use_team_prefix", False) + ) + self.s3_use_team_prefix = s3_use_team_prefix self.bucket_name = s3_bucket_name self.s3_path = s3_path verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}") @@ -114,21 +120,31 @@ class S3Logger: clean_metadata[key] = value # Ensure everything in the payload is converted to str - payload: Optional[StandardLoggingPayload] = kwargs.get( - "standard_logging_object", None + payload: Optional[StandardLoggingPayload] = cast( + Optional[StandardLoggingPayload], + kwargs.get("standard_logging_object", None), ) if payload is None: return + team_alias = payload["metadata"].get("user_api_key_team_alias") + + team_alias_prefix = "" + if ( + litellm.enable_preview_features + and self.s3_use_team_prefix + and team_alias is not None + ): + team_alias_prefix = f"{team_alias}/" + s3_file_name = litellm.utils.get_logging_id(start_time, payload) or "" - s3_object_key = ( - (self.s3_path.rstrip("/") + "/" if self.s3_path else "") - + start_time.strftime("%Y-%m-%d") - + "/" - + s3_file_name - ) # we need the s3 key to include the time, so we log cache hits too - s3_object_key += ".json" + s3_object_key = get_s3_object_key( + cast(Optional[str], self.s3_path) or "", + team_alias_prefix, + start_time, + s3_file_name, + ) s3_object_download_filename = ( "time-" @@ -161,3 +177,20 @@ class S3Logger: except Exception as e: verbose_logger.exception(f"s3 Layer Error - {str(e)}") pass + + +def get_s3_object_key( + s3_path: str, + team_alias_prefix: str, + start_time: datetime, + s3_file_name: str, +) -> str: + s3_object_key = ( + (s3_path.rstrip("/") + "/" if s3_path else "") + + team_alias_prefix + + start_time.strftime("%Y-%m-%d") + + "/" + + s3_file_name + ) # we need the s3 key to include the time, so we log cache hits too + s3_object_key += ".json" + return s3_object_key diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 34a270258d..cbec481ab0 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -416,6 +416,8 @@ class Logging(LiteLLMLoggingBaseClass): if custom_logger is None: continue + old_name = model + model, messages, non_default_params = ( custom_logger.get_chat_completion_prompt( model=model, @@ -426,6 +428,7 @@ class Logging(LiteLLMLoggingBaseClass): dynamic_callback_params=self.standard_callback_dynamic_params, ) ) + self.model_call_details["prompt_integration"] = old_name.split("/")[0] self.messages = messages return model, messages, non_default_params @@ -2790,7 +2793,9 @@ class StandardLoggingPayloadSetup: @staticmethod def get_standard_logging_metadata( - metadata: Optional[Dict[str, Any]], litellm_params: Optional[dict] = None + metadata: Optional[Dict[str, Any]], + litellm_params: Optional[dict] = None, + prompt_integration: Optional[str] = None, ) -> StandardLoggingMetadata: """ Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. @@ -2814,10 +2819,12 @@ class StandardLoggingPayloadSetup: Optional[dict], litellm_params.get("prompt_variables", None) ) - prompt_management_metadata = StandardLoggingPromptManagementMetadata( - prompt_id=prompt_id, - prompt_variables=prompt_variables, - ) + if prompt_id is not None and prompt_integration is not None: + prompt_management_metadata = StandardLoggingPromptManagementMetadata( + prompt_id=prompt_id, + prompt_variables=prompt_variables, + prompt_integration=prompt_integration, + ) # Initialize with default values clean_metadata = StandardLoggingMetadata( @@ -3126,7 +3133,9 @@ def get_standard_logging_object_payload( ) # clean up litellm metadata clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata( - metadata=metadata, litellm_params=litellm_params + metadata=metadata, + litellm_params=litellm_params, + prompt_integration=kwargs.get("prompt_integration", None), ) saved_cache_cost: float = 0.0 diff --git a/litellm/llms/openai_like/embedding/handler.py b/litellm/llms/openai_like/embedding/handler.py index 6e2471baca..95a4aa854a 100644 --- a/litellm/llms/openai_like/embedding/handler.py +++ b/litellm/llms/openai_like/embedding/handler.py @@ -36,16 +36,15 @@ class OpenAILikeEmbeddingHandler(OpenAILikeBase): ) -> EmbeddingResponse: response = None try: - if client is None or isinstance(client, AsyncHTTPHandler): - self.async_client = get_async_httpx_client( + if client is None or not isinstance(client, AsyncHTTPHandler): + async_client = get_async_httpx_client( llm_provider=litellm.LlmProviders.OPENAI, params={"timeout": timeout}, ) else: - self.async_client = client - + async_client = client try: - response = await self.async_client.post( + response = await async_client.post( api_base, headers=headers, data=json.dumps(data), diff --git a/litellm/main.py b/litellm/main.py index 082244091c..5484cf43a3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3055,52 +3055,17 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: model=model, api_base=kwargs.get("api_base", None) ) + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + response: Optional[EmbeddingResponse] = None - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "azure" - or custom_llm_provider == "xinference" - or custom_llm_provider == "voyage" - or custom_llm_provider == "mistral" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "triton" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "openrouter" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "groq" - or custom_llm_provider == "nvidia_nim" - or custom_llm_provider == "cerebras" - or custom_llm_provider == "sambanova" - or custom_llm_provider == "ai21_chat" - or custom_llm_provider == "volcengine" - or custom_llm_provider == "deepseek" - or custom_llm_provider == "fireworks_ai" - or custom_llm_provider == "ollama" - or custom_llm_provider == "vertex_ai" - or custom_llm_provider == "gemini" - or custom_llm_provider == "databricks" - or custom_llm_provider == "watsonx" - or custom_llm_provider == "cohere" - or custom_llm_provider == "huggingface" - or custom_llm_provider == "bedrock" - or custom_llm_provider == "azure_ai" - or custom_llm_provider == "together_ai" - or custom_llm_provider == "openai_like" - or custom_llm_provider == "jina_ai" - or custom_llm_provider == "voyage" - ): # currently implemented aiohttp calls for just azure and openai, soon all. - # Await normally - init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict): - response = EmbeddingResponse(**init_response) - elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO - response = init_response - elif asyncio.iscoroutine(init_response): - response = await init_response # type: ignore - else: - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict): + response = EmbeddingResponse(**init_response) + elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response # type: ignore + if ( response is not None and isinstance(response, EmbeddingResponse) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index aaffa1c851..2527441d9d 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -32,8 +32,13 @@ model_list: output_cost_per_token: 0.0000006 -# litellm_settings: -# key_generation_settings: -# personal_key_generation: # maps to 'Default Team' on UI -# allowed_user_roles: ["proxy_admin"] +litellm_settings: + success_callback: ["s3"] + enable_preview_features: true + s3_callback_params: + s3_bucket_name: my-new-test-bucket-litellm # AWS Bucket Name for S3 + s3_region_name: us-west-2 # AWS Region Name for S3 + s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 + s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 + s3_use_team_prefix: true diff --git a/litellm/router.py b/litellm/router.py index 9b0d34f92c..a7d9667f43 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4163,31 +4163,6 @@ class Router: litellm_router_instance=self, model=deployment.to_json(exclude_none=True) ) - # set region (if azure model) ## PREVIEW FEATURE ## - if litellm.enable_preview_features is True: - print("Auto inferring region") # noqa - """ - Hiding behind a feature flag - When there is a large amount of LLM deployments this makes startup times blow up - """ - try: - if ( - "azure" in deployment.litellm_params.model - and deployment.litellm_params.region_name is None - ): - region = litellm.utils.get_model_region( - litellm_params=deployment.litellm_params, mode=None - ) - - deployment.litellm_params.region_name = region - except Exception as e: - verbose_router_logger.debug( - "Unable to get the region for azure model - {}, {}".format( - deployment.litellm_params.model, str(e) - ) - ) - pass # [NON-BLOCKING] - return deployment def add_deployment(self, deployment: Deployment) -> Optional[Deployment]: diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 09567b88b4..e96cd91825 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1458,8 +1458,9 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict): class StandardLoggingPromptManagementMetadata(TypedDict): - prompt_id: Optional[str] + prompt_id: str prompt_variables: Optional[dict] + prompt_integration: str class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): diff --git a/tests/local_testing/test_alangfuse.py b/tests/local_testing/test_alangfuse.py index 3887805996..3784ee547e 100644 --- a/tests/local_testing/test_alangfuse.py +++ b/tests/local_testing/test_alangfuse.py @@ -1114,6 +1114,7 @@ generation_params = { def test_langfuse_prompt_type(prompt): from litellm.integrations.langfuse.langfuse import _add_prompt_to_generation_params + from unittest.mock import patch, MagicMock, Mock clean_metadata = { "prompt": { @@ -1215,7 +1216,10 @@ def test_langfuse_prompt_type(prompt): "cache_hit": False, } _add_prompt_to_generation_params( - generation_params=generation_params, clean_metadata=clean_metadata + generation_params=generation_params, + clean_metadata=clean_metadata, + prompt_management_metadata=None, + langfuse_client=Mock(), ) diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index 4aedc00871..6bb1e95532 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -1019,18 +1019,27 @@ def test_hosted_vllm_embedding(monkeypatch): assert json_data["model"] == "jina-embeddings-v3" -def test_lm_studio_embedding(monkeypatch): +@pytest.mark.asyncio +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_lm_studio_embedding(monkeypatch, sync_mode): monkeypatch.setenv("LM_STUDIO_API_BASE", "http://localhost:8000") - from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler - client = HTTPHandler() + client = HTTPHandler() if sync_mode else AsyncHTTPHandler() with patch.object(client, "post") as mock_post: try: - embedding( - model="lm_studio/jina-embeddings-v3", - input=["Hello world"], - client=client, - ) + if sync_mode: + embedding( + model="lm_studio/jina-embeddings-v3", + input=["Hello world"], + client=client, + ) + else: + await litellm.aembedding( + model="lm_studio/jina-embeddings-v3", + input=["Hello world"], + client=client, + ) except Exception as e: print(e)