mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Merge branch 'main' into litellm_arize_dynamic_logging
This commit is contained in:
commit
57e5c94360
13 changed files with 225 additions and 135 deletions
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
|
|
@ -129,17 +129,15 @@ def get_llm_provider( # noqa: PLR0915
|
|||
model, custom_llm_provider
|
||||
)
|
||||
|
||||
if custom_llm_provider:
|
||||
if (
|
||||
model.split("/")[0] == custom_llm_provider
|
||||
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
|
||||
model = model.replace("{}/".format(custom_llm_provider), "")
|
||||
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
if custom_llm_provider and (
|
||||
model.split("/")[0] != custom_llm_provider
|
||||
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
|
||||
model = custom_llm_provider + "/" + model
|
||||
|
||||
if api_key and api_key.startswith("os.environ/"):
|
||||
dynamic_api_key = get_secret_str(api_key)
|
||||
# check if llm provider part of model name
|
||||
|
||||
if (
|
||||
model.split("/", 1)[0] in litellm.provider_list
|
||||
and model.split("/", 1)[0] not in litellm.model_list_set
|
||||
|
@ -573,11 +571,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
|||
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
||||
elif custom_llm_provider == "snowflake":
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret("SNOWFLAKE_API_BASE")
|
||||
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT")
|
||||
api_base
|
||||
or get_secret_str("SNOWFLAKE_API_BASE")
|
||||
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
|
||||
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
||||
|
|
|
@ -3257,6 +3257,7 @@ class StandardLoggingPayloadSetup:
|
|||
additional_headers=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
litellm_model_name=None,
|
||||
)
|
||||
if hidden_params is not None:
|
||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
|
@ -3371,6 +3372,7 @@ def get_standard_logging_object_payload(
|
|||
response_cost=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
litellm_model_name=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -3656,6 +3658,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
|||
additional_headers=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
litellm_model_name=None,
|
||||
)
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
|
|
|
@ -44,6 +44,7 @@ class ResponseMetadata:
|
|||
"additional_headers": process_response_headers(
|
||||
self._get_value_from_hidden_params("additional_headers") or {}
|
||||
),
|
||||
"litellm_model_name": model,
|
||||
}
|
||||
self._update_hidden_params(new_params)
|
||||
|
||||
|
|
|
@ -65,10 +65,13 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
rpm_key = f"{model_id}:rpm:{current_minute}"
|
||||
deployment_name = deployment.get("litellm_params", {}).get("model")
|
||||
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
|
||||
|
||||
local_result = self.router_cache.get_cache(
|
||||
key=rpm_key, local_only=True
|
||||
) # check local result first
|
||||
|
@ -151,7 +154,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
rpm_key = f"{model_id}:rpm:{current_minute}"
|
||||
deployment_name = deployment.get("litellm_params", {}).get("model")
|
||||
|
||||
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
|
||||
local_result = await self.router_cache.async_get_cache(
|
||||
key=rpm_key, local_only=True
|
||||
) # check local result first
|
||||
|
@ -228,8 +233,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
if standard_logging_object is None:
|
||||
raise ValueError("standard_logging_object not passed in.")
|
||||
model_group = standard_logging_object.get("model_group")
|
||||
model = standard_logging_object["hidden_params"].get("litellm_model_name")
|
||||
id = standard_logging_object.get("model_id")
|
||||
if model_group is None or id is None:
|
||||
if model_group is None or id is None or model is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
@ -244,7 +250,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = f"{id}:tpm:{current_minute}"
|
||||
tpm_key = f"{id}:{model}:tpm:{current_minute}"
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
@ -276,6 +282,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
if standard_logging_object is None:
|
||||
raise ValueError("standard_logging_object not passed in.")
|
||||
model_group = standard_logging_object.get("model_group")
|
||||
model = standard_logging_object["hidden_params"]["litellm_model_name"]
|
||||
id = standard_logging_object.get("model_id")
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
|
@ -290,7 +297,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = f"{id}:tpm:{current_minute}"
|
||||
tpm_key = f"{id}:{model}:tpm:{current_minute}"
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
@ -458,8 +465,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
id = m.get("model_info", {}).get(
|
||||
"id"
|
||||
) # a deployment should always have an 'id'. this is set in router.py
|
||||
tpm_key = "{}:tpm:{}".format(id, current_minute)
|
||||
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||
deployment_name = m.get("litellm_params", {}).get("model")
|
||||
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
|
||||
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
|
||||
|
||||
tpm_keys.append(tpm_key)
|
||||
rpm_keys.append(rpm_key)
|
||||
|
@ -576,8 +584,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
|||
id = m.get("model_info", {}).get(
|
||||
"id"
|
||||
) # a deployment should always have an 'id'. this is set in router.py
|
||||
tpm_key = "{}:tpm:{}".format(id, current_minute)
|
||||
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||
deployment_name = m.get("litellm_params", {}).get("model")
|
||||
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
|
||||
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
|
||||
|
||||
tpm_keys.append(tpm_key)
|
||||
rpm_keys.append(rpm_key)
|
||||
|
|
|
@ -1625,13 +1625,16 @@ class StandardLoggingAdditionalHeaders(TypedDict, total=False):
|
|||
|
||||
|
||||
class StandardLoggingHiddenParams(TypedDict):
|
||||
model_id: Optional[str]
|
||||
model_id: Optional[
|
||||
str
|
||||
] # id of the model in the router, separates multiple models with the same name but different credentials
|
||||
cache_key: Optional[str]
|
||||
api_base: Optional[str]
|
||||
response_cost: Optional[str]
|
||||
litellm_overhead_time_ms: Optional[float]
|
||||
additional_headers: Optional[StandardLoggingAdditionalHeaders]
|
||||
batch_models: Optional[List[str]]
|
||||
litellm_model_name: Optional[str] # the model name sent to the provider by litellm
|
||||
|
||||
|
||||
class StandardLoggingModelInformation(TypedDict):
|
||||
|
|
|
@ -20,7 +20,8 @@ from litellm.caching.redis_cache import RedisCache
|
|||
|
||||
@pytest.mark.parametrize("namespace", [None, "test"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_async_increment(namespace):
|
||||
async def test_redis_cache_async_increment(namespace, monkeypatch):
|
||||
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
|
||||
redis_cache = RedisCache(namespace=namespace)
|
||||
# Create an AsyncMock for the Redis client
|
||||
mock_redis_instance = AsyncMock()
|
||||
|
@ -46,7 +47,8 @@ async def test_redis_cache_async_increment(namespace):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_client_init_with_socket_timeout():
|
||||
async def test_redis_client_init_with_socket_timeout(monkeypatch):
|
||||
monkeypatch.setenv("REDIS_HOST", "my-fake-host")
|
||||
redis_cache = RedisCache(socket_timeout=1.0)
|
||||
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
|
||||
client = redis_cache.init_async_client()
|
||||
|
|
|
@ -8,84 +8,24 @@ import pytest
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_async", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_request_format(is_async):
|
||||
async def test_get_openai_compatible_provider_info():
|
||||
"""
|
||||
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
||||
for both synchronous and asynchronous calls
|
||||
"""
|
||||
litellm._turn_on_debug()
|
||||
config = AzureAIStudioConfig()
|
||||
|
||||
# Set up the test parameters
|
||||
api_key = "00xxx"
|
||||
api_base = "https://my-endpoint-europe-berri-992.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview"
|
||||
model = "azure_ai/gpt-4o-mini"
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
api_base, dynamic_api_key, custom_llm_provider = (
|
||||
config._get_openai_compatible_provider_info(
|
||||
model="azure_ai/gpt-4o-mini",
|
||||
api_base="https://my-base",
|
||||
api_key="my-key",
|
||||
custom_llm_provider="azure_ai",
|
||||
)
|
||||
)
|
||||
|
||||
if is_async:
|
||||
# Mock AsyncHTTPHandler.post method for async test
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post"
|
||||
) as mock_post:
|
||||
# Set up mock response
|
||||
mock_post.return_value = AsyncMock()
|
||||
|
||||
# Call the acompletion function
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
custom_llm_provider="azure_ai",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
except Exception as e:
|
||||
# We expect an exception since we're mocking the response
|
||||
pass
|
||||
|
||||
else:
|
||||
# Mock HTTPHandler.post method for sync test
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.llm_http_handler.HTTPHandler.post"
|
||||
) as mock_post:
|
||||
# Set up mock response
|
||||
mock_post.return_value = MagicMock()
|
||||
|
||||
# Call the completion function
|
||||
try:
|
||||
litellm.completion(
|
||||
custom_llm_provider="azure_ai",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
except Exception as e:
|
||||
# We expect an exception since we're mocking the response
|
||||
pass
|
||||
|
||||
# Verify the request was made with the correct parameters
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
print("sync request call=", json.dumps(call_args.kwargs, indent=4, default=str))
|
||||
|
||||
# Check URL
|
||||
assert call_args.kwargs["url"] == api_base
|
||||
|
||||
# Check headers
|
||||
assert call_args.kwargs["headers"]["api-key"] == api_key
|
||||
|
||||
# Check request body
|
||||
request_body = json.loads(call_args.kwargs["data"])
|
||||
assert (
|
||||
request_body["model"] == "gpt-4o-mini"
|
||||
) # Model name should be stripped of provider prefix
|
||||
assert request_body["messages"] == messages
|
||||
assert custom_llm_provider == "azure"
|
||||
|
|
|
@ -265,3 +265,32 @@ class TestAzureAIRerank(BaseLLMRerankTest):
|
|||
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
|
||||
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_request_format():
|
||||
"""
|
||||
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
||||
for both synchronous and asynchronous calls
|
||||
"""
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
# Set up the test parameters
|
||||
api_key = os.getenv("AZURE_API_KEY")
|
||||
api_base = f"{os.getenv('AZURE_API_BASE')}/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
|
||||
model = "azure_ai/gpt-4o"
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
|
||||
await litellm.acompletion(
|
||||
custom_llm_provider="azure_ai",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
|
|
|
@ -2,7 +2,9 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import time
|
||||
from unittest.mock import patch, Mock
|
||||
import opentelemetry.exporter.otlp.proto.grpc.trace_exporter
|
||||
from litellm import Choices
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
@ -81,3 +83,52 @@ def test_get_arize_config_with_endpoints(mock_env_vars, monkeypatch):
|
|||
config = ArizeLogger.get_arize_config()
|
||||
assert config.endpoint == "grpc://test.endpoint"
|
||||
assert config.protocol == "otlp_grpc"
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Works locally but not in CI/CD. We'll need a better way to test Arize on CI/CD"
|
||||
)
|
||||
def test_arize_callback():
|
||||
litellm.callbacks = ["arize"]
|
||||
os.environ["ARIZE_SPACE_KEY"] = "test_space_key"
|
||||
os.environ["ARIZE_API_KEY"] = "test_api_key"
|
||||
os.environ["ARIZE_ENDPOINT"] = "https://otlp.arize.com/v1"
|
||||
|
||||
# Set the batch span processor to quickly flush after a span has been added
|
||||
# This is to ensure that the span is exported before the test ends
|
||||
os.environ["OTEL_BSP_MAX_QUEUE_SIZE"] = "1"
|
||||
os.environ["OTEL_BSP_MAX_EXPORT_BATCH_SIZE"] = "1"
|
||||
os.environ["OTEL_BSP_SCHEDULE_DELAY_MILLIS"] = "1"
|
||||
os.environ["OTEL_BSP_EXPORT_TIMEOUT_MILLIS"] = "5"
|
||||
|
||||
try:
|
||||
with patch.object(
|
||||
opentelemetry.exporter.otlp.proto.grpc.trace_exporter.OTLPSpanExporter,
|
||||
"export",
|
||||
new=Mock(),
|
||||
) as patched_export:
|
||||
litellm.completion(
|
||||
model="openai/test-model",
|
||||
messages=[{"role": "user", "content": "arize test content"}],
|
||||
stream=False,
|
||||
mock_response="hello there!",
|
||||
)
|
||||
|
||||
time.sleep(1) # Wait for the batch span processor to flush
|
||||
assert patched_export.called
|
||||
finally:
|
||||
# Clean up environment variables
|
||||
for key in [
|
||||
"ARIZE_SPACE_KEY",
|
||||
"ARIZE_API_KEY",
|
||||
"ARIZE_ENDPOINT",
|
||||
"OTEL_BSP_MAX_QUEUE_SIZE",
|
||||
"OTEL_BSP_MAX_EXPORT_BATCH_SIZE",
|
||||
"OTEL_BSP_SCHEDULE_DELAY_MILLIS",
|
||||
"OTEL_BSP_EXPORT_TIMEOUT_MILLIS",
|
||||
]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
# Reset callbacks
|
||||
litellm.callbacks = []
|
||||
|
|
|
@ -216,3 +216,21 @@ def test_bedrock_invoke_anthropic():
|
|||
)
|
||||
assert custom_llm_provider == "bedrock"
|
||||
assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["xai/grok-2-vision-latest", "grok-2-vision-latest"])
|
||||
def test_xai_api_base(model):
|
||||
args = {
|
||||
"model": model,
|
||||
"custom_llm_provider": "xai",
|
||||
"api_base": None,
|
||||
"api_key": "xai-my-specialkey",
|
||||
"litellm_params": None,
|
||||
}
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
**args
|
||||
)
|
||||
assert custom_llm_provider == "xai"
|
||||
assert model == "grok-2-vision-latest"
|
||||
assert api_base == "https://api.x.ai/v1"
|
||||
assert dynamic_api_key == "xai-my-specialkey"
|
||||
|
|
|
@ -20,7 +20,7 @@ sys.path.insert(
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
import pytest
|
||||
|
||||
from litellm.types.router import DeploymentTypedDict
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.caching.caching import DualCache
|
||||
|
@ -47,12 +47,14 @@ def test_tpm_rpm_updated():
|
|||
deployment_id = "1234"
|
||||
deployment = "azure/chatgpt-v-2"
|
||||
total_tokens = 50
|
||||
standard_logging_payload = create_standard_logging_payload()
|
||||
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
|
||||
standard_logging_payload["model_group"] = model_group
|
||||
standard_logging_payload["model_id"] = deployment_id
|
||||
standard_logging_payload["total_tokens"] = total_tokens
|
||||
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"model": deployment,
|
||||
"metadata": {
|
||||
"model_group": model_group,
|
||||
"deployment": deployment,
|
||||
|
@ -62,10 +64,16 @@ def test_tpm_rpm_updated():
|
|||
"standard_logging_object": standard_logging_payload,
|
||||
}
|
||||
|
||||
litellm_deployment_dict: DeploymentTypedDict = {
|
||||
"model_name": model_group,
|
||||
"litellm_params": {"model": deployment},
|
||||
"model_info": {"id": deployment_id},
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
response_obj = {"usage": {"total_tokens": total_tokens}}
|
||||
end_time = time.time()
|
||||
lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"])
|
||||
lowest_tpm_logger.pre_call_check(deployment=litellm_deployment_dict)
|
||||
lowest_tpm_logger.log_success_event(
|
||||
response_obj=response_obj,
|
||||
kwargs=kwargs,
|
||||
|
@ -74,8 +82,8 @@ def test_tpm_rpm_updated():
|
|||
)
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}"
|
||||
rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}"
|
||||
tpm_count_api_key = f"{deployment_id}:{deployment}:tpm:{current_minute}"
|
||||
rpm_count_api_key = f"{deployment_id}:{deployment}:rpm:{current_minute}"
|
||||
|
||||
print(f"tpm_count_api_key={tpm_count_api_key}")
|
||||
assert response_obj["usage"]["total_tokens"] == test_cache.get_cache(
|
||||
|
@ -113,6 +121,7 @@ def test_get_available_deployments():
|
|||
standard_logging_payload["model_group"] = model_group
|
||||
standard_logging_payload["model_id"] = deployment_id
|
||||
standard_logging_payload["total_tokens"] = total_tokens
|
||||
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
|
@ -135,10 +144,11 @@ def test_get_available_deployments():
|
|||
## DEPLOYMENT 2 ##
|
||||
total_tokens = 20
|
||||
deployment_id = "5678"
|
||||
standard_logging_payload = create_standard_logging_payload()
|
||||
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
|
||||
standard_logging_payload["model_group"] = model_group
|
||||
standard_logging_payload["model_id"] = deployment_id
|
||||
standard_logging_payload["total_tokens"] = total_tokens
|
||||
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
|
@ -209,11 +219,12 @@ def test_router_get_available_deployments():
|
|||
print(f"router id's: {router.get_model_ids()}")
|
||||
## DEPLOYMENT 1 ##
|
||||
deployment_id = 1
|
||||
standard_logging_payload = create_standard_logging_payload()
|
||||
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
|
||||
standard_logging_payload["model_group"] = "azure-model"
|
||||
standard_logging_payload["model_id"] = str(deployment_id)
|
||||
total_tokens = 50
|
||||
standard_logging_payload["total_tokens"] = total_tokens
|
||||
standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo"
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
|
@ -237,6 +248,9 @@ def test_router_get_available_deployments():
|
|||
standard_logging_payload = create_standard_logging_payload()
|
||||
standard_logging_payload["model_group"] = "azure-model"
|
||||
standard_logging_payload["model_id"] = str(deployment_id)
|
||||
standard_logging_payload["hidden_params"][
|
||||
"litellm_model_name"
|
||||
] = "azure/gpt-35-turbo"
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
|
@ -293,10 +307,11 @@ def test_router_skip_rate_limited_deployments():
|
|||
## DEPLOYMENT 1 ##
|
||||
deployment_id = 1
|
||||
total_tokens = 1439
|
||||
standard_logging_payload = create_standard_logging_payload()
|
||||
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
|
||||
standard_logging_payload["model_group"] = "azure-model"
|
||||
standard_logging_payload["model_id"] = str(deployment_id)
|
||||
standard_logging_payload["total_tokens"] = total_tokens
|
||||
standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo"
|
||||
kwargs = {
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
|
@ -699,3 +714,54 @@ def test_return_potential_deployments():
|
|||
)
|
||||
|
||||
assert len(potential_deployments) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tpm_rpm_routing_model_name_checks():
|
||||
deployment = {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"mock_response": "Hey, how's it going?",
|
||||
},
|
||||
}
|
||||
router = Router(model_list=[deployment], routing_strategy="usage-based-routing-v2")
|
||||
|
||||
async def side_effect_pre_call_check(*args, **kwargs):
|
||||
return args[0]
|
||||
|
||||
with patch.object(
|
||||
router.lowesttpm_logger_v2,
|
||||
"async_pre_call_check",
|
||||
side_effect=side_effect_pre_call_check,
|
||||
) as mock_object, patch.object(
|
||||
router.lowesttpm_logger_v2, "async_log_success_event"
|
||||
) as mock_logging_event:
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey!"}]
|
||||
)
|
||||
|
||||
mock_object.assert_called()
|
||||
print(f"mock_object.call_args: {mock_object.call_args[0][0]}")
|
||||
assert (
|
||||
mock_object.call_args[0][0]["litellm_params"]["model"]
|
||||
== deployment["litellm_params"]["model"]
|
||||
)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
mock_logging_event.assert_called()
|
||||
|
||||
print(f"mock_logging_event: {mock_logging_event.call_args.kwargs}")
|
||||
standard_logging_payload: StandardLoggingPayload = (
|
||||
mock_logging_event.call_args.kwargs.get("kwargs", {}).get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
)
|
||||
|
||||
assert (
|
||||
standard_logging_payload["hidden_params"]["litellm_model_name"]
|
||||
== "azure/chatgpt-v-2"
|
||||
)
|
||||
|
|
|
@ -15,35 +15,6 @@ import litellm
|
|||
from litellm.types.utils import Choices
|
||||
|
||||
|
||||
def test_arize_callback():
|
||||
litellm.callbacks = ["arize"]
|
||||
os.environ["ARIZE_SPACE_KEY"] = "test_space_key"
|
||||
os.environ["ARIZE_API_KEY"] = "test_api_key"
|
||||
os.environ["ARIZE_ENDPOINT"] = "https://otlp.arize.com/v1"
|
||||
|
||||
# Set the batch span processor to quickly flush after a span has been added
|
||||
# This is to ensure that the span is exported before the test ends
|
||||
os.environ["OTEL_BSP_MAX_QUEUE_SIZE"] = "1"
|
||||
os.environ["OTEL_BSP_MAX_EXPORT_BATCH_SIZE"] = "1"
|
||||
os.environ["OTEL_BSP_SCHEDULE_DELAY_MILLIS"] = "1"
|
||||
os.environ["OTEL_BSP_EXPORT_TIMEOUT_MILLIS"] = "5"
|
||||
|
||||
with patch.object(
|
||||
opentelemetry.exporter.otlp.proto.grpc.trace_exporter.OTLPSpanExporter,
|
||||
"export",
|
||||
new=Mock(),
|
||||
) as patched_export:
|
||||
completion(
|
||||
model="openai/test-model",
|
||||
messages=[{"role": "user", "content": "arize test content"}],
|
||||
stream=False,
|
||||
mock_response="hello there!",
|
||||
)
|
||||
|
||||
time.sleep(1) # Wait for the batch span processor to flush
|
||||
assert patched_export.called
|
||||
|
||||
|
||||
def test_arize_set_attributes():
|
||||
"""
|
||||
Test setting attributes for Arize
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue