mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 01 10 2025 p3 (#7682)
* feat(langfuse.py): log the used prompt when prompt management used * test: fix test * docs(self_serve.md): add doc on restricting personal key creation on ui * feat(s3.py): support s3 logging with team alias prefixes (if available) New preview feature * fix(main.py): remove old if block - simplify to just await if coroutine returned fixes lm_studio async embedding error * fix(langfuse.py): handle get prompt check
This commit is contained in:
parent
e54d23c919
commit
953c021aa7
11 changed files with 148 additions and 112 deletions
|
@ -231,6 +231,14 @@ curl -X POST '<PROXY_BASE_URL>/team/new' \
|
|||
|
||||
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
||||
|
||||
### Restrict Users from creating personal keys
|
||||
|
||||
This is useful if you only want users to create keys under a specific team.
|
||||
|
||||
This will also prevent users from using their session tokens on the test keys chat pane.
|
||||
|
||||
👉 [**See this**](./virtual_keys.md#restricting-key-generation)
|
||||
|
||||
## **All Settings for Self Serve / SSO Flow**
|
||||
|
||||
```yaml
|
||||
|
|
|
@ -15,7 +15,10 @@ from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
|
|||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.integrations.langfuse import *
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingPromptManagementMetadata,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
|
||||
|
@ -463,14 +466,16 @@ class LangFuseLogger:
|
|||
|
||||
if standard_logging_object is None:
|
||||
end_user_id = None
|
||||
prompt_management_metadata: Optional[dict] = None
|
||||
prompt_management_metadata: Optional[
|
||||
StandardLoggingPromptManagementMetadata
|
||||
] = None
|
||||
else:
|
||||
end_user_id = standard_logging_object["metadata"].get(
|
||||
"user_api_key_end_user_id", None
|
||||
)
|
||||
|
||||
prompt_management_metadata = cast(
|
||||
Optional[dict],
|
||||
Optional[StandardLoggingPromptManagementMetadata],
|
||||
standard_logging_object["metadata"].get(
|
||||
"prompt_management_metadata", None
|
||||
),
|
||||
|
@ -707,7 +712,10 @@ class LangFuseLogger:
|
|||
|
||||
if supports_prompt:
|
||||
generation_params = _add_prompt_to_generation_params(
|
||||
generation_params=generation_params, clean_metadata=clean_metadata
|
||||
generation_params=generation_params,
|
||||
clean_metadata=clean_metadata,
|
||||
prompt_management_metadata=prompt_management_metadata,
|
||||
langfuse_client=self.Langfuse,
|
||||
)
|
||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||
generation_params["status_message"] = output
|
||||
|
@ -753,8 +761,12 @@ class LangFuseLogger:
|
|||
|
||||
|
||||
def _add_prompt_to_generation_params(
|
||||
generation_params: dict, clean_metadata: dict
|
||||
generation_params: dict,
|
||||
clean_metadata: dict,
|
||||
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata],
|
||||
langfuse_client: Any,
|
||||
) -> dict:
|
||||
from langfuse import Langfuse
|
||||
from langfuse.model import (
|
||||
ChatPromptClient,
|
||||
Prompt_Chat,
|
||||
|
@ -762,8 +774,10 @@ def _add_prompt_to_generation_params(
|
|||
TextPromptClient,
|
||||
)
|
||||
|
||||
langfuse_client = cast(Langfuse, langfuse_client)
|
||||
|
||||
user_prompt = clean_metadata.pop("prompt", None)
|
||||
if user_prompt is None:
|
||||
if user_prompt is None and prompt_management_metadata is None:
|
||||
pass
|
||||
elif isinstance(user_prompt, dict):
|
||||
if user_prompt.get("type", "") == "chat":
|
||||
|
@ -815,6 +829,20 @@ def _add_prompt_to_generation_params(
|
|||
verbose_logger.error(
|
||||
"[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse"
|
||||
)
|
||||
elif (
|
||||
prompt_management_metadata is not None
|
||||
and prompt_management_metadata["prompt_integration"] == "langfuse"
|
||||
):
|
||||
try:
|
||||
generation_params["prompt"] = langfuse_client.get_prompt(
|
||||
prompt_management_metadata["prompt_id"]
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"[Non-blocking] Langfuse Logger: Error getting prompt client for logging: {e}"
|
||||
)
|
||||
pass
|
||||
|
||||
else:
|
||||
generation_params["prompt"] = user_prompt
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
@ -32,6 +33,8 @@ class S3Logger:
|
|||
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
|
||||
)
|
||||
|
||||
s3_use_team_prefix = False
|
||||
|
||||
if litellm.s3_callback_params is not None:
|
||||
# read in .env variables - example os.environ/AWS_BUCKET_NAME
|
||||
for key, value in litellm.s3_callback_params.items():
|
||||
|
@ -56,7 +59,10 @@ class S3Logger:
|
|||
s3_config = litellm.s3_callback_params.get("s3_config")
|
||||
s3_path = litellm.s3_callback_params.get("s3_path")
|
||||
# done reading litellm.s3_callback_params
|
||||
|
||||
s3_use_team_prefix = bool(
|
||||
litellm.s3_callback_params.get("s3_use_team_prefix", False)
|
||||
)
|
||||
self.s3_use_team_prefix = s3_use_team_prefix
|
||||
self.bucket_name = s3_bucket_name
|
||||
self.s3_path = s3_path
|
||||
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
|
||||
|
@ -114,21 +120,31 @@ class S3Logger:
|
|||
clean_metadata[key] = value
|
||||
|
||||
# Ensure everything in the payload is converted to str
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
payload: Optional[StandardLoggingPayload] = cast(
|
||||
Optional[StandardLoggingPayload],
|
||||
kwargs.get("standard_logging_object", None),
|
||||
)
|
||||
|
||||
if payload is None:
|
||||
return
|
||||
|
||||
team_alias = payload["metadata"].get("user_api_key_team_alias")
|
||||
|
||||
team_alias_prefix = ""
|
||||
if (
|
||||
litellm.enable_preview_features
|
||||
and self.s3_use_team_prefix
|
||||
and team_alias is not None
|
||||
):
|
||||
team_alias_prefix = f"{team_alias}/"
|
||||
|
||||
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
|
||||
s3_object_key = (
|
||||
(self.s3_path.rstrip("/") + "/" if self.s3_path else "")
|
||||
+ start_time.strftime("%Y-%m-%d")
|
||||
+ "/"
|
||||
+ s3_file_name
|
||||
) # we need the s3 key to include the time, so we log cache hits too
|
||||
s3_object_key += ".json"
|
||||
s3_object_key = get_s3_object_key(
|
||||
cast(Optional[str], self.s3_path) or "",
|
||||
team_alias_prefix,
|
||||
start_time,
|
||||
s3_file_name,
|
||||
)
|
||||
|
||||
s3_object_download_filename = (
|
||||
"time-"
|
||||
|
@ -161,3 +177,20 @@ class S3Logger:
|
|||
except Exception as e:
|
||||
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
||||
pass
|
||||
|
||||
|
||||
def get_s3_object_key(
|
||||
s3_path: str,
|
||||
team_alias_prefix: str,
|
||||
start_time: datetime,
|
||||
s3_file_name: str,
|
||||
) -> str:
|
||||
s3_object_key = (
|
||||
(s3_path.rstrip("/") + "/" if s3_path else "")
|
||||
+ team_alias_prefix
|
||||
+ start_time.strftime("%Y-%m-%d")
|
||||
+ "/"
|
||||
+ s3_file_name
|
||||
) # we need the s3 key to include the time, so we log cache hits too
|
||||
s3_object_key += ".json"
|
||||
return s3_object_key
|
||||
|
|
|
@ -416,6 +416,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
if custom_logger is None:
|
||||
continue
|
||||
old_name = model
|
||||
|
||||
model, messages, non_default_params = (
|
||||
custom_logger.get_chat_completion_prompt(
|
||||
model=model,
|
||||
|
@ -426,6 +428,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||
)
|
||||
)
|
||||
self.model_call_details["prompt_integration"] = old_name.split("/")[0]
|
||||
self.messages = messages
|
||||
|
||||
return model, messages, non_default_params
|
||||
|
@ -2790,7 +2793,9 @@ class StandardLoggingPayloadSetup:
|
|||
|
||||
@staticmethod
|
||||
def get_standard_logging_metadata(
|
||||
metadata: Optional[Dict[str, Any]], litellm_params: Optional[dict] = None
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
prompt_integration: Optional[str] = None,
|
||||
) -> StandardLoggingMetadata:
|
||||
"""
|
||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||
|
@ -2814,9 +2819,11 @@ class StandardLoggingPayloadSetup:
|
|||
Optional[dict], litellm_params.get("prompt_variables", None)
|
||||
)
|
||||
|
||||
if prompt_id is not None and prompt_integration is not None:
|
||||
prompt_management_metadata = StandardLoggingPromptManagementMetadata(
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
prompt_integration=prompt_integration,
|
||||
)
|
||||
|
||||
# Initialize with default values
|
||||
|
@ -3126,7 +3133,9 @@ def get_standard_logging_object_payload(
|
|||
)
|
||||
# clean up litellm metadata
|
||||
clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
||||
metadata=metadata, litellm_params=litellm_params
|
||||
metadata=metadata,
|
||||
litellm_params=litellm_params,
|
||||
prompt_integration=kwargs.get("prompt_integration", None),
|
||||
)
|
||||
|
||||
saved_cache_cost: float = 0.0
|
||||
|
|
|
@ -36,16 +36,15 @@ class OpenAILikeEmbeddingHandler(OpenAILikeBase):
|
|||
) -> EmbeddingResponse:
|
||||
response = None
|
||||
try:
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
self.async_client = get_async_httpx_client(
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.OPENAI,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
else:
|
||||
self.async_client = client
|
||||
|
||||
async_client = client
|
||||
try:
|
||||
response = await self.async_client.post(
|
||||
response = await async_client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
|
|
|
@ -3055,52 +3055,17 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
model=model, api_base=kwargs.get("api_base", None)
|
||||
)
|
||||
|
||||
response: Optional[EmbeddingResponse] = None
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider == "xinference"
|
||||
or custom_llm_provider == "voyage"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "triton"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "openrouter"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "nvidia_nim"
|
||||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "sambanova"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
or custom_llm_provider == "volcengine"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "fireworks_ai"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "gemini"
|
||||
or custom_llm_provider == "databricks"
|
||||
or custom_llm_provider == "watsonx"
|
||||
or custom_llm_provider == "cohere"
|
||||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "bedrock"
|
||||
or custom_llm_provider == "azure_ai"
|
||||
or custom_llm_provider == "together_ai"
|
||||
or custom_llm_provider == "openai_like"
|
||||
or custom_llm_provider == "jina_ai"
|
||||
or custom_llm_provider == "voyage"
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
response: Optional[EmbeddingResponse] = None
|
||||
if isinstance(init_response, dict):
|
||||
response = EmbeddingResponse(**init_response)
|
||||
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response # type: ignore
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if (
|
||||
response is not None
|
||||
and isinstance(response, EmbeddingResponse)
|
||||
|
|
|
@ -32,8 +32,13 @@ model_list:
|
|||
output_cost_per_token: 0.0000006
|
||||
|
||||
|
||||
# litellm_settings:
|
||||
# key_generation_settings:
|
||||
# personal_key_generation: # maps to 'Default Team' on UI
|
||||
# allowed_user_roles: ["proxy_admin"]
|
||||
litellm_settings:
|
||||
success_callback: ["s3"]
|
||||
enable_preview_features: true
|
||||
s3_callback_params:
|
||||
s3_bucket_name: my-new-test-bucket-litellm # AWS Bucket Name for S3
|
||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||
s3_use_team_prefix: true
|
||||
|
||||
|
|
|
@ -4163,31 +4163,6 @@ class Router:
|
|||
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
||||
)
|
||||
|
||||
# set region (if azure model) ## PREVIEW FEATURE ##
|
||||
if litellm.enable_preview_features is True:
|
||||
print("Auto inferring region") # noqa
|
||||
"""
|
||||
Hiding behind a feature flag
|
||||
When there is a large amount of LLM deployments this makes startup times blow up
|
||||
"""
|
||||
try:
|
||||
if (
|
||||
"azure" in deployment.litellm_params.model
|
||||
and deployment.litellm_params.region_name is None
|
||||
):
|
||||
region = litellm.utils.get_model_region(
|
||||
litellm_params=deployment.litellm_params, mode=None
|
||||
)
|
||||
|
||||
deployment.litellm_params.region_name = region
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(
|
||||
"Unable to get the region for azure model - {}, {}".format(
|
||||
deployment.litellm_params.model, str(e)
|
||||
)
|
||||
)
|
||||
pass # [NON-BLOCKING]
|
||||
|
||||
return deployment
|
||||
|
||||
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
||||
|
|
|
@ -1458,8 +1458,9 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
|||
|
||||
|
||||
class StandardLoggingPromptManagementMetadata(TypedDict):
|
||||
prompt_id: Optional[str]
|
||||
prompt_id: str
|
||||
prompt_variables: Optional[dict]
|
||||
prompt_integration: str
|
||||
|
||||
|
||||
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
||||
|
|
|
@ -1114,6 +1114,7 @@ generation_params = {
|
|||
def test_langfuse_prompt_type(prompt):
|
||||
|
||||
from litellm.integrations.langfuse.langfuse import _add_prompt_to_generation_params
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
clean_metadata = {
|
||||
"prompt": {
|
||||
|
@ -1215,7 +1216,10 @@ def test_langfuse_prompt_type(prompt):
|
|||
"cache_hit": False,
|
||||
}
|
||||
_add_prompt_to_generation_params(
|
||||
generation_params=generation_params, clean_metadata=clean_metadata
|
||||
generation_params=generation_params,
|
||||
clean_metadata=clean_metadata,
|
||||
prompt_management_metadata=None,
|
||||
langfuse_client=Mock(),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1019,18 +1019,27 @@ def test_hosted_vllm_embedding(monkeypatch):
|
|||
assert json_data["model"] == "jina-embeddings-v3"
|
||||
|
||||
|
||||
def test_lm_studio_embedding(monkeypatch):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_lm_studio_embedding(monkeypatch, sync_mode):
|
||||
monkeypatch.setenv("LM_STUDIO_API_BASE", "http://localhost:8000")
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
if sync_mode:
|
||||
embedding(
|
||||
model="lm_studio/jina-embeddings-v3",
|
||||
input=["Hello world"],
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
await litellm.aembedding(
|
||||
model="lm_studio/jina-embeddings-v3",
|
||||
input=["Hello world"],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue