Litellm remove circular imports (#7232)

* fix(utils.py): initial commit to remove circular imports - moves llmproviders to utils.py

* fix(router.py): fix 'litellm.EmbeddingResponse' import from router.py

'

* refactor: fix litellm.ModelResponse import on pass through endpoints

* refactor(litellm_logging.py): fix circular import for custom callbacks literal

* fix(factory.py): fix circular imports inside prompt factory

* fix(cost_calculator.py): fix circular import for 'litellm.Usage'

* fix(proxy_server.py): fix potential circular import with `litellm.Router'

* fix(proxy/utils.py): fix potential circular import in `litellm.Router`

* fix: remove circular imports in 'auth_checks' and 'guardrails/'

* fix(prompt_injection_detection.py): fix router impor t

* fix(vertex_passthrough_logging_handler.py): fix potential circular imports in vertex pass through

* fix(anthropic_pass_through_logging_handler.py): fix potential circular imports

* fix(slack_alerting.py-+-ollama_chat.py): fix modelresponse import

* fix(base.py): fix potential circular import

* fix(handler.py): fix potential circular ref in codestral + cohere handler's

* fix(azure.py): fix potential circular imports

* fix(gpt_transformation.py): fix modelresponse import

* fix(litellm_logging.py): add logging base class - simplify typing

makes it easy for other files to type check the logging obj without introducing circular imports

* fix(azure_ai/embed): fix potential circular import on handler.py

* fix(databricks/): fix potential circular imports in databricks/

* fix(vertex_ai/): fix potential circular imports on vertex ai embeddings

* fix(vertex_ai/image_gen): fix import

* fix(watsonx-+-bedrock): cleanup imports

* refactor(anthropic-pass-through-+-petals): cleanup imports

* refactor(huggingface/): cleanup imports

* fix(ollama-+-clarifai): cleanup circular imports

* fix(openai_like/): fix impor t

* fix(openai_like/): fix embedding handler

cleanup imports

* refactor(openai.py): cleanup imports

* fix(sagemaker/transformation.py): fix import

* ci(config.yml): add circular import test to ci/cd
This commit is contained in:
Krish Dholakia 2024-12-14 16:28:34 -08:00 committed by GitHub
parent 0dbf71291e
commit 516c2a6a70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 489 additions and 256 deletions

View file

@ -817,6 +817,7 @@ jobs:
- run: python ./tests/documentation_tests/test_api_docs.py
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
- run: python ./tests/code_coverage_tests/enforce_llms_folder_style.py
- run: python ./tests/documentation_tests/test_circular_imports.py
- run: helm lint ./deploy/charts/litellm-helm
db_migration_disable_update_check:

View file

@ -474,12 +474,9 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
from detect_secrets import SecretsCollection
from detect_secrets.settings import default_settings
print("INSIDE SECRET DETECTION PRE-CALL HOOK!")
if await self.should_run_check(user_api_key_dict) is False:
return
print("RUNNING CHECK!")
if "messages" in data and isinstance(data["messages"], list):
for message in data["messages"]:
if "content" in message and isinstance(message["content"], str):

View file

@ -32,7 +32,7 @@ from litellm.proxy._types import (
KeyManagementSettings,
LiteLLM_UpperboundKeyGenerateParams,
)
from litellm.types.utils import StandardKeyGenerationConfig
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
import httpx
import dotenv
from enum import Enum
@ -838,71 +838,6 @@ model_list = (
)
class LlmProviders(str, Enum):
OPENAI = "openai"
OPENAI_LIKE = "openai_like" # embedding only
JINA_AI = "jina_ai"
XAI = "xai"
CUSTOM_OPENAI = "custom_openai"
TEXT_COMPLETION_OPENAI = "text-completion-openai"
COHERE = "cohere"
COHERE_CHAT = "cohere_chat"
CLARIFAI = "clarifai"
ANTHROPIC = "anthropic"
ANTHROPIC_TEXT = "anthropic_text"
REPLICATE = "replicate"
HUGGINGFACE = "huggingface"
TOGETHER_AI = "together_ai"
OPENROUTER = "openrouter"
VERTEX_AI = "vertex_ai"
VERTEX_AI_BETA = "vertex_ai_beta"
GEMINI = "gemini"
AI21 = "ai21"
BASETEN = "baseten"
AZURE = "azure"
AZURE_TEXT = "azure_text"
AZURE_AI = "azure_ai"
SAGEMAKER = "sagemaker"
SAGEMAKER_CHAT = "sagemaker_chat"
BEDROCK = "bedrock"
VLLM = "vllm"
NLP_CLOUD = "nlp_cloud"
PETALS = "petals"
OOBABOOGA = "oobabooga"
OLLAMA = "ollama"
OLLAMA_CHAT = "ollama_chat"
DEEPINFRA = "deepinfra"
PERPLEXITY = "perplexity"
MISTRAL = "mistral"
GROQ = "groq"
NVIDIA_NIM = "nvidia_nim"
CEREBRAS = "cerebras"
AI21_CHAT = "ai21_chat"
VOLCENGINE = "volcengine"
CODESTRAL = "codestral"
TEXT_COMPLETION_CODESTRAL = "text-completion-codestral"
DEEPSEEK = "deepseek"
SAMBANOVA = "sambanova"
MARITALK = "maritalk"
VOYAGE = "voyage"
CLOUDFLARE = "cloudflare"
XINFERENCE = "xinference"
FIREWORKS_AI = "fireworks_ai"
FRIENDLIAI = "friendliai"
WATSONX = "watsonx"
WATSONX_TEXT = "watsonx_text"
TRITON = "triton"
PREDIBASE = "predibase"
DATABRICKS = "databricks"
EMPOWER = "empower"
GITHUB = "github"
CUSTOM = "custom"
LITELLM_PROXY = "litellm_proxy"
HOSTED_VLLM = "hosted_vllm"
LM_STUDIO = "lm_studio"
GALADRIEL = "galadriel"
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)

View file

@ -18,7 +18,7 @@ from litellm.types.llms.anthropic import (
AnthropicResponse,
ContentBlockDelta,
)
from litellm.types.utils import AdapterCompletionStreamWrapper
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
class AnthropicAdapter(CustomLogger):
@ -41,7 +41,7 @@ class AnthropicAdapter(CustomLogger):
return translated_body
def translate_completion_output_params(
self, response: litellm.ModelResponse
self, response: ModelResponse
) -> Optional[AnthropicResponse]:
return litellm.AnthropicExperimentalPassThroughConfig().translate_openai_response_to_anthropic(

View file

@ -484,7 +484,7 @@ def completion_cost( # noqa: PLR0915
completion_characters: Optional[int] = None
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
cost_per_token_usage_object: Optional[litellm.Usage] = _get_usage_object(
cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
completion_response=completion_response
)
if completion_response is not None and (
@ -492,7 +492,7 @@ def completion_cost( # noqa: PLR0915
or isinstance(completion_response, dict)
): # tts returns a custom class
usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get( # type: ignore
usage_obj: Optional[Union[dict, Usage]] = completion_response.get( # type: ignore
"usage", {}
)
if isinstance(usage_obj, BaseModel) and not isinstance(

View file

@ -39,6 +39,7 @@ from litellm.proxy._types import (
VirtualKeyEvent,
WebhookEvent,
)
from litellm.router import Router
from litellm.types.integrations.slack_alerting import *
from litellm.types.router import LiteLLM_Params
@ -93,7 +94,7 @@ class SlackAlerting(CustomBatchLogger):
alert_types: Optional[List[AlertType]] = None,
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] = None,
alerting_args: Optional[Dict] = None,
llm_router: Optional[litellm.Router] = None,
llm_router: Optional[Router] = None,
):
if alerting is not None:
self.alerting = alerting

View file

@ -18,6 +18,7 @@ from pydantic import BaseModel
import litellm
from litellm import (
_custom_logger_compatible_callbacks_literal,
json_logs,
log_raw_request_response,
turn_off_message_logging,
@ -41,6 +42,7 @@ from litellm.types.utils import (
CallTypes,
EmbeddingResponse,
ImageResponse,
LiteLLMLoggingBaseClass,
ModelResponse,
StandardCallbackDynamicParams,
StandardLoggingAdditionalHeaders,
@ -190,7 +192,7 @@ in_memory_trace_id_cache = ServiceTraceIDCache()
in_memory_dynamic_logger_cache = DynamicLoggingCache()
class Logging:
class Logging(LiteLLMLoggingBaseClass):
global supabaseClient, promptLayerLogger, weightsBiasesLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger, logfireLogger, prometheusLogger, slack_app
custom_pricing: bool = False
stream_options = None
@ -2142,7 +2144,7 @@ def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
def _init_custom_logger_compatible_class( # noqa: PLR0915
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
logging_integration: _custom_logger_compatible_callbacks_literal,
internal_usage_cache: Optional[DualCache],
llm_router: Optional[
Any
@ -2362,7 +2364,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
def get_custom_logger_compatible_class( # noqa: PLR0915
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
logging_integration: _custom_logger_compatible_callbacks_literal,
) -> Optional[CustomLogger]:
if logging_integration == "lago":
for callback in _in_memory_loggers:

View file

@ -13,7 +13,6 @@ from jinja2.sandbox import ImmutableSandboxedEnvironment
import litellm
import litellm.types
import litellm.types.llms
import litellm.types.llms.vertex_ai
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.types.completion import (
@ -40,6 +39,9 @@ from litellm.types.llms.openai import (
ChatCompletionUserMessage,
OpenAIMessageContentListBlock,
)
from litellm.types.llms.vertex_ai import FunctionCall as VertexFunctionCall
from litellm.types.llms.vertex_ai import FunctionResponse as VertexFunctionResponse
from litellm.types.llms.vertex_ai import PartType as VertexPartType
from litellm.types.utils import GenericImageParsingChunk
from .common_utils import convert_content_list_to_str, is_non_content_values_set
@ -965,11 +967,11 @@ def infer_protocol_value(
def _gemini_tool_call_invoke_helper(
function_call_params: ChatCompletionToolCallFunctionChunk,
) -> Optional[litellm.types.llms.vertex_ai.FunctionCall]:
) -> Optional[VertexFunctionCall]:
name = function_call_params.get("name", "") or ""
arguments = function_call_params.get("arguments", "")
arguments_dict = json.loads(arguments)
function_call = litellm.types.llms.vertex_ai.FunctionCall(
function_call = VertexFunctionCall(
name=name,
args=arguments_dict,
)
@ -978,7 +980,7 @@ def _gemini_tool_call_invoke_helper(
def convert_to_gemini_tool_call_invoke(
message: ChatCompletionAssistantMessage,
) -> List[litellm.types.llms.vertex_ai.PartType]:
) -> List[VertexPartType]:
"""
OpenAI tool invokes:
{
@ -1019,22 +1021,20 @@ def convert_to_gemini_tool_call_invoke(
- json.load the arguments
"""
try:
_parts_list: List[litellm.types.llms.vertex_ai.PartType] = []
_parts_list: List[VertexPartType] = []
tool_calls = message.get("tool_calls", None)
function_call = message.get("function_call", None)
if tool_calls is not None:
for tool in tool_calls:
if "function" in tool:
gemini_function_call: Optional[
litellm.types.llms.vertex_ai.FunctionCall
] = _gemini_tool_call_invoke_helper(
function_call_params=tool["function"]
gemini_function_call: Optional[VertexFunctionCall] = (
_gemini_tool_call_invoke_helper(
function_call_params=tool["function"]
)
)
if gemini_function_call is not None:
_parts_list.append(
litellm.types.llms.vertex_ai.PartType(
function_call=gemini_function_call
)
VertexPartType(function_call=gemini_function_call)
)
else: # don't silently drop params. Make it clear to user what's happening.
raise Exception(
@ -1047,11 +1047,7 @@ def convert_to_gemini_tool_call_invoke(
function_call_params=function_call
)
if gemini_function_call is not None:
_parts_list.append(
litellm.types.llms.vertex_ai.PartType(
function_call=gemini_function_call
)
)
_parts_list.append(VertexPartType(function_call=gemini_function_call))
else: # don't silently drop params. Make it clear to user what's happening.
raise Exception(
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
@ -1070,7 +1066,7 @@ def convert_to_gemini_tool_call_invoke(
def convert_to_gemini_tool_call_result(
message: Union[ChatCompletionToolMessage, ChatCompletionFunctionMessage],
last_message_with_tool_calls: Optional[dict],
) -> litellm.types.llms.vertex_ai.PartType:
) -> VertexPartType:
"""
OpenAI message with a tool result looks like:
{
@ -1119,11 +1115,11 @@ def convert_to_gemini_tool_call_result(
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
_function_response = litellm.types.llms.vertex_ai.FunctionResponse(
_function_response = VertexFunctionResponse(
name=name, response={"content": content_str} # type: ignore
)
_part = litellm.types.llms.vertex_ai.PartType(function_response=_function_response)
_part = VertexPartType(function_response=_function_response)
return _part

View file

@ -5,6 +5,12 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
anthropic_messages_pt,
custom_prompt,
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
AnthopicMessagesAssistantMessageParam,
@ -53,15 +59,9 @@ from litellm.types.llms.openai import (
ChatCompletionUserMessage,
OpenAIMessageContent,
)
from litellm.types.utils import Choices, GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from litellm.types.utils import Choices, GenericStreamingChunk, ModelResponse, Usage
from ...base import BaseLLM
from litellm.litellm_core_utils.prompt_templates.factory import (
anthropic_messages_pt,
custom_prompt,
prompt_factory,
)
class AnthropicExperimentalPassThroughConfig:
@ -338,7 +338,7 @@ class AnthropicExperimentalPassThroughConfig:
return "end_turn"
def translate_openai_response_to_anthropic(
self, response: litellm.ModelResponse
self, response: ModelResponse
) -> AnthropicResponse:
## translate content block
anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore
@ -347,7 +347,7 @@ class AnthropicExperimentalPassThroughConfig:
openai_finish_reason=response.choices[0].finish_reason # type: ignore
)
# extract usage
usage: litellm.Usage = getattr(response, "usage")
usage: Usage = getattr(response, "usage")
anthropic_usage = AnthropicResponseUsageBlock(
input_tokens=usage.prompt_tokens or 0,
output_tokens=usage.completion_tokens or 0,
@ -393,7 +393,7 @@ class AnthropicExperimentalPassThroughConfig:
return "text_delta", ContentTextBlockDelta(type="text_delta", text=text)
def translate_streaming_openai_response_to_anthropic(
self, response: litellm.ModelResponse
self, response: ModelResponse
) -> Union[ContentBlockDelta, MessageBlockDelta]:
## base case - final chunk w/ finish reason
if response.choices[0].finish_reason is not None:
@ -403,7 +403,7 @@ class AnthropicExperimentalPassThroughConfig:
),
)
if getattr(response, "usage", None) is not None:
litellm_usage_chunk: Optional[litellm.Usage] = response.usage # type: ignore
litellm_usage_chunk: Optional[Usage] = response.usage # type: ignore
elif (
hasattr(response, "_hidden_params")
and "usage" in response._hidden_params

View file

@ -17,10 +17,14 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.utils import EmbeddingResponse
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
LlmProviders,
ModelResponse,
)
from litellm.utils import (
CustomStreamWrapper,
ModelResponse,
UnsupportedParamsError,
convert_to_model_response_object,
get_secret,
@ -853,7 +857,7 @@ class AzureChatCompletion(BaseLLM):
client=None,
aembedding=None,
headers: Optional[dict] = None,
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
if headers:
optional_params["extra_headers"] = headers
if self._client_session is None:
@ -963,7 +967,7 @@ class AzureChatCompletion(BaseLLM):
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE,
llm_provider=LlmProviders.AZURE,
params=_params,
)
else:
@ -1242,11 +1246,11 @@ class AzureChatCompletion(BaseLLM):
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
model_response: Optional[ImageResponse] = None,
azure_ad_token: Optional[str] = None,
client=None,
aimg_generation=None,
) -> litellm.ImageResponse:
) -> ImageResponse:
try:
if model and len(model) > 0:
model = model
@ -1510,7 +1514,7 @@ class AzureChatCompletion(BaseLLM):
) -> dict:
client_session = (
litellm.aclient_session
or get_async_httpx_client(llm_provider=litellm.LlmProviders.AZURE).client
or get_async_httpx_client(llm_provider=LlmProviders.AZURE).client
) # handle dall-e-2 calls
if "gateway.ai.cloudflare.com" in api_base:

View file

@ -4,7 +4,11 @@ from typing import TYPE_CHECKING, Any, List, Optional, Type, Union
from httpx._models import Headers, Response
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_azure_openai_messages,
)
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.types.utils import ModelResponse
from ....exceptions import UnsupportedParamsError
from ....types.llms.openai import (
@ -14,9 +18,7 @@ from ....types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from ...base_llm.transformation import BaseConfig
from litellm.litellm_core_utils.prompt_templates.factory import convert_to_azure_openai_messages
from ..common_utils import AzureOpenAIError
if TYPE_CHECKING:
@ -26,6 +28,7 @@ if TYPE_CHECKING:
else:
LoggingClass = Any
class AzureOpenAIConfig(BaseConfig):
"""
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
@ -221,7 +224,7 @@ class AzureOpenAIConfig(BaseConfig):
self,
model: str,
raw_response: Response,
model_response: litellm.ModelResponse,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: dict,
messages: List[AllMessageValues],
@ -230,7 +233,7 @@ class AzureOpenAIConfig(BaseConfig):
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
) -> ModelResponse:
raise NotImplementedError(
"Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK."
)

View file

@ -89,7 +89,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
embedding_response = response.json()
embedding_headers = dict(response.headers)
returned_response: litellm.EmbeddingResponse = convert_to_model_response_object( # type: ignore
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
response_object=embedding_response,
model_response_object=model_response,
response_type="embedding",
@ -104,7 +104,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
data: ImageEmbeddingRequest,
timeout: float,
logging_obj,
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
@ -132,7 +132,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
embedding_response = response.json()
embedding_headers = dict(response.headers)
returned_response: litellm.EmbeddingResponse = convert_to_model_response_object( # type: ignore
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
response_object=embedding_response,
model_response_object=model_response,
response_type="embedding",
@ -213,14 +213,14 @@ class AzureAIEmbedding(OpenAIChatCompletion):
input: List,
timeout: float,
logging_obj,
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
aembedding=None,
max_retries: Optional[int] = None,
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
"""
- Separate image url from text
-> route image url call to `/image/embeddings`

View file

@ -5,6 +5,8 @@ import httpx
import requests
import litellm
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.types.utils import ModelResponse, TextCompletionResponse
class BaseLLM:
@ -15,7 +17,7 @@ class BaseLLM:
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.ModelResponse,
model_response: ModelResponse,
stream: bool,
logging_obj: Any,
optional_params: dict,
@ -24,7 +26,7 @@ class BaseLLM:
messages: list,
print_verbose,
encoding,
) -> Union[litellm.utils.ModelResponse, litellm.utils.CustomStreamWrapper]:
) -> Union[ModelResponse, CustomStreamWrapper]:
"""
Helper function to process the response across sync + async completion calls
"""
@ -34,7 +36,7 @@ class BaseLLM:
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.TextCompletionResponse,
model_response: TextCompletionResponse,
stream: bool,
logging_obj: Any,
optional_params: dict,
@ -43,7 +45,7 @@ class BaseLLM:
messages: list,
print_verbose,
encoding,
) -> Union[litellm.utils.TextCompletionResponse, litellm.utils.CustomStreamWrapper]:
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
"""
Helper function to process the response across sync + async completion calls
"""

View file

@ -32,6 +32,17 @@ from litellm import verbose_logger
from litellm.caching.caching import InMemoryCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.prompt_templates.factory import (
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
cohere_message_pt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -50,20 +61,10 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import CustomStreamWrapper, get_secret
from ..base_aws_llm import BaseAWSLLM
from litellm.litellm_core_utils.prompt_templates.factory import (
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
cohere_message_pt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
)
from ..common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name
from .converse_transformation import AmazonConverseConfig
@ -1317,7 +1318,7 @@ class MockResponseIterator: # for returning ai21 streaming responses
def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
try:
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
chunk_usage: Usage = getattr(chunk_data, "usage")
text = chunk_data.choices[0].message.content or "" # type: ignore
tool_use = None
if self.json_mode is True:

View file

@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
@ -152,7 +154,7 @@ class ClarifaiConfig(BaseConfig):
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
) -> ModelResponse:
logging_obj.post_call(
input=messages,
api_key=api_key,

View file

@ -29,6 +29,7 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.types.utils import TextChoices
from litellm.utils import (
Choices,
CustomStreamWrapper,
@ -169,7 +170,7 @@ class CodestralTextCompletion(BaseLLM):
raise TextCompletionCodestralError(message=response.text, status_code=422)
_original_choices = completion_response.get("choices", [])
_choices: List[litellm.utils.TextChoices] = []
_choices: List[TextChoices] = []
for choice in _original_choices:
# This is what 1 choice looks like from codestral API
# {

View file

@ -17,6 +17,7 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.types.llms.bedrock import CohereEmbeddingRequest
from litellm.types.utils import EmbeddingResponse
from litellm.utils import Choices, Message, ModelResponse, Usage
from .transformation import CohereEmbeddingConfig
@ -118,7 +119,7 @@ async def async_embedding(
def embedding(
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
headers: dict,

View file

@ -21,7 +21,7 @@ class DatabricksEmbeddingHandler(OpenAILikeEmbeddingHandler, DatabricksBase):
api_key: Optional[str],
api_base: Optional[str],
optional_params: dict,
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
model_response: Optional[EmbeddingResponse] = None,
client=None,
aembedding=None,
custom_endpoint: Optional[bool] = None,

View file

@ -55,9 +55,7 @@ class ModelResponseIterator:
is_finished = True
finish_reason = processed_chunk.choices[0].finish_reason
usage_chunk: Optional[litellm.Usage] = getattr(
processed_chunk, "usage", None
)
usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None)
if usage_chunk is not None:
usage = ChatCompletionUsageBlock(

View file

@ -24,6 +24,7 @@ import requests
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -36,8 +37,9 @@ from litellm.llms.huggingface.chat.transformation import (
from litellm.secret_managers.main import get_secret_str
from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import EmbeddingResponse
from litellm.types.utils import Logprobs as TextCompletionLogprobs
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from litellm.types.utils import ModelResponse, Usage
from ...base import BaseLLM
from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks
@ -453,11 +455,11 @@ class Huggingface(BaseLLM):
def _process_embedding_response(
self,
embeddings: dict,
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
model: str,
input: List,
encoding: Any,
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
output_data = []
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
@ -583,7 +585,7 @@ class Huggingface(BaseLLM):
self,
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
@ -593,7 +595,7 @@ class Huggingface(BaseLLM):
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={},
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
super().embedding()
headers = hf_chat_config.validate_environment(
api_key=api_key,

View file

@ -8,10 +8,15 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
@ -407,7 +412,7 @@ class HuggingfaceChatConfig(BaseConfig):
def convert_to_model_response_object( # noqa: PLR0915
self,
completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
model_response: litellm.ModelResponse,
model_response: ModelResponse,
task: Optional[hf_tasks],
optional_params: dict,
encoding: Any,

View file

@ -14,11 +14,20 @@ import requests # type: ignore
import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices
from litellm.types.utils import (
EmbeddingResponse,
ModelInfo,
ModelResponse,
ProviderField,
StreamingChoices,
)
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import OllamaError
from .transformation import OllamaConfig
@ -53,7 +62,7 @@ def _convert_image(image):
# ollama implementation
def get_ollama_response(
model_response: litellm.ModelResponse,
model_response: ModelResponse,
model: str,
prompt: str,
optional_params: dict,
@ -391,7 +400,7 @@ async def ollama_aembeddings(
api_base: str,
model: str,
prompts: List[str],
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
optional_params: dict,
logging_obj: Any,
encoding: Any,
@ -479,7 +488,7 @@ def ollama_embeddings(
model: str,
prompts: list,
optional_params: dict,
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
logging_obj: Any,
encoding=None,
):

View file

@ -17,7 +17,7 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import StreamingChoices
from litellm.types.utils import ModelResponse, StreamingChoices
class OllamaError(Exception):
@ -198,7 +198,7 @@ class OllamaChatConfig(OpenAIGPTConfig):
# ollama implementation
def get_ollama_response( # noqa: PLR0915
model_response: litellm.ModelResponse,
model_response: ModelResponse,
messages: list,
optional_params: dict,
model: str,

View file

@ -28,24 +28,31 @@ import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ProviderField
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ProviderField,
TextCompletionResponse,
Usage,
)
from litellm.utils import (
Choices,
CustomStreamWrapper,
Message,
ModelResponse,
ProviderConfigManager,
TextCompletionResponse,
Usage,
convert_to_model_response_object,
)
from ...types.llms.openai import *
from ..base import BaseLLM
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from .chat.gpt_transformation import OpenAIGPTConfig
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
@ -882,7 +889,7 @@ class OpenAIChatCompletion(BaseLLM):
self,
input: list,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
model_response: EmbeddingResponse,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
@ -911,9 +918,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
returned_response: (
litellm.EmbeddingResponse
) = convert_to_model_response_object(
returned_response: EmbeddingResponse = convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
response_type="embedding",
@ -953,14 +958,14 @@ class OpenAIChatCompletion(BaseLLM):
input: list,
timeout: float,
logging_obj,
model_response: litellm.utils.EmbeddingResponse,
model_response: EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
aembedding=None,
max_retries: Optional[int] = None,
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
super().embedding()
try:
model = model
@ -1011,7 +1016,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=sync_embedding_response,
)
response: litellm.EmbeddingResponse = convert_to_model_response_object(
response: EmbeddingResponse = convert_to_model_response_object(
response_object=sync_embedding_response.model_dump(),
model_response_object=model_response,
_response_headers=headers,
@ -1068,7 +1073,7 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
input=prompt,
api_key=api_key,
original_response=str(e),
)
@ -1083,10 +1088,10 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj: Any,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
model_response: Optional[ImageResponse] = None,
client=None,
aimg_generation=None,
) -> litellm.ImageResponse:
) -> ImageResponse:
data = {}
try:
model = model

View file

@ -3,7 +3,7 @@ OpenAI-like chat completion transformation
"""
import types
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import httpx
from pydantic import BaseModel
@ -16,6 +16,13 @@ from litellm.types.utils import ModelResponse
from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class OpenAILikeChatConfig(OpenAIGPTConfig):
def _get_openai_compatible_provider_info(
@ -64,7 +71,7 @@ class OpenAILikeChatConfig(OpenAIGPTConfig):
response: httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_key: Optional[str],
data: Union[dict, str],

View file

@ -11,8 +11,7 @@ from enum import Enum
from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import httpx
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
@ -21,7 +20,7 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
)
from litellm.utils import EmbeddingResponse
from litellm.types.utils import EmbeddingResponse
from ..common_utils import OpenAILikeBase, OpenAILikeError
@ -100,7 +99,7 @@ class OpenAILikeEmbeddingHandler(OpenAILikeBase):
api_key: Optional[str],
api_base: Optional[str],
optional_params: dict,
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
model_response: Optional[EmbeddingResponse] = None,
client=None,
aembedding=None,
custom_endpoint: Optional[bool] = None,

View file

@ -10,6 +10,7 @@ from litellm.llms.base_llm.transformation import (
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..common_utils import PetalsError
@ -111,7 +112,7 @@ class PetalsConfig(BaseConfig):
self,
model: str,
raw_response: Response,
model_response: litellm.ModelResponse,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
@ -120,7 +121,7 @@ class PetalsConfig(BaseConfig):
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
) -> ModelResponse:
raise NotImplementedError(
"Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
)

View file

@ -27,6 +27,7 @@ from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.types.utils import LiteLLMLoggingBaseClass
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from ...base import BaseLLM
@ -92,7 +93,7 @@ class PredibaseChatCompletion(BaseLLM):
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
logging_obj: LiteLLMLoggingBaseClass,
optional_params: dict,
api_key: str,
data: Union[dict, str],

View file

@ -13,10 +13,13 @@ from httpx._models import Headers, Response
import litellm
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Usage
from litellm.types.utils import ModelResponse, Usage
from ..common_utils import SagemakerError
@ -197,7 +200,7 @@ class SagemakerConfig(BaseConfig):
self,
model: str,
raw_response: Response,
model_response: litellm.ModelResponse,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
@ -206,7 +209,7 @@ class SagemakerConfig(BaseConfig):
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> litellm.ModelResponse:
) -> ModelResponse:
completion_response = raw_response.json()
## LOGGING
logging_obj.post_call(

View file

@ -5,20 +5,20 @@ Why separate file? Make it easy to see how transformation works
"""
import os
from typing import List, Literal, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union, cast
import httpx
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_anthropic_image_obj,
convert_to_gemini_tool_call_invoke,
convert_to_gemini_tool_call_result,
response_schema_prompt,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.files import (
get_file_mime_type_for_file_type,
get_file_type_from_extension,
@ -49,6 +49,13 @@ from ..common_utils import (
get_supports_system_message,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
def _process_gemini_image(image_url: str) -> PartType:
"""
@ -348,7 +355,7 @@ def sync_transform_request_body(
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[dict],
optional_params: dict,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
litellm_params: dict,
) -> RequestBody:

View file

@ -10,18 +10,17 @@ from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.utils import ImageResponse
class VertexImageGeneration(VertexLLM):
def process_image_generation_response(
self,
json_response: Dict[str, Any],
model_response: litellm.ImageResponse,
model_response: ImageResponse,
model: Optional[str] = None,
) -> litellm.ImageResponse:
) -> ImageResponse:
if "predictions" not in json_response:
raise litellm.InternalServerError(
message=f"image generation response does not contain 'predictions', got {json_response}",
@ -46,7 +45,7 @@ class VertexImageGeneration(VertexLLM):
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
model_response: ImageResponse,
logging_obj: Any,
model: Optional[
str
@ -55,7 +54,7 @@ class VertexImageGeneration(VertexLLM):
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
aimg_generation=False,
) -> litellm.ImageResponse:
) -> ImageResponse:
if aimg_generation is True:
return self.aimage_generation( # type: ignore
prompt=prompt,

View file

@ -22,7 +22,7 @@ from litellm.types.llms.vertex_ai import (
MultimodalPredictions,
VertexMultimodalEmbeddingRequest,
)
from litellm.types.utils import Embedding
from litellm.types.utils import Embedding, EmbeddingResponse
from litellm.utils import is_base64_encoded
@ -39,7 +39,7 @@ class VertexMultimodalEmbedding(VertexLLM):
model: str,
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
@ -52,7 +52,7 @@ class VertexMultimodalEmbedding(VertexLLM):
aembedding=False,
timeout=300,
client=None,
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,

View file

@ -15,12 +15,10 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.vertex_ai_non_gemini import (
VertexAIError,
)
from litellm.llms.vertex_ai.vertex_ai_non_gemini import VertexAIError
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.llms.vertex_ai import *
from litellm.utils import Usage
from litellm.types.utils import EmbeddingResponse, Usage
from .transformation import VertexAITextEmbeddingConfig
from .types import *
@ -35,7 +33,7 @@ class VertexEmbedding(VertexBase):
model: str,
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObject,
custom_llm_provider: Literal[
@ -52,7 +50,7 @@ class VertexEmbedding(VertexBase):
vertex_credentials: Optional[str] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
if aembedding is True:
return self.async_embedding( # type: ignore
model=model,

View file

@ -4,7 +4,7 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel
import litellm
from litellm.utils import Usage
from litellm.types.utils import EmbeddingResponse, Usage
from .types import *
@ -198,8 +198,8 @@ class VertexAITextEmbeddingConfig(BaseModel):
return text_embedding_input
def transform_vertex_response_to_openai(
self, response: dict, model: str, model_response: litellm.EmbeddingResponse
) -> litellm.EmbeddingResponse:
self, response: dict, model: str, model_response: EmbeddingResponse
) -> EmbeddingResponse:
"""
Transforms a vertex embedding response to an openai response.
"""
@ -234,8 +234,8 @@ class VertexAITextEmbeddingConfig(BaseModel):
return model_response
def _transform_vertex_response_to_openai_for_fine_tuned_models(
self, response: dict, model: str, model_response: litellm.EmbeddingResponse
) -> litellm.EmbeddingResponse:
self, response: dict, model: str, model_response: EmbeddingResponse
) -> EmbeddingResponse:
"""
Transforms a vertex fine-tuned model embedding response to an openai response format.
"""

View file

@ -24,6 +24,8 @@ import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
@ -34,7 +36,6 @@ from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
from ...base import BaseLLM
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
from .transformation import IBMWatsonXAIConfig
@ -204,7 +205,7 @@ class IBMWatsonXAI(BaseLLM):
def process_stream_response(
stream_resp: Union[Iterator[str], AsyncIterator],
) -> litellm.CustomStreamWrapper:
) -> CustomStreamWrapper:
streamwrapper = litellm.CustomStreamWrapper(
stream_resp,
model=model,
@ -235,7 +236,7 @@ class IBMWatsonXAI(BaseLLM):
json_resp = resp.json()
return self._process_text_gen_response(json_resp, model_response)
def handle_stream_request(request_params: dict) -> litellm.CustomStreamWrapper:
def handle_stream_request(request_params: dict) -> CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self.request_manager.request(
@ -249,7 +250,7 @@ class IBMWatsonXAI(BaseLLM):
async def handle_stream_request_async(
request_params: dict,
) -> litellm.CustomStreamWrapper:
) -> CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with self.request_manager.async_request(
@ -321,14 +322,14 @@ class IBMWatsonXAI(BaseLLM):
self,
model: str,
input: Union[list, str],
model_response: litellm.EmbeddingResponse,
model_response: EmbeddingResponse,
api_key: Optional[str],
logging_obj: Any,
optional_params: dict,
encoding=None,
print_verbose=None,
aembedding=None,
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
"""
Send a text embedding request to the IBM Watsonx.ai API.
"""

View file

@ -35,6 +35,7 @@ from litellm.proxy._types import (
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
from litellm.router import Router
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from .auth_checks_organization import organization_role_based_access_check
@ -61,7 +62,7 @@ def common_checks( # noqa: PLR0915
global_proxy_spend: Optional[float],
general_settings: dict,
route: str,
llm_router: Optional[litellm.Router],
llm_router: Optional[Router],
) -> bool:
"""
Common checks across jwt + key-based auth.
@ -347,7 +348,7 @@ async def get_end_user_object(
def model_in_access_group(
model: str, team_models: Optional[List[str]], llm_router: Optional[litellm.Router]
model: str, team_models: Optional[List[str]], llm_router: Optional[Router]
) -> bool:
from collections import defaultdict

View file

@ -4,10 +4,11 @@ from typing import Any, Optional
import litellm
from litellm import CustomLLM, ImageObject, ImageResponse, completion, get_llm_provider
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.utils import ModelResponse
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
def completion(self, *args, **kwargs) -> ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],

View file

@ -45,6 +45,7 @@ from litellm.types.guardrails import (
BedrockTextContent,
GuardrailEventHooks,
)
from litellm.types.utils import ModelResponse
GUARDRAIL_NAME = "bedrock"
@ -70,7 +71,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
def convert_to_bedrock_format(
self,
messages: Optional[List[Dict[str, str]]] = None,
response: Optional[Union[Any, litellm.ModelResponse]] = None,
response: Optional[Union[Any, ModelResponse]] = None,
) -> BedrockRequest:
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
bedrock_request_content: List[BedrockContentItem] = []

View file

@ -20,8 +20,11 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.factory import prompt_injection_detection_default_pt
from litellm.litellm_core_utils.prompt_templates.factory import (
prompt_injection_detection_default_pt,
)
from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
from litellm.router import Router
from litellm.utils import get_formatted_prompt
@ -32,7 +35,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
):
self.prompt_injection_params = prompt_injection_params
self.llm_router: Optional[litellm.Router] = None
self.llm_router: Optional[Router] = None
self.verbs = [
"Ignore",
@ -74,7 +77,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
if litellm.set_verbose is True:
print(print_statement) # noqa
def update_environment(self, router: Optional[litellm.Router] = None):
def update_environment(self, router: Optional[Router] = None):
self.llm_router = router
if (

View file

@ -16,6 +16,7 @@ from litellm.llms.anthropic.chat.handler import (
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload
from litellm.types.utils import ModelResponse, TextCompletionResponse
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
@ -43,9 +44,7 @@ class AnthropicPassthroughLoggingHandler:
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
model = response_body.get("model", "")
litellm_model_response: (
litellm.ModelResponse
) = AnthropicConfig().transform_response(
litellm_model_response: ModelResponse = AnthropicConfig().transform_response(
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
@ -89,9 +88,7 @@ class AnthropicPassthroughLoggingHandler:
@staticmethod
def _create_anthropic_response_logging_payload(
litellm_model_response: Union[
litellm.ModelResponse, litellm.TextCompletionResponse
],
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
model: str,
kwargs: dict,
start_time: datetime,
@ -204,7 +201,7 @@ class AnthropicPassthroughLoggingHandler:
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[litellm.ModelResponse, litellm.TextCompletionResponse]]:
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
"""
Builds complete response from raw Anthropic chunks

View file

@ -15,6 +15,12 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexModelResponseIterator,
)
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
TextCompletionResponse,
)
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
@ -40,7 +46,7 @@ class VertexPassthroughLoggingHandler:
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
instance_of_vertex_llm = litellm.VertexGeminiConfig()
litellm_model_response: litellm.ModelResponse = (
litellm_model_response: ModelResponse = (
instance_of_vertex_llm.transform_response(
model=model,
messages=[
@ -82,8 +88,8 @@ class VertexPassthroughLoggingHandler:
_json_response = httpx_response.json()
litellm_prediction_response: Union[
litellm.ModelResponse, litellm.EmbeddingResponse, litellm.ImageResponse
] = litellm.ModelResponse()
ModelResponse, EmbeddingResponse, ImageResponse
] = ModelResponse()
if vertex_image_generation_class.is_image_generation_response(
_json_response
):
@ -176,7 +182,7 @@ class VertexPassthroughLoggingHandler:
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[litellm.ModelResponse, litellm.TextCompletionResponse]]:
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
vertex_iterator = VertexModelResponseIterator(
streaming_response=None,
sync_stream=False,
@ -212,9 +218,7 @@ class VertexPassthroughLoggingHandler:
@staticmethod
def _create_vertex_response_logging_payload_for_generate_content(
litellm_model_response: Union[
litellm.ModelResponse, litellm.TextCompletionResponse
],
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
model: str,
kwargs: dict,
start_time: datetime,

View file

@ -109,6 +109,7 @@ from litellm import (
CreateBatchRequest,
ListBatchRequest,
RetrieveBatchRequest,
Router,
)
from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching.caching import DualCache, RedisCache
@ -482,7 +483,7 @@ user_config_file_path: Optional[str] = None
local_logging = True # writes logs to a local api_log.json file for debugging
experimental = False
#### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None
llm_router: Optional[Router] = None
llm_model_list: Optional[list] = None
general_settings: dict = {}
callback_settings: dict = {}
@ -2833,7 +2834,7 @@ class ProxyStartupEvent:
@classmethod
def _initialize_startup_logging(
cls,
llm_router: Optional[litellm.Router],
llm_router: Optional[Router],
proxy_logging_obj: ProxyLogging,
redis_usage_cache: Optional[RedisCache],
):

View file

@ -289,7 +289,7 @@ class ProxyLogging:
def startup_event(
self,
llm_router: Optional[litellm.Router],
llm_router: Optional[Router],
redis_usage_cache: Optional[RedisCache],
):
"""Initialize logging and alerting on proxy startup"""
@ -359,7 +359,7 @@ class ProxyLogging:
if redis_cache is not None:
self.internal_usage_cache.dual_cache.redis_cache = redis_cache
def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None):
def _init_litellm_callbacks(self, llm_router: Optional[Router] = None):
litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore
litellm.callbacks.append(self.max_budget_limiter) # type: ignore
litellm.callbacks.append(self.cache_control_check) # type: ignore

View file

@ -145,6 +145,7 @@ from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import (
CustomStreamWrapper,
EmbeddingResponse,
ModelResponse,
_is_region_eu,
calculate_max_parallel_requests,
@ -2071,7 +2072,7 @@ class Router:
input: Union[str, List],
is_async: Optional[bool] = False,
**kwargs,
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
try:
kwargs["model"] = model
kwargs["input"] = input
@ -2146,7 +2147,7 @@ class Router:
input: Union[str, List],
is_async: Optional[bool] = True,
**kwargs,
) -> litellm.EmbeddingResponse:
) -> EmbeddingResponse:
try:
kwargs["model"] = model
kwargs["input"] = input

View file

@ -1660,3 +1660,84 @@ class PersonalUIKeyGenerationConfig(KeyGenerationConfig):
class StandardKeyGenerationConfig(TypedDict, total=False):
team_key_generation: TeamUIKeyGenerationConfig
personal_key_generation: PersonalUIKeyGenerationConfig
class LlmProviders(str, Enum):
OPENAI = "openai"
OPENAI_LIKE = "openai_like" # embedding only
JINA_AI = "jina_ai"
XAI = "xai"
CUSTOM_OPENAI = "custom_openai"
TEXT_COMPLETION_OPENAI = "text-completion-openai"
COHERE = "cohere"
COHERE_CHAT = "cohere_chat"
CLARIFAI = "clarifai"
ANTHROPIC = "anthropic"
ANTHROPIC_TEXT = "anthropic_text"
REPLICATE = "replicate"
HUGGINGFACE = "huggingface"
TOGETHER_AI = "together_ai"
OPENROUTER = "openrouter"
VERTEX_AI = "vertex_ai"
VERTEX_AI_BETA = "vertex_ai_beta"
GEMINI = "gemini"
AI21 = "ai21"
BASETEN = "baseten"
AZURE = "azure"
AZURE_TEXT = "azure_text"
AZURE_AI = "azure_ai"
SAGEMAKER = "sagemaker"
SAGEMAKER_CHAT = "sagemaker_chat"
BEDROCK = "bedrock"
VLLM = "vllm"
NLP_CLOUD = "nlp_cloud"
PETALS = "petals"
OOBABOOGA = "oobabooga"
OLLAMA = "ollama"
OLLAMA_CHAT = "ollama_chat"
DEEPINFRA = "deepinfra"
PERPLEXITY = "perplexity"
MISTRAL = "mistral"
GROQ = "groq"
NVIDIA_NIM = "nvidia_nim"
CEREBRAS = "cerebras"
AI21_CHAT = "ai21_chat"
VOLCENGINE = "volcengine"
CODESTRAL = "codestral"
TEXT_COMPLETION_CODESTRAL = "text-completion-codestral"
DEEPSEEK = "deepseek"
SAMBANOVA = "sambanova"
MARITALK = "maritalk"
VOYAGE = "voyage"
CLOUDFLARE = "cloudflare"
XINFERENCE = "xinference"
FIREWORKS_AI = "fireworks_ai"
FRIENDLIAI = "friendliai"
WATSONX = "watsonx"
WATSONX_TEXT = "watsonx_text"
TRITON = "triton"
PREDIBASE = "predibase"
DATABRICKS = "databricks"
EMPOWER = "empower"
GITHUB = "github"
CUSTOM = "custom"
LITELLM_PROXY = "litellm_proxy"
HOSTED_VLLM = "hosted_vllm"
LM_STUDIO = "lm_studio"
GALADRIEL = "galadriel"
class LiteLLMLoggingBaseClass:
"""
Base class for logging pre and post call
Meant to simplify type checking for logging obj.
"""
def pre_call(self, input, api_key, model=None, additional_args={}):
pass
def post_call(
self, original_response, input=None, api_key=None, additional_args={}
):
pass

View file

@ -126,6 +126,7 @@ from litellm.types.utils import (
EmbeddingResponse,
Function,
ImageResponse,
LlmProviders,
Message,
ModelInfo,
ModelResponse,
@ -147,6 +148,7 @@ claude_json_str = json.dumps(json_data)
import importlib.metadata
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
@ -163,6 +165,8 @@ from typing import (
from openai import OpenAIError as OriginalError
from litellm.llms.base_llm.transformation import BaseConfig
from ._logging import verbose_logger
from .caching.caching import (
Cache,
@ -235,7 +239,6 @@ last_fetched_at = None
last_fetched_at_keys = None
######## Model Response #########################
# All liteLLM Model responses will be in this format, Follows the OpenAI Format
# https://docs.litellm.ai/docs/completion/output
# {
@ -6205,13 +6208,10 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
return messages
from litellm.llms.base_llm.transformation import BaseConfig
class ProviderConfigManager:
@staticmethod
def get_provider_chat_config( # noqa: PLR0915
model: str, provider: litellm.LlmProviders
model: str, provider: LlmProviders
) -> BaseConfig:
"""
Returns the provider config for a given provider.

View file

@ -173,8 +173,8 @@ def main():
"list_organization",
"user_update",
]
directory = "../../litellm/proxy/management_endpoints" # LOCAL
# directory = "./litellm/proxy/management_endpoints"
# directory = "../../litellm/proxy/management_endpoints" # LOCAL
directory = "./litellm/proxy/management_endpoints"
# Convert function names to set for faster lookup
target_functions = set(function_names)

View file

@ -0,0 +1,162 @@
import os
import ast
import sys
from typing import List, Tuple, Optional
def find_litellm_type_hints(directory: str) -> List[Tuple[str, int, str]]:
"""
Recursively search for Python files in the given directory
and find type hints containing 'litellm.'.
Args:
directory (str): The root directory to search for Python files
Returns:
List of tuples containing (file_path, line_number, type_hint)
"""
litellm_type_hints = []
def is_litellm_type_hint(node):
"""
Recursively check if a type annotation contains 'litellm.'
Handles more complex type hints like:
- Optional[litellm.Type]
- Union[litellm.Type1, litellm.Type2]
- Nested type hints
"""
try:
# Convert node to string representation
type_str = ast.unparse(node)
# Direct check for litellm in type string
if "litellm." in type_str:
return True
# Handle more complex type hints
if isinstance(node, ast.Subscript):
# Check Union or Optional types
if isinstance(node.value, ast.Name) and node.value.id in [
"Union",
"Optional",
]:
# Check each element in the Union/Optional type
if isinstance(node.slice, ast.Tuple):
return any(is_litellm_type_hint(elt) for elt in node.slice.elts)
else:
return is_litellm_type_hint(node.slice)
# Recursive check for subscripted types
return is_litellm_type_hint(node.value) or is_litellm_type_hint(
node.slice
)
# Recursive check for attribute types
if isinstance(node, ast.Attribute):
return "litellm." in ast.unparse(node)
# Recursive check for name types
if isinstance(node, ast.Name):
return "litellm" in node.id
return False
except Exception:
# Fallback to string checking if parsing fails
try:
return "litellm." in ast.unparse(node)
except:
return False
def scan_file(file_path: str):
"""
Scan a single Python file for LiteLLM type hints
"""
try:
# Use utf-8-sig to handle files with BOM, ignore errors
with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file:
tree = ast.parse(file.read())
for node in ast.walk(tree):
# Check type annotations in variable annotations
if isinstance(node, ast.AnnAssign) and node.annotation:
if is_litellm_type_hint(node.annotation):
litellm_type_hints.append(
(file_path, node.lineno, ast.unparse(node.annotation))
)
# Check type hints in function arguments
elif isinstance(node, ast.FunctionDef):
for arg in node.args.args:
if arg.annotation and is_litellm_type_hint(arg.annotation):
litellm_type_hints.append(
(file_path, arg.lineno, ast.unparse(arg.annotation))
)
# Check return type annotation
if node.returns and is_litellm_type_hint(node.returns):
litellm_type_hints.append(
(file_path, node.lineno, ast.unparse(node.returns))
)
except SyntaxError as e:
print(f"Syntax error in {file_path}: {e}", file=sys.stderr)
except Exception as e:
print(f"Error processing {file_path}: {e}", file=sys.stderr)
# Recursively walk through directory
for root, dirs, files in os.walk(directory):
# Remove virtual environment and cache directories from search
dirs[:] = [
d
for d in dirs
if not any(
venv in d
for venv in [
"venv",
"env",
"myenv",
".venv",
"__pycache__",
".pytest_cache",
]
)
]
for file in files:
if file.endswith(".py"):
full_path = os.path.join(root, file)
# Skip files in virtual environment or cache directories
if not any(
venv in full_path
for venv in [
"venv",
"env",
"myenv",
".venv",
"__pycache__",
".pytest_cache",
]
):
scan_file(full_path)
return litellm_type_hints
def main():
# Get directory from command line argument or use current directory
directory = "./litellm/"
# Find LiteLLM type hints
results = find_litellm_type_hints(directory)
# Print results
if results:
print("LiteLLM Type Hints Found:")
for file_path, line_num, type_hint in results:
print(f"{file_path}:{line_num} - {type_hint}")
else:
print("No LiteLLM type hints found.")
if __name__ == "__main__":
main()