mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Litellm dev 01 10 2025 p3 (#7682)
* feat(langfuse.py): log the used prompt when prompt management used * test: fix test * docs(self_serve.md): add doc on restricting personal key creation on ui * feat(s3.py): support s3 logging with team alias prefixes (if available) New preview feature * fix(main.py): remove old if block - simplify to just await if coroutine returned fixes lm_studio async embedding error * fix(langfuse.py): handle get prompt check
This commit is contained in:
parent
e54d23c919
commit
953c021aa7
11 changed files with 148 additions and 112 deletions
|
@ -231,6 +231,14 @@ curl -X POST '<PROXY_BASE_URL>/team/new' \
|
||||||
|
|
||||||
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
||||||
|
|
||||||
|
### Restrict Users from creating personal keys
|
||||||
|
|
||||||
|
This is useful if you only want users to create keys under a specific team.
|
||||||
|
|
||||||
|
This will also prevent users from using their session tokens on the test keys chat pane.
|
||||||
|
|
||||||
|
👉 [**See this**](./virtual_keys.md#restricting-key-generation)
|
||||||
|
|
||||||
## **All Settings for Self Serve / SSO Flow**
|
## **All Settings for Self Serve / SSO Flow**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|
|
@ -15,7 +15,10 @@ from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
|
||||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
from litellm.types.integrations.langfuse import *
|
from litellm.types.integrations.langfuse import *
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import (
|
||||||
|
StandardLoggingPayload,
|
||||||
|
StandardLoggingPromptManagementMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
|
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
|
||||||
|
@ -463,14 +466,16 @@ class LangFuseLogger:
|
||||||
|
|
||||||
if standard_logging_object is None:
|
if standard_logging_object is None:
|
||||||
end_user_id = None
|
end_user_id = None
|
||||||
prompt_management_metadata: Optional[dict] = None
|
prompt_management_metadata: Optional[
|
||||||
|
StandardLoggingPromptManagementMetadata
|
||||||
|
] = None
|
||||||
else:
|
else:
|
||||||
end_user_id = standard_logging_object["metadata"].get(
|
end_user_id = standard_logging_object["metadata"].get(
|
||||||
"user_api_key_end_user_id", None
|
"user_api_key_end_user_id", None
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_management_metadata = cast(
|
prompt_management_metadata = cast(
|
||||||
Optional[dict],
|
Optional[StandardLoggingPromptManagementMetadata],
|
||||||
standard_logging_object["metadata"].get(
|
standard_logging_object["metadata"].get(
|
||||||
"prompt_management_metadata", None
|
"prompt_management_metadata", None
|
||||||
),
|
),
|
||||||
|
@ -707,7 +712,10 @@ class LangFuseLogger:
|
||||||
|
|
||||||
if supports_prompt:
|
if supports_prompt:
|
||||||
generation_params = _add_prompt_to_generation_params(
|
generation_params = _add_prompt_to_generation_params(
|
||||||
generation_params=generation_params, clean_metadata=clean_metadata
|
generation_params=generation_params,
|
||||||
|
clean_metadata=clean_metadata,
|
||||||
|
prompt_management_metadata=prompt_management_metadata,
|
||||||
|
langfuse_client=self.Langfuse,
|
||||||
)
|
)
|
||||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||||
generation_params["status_message"] = output
|
generation_params["status_message"] = output
|
||||||
|
@ -753,8 +761,12 @@ class LangFuseLogger:
|
||||||
|
|
||||||
|
|
||||||
def _add_prompt_to_generation_params(
|
def _add_prompt_to_generation_params(
|
||||||
generation_params: dict, clean_metadata: dict
|
generation_params: dict,
|
||||||
|
clean_metadata: dict,
|
||||||
|
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata],
|
||||||
|
langfuse_client: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
from langfuse import Langfuse
|
||||||
from langfuse.model import (
|
from langfuse.model import (
|
||||||
ChatPromptClient,
|
ChatPromptClient,
|
||||||
Prompt_Chat,
|
Prompt_Chat,
|
||||||
|
@ -762,8 +774,10 @@ def _add_prompt_to_generation_params(
|
||||||
TextPromptClient,
|
TextPromptClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
langfuse_client = cast(Langfuse, langfuse_client)
|
||||||
|
|
||||||
user_prompt = clean_metadata.pop("prompt", None)
|
user_prompt = clean_metadata.pop("prompt", None)
|
||||||
if user_prompt is None:
|
if user_prompt is None and prompt_management_metadata is None:
|
||||||
pass
|
pass
|
||||||
elif isinstance(user_prompt, dict):
|
elif isinstance(user_prompt, dict):
|
||||||
if user_prompt.get("type", "") == "chat":
|
if user_prompt.get("type", "") == "chat":
|
||||||
|
@ -815,6 +829,20 @@ def _add_prompt_to_generation_params(
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
"[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse"
|
"[Non-blocking] Langfuse Logger: Invalid prompt format. No prompt logged to Langfuse"
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
prompt_management_metadata is not None
|
||||||
|
and prompt_management_metadata["prompt_integration"] == "langfuse"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
generation_params["prompt"] = langfuse_client.get_prompt(
|
||||||
|
prompt_management_metadata["prompt_id"]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"[Non-blocking] Langfuse Logger: Error getting prompt client for logging: {e}"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
else:
|
else:
|
||||||
generation_params["prompt"] = user_prompt
|
generation_params["prompt"] = user_prompt
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
from typing import Optional
|
from datetime import datetime
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
@ -32,6 +33,8 @@ class S3Logger:
|
||||||
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
|
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
s3_use_team_prefix = False
|
||||||
|
|
||||||
if litellm.s3_callback_params is not None:
|
if litellm.s3_callback_params is not None:
|
||||||
# read in .env variables - example os.environ/AWS_BUCKET_NAME
|
# read in .env variables - example os.environ/AWS_BUCKET_NAME
|
||||||
for key, value in litellm.s3_callback_params.items():
|
for key, value in litellm.s3_callback_params.items():
|
||||||
|
@ -56,7 +59,10 @@ class S3Logger:
|
||||||
s3_config = litellm.s3_callback_params.get("s3_config")
|
s3_config = litellm.s3_callback_params.get("s3_config")
|
||||||
s3_path = litellm.s3_callback_params.get("s3_path")
|
s3_path = litellm.s3_callback_params.get("s3_path")
|
||||||
# done reading litellm.s3_callback_params
|
# done reading litellm.s3_callback_params
|
||||||
|
s3_use_team_prefix = bool(
|
||||||
|
litellm.s3_callback_params.get("s3_use_team_prefix", False)
|
||||||
|
)
|
||||||
|
self.s3_use_team_prefix = s3_use_team_prefix
|
||||||
self.bucket_name = s3_bucket_name
|
self.bucket_name = s3_bucket_name
|
||||||
self.s3_path = s3_path
|
self.s3_path = s3_path
|
||||||
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
|
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
|
||||||
|
@ -114,21 +120,31 @@ class S3Logger:
|
||||||
clean_metadata[key] = value
|
clean_metadata[key] = value
|
||||||
|
|
||||||
# Ensure everything in the payload is converted to str
|
# Ensure everything in the payload is converted to str
|
||||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
payload: Optional[StandardLoggingPayload] = cast(
|
||||||
"standard_logging_object", None
|
Optional[StandardLoggingPayload],
|
||||||
|
kwargs.get("standard_logging_object", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
if payload is None:
|
if payload is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
team_alias = payload["metadata"].get("user_api_key_team_alias")
|
||||||
|
|
||||||
|
team_alias_prefix = ""
|
||||||
|
if (
|
||||||
|
litellm.enable_preview_features
|
||||||
|
and self.s3_use_team_prefix
|
||||||
|
and team_alias is not None
|
||||||
|
):
|
||||||
|
team_alias_prefix = f"{team_alias}/"
|
||||||
|
|
||||||
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
|
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
|
||||||
s3_object_key = (
|
s3_object_key = get_s3_object_key(
|
||||||
(self.s3_path.rstrip("/") + "/" if self.s3_path else "")
|
cast(Optional[str], self.s3_path) or "",
|
||||||
+ start_time.strftime("%Y-%m-%d")
|
team_alias_prefix,
|
||||||
+ "/"
|
start_time,
|
||||||
+ s3_file_name
|
s3_file_name,
|
||||||
) # we need the s3 key to include the time, so we log cache hits too
|
)
|
||||||
s3_object_key += ".json"
|
|
||||||
|
|
||||||
s3_object_download_filename = (
|
s3_object_download_filename = (
|
||||||
"time-"
|
"time-"
|
||||||
|
@ -161,3 +177,20 @@ class S3Logger:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_s3_object_key(
|
||||||
|
s3_path: str,
|
||||||
|
team_alias_prefix: str,
|
||||||
|
start_time: datetime,
|
||||||
|
s3_file_name: str,
|
||||||
|
) -> str:
|
||||||
|
s3_object_key = (
|
||||||
|
(s3_path.rstrip("/") + "/" if s3_path else "")
|
||||||
|
+ team_alias_prefix
|
||||||
|
+ start_time.strftime("%Y-%m-%d")
|
||||||
|
+ "/"
|
||||||
|
+ s3_file_name
|
||||||
|
) # we need the s3 key to include the time, so we log cache hits too
|
||||||
|
s3_object_key += ".json"
|
||||||
|
return s3_object_key
|
||||||
|
|
|
@ -416,6 +416,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
|
|
||||||
if custom_logger is None:
|
if custom_logger is None:
|
||||||
continue
|
continue
|
||||||
|
old_name = model
|
||||||
|
|
||||||
model, messages, non_default_params = (
|
model, messages, non_default_params = (
|
||||||
custom_logger.get_chat_completion_prompt(
|
custom_logger.get_chat_completion_prompt(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -426,6 +428,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.model_call_details["prompt_integration"] = old_name.split("/")[0]
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
|
|
||||||
return model, messages, non_default_params
|
return model, messages, non_default_params
|
||||||
|
@ -2790,7 +2793,9 @@ class StandardLoggingPayloadSetup:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_standard_logging_metadata(
|
def get_standard_logging_metadata(
|
||||||
metadata: Optional[Dict[str, Any]], litellm_params: Optional[dict] = None
|
metadata: Optional[Dict[str, Any]],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
|
prompt_integration: Optional[str] = None,
|
||||||
) -> StandardLoggingMetadata:
|
) -> StandardLoggingMetadata:
|
||||||
"""
|
"""
|
||||||
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||||
|
@ -2814,9 +2819,11 @@ class StandardLoggingPayloadSetup:
|
||||||
Optional[dict], litellm_params.get("prompt_variables", None)
|
Optional[dict], litellm_params.get("prompt_variables", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prompt_id is not None and prompt_integration is not None:
|
||||||
prompt_management_metadata = StandardLoggingPromptManagementMetadata(
|
prompt_management_metadata = StandardLoggingPromptManagementMetadata(
|
||||||
prompt_id=prompt_id,
|
prompt_id=prompt_id,
|
||||||
prompt_variables=prompt_variables,
|
prompt_variables=prompt_variables,
|
||||||
|
prompt_integration=prompt_integration,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize with default values
|
# Initialize with default values
|
||||||
|
@ -3126,7 +3133,9 @@ def get_standard_logging_object_payload(
|
||||||
)
|
)
|
||||||
# clean up litellm metadata
|
# clean up litellm metadata
|
||||||
clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
clean_metadata = StandardLoggingPayloadSetup.get_standard_logging_metadata(
|
||||||
metadata=metadata, litellm_params=litellm_params
|
metadata=metadata,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
prompt_integration=kwargs.get("prompt_integration", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
saved_cache_cost: float = 0.0
|
saved_cache_cost: float = 0.0
|
||||||
|
|
|
@ -36,16 +36,15 @@ class OpenAILikeEmbeddingHandler(OpenAILikeBase):
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||||
self.async_client = get_async_httpx_client(
|
async_client = get_async_httpx_client(
|
||||||
llm_provider=litellm.LlmProviders.OPENAI,
|
llm_provider=litellm.LlmProviders.OPENAI,
|
||||||
params={"timeout": timeout},
|
params={"timeout": timeout},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.async_client = client
|
async_client = client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.async_client.post(
|
response = await async_client.post(
|
||||||
api_base,
|
api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
|
|
@ -3055,52 +3055,17 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
||||||
model=model, api_base=kwargs.get("api_base", None)
|
model=model, api_base=kwargs.get("api_base", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
response: Optional[EmbeddingResponse] = None
|
|
||||||
if (
|
|
||||||
custom_llm_provider == "openai"
|
|
||||||
or custom_llm_provider == "azure"
|
|
||||||
or custom_llm_provider == "xinference"
|
|
||||||
or custom_llm_provider == "voyage"
|
|
||||||
or custom_llm_provider == "mistral"
|
|
||||||
or custom_llm_provider == "custom_openai"
|
|
||||||
or custom_llm_provider == "triton"
|
|
||||||
or custom_llm_provider == "anyscale"
|
|
||||||
or custom_llm_provider == "openrouter"
|
|
||||||
or custom_llm_provider == "deepinfra"
|
|
||||||
or custom_llm_provider == "perplexity"
|
|
||||||
or custom_llm_provider == "groq"
|
|
||||||
or custom_llm_provider == "nvidia_nim"
|
|
||||||
or custom_llm_provider == "cerebras"
|
|
||||||
or custom_llm_provider == "sambanova"
|
|
||||||
or custom_llm_provider == "ai21_chat"
|
|
||||||
or custom_llm_provider == "volcengine"
|
|
||||||
or custom_llm_provider == "deepseek"
|
|
||||||
or custom_llm_provider == "fireworks_ai"
|
|
||||||
or custom_llm_provider == "ollama"
|
|
||||||
or custom_llm_provider == "vertex_ai"
|
|
||||||
or custom_llm_provider == "gemini"
|
|
||||||
or custom_llm_provider == "databricks"
|
|
||||||
or custom_llm_provider == "watsonx"
|
|
||||||
or custom_llm_provider == "cohere"
|
|
||||||
or custom_llm_provider == "huggingface"
|
|
||||||
or custom_llm_provider == "bedrock"
|
|
||||||
or custom_llm_provider == "azure_ai"
|
|
||||||
or custom_llm_provider == "together_ai"
|
|
||||||
or custom_llm_provider == "openai_like"
|
|
||||||
or custom_llm_provider == "jina_ai"
|
|
||||||
or custom_llm_provider == "voyage"
|
|
||||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
|
||||||
|
response: Optional[EmbeddingResponse] = None
|
||||||
if isinstance(init_response, dict):
|
if isinstance(init_response, dict):
|
||||||
response = EmbeddingResponse(**init_response)
|
response = EmbeddingResponse(**init_response)
|
||||||
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
|
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
|
||||||
response = init_response
|
response = init_response
|
||||||
elif asyncio.iscoroutine(init_response):
|
elif asyncio.iscoroutine(init_response):
|
||||||
response = await init_response # type: ignore
|
response = await init_response # type: ignore
|
||||||
else:
|
|
||||||
# Call the synchronous function using run_in_executor
|
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
|
||||||
if (
|
if (
|
||||||
response is not None
|
response is not None
|
||||||
and isinstance(response, EmbeddingResponse)
|
and isinstance(response, EmbeddingResponse)
|
||||||
|
|
|
@ -32,8 +32,13 @@ model_list:
|
||||||
output_cost_per_token: 0.0000006
|
output_cost_per_token: 0.0000006
|
||||||
|
|
||||||
|
|
||||||
# litellm_settings:
|
litellm_settings:
|
||||||
# key_generation_settings:
|
success_callback: ["s3"]
|
||||||
# personal_key_generation: # maps to 'Default Team' on UI
|
enable_preview_features: true
|
||||||
# allowed_user_roles: ["proxy_admin"]
|
s3_callback_params:
|
||||||
|
s3_bucket_name: my-new-test-bucket-litellm # AWS Bucket Name for S3
|
||||||
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||||
|
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||||
|
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||||
|
s3_use_team_prefix: true
|
||||||
|
|
||||||
|
|
|
@ -4163,31 +4163,6 @@ class Router:
|
||||||
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# set region (if azure model) ## PREVIEW FEATURE ##
|
|
||||||
if litellm.enable_preview_features is True:
|
|
||||||
print("Auto inferring region") # noqa
|
|
||||||
"""
|
|
||||||
Hiding behind a feature flag
|
|
||||||
When there is a large amount of LLM deployments this makes startup times blow up
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if (
|
|
||||||
"azure" in deployment.litellm_params.model
|
|
||||||
and deployment.litellm_params.region_name is None
|
|
||||||
):
|
|
||||||
region = litellm.utils.get_model_region(
|
|
||||||
litellm_params=deployment.litellm_params, mode=None
|
|
||||||
)
|
|
||||||
|
|
||||||
deployment.litellm_params.region_name = region
|
|
||||||
except Exception as e:
|
|
||||||
verbose_router_logger.debug(
|
|
||||||
"Unable to get the region for azure model - {}, {}".format(
|
|
||||||
deployment.litellm_params.model, str(e)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
pass # [NON-BLOCKING]
|
|
||||||
|
|
||||||
return deployment
|
return deployment
|
||||||
|
|
||||||
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
||||||
|
|
|
@ -1458,8 +1458,9 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
||||||
|
|
||||||
|
|
||||||
class StandardLoggingPromptManagementMetadata(TypedDict):
|
class StandardLoggingPromptManagementMetadata(TypedDict):
|
||||||
prompt_id: Optional[str]
|
prompt_id: str
|
||||||
prompt_variables: Optional[dict]
|
prompt_variables: Optional[dict]
|
||||||
|
prompt_integration: str
|
||||||
|
|
||||||
|
|
||||||
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
||||||
|
|
|
@ -1114,6 +1114,7 @@ generation_params = {
|
||||||
def test_langfuse_prompt_type(prompt):
|
def test_langfuse_prompt_type(prompt):
|
||||||
|
|
||||||
from litellm.integrations.langfuse.langfuse import _add_prompt_to_generation_params
|
from litellm.integrations.langfuse.langfuse import _add_prompt_to_generation_params
|
||||||
|
from unittest.mock import patch, MagicMock, Mock
|
||||||
|
|
||||||
clean_metadata = {
|
clean_metadata = {
|
||||||
"prompt": {
|
"prompt": {
|
||||||
|
@ -1215,7 +1216,10 @@ def test_langfuse_prompt_type(prompt):
|
||||||
"cache_hit": False,
|
"cache_hit": False,
|
||||||
}
|
}
|
||||||
_add_prompt_to_generation_params(
|
_add_prompt_to_generation_params(
|
||||||
generation_params=generation_params, clean_metadata=clean_metadata
|
generation_params=generation_params,
|
||||||
|
clean_metadata=clean_metadata,
|
||||||
|
prompt_management_metadata=None,
|
||||||
|
langfuse_client=Mock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1019,18 +1019,27 @@ def test_hosted_vllm_embedding(monkeypatch):
|
||||||
assert json_data["model"] == "jina-embeddings-v3"
|
assert json_data["model"] == "jina-embeddings-v3"
|
||||||
|
|
||||||
|
|
||||||
def test_lm_studio_embedding(monkeypatch):
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_lm_studio_embedding(monkeypatch, sync_mode):
|
||||||
monkeypatch.setenv("LM_STUDIO_API_BASE", "http://localhost:8000")
|
monkeypatch.setenv("LM_STUDIO_API_BASE", "http://localhost:8000")
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
|
|
||||||
client = HTTPHandler()
|
client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
|
||||||
with patch.object(client, "post") as mock_post:
|
with patch.object(client, "post") as mock_post:
|
||||||
try:
|
try:
|
||||||
|
if sync_mode:
|
||||||
embedding(
|
embedding(
|
||||||
model="lm_studio/jina-embeddings-v3",
|
model="lm_studio/jina-embeddings-v3",
|
||||||
input=["Hello world"],
|
input=["Hello world"],
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
await litellm.aembedding(
|
||||||
|
model="lm_studio/jina-embeddings-v3",
|
||||||
|
input=["Hello world"],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue