Add /vllm/* and /mistral/* passthrough endpoints (adds support for Mistral OCR via passthrough)

* feat(llm_passthrough_endpoints.py): support mistral passthrough

Closes https://github.com/BerriAI/litellm/issues/9051

* feat(llm_passthrough_endpoints.py): initial commit for adding vllm passthrough route

* feat(vllm/common_utils.py): add new vllm model info route

make it possible to use vllm passthrough route via factory function

* fix(llm_passthrough_endpoints.py): add all methods to vllm passthrough route

* fix: fix linting error

* fix: fix linting error

* fix: fix ruff check

* fix(proxy/_types.py): add new passthrough routes

* docs(config_settings.md): add mistral env vars to docs
This commit is contained in:
Krish Dholakia 2025-04-14 22:06:33 -07:00 committed by GitHub
parent 5fcdf4becf
commit 3031fff297
12 changed files with 450 additions and 176 deletions

View file

@ -449,6 +449,8 @@ router_settings:
| LITELLM_TOKEN | Access token for LiteLLM integration | LITELLM_TOKEN | Access token for LiteLLM integration
| LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging | LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging
| LOGFIRE_TOKEN | Token for Logfire logging service | LOGFIRE_TOKEN | Token for Logfire logging service
| MISTRAL_API_BASE | Base URL for Mistral API
| MISTRAL_API_KEY | API key for Mistral API
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services | MICROSOFT_CLIENT_ID | Client ID for Microsoft services
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services | MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure | MICROSOFT_TENANT | Tenant ID for Microsoft Azure

View file

@ -44,7 +44,7 @@ from litellm.utils import (
token_counter, token_counter,
) )
from ..common_utils import AnthropicError, process_anthropic_headers from ..common_utils import AnthropicError, AnthropicModelInfo, process_anthropic_headers
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
@ -54,7 +54,7 @@ else:
LoggingClass = Any LoggingClass = Any
class AnthropicConfig(BaseConfig): class AnthropicConfig(AnthropicModelInfo, BaseConfig):
""" """
Reference: https://docs.anthropic.com/claude/reference/messages_post Reference: https://docs.anthropic.com/claude/reference/messages_post
@ -127,41 +127,6 @@ class AnthropicConfig(BaseConfig):
"anthropic-beta": "prompt-caching-2024-07-31", "anthropic-beta": "prompt-caching-2024-07-31",
} }
def get_anthropic_headers(
self,
api_key: str,
anthropic_version: Optional[str] = None,
computer_tool_used: bool = False,
prompt_caching_set: bool = False,
pdf_used: bool = False,
is_vertex_request: bool = False,
user_anthropic_beta_headers: Optional[List[str]] = None,
) -> dict:
betas = set()
if prompt_caching_set:
betas.add("prompt-caching-2024-07-31")
if computer_tool_used:
betas.add("computer-use-2024-10-22")
if pdf_used:
betas.add("pdfs-2024-09-25")
headers = {
"anthropic-version": anthropic_version or "2023-06-01",
"x-api-key": api_key,
"accept": "application/json",
"content-type": "application/json",
}
if user_anthropic_beta_headers is not None:
betas.update(user_anthropic_beta_headers)
# Don't send any beta headers to Vertex, Vertex has failed requests when they are sent
if is_vertex_request is True:
pass
elif len(betas) > 0:
headers["anthropic-beta"] = ",".join(betas)
return headers
def _map_tool_choice( def _map_tool_choice(
self, tool_choice: Optional[str], parallel_tool_use: Optional[bool] self, tool_choice: Optional[str], parallel_tool_use: Optional[bool]
) -> Optional[AnthropicMessagesToolChoice]: ) -> Optional[AnthropicMessagesToolChoice]:
@ -446,49 +411,6 @@ class AnthropicConfig(BaseConfig):
) )
return _tool return _tool
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
"""
Return if {"cache_control": ..} in message content block
Used to check if anthropic prompt caching headers need to be set.
"""
for message in messages:
if message.get("cache_control", None) is not None:
return True
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
for content in _message_content:
if "cache_control" in content:
return True
return False
def is_computer_tool_used(
self, tools: Optional[List[AllAnthropicToolsValues]]
) -> bool:
if tools is None:
return False
for tool in tools:
if "type" in tool and tool["type"].startswith("computer_"):
return True
return False
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
"""
Set to true if media passed into messages.
"""
for message in messages:
if (
"content" in message
and message["content"] is not None
and isinstance(message["content"], list)
):
for content in message["content"]:
if "type" in content and content["type"] != "text":
return True
return False
def translate_system_message( def translate_system_message(
self, messages: List[AllMessageValues] self, messages: List[AllMessageValues]
) -> List[AnthropicSystemMessageContent]: ) -> List[AnthropicSystemMessageContent]:
@ -862,47 +784,3 @@ class AnthropicConfig(BaseConfig):
message=error_message, message=error_message,
headers=cast(httpx.Headers, headers), headers=cast(httpx.Headers, headers),
) )
def _get_user_anthropic_beta_headers(
self, anthropic_beta_header: Optional[str]
) -> Optional[List[str]]:
if anthropic_beta_header is None:
return None
return anthropic_beta_header.split(",")
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
if api_key is None:
raise litellm.AuthenticationError(
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
llm_provider="anthropic",
model=model,
)
tools = optional_params.get("tools")
prompt_caching_set = self.is_cache_control_set(messages=messages)
computer_tool_used = self.is_computer_tool_used(tools=tools)
pdf_used = self.is_pdf_used(messages=messages)
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
anthropic_beta_header=headers.get("anthropic-beta")
)
anthropic_headers = self.get_anthropic_headers(
computer_tool_used=computer_tool_used,
prompt_caching_set=prompt_caching_set,
pdf_used=pdf_used,
api_key=api_key,
is_vertex_request=optional_params.get("is_vertex_request", False),
user_anthropic_beta_headers=user_anthropic_beta_headers,
)
headers = {**headers, **anthropic_headers}
return headers

View file

@ -2,7 +2,7 @@
This file contains common utils for anthropic calls. This file contains common utils for anthropic calls.
""" """
from typing import List, Optional, Union from typing import Dict, List, Optional, Union
import httpx import httpx
@ -10,6 +10,8 @@ import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.anthropic import AllAnthropicToolsValues
from litellm.types.llms.openai import AllMessageValues
class AnthropicError(BaseLLMException): class AnthropicError(BaseLLMException):
@ -23,6 +25,128 @@ class AnthropicError(BaseLLMException):
class AnthropicModelInfo(BaseLLMModelInfo): class AnthropicModelInfo(BaseLLMModelInfo):
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
"""
Return if {"cache_control": ..} in message content block
Used to check if anthropic prompt caching headers need to be set.
"""
for message in messages:
if message.get("cache_control", None) is not None:
return True
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
for content in _message_content:
if "cache_control" in content:
return True
return False
def is_computer_tool_used(
self, tools: Optional[List[AllAnthropicToolsValues]]
) -> bool:
if tools is None:
return False
for tool in tools:
if "type" in tool and tool["type"].startswith("computer_"):
return True
return False
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
"""
Set to true if media passed into messages.
"""
for message in messages:
if (
"content" in message
and message["content"] is not None
and isinstance(message["content"], list)
):
for content in message["content"]:
if "type" in content and content["type"] != "text":
return True
return False
def _get_user_anthropic_beta_headers(
self, anthropic_beta_header: Optional[str]
) -> Optional[List[str]]:
if anthropic_beta_header is None:
return None
return anthropic_beta_header.split(",")
def get_anthropic_headers(
self,
api_key: str,
anthropic_version: Optional[str] = None,
computer_tool_used: bool = False,
prompt_caching_set: bool = False,
pdf_used: bool = False,
is_vertex_request: bool = False,
user_anthropic_beta_headers: Optional[List[str]] = None,
) -> dict:
betas = set()
if prompt_caching_set:
betas.add("prompt-caching-2024-07-31")
if computer_tool_used:
betas.add("computer-use-2024-10-22")
if pdf_used:
betas.add("pdfs-2024-09-25")
headers = {
"anthropic-version": anthropic_version or "2023-06-01",
"x-api-key": api_key,
"accept": "application/json",
"content-type": "application/json",
}
if user_anthropic_beta_headers is not None:
betas.update(user_anthropic_beta_headers)
# Don't send any beta headers to Vertex, Vertex has failed requests when they are sent
if is_vertex_request is True:
pass
elif len(betas) > 0:
headers["anthropic-beta"] = ",".join(betas)
return headers
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
if api_key is None:
raise litellm.AuthenticationError(
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
llm_provider="anthropic",
model=model,
)
tools = optional_params.get("tools")
prompt_caching_set = self.is_cache_control_set(messages=messages)
computer_tool_used = self.is_computer_tool_used(tools=tools)
pdf_used = self.is_pdf_used(messages=messages)
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
anthropic_beta_header=headers.get("anthropic-beta")
)
anthropic_headers = self.get_anthropic_headers(
computer_tool_used=computer_tool_used,
prompt_caching_set=prompt_caching_set,
pdf_used=pdf_used,
api_key=api_key,
is_vertex_request=optional_params.get("is_vertex_request", False),
user_anthropic_beta_headers=user_anthropic_beta_headers,
)
headers = {**headers, **anthropic_headers}
return headers
@staticmethod @staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]: def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return ( return (

View file

@ -44,6 +44,19 @@ class BaseLLMModelInfo(ABC):
def get_api_base(api_base: Optional[str] = None) -> Optional[str]: def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
pass pass
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
pass
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_base_model(model: str) -> Optional[str]: def get_base_model(model: str) -> Optional[str]:

View file

@ -389,7 +389,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
) )
@staticmethod @staticmethod
def get_base_model(model: str) -> str: def get_base_model(model: Optional[str] = None) -> Optional[str]:
return model return model
def get_model_response_iterator( def get_model_response_iterator(

View file

@ -1,6 +1,7 @@
from typing import List, Optional from typing import List, Optional
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from ..base_llm.base_utils import BaseLLMModelInfo from ..base_llm.base_utils import BaseLLMModelInfo
from ..base_llm.chat.transformation import BaseLLMException from ..base_llm.chat.transformation import BaseLLMException
@ -11,6 +12,26 @@ class TopazException(BaseLLMException):
class TopazModelInfo(BaseLLMModelInfo): class TopazModelInfo(BaseLLMModelInfo):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
)
return {
# "Content-Type": "multipart/form-data",
"Accept": "image/jpeg",
"X-API-Key": api_key,
}
def get_models( def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]: ) -> List[str]:

View file

@ -10,10 +10,7 @@ from litellm.llms.base_llm.chat.transformation import (
BaseLLMException, BaseLLMException,
LiteLLMLoggingObj, LiteLLMLoggingObj,
) )
from litellm.types.llms.openai import ( from litellm.types.llms.openai import OpenAIImageVariationOptionalParams
AllMessageValues,
OpenAIImageVariationOptionalParams,
)
from litellm.types.utils import ( from litellm.types.utils import (
FileTypes, FileTypes,
HttpHandlerRequestFields, HttpHandlerRequestFields,
@ -22,35 +19,15 @@ from litellm.types.utils import (
) )
from ...base_llm.image_variations.transformation import BaseImageVariationConfig from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ..common_utils import TopazException from ..common_utils import TopazException, TopazModelInfo
class TopazImageVariationConfig(BaseImageVariationConfig): class TopazImageVariationConfig(TopazModelInfo, BaseImageVariationConfig):
def get_supported_openai_params( def get_supported_openai_params(
self, model: str self, model: str
) -> List[OpenAIImageVariationOptionalParams]: ) -> List[OpenAIImageVariationOptionalParams]:
return ["response_format", "size"] return ["response_format", "size"]
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
)
return {
# "Content-Type": "multipart/form-data",
"Accept": "image/jpeg",
"X-API-Key": api_key,
}
def get_complete_url( def get_complete_url(
self, self,
api_base: Optional[str], api_base: Optional[str],

View file

@ -0,0 +1,75 @@
from typing import List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import _add_path_to_api_base
class VLLMError(BaseLLMException):
pass
class VLLMModelInfo(BaseLLMModelInfo):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""Google AI Studio sends api key in query params"""
return headers
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
api_base = api_base or get_secret_str("VLLM_API_BASE")
if api_base is None:
raise ValueError(
"VLLM_API_BASE is not set. Please set the environment variable, to use VLLM's pass-through - `{LITELLM_API_BASE}/vllm/{endpoint}`."
)
return api_base
@staticmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
return None
@staticmethod
def get_base_model(model: str) -> Optional[str]:
return model
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
api_base = VLLMModelInfo.get_api_base(api_base)
api_key = VLLMModelInfo.get_api_key(api_key)
endpoint = "/v1/models"
if api_base is None or api_key is None:
raise ValueError(
"GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
)
url = _add_path_to_api_base(api_base, endpoint)
response = litellm.module_level_client.get(
url=url,
)
response.raise_for_status()
models = response.json()["data"]
return [model["id"] for model in models]
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return VLLMError(
status_code=status_code, message=error_message, headers=headers
)

View file

@ -5,9 +5,29 @@ import httpx
import litellm import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
class XAIModelInfo(BaseLLMModelInfo): class XAIModelInfo(BaseLLMModelInfo):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Ensure Content-Type is set to application/json
if "content-type" not in headers and "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return headers
@staticmethod @staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]: def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai" return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai"

View file

@ -317,6 +317,8 @@ class LiteLLMRoutes(enum.Enum):
"/openai", "/openai",
"/assemblyai", "/assemblyai",
"/eu.assemblyai", "/eu.assemblyai",
"/vllm",
"/mistral",
] ]
anthropic_routes = [ anthropic_routes = [

View file

@ -6,6 +6,7 @@ Provider-specific Pass-Through Endpoints
Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc. Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc.
""" """
import os
from typing import Optional from typing import Optional
import httpx import httpx
@ -43,6 +44,84 @@ def create_request_copy(request: Request):
} }
async def llm_passthrough_factory_proxy_route(
custom_llm_provider: str,
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Factory function for creating pass-through endpoints for LLM providers.
"""
from litellm.types.utils import LlmProviders
from litellm.utils import ProviderConfigManager
provider_config = ProviderConfigManager.get_provider_model_info(
provider=LlmProviders(custom_llm_provider),
model=None,
)
if provider_config is None:
raise HTTPException(
status_code=404, detail=f"Provider {custom_llm_provider} not found"
)
base_target_url = provider_config.get_api_base()
if base_target_url is None:
raise HTTPException(
status_code=404, detail=f"Provider {custom_llm_provider} api base not found"
)
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
provider_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider=custom_llm_provider,
region_name=None,
)
auth_headers = provider_config.validate_environment(
headers={},
model="",
messages=[],
optional_params={},
litellm_params={},
api_key=provider_api_key,
api_base=base_target_url,
)
## check for streaming
is_streaming_request = False
# anthropic is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers=auth_headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route( @router.api_route(
"/gemini/{endpoint:path}", "/gemini/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"], methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
@ -162,6 +241,84 @@ async def cohere_proxy_route(
return received_value return received_value
@router.api_route(
"/vllm/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["VLLM Pass-through", "pass-through"],
)
async def vllm_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/pass_through/vllm)
"""
return await llm_passthrough_factory_proxy_route(
endpoint=endpoint,
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
custom_llm_provider="vllm",
)
@router.api_route(
"/mistral/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Mistral Pass-through", "pass-through"],
)
async def mistral_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
"""
base_target_url = os.getenv("MISTRAL_API_BASE") or "https://api.mistral.ai"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
mistral_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="mistral",
region_name=None,
)
## check for streaming
is_streaming_request = False
# anthropic is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={"Authorization": "Bearer {}".format(mistral_api_key)},
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route( @router.api_route(
"/anthropic/{endpoint:path}", "/anthropic/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"], methods=["GET", "POST", "PUT", "DELETE", "PATCH"],

View file

@ -516,9 +516,9 @@ def function_setup( # noqa: PLR0915
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
## DYNAMIC CALLBACKS ## ## DYNAMIC CALLBACKS ##
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = ( dynamic_callbacks: Optional[
kwargs.pop("callbacks", None) List[Union[str, Callable, CustomLogger]]
) ] = kwargs.pop("callbacks", None)
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks) all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
if len(all_callbacks) > 0: if len(all_callbacks) > 0:
@ -1202,9 +1202,9 @@ def client(original_function): # noqa: PLR0915
exception=e, exception=e,
retry_policy=kwargs.get("retry_policy"), retry_policy=kwargs.get("retry_policy"),
) )
kwargs["retry_policy"] = ( kwargs[
reset_retry_policy() "retry_policy"
) # prevent infinite loops ] = reset_retry_policy() # prevent infinite loops
litellm.num_retries = ( litellm.num_retries = (
None # set retries to None to prevent infinite loops None # set retries to None to prevent infinite loops
) )
@ -3013,16 +3013,16 @@ def get_optional_params( # noqa: PLR0915
True # so that main.py adds the function call to the prompt True # so that main.py adds the function call to the prompt
) )
if "tools" in non_default_params: if "tools" in non_default_params:
optional_params["functions_unsupported_model"] = ( optional_params[
non_default_params.pop("tools") "functions_unsupported_model"
) ] = non_default_params.pop("tools")
non_default_params.pop( non_default_params.pop(
"tool_choice", None "tool_choice", None
) # causes ollama requests to hang ) # causes ollama requests to hang
elif "functions" in non_default_params: elif "functions" in non_default_params:
optional_params["functions_unsupported_model"] = ( optional_params[
non_default_params.pop("functions") "functions_unsupported_model"
) ] = non_default_params.pop("functions")
elif ( elif (
litellm.add_function_to_prompt litellm.add_function_to_prompt
): # if user opts to add it to prompt instead ): # if user opts to add it to prompt instead
@ -3045,10 +3045,10 @@ def get_optional_params( # noqa: PLR0915
if "response_format" in non_default_params: if "response_format" in non_default_params:
if provider_config is not None: if provider_config is not None:
non_default_params["response_format"] = ( non_default_params[
provider_config.get_json_schema_from_pydantic_object( "response_format"
response_format=non_default_params["response_format"] ] = provider_config.get_json_schema_from_pydantic_object(
) response_format=non_default_params["response_format"]
) )
else: else:
non_default_params["response_format"] = type_to_response_format_param( non_default_params["response_format"] = type_to_response_format_param(
@ -4064,9 +4064,9 @@ def _count_characters(text: str) -> int:
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str: def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = ( _choices: Union[
response_obj.choices List[Union[Choices, StreamingChoices]], List[StreamingChoices]
) ] = response_obj.choices
response_str = "" response_str = ""
for choice in _choices: for choice in _choices:
@ -4458,14 +4458,14 @@ def _get_model_info_helper( # noqa: PLR0915
if combined_model_name in litellm.model_cost: if combined_model_name in litellm.model_cost:
key = combined_model_name key = combined_model_name
_model_info = _get_model_info_from_model_cost(key=key) _model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match( if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider model_info=_model_info, custom_llm_provider=custom_llm_provider
): ):
_model_info = None _model_info = None
if _model_info is None and model in litellm.model_cost: if _model_info is None and model in litellm.model_cost:
key = model key = model
_model_info = _get_model_info_from_model_cost(key=key) _model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match( if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider model_info=_model_info, custom_llm_provider=custom_llm_provider
): ):
@ -4475,21 +4475,21 @@ def _get_model_info_helper( # noqa: PLR0915
and combined_stripped_model_name in litellm.model_cost and combined_stripped_model_name in litellm.model_cost
): ):
key = combined_stripped_model_name key = combined_stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key) _model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match( if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider model_info=_model_info, custom_llm_provider=custom_llm_provider
): ):
_model_info = None _model_info = None
if _model_info is None and stripped_model_name in litellm.model_cost: if _model_info is None and stripped_model_name in litellm.model_cost:
key = stripped_model_name key = stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key) _model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match( if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider model_info=_model_info, custom_llm_provider=custom_llm_provider
): ):
_model_info = None _model_info = None
if _model_info is None and split_model in litellm.model_cost: if _model_info is None and split_model in litellm.model_cost:
key = split_model key = split_model
_model_info = _get_model_info_from_model_cost(key=key) _model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match( if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider model_info=_model_info, custom_llm_provider=custom_llm_provider
): ):
@ -6510,7 +6510,12 @@ class ProviderConfigManager:
return litellm.AnthropicModelInfo() return litellm.AnthropicModelInfo()
elif LlmProviders.XAI == provider: elif LlmProviders.XAI == provider:
return litellm.XAIModelInfo() return litellm.XAIModelInfo()
elif LlmProviders.VLLM == provider:
from litellm.llms.vllm.common_utils import (
VLLMModelInfo, # experimental approach, to reduce bloat on __init__.py
)
return VLLMModelInfo()
return None return None
@staticmethod @staticmethod