mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Litellm dev 12 12 2024 (#7203)
* fix(azure/): support passing headers to azure openai endpoints
Fixes https://github.com/BerriAI/litellm/issues/6217
* fix(utils.py): move default tokenizer to just openai
hf tokenizer makes network calls when trying to get the tokenizer - this slows down execution time calls
* fix(router.py): fix pattern matching router - add generic "*" to it as well
Fixes issue where generic "*" model access group wouldn't show up
* fix(pattern_match_deployments.py): match to more specific pattern
match to more specific pattern
allows setting generic wildcard model access group and excluding specific models more easily
* fix(proxy_server.py): fix _delete_deployment to handle base case where db_model list is empty
don't delete all router models b/c of empty list
Fixes https://github.com/BerriAI/litellm/issues/7196
* fix(anthropic/): fix handling response_format for anthropic messages with anthropic api
* fix(fireworks_ai/): support passing response_format + tool call in same message
Addresses https://github.com/BerriAI/litellm/issues/7135
* Revert "fix(fireworks_ai/): support passing response_format + tool call in same message"
This reverts commit 6a30dc6929
.
* test: fix test
* fix(replicate/): fix replicate default retry/polling logic
* test: add unit testing for router pattern matching
* test: update test to use default oai tokenizer
* test: mark flaky test
* test: skip flaky test
This commit is contained in:
parent
e65f990319
commit
a42f008cd0
19 changed files with 496 additions and 103 deletions
|
@ -22,6 +22,8 @@ from litellm.constants import (
|
||||||
DEFAULT_FLUSH_INTERVAL_SECONDS,
|
DEFAULT_FLUSH_INTERVAL_SECONDS,
|
||||||
ROUTER_MAX_FALLBACKS,
|
ROUTER_MAX_FALLBACKS,
|
||||||
DEFAULT_MAX_RETRIES,
|
DEFAULT_MAX_RETRIES,
|
||||||
|
DEFAULT_REPLICATE_POLLING_RETRIES,
|
||||||
|
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS,
|
||||||
LITELLM_CHAT_PROVIDERS,
|
LITELLM_CHAT_PROVIDERS,
|
||||||
)
|
)
|
||||||
from litellm.types.guardrails import GuardrailItem
|
from litellm.types.guardrails import GuardrailItem
|
||||||
|
|
|
@ -2,6 +2,8 @@ ROUTER_MAX_FALLBACKS = 5
|
||||||
DEFAULT_BATCH_SIZE = 512
|
DEFAULT_BATCH_SIZE = 512
|
||||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||||
DEFAULT_MAX_RETRIES = 2
|
DEFAULT_MAX_RETRIES = 2
|
||||||
|
DEFAULT_REPLICATE_POLLING_RETRIES = 5
|
||||||
|
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
|
||||||
DEFAULT_IMAGE_TOKEN_COUNT = 250
|
DEFAULT_IMAGE_TOKEN_COUNT = 250
|
||||||
DEFAULT_IMAGE_WIDTH = 300
|
DEFAULT_IMAGE_WIDTH = 300
|
||||||
DEFAULT_IMAGE_HEIGHT = 300
|
DEFAULT_IMAGE_HEIGHT = 300
|
||||||
|
@ -67,6 +69,7 @@ LITELLM_CHAT_PROVIDERS = [
|
||||||
"galadriel",
|
"galadriel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when converting response format to tool call
|
||||||
|
|
||||||
########################### LiteLLM Proxy Specific Constants ###########################
|
########################### LiteLLM Proxy Specific Constants ###########################
|
||||||
MAX_SPENDLOG_ROWS_TO_QUERY = (
|
MAX_SPENDLOG_ROWS_TO_QUERY = (
|
||||||
|
@ -74,4 +77,3 @@ MAX_SPENDLOG_ROWS_TO_QUERY = (
|
||||||
)
|
)
|
||||||
# makes it clear this is a rate limit error for a litellm virtual key
|
# makes it clear this is a rate limit error for a litellm virtual key
|
||||||
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
|
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,10 @@ import httpx
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
|
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.anthropic import (
|
from litellm.types.llms.anthropic import (
|
||||||
AllAnthropicToolsValues,
|
AllAnthropicToolsValues,
|
||||||
AnthropicComputerTool,
|
AnthropicComputerTool,
|
||||||
|
@ -298,6 +299,18 @@ class AnthropicConfig(BaseConfig):
|
||||||
new_stop = new_v
|
new_stop = new_v
|
||||||
return new_stop
|
return new_stop
|
||||||
|
|
||||||
|
def _add_tools_to_optional_params(
|
||||||
|
self, optional_params: dict, tools: List[AllAnthropicToolsValues]
|
||||||
|
) -> dict:
|
||||||
|
if "tools" not in optional_params:
|
||||||
|
optional_params["tools"] = tools
|
||||||
|
else:
|
||||||
|
optional_params["tools"] = [
|
||||||
|
*optional_params["tools"],
|
||||||
|
*tools,
|
||||||
|
]
|
||||||
|
return optional_params
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
|
@ -311,7 +324,11 @@ class AnthropicConfig(BaseConfig):
|
||||||
if param == "max_completion_tokens":
|
if param == "max_completion_tokens":
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
if param == "tools":
|
if param == "tools":
|
||||||
optional_params["tools"] = self._map_tools(value)
|
# check if optional params already has tools
|
||||||
|
tool_value = self._map_tools(value)
|
||||||
|
optional_params = self._add_tools_to_optional_params(
|
||||||
|
optional_params=optional_params, tools=tool_value
|
||||||
|
)
|
||||||
if param == "tool_choice" or param == "parallel_tool_calls":
|
if param == "tool_choice" or param == "parallel_tool_calls":
|
||||||
_tool_choice: Optional[AnthropicMessagesToolChoice] = (
|
_tool_choice: Optional[AnthropicMessagesToolChoice] = (
|
||||||
self._map_tool_choice(
|
self._map_tool_choice(
|
||||||
|
@ -333,6 +350,7 @@ class AnthropicConfig(BaseConfig):
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
if param == "response_format" and isinstance(value, dict):
|
if param == "response_format" and isinstance(value, dict):
|
||||||
|
|
||||||
json_schema: Optional[dict] = None
|
json_schema: Optional[dict] = None
|
||||||
if "response_schema" in value:
|
if "response_schema" in value:
|
||||||
json_schema = value["response_schema"]
|
json_schema = value["response_schema"]
|
||||||
|
@ -344,11 +362,14 @@ class AnthropicConfig(BaseConfig):
|
||||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||||
"""
|
"""
|
||||||
_tool_choice = {"name": "json_tool_call", "type": "tool"}
|
|
||||||
|
_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
|
||||||
_tool = self._create_json_tool_call_for_response_format(
|
_tool = self._create_json_tool_call_for_response_format(
|
||||||
json_schema=json_schema,
|
json_schema=json_schema,
|
||||||
)
|
)
|
||||||
optional_params["tools"] = [_tool]
|
optional_params = self._add_tools_to_optional_params(
|
||||||
|
optional_params=optional_params, tools=[_tool]
|
||||||
|
)
|
||||||
optional_params["tool_choice"] = _tool_choice
|
optional_params["tool_choice"] = _tool_choice
|
||||||
optional_params["json_mode"] = True
|
optional_params["json_mode"] = True
|
||||||
if param == "user":
|
if param == "user":
|
||||||
|
@ -381,7 +402,9 @@ class AnthropicConfig(BaseConfig):
|
||||||
else:
|
else:
|
||||||
_input_schema["properties"] = {"values": json_schema}
|
_input_schema["properties"] = {"values": json_schema}
|
||||||
|
|
||||||
_tool = AnthropicMessagesTool(name="json_tool_call", input_schema=_input_schema)
|
_tool = AnthropicMessagesTool(
|
||||||
|
name=RESPONSE_FORMAT_TOOL_NAME, input_schema=_input_schema
|
||||||
|
)
|
||||||
return _tool
|
return _tool
|
||||||
|
|
||||||
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
|
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
|
||||||
|
@ -537,10 +560,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
## Handle Tool Calling
|
|
||||||
if "tools" in optional_params:
|
|
||||||
_is_function_call = True
|
|
||||||
|
|
||||||
## Handle user_id in metadata
|
## Handle user_id in metadata
|
||||||
_litellm_metadata = litellm_params.get("metadata", None)
|
_litellm_metadata = litellm_params.get("metadata", None)
|
||||||
if (
|
if (
|
||||||
|
@ -558,6 +577,26 @@ class AnthropicConfig(BaseConfig):
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def _transform_response_for_json_mode(
|
||||||
|
self,
|
||||||
|
json_mode: Optional[bool],
|
||||||
|
tool_calls: List[ChatCompletionToolCallChunk],
|
||||||
|
) -> Optional[LitellmMessage]:
|
||||||
|
_message: Optional[LitellmMessage] = None
|
||||||
|
if json_mode is True and len(tool_calls) == 1:
|
||||||
|
# check if tool name is the default tool name
|
||||||
|
json_mode_content_str: Optional[str] = None
|
||||||
|
if (
|
||||||
|
"name" in tool_calls[0]["function"]
|
||||||
|
and tool_calls[0]["function"]["name"] == RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
):
|
||||||
|
json_mode_content_str = tool_calls[0]["function"].get("arguments")
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
_message = AnthropicConfig._convert_tool_response_to_message(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
return _message
|
||||||
|
|
||||||
def transform_response(
|
def transform_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -629,19 +668,14 @@ class AnthropicConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
## HANDLE JSON MODE - anthropic returns single function call
|
## HANDLE JSON MODE - anthropic returns single function call
|
||||||
if json_mode is True and len(tool_calls) == 1:
|
json_mode_message = self._transform_response_for_json_mode(
|
||||||
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
json_mode=json_mode,
|
||||||
"arguments"
|
|
||||||
)
|
|
||||||
if json_mode_content_str is not None:
|
|
||||||
_converted_message = (
|
|
||||||
AnthropicConfig._convert_tool_response_to_message(
|
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
)
|
if json_mode_message is not None:
|
||||||
if _converted_message is not None:
|
|
||||||
completion_response["stop_reason"] = "stop"
|
completion_response["stop_reason"] = "stop"
|
||||||
_message = _converted_message
|
_message = json_mode_message
|
||||||
|
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
model_response._hidden_params["original_response"] = completion_response[
|
model_response._hidden_params["original_response"] = completion_response[
|
||||||
"content"
|
"content"
|
||||||
|
|
|
@ -342,7 +342,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
super().completion()
|
if headers:
|
||||||
|
optional_params["extra_headers"] = headers
|
||||||
try:
|
try:
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
|
@ -851,8 +852,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
client=None,
|
client=None,
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
) -> litellm.EmbeddingResponse:
|
) -> litellm.EmbeddingResponse:
|
||||||
super().embedding()
|
if headers:
|
||||||
|
optional_params["extra_headers"] = headers
|
||||||
if self._client_session is None:
|
if self._client_session is None:
|
||||||
self._client_session = self.create_client_session()
|
self._client_session = self.create_client_session()
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -259,9 +259,9 @@ async def async_completion(
|
||||||
)
|
)
|
||||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||||
|
|
||||||
for _ in range(litellm.DEFAULT_MAX_RETRIES):
|
for _ in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES):
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
1
|
litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS
|
||||||
) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing"
|
) # wait 1s to allow response to be generated by replicate - else partial output is generated with status=="processing"
|
||||||
response = await async_handler.get(url=prediction_url, headers=headers)
|
response = await async_handler.get(url=prediction_url, headers=headers)
|
||||||
return litellm.ReplicateConfig().transform_response(
|
return litellm.ReplicateConfig().transform_response(
|
||||||
|
|
|
@ -3171,6 +3171,7 @@ def embedding( # noqa: PLR0915
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
aembedding = kwargs.get("aembedding", None)
|
aembedding = kwargs.get("aembedding", None)
|
||||||
extra_headers = kwargs.get("extra_headers", None)
|
extra_headers = kwargs.get("extra_headers", None)
|
||||||
|
headers = kwargs.get("headers", None)
|
||||||
### CUSTOM MODEL COST ###
|
### CUSTOM MODEL COST ###
|
||||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||||
|
@ -3281,9 +3282,6 @@ def embedding( # noqa: PLR0915
|
||||||
"azure_ad_token", None
|
"azure_ad_token", None
|
||||||
) or get_secret_str("AZURE_AD_TOKEN")
|
) or get_secret_str("AZURE_AD_TOKEN")
|
||||||
|
|
||||||
if extra_headers is not None:
|
|
||||||
optional_params["extra_headers"] = extra_headers
|
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
api_key
|
api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
|
@ -3311,6 +3309,7 @@ def embedding( # noqa: PLR0915
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
headers=headers or extra_headers,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
model in litellm.open_ai_embedding_models
|
model in litellm.open_ai_embedding_models
|
||||||
|
|
|
@ -5,26 +5,8 @@ model_list:
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: os.environ/AZURE_API_BASE
|
||||||
temperature: 0.2
|
temperature: 0.2
|
||||||
|
- model_name: "*"
|
||||||
guardrails:
|
|
||||||
- guardrail_name: "presidio-log-guard"
|
|
||||||
litellm_params:
|
litellm_params:
|
||||||
guardrail: presidio
|
model: "*"
|
||||||
mode: "logging_only"
|
model_info:
|
||||||
mock_redacted_text:
|
access_groups: ["default"]
|
||||||
text: "hello world, my name is <PERSON>. My number is: <PHONE_NUMBER>"
|
|
||||||
items:
|
|
||||||
- start: 48
|
|
||||||
end: 62
|
|
||||||
entity_type: PHONE_NUMBER
|
|
||||||
text: "<PHONE_NUMBER>"
|
|
||||||
operator: replace
|
|
||||||
- start: 24
|
|
||||||
end: 32
|
|
||||||
entity_type: PERSON
|
|
||||||
text: "<PERSON>"
|
|
||||||
operator: replace
|
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
set_verbose: true
|
|
||||||
success_callback: ["langfuse"]
|
|
|
@ -757,6 +757,7 @@ async def get_key_object(
|
||||||
except DB_CONNECTION_ERROR_TYPES as e:
|
except DB_CONNECTION_ERROR_TYPES as e:
|
||||||
return await _handle_failed_db_connection_for_get_key_object(e=e)
|
return await _handle_failed_db_connection_for_get_key_object(e=e)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
|
f"Key doesn't exist in db. key={hashed_token}. Create key via `/key/generate` call."
|
||||||
)
|
)
|
||||||
|
@ -870,7 +871,6 @@ async def can_key_call_model(
|
||||||
access_groups = defaultdict(list)
|
access_groups = defaultdict(list)
|
||||||
if llm_router:
|
if llm_router:
|
||||||
access_groups = llm_router.get_model_access_groups(model_name=model)
|
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(access_groups) > 0 and llm_router is not None
|
len(access_groups) > 0 and llm_router is not None
|
||||||
): # check if token contains any model access groups
|
): # check if token contains any model access groups
|
||||||
|
|
|
@ -25,8 +25,6 @@ from typing import (
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
@ -120,7 +118,7 @@ from litellm.litellm_core_utils.core_helpers import (
|
||||||
_get_parent_otel_span_from_kwargs,
|
_get_parent_otel_span_from_kwargs,
|
||||||
get_litellm_metadata_from_kwargs,
|
get_litellm_metadata_from_kwargs,
|
||||||
)
|
)
|
||||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||||
router as analytics_router,
|
router as analytics_router,
|
||||||
|
@ -528,7 +526,7 @@ async_result = None
|
||||||
celery_app_conn = None
|
celery_app_conn = None
|
||||||
celery_fn = None # Redis Queue for handling requests
|
celery_fn = None # Redis Queue for handling requests
|
||||||
### DB WRITER ###
|
### DB WRITER ###
|
||||||
db_writer_client: Optional[HTTPHandler] = None
|
db_writer_client: Optional[AsyncHTTPHandler] = None
|
||||||
### logger ###
|
### logger ###
|
||||||
|
|
||||||
|
|
||||||
|
@ -2092,7 +2090,10 @@ class ProxyConfig:
|
||||||
"""
|
"""
|
||||||
global user_config_file_path, llm_router
|
global user_config_file_path, llm_router
|
||||||
combined_id_list = []
|
combined_id_list = []
|
||||||
if llm_router is None:
|
|
||||||
|
## BASE CASES ##
|
||||||
|
# if llm_router is None or db_models is empty, return 0
|
||||||
|
if llm_router is None or len(db_models) == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
## DB MODELS ##
|
## DB MODELS ##
|
||||||
|
@ -2422,6 +2423,19 @@ class ProxyConfig:
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
async def _get_models_from_db(self, prisma_client: PrismaClient) -> list:
|
||||||
|
try:
|
||||||
|
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"litellm.proxy_server.py::add_deployment() - Error getting new models from DB - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_models = []
|
||||||
|
|
||||||
|
return new_models
|
||||||
|
|
||||||
async def add_deployment(
|
async def add_deployment(
|
||||||
self,
|
self,
|
||||||
prisma_client: PrismaClient,
|
prisma_client: PrismaClient,
|
||||||
|
@ -2439,15 +2453,9 @@ class ProxyConfig:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Master key is not initialized or formatted. master_key={master_key}"
|
f"Master key is not initialized or formatted. master_key={master_key}"
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
|
new_models = await self._get_models_from_db(prisma_client=prisma_client)
|
||||||
except Exception as e:
|
|
||||||
verbose_proxy_logger.exception(
|
|
||||||
"litellm.proxy_server.py::add_deployment() - Error getting new models from DB - {}".format(
|
|
||||||
str(e)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
new_models = []
|
|
||||||
# update llm router
|
# update llm router
|
||||||
await self._update_llm_router(
|
await self._update_llm_router(
|
||||||
new_models=new_models, proxy_logging_obj=proxy_logging_obj
|
new_models=new_models, proxy_logging_obj=proxy_logging_obj
|
||||||
|
@ -8066,7 +8074,8 @@ def get_image():
|
||||||
# Check if the logo path is an HTTP/HTTPS URL
|
# Check if the logo path is an HTTP/HTTPS URL
|
||||||
if logo_path.startswith(("http://", "https://")):
|
if logo_path.startswith(("http://", "https://")):
|
||||||
# Download the image and cache it
|
# Download the image and cache it
|
||||||
response = requests.get(logo_path)
|
client = HTTPHandler()
|
||||||
|
response = client.get(logo_path)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
# Save the image to a local file
|
# Save the image to a local file
|
||||||
cache_path = os.path.join(current_dir, "cached_logo.jpg")
|
cache_path = os.path.join(current_dir, "cached_logo.jpg")
|
||||||
|
|
|
@ -4019,15 +4019,15 @@ class Router:
|
||||||
|
|
||||||
# Check if user is trying to use model_name == "*"
|
# Check if user is trying to use model_name == "*"
|
||||||
# this is a catch all model for their specific api key
|
# this is a catch all model for their specific api key
|
||||||
if deployment.model_name == "*":
|
# if deployment.model_name == "*":
|
||||||
if deployment.litellm_params.model == "*":
|
# if deployment.litellm_params.model == "*":
|
||||||
# user wants to pass through all requests to litellm.acompletion for unknown deployments
|
# # user wants to pass through all requests to litellm.acompletion for unknown deployments
|
||||||
self.router_general_settings.pass_through_all_models = True
|
# self.router_general_settings.pass_through_all_models = True
|
||||||
else:
|
# else:
|
||||||
self.default_deployment = deployment.to_json(exclude_none=True)
|
# self.default_deployment = deployment.to_json(exclude_none=True)
|
||||||
# Check if user is using provider specific wildcard routing
|
# Check if user is using provider specific wildcard routing
|
||||||
# example model_name = "databricks/*" or model_name = "anthropic/*"
|
# example model_name = "databricks/*" or model_name = "anthropic/*"
|
||||||
elif "*" in deployment.model_name:
|
if "*" in deployment.model_name:
|
||||||
# store this as a regex pattern - all deployments matching this pattern will be sent to this deployment
|
# store this as a regex pattern - all deployments matching this pattern will be sent to this deployment
|
||||||
# Store deployment.model_name as a regex pattern
|
# Store deployment.model_name as a regex pattern
|
||||||
self.pattern_router.add_pattern(
|
self.pattern_router.add_pattern(
|
||||||
|
|
|
@ -4,13 +4,52 @@ Class to handle llm wildcard routing and regex pattern matching
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
from functools import cached_property
|
||||||
from re import Match
|
from re import Match
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from litellm import get_llm_provider
|
from litellm import get_llm_provider
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
|
|
||||||
|
|
||||||
|
class PatternUtils:
|
||||||
|
@staticmethod
|
||||||
|
def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Calculate pattern specificity based on length and complexity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Regex pattern to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (length, complexity) for sorting
|
||||||
|
"""
|
||||||
|
complexity_chars = ["*", "+", "?", "\\", "^", "$", "|", "(", ")"]
|
||||||
|
ret_val = (
|
||||||
|
len(pattern), # Longer patterns more specific
|
||||||
|
sum(
|
||||||
|
pattern.count(char) for char in complexity_chars
|
||||||
|
), # More regex complexity
|
||||||
|
)
|
||||||
|
return ret_val
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sorted_patterns(
|
||||||
|
patterns: Dict[str, List[Dict]]
|
||||||
|
) -> List[Tuple[str, List[Dict]]]:
|
||||||
|
"""
|
||||||
|
Cached property for patterns sorted by specificity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of pattern-deployment tuples
|
||||||
|
"""
|
||||||
|
return sorted(
|
||||||
|
patterns.items(),
|
||||||
|
key=lambda x: PatternUtils.calculate_pattern_specificity(x[0]),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PatternMatchRouter:
|
class PatternMatchRouter:
|
||||||
"""
|
"""
|
||||||
Class to handle llm wildcard routing and regex pattern matching
|
Class to handle llm wildcard routing and regex pattern matching
|
||||||
|
@ -99,13 +138,13 @@ class PatternMatchRouter:
|
||||||
if request is None:
|
if request is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
sorted_patterns = PatternUtils.sorted_patterns(self.patterns)
|
||||||
regex_filtered_model_names = (
|
regex_filtered_model_names = (
|
||||||
[self._pattern_to_regex(m) for m in filtered_model_names]
|
[self._pattern_to_regex(m) for m in filtered_model_names]
|
||||||
if filtered_model_names is not None
|
if filtered_model_names is not None
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
for pattern, llm_deployments in sorted_patterns:
|
||||||
for pattern, llm_deployments in self.patterns.items():
|
|
||||||
if (
|
if (
|
||||||
filtered_model_names is not None
|
filtered_model_names is not None
|
||||||
and pattern not in regex_filtered_model_names
|
and pattern not in regex_filtered_model_names
|
||||||
|
|
|
@ -1214,7 +1214,9 @@ def client(original_function): # noqa: PLR0915
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def _select_tokenizer(model: str):
|
def _select_tokenizer(
|
||||||
|
model: str,
|
||||||
|
):
|
||||||
if model in litellm.cohere_models and "command-r" in model:
|
if model in litellm.cohere_models and "command-r" in model:
|
||||||
# cohere
|
# cohere
|
||||||
cohere_tokenizer = Tokenizer.from_pretrained(
|
cohere_tokenizer = Tokenizer.from_pretrained(
|
||||||
|
@ -1235,19 +1237,10 @@ def _select_tokenizer(model: str):
|
||||||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
||||||
# default - tiktoken
|
# default - tiktoken
|
||||||
else:
|
else:
|
||||||
tokenizer = None
|
return {
|
||||||
if (
|
"type": "openai_tokenizer",
|
||||||
model in litellm.open_ai_chat_completion_models
|
"tokenizer": encoding,
|
||||||
or model in litellm.open_ai_text_completion_models
|
} # default to openai tokenizer
|
||||||
or model in litellm.open_ai_embedding_models
|
|
||||||
):
|
|
||||||
return {"type": "openai_tokenizer", "tokenizer": encoding}
|
|
||||||
|
|
||||||
try:
|
|
||||||
tokenizer = Tokenizer.from_pretrained(model)
|
|
||||||
return {"type": "huggingface_tokenizer", "tokenizer": tokenizer}
|
|
||||||
except Exception:
|
|
||||||
return {"type": "openai_tokenizer", "tokenizer": encoding}
|
|
||||||
|
|
||||||
|
|
||||||
def encode(model="", text="", custom_tokenizer: Optional[dict] = None):
|
def encode(model="", text="", custom_tokenizer: Optional[dict] = None):
|
||||||
|
|
|
@ -685,6 +685,67 @@ class TestAnthropicCompletion(BaseLLMChatTest):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_tool_call_and_json_response_format(self):
|
||||||
|
"""
|
||||||
|
Test that the tool call and JSON response format is supported by the LLM API
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from litellm.utils import supports_response_schema
|
||||||
|
|
||||||
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
|
||||||
|
class RFormat(BaseModel):
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
if not supports_response_schema(base_completion_call_args["model"], None):
|
||||||
|
pytest.skip("Model does not support response schema")
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = litellm.completion(
|
||||||
|
**base_completion_call_args,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "response user question with JSON object",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Hey! What's the weather in NewYork?"},
|
||||||
|
],
|
||||||
|
tool_choice="required",
|
||||||
|
response_format=RFormat,
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert res is not None
|
||||||
|
|
||||||
|
assert res.choices[0].message.tool_calls is not None
|
||||||
|
except litellm.InternalServerError:
|
||||||
|
pytest.skip("Model is overloaded")
|
||||||
|
|
||||||
|
|
||||||
def test_convert_tool_response_to_message_with_values():
|
def test_convert_tool_response_to_message_with_values():
|
||||||
"""Test converting a tool response with 'values' key to a message"""
|
"""Test converting a tool response with 'values' key to a message"""
|
||||||
|
@ -829,3 +890,128 @@ def test_anthropic_tool_with_image():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert b64_data in json.dumps(result)
|
assert b64_data in json.dumps(result)
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_map_openai_params_tools_and_json_schema():
|
||||||
|
import json
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"non_default_params": {
|
||||||
|
"response_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"schema": {
|
||||||
|
"properties": {
|
||||||
|
"question": {"title": "Question", "type": "string"},
|
||||||
|
"answer": {"title": "Answer", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["question", "answer"],
|
||||||
|
"title": "RFormat",
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
"name": "RFormat",
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_choice": "required",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mapped_params = litellm.AnthropicConfig().map_openai_params(
|
||||||
|
non_default_params=args["non_default_params"],
|
||||||
|
optional_params={},
|
||||||
|
model="claude-3-5-sonnet-20240620",
|
||||||
|
drop_params=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Question" in json.dumps(mapped_params)
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"json_mode, tool_calls, expect_null_response",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
True,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "toolu_013JszbnYBVygTxh6EGHEHia",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": '{"location": "New York, NY"}',
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
True,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "toolu_013JszbnYBVygTxh6EGHEHia",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": RESPONSE_FORMAT_TOOL_NAME,
|
||||||
|
"arguments": '{"location": "New York, NY"}',
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
False,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "toolu_013JszbnYBVygTxh6EGHEHia",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": RESPONSE_FORMAT_TOOL_NAME,
|
||||||
|
"arguments": '{"location": "New York, NY"}',
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_anthropic_json_mode_and_tool_call_response(
|
||||||
|
json_mode, tool_calls, expect_null_response
|
||||||
|
):
|
||||||
|
result = litellm.AnthropicConfig()._transform_response_for_json_mode(
|
||||||
|
json_mode=json_mode,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result is None if expect_null_response else result is not None
|
||||||
|
), f"Expected result to be {None if expect_null_response else 'not None'}, but got {result}"
|
||||||
|
|
|
@ -113,7 +113,14 @@ import os
|
||||||
({"prompt": "Hello world"}, "image_generation"),
|
({"prompt": "Hello world"}, "image_generation"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_azure_extra_headers(input, call_type):
|
@pytest.mark.parametrize(
|
||||||
|
"header_value",
|
||||||
|
[
|
||||||
|
"headers",
|
||||||
|
"extra_headers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_azure_extra_headers(input, call_type, header_value):
|
||||||
from litellm import embedding, image_generation
|
from litellm import embedding, image_generation
|
||||||
|
|
||||||
http_client = Client()
|
http_client = Client()
|
||||||
|
@ -128,18 +135,21 @@ def test_azure_extra_headers(input, call_type):
|
||||||
func = embedding
|
func = embedding
|
||||||
elif call_type == "image_generation":
|
elif call_type == "image_generation":
|
||||||
func = image_generation
|
func = image_generation
|
||||||
response = func(
|
|
||||||
model="azure/chatgpt-v-2",
|
data = {
|
||||||
api_base="https://openai-gpt-4-test-v-1.openai.azure.com",
|
"model": "azure/chatgpt-v-2",
|
||||||
api_version="2023-07-01-preview",
|
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||||
api_key="my-azure-api-key",
|
"api_version": "2023-07-01-preview",
|
||||||
extra_headers={
|
"api_key": "my-azure-api-key",
|
||||||
|
header_value: {
|
||||||
"Authorization": "my-bad-key",
|
"Authorization": "my-bad-key",
|
||||||
"Ocp-Apim-Subscription-Key": "hello-world-testing",
|
"Ocp-Apim-Subscription-Key": "hello-world-testing",
|
||||||
},
|
},
|
||||||
**input,
|
**input,
|
||||||
)
|
}
|
||||||
|
response = func(**data)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
|
@ -116,6 +116,7 @@ async def test_audio_speech_litellm_vertex(sync_mode):
|
||||||
response.stream_to_file(speech_file_path)
|
response.stream_to_file(speech_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=6, delay=2)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_speech_litellm_vertex_async():
|
async def test_speech_litellm_vertex_async():
|
||||||
# Mock the response
|
# Mock the response
|
||||||
|
|
|
@ -3094,6 +3094,7 @@ def test_completion_azure_deployment_id():
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="replicate endpoints are extremely flaky")
|
||||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion_replicate_llama3(sync_mode):
|
async def test_completion_replicate_llama3(sync_mode):
|
||||||
|
|
|
@ -175,6 +175,62 @@ async def test_add_existing_deployment():
|
||||||
assert init_len_list == len(llm_router.model_list)
|
assert init_len_list == len(llm_router.model_list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_db_error_new_model_check():
|
||||||
|
"""
|
||||||
|
- if error in db, don't delete existing models
|
||||||
|
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/blob/ddfe687b13e9f31db2fb2322887804e3d01dd467/litellm/proxy/proxy_server.py#L2461
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
|
||||||
|
litellm_params = LiteLLM_Params(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
api_key=os.getenv("AZURE_API_KEY"),
|
||||||
|
api_base=os.getenv("AZURE_API_BASE"),
|
||||||
|
api_version=os.getenv("AZURE_API_VERSION"),
|
||||||
|
)
|
||||||
|
deployment = Deployment(model_name="gpt-3.5-turbo", litellm_params=litellm_params)
|
||||||
|
deployment_2 = Deployment(
|
||||||
|
model_name="gpt-3.5-turbo-2", litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
deployment.to_json(exclude_none=True),
|
||||||
|
deployment_2.to_json(exclude_none=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
init_len_list = len(llm_router.model_list)
|
||||||
|
print(f"llm_router: {llm_router}")
|
||||||
|
master_key = "sk-1234"
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", llm_router)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", master_key)
|
||||||
|
pc = ProxyConfig()
|
||||||
|
|
||||||
|
encrypted_litellm_params = litellm_params.dict(exclude_none=True)
|
||||||
|
|
||||||
|
for k, v in encrypted_litellm_params.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
encrypted_value = encrypt_value(v, master_key)
|
||||||
|
encrypted_litellm_params[k] = base64.b64encode(encrypted_value).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
db_model = DBModel(
|
||||||
|
model_id=deployment.model_info.id,
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
litellm_params=encrypted_litellm_params,
|
||||||
|
model_info={"id": deployment.model_info.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
db_models = []
|
||||||
|
deleted_deployments = await pc._delete_deployment(db_models=db_models)
|
||||||
|
assert deleted_deployments == 0
|
||||||
|
|
||||||
|
assert init_len_list == len(llm_router.model_list)
|
||||||
|
|
||||||
|
|
||||||
litellm_params = LiteLLM_Params(
|
litellm_params = LiteLLM_Params(
|
||||||
model="azure/chatgpt-v-2",
|
model="azure/chatgpt-v-2",
|
||||||
api_key=os.getenv("AZURE_API_KEY"),
|
api_key=os.getenv("AZURE_API_KEY"),
|
||||||
|
|
|
@ -133,7 +133,7 @@ def test_route_with_multiple_matching_patterns():
|
||||||
router.add_pattern("openai/*", deployment1.to_json(exclude_none=True))
|
router.add_pattern("openai/*", deployment1.to_json(exclude_none=True))
|
||||||
router.add_pattern("openai/gpt-*", deployment2.to_json(exclude_none=True))
|
router.add_pattern("openai/gpt-*", deployment2.to_json(exclude_none=True))
|
||||||
assert router.route("openai/gpt-3.5-turbo") == [
|
assert router.route("openai/gpt-3.5-turbo") == [
|
||||||
deployment1.to_json(exclude_none=True)
|
deployment2.to_json(exclude_none=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -237,3 +237,79 @@ def test_router_pattern_match_e2e():
|
||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_pattern_matching_router_with_default_wildcard():
|
||||||
|
"""
|
||||||
|
Tests that the router returns the default wildcard model when the pattern is not found
|
||||||
|
|
||||||
|
Make sure generic '*' allows all models to be passed through.
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "*",
|
||||||
|
"litellm_params": {"model": "*"},
|
||||||
|
"model_info": {"access_groups": ["default"]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "anthropic-claude",
|
||||||
|
"litellm_params": {"model": "anthropic/claude-3-5-sonnet"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(router.pattern_router.patterns) > 0
|
||||||
|
|
||||||
|
router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pattern_matching_router_with_default_wildcard_and_model_wildcard():
|
||||||
|
"""
|
||||||
|
Match to more specific pattern first.
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "*",
|
||||||
|
"litellm_params": {"model": "*"},
|
||||||
|
"model_info": {"access_groups": ["default"]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "llmengine/*",
|
||||||
|
"litellm_params": {"model": "openai/*"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(router.pattern_router.patterns) > 0
|
||||||
|
|
||||||
|
pattern_router = router.pattern_router
|
||||||
|
deployments = pattern_router.route("llmengine/gpt-3.5-turbo")
|
||||||
|
assert len(deployments) == 1
|
||||||
|
assert deployments[0]["model_name"] == "llmengine/*"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sorted_patterns():
|
||||||
|
"""
|
||||||
|
Tests that the pattern specificity is calculated correctly
|
||||||
|
"""
|
||||||
|
from litellm.router_utils.pattern_match_deployments import PatternUtils
|
||||||
|
|
||||||
|
sorted_patterns = PatternUtils.sorted_patterns(
|
||||||
|
{
|
||||||
|
"llmengine/*": [{"model_name": "anthropic/claude-3-5-sonnet"}],
|
||||||
|
"*": [{"model_name": "openai/*"}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert sorted_patterns[0][0] == "llmengine/*"
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_pattern_specificity():
|
||||||
|
from litellm.router_utils.pattern_match_deployments import PatternUtils
|
||||||
|
|
||||||
|
assert PatternUtils.calculate_pattern_specificity("llmengine/*") == (11, 1)
|
||||||
|
assert PatternUtils.calculate_pattern_specificity("*") == (1, 1)
|
||||||
|
|
|
@ -63,8 +63,8 @@ async def test_vLLM_token_counting():
|
||||||
print("response: ", response)
|
print("response: ", response)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
response.tokenizer_type == "huggingface_tokenizer"
|
response.tokenizer_type == "openai_tokenizer"
|
||||||
) # SHOULD use the hugging face tokenizer
|
) # SHOULD use the default tokenizer
|
||||||
assert response.model_used == "wolfram/miquliz-120b-v2.0"
|
assert response.model_used == "wolfram/miquliz-120b-v2.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue