mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Add /vllm/*
and /mistral/*
passthrough endpoints (adds support for Mistral OCR via passthrough)
* feat(llm_passthrough_endpoints.py): support mistral passthrough Closes https://github.com/BerriAI/litellm/issues/9051 * feat(llm_passthrough_endpoints.py): initial commit for adding vllm passthrough route * feat(vllm/common_utils.py): add new vllm model info route make it possible to use vllm passthrough route via factory function * fix(llm_passthrough_endpoints.py): add all methods to vllm passthrough route * fix: fix linting error * fix: fix linting error * fix: fix ruff check * fix(proxy/_types.py): add new passthrough routes * docs(config_settings.md): add mistral env vars to docs
This commit is contained in:
parent
5fcdf4becf
commit
3031fff297
12 changed files with 450 additions and 176 deletions
|
@ -449,6 +449,8 @@ router_settings:
|
||||||
| LITELLM_TOKEN | Access token for LiteLLM integration
|
| LITELLM_TOKEN | Access token for LiteLLM integration
|
||||||
| LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging
|
| LITELLM_PRINT_STANDARD_LOGGING_PAYLOAD | If true, prints the standard logging payload to the console - useful for debugging
|
||||||
| LOGFIRE_TOKEN | Token for Logfire logging service
|
| LOGFIRE_TOKEN | Token for Logfire logging service
|
||||||
|
| MISTRAL_API_BASE | Base URL for Mistral API
|
||||||
|
| MISTRAL_API_KEY | API key for Mistral API
|
||||||
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
| MICROSOFT_CLIENT_ID | Client ID for Microsoft services
|
||||||
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
| MICROSOFT_CLIENT_SECRET | Client secret for Microsoft services
|
||||||
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure
|
| MICROSOFT_TENANT | Tenant ID for Microsoft Azure
|
||||||
|
|
|
@ -44,7 +44,7 @@ from litellm.utils import (
|
||||||
token_counter,
|
token_counter,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..common_utils import AnthropicError, process_anthropic_headers
|
from ..common_utils import AnthropicError, AnthropicModelInfo, process_anthropic_headers
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
@ -54,7 +54,7 @@ else:
|
||||||
LoggingClass = Any
|
LoggingClass = Any
|
||||||
|
|
||||||
|
|
||||||
class AnthropicConfig(BaseConfig):
|
class AnthropicConfig(AnthropicModelInfo, BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.anthropic.com/claude/reference/messages_post
|
Reference: https://docs.anthropic.com/claude/reference/messages_post
|
||||||
|
|
||||||
|
@ -127,41 +127,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
"anthropic-beta": "prompt-caching-2024-07-31",
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_anthropic_headers(
|
|
||||||
self,
|
|
||||||
api_key: str,
|
|
||||||
anthropic_version: Optional[str] = None,
|
|
||||||
computer_tool_used: bool = False,
|
|
||||||
prompt_caching_set: bool = False,
|
|
||||||
pdf_used: bool = False,
|
|
||||||
is_vertex_request: bool = False,
|
|
||||||
user_anthropic_beta_headers: Optional[List[str]] = None,
|
|
||||||
) -> dict:
|
|
||||||
betas = set()
|
|
||||||
if prompt_caching_set:
|
|
||||||
betas.add("prompt-caching-2024-07-31")
|
|
||||||
if computer_tool_used:
|
|
||||||
betas.add("computer-use-2024-10-22")
|
|
||||||
if pdf_used:
|
|
||||||
betas.add("pdfs-2024-09-25")
|
|
||||||
headers = {
|
|
||||||
"anthropic-version": anthropic_version or "2023-06-01",
|
|
||||||
"x-api-key": api_key,
|
|
||||||
"accept": "application/json",
|
|
||||||
"content-type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
if user_anthropic_beta_headers is not None:
|
|
||||||
betas.update(user_anthropic_beta_headers)
|
|
||||||
|
|
||||||
# Don't send any beta headers to Vertex, Vertex has failed requests when they are sent
|
|
||||||
if is_vertex_request is True:
|
|
||||||
pass
|
|
||||||
elif len(betas) > 0:
|
|
||||||
headers["anthropic-beta"] = ",".join(betas)
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def _map_tool_choice(
|
def _map_tool_choice(
|
||||||
self, tool_choice: Optional[str], parallel_tool_use: Optional[bool]
|
self, tool_choice: Optional[str], parallel_tool_use: Optional[bool]
|
||||||
) -> Optional[AnthropicMessagesToolChoice]:
|
) -> Optional[AnthropicMessagesToolChoice]:
|
||||||
|
@ -446,49 +411,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
return _tool
|
return _tool
|
||||||
|
|
||||||
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
|
|
||||||
"""
|
|
||||||
Return if {"cache_control": ..} in message content block
|
|
||||||
|
|
||||||
Used to check if anthropic prompt caching headers need to be set.
|
|
||||||
"""
|
|
||||||
for message in messages:
|
|
||||||
if message.get("cache_control", None) is not None:
|
|
||||||
return True
|
|
||||||
_message_content = message.get("content")
|
|
||||||
if _message_content is not None and isinstance(_message_content, list):
|
|
||||||
for content in _message_content:
|
|
||||||
if "cache_control" in content:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_computer_tool_used(
|
|
||||||
self, tools: Optional[List[AllAnthropicToolsValues]]
|
|
||||||
) -> bool:
|
|
||||||
if tools is None:
|
|
||||||
return False
|
|
||||||
for tool in tools:
|
|
||||||
if "type" in tool and tool["type"].startswith("computer_"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
|
|
||||||
"""
|
|
||||||
Set to true if media passed into messages.
|
|
||||||
|
|
||||||
"""
|
|
||||||
for message in messages:
|
|
||||||
if (
|
|
||||||
"content" in message
|
|
||||||
and message["content"] is not None
|
|
||||||
and isinstance(message["content"], list)
|
|
||||||
):
|
|
||||||
for content in message["content"]:
|
|
||||||
if "type" in content and content["type"] != "text":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def translate_system_message(
|
def translate_system_message(
|
||||||
self, messages: List[AllMessageValues]
|
self, messages: List[AllMessageValues]
|
||||||
) -> List[AnthropicSystemMessageContent]:
|
) -> List[AnthropicSystemMessageContent]:
|
||||||
|
@ -862,47 +784,3 @@ class AnthropicConfig(BaseConfig):
|
||||||
message=error_message,
|
message=error_message,
|
||||||
headers=cast(httpx.Headers, headers),
|
headers=cast(httpx.Headers, headers),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_anthropic_beta_headers(
|
|
||||||
self, anthropic_beta_header: Optional[str]
|
|
||||||
) -> Optional[List[str]]:
|
|
||||||
if anthropic_beta_header is None:
|
|
||||||
return None
|
|
||||||
return anthropic_beta_header.split(",")
|
|
||||||
|
|
||||||
def validate_environment(
|
|
||||||
self,
|
|
||||||
headers: dict,
|
|
||||||
model: str,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params: dict,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
) -> Dict:
|
|
||||||
if api_key is None:
|
|
||||||
raise litellm.AuthenticationError(
|
|
||||||
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
|
|
||||||
llm_provider="anthropic",
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = optional_params.get("tools")
|
|
||||||
prompt_caching_set = self.is_cache_control_set(messages=messages)
|
|
||||||
computer_tool_used = self.is_computer_tool_used(tools=tools)
|
|
||||||
pdf_used = self.is_pdf_used(messages=messages)
|
|
||||||
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
|
|
||||||
anthropic_beta_header=headers.get("anthropic-beta")
|
|
||||||
)
|
|
||||||
anthropic_headers = self.get_anthropic_headers(
|
|
||||||
computer_tool_used=computer_tool_used,
|
|
||||||
prompt_caching_set=prompt_caching_set,
|
|
||||||
pdf_used=pdf_used,
|
|
||||||
api_key=api_key,
|
|
||||||
is_vertex_request=optional_params.get("is_vertex_request", False),
|
|
||||||
user_anthropic_beta_headers=user_anthropic_beta_headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = {**headers, **anthropic_headers}
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
This file contains common utils for anthropic calls.
|
This file contains common utils for anthropic calls.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -10,6 +10,8 @@ import litellm
|
||||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.anthropic import AllAnthropicToolsValues
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
|
|
||||||
class AnthropicError(BaseLLMException):
|
class AnthropicError(BaseLLMException):
|
||||||
|
@ -23,6 +25,128 @@ class AnthropicError(BaseLLMException):
|
||||||
|
|
||||||
|
|
||||||
class AnthropicModelInfo(BaseLLMModelInfo):
|
class AnthropicModelInfo(BaseLLMModelInfo):
|
||||||
|
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
|
||||||
|
"""
|
||||||
|
Return if {"cache_control": ..} in message content block
|
||||||
|
|
||||||
|
Used to check if anthropic prompt caching headers need to be set.
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
if message.get("cache_control", None) is not None:
|
||||||
|
return True
|
||||||
|
_message_content = message.get("content")
|
||||||
|
if _message_content is not None and isinstance(_message_content, list):
|
||||||
|
for content in _message_content:
|
||||||
|
if "cache_control" in content:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_computer_tool_used(
|
||||||
|
self, tools: Optional[List[AllAnthropicToolsValues]]
|
||||||
|
) -> bool:
|
||||||
|
if tools is None:
|
||||||
|
return False
|
||||||
|
for tool in tools:
|
||||||
|
if "type" in tool and tool["type"].startswith("computer_"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
|
||||||
|
"""
|
||||||
|
Set to true if media passed into messages.
|
||||||
|
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
if (
|
||||||
|
"content" in message
|
||||||
|
and message["content"] is not None
|
||||||
|
and isinstance(message["content"], list)
|
||||||
|
):
|
||||||
|
for content in message["content"]:
|
||||||
|
if "type" in content and content["type"] != "text":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_user_anthropic_beta_headers(
|
||||||
|
self, anthropic_beta_header: Optional[str]
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
if anthropic_beta_header is None:
|
||||||
|
return None
|
||||||
|
return anthropic_beta_header.split(",")
|
||||||
|
|
||||||
|
def get_anthropic_headers(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
anthropic_version: Optional[str] = None,
|
||||||
|
computer_tool_used: bool = False,
|
||||||
|
prompt_caching_set: bool = False,
|
||||||
|
pdf_used: bool = False,
|
||||||
|
is_vertex_request: bool = False,
|
||||||
|
user_anthropic_beta_headers: Optional[List[str]] = None,
|
||||||
|
) -> dict:
|
||||||
|
betas = set()
|
||||||
|
if prompt_caching_set:
|
||||||
|
betas.add("prompt-caching-2024-07-31")
|
||||||
|
if computer_tool_used:
|
||||||
|
betas.add("computer-use-2024-10-22")
|
||||||
|
if pdf_used:
|
||||||
|
betas.add("pdfs-2024-09-25")
|
||||||
|
headers = {
|
||||||
|
"anthropic-version": anthropic_version or "2023-06-01",
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
if user_anthropic_beta_headers is not None:
|
||||||
|
betas.update(user_anthropic_beta_headers)
|
||||||
|
|
||||||
|
# Don't send any beta headers to Vertex, Vertex has failed requests when they are sent
|
||||||
|
if is_vertex_request is True:
|
||||||
|
pass
|
||||||
|
elif len(betas) > 0:
|
||||||
|
headers["anthropic-beta"] = ",".join(betas)
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> Dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise litellm.AuthenticationError(
|
||||||
|
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
|
||||||
|
llm_provider="anthropic",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = optional_params.get("tools")
|
||||||
|
prompt_caching_set = self.is_cache_control_set(messages=messages)
|
||||||
|
computer_tool_used = self.is_computer_tool_used(tools=tools)
|
||||||
|
pdf_used = self.is_pdf_used(messages=messages)
|
||||||
|
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
|
||||||
|
anthropic_beta_header=headers.get("anthropic-beta")
|
||||||
|
)
|
||||||
|
anthropic_headers = self.get_anthropic_headers(
|
||||||
|
computer_tool_used=computer_tool_used,
|
||||||
|
prompt_caching_set=prompt_caching_set,
|
||||||
|
pdf_used=pdf_used,
|
||||||
|
api_key=api_key,
|
||||||
|
is_vertex_request=optional_params.get("is_vertex_request", False),
|
||||||
|
user_anthropic_beta_headers=user_anthropic_beta_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {**headers, **anthropic_headers}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||||
return (
|
return (
|
||||||
|
|
|
@ -44,6 +44,19 @@ class BaseLLMModelInfo(ABC):
|
||||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_base_model(model: str) -> Optional[str]:
|
def get_base_model(model: str) -> Optional[str]:
|
||||||
|
|
|
@ -389,7 +389,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_base_model(model: str) -> str:
|
def get_base_model(model: Optional[str] = None) -> Optional[str]:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_model_response_iterator(
|
def get_model_response_iterator(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
from ..base_llm.base_utils import BaseLLMModelInfo
|
from ..base_llm.base_utils import BaseLLMModelInfo
|
||||||
from ..base_llm.chat.transformation import BaseLLMException
|
from ..base_llm.chat.transformation import BaseLLMException
|
||||||
|
@ -11,6 +12,26 @@ class TopazException(BaseLLMException):
|
||||||
|
|
||||||
|
|
||||||
class TopazModelInfo(BaseLLMModelInfo):
|
class TopazModelInfo(BaseLLMModelInfo):
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
# "Content-Type": "multipart/form-data",
|
||||||
|
"Accept": "image/jpeg",
|
||||||
|
"X-API-Key": api_key,
|
||||||
|
}
|
||||||
|
|
||||||
def get_models(
|
def get_models(
|
||||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
|
|
@ -10,10 +10,7 @@ from litellm.llms.base_llm.chat.transformation import (
|
||||||
BaseLLMException,
|
BaseLLMException,
|
||||||
LiteLLMLoggingObj,
|
LiteLLMLoggingObj,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import OpenAIImageVariationOptionalParams
|
||||||
AllMessageValues,
|
|
||||||
OpenAIImageVariationOptionalParams,
|
|
||||||
)
|
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
FileTypes,
|
FileTypes,
|
||||||
HttpHandlerRequestFields,
|
HttpHandlerRequestFields,
|
||||||
|
@ -22,35 +19,15 @@ from litellm.types.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
|
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
|
||||||
from ..common_utils import TopazException
|
from ..common_utils import TopazException, TopazModelInfo
|
||||||
|
|
||||||
|
|
||||||
class TopazImageVariationConfig(BaseImageVariationConfig):
|
class TopazImageVariationConfig(TopazModelInfo, BaseImageVariationConfig):
|
||||||
def get_supported_openai_params(
|
def get_supported_openai_params(
|
||||||
self, model: str
|
self, model: str
|
||||||
) -> List[OpenAIImageVariationOptionalParams]:
|
) -> List[OpenAIImageVariationOptionalParams]:
|
||||||
return ["response_format", "size"]
|
return ["response_format", "size"]
|
||||||
|
|
||||||
def validate_environment(
|
|
||||||
self,
|
|
||||||
headers: dict,
|
|
||||||
model: str,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
optional_params: dict,
|
|
||||||
litellm_params: dict,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
) -> dict:
|
|
||||||
if api_key is None:
|
|
||||||
raise ValueError(
|
|
||||||
"API key is required for Topaz image variations. Set via `TOPAZ_API_KEY` or `api_key=..`"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
# "Content-Type": "multipart/form-data",
|
|
||||||
"Accept": "image/jpeg",
|
|
||||||
"X-API-Key": api_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_complete_url(
|
def get_complete_url(
|
||||||
self,
|
self,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
|
|
75
litellm/llms/vllm/common_utils.py
Normal file
75
litellm/llms/vllm/common_utils.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.utils import _add_path_to_api_base
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMError(BaseLLMException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMModelInfo(BaseLLMModelInfo):
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Google AI Studio sends api key in query params"""
|
||||||
|
return headers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||||
|
api_base = api_base or get_secret_str("VLLM_API_BASE")
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError(
|
||||||
|
"VLLM_API_BASE is not set. Please set the environment variable, to use VLLM's pass-through - `{LITELLM_API_BASE}/vllm/{endpoint}`."
|
||||||
|
)
|
||||||
|
return api_base
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_model(model: str) -> Optional[str]:
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_models(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
api_base = VLLMModelInfo.get_api_base(api_base)
|
||||||
|
api_key = VLLMModelInfo.get_api_key(api_key)
|
||||||
|
endpoint = "/v1/models"
|
||||||
|
if api_base is None or api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
url = _add_path_to_api_base(api_base, endpoint)
|
||||||
|
response = litellm.module_level_client.get(
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
models = response.json()["data"]
|
||||||
|
|
||||||
|
return [model["id"] for model in models]
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return VLLMError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
|
@ -5,9 +5,29 @@ import httpx
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
|
|
||||||
class XAIModelInfo(BaseLLMModelInfo):
|
class XAIModelInfo(BaseLLMModelInfo):
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if api_key is not None:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
# Ensure Content-Type is set to application/json
|
||||||
|
if "content-type" not in headers and "Content-Type" not in headers:
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||||
return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai"
|
return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai"
|
||||||
|
|
|
@ -317,6 +317,8 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/openai",
|
"/openai",
|
||||||
"/assemblyai",
|
"/assemblyai",
|
||||||
"/eu.assemblyai",
|
"/eu.assemblyai",
|
||||||
|
"/vllm",
|
||||||
|
"/mistral",
|
||||||
]
|
]
|
||||||
|
|
||||||
anthropic_routes = [
|
anthropic_routes = [
|
||||||
|
|
|
@ -6,6 +6,7 @@ Provider-specific Pass-Through Endpoints
|
||||||
Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc.
|
Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -43,6 +44,84 @@ def create_request_copy(request: Request):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_passthrough_factory_proxy_route(
|
||||||
|
custom_llm_provider: str,
|
||||||
|
endpoint: str,
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Factory function for creating pass-through endpoints for LLM providers.
|
||||||
|
"""
|
||||||
|
from litellm.types.utils import LlmProviders
|
||||||
|
from litellm.utils import ProviderConfigManager
|
||||||
|
|
||||||
|
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||||
|
provider=LlmProviders(custom_llm_provider),
|
||||||
|
model=None,
|
||||||
|
)
|
||||||
|
if provider_config is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Provider {custom_llm_provider} not found"
|
||||||
|
)
|
||||||
|
base_target_url = provider_config.get_api_base()
|
||||||
|
|
||||||
|
if base_target_url is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Provider {custom_llm_provider} api base not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
|
if not encoded_endpoint.startswith("/"):
|
||||||
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
|
|
||||||
|
# Construct the full target URL using httpx
|
||||||
|
base_url = httpx.URL(base_target_url)
|
||||||
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
|
# Add or update query parameters
|
||||||
|
provider_api_key = passthrough_endpoint_router.get_credentials(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
region_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_headers = provider_config.validate_environment(
|
||||||
|
headers={},
|
||||||
|
model="",
|
||||||
|
messages=[],
|
||||||
|
optional_params={},
|
||||||
|
litellm_params={},
|
||||||
|
api_key=provider_api_key,
|
||||||
|
api_base=base_target_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
## check for streaming
|
||||||
|
is_streaming_request = False
|
||||||
|
# anthropic is streaming when 'stream' = True is in the body
|
||||||
|
if request.method == "POST":
|
||||||
|
_request_body = await request.json()
|
||||||
|
if _request_body.get("stream"):
|
||||||
|
is_streaming_request = True
|
||||||
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
|
endpoint_func = create_pass_through_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=str(updated_url),
|
||||||
|
custom_headers=auth_headers,
|
||||||
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
|
received_value = await endpoint_func(
|
||||||
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
stream=is_streaming_request, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
@router.api_route(
|
@router.api_route(
|
||||||
"/gemini/{endpoint:path}",
|
"/gemini/{endpoint:path}",
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
@ -162,6 +241,84 @@ async def cohere_proxy_route(
|
||||||
return received_value
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route(
|
||||||
|
"/vllm/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
tags=["VLLM Pass-through", "pass-through"],
|
||||||
|
)
|
||||||
|
async def vllm_proxy_route(
|
||||||
|
endpoint: str,
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[Docs](https://docs.litellm.ai/docs/pass_through/vllm)
|
||||||
|
"""
|
||||||
|
return await llm_passthrough_factory_proxy_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
request=request,
|
||||||
|
fastapi_response=fastapi_response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
custom_llm_provider="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route(
|
||||||
|
"/mistral/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
tags=["Mistral Pass-through", "pass-through"],
|
||||||
|
)
|
||||||
|
async def mistral_proxy_route(
|
||||||
|
endpoint: str,
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
|
||||||
|
"""
|
||||||
|
base_target_url = os.getenv("MISTRAL_API_BASE") or "https://api.mistral.ai"
|
||||||
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
|
if not encoded_endpoint.startswith("/"):
|
||||||
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
|
|
||||||
|
# Construct the full target URL using httpx
|
||||||
|
base_url = httpx.URL(base_target_url)
|
||||||
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
|
# Add or update query parameters
|
||||||
|
mistral_api_key = passthrough_endpoint_router.get_credentials(
|
||||||
|
custom_llm_provider="mistral",
|
||||||
|
region_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
## check for streaming
|
||||||
|
is_streaming_request = False
|
||||||
|
# anthropic is streaming when 'stream' = True is in the body
|
||||||
|
if request.method == "POST":
|
||||||
|
_request_body = await request.json()
|
||||||
|
if _request_body.get("stream"):
|
||||||
|
is_streaming_request = True
|
||||||
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
|
endpoint_func = create_pass_through_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=str(updated_url),
|
||||||
|
custom_headers={"Authorization": "Bearer {}".format(mistral_api_key)},
|
||||||
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
|
received_value = await endpoint_func(
|
||||||
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
stream=is_streaming_request, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
@router.api_route(
|
@router.api_route(
|
||||||
"/anthropic/{endpoint:path}",
|
"/anthropic/{endpoint:path}",
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
|
|
@ -516,9 +516,9 @@ def function_setup( # noqa: PLR0915
|
||||||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||||
|
|
||||||
## DYNAMIC CALLBACKS ##
|
## DYNAMIC CALLBACKS ##
|
||||||
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
|
dynamic_callbacks: Optional[
|
||||||
kwargs.pop("callbacks", None)
|
List[Union[str, Callable, CustomLogger]]
|
||||||
)
|
] = kwargs.pop("callbacks", None)
|
||||||
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
|
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
|
||||||
|
|
||||||
if len(all_callbacks) > 0:
|
if len(all_callbacks) > 0:
|
||||||
|
@ -1202,9 +1202,9 @@ def client(original_function): # noqa: PLR0915
|
||||||
exception=e,
|
exception=e,
|
||||||
retry_policy=kwargs.get("retry_policy"),
|
retry_policy=kwargs.get("retry_policy"),
|
||||||
)
|
)
|
||||||
kwargs["retry_policy"] = (
|
kwargs[
|
||||||
reset_retry_policy()
|
"retry_policy"
|
||||||
) # prevent infinite loops
|
] = reset_retry_policy() # prevent infinite loops
|
||||||
litellm.num_retries = (
|
litellm.num_retries = (
|
||||||
None # set retries to None to prevent infinite loops
|
None # set retries to None to prevent infinite loops
|
||||||
)
|
)
|
||||||
|
@ -3013,16 +3013,16 @@ def get_optional_params( # noqa: PLR0915
|
||||||
True # so that main.py adds the function call to the prompt
|
True # so that main.py adds the function call to the prompt
|
||||||
)
|
)
|
||||||
if "tools" in non_default_params:
|
if "tools" in non_default_params:
|
||||||
optional_params["functions_unsupported_model"] = (
|
optional_params[
|
||||||
non_default_params.pop("tools")
|
"functions_unsupported_model"
|
||||||
)
|
] = non_default_params.pop("tools")
|
||||||
non_default_params.pop(
|
non_default_params.pop(
|
||||||
"tool_choice", None
|
"tool_choice", None
|
||||||
) # causes ollama requests to hang
|
) # causes ollama requests to hang
|
||||||
elif "functions" in non_default_params:
|
elif "functions" in non_default_params:
|
||||||
optional_params["functions_unsupported_model"] = (
|
optional_params[
|
||||||
non_default_params.pop("functions")
|
"functions_unsupported_model"
|
||||||
)
|
] = non_default_params.pop("functions")
|
||||||
elif (
|
elif (
|
||||||
litellm.add_function_to_prompt
|
litellm.add_function_to_prompt
|
||||||
): # if user opts to add it to prompt instead
|
): # if user opts to add it to prompt instead
|
||||||
|
@ -3045,11 +3045,11 @@ def get_optional_params( # noqa: PLR0915
|
||||||
|
|
||||||
if "response_format" in non_default_params:
|
if "response_format" in non_default_params:
|
||||||
if provider_config is not None:
|
if provider_config is not None:
|
||||||
non_default_params["response_format"] = (
|
non_default_params[
|
||||||
provider_config.get_json_schema_from_pydantic_object(
|
"response_format"
|
||||||
|
] = provider_config.get_json_schema_from_pydantic_object(
|
||||||
response_format=non_default_params["response_format"]
|
response_format=non_default_params["response_format"]
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
non_default_params["response_format"] = type_to_response_format_param(
|
non_default_params["response_format"] = type_to_response_format_param(
|
||||||
response_format=non_default_params["response_format"]
|
response_format=non_default_params["response_format"]
|
||||||
|
@ -4064,9 +4064,9 @@ def _count_characters(text: str) -> int:
|
||||||
|
|
||||||
|
|
||||||
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
|
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
|
||||||
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = (
|
_choices: Union[
|
||||||
response_obj.choices
|
List[Union[Choices, StreamingChoices]], List[StreamingChoices]
|
||||||
)
|
] = response_obj.choices
|
||||||
|
|
||||||
response_str = ""
|
response_str = ""
|
||||||
for choice in _choices:
|
for choice in _choices:
|
||||||
|
@ -4458,14 +4458,14 @@ def _get_model_info_helper( # noqa: PLR0915
|
||||||
|
|
||||||
if combined_model_name in litellm.model_cost:
|
if combined_model_name in litellm.model_cost:
|
||||||
key = combined_model_name
|
key = combined_model_name
|
||||||
_model_info = _get_model_info_from_model_cost(key=key)
|
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||||
if not _check_provider_match(
|
if not _check_provider_match(
|
||||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||||
):
|
):
|
||||||
_model_info = None
|
_model_info = None
|
||||||
if _model_info is None and model in litellm.model_cost:
|
if _model_info is None and model in litellm.model_cost:
|
||||||
key = model
|
key = model
|
||||||
_model_info = _get_model_info_from_model_cost(key=key)
|
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||||
if not _check_provider_match(
|
if not _check_provider_match(
|
||||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||||
):
|
):
|
||||||
|
@ -4475,21 +4475,21 @@ def _get_model_info_helper( # noqa: PLR0915
|
||||||
and combined_stripped_model_name in litellm.model_cost
|
and combined_stripped_model_name in litellm.model_cost
|
||||||
):
|
):
|
||||||
key = combined_stripped_model_name
|
key = combined_stripped_model_name
|
||||||
_model_info = _get_model_info_from_model_cost(key=key)
|
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||||
if not _check_provider_match(
|
if not _check_provider_match(
|
||||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||||
):
|
):
|
||||||
_model_info = None
|
_model_info = None
|
||||||
if _model_info is None and stripped_model_name in litellm.model_cost:
|
if _model_info is None and stripped_model_name in litellm.model_cost:
|
||||||
key = stripped_model_name
|
key = stripped_model_name
|
||||||
_model_info = _get_model_info_from_model_cost(key=key)
|
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||||
if not _check_provider_match(
|
if not _check_provider_match(
|
||||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||||
):
|
):
|
||||||
_model_info = None
|
_model_info = None
|
||||||
if _model_info is None and split_model in litellm.model_cost:
|
if _model_info is None and split_model in litellm.model_cost:
|
||||||
key = split_model
|
key = split_model
|
||||||
_model_info = _get_model_info_from_model_cost(key=key)
|
_model_info = _get_model_info_from_model_cost(key=cast(str, key))
|
||||||
if not _check_provider_match(
|
if not _check_provider_match(
|
||||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||||
):
|
):
|
||||||
|
@ -6510,7 +6510,12 @@ class ProviderConfigManager:
|
||||||
return litellm.AnthropicModelInfo()
|
return litellm.AnthropicModelInfo()
|
||||||
elif LlmProviders.XAI == provider:
|
elif LlmProviders.XAI == provider:
|
||||||
return litellm.XAIModelInfo()
|
return litellm.XAIModelInfo()
|
||||||
|
elif LlmProviders.VLLM == provider:
|
||||||
|
from litellm.llms.vllm.common_utils import (
|
||||||
|
VLLMModelInfo, # experimental approach, to reduce bloat on __init__.py
|
||||||
|
)
|
||||||
|
|
||||||
|
return VLLMModelInfo()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue