Litellm dev 01 10 2025 p3 (#7682)

* feat(langfuse.py): log the used prompt when prompt management used

* test: fix test

* docs(self_serve.md): add doc on restricting personal key creation on ui

* feat(s3.py): support s3 logging with team alias prefixes (if available)

New preview feature

* fix(main.py): remove old if block - simplify to just await if coroutine returned

fixes lm_studio async embedding error

* fix(langfuse.py): handle get prompt check
This commit is contained in:
Krish Dholakia 2025-01-10 21:56:42 -08:00 committed by GitHub
parent e54d23c919
commit 953c021aa7
11 changed files with 148 additions and 112 deletions

View file

@ -231,6 +231,14 @@ curl -X POST '<PROXY_BASE_URL>/team/new' \
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e) Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
### Restrict Users from creating personal keys
This is useful if you only want users to create keys under a specific team.
This will also prevent users from using their session tokens on the test keys chat pane.
👉 [**See this**](./virtual_keys.md#restricting-key-generation)
## **All Settings for Self Serve / SSO Flow** ## **All Settings for Self Serve / SSO Flow**
```yaml ```yaml

View file

@ -15,7 +15,10 @@ from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
from litellm.llms.custom_httpx.http_handler import _get_httpx_client from litellm.llms.custom_httpx.http_handler import _get_httpx_client
from litellm.secret_managers.main import str_to_bool from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.langfuse import * from litellm.types.integrations.langfuse import *
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingPromptManagementMetadata,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
@ -463,14 +466,16 @@ class LangFuseLogger:
if standard_logging_object is None: if standard_logging_object is None:
end_user_id = None end_user_id = None
prompt_management_metadata: Optional[dict] = None prompt_management_metadata: Optional[
StandardLoggingPromptManagementMetadata
] = None
else: else:
end_user_id = standard_logging_object["metadata"].get( end_user_id = standard_logging_object["metadata"].get(
"user_api_key_end_user_id", None "user_api_key_end_user_id", None
) )
prompt_management_metadata = cast( prompt_management_metadata = cast(
Optional[dict], Optional[StandardLoggingPromptManagementMetadata],
standard_logging_object["metadata"].get( standard_logging_object["metadata"].get(
"prompt_management_metadata", None "prompt_management_metadata", None
), ),
@ -707,7 +712,10 @@ class LangFuseLogger:
if supports_prompt: if supports_prompt:
generation_params = _add_prompt_to_generation_params( generation_params = _add_prompt_to_generation_params(
generation_params=generation_params, clean_metadata=clean_metadata generation_params=generation_params,
clean_metadata=clean_metadata,
prompt_management_metadata=prompt_management_metadata,
langfuse_client=self.Langfuse,
) )
if output is not None and isinstance(output, str) and level == "ERROR": if output is not None and isinstance(output, str) and level == "ERROR":
generation_params["status_message"] = output generation_params["status_message"] = output
@ -753,8 +761,12 @@ class LangFuseLogger:
def _add_prompt_to_generation_params( def _add_prompt_to_generation_params(
generation_params: dict, clean_metadata: dict generation_params: dict,
clean_metadata: dict,
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata],
langfuse_client: Any,
) -> dict: ) -> dict:
from langfuse import Langfuse
from langfuse.model import ( from langfuse.model import (
ChatPromptClient, ChatPromptClient,
Prompt_Chat, Prompt_Chat,
@ -762,8 +774,10 @@ def _add_prompt_to_generation_params(
TextPromptClient, TextPromptClient,
) )
langfuse_client = cast(Langfuse, langfuse_client)
user_prompt = clean_metadata.pop("prompt", None) user_prompt = clean_metadata.pop("prompt", None)
if user_prompt is None: if user_prompt is None and prompt_management_metadata is None:
pass pass
elif isinstance(user_prompt, dict): elif isinstance(user_prompt, dict):
if user_prompt.get("type", "") == "chat": if user_prompt.get("type", "") == "chat":
@ -815,6 +829,20 @@ def _add_prompt_to_generation_params(
verbose_logger.error( verbose_logger.error(
"[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse" "[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse"
) )
elif (
prompt_management_metadata is not None
and prompt_management_metadata["prompt_integration"] == "langfuse"
):
try:
generation_params["prompt"] = langfuse_client.get_prompt(
prompt_management_metadata["prompt_id"]
)
except Exception as e:
verbose_logger.debug(
f"[Non-blocking] Langfuse Logger: Error getting prompt client for logging: {e}"
)
pass
else: else:
generation_params["prompt"] = user_prompt generation_params["prompt"] = user_prompt

View file

@ -1,7 +1,8 @@
#### What this does #### #### What this does ####
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
from typing import Optional from datetime import datetime
from typing import Optional, cast
import litellm import litellm
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
@ -32,6 +33,8 @@ class S3Logger:
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}" f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
) )
s3_use_team_prefix = False
if litellm.s3_callback_params is not None: if litellm.s3_callback_params is not None:
# read in .env variables - example os.environ/AWS_BUCKET_NAME # read in .env variables - example os.environ/AWS_BUCKET_NAME
for key, value in litellm.s3_callback_params.items(): for key, value in litellm.s3_callback_params.items():
@ -56,7 +59,10 @@ class S3Logger:
s3_config = litellm.s3_callback_params.get("s3_config") s3_config = litellm.s3_callback_params.get("s3_config")
s3_path = litellm.s3_callback_params.get("s3_path") s3_path = litellm.s3_callback_params.get("s3_path")
# done reading litellm.s3_callback_params # done reading litellm.s3_callback_params
s3_use_team_prefix = bool(
litellm.s3_callback_params.get("s3_use_team_prefix", False)
)
self.s3_use_team_prefix = s3_use_team_prefix
self.bucket_name = s3_bucket_name self.bucket_name = s3_bucket_name
self.s3_path = s3_path self.s3_path = s3_path
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}") verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
@ -114,21 +120,31 @@ class S3Logger:
clean_metadata[key] = value clean_metadata[key] = value
# Ensure everything in the payload is converted to str # Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get( payload: Optional[StandardLoggingPayload] = cast(
"standard_logging_object", None Optional[StandardLoggingPayload],
kwargs.get("standard_logging_object", None),
) )
if payload is None: if payload is None:
return return
team_alias = payload["metadata"].get("user_api_key_team_alias")
team_alias_prefix = ""
if (
litellm.enable_preview_features
and self.s3_use_team_prefix
and team_alias is not None
):
team_alias_prefix = f"{team_alias}/"
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or "" s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
s3_object_key = ( s3_object_key = get_s3_object_key(
(self.s3_path.rstrip("/") + "/" if self.s3_path else "") cast(Optional[str], self.s3_path) or "",
+ start_time.strftime("%Y-%m-%d") team_alias_prefix,
+ "/" start_time,
+ s3_file_name s3_file_name,
) # we need the s3 key to include the time, so we log cache hits too )
s3_object_key += ".json"
s3_object_download_filename = ( s3_object_download_filename = (
"time-" "time-"
@ -161,3 +177,20 @@ class S3Logger:
except Exception as e: except Exception as e:
verbose_logger.exception(f"s3 Layer Error - {str(e)}") verbose_logger.exception(f"s3 Layer Error - {str(e)}")
pass pass
def get_s3_object_key(
s3_path: str,
team_alias_prefix: str,
start_time: datetime,
s3_file_name: str,
) -> str:
s3_object_key = (
(s3_path.rstrip("/") + "/" if s3_path else "")
+ team_alias_prefix
+ start_time.strftime("%Y-%m-%d")
+ "/"
+ s3_file_name
) # we need the s3 key to include the time, so we log cache hits too
s3_object_key += ".json"
return s3_object_key

View file

@ -416,6 +416,8 @@ class Logging(LiteLLMLoggingBaseClass):
if custom_logger is None: if custom_logger is None:
continue continue
old_name = model
model, messages, non_default_params = ( model, messages, non_default_params = (
custom_logger.get_chat_completion_prompt( custom_logger.get_chat_completion_prompt(
model=model, model=model,
@ -426,6 +428,7 @@ class Logging(LiteLLMLoggingBaseClass):
dynamic_callback_params=self.standard_callback_dynamic_params, dynamic_callback_params=self.standard_callback_dynamic_params,
) )
) )
self.model_call_details["prompt_integration"] = old_name.split("/")[0]
self.messages = messages self.messages = messages
return model, messages, non_default_params return model, messages, non_default_params
@ -2790,7 +2793,9 @@ class StandardLoggingPayloadSetup:
@staticmethod @staticmethod
def get_standard_logging_metadata( def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]], litellm_params: Optional[dict] = None metadata: Optional[Dict[str, Any]],
litellm_params: Optional[dict] = None,
prompt_integration: Optional[str] = None,
) -> StandardLoggingMetadata: ) -> StandardLoggingMetadata:
""" """
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@ -2814,9 +2819,11 @@ class StandardLoggingPayloadSetup:
Optional[dict], litellm_params.get("prompt_variables", None) Optional[dict], litellm_params.get("prompt_variables", None)
) )
if prompt_id is not None and prompt_integration is not None:
prompt_management_metadata = StandardLoggingPromptManagementMetadata( prompt_management_metadata = StandardLoggingPromptManagementMetadata(
prompt_id=prompt_id, prompt_id=prompt_id,
prompt_variables=prompt_variables, prompt_variables=prompt_variables,
prompt_integration=prompt_integration,
) )
# Initialize with default values # Initialize with default values
@ -3126,7 +3133,9 @@ def get_standard_logging_object_payload(
) )
# clean up litellm metadata # clean up litellm metadata
clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata( clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
metadata=metadata, litellm_params=litellm_params metadata=metadata,
litellm_params=litellm_params,
prompt_integration=kwargs.get("prompt_integration", None),
) )
saved_cache_cost: float = 0.0 saved_cache_cost: float = 0.0

View file

@ -36,16 +36,15 @@ class OpenAILikeEmbeddingHandler(OpenAILikeBase):
) -> EmbeddingResponse: ) -> EmbeddingResponse:
response = None response = None
try: try:
if client is None or isinstance(client, AsyncHTTPHandler): if client is None or not isinstance(client, AsyncHTTPHandler):
self.async_client = get_async_httpx_client( async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OPENAI, llm_provider=litellm.LlmProviders.OPENAI,
params={"timeout": timeout}, params={"timeout": timeout},
) )
else: else:
self.async_client = client async_client = client
try: try:
response = await self.async_client.post( response = await async_client.post(
api_base, api_base,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),

View file

@ -3055,52 +3055,17 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
model=model, api_base=kwargs.get("api_base", None) model=model, api_base=kwargs.get("api_base", None)
) )
response: Optional[EmbeddingResponse] = None
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "xinference"
or custom_llm_provider == "voyage"
or custom_llm_provider == "mistral"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "triton"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
or custom_llm_provider == "volcengine"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini"
or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere"
or custom_llm_provider == "huggingface"
or custom_llm_provider == "bedrock"
or custom_llm_provider == "azure_ai"
or custom_llm_provider == "together_ai"
or custom_llm_provider == "openai_like"
or custom_llm_provider == "jina_ai"
or custom_llm_provider == "voyage"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
response: Optional[EmbeddingResponse] = None
if isinstance(init_response, dict): if isinstance(init_response, dict):
response = EmbeddingResponse(**init_response) response = EmbeddingResponse(**init_response)
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
response = init_response response = init_response
elif asyncio.iscoroutine(init_response): elif asyncio.iscoroutine(init_response):
response = await init_response # type: ignore response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if ( if (
response is not None response is not None
and isinstance(response, EmbeddingResponse) and isinstance(response, EmbeddingResponse)

View file

@ -32,8 +32,13 @@ model_list:
output_cost_per_token: 0.0000006 output_cost_per_token: 0.0000006
# litellm_settings: litellm_settings:
# key_generation_settings: success_callback: ["s3"]
# personal_key_generation: # maps to 'Default Team' on UI enable_preview_features: true
# allowed_user_roles: ["proxy_admin"] s3_callback_params:
s3_bucket_name: my-new-test-bucket-litellm # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
s3_use_team_prefix: true

View file

@ -4163,31 +4163,6 @@ class Router:
litellm_router_instance=self, model=deployment.to_json(exclude_none=True) litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
) )
# set region (if azure model) ## PREVIEW FEATURE ##
if litellm.enable_preview_features is True:
print("Auto inferring region") # noqa
"""
Hiding behind a feature flag
When there is a large amount of LLM deployments this makes startup times blow up
"""
try:
if (
"azure" in deployment.litellm_params.model
and deployment.litellm_params.region_name is None
):
region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None
)
deployment.litellm_params.region_name = region
except Exception as e:
verbose_router_logger.debug(
"Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e)
)
)
pass # [NON-BLOCKING]
return deployment return deployment
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]: def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:

View file

@ -1458,8 +1458,9 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
class StandardLoggingPromptManagementMetadata(TypedDict): class StandardLoggingPromptManagementMetadata(TypedDict):
prompt_id: Optional[str] prompt_id: str
prompt_variables: Optional[dict] prompt_variables: Optional[dict]
prompt_integration: str
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):

View file

@ -1114,6 +1114,7 @@ generation_params = {
def test_langfuse_prompt_type(prompt): def test_langfuse_prompt_type(prompt):
from litellm.integrations.langfuse.langfuse import _add_prompt_to_generation_params from litellm.integrations.langfuse.langfuse import _add_prompt_to_generation_params
from unittest.mock import patch, MagicMock, Mock
clean_metadata = { clean_metadata = {
"prompt": { "prompt": {
@ -1215,7 +1216,10 @@ def test_langfuse_prompt_type(prompt):
"cache_hit": False, "cache_hit": False,
} }
_add_prompt_to_generation_params( _add_prompt_to_generation_params(
generation_params=generation_params, clean_metadata=clean_metadata generation_params=generation_params,
clean_metadata=clean_metadata,
prompt_management_metadata=None,
langfuse_client=Mock(),
) )

View file

@ -1019,18 +1019,27 @@ def test_hosted_vllm_embedding(monkeypatch):
assert json_data["model"] == "jina-embeddings-v3" assert json_data["model"] == "jina-embeddings-v3"
def test_lm_studio_embedding(monkeypatch): @pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_lm_studio_embedding(monkeypatch, sync_mode):
monkeypatch.setenv("LM_STUDIO_API_BASE", "http://localhost:8000") monkeypatch.setenv("LM_STUDIO_API_BASE", "http://localhost:8000")
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
client = HTTPHandler() client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
with patch.object(client, "post") as mock_post: with patch.object(client, "post") as mock_post:
try: try:
if sync_mode:
embedding( embedding(
model="lm_studio/jina-embeddings-v3", model="lm_studio/jina-embeddings-v3",
input=["Hello world"], input=["Hello world"],
client=client, client=client,
) )
else:
await litellm.aembedding(
model="lm_studio/jina-embeddings-v3",
input=["Hello world"],
client=client,
)
except Exception as e: except Exception as e:
print(e) print(e)