mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Add /vllm/*
and /mistral/*
passthrough endpoints (adds support for Mistral OCR via passthrough)
* feat(llm_passthrough_endpoints.py): support mistral passthrough Closes https://github.com/BerriAI/litellm/issues/9051 * feat(llm_passthrough_endpoints.py): initial commit for adding vllm passthrough route * feat(vllm/common_utils.py): add new vllm model info route make it possible to use vllm passthrough route via factory function * fix(llm_passthrough_endpoints.py): add all methods to vllm passthrough route * fix: fix linting error * fix: fix linting error * fix: fix ruff check * fix(proxy/_types.py): add new passthrough routes * docs(config_settings.md): add mistral env vars to docs
This commit is contained in:
parent
8faf56922c
commit
9b0f871129
12 changed files with 450 additions and 176 deletions
|
@ -449,6 +449,8 @@ router_settings:
|
|||
| LITELLM_TOKEN | Access token for LiteLLM integration
|
||||
| LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging
|
||||
| LOGFIRE_TOKEN | Token for Logfire logging service
|
||||
| MISTRAL_API_BASE | Base URL for Mistral API
|
||||
| MISTRAL_API_KEY | API key for Mistral API
|
||||
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
||||
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
||||
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure
|
||||
|
|
|
@ -44,7 +44,7 @@ from litellm.utils import (
|
|||
token_counter,
|
||||
)
|
||||
|
||||
from ..common_utils import AnthropicError, process_anthropic_headers
|
||||
from ..common_utils import AnthropicError, AnthropicModelInfo, process_anthropic_headers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
@ -54,7 +54,7 @@ else:
|
|||
LoggingClass = Any
|
||||
|
||||
|
||||
class AnthropicConfig(BaseConfig):
|
||||
class AnthropicConfig(AnthropicModelInfo, BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
|
@ -127,41 +127,6 @@ class AnthropicConfig(BaseConfig):
|
|||
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||
}
|
||||
|
||||
def get_anthropic_headers(
|
||||
self,
|
||||
api_key: str,
|
||||
anthropic_version: Optional[str] = None,
|
||||
computer_tool_used: bool = False,
|
||||
prompt_caching_set: bool = False,
|
||||
pdf_used: bool = False,
|
||||
is_vertex_request: bool = False,
|
||||
user_anthropic_beta_headers: Optional[List[str]] = None,
|
||||
) -> dict:
|
||||
betas = set()
|
||||
if prompt_caching_set:
|
||||
betas.add("prompt-caching-2024-07-31")
|
||||
if computer_tool_used:
|
||||
betas.add("computer-use-2024-10-22")
|
||||
if pdf_used:
|
||||
betas.add("pdfs-2024-09-25")
|
||||
headers = {
|
||||
"anthropic-version": anthropic_version or "2023-06-01",
|
||||
"x-api-key": api_key,
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
if user_anthropic_beta_headers is not None:
|
||||
betas.update(user_anthropic_beta_headers)
|
||||
|
||||
# Don't send any beta headers to Vertex, Vertex has failed requests when they are sent
|
||||
if is_vertex_request is True:
|
||||
pass
|
||||
elif len(betas) > 0:
|
||||
headers["anthropic-beta"] = ",".join(betas)
|
||||
|
||||
return headers
|
||||
|
||||
def _map_tool_choice(
|
||||
self, tool_choice: Optional[str], parallel_tool_use: Optional[bool]
|
||||
) -> Optional[AnthropicMessagesToolChoice]:
|
||||
|
@ -446,49 +411,6 @@ class AnthropicConfig(BaseConfig):
|
|||
)
|
||||
return _tool
|
||||
|
||||
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Return if {"cache_control": ..} in message content block
|
||||
|
||||
Used to check if anthropic prompt caching headers need to be set.
|
||||
"""
|
||||
for message in messages:
|
||||
if message.get("cache_control", None) is not None:
|
||||
return True
|
||||
_message_content = message.get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
for content in _message_content:
|
||||
if "cache_control" in content:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_computer_tool_used(
|
||||
self, tools: Optional[List[AllAnthropicToolsValues]]
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools:
|
||||
if "type" in tool and tool["type"].startswith("computer_"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Set to true if media passed into messages.
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
if (
|
||||
"content" in message
|
||||
and message["content"] is not None
|
||||
and isinstance(message["content"], list)
|
||||
):
|
||||
for content in message["content"]:
|
||||
if "type" in content and content["type"] != "text":
|
||||
return True
|
||||
return False
|
||||
|
||||
def translate_system_message(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AnthropicSystemMessageContent]:
|
||||
|
@ -862,47 +784,3 @@ class AnthropicConfig(BaseConfig):
|
|||
message=error_message,
|
||||
headers=cast(httpx.Headers, headers),
|
||||
)
|
||||
|
||||
def _get_user_anthropic_beta_headers(
|
||||
self, anthropic_beta_header: Optional[str]
|
||||
) -> Optional[List[str]]:
|
||||
if anthropic_beta_header is None:
|
||||
return None
|
||||
return anthropic_beta_header.split(",")
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
if api_key is None:
|
||||
raise litellm.AuthenticationError(
|
||||
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
)
|
||||
|
||||
tools = optional_params.get("tools")
|
||||
prompt_caching_set = self.is_cache_control_set(messages=messages)
|
||||
computer_tool_used = self.is_computer_tool_used(tools=tools)
|
||||
pdf_used = self.is_pdf_used(messages=messages)
|
||||
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
|
||||
anthropic_beta_header=headers.get("anthropic-beta")
|
||||
)
|
||||
anthropic_headers = self.get_anthropic_headers(
|
||||
computer_tool_used=computer_tool_used,
|
||||
prompt_caching_set=prompt_caching_set,
|
||||
pdf_used=pdf_used,
|
||||
api_key=api_key,
|
||||
is_vertex_request=optional_params.get("is_vertex_request", False),
|
||||
user_anthropic_beta_headers=user_anthropic_beta_headers,
|
||||
)
|
||||
|
||||
headers = {**headers, **anthropic_headers}
|
||||
|
||||
return headers
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
This file contains common utils for anthropic calls.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -10,6 +10,8 @@ import litellm
|
|||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.anthropic import AllAnthropicToolsValues
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class AnthropicError(BaseLLMException):
|
||||
|
@ -23,6 +25,128 @@ class AnthropicError(BaseLLMException):
|
|||
|
||||
|
||||
class AnthropicModelInfo(BaseLLMModelInfo):
|
||||
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Return if {"cache_control": ..} in message content block
|
||||
|
||||
Used to check if anthropic prompt caching headers need to be set.
|
||||
"""
|
||||
for message in messages:
|
||||
if message.get("cache_control", None) is not None:
|
||||
return True
|
||||
_message_content = message.get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
for content in _message_content:
|
||||
if "cache_control" in content:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_computer_tool_used(
|
||||
self, tools: Optional[List[AllAnthropicToolsValues]]
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools:
|
||||
if "type" in tool and tool["type"].startswith("computer_"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Set to true if media passed into messages.
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
if (
|
||||
"content" in message
|
||||
and message["content"] is not None
|
||||
and isinstance(message["content"], list)
|
||||
):
|
||||
for content in message["content"]:
|
||||
if "type" in content and content["type"] != "text":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_user_anthropic_beta_headers(
|
||||
self, anthropic_beta_header: Optional[str]
|
||||
) -> Optional[List[str]]:
|
||||
if anthropic_beta_header is None:
|
||||
return None
|
||||
return anthropic_beta_header.split(",")
|
||||
|
||||
def get_anthropic_headers(
|
||||
self,
|
||||
api_key: str,
|
||||
anthropic_version: Optional[str] = None,
|
||||
computer_tool_used: bool = False,
|
||||
prompt_caching_set: bool = False,
|
||||
pdf_used: bool = False,
|
||||
is_vertex_request: bool = False,
|
||||
user_anthropic_beta_headers: Optional[List[str]] = None,
|
||||
) -> dict:
|
||||
betas = set()
|
||||
if prompt_caching_set:
|
||||
betas.add("prompt-caching-2024-07-31")
|
||||
if computer_tool_used:
|
||||
betas.add("computer-use-2024-10-22")
|
||||
if pdf_used:
|
||||
betas.add("pdfs-2024-09-25")
|
||||
headers = {
|
||||
"anthropic-version": anthropic_version or "2023-06-01",
|
||||
"x-api-key": api_key,
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
if user_anthropic_beta_headers is not None:
|
||||
betas.update(user_anthropic_beta_headers)
|
||||
|
||||
# Don't send any beta headers to Vertex, Vertex has failed requests when they are sent
|
||||
if is_vertex_request is True:
|
||||
pass
|
||||
elif len(betas) > 0:
|
||||
headers["anthropic-beta"] = ",".join(betas)
|
||||
|
||||
return headers
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
if api_key is None:
|
||||
raise litellm.AuthenticationError(
|
||||
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
)
|
||||
|
||||
tools = optional_params.get("tools")
|
||||
prompt_caching_set = self.is_cache_control_set(messages=messages)
|
||||
computer_tool_used = self.is_computer_tool_used(tools=tools)
|
||||
pdf_used = self.is_pdf_used(messages=messages)
|
||||
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
|
||||
anthropic_beta_header=headers.get("anthropic-beta")
|
||||
)
|
||||
anthropic_headers = self.get_anthropic_headers(
|
||||
computer_tool_used=computer_tool_used,
|
||||
prompt_caching_set=prompt_caching_set,
|
||||
pdf_used=pdf_used,
|
||||
api_key=api_key,
|
||||
is_vertex_request=optional_params.get("is_vertex_request", False),
|
||||
user_anthropic_beta_headers=user_anthropic_beta_headers,
|
||||
)
|
||||
|
||||
headers = {**headers, **anthropic_headers}
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
|
|
|
@ -44,6 +44,19 @@ class BaseLLMModelInfo(ABC):
|
|||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
|
|
|
@ -389,7 +389,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
def get_base_model(model: Optional[str] = None) -> Optional[str]:
|
||||
return model
|
||||
|
||||
def get_model_response_iterator(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
from ..base_llm.base_utils import BaseLLMModelInfo
|
||||
from ..base_llm.chat.transformation import BaseLLMException
|
||||
|
@ -11,6 +12,26 @@ class TopazException(BaseLLMException):
|
|||
|
||||
|
||||
class TopazModelInfo(BaseLLMModelInfo):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
|
||||
)
|
||||
return {
|
||||
# "Content-Type": "multipart/form-data",
|
||||
"Accept": "image/jpeg",
|
||||
"X-API-Key": api_key,
|
||||
}
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
|
|
|
@ -10,10 +10,7 @@ from litellm.llms.base_llm.chat.transformation import (
|
|||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageVariationOptionalParams,
|
||||
)
|
||||
from litellm.types.llms.openai import OpenAIImageVariationOptionalParams
|
||||
from litellm.types.utils import (
|
||||
FileTypes,
|
||||
HttpHandlerRequestFields,
|
||||
|
@ -22,35 +19,15 @@ from litellm.types.utils import (
|
|||
)
|
||||
|
||||
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
|
||||
from ..common_utils import TopazException
|
||||
from ..common_utils import TopazException, TopazModelInfo
|
||||
|
||||
|
||||
class TopazImageVariationConfig(BaseImageVariationConfig):
|
||||
class TopazImageVariationConfig(TopazModelInfo, BaseImageVariationConfig):
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageVariationOptionalParams]:
|
||||
return ["response_format", "size"]
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
|
||||
)
|
||||
return {
|
||||
# "Content-Type": "multipart/form-data",
|
||||
"Accept": "image/jpeg",
|
||||
"X-API-Key": api_key,
|
||||
}
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
|
|
75
litellm/llms/vllm/common_utils.py
Normal file
75
litellm/llms/vllm/common_utils.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
|
||||
class VLLMError(BaseLLMException):
|
||||
pass
|
||||
|
||||
|
||||
class VLLMModelInfo(BaseLLMModelInfo):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Google AI Studio sends api key in query params"""
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
api_base = api_base or get_secret_str("VLLM_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"VLLM_API_BASE is not set. Please set the environment variable, to use VLLM's pass-through - `{LITELLM_API_BASE}/vllm/{endpoint}`."
|
||||
)
|
||||
return api_base
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
return model
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
api_base = VLLMModelInfo.get_api_base(api_base)
|
||||
api_key = VLLMModelInfo.get_api_key(api_key)
|
||||
endpoint = "/v1/models"
|
||||
if api_base is None or api_key is None:
|
||||
raise ValueError(
|
||||
"GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
|
||||
)
|
||||
|
||||
url = _add_path_to_api_base(api_base, endpoint)
|
||||
response = litellm.module_level_client.get(
|
||||
url=url,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
models = response.json()["data"]
|
||||
|
||||
return [model["id"] for model in models]
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return VLLMError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
|
@ -5,9 +5,29 @@ import httpx
|
|||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class XAIModelInfo(BaseLLMModelInfo):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Ensure Content-Type is set to application/json
|
||||
if "content-type" not in headers and "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai"
|
||||
|
|
|
@ -317,6 +317,8 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/openai",
|
||||
"/assemblyai",
|
||||
"/eu.assemblyai",
|
||||
"/vllm",
|
||||
"/mistral",
|
||||
]
|
||||
|
||||
anthropic_routes = [
|
||||
|
|
|
@ -6,6 +6,7 @@ Provider-specific Pass-Through Endpoints
|
|||
Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
@ -43,6 +44,84 @@ def create_request_copy(request: Request):
|
|||
}
|
||||
|
||||
|
||||
async def llm_passthrough_factory_proxy_route(
|
||||
custom_llm_provider: str,
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Factory function for creating pass-through endpoints for LLM providers.
|
||||
"""
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
model=None,
|
||||
)
|
||||
if provider_config is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider {custom_llm_provider} not found"
|
||||
)
|
||||
base_target_url = provider_config.get_api_base()
|
||||
|
||||
if base_target_url is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider {custom_llm_provider} api base not found"
|
||||
)
|
||||
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
provider_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
auth_headers = provider_config.validate_environment(
|
||||
headers={},
|
||||
model="",
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=provider_api_key,
|
||||
api_base=base_target_url,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# anthropic is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers=auth_headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/gemini/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
|
@ -162,6 +241,84 @@ async def cohere_proxy_route(
|
|||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vllm/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["VLLM Pass-through", "pass-through"],
|
||||
)
|
||||
async def vllm_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/vllm)
|
||||
"""
|
||||
return await llm_passthrough_factory_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
custom_llm_provider="vllm",
|
||||
)
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/mistral/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Mistral Pass-through", "pass-through"],
|
||||
)
|
||||
async def mistral_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
|
||||
"""
|
||||
base_target_url = os.getenv("MISTRAL_API_BASE") or "https://api.mistral.ai"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
mistral_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="mistral",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# anthropic is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"Authorization": "Bearer {}".format(mistral_api_key)},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/anthropic/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
|
|
|
@ -516,9 +516,9 @@ def function_setup( # noqa: PLR0915
|
|||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
## DYNAMIC CALLBACKS ##
|
||||
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
|
||||
kwargs.pop("callbacks", None)
|
||||
)
|
||||
dynamic_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = kwargs.pop("callbacks", None)
|
||||
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
|
||||
|
||||
if len(all_callbacks) > 0:
|
||||
|
@ -1202,9 +1202,9 @@ def client(original_function): # noqa: PLR0915
|
|||
exception=e,
|
||||
retry_policy=kwargs.get("retry_policy"),
|
||||
)
|
||||
kwargs["retry_policy"] = (
|
||||
reset_retry_policy()
|
||||
) # prevent infinite loops
|
||||
kwargs[
|
||||
"retry_policy"
|
||||
] = reset_retry_policy() # prevent infinite loops
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
)
|
||||
|
@ -3013,16 +3013,16 @@ def get_optional_params( # noqa: PLR0915
|
|||
True # so that main.py adds the function call to the prompt
|
||||
)
|
||||
if "tools" in non_default_params:
|
||||
optional_params["functions_unsupported_model"] = (
|
||||
non_default_params.pop("tools")
|
||||
)
|
||||
optional_params[
|
||||
"functions_unsupported_model"
|
||||
] = non_default_params.pop("tools")
|
||||
non_default_params.pop(
|
||||
"tool_choice", None
|
||||
) # causes ollama requests to hang
|
||||
elif "functions" in non_default_params:
|
||||
optional_params["functions_unsupported_model"] = (
|
||||
non_default_params.pop("functions")
|
||||
)
|
||||
optional_params[
|
||||
"functions_unsupported_model"
|
||||
] = non_default_params.pop("functions")
|
||||
elif (
|
||||
litellm.add_function_to_prompt
|
||||
): # if user opts to add it to prompt instead
|
||||
|
@ -3045,10 +3045,10 @@ def get_optional_params( # noqa: PLR0915
|
|||
|
||||
if "response_format" in non_default_params:
|
||||
if provider_config is not None:
|
||||
non_default_params["response_format"] = (
|
||||
provider_config.get_json_schema_from_pydantic_object(
|
||||
response_format=non_default_params["response_format"]
|
||||
)
|
||||
non_default_params[
|
||||
"response_format"
|
||||
] = provider_config.get_json_schema_from_pydantic_object(
|
||||
response_format=non_default_params["response_format"]
|
||||
)
|
||||
else:
|
||||
non_default_params["response_format"] = type_to_response_format_param(
|
||||
|
@ -4064,9 +4064,9 @@ def _count_characters(text: str) -> int:
|
|||
|
||||
|
||||
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
|
||||
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = (
|
||||
response_obj.choices
|
||||
)
|
||||
_choices: Union[
|
||||
List[Union[Choices, StreamingChoices]], List[StreamingChoices]
|
||||
] = response_obj.choices
|
||||
|
||||
response_str = ""
|
||||
for choice in _choices:
|
||||
|
@ -4458,14 +4458,14 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
|
||||
if combined_model_name in litellm.model_cost:
|
||||
key = combined_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
_model_info = None
|
||||
if _model_info is None and model in litellm.model_cost:
|
||||
key = model
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
|
@ -4475,21 +4475,21 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
and combined_stripped_model_name in litellm.model_cost
|
||||
):
|
||||
key = combined_stripped_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
_model_info = None
|
||||
if _model_info is None and stripped_model_name in litellm.model_cost:
|
||||
key = stripped_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
_model_info = None
|
||||
if _model_info is None and split_model in litellm.model_cost:
|
||||
key = split_model
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
|
@ -6510,7 +6510,12 @@ class ProviderConfigManager:
|
|||
return litellm.AnthropicModelInfo()
|
||||
elif LlmProviders.XAI == provider:
|
||||
return litellm.XAIModelInfo()
|
||||
elif LlmProviders.VLLM == provider:
|
||||
from litellm.llms.vllm.common_utils import (
|
||||
VLLMModelInfo, # experimental approach, to reduce bloat on __init__.py
|
||||
)
|
||||
|
||||
return VLLMModelInfo()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue