Add /vllm/* and /mistral/* passthrough endpoints (adds support for Mistral OCR via passthrough)

* feat(llm_passthrough_endpoints.py): support mistral passthrough

Closes https://github.com/BerriAI/litellm/issues/9051

* feat(llm_passthrough_endpoints.py): initial commit for adding vllm passthrough route

* feat(vllm/common_utils.py): add new vllm model info route

make it possible to use vllm passthrough route via factory function

* fix(llm_passthrough_endpoints.py): add all methods to vllm passthrough route

* fix: fix linting error

* fix: fix linting error

* fix: fix ruff check

* fix(proxy/_types.py): add new passthrough routes

* docs(config_settings.md): add mistral env vars to docs
This commit is contained in:
Krish Dholakia 2025-04-14 22:06:33 -07:00 committed by GitHub
parent 8faf56922c
commit 9b0f871129
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 450 additions and 176 deletions

View file

@ -449,6 +449,8 @@ router_settings:
| LITELLM_TOKEN | Access token for LiteLLM integration
| LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging
| LOGFIRE_TOKEN | Token for Logfire logging service
| MISTRAL_API_BASE | Base URL for Mistral API
| MISTRAL_API_KEY | API key for Mistral API
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure

View file

@ -44,7 +44,7 @@ from litellm.utils import (
token_counter,
)
from ..common_utils import AnthropicError, process_anthropic_headers
from ..common_utils import AnthropicError, AnthropicModelInfo, process_anthropic_headers
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
@ -54,7 +54,7 @@ else:
LoggingClass = Any
class AnthropicConfig(BaseConfig):
class AnthropicConfig(AnthropicModelInfo, BaseConfig):
"""
Reference: https://docs.anthropic.com/claude/reference/messages_post
@ -127,41 +127,6 @@ class AnthropicConfig(BaseConfig):
"anthropic-beta": "prompt-caching-2024-07-31",
}
def get_anthropic_headers(
self,
api_key: str,
anthropic_version: Optional[str] = None,
computer_tool_used: bool = False,
prompt_caching_set: bool = False,
pdf_used: bool = False,
is_vertex_request: bool = False,
user_anthropic_beta_headers: Optional[List[str]] = None,
) -> dict:
betas = set()
if prompt_caching_set:
betas.add("prompt-caching-2024-07-31")
if computer_tool_used:
betas.add("computer-use-2024-10-22")
if pdf_used:
betas.add("pdfs-2024-09-25")
headers = {
"anthropic-version": anthropic_version or "2023-06-01",
"x-api-key": api_key,
"accept": "application/json",
"content-type": "application/json",
}
if user_anthropic_beta_headers is not None:
betas.update(user_anthropic_beta_headers)
# Don't send any beta headers to Vertex, Vertex has failed requests when they are sent
if is_vertex_request is True:
pass
elif len(betas) > 0:
headers["anthropic-beta"] = ",".join(betas)
return headers
def _map_tool_choice(
self, tool_choice: Optional[str], parallel_tool_use: Optional[bool]
) -> Optional[AnthropicMessagesToolChoice]:
@ -446,49 +411,6 @@ class AnthropicConfig(BaseConfig):
)
return _tool
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
"""
Return if {"cache_control": ..} in message content block
Used to check if anthropic prompt caching headers need to be set.
"""
for message in messages:
if message.get("cache_control", None) is not None:
return True
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
for content in _message_content:
if "cache_control" in content:
return True
return False
def is_computer_tool_used(
self, tools: Optional[List[AllAnthropicToolsValues]]
) -> bool:
if tools is None:
return False
for tool in tools:
if "type" in tool and tool["type"].startswith("computer_"):
return True
return False
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
"""
Set to true if media passed into messages.
"""
for message in messages:
if (
"content" in message
and message["content"] is not None
and isinstance(message["content"], list)
):
for content in message["content"]:
if "type" in content and content["type"] != "text":
return True
return False
def translate_system_message(
self, messages: List[AllMessageValues]
) -> List[AnthropicSystemMessageContent]:
@ -862,47 +784,3 @@ class AnthropicConfig(BaseConfig):
message=error_message,
headers=cast(httpx.Headers, headers),
)
def _get_user_anthropic_beta_headers(
self, anthropic_beta_header: Optional[str]
) -> Optional[List[str]]:
if anthropic_beta_header is None:
return None
return anthropic_beta_header.split(",")
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
if api_key is None:
raise litellm.AuthenticationError(
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
llm_provider="anthropic",
model=model,
)
tools = optional_params.get("tools")
prompt_caching_set = self.is_cache_control_set(messages=messages)
computer_tool_used = self.is_computer_tool_used(tools=tools)
pdf_used = self.is_pdf_used(messages=messages)
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
anthropic_beta_header=headers.get("anthropic-beta")
)
anthropic_headers = self.get_anthropic_headers(
computer_tool_used=computer_tool_used,
prompt_caching_set=prompt_caching_set,
pdf_used=pdf_used,
api_key=api_key,
is_vertex_request=optional_params.get("is_vertex_request", False),
user_anthropic_beta_headers=user_anthropic_beta_headers,
)
headers = {**headers, **anthropic_headers}
return headers

View file

@ -2,7 +2,7 @@
This file contains common utils for anthropic calls.
"""
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
import httpx
@ -10,6 +10,8 @@ import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.anthropic import AllAnthropicToolsValues
from litellm.types.llms.openai import AllMessageValues
class AnthropicError(BaseLLMException):
@ -23,6 +25,128 @@ class AnthropicError(BaseLLMException):
class AnthropicModelInfo(BaseLLMModelInfo):
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
"""
Return if {"cache_control": ..} in message content block
Used to check if anthropic prompt caching headers need to be set.
"""
for message in messages:
if message.get("cache_control", None) is not None:
return True
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
for content in _message_content:
if "cache_control" in content:
return True
return False
def is_computer_tool_used(
self, tools: Optional[List[AllAnthropicToolsValues]]
) -> bool:
if tools is None:
return False
for tool in tools:
if "type" in tool and tool["type"].startswith("computer_"):
return True
return False
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
"""
Set to true if media passed into messages.
"""
for message in messages:
if (
"content" in message
and message["content"] is not None
and isinstance(message["content"], list)
):
for content in message["content"]:
if "type" in content and content["type"] != "text":
return True
return False
def _get_user_anthropic_beta_headers(
self, anthropic_beta_header: Optional[str]
) -> Optional[List[str]]:
if anthropic_beta_header is None:
return None
return anthropic_beta_header.split(",")
def get_anthropic_headers(
self,
api_key: str,
anthropic_version: Optional[str] = None,
computer_tool_used: bool = False,
prompt_caching_set: bool = False,
pdf_used: bool = False,
is_vertex_request: bool = False,
user_anthropic_beta_headers: Optional[List[str]] = None,
) -> dict:
betas = set()
if prompt_caching_set:
betas.add("prompt-caching-2024-07-31")
if computer_tool_used:
betas.add("computer-use-2024-10-22")
if pdf_used:
betas.add("pdfs-2024-09-25")
headers = {
"anthropic-version": anthropic_version or "2023-06-01",
"x-api-key": api_key,
"accept": "application/json",
"content-type": "application/json",
}
if user_anthropic_beta_headers is not None:
betas.update(user_anthropic_beta_headers)
# Don't send any beta headers to Vertex, Vertex has failed requests when they are sent
if is_vertex_request is True:
pass
elif len(betas) > 0:
headers["anthropic-beta"] = ",".join(betas)
return headers
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
if api_key is None:
raise litellm.AuthenticationError(
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
llm_provider="anthropic",
model=model,
)
tools = optional_params.get("tools")
prompt_caching_set = self.is_cache_control_set(messages=messages)
computer_tool_used = self.is_computer_tool_used(tools=tools)
pdf_used = self.is_pdf_used(messages=messages)
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
anthropic_beta_header=headers.get("anthropic-beta")
)
anthropic_headers = self.get_anthropic_headers(
computer_tool_used=computer_tool_used,
prompt_caching_set=prompt_caching_set,
pdf_used=pdf_used,
api_key=api_key,
is_vertex_request=optional_params.get("is_vertex_request", False),
user_anthropic_beta_headers=user_anthropic_beta_headers,
)
headers = {**headers, **anthropic_headers}
return headers
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return (

View file

@ -44,6 +44,19 @@ class BaseLLMModelInfo(ABC):
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
pass
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
pass
@staticmethod
@abstractmethod
def get_base_model(model: str) -> Optional[str]:

View file

@ -389,7 +389,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
)
@staticmethod
def get_base_model(model: str) -> str:
def get_base_model(model: Optional[str] = None) -> Optional[str]:
return model
def get_model_response_iterator(

View file

@ -1,6 +1,7 @@
from typing import List, Optional
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from ..base_llm.base_utils import BaseLLMModelInfo
from ..base_llm.chat.transformation import BaseLLMException
@ -11,6 +12,26 @@ class TopazException(BaseLLMException):
class TopazModelInfo(BaseLLMModelInfo):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
)
return {
# "Content-Type": "multipart/form-data",
"Accept": "image/jpeg",
"X-API-Key": api_key,
}
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:

View file

@ -10,10 +10,7 @@ from litellm.llms.base_llm.chat.transformation import (
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageVariationOptionalParams,
)
from litellm.types.llms.openai import OpenAIImageVariationOptionalParams
from litellm.types.utils import (
FileTypes,
HttpHandlerRequestFields,
@ -22,35 +19,15 @@ from litellm.types.utils import (
)
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ..common_utils import TopazException
from ..common_utils import TopazException, TopazModelInfo
class TopazImageVariationConfig(BaseImageVariationConfig):
class TopazImageVariationConfig(TopazModelInfo, BaseImageVariationConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageVariationOptionalParams]:
return ["response_format", "size"]
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
)
return {
# "Content-Type": "multipart/form-data",
"Accept": "image/jpeg",
"X-API-Key": api_key,
}
def get_complete_url(
self,
api_base: Optional[str],

View file

@ -0,0 +1,75 @@
from typing import List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import _add_path_to_api_base
class VLLMError(BaseLLMException):
pass
class VLLMModelInfo(BaseLLMModelInfo):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""Google AI Studio sends api key in query params"""
return headers
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
api_base = api_base or get_secret_str("VLLM_API_BASE")
if api_base is None:
raise ValueError(
"VLLM_API_BASE is not set. Please set the environment variable, to use VLLM's pass-through - `{LITELLM_API_BASE}/vllm/{endpoint}`."
)
return api_base
@staticmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
return None
@staticmethod
def get_base_model(model: str) -> Optional[str]:
return model
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
api_base = VLLMModelInfo.get_api_base(api_base)
api_key = VLLMModelInfo.get_api_key(api_key)
endpoint = "/v1/models"
if api_base is None or api_key is None:
raise ValueError(
"GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
)
url = _add_path_to_api_base(api_base, endpoint)
response = litellm.module_level_client.get(
url=url,
)
response.raise_for_status()
models = response.json()["data"]
return [model["id"] for model in models]
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return VLLMError(
status_code=status_code, message=error_message, headers=headers
)

View file

@ -5,9 +5,29 @@ import httpx
import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
class XAIModelInfo(BaseLLMModelInfo):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Ensure Content-Type is set to application/json
if "content-type" not in headers and "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return headers
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai"

View file

@ -317,6 +317,8 @@ class LiteLLMRoutes(enum.Enum):
"/openai",
"/assemblyai",
"/eu.assemblyai",
"/vllm",
"/mistral",
]
anthropic_routes = [

View file

@ -6,6 +6,7 @@ Provider-specific Pass-Through Endpoints
Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc.
"""
import os
from typing import Optional
import httpx
@ -43,6 +44,84 @@ def create_request_copy(request: Request):
}
async def llm_passthrough_factory_proxy_route(
custom_llm_provider: str,
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Factory function for creating pass-through endpoints for LLM providers.
"""
from litellm.types.utils import LlmProviders
from litellm.utils import ProviderConfigManager
provider_config = ProviderConfigManager.get_provider_model_info(
provider=LlmProviders(custom_llm_provider),
model=None,
)
if provider_config is None:
raise HTTPException(
status_code=404, detail=f"Provider {custom_llm_provider} not found"
)
base_target_url = provider_config.get_api_base()
if base_target_url is None:
raise HTTPException(
status_code=404, detail=f"Provider {custom_llm_provider} api base not found"
)
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
provider_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider=custom_llm_provider,
region_name=None,
)
auth_headers = provider_config.validate_environment(
headers={},
model="",
messages=[],
optional_params={},
litellm_params={},
api_key=provider_api_key,
api_base=base_target_url,
)
## check for streaming
is_streaming_request = False
# anthropic is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers=auth_headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route(
"/gemini/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
@ -162,6 +241,84 @@ async def cohere_proxy_route(
return received_value
@router.api_route(
"/vllm/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["VLLM Pass-through", "pass-through"],
)
async def vllm_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/pass_through/vllm)
"""
return await llm_passthrough_factory_proxy_route(
endpoint=endpoint,
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
custom_llm_provider="vllm",
)
@router.api_route(
"/mistral/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Mistral Pass-through", "pass-through"],
)
async def mistral_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
"""
base_target_url = os.getenv("MISTRAL_API_BASE") or "https://api.mistral.ai"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
mistral_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="mistral",
region_name=None,
)
## check for streaming
is_streaming_request = False
# anthropic is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={"Authorization": "Bearer {}".format(mistral_api_key)},
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route(
"/anthropic/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],

View file

@ -516,9 +516,9 @@ def function_setup( # noqa: PLR0915
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
## DYNAMIC CALLBACKS ##
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
kwargs.pop("callbacks", None)
)
dynamic_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = kwargs.pop("callbacks", None)
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
if len(all_callbacks) > 0:
@ -1202,9 +1202,9 @@ def client(original_function): # noqa: PLR0915
exception=e,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = (
reset_retry_policy()
) # prevent infinite loops
kwargs[
"retry_policy"
] = reset_retry_policy() # prevent infinite loops
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)
@ -3013,16 +3013,16 @@ def get_optional_params( # noqa: PLR0915
True # so that main.py adds the function call to the prompt
)
if "tools" in non_default_params:
optional_params["functions_unsupported_model"] = (
non_default_params.pop("tools")
)
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("tools")
non_default_params.pop(
"tool_choice", None
) # causes ollama requests to hang
elif "functions" in non_default_params:
optional_params["functions_unsupported_model"] = (
non_default_params.pop("functions")
)
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("functions")
elif (
litellm.add_function_to_prompt
): # if user opts to add it to prompt instead
@ -3045,10 +3045,10 @@ def get_optional_params( # noqa: PLR0915
if "response_format" in non_default_params:
if provider_config is not None:
non_default_params["response_format"] = (
provider_config.get_json_schema_from_pydantic_object(
response_format=non_default_params["response_format"]
)
non_default_params[
"response_format"
] = provider_config.get_json_schema_from_pydantic_object(
response_format=non_default_params["response_format"]
)
else:
non_default_params["response_format"] = type_to_response_format_param(
@ -4064,9 +4064,9 @@ def _count_characters(text: str) -> int:
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = (
response_obj.choices
)
_choices: Union[
List[Union[Choices, StreamingChoices]], List[StreamingChoices]
] = response_obj.choices
response_str = ""
for choice in _choices:
@ -4458,14 +4458,14 @@ def _get_model_info_helper( # noqa: PLR0915
if combined_model_name in litellm.model_cost:
key = combined_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
_model_info = None
if _model_info is None and model in litellm.model_cost:
key = model
_model_info = _get_model_info_from_model_cost(key=key)
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
@ -4475,21 +4475,21 @@ def _get_model_info_helper( # noqa: PLR0915
and combined_stripped_model_name in litellm.model_cost
):
key = combined_stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
_model_info = None
if _model_info is None and stripped_model_name in litellm.model_cost:
key = stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
_model_info = None
if _model_info is None and split_model in litellm.model_cost:
key = split_model
_model_info = _get_model_info_from_model_cost(key=key)
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
@ -6510,7 +6510,12 @@ class ProviderConfigManager:
return litellm.AnthropicModelInfo()
elif LlmProviders.XAI == provider:
return litellm.XAIModelInfo()
elif LlmProviders.VLLM == provider:
from litellm.llms.vllm.common_utils import (
VLLMModelInfo, # experimental approach, to reduce bloat on __init__.py
)
return VLLMModelInfo()
return None
@staticmethod