Compare commits

...
Sign in to create a new pull request.

16 commits

Author SHA1 Message Date
Krrish Dholakia
5ae71bda70 fix: new run 2024-11-14 23:42:01 +05:30
Krrish Dholakia
4b1f390381 test: fix test 2024-11-14 23:35:23 +05:30
Krrish Dholakia
7d4829f908 test: handle gemini error 2024-11-14 22:26:06 +05:30
Krrish Dholakia
8b4c33c398 fix(exception_mapping_utils.py): map 'model is overloaded' to internal server error 2024-11-14 20:03:21 +05:30
Krish Dholakia
68d81f88f9
Litellm router disable fallbacks (#6743)
* bump: version 1.52.6 → 1.52.7

* feat(router.py): enable dynamically disabling fallbacks

Allows for enabling/disabling fallbacks per key

* feat(litellm_pre_call_utils.py): support setting 'disable_fallbacks' on litellm key

* test: fix test
2024-11-14 19:15:13 +05:30
Krish Dholakia
02b6f69004
Litellm router trace (#6742)
* feat(router.py): add trace_id to parent functions - allows tracking retry/fallbacks

* feat(router.py): log trace id across retry/fallback logic

allows grouping llm logs for the same request

* test: fix tests

* fix: fix test

* fix(transformation.py): only set non-none stop_sequences
2024-11-14 19:13:36 +05:30
Krrish Dholakia
9053a6a203 test: fix test 2024-11-14 19:11:29 +05:30
Krrish Dholakia
b4f7771ed1 test: update test to handle overloaded error 2024-11-14 19:02:57 +05:30
Krrish Dholakia
fef31a1c01 fix(handler.py): refactor together ai rerank call 2024-11-14 18:54:08 +05:30
Krrish Dholakia
8bf0f57b59 test: handle service unavailable error 2024-11-14 18:13:56 +05:30
Krrish Dholakia
51ec270501 feat(jina_ai/): add rerank support
Closes https://github.com/BerriAI/litellm/issues/6691
2024-11-14 18:11:18 +05:30
Krrish Dholakia
1988b13f46 fix: fix tests 2024-11-14 17:08:29 +05:30
Krrish Dholakia
b3e367db19 test: fix tests 2024-11-14 17:00:08 +05:30
Krrish Dholakia
62b97388a4 docs(anthropic.md): document all supported openai params for anthropic 2024-11-14 12:13:29 +05:30
Krrish Dholakia
756d838dfa feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param 2024-11-14 12:07:23 +05:30
Krrish Dholakia
b6c9032454 fix(ollama.py): fix get model info request
Fixes https://github.com/BerriAI/litellm/issues/6703
2024-11-14 01:04:33 +05:30
37 changed files with 856 additions and 249 deletions

View file

@ -957,3 +957,69 @@ curl http://0.0.0.0:4000/v1/chat/completions \
``` ```
</TabItem> </TabItem>
</Tabs> </Tabs>
## Usage - passing 'user_id' to Anthropic
LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
response = completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
user="user_123",
)
```
</TabItem>
</TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: claude-3-5-sonnet-20240620
litellm_params:
model: anthropic/claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
```
2. Start Proxy
```
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
-d '{
"model": "claude-3-5-sonnet-20240620",
"messages": [{"role": "user", "content": "What is Anthropic?"}],
"user": "user_123"
}'
```
</TabItem>
</Tabs>
## All Supported OpenAI Params
```
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"extra_headers",
"parallel_tool_calls",
"response_format",
"user"
```

View file

@ -1124,10 +1124,13 @@ def exception_type( # type: ignore # noqa: PLR0915
), ),
), ),
) )
elif "500 Internal Server Error" in error_str: elif (
"500 Internal Server Error" in error_str
or "The model is overloaded." in error_str
):
exception_mapping_worked = True exception_mapping_worked = True
raise ServiceUnavailableError( raise litellm.InternalServerError(
message=f"litellm.ServiceUnavailableError: VertexAIException - {error_str}", message=f"litellm.InternalServerError: VertexAIException - {error_str}",
model=model, model=model,
llm_provider="vertex_ai", llm_provider="vertex_ai",
litellm_debug_info=extra_information, litellm_debug_info=extra_information,

View file

@ -201,6 +201,7 @@ class Logging:
start_time, start_time,
litellm_call_id: str, litellm_call_id: str,
function_id: str, function_id: str,
litellm_trace_id: Optional[str] = None,
dynamic_input_callbacks: Optional[ dynamic_input_callbacks: Optional[
List[Union[str, Callable, CustomLogger]] List[Union[str, Callable, CustomLogger]]
] = None, ] = None,
@ -238,6 +239,7 @@ class Logging:
self.start_time = start_time # log the call start time self.start_time = start_time # log the call start time
self.call_type = call_type self.call_type = call_type
self.litellm_call_id = litellm_call_id self.litellm_call_id = litellm_call_id
self.litellm_trace_id = litellm_trace_id
self.function_id = function_id self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = ( self.sync_streaming_chunks: List[Any] = (
@ -274,6 +276,11 @@ class Logging:
self.completion_start_time: Optional[datetime.datetime] = None self.completion_start_time: Optional[datetime.datetime] = None
self._llm_caching_handler: Optional[LLMCachingHandler] = None self._llm_caching_handler: Optional[LLMCachingHandler] = None
self.model_call_details = {
"litellm_trace_id": litellm_trace_id,
"litellm_call_id": litellm_call_id,
}
def process_dynamic_callbacks(self): def process_dynamic_callbacks(self):
""" """
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
@ -381,21 +388,23 @@ class Logging:
self.logger_fn = litellm_params.get("logger_fn", None) self.logger_fn = litellm_params.get("logger_fn", None)
verbose_logger.debug(f"self.optional_params: {self.optional_params}") verbose_logger.debug(f"self.optional_params: {self.optional_params}")
self.model_call_details = { self.model_call_details.update(
"model": self.model, {
"messages": self.messages, "model": self.model,
"optional_params": self.optional_params, "messages": self.messages,
"litellm_params": self.litellm_params, "optional_params": self.optional_params,
"start_time": self.start_time, "litellm_params": self.litellm_params,
"stream": self.stream, "start_time": self.start_time,
"user": user, "stream": self.stream,
"call_type": str(self.call_type), "user": user,
"litellm_call_id": self.litellm_call_id, "call_type": str(self.call_type),
"completion_start_time": self.completion_start_time, "litellm_call_id": self.litellm_call_id,
"standard_callback_dynamic_params": self.standard_callback_dynamic_params, "completion_start_time": self.completion_start_time,
**self.optional_params, "standard_callback_dynamic_params": self.standard_callback_dynamic_params,
**additional_params, **self.optional_params,
} **additional_params,
}
)
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation ## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
if "stream_options" in additional_params: if "stream_options" in additional_params:
@ -2806,6 +2815,7 @@ def get_standard_logging_object_payload(
payload: StandardLoggingPayload = StandardLoggingPayload( payload: StandardLoggingPayload = StandardLoggingPayload(
id=str(id), id=str(id),
trace_id=kwargs.get("litellm_trace_id"), # type: ignore
call_type=call_type or "", call_type=call_type or "",
cache_hit=cache_hit, cache_hit=cache_hit,
status=status, status=status,

View file

@ -440,8 +440,8 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion=None, acompletion=None,
litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
client=None, client=None,
@ -464,6 +464,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
headers=headers, headers=headers,
_is_function_call=_is_function_call, _is_function_call=_is_function_call,
is_vertex_request=is_vertex_request, is_vertex_request=is_vertex_request,

View file

@ -91,6 +91,7 @@ class AnthropicConfig:
"extra_headers", "extra_headers",
"parallel_tool_calls", "parallel_tool_calls",
"response_format", "response_format",
"user",
] ]
def get_cache_control_headers(self) -> dict: def get_cache_control_headers(self) -> dict:
@ -246,6 +247,28 @@ class AnthropicConfig:
anthropic_tools.append(new_tool) anthropic_tools.append(new_tool)
return anthropic_tools return anthropic_tools
def _map_stop_sequences(
self, stop: Optional[Union[str, List[str]]]
) -> Optional[List[str]]:
new_stop: Optional[List[str]] = None
if isinstance(stop, str):
if (
stop == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
return new_stop
new_stop = [stop]
elif isinstance(stop, list):
new_v = []
for v in stop:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
new_stop = new_v
return new_stop
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,
@ -271,26 +294,10 @@ class AnthropicConfig:
optional_params["tool_choice"] = _tool_choice optional_params["tool_choice"] = _tool_choice
if param == "stream" and value is True: if param == "stream" and value is True:
optional_params["stream"] = value optional_params["stream"] = value
if param == "stop": if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
if isinstance(value, str): _value = self._map_stop_sequences(value)
if ( if _value is not None:
value == "\n" optional_params["stop_sequences"] = _value
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
value = [value]
elif isinstance(value, list):
new_v = []
for v in value:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
value = new_v
else:
continue
optional_params["stop_sequences"] = value
if param == "temperature": if param == "temperature":
optional_params["temperature"] = value optional_params["temperature"] = value
if param == "top_p": if param == "top_p":
@ -314,7 +321,8 @@ class AnthropicConfig:
optional_params["tools"] = [_tool] optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True optional_params["json_mode"] = True
if param == "user":
optional_params["metadata"] = {"user_id": value}
## VALIDATE REQUEST ## VALIDATE REQUEST
""" """
Anthropic doesn't support tool calling without `tools=` param specified. Anthropic doesn't support tool calling without `tools=` param specified.
@ -465,6 +473,7 @@ class AnthropicConfig:
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
headers: dict, headers: dict,
_is_function_call: bool, _is_function_call: bool,
is_vertex_request: bool, is_vertex_request: bool,
@ -502,6 +511,15 @@ class AnthropicConfig:
if "tools" in optional_params: if "tools" in optional_params:
_is_function_call = True _is_function_call = True
## Handle user_id in metadata
_litellm_metadata = litellm_params.get("metadata", None)
if (
_litellm_metadata
and isinstance(_litellm_metadata, dict)
and "user_id" in _litellm_metadata
):
optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]}
data = { data = {
"messages": anthropic_messages, "messages": anthropic_messages,
**optional_params, **optional_params,

View file

@ -76,4 +76,4 @@ class JinaAIEmbeddingConfig:
or get_secret_str("JINA_AI_API_KEY") or get_secret_str("JINA_AI_API_KEY")
or get_secret_str("JINA_AI_TOKEN") or get_secret_str("JINA_AI_TOKEN")
) )
return LlmProviders.OPENAI_LIKE.value, api_base, dynamic_api_key return LlmProviders.JINA_AI.value, api_base, dynamic_api_key

View file

@ -0,0 +1,96 @@
"""
Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs
"""
import uuid
from typing import Any, Dict, List, Optional, Union
import httpx
from pydantic import BaseModel
import litellm
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
from litellm.types.rerank import RerankRequest, RerankResponse
class JinaAIRerank(BaseLLM):
def rerank(
self,
model: str,
api_key: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
) -> RerankResponse:
client = _get_httpx_client()
request_data = RerankRequest(
model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
)
# exclude None values from request_data
request_data_dict = request_data.dict(exclude_none=True)
if _is_async:
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
response = client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
async def async_rerank( # New async method
self,
request_data_dict: Dict[str, Any],
api_key: str,
) -> RerankResponse:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.JINA_AI
) # Use async client
response = await client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
pass

View file

@ -0,0 +1,36 @@
"""
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
Docs - https://jina.ai/reranker
"""
import uuid
from typing import List, Optional
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class JinaAIRerankConfig:
def _transform_response(self, response: dict) -> RerankResponse:
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = response.get("results")
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response

View file

@ -185,6 +185,8 @@ class OllamaConfig:
"name": "mistral" "name": "mistral"
}' }'
""" """
if model.startswith("ollama/") or model.startswith("ollama_chat/"):
model = model.split("/", 1)[1]
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434" api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"
try: try:

View file

@ -15,7 +15,14 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.rerank import RerankRequest, RerankResponse from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig
from litellm.types.rerank import (
RerankBilledUnits,
RerankRequest,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class TogetherAIRerank(BaseLLM): class TogetherAIRerank(BaseLLM):
@ -65,13 +72,7 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json() _json_response = response.json()
response = RerankResponse( return TogetherAIRerankConfig()._transform_response(_json_response)
id=_json_response.get("id"),
results=_json_response.get("results"),
meta=_json_response.get("meta") or {},
)
return response
async def async_rerank( # New async method async def async_rerank( # New async method
self, self,
@ -97,10 +98,4 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json() _json_response = response.json()
return RerankResponse( return TogetherAIRerankConfig()._transform_response(_json_response)
id=_json_response.get("id"),
results=_json_response.get("results"),
meta=_json_response.get("meta") or {},
) # Return response
pass

View file

@ -0,0 +1,34 @@
"""
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
"""
import uuid
from typing import List, Optional
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class TogetherAIRerankConfig:
def _transform_response(self, response: dict) -> RerankResponse:
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = response.get("results")
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response

View file

@ -1066,6 +1066,7 @@ def completion( # type: ignore # noqa: PLR0915
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"), azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"), user_continue_message=kwargs.get("user_continue_message"),
base_model=base_model, base_model=base_model,
litellm_trace_id=kwargs.get("litellm_trace_id"),
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -3455,7 +3456,7 @@ def embedding( # noqa: PLR0915
client=client, client=client,
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "openai_like": elif custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai":
api_base = ( api_base = (
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE") api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
) )

View file

@ -1,122 +1,15 @@
model_list: model_list:
- model_name: "*" # GPT-4 Turbo Models
litellm_params:
model: claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-3-5-sonnet-aihubmix
litellm_params:
model: openai/claude-3-5-sonnet-20240620
input_cost_per_token: 0.000003 # 3$/M
output_cost_per_token: 0.000015 # 15$/M
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
api_key: my-fake-key
- model_name: fake-openai-endpoint-2
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
stream_timeout: 0.001
timeout: 1
rpm: 1
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
## bedrock chat completions
- model_name: "*anthropic.claude*"
litellm_params:
model: bedrock/*anthropic.claude*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
guardrailConfig:
"guardrailIdentifier": "h4dsqwhp6j66"
"guardrailVersion": "2"
"trace": "enabled"
## bedrock embeddings
- model_name: "*amazon.titan-embed-*"
litellm_params:
model: bedrock/amazon.titan-embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "*cohere.embed-*"
litellm_params:
model: bedrock/cohere.embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "bedrock/*"
litellm_params:
model: bedrock/*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: gpt-4 - model_name: gpt-4
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: gpt-4
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - model_name: rerank-model
api_version: "2023-05-15" litellm_params:
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault model: jina_ai/jina-reranker-v2-base-multilingual
rpm: 480
timeout: 300
stream_timeout: 60
litellm_settings:
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
# callbacks: ["otel", "prometheus"]
default_redis_batch_cache_expiry: 10
# default_team_settings:
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
# success_callback: ["langfuse"]
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
# litellm_settings:
# cache: True
# cache_params:
# type: redis
# # disable caching on the actual API call
# supported_call_types: []
# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
# host: os.environ/REDIS_HOST
# port: os.environ/REDIS_PORT
# password: os.environ/REDIS_PASSWORD
# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
# # see https://docs.litellm.ai/docs/proxy/prometheus
# callbacks: ['otel']
# # router_settings: router_settings:
# # routing_strategy: latency-based-routing model_group_alias:
# # routing_strategy_args: "gpt-4-turbo": # Aliased model name
# # # only assign 40% of traffic to the fastest deployment to avoid overloading it model: "gpt-4" # Actual model name in 'model_list'
# # lowest_latency_buffer: 0.4 hidden: true
# # # consider last five minutes of calls for latency calculation
# # ttl: 300
# # redis_host: os.environ/REDIS_HOST
# # redis_port: os.environ/REDIS_PORT
# # redis_password: os.environ/REDIS_PASSWORD
# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
# # general_settings:
# # master_key: os.environ/LITELLM_MASTER_KEY
# # database_url: os.environ/DATABASE_URL
# # disable_master_key_return: true
# # # alerting: ['slack', 'email']
# # alerting: ['email']
# # # Batch write spend updates every 60s
# # proxy_batch_write_at: 60
# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
# # # our api keys rarely change
# # user_api_key_cache_ttl: 3600

View file

@ -8,6 +8,7 @@ Run checks for:
2. If user is in budget 2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
""" """
import time import time
import traceback import traceback
from datetime import datetime from datetime import datetime

View file

@ -274,6 +274,51 @@ class LiteLLMProxyRequestSetup:
) )
return user_api_key_logged_metadata return user_api_key_logged_metadata
@staticmethod
def add_key_level_controls(
key_metadata: dict, data: dict, _metadata_variable_name: str
):
data = data.copy()
if "cache" in key_metadata:
data["cache"] = {}
if isinstance(key_metadata["cache"], dict):
for k, v in key_metadata["cache"].items():
if k in SupportedCacheControls:
data["cache"][k] = v
## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
if "tags" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["tags"], list
):
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
else:
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
if "spend_logs_metadata" in key_metadata and isinstance(
key_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
for key, value in key_metadata["spend_logs_metadata"].items():
if (
key not in data[_metadata_variable_name]["spend_logs_metadata"]
): # don't override k-v pair sent by request (user request)
data[_metadata_variable_name]["spend_logs_metadata"][
key
] = value
else:
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
"spend_logs_metadata"
]
## KEY-LEVEL DISABLE FALLBACKS
if "disable_fallbacks" in key_metadata and isinstance(
key_metadata["disable_fallbacks"], bool
):
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
return data
async def add_litellm_data_to_request( # noqa: PLR0915 async def add_litellm_data_to_request( # noqa: PLR0915
data: dict, data: dict,
@ -389,37 +434,11 @@ async def add_litellm_data_to_request( # noqa: PLR0915
### KEY-LEVEL Controls ### KEY-LEVEL Controls
key_metadata = user_api_key_dict.metadata key_metadata = user_api_key_dict.metadata
if "cache" in key_metadata: data = LiteLLMProxyRequestSetup.add_key_level_controls(
data["cache"] = {} key_metadata=key_metadata,
if isinstance(key_metadata["cache"], dict): data=data,
for k, v in key_metadata["cache"].items(): _metadata_variable_name=_metadata_variable_name,
if k in SupportedCacheControls: )
data["cache"][k] = v
## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
if "tags" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["tags"], list
):
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
else:
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
if "spend_logs_metadata" in key_metadata and isinstance(
key_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
for key, value in key_metadata["spend_logs_metadata"].items():
if (
key not in data[_metadata_variable_name]["spend_logs_metadata"]
): # don't override k-v pair sent by request (user request)
data[_metadata_variable_name]["spend_logs_metadata"][key] = value
else:
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
"spend_logs_metadata"
]
## TEAM-LEVEL SPEND LOGS/TAGS ## TEAM-LEVEL SPEND LOGS/TAGS
team_metadata = user_api_key_dict.team_metadata or {} team_metadata = user_api_key_dict.team_metadata or {}
if "tags" in team_metadata and team_metadata["tags"] is not None: if "tags" in team_metadata and team_metadata["tags"] is not None:

View file

@ -8,7 +8,8 @@ from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure_ai.rerank import AzureAIRerank from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.together_ai.rerank import TogetherAIRerank from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
from litellm.types.rerank import RerankRequest, RerankResponse from litellm.types.rerank import RerankRequest, RerankResponse
from litellm.types.router import * from litellm.types.router import *
@ -19,6 +20,7 @@ from litellm.utils import client, exception_type, supports_httpx_timeout
cohere_rerank = CohereRerank() cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank() together_rerank = TogetherAIRerank()
azure_ai_rerank = AzureAIRerank() azure_ai_rerank = AzureAIRerank()
jina_ai_rerank = JinaAIRerank()
################################################# #################################################
@ -247,7 +249,23 @@ def rerank(
api_key=api_key, api_key=api_key,
_is_async=_is_async, _is_async=_is_async,
) )
elif _custom_llm_provider == "jina_ai":
if dynamic_api_key is None:
raise ValueError(
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
)
response = jina_ai_rerank.rerank(
model=model,
api_key=dynamic_api_key,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async,
)
else: else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}") raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

View file

@ -679,9 +679,8 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["original_function"] = self._completion kwargs["original_function"] = self._completion
kwargs.get("request_timeout", self.timeout) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs) response = self.function_with_fallbacks(**kwargs)
return response return response
except Exception as e: except Exception as e:
@ -783,8 +782,7 @@ class Router:
kwargs["stream"] = stream kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.setdefault("metadata", {}).update({"model_group": model})
request_priority = kwargs.get("priority") or self.default_priority request_priority = kwargs.get("priority") or self.default_priority
@ -948,6 +946,17 @@ class Router:
self.fail_calls[model_name] += 1 self.fail_calls[model_name] += 1
raise e raise e
def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None:
"""
Adds/updates to kwargs:
- num_retries
- litellm_trace_id
- metadata
"""
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
kwargs.setdefault("metadata", {}).update({"model_group": model})
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None: def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
""" """
Adds default litellm params to kwargs, if set. Adds default litellm params to kwargs, if set.
@ -1511,9 +1520,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["file"] = file kwargs["file"] = file
kwargs["original_function"] = self._atranscription kwargs["original_function"] = self._atranscription
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response
@ -1688,9 +1695,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["input"] = input kwargs["input"] = input
kwargs["original_function"] = self._arerank kwargs["original_function"] = self._arerank
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
@ -1839,9 +1844,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["prompt"] = prompt kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion kwargs["original_function"] = self._atext_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response
@ -2112,9 +2115,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["input"] = input kwargs["input"] = input
kwargs["original_function"] = self._aembedding kwargs["original_function"] = self._aembedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response
except Exception as e: except Exception as e:
@ -2609,6 +2610,7 @@ class Router:
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
model_group: Optional[str] = kwargs.get("model") model_group: Optional[str] = kwargs.get("model")
disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False)
fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks) fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks: Optional[List] = kwargs.get( context_window_fallbacks: Optional[List] = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
@ -2616,6 +2618,7 @@ class Router:
content_policy_fallbacks: Optional[List] = kwargs.get( content_policy_fallbacks: Optional[List] = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks "content_policy_fallbacks", self.content_policy_fallbacks
) )
try: try:
self._handle_mock_testing_fallbacks( self._handle_mock_testing_fallbacks(
kwargs=kwargs, kwargs=kwargs,
@ -2635,7 +2638,7 @@ class Router:
original_model_group: Optional[str] = kwargs.get("model") # type: ignore original_model_group: Optional[str] = kwargs.get("model") # type: ignore
fallback_failure_exception_str = "" fallback_failure_exception_str = ""
if original_model_group is None: if disable_fallbacks is True or original_model_group is None:
raise e raise e
input_kwargs = { input_kwargs = {

View file

@ -7,6 +7,7 @@ https://docs.cohere.com/reference/rerank
from typing import List, Optional, Union from typing import List, Optional, Union
from pydantic import BaseModel, PrivateAttr from pydantic import BaseModel, PrivateAttr
from typing_extensions import TypedDict
class RerankRequest(BaseModel): class RerankRequest(BaseModel):
@ -19,10 +20,26 @@ class RerankRequest(BaseModel):
max_chunks_per_doc: Optional[int] = None max_chunks_per_doc: Optional[int] = None
class RerankBilledUnits(TypedDict, total=False):
search_units: int
total_tokens: int
class RerankTokens(TypedDict, total=False):
input_tokens: int
output_tokens: int
class RerankResponseMeta(TypedDict, total=False):
api_version: dict
billed_units: RerankBilledUnits
tokens: RerankTokens
class RerankResponse(BaseModel): class RerankResponse(BaseModel):
id: str id: str
results: List[dict] # Contains index and relevance_score results: List[dict] # Contains index and relevance_score
meta: Optional[dict] = None # Contains api_version and billed_units meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
# Define private attributes using PrivateAttr # Define private attributes using PrivateAttr
_hidden_params: dict = PrivateAttr(default_factory=dict) _hidden_params: dict = PrivateAttr(default_factory=dict)

View file

@ -150,6 +150,8 @@ class GenericLiteLLMParams(BaseModel):
max_retries: Optional[int] = None max_retries: Optional[int] = None
organization: Optional[str] = None # for openai orgs organization: Optional[str] = None # for openai orgs
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
## LOGGING PARAMS ##
litellm_trace_id: Optional[str] = None
## UNIFIED PROJECT/REGION ## ## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None region_name: Optional[str] = None
## VERTEX AI ## ## VERTEX AI ##
@ -186,6 +188,8 @@ class GenericLiteLLMParams(BaseModel):
None # timeout when making stream=True calls, if str, pass in as os.environ/ None # timeout when making stream=True calls, if str, pass in as os.environ/
), ),
organization: Optional[str] = None, # for openai orgs organization: Optional[str] = None, # for openai orgs
## LOGGING PARAMS ##
litellm_trace_id: Optional[str] = None,
## UNIFIED PROJECT/REGION ## ## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None, region_name: Optional[str] = None,
## VERTEX AI ## ## VERTEX AI ##

View file

@ -1334,6 +1334,7 @@ class ResponseFormatChunk(TypedDict, total=False):
all_litellm_params = [ all_litellm_params = [
"metadata", "metadata",
"litellm_trace_id",
"tags", "tags",
"acompletion", "acompletion",
"aimg_generation", "aimg_generation",
@ -1523,6 +1524,7 @@ StandardLoggingPayloadStatus = Literal["success", "failure"]
class StandardLoggingPayload(TypedDict): class StandardLoggingPayload(TypedDict):
id: str id: str
trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
call_type: str call_type: str
response_cost: float response_cost: float
response_cost_failure_debug_info: Optional[ response_cost_failure_debug_info: Optional[

View file

@ -527,6 +527,7 @@ def function_setup( # noqa: PLR0915
messages=messages, messages=messages,
stream=stream, stream=stream,
litellm_call_id=kwargs["litellm_call_id"], litellm_call_id=kwargs["litellm_call_id"],
litellm_trace_id=kwargs.get("litellm_trace_id"),
function_id=function_id or "", function_id=function_id or "",
call_type=call_type, call_type=call_type,
start_time=start_time, start_time=start_time,
@ -2056,6 +2057,7 @@ def get_litellm_params(
azure_ad_token_provider=None, azure_ad_token_provider=None,
user_continue_message=None, user_continue_message=None,
base_model=None, base_model=None,
litellm_trace_id=None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2084,6 +2086,7 @@ def get_litellm_params(
"user_continue_message": user_continue_message, "user_continue_message": user_continue_message,
"base_model": base_model "base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata), or _get_base_model_from_litellm_call_metadata(metadata=metadata),
"litellm_trace_id": litellm_trace_id,
} }
return litellm_params return litellm_params

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.52.6" version = "1.52.7"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.52.6" version = "1.52.7"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -13,8 +13,11 @@ sys.path.insert(
import litellm import litellm
from litellm.exceptions import BadRequestError from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py # test_example.py
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View file

@ -0,0 +1,115 @@
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py
from abc import ABC, abstractmethod
def assert_response_shape(response, custom_llm_provider):
expected_response_shape = {"id": str, "results": list, "meta": dict}
expected_results_shape = {"index": int, "relevance_score": float}
expected_meta_shape = {"api_version": dict, "billed_units": dict}
expected_api_version_shape = {"version": str}
expected_billed_units_shape = {"search_units": int}
assert isinstance(response.id, expected_response_shape["id"])
assert isinstance(response.results, expected_response_shape["results"])
for result in response.results:
assert isinstance(result["index"], expected_results_shape["index"])
assert isinstance(
result["relevance_score"], expected_results_shape["relevance_score"]
)
assert isinstance(response.meta, expected_response_shape["meta"])
if custom_llm_provider == "cohere":
assert isinstance(
response.meta["api_version"], expected_meta_shape["api_version"]
)
assert isinstance(
response.meta["api_version"]["version"],
expected_api_version_shape["version"],
)
assert isinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"]
)
assert isinstance(
response.meta["billed_units"]["search_units"],
expected_billed_units_shape["search_units"],
)
class BaseLLMRerankTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_rerank_call_args(self) -> dict:
"""Must return the base rerank call args"""
pass
@abstractmethod
def get_custom_llm_provider(self) -> litellm.LlmProviders:
"""Must return the custom llm provider"""
pass
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(self, sync_mode):
rerank_call_args = self.get_base_rerank_call_args()
custom_llm_provider = self.get_custom_llm_provider()
if sync_mode is True:
response = litellm.rerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)
else:
response = await litellm.arerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)

View file

@ -0,0 +1,23 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from base_rerank_unit_tests import BaseLLMRerankTest
import litellm
class TestJinaAI(BaseLLMRerankTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.JINA_AI
def get_base_rerank_call_args(self) -> dict:
return {
"model": "jina_ai/jina-reranker-v2-base-multilingual",
}

View file

@ -921,3 +921,16 @@ def test_watsonx_text_top_k():
) )
print(optional_params) print(optional_params)
assert optional_params["top_k"] == 10 assert optional_params["top_k"] == 10
def test_forward_user_param():
from litellm.utils import get_supported_openai_params, get_optional_params
model = "claude-3-5-sonnet-20240620"
optional_params = get_optional_params(
model=model,
user="test_user",
custom_llm_provider="anthropic",
)
assert optional_params["metadata"]["user_id"] == "test_user"

View file

@ -679,6 +679,8 @@ async def test_anthropic_no_content_error():
frequency_penalty=0.8, frequency_penalty=0.8,
) )
pass
except litellm.InternalServerError:
pass pass
except litellm.APIError as e: except litellm.APIError as e:
assert e.status_code == 500 assert e.status_code == 500

View file

@ -1624,3 +1624,55 @@ async def test_standard_logging_payload_stream_usage(sync_mode):
print(f"standard_logging_object usage: {built_response.usage}") print(f"standard_logging_object usage: {built_response.usage}")
except litellm.InternalServerError: except litellm.InternalServerError:
pass pass
def test_standard_logging_retries():
"""
know if a request was retried.
"""
from litellm.types.utils import StandardLoggingPayload
from litellm.router import Router
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": "test-api-key",
},
}
]
)
with patch.object(
customHandler, "log_failure_event", new=MagicMock()
) as mock_client:
try:
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
num_retries=1,
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
assert mock_client.call_count == 2
assert (
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
"trace_id"
]
is not None
)
assert (
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
"trace_id"
]
== mock_client.call_args_list[1].kwargs["kwargs"][
"standard_logging_object"
]["trace_id"]
)

View file

@ -157,7 +157,7 @@ def test_get_llm_provider_jina_ai():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="jina_ai/jina-embeddings-v3", model="jina_ai/jina-embeddings-v3",
) )
assert custom_llm_provider == "openai_like" assert custom_llm_provider == "jina_ai"
assert api_base == "https://api.jina.ai/v1" assert api_base == "https://api.jina.ai/v1"
assert model == "jina-embeddings-v3" assert model == "jina-embeddings-v3"

View file

@ -89,11 +89,16 @@ def test_get_model_info_ollama_chat():
"template": "tools", "template": "tools",
} }
), ),
): ) as mock_client:
info = OllamaConfig().get_model_info("mistral") info = OllamaConfig().get_model_info("mistral")
print("info", info)
assert info["supports_function_calling"] is True assert info["supports_function_calling"] is True
info = get_model_info("ollama/mistral") info = get_model_info("ollama/mistral")
print("info", info)
assert info["supports_function_calling"] is True assert info["supports_function_calling"] is True
mock_client.assert_called()
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["json"]["name"] == "mistral"

View file

@ -1455,3 +1455,46 @@ async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode):
assert isinstance( assert isinstance(
exc_info.value, litellm.AuthenticationError exc_info.value, litellm.AuthenticationError
), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}" ), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"
@pytest.mark.asyncio
async def test_router_disable_fallbacks_dynamically():
from litellm.router import run_async_fallback
router = Router(
model_list=[
{
"model_name": "bad-model",
"litellm_params": {
"model": "openai/my-bad-model",
"api_key": "my-bad-api-key",
},
},
{
"model_name": "good-model",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
],
fallbacks=[{"bad-model": ["good-model"]}],
default_fallbacks=["good-model"],
)
with patch.object(
router,
"log_retry",
new=MagicMock(return_value=None),
) as mock_client:
try:
resp = await router.acompletion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
disable_fallbacks=True,
)
print(resp)
except Exception as e:
print(e)
mock_client.assert_not_called()

View file

@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict from collections import defaultdict
from dotenv import load_dotenv from dotenv import load_dotenv
from unittest.mock import patch, MagicMock, AsyncMock
load_dotenv() load_dotenv()
@ -83,3 +84,93 @@ def test_returned_settings():
except Exception: except Exception:
print(traceback.format_exc()) print(traceback.format_exc())
pytest.fail("An error occurred - " + traceback.format_exc()) pytest.fail("An error occurred - " + traceback.format_exc())
from litellm.types.utils import CallTypes
def test_update_kwargs_before_fallbacks_unit_test():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
],
)
kwargs = {"messages": [{"role": "user", "content": "write 1 sentence poem"}]}
router._update_kwargs_before_fallbacks(
model="gpt-3.5-turbo",
kwargs=kwargs,
)
assert kwargs["litellm_trace_id"] is not None
@pytest.mark.parametrize(
"call_type",
[
CallTypes.acompletion,
CallTypes.atext_completion,
CallTypes.aembedding,
CallTypes.arerank,
CallTypes.atranscription,
],
)
@pytest.mark.asyncio
async def test_update_kwargs_before_fallbacks(call_type):
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
],
)
if call_type.value.startswith("a"):
with patch.object(router, "async_function_with_fallbacks") as mock_client:
if call_type.value == "acompletion":
input_kwarg = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
elif (
call_type.value == "atext_completion"
or call_type.value == "aimage_generation"
):
input_kwarg = {
"prompt": "Hello, how are you?",
}
elif call_type.value == "aembedding" or call_type.value == "arerank":
input_kwarg = {
"input": "Hello, how are you?",
}
elif call_type.value == "atranscription":
input_kwarg = {
"file": "path/to/file",
}
else:
input_kwarg = {}
await getattr(router, call_type.value)(
model="gpt-3.5-turbo",
**input_kwarg,
)
mock_client.assert_called_once()
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None

View file

@ -172,6 +172,8 @@ def test_stream_chunk_builder_litellm_usage_chunks():
""" """
Checks if stream_chunk_builder is able to correctly rebuild with given metadata from streaming chunks Checks if stream_chunk_builder is able to correctly rebuild with given metadata from streaming chunks
""" """
from litellm.types.utils import Usage
messages = [ messages = [
{"role": "user", "content": "Tell me the funniest joke you know."}, {"role": "user", "content": "Tell me the funniest joke you know."},
{ {
@ -182,24 +184,28 @@ def test_stream_chunk_builder_litellm_usage_chunks():
{"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"}, {"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
{"role": "user", "content": "\nI am waiting...\n\n...\n"}, {"role": "user", "content": "\nI am waiting...\n\n...\n"},
] ]
# make a regular gemini call
response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
)
usage: litellm.Usage = response.usage usage: litellm.Usage = Usage(
completion_tokens=27,
prompt_tokens=55,
total_tokens=82,
completion_tokens_details=None,
prompt_tokens_details=None,
)
gemini_pt = usage.prompt_tokens gemini_pt = usage.prompt_tokens
# make a streaming gemini call # make a streaming gemini call
response = completion( try:
model="gemini/gemini-1.5-flash", response = completion(
messages=messages, model="gemini/gemini-1.5-flash",
stream=True, messages=messages,
complete_response=True, stream=True,
stream_options={"include_usage": True}, complete_response=True,
) stream_options={"include_usage": True},
)
except litellm.InternalServerError as e:
pytest.skip(f"Skipping test due to internal server error - {str(e)}")
usage: litellm.Usage = response.usage usage: litellm.Usage = response.usage

View file

@ -736,6 +736,8 @@ async def test_acompletion_claude_2_stream():
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
except litellm.InternalServerError:
pass
except litellm.RateLimitError: except litellm.RateLimitError:
pass pass
except Exception as e: except Exception as e:
@ -3272,7 +3274,7 @@ def test_completion_claude_3_function_call_with_streaming():
], # "claude-3-opus-20240229" ], # "claude-3-opus-20240229"
) # ) #
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_claude_3_function_call_with_streaming(model): async def test_acompletion_function_call_with_streaming(model):
litellm.set_verbose = True litellm.set_verbose = True
tools = [ tools = [
{ {
@ -3331,6 +3333,10 @@ async def test_acompletion_claude_3_function_call_with_streaming(model):
validate_final_streaming_function_calling_chunk(chunk=chunk) validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1 idx += 1
# raise Exception("it worked! ") # raise Exception("it worked! ")
except litellm.InternalServerError:
pass
except litellm.ServiceUnavailableError:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -748,7 +748,7 @@ def test_convert_model_response_object():
("vertex_ai/gemini-1.5-pro", True), ("vertex_ai/gemini-1.5-pro", True),
("gemini/gemini-1.5-pro", True), ("gemini/gemini-1.5-pro", True),
("predibase/llama3-8b-instruct", True), ("predibase/llama3-8b-instruct", True),
("gpt-4o", False), ("gpt-3.5-turbo", False),
], ],
) )
def test_supports_response_schema(model, expected_bool): def test_supports_response_schema(model, expected_bool):

View file

@ -188,7 +188,8 @@ def test_completion_claude_3_function_call_with_otel(model):
) )
print("response from LiteLLM", response) print("response from LiteLLM", response)
except litellm.InternalServerError:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
finally: finally:

View file

@ -1500,6 +1500,31 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
assert new_data["failure_callback"] == expected_failure_callbacks assert new_data["failure_callback"] == expected_failure_callbacks
@pytest.mark.asyncio
@pytest.mark.parametrize(
"disable_fallbacks_set",
[
True,
False,
],
)
async def test_disable_fallbacks_by_key(disable_fallbacks_set):
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
key_metadata = {"disable_fallbacks": disable_fallbacks_set}
existing_data = {
"model": "azure/chatgpt-v-2",
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
}
data = LiteLLMProxyRequestSetup.add_key_level_controls(
key_metadata=key_metadata,
data=existing_data,
_metadata_variable_name="metadata",
)
assert data["disable_fallbacks"] == disable_fallbacks_set
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"callback_type, expected_success_callbacks, expected_failure_callbacks", "callback_type, expected_success_callbacks, expected_failure_callbacks",