From 3031fff2976faa615833c86c63cad5489ad3a922 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 14 Apr 2025 22:06:33 -0700 Subject: [PATCH] Add `/vllm/*` and `/mistral/*` passthrough endpoints (adds support for Mistral OCR via passthrough) * feat(llm_passthrough_endpoints.py): support mistral passthrough Closes https://github.com/BerriAI/litellm/issues/9051 * feat(llm_passthrough_endpoints.py): initial commit for adding vllm passthrough route * feat(vllm/common_utils.py): add new vllm model info route make it possible to use vllm passthrough route via factory function * fix(llm_passthrough_endpoints.py): add all methods to vllm passthrough route * fix: fix linting error * fix: fix linting error * fix: fix ruff check * fix(proxy/_types.py): add new passthrough routes * docs(config_settings.md): add mistral env vars to docs --- docs/my-website/docs/proxy/config_settings.md | 2 + litellm/llms/anthropic/chat/transformation.py | 126 +------------- litellm/llms/anthropic/common_utils.py | 126 +++++++++++++- litellm/llms/base_llm/base_utils.py | 13 ++ .../llms/openai/chat/gpt_transformation.py | 2 +- litellm/llms/topaz/common_utils.py | 21 +++ .../topaz/image_variations/transformation.py | 29 +--- litellm/llms/vllm/common_utils.py | 75 +++++++++ litellm/llms/xai/common_utils.py | 20 +++ litellm/proxy/_types.py | 2 + .../llm_passthrough_endpoints.py | 157 ++++++++++++++++++ litellm/utils.py | 53 +++--- 12 files changed, 450 insertions(+), 176 deletions(-) create mode 100644 litellm/llms/vllm/common_utils.py diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 779d74acbb..db12e1ed07 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -449,6 +449,8 @@ router_settings: | LITELLM_TOKEN | Access token for LiteLLM integration | LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging | LOGFIRE_TOKEN | Token for Logfire logging service +| MISTRAL_API_BASE | Base URL for Mistral API +| MISTRAL_API_KEY | API key for Mistral API | MICROSOFT_CLIENT_ID | Client ID for Microsoft services | MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services | MICROSOFT_TENANT | Tenant ID for Microsoft Azure diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 96da34a855..590931321d 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -44,7 +44,7 @@ from litellm.utils import ( token_counter, ) -from ..common_utils import AnthropicError, process_anthropic_headers +from ..common_utils import AnthropicError, AnthropicModelInfo, process_anthropic_headers if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj @@ -54,7 +54,7 @@ else: LoggingClass = Any -class AnthropicConfig(BaseConfig): +class AnthropicConfig(AnthropicModelInfo, BaseConfig): """ Reference: https://docs.anthropic.com/claude/reference/messages_post @@ -127,41 +127,6 @@ class AnthropicConfig(BaseConfig): "anthropic-beta": "prompt-caching-2024-07-31", } - def get_anthropic_headers( - self, - api_key: str, - anthropic_version: Optional[str] = None, - computer_tool_used: bool = False, - prompt_caching_set: bool = False, - pdf_used: bool = False, - is_vertex_request: bool = False, - user_anthropic_beta_headers: Optional[List[str]] = None, - ) -> dict: - betas = set() - if prompt_caching_set: - betas.add("prompt-caching-2024-07-31") - if computer_tool_used: - betas.add("computer-use-2024-10-22") - if pdf_used: - betas.add("pdfs-2024-09-25") - headers = { - "anthropic-version": anthropic_version or "2023-06-01", - "x-api-key": api_key, - "accept": "application/json", - "content-type": "application/json", - } - - if user_anthropic_beta_headers is not None: - betas.update(user_anthropic_beta_headers) - - # Don't send any beta headers to Vertex, Vertex has failed requests when they are sent - if is_vertex_request is True: - pass - elif len(betas) > 0: - headers["anthropic-beta"] = ",".join(betas) - - return headers - def _map_tool_choice( self, tool_choice: Optional[str], parallel_tool_use: Optional[bool] ) -> Optional[AnthropicMessagesToolChoice]: @@ -446,49 +411,6 @@ class AnthropicConfig(BaseConfig): ) return _tool - def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool: - """ - Return if {"cache_control": ..} in message content block - - Used to check if anthropic prompt caching headers need to be set. - """ - for message in messages: - if message.get("cache_control", None) is not None: - return True - _message_content = message.get("content") - if _message_content is not None and isinstance(_message_content, list): - for content in _message_content: - if "cache_control" in content: - return True - - return False - - def is_computer_tool_used( - self, tools: Optional[List[AllAnthropicToolsValues]] - ) -> bool: - if tools is None: - return False - for tool in tools: - if "type" in tool and tool["type"].startswith("computer_"): - return True - return False - - def is_pdf_used(self, messages: List[AllMessageValues]) -> bool: - """ - Set to true if media passed into messages. - - """ - for message in messages: - if ( - "content" in message - and message["content"] is not None - and isinstance(message["content"], list) - ): - for content in message["content"]: - if "type" in content and content["type"] != "text": - return True - return False - def translate_system_message( self, messages: List[AllMessageValues] ) -> List[AnthropicSystemMessageContent]: @@ -862,47 +784,3 @@ class AnthropicConfig(BaseConfig): message=error_message, headers=cast(httpx.Headers, headers), ) - - def _get_user_anthropic_beta_headers( - self, anthropic_beta_header: Optional[str] - ) -> Optional[List[str]]: - if anthropic_beta_header is None: - return None - return anthropic_beta_header.split(",") - - def validate_environment( - self, - headers: dict, - model: str, - messages: List[AllMessageValues], - optional_params: dict, - litellm_params: dict, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - ) -> Dict: - if api_key is None: - raise litellm.AuthenticationError( - message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars", - llm_provider="anthropic", - model=model, - ) - - tools = optional_params.get("tools") - prompt_caching_set = self.is_cache_control_set(messages=messages) - computer_tool_used = self.is_computer_tool_used(tools=tools) - pdf_used = self.is_pdf_used(messages=messages) - user_anthropic_beta_headers = self._get_user_anthropic_beta_headers( - anthropic_beta_header=headers.get("anthropic-beta") - ) - anthropic_headers = self.get_anthropic_headers( - computer_tool_used=computer_tool_used, - prompt_caching_set=prompt_caching_set, - pdf_used=pdf_used, - api_key=api_key, - is_vertex_request=optional_params.get("is_vertex_request", False), - user_anthropic_beta_headers=user_anthropic_beta_headers, - ) - - headers = {**headers, **anthropic_headers} - - return headers diff --git a/litellm/llms/anthropic/common_utils.py b/litellm/llms/anthropic/common_utils.py index 9eae6734ff..bacd2a54d0 100644 --- a/litellm/llms/anthropic/common_utils.py +++ b/litellm/llms/anthropic/common_utils.py @@ -2,7 +2,7 @@ This file contains common utils for anthropic calls. """ -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import httpx @@ -10,6 +10,8 @@ import litellm from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.anthropic import AllAnthropicToolsValues +from litellm.types.llms.openai import AllMessageValues class AnthropicError(BaseLLMException): @@ -23,6 +25,128 @@ class AnthropicError(BaseLLMException): class AnthropicModelInfo(BaseLLMModelInfo): + def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool: + """ + Return if {"cache_control": ..} in message content block + + Used to check if anthropic prompt caching headers need to be set. + """ + for message in messages: + if message.get("cache_control", None) is not None: + return True + _message_content = message.get("content") + if _message_content is not None and isinstance(_message_content, list): + for content in _message_content: + if "cache_control" in content: + return True + + return False + + def is_computer_tool_used( + self, tools: Optional[List[AllAnthropicToolsValues]] + ) -> bool: + if tools is None: + return False + for tool in tools: + if "type" in tool and tool["type"].startswith("computer_"): + return True + return False + + def is_pdf_used(self, messages: List[AllMessageValues]) -> bool: + """ + Set to true if media passed into messages. + + """ + for message in messages: + if ( + "content" in message + and message["content"] is not None + and isinstance(message["content"], list) + ): + for content in message["content"]: + if "type" in content and content["type"] != "text": + return True + return False + + def _get_user_anthropic_beta_headers( + self, anthropic_beta_header: Optional[str] + ) -> Optional[List[str]]: + if anthropic_beta_header is None: + return None + return anthropic_beta_header.split(",") + + def get_anthropic_headers( + self, + api_key: str, + anthropic_version: Optional[str] = None, + computer_tool_used: bool = False, + prompt_caching_set: bool = False, + pdf_used: bool = False, + is_vertex_request: bool = False, + user_anthropic_beta_headers: Optional[List[str]] = None, + ) -> dict: + betas = set() + if prompt_caching_set: + betas.add("prompt-caching-2024-07-31") + if computer_tool_used: + betas.add("computer-use-2024-10-22") + if pdf_used: + betas.add("pdfs-2024-09-25") + headers = { + "anthropic-version": anthropic_version or "2023-06-01", + "x-api-key": api_key, + "accept": "application/json", + "content-type": "application/json", + } + + if user_anthropic_beta_headers is not None: + betas.update(user_anthropic_beta_headers) + + # Don't send any beta headers to Vertex, Vertex has failed requests when they are sent + if is_vertex_request is True: + pass + elif len(betas) > 0: + headers["anthropic-beta"] = ",".join(betas) + + return headers + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> Dict: + if api_key is None: + raise litellm.AuthenticationError( + message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars", + llm_provider="anthropic", + model=model, + ) + + tools = optional_params.get("tools") + prompt_caching_set = self.is_cache_control_set(messages=messages) + computer_tool_used = self.is_computer_tool_used(tools=tools) + pdf_used = self.is_pdf_used(messages=messages) + user_anthropic_beta_headers = self._get_user_anthropic_beta_headers( + anthropic_beta_header=headers.get("anthropic-beta") + ) + anthropic_headers = self.get_anthropic_headers( + computer_tool_used=computer_tool_used, + prompt_caching_set=prompt_caching_set, + pdf_used=pdf_used, + api_key=api_key, + is_vertex_request=optional_params.get("is_vertex_request", False), + user_anthropic_beta_headers=user_anthropic_beta_headers, + ) + + headers = {**headers, **anthropic_headers} + + return headers + @staticmethod def get_api_base(api_base: Optional[str] = None) -> Optional[str]: return ( diff --git a/litellm/llms/base_llm/base_utils.py b/litellm/llms/base_llm/base_utils.py index 5b175f4756..712f5de8cc 100644 --- a/litellm/llms/base_llm/base_utils.py +++ b/litellm/llms/base_llm/base_utils.py @@ -44,6 +44,19 @@ class BaseLLMModelInfo(ABC): def get_api_base(api_base: Optional[str] = None) -> Optional[str]: pass + @abstractmethod + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + pass + @staticmethod @abstractmethod def get_base_model(model: str) -> Optional[str]: diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index 03257e50f0..e8f60357a6 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -389,7 +389,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig): ) @staticmethod - def get_base_model(model: str) -> str: + def get_base_model(model: Optional[str] = None) -> Optional[str]: return model def get_model_response_iterator( diff --git a/litellm/llms/topaz/common_utils.py b/litellm/llms/topaz/common_utils.py index 0252585922..95fe291493 100644 --- a/litellm/llms/topaz/common_utils.py +++ b/litellm/llms/topaz/common_utils.py @@ -1,6 +1,7 @@ from typing import List, Optional from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues from ..base_llm.base_utils import BaseLLMModelInfo from ..base_llm.chat.transformation import BaseLLMException @@ -11,6 +12,26 @@ class TopazException(BaseLLMException): class TopazModelInfo(BaseLLMModelInfo): + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if api_key is None: + raise ValueError( + "API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`" + ) + return { + # "Content-Type": "multipart/form-data", + "Accept": "image/jpeg", + "X-API-Key": api_key, + } + def get_models( self, api_key: Optional[str] = None, api_base: Optional[str] = None ) -> List[str]: diff --git a/litellm/llms/topaz/image_variations/transformation.py b/litellm/llms/topaz/image_variations/transformation.py index afbd89b9bc..41b51a558c 100644 --- a/litellm/llms/topaz/image_variations/transformation.py +++ b/litellm/llms/topaz/image_variations/transformation.py @@ -10,10 +10,7 @@ from litellm.llms.base_llm.chat.transformation import ( BaseLLMException, LiteLLMLoggingObj, ) -from litellm.types.llms.openai import ( - AllMessageValues, - OpenAIImageVariationOptionalParams, -) +from litellm.types.llms.openai import OpenAIImageVariationOptionalParams from litellm.types.utils import ( FileTypes, HttpHandlerRequestFields, @@ -22,35 +19,15 @@ from litellm.types.utils import ( ) from ...base_llm.image_variations.transformation import BaseImageVariationConfig -from ..common_utils import TopazException +from ..common_utils import TopazException, TopazModelInfo -class TopazImageVariationConfig(BaseImageVariationConfig): +class TopazImageVariationConfig(TopazModelInfo, BaseImageVariationConfig): def get_supported_openai_params( self, model: str ) -> List[OpenAIImageVariationOptionalParams]: return ["response_format", "size"] - def validate_environment( - self, - headers: dict, - model: str, - messages: List[AllMessageValues], - optional_params: dict, - litellm_params: dict, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - ) -> dict: - if api_key is None: - raise ValueError( - "API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`" - ) - return { - # "Content-Type": "multipart/form-data", - "Accept": "image/jpeg", - "X-API-Key": api_key, - } - def get_complete_url( self, api_base: Optional[str], diff --git a/litellm/llms/vllm/common_utils.py b/litellm/llms/vllm/common_utils.py new file mode 100644 index 0000000000..8dca3e1de2 --- /dev/null +++ b/litellm/llms/vllm/common_utils.py @@ -0,0 +1,75 @@ +from typing import List, Optional, Union + +import httpx + +import litellm +from litellm.llms.base_llm.base_utils import BaseLLMModelInfo +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues +from litellm.utils import _add_path_to_api_base + + +class VLLMError(BaseLLMException): + pass + + +class VLLMModelInfo(BaseLLMModelInfo): + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + """Google AI Studio sends api key in query params""" + return headers + + @staticmethod + def get_api_base(api_base: Optional[str] = None) -> Optional[str]: + api_base = api_base or get_secret_str("VLLM_API_BASE") + if api_base is None: + raise ValueError( + "VLLM_API_BASE is not set. Please set the environment variable, to use VLLM's pass-through - `{LITELLM_API_BASE}/vllm/{endpoint}`." + ) + return api_base + + @staticmethod + def get_api_key(api_key: Optional[str] = None) -> Optional[str]: + return None + + @staticmethod + def get_base_model(model: str) -> Optional[str]: + return model + + def get_models( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> List[str]: + api_base = VLLMModelInfo.get_api_base(api_base) + api_key = VLLMModelInfo.get_api_key(api_key) + endpoint = "/v1/models" + if api_base is None or api_key is None: + raise ValueError( + "GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint." + ) + + url = _add_path_to_api_base(api_base, endpoint) + response = litellm.module_level_client.get( + url=url, + ) + + response.raise_for_status() + + models = response.json()["data"] + + return [model["id"] for model in models] + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return VLLMError( + status_code=status_code, message=error_message, headers=headers + ) diff --git a/litellm/llms/xai/common_utils.py b/litellm/llms/xai/common_utils.py index d7ceeadd95..a26dc1e043 100644 --- a/litellm/llms/xai/common_utils.py +++ b/litellm/llms/xai/common_utils.py @@ -5,9 +5,29 @@ import httpx import litellm from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues class XAIModelInfo(BaseLLMModelInfo): + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + # Ensure Content-Type is set to application/json + if "content-type" not in headers and "Content-Type" not in headers: + headers["Content-Type"] = "application/json" + + return headers + @staticmethod def get_api_base(api_base: Optional[str] = None) -> Optional[str]: return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai" diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 81c69dc6e6..e0bdfdb649 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -317,6 +317,8 @@ class LiteLLMRoutes(enum.Enum): "/openai", "/assemblyai", "/eu.assemblyai", + "/vllm", + "/mistral", ] anthropic_routes = [ diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index bfdb324b92..a7cb9d9e41 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -6,6 +6,7 @@ Provider-specific Pass-Through Endpoints Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc. """ +import os from typing import Optional import httpx @@ -43,6 +44,84 @@ def create_request_copy(request: Request): } +async def llm_passthrough_factory_proxy_route( + custom_llm_provider: str, + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Factory function for creating pass-through endpoints for LLM providers. + """ + from litellm.types.utils import LlmProviders + from litellm.utils import ProviderConfigManager + + provider_config = ProviderConfigManager.get_provider_model_info( + provider=LlmProviders(custom_llm_provider), + model=None, + ) + if provider_config is None: + raise HTTPException( + status_code=404, detail=f"Provider {custom_llm_provider} not found" + ) + base_target_url = provider_config.get_api_base() + + if base_target_url is None: + raise HTTPException( + status_code=404, detail=f"Provider {custom_llm_provider} api base not found" + ) + + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + provider_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider=custom_llm_provider, + region_name=None, + ) + + auth_headers = provider_config.validate_environment( + headers={}, + model="", + messages=[], + optional_params={}, + litellm_params={}, + api_key=provider_api_key, + api_base=base_target_url, + ) + + ## check for streaming + is_streaming_request = False + # anthropic is streaming when 'stream' = True is in the body + if request.method == "POST": + _request_body = await request.json() + if _request_body.get("stream"): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers=auth_headers, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + ) + + return received_value + + @router.api_route( "/gemini/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"], @@ -162,6 +241,84 @@ async def cohere_proxy_route( return received_value +@router.api_route( + "/vllm/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["VLLM Pass-through", "pass-through"], +) +async def vllm_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Docs](https://docs.litellm.ai/docs/pass_through/vllm) + """ + return await llm_passthrough_factory_proxy_route( + endpoint=endpoint, + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + custom_llm_provider="vllm", + ) + + +@router.api_route( + "/mistral/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Mistral Pass-through", "pass-through"], +) +async def mistral_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Docs](https://docs.litellm.ai/docs/anthropic_completion) + """ + base_target_url = os.getenv("MISTRAL_API_BASE") or "https://api.mistral.ai" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + mistral_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="mistral", + region_name=None, + ) + + ## check for streaming + is_streaming_request = False + # anthropic is streaming when 'stream' = True is in the body + if request.method == "POST": + _request_body = await request.json() + if _request_body.get("stream"): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers={"Authorization": "Bearer {}".format(mistral_api_key)}, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + ) + + return received_value + + @router.api_route( "/anthropic/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"], diff --git a/litellm/utils.py b/litellm/utils.py index b31b929b27..340d06f930 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -516,9 +516,9 @@ def function_setup( # noqa: PLR0915 function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None ## DYNAMIC CALLBACKS ## - dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = ( - kwargs.pop("callbacks", None) - ) + dynamic_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = kwargs.pop("callbacks", None) all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks) if len(all_callbacks) > 0: @@ -1202,9 +1202,9 @@ def client(original_function): # noqa: PLR0915 exception=e, retry_policy=kwargs.get("retry_policy"), ) - kwargs["retry_policy"] = ( - reset_retry_policy() - ) # prevent infinite loops + kwargs[ + "retry_policy" + ] = reset_retry_policy() # prevent infinite loops litellm.num_retries = ( None # set retries to None to prevent infinite loops ) @@ -3013,16 +3013,16 @@ def get_optional_params( # noqa: PLR0915 True # so that main.py adds the function call to the prompt ) if "tools" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("tools") - ) + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("tools") non_default_params.pop( "tool_choice", None ) # causes ollama requests to hang elif "functions" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("functions") - ) + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("functions") elif ( litellm.add_function_to_prompt ): # if user opts to add it to prompt instead @@ -3045,10 +3045,10 @@ def get_optional_params( # noqa: PLR0915 if "response_format" in non_default_params: if provider_config is not None: - non_default_params["response_format"] = ( - provider_config.get_json_schema_from_pydantic_object( - response_format=non_default_params["response_format"] - ) + non_default_params[ + "response_format" + ] = provider_config.get_json_schema_from_pydantic_object( + response_format=non_default_params["response_format"] ) else: non_default_params["response_format"] = type_to_response_format_param( @@ -4064,9 +4064,9 @@ def _count_characters(text: str) -> int: def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str: - _choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = ( - response_obj.choices - ) + _choices: Union[ + List[Union[Choices, StreamingChoices]], List[StreamingChoices] + ] = response_obj.choices response_str = "" for choice in _choices: @@ -4458,14 +4458,14 @@ def _get_model_info_helper( # noqa: PLR0915 if combined_model_name in litellm.model_cost: key = combined_model_name - _model_info = _get_model_info_from_model_cost(key=key) + _model_info = _get_model_info_from_model_cost(key=cast(str, key)) if not _check_provider_match( model_info=_model_info, custom_llm_provider=custom_llm_provider ): _model_info = None if _model_info is None and model in litellm.model_cost: key = model - _model_info = _get_model_info_from_model_cost(key=key) + _model_info = _get_model_info_from_model_cost(key=cast(str, key)) if not _check_provider_match( model_info=_model_info, custom_llm_provider=custom_llm_provider ): @@ -4475,21 +4475,21 @@ def _get_model_info_helper( # noqa: PLR0915 and combined_stripped_model_name in litellm.model_cost ): key = combined_stripped_model_name - _model_info = _get_model_info_from_model_cost(key=key) + _model_info = _get_model_info_from_model_cost(key=cast(str, key)) if not _check_provider_match( model_info=_model_info, custom_llm_provider=custom_llm_provider ): _model_info = None if _model_info is None and stripped_model_name in litellm.model_cost: key = stripped_model_name - _model_info = _get_model_info_from_model_cost(key=key) + _model_info = _get_model_info_from_model_cost(key=cast(str, key)) if not _check_provider_match( model_info=_model_info, custom_llm_provider=custom_llm_provider ): _model_info = None if _model_info is None and split_model in litellm.model_cost: key = split_model - _model_info = _get_model_info_from_model_cost(key=key) + _model_info = _get_model_info_from_model_cost(key=cast(str, key)) if not _check_provider_match( model_info=_model_info, custom_llm_provider=custom_llm_provider ): @@ -6510,7 +6510,12 @@ class ProviderConfigManager: return litellm.AnthropicModelInfo() elif LlmProviders.XAI == provider: return litellm.XAIModelInfo() + elif LlmProviders.VLLM == provider: + from litellm.llms.vllm.common_utils import ( + VLLMModelInfo, # experimental approach, to reduce bloat on __init__.py + ) + return VLLMModelInfo() return None @staticmethod