mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
LiteLLM Minor Fixes & Improvements (10/16/2024) (#6265)
* fix(caching_handler.py): handle positional arguments in add cache logic Fixes https://github.com/BerriAI/litellm/issues/6264 * feat(litellm_pre_call_utils.py): allow forwarding openai org id to backend client https://github.com/BerriAI/litellm/issues/6237 * docs(configs.md): add 'forward_openai_org_id' to docs * fix(proxy_server.py): return model info if user_model is set Fixes https://github.com/BerriAI/litellm/issues/6233 * fix(hosted_vllm/chat/transformation.py): don't set tools unless non-none * fix(openai.py): improve debug log for openai 'str' error Addresses https://github.com/BerriAI/litellm/issues/6272 * fix(proxy_server.py): fix linting error * fix(proxy_server.py): fix linting errors * test: skip WIP test * docs(openai.md): add docs on passing openai org id from client to openai
This commit is contained in:
parent
43878bd2a0
commit
38a9a106d2
14 changed files with 371 additions and 47 deletions
|
@ -492,4 +492,49 @@ response = completion("openai/your-model-name", messages)
|
||||||
|
|
||||||
If you need to set api_base dynamically, just pass it in completions instead - `completions(...,api_base="your-proxy-api-base")`
|
If you need to set api_base dynamically, just pass it in completions instead - `completions(...,api_base="your-proxy-api-base")`
|
||||||
|
|
||||||
For more check out [setting API Base/Keys](../set_keys.md)
|
For more check out [setting API Base/Keys](../set_keys.md)
|
||||||
|
|
||||||
|
### Forwarding Org ID for Proxy requests
|
||||||
|
|
||||||
|
Forward openai Org ID's from the client to OpenAI with `forward_openai_org_id` param.
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "gpt-3.5-turbo"
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
forward_openai_org_id: true # 👈 KEY CHANGE
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config config.yaml --detailed_debug
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make OpenAI call
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="sk-1234",
|
||||||
|
organization="my-special-org",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
client.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
|
||||||
|
```
|
||||||
|
|
||||||
|
In your logs you should see the forwarded org id
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LiteLLM:DEBUG: utils.py:255 - Request to litellm:
|
||||||
|
LiteLLM:DEBUG: utils.py:255 - litellm.acompletion(... organization='my-special-org',)
|
||||||
|
```
|
|
@ -811,6 +811,8 @@ general_settings:
|
||||||
| oauth2_config_mappings | Dict[str, str] | Define the OAuth2 config mappings |
|
| oauth2_config_mappings | Dict[str, str] | Define the OAuth2 config mappings |
|
||||||
| pass_through_endpoints | List[Dict[str, Any]] | Define the pass through endpoints. [Docs](./pass_through) |
|
| pass_through_endpoints | List[Dict[str, Any]] | Define the pass through endpoints. [Docs](./pass_through) |
|
||||||
| enable_oauth2_proxy_auth | boolean | (Enterprise Feature) If true, enables oauth2.0 authentication |
|
| enable_oauth2_proxy_auth | boolean | (Enterprise Feature) If true, enables oauth2.0 authentication |
|
||||||
|
| forward_openai_org_id | boolean | If true, forwards the OpenAI Organization ID to the backend LLM call (if it's OpenAI). |
|
||||||
|
|
||||||
### router_settings - Reference
|
### router_settings - Reference
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
@ -859,6 +861,7 @@ router_settings:
|
||||||
| allowed_fails | integer | The number of failures allowed before cooling down a model. [More information here](reliability) |
|
| allowed_fails | integer | The number of failures allowed before cooling down a model. [More information here](reliability) |
|
||||||
| allowed_fails_policy | object | Specifies the number of allowed failures for different error types before cooling down a deployment. [More information here](reliability) |
|
| allowed_fails_policy | object | Specifies the number of allowed failures for different error types before cooling down a deployment. [More information here](reliability) |
|
||||||
|
|
||||||
|
|
||||||
### environment variables - Reference
|
### environment variables - Reference
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
|
|
|
@ -16,6 +16,7 @@ In each method it will call the appropriate method from caching.py
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
@ -632,7 +633,7 @@ class LLMCachingHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_set_cache(
|
async def async_set_cache(
|
||||||
self,
|
self,
|
||||||
result: Any,
|
result: Any,
|
||||||
original_function: Callable,
|
original_function: Callable,
|
||||||
|
@ -653,7 +654,7 @@ class LLMCachingHandler:
|
||||||
Raises:
|
Raises:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
args = args or ()
|
kwargs.update(convert_args_to_kwargs(result, original_function, kwargs, args))
|
||||||
if litellm.cache is None:
|
if litellm.cache is None:
|
||||||
return
|
return
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
|
@ -675,24 +676,24 @@ class LLMCachingHandler:
|
||||||
) # s3 doesn't support bulk writing. Exclude.
|
) # s3 doesn't support bulk writing. Exclude.
|
||||||
):
|
):
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
litellm.cache.async_add_cache_pipeline(result, *args, **kwargs)
|
litellm.cache.async_add_cache_pipeline(result, **kwargs)
|
||||||
)
|
)
|
||||||
elif isinstance(litellm.cache.cache, S3Cache):
|
elif isinstance(litellm.cache.cache, S3Cache):
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=litellm.cache.add_cache,
|
target=litellm.cache.add_cache,
|
||||||
args=(result,) + args,
|
args=(result,),
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
).start()
|
).start()
|
||||||
else:
|
else:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
litellm.cache.async_add_cache(result.json(), *args, **kwargs)
|
litellm.cache.async_add_cache(
|
||||||
|
result.model_dump_json(), **kwargs
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
asyncio.create_task(
|
asyncio.create_task(litellm.cache.async_add_cache(result, **kwargs))
|
||||||
litellm.cache.async_add_cache(result, *args, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sync_set_cache(
|
def sync_set_cache(
|
||||||
self,
|
self,
|
||||||
result: Any,
|
result: Any,
|
||||||
kwargs: Dict[str, Any],
|
kwargs: Dict[str, Any],
|
||||||
|
@ -701,14 +702,16 @@ class LLMCachingHandler:
|
||||||
"""
|
"""
|
||||||
Sync internal method to add the result to the cache
|
Sync internal method to add the result to the cache
|
||||||
"""
|
"""
|
||||||
|
kwargs.update(
|
||||||
|
convert_args_to_kwargs(result, self.original_function, kwargs, args)
|
||||||
|
)
|
||||||
if litellm.cache is None:
|
if litellm.cache is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
args = args or ()
|
|
||||||
if self._should_store_result_in_cache(
|
if self._should_store_result_in_cache(
|
||||||
original_function=self.original_function, kwargs=kwargs
|
original_function=self.original_function, kwargs=kwargs
|
||||||
):
|
):
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, **kwargs)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -772,7 +775,7 @@ class LLMCachingHandler:
|
||||||
|
|
||||||
# if a complete_streaming_response is assembled, add it to the cache
|
# if a complete_streaming_response is assembled, add it to the cache
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
await self._async_set_cache(
|
await self.async_set_cache(
|
||||||
result=complete_streaming_response,
|
result=complete_streaming_response,
|
||||||
original_function=self.original_function,
|
original_function=self.original_function,
|
||||||
kwargs=self.request_kwargs,
|
kwargs=self.request_kwargs,
|
||||||
|
@ -795,7 +798,7 @@ class LLMCachingHandler:
|
||||||
|
|
||||||
# if a complete_streaming_response is assembled, add it to the cache
|
# if a complete_streaming_response is assembled, add it to the cache
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
self._sync_set_cache(
|
self.sync_set_cache(
|
||||||
result=complete_streaming_response,
|
result=complete_streaming_response,
|
||||||
kwargs=self.request_kwargs,
|
kwargs=self.request_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -849,3 +852,26 @@ class LLMCachingHandler:
|
||||||
additional_args=None,
|
additional_args=None,
|
||||||
stream=kwargs.get("stream", False),
|
stream=kwargs.get("stream", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_args_to_kwargs(
|
||||||
|
result: Any,
|
||||||
|
original_function: Callable,
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
args: Optional[Tuple[Any, ...]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
# Get the signature of the original function
|
||||||
|
signature = inspect.signature(original_function)
|
||||||
|
|
||||||
|
# Get parameter names in the order they appear in the original function
|
||||||
|
param_names = list(signature.parameters.keys())
|
||||||
|
|
||||||
|
# Create a mapping of positional arguments to parameter names
|
||||||
|
args_to_kwargs = {}
|
||||||
|
if args:
|
||||||
|
for index, arg in enumerate(args):
|
||||||
|
if index < len(param_names):
|
||||||
|
param_name = param_names[index]
|
||||||
|
args_to_kwargs[param_name] = arg
|
||||||
|
|
||||||
|
return args_to_kwargs
|
||||||
|
|
|
@ -590,6 +590,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||||
- call chat.completions.create by default
|
- call chat.completions.create by default
|
||||||
"""
|
"""
|
||||||
|
raw_response = None
|
||||||
try:
|
try:
|
||||||
raw_response = openai_client.chat.completions.with_raw_response.create(
|
raw_response = openai_client.chat.completions.with_raw_response.create(
|
||||||
**data, timeout=timeout
|
**data, timeout=timeout
|
||||||
|
@ -602,7 +603,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
return headers, response
|
return headers, response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
if raw_response is not None:
|
||||||
|
raise Exception(
|
||||||
|
"error - {}, Received response - {}, Type of response - {}".format(
|
||||||
|
e, raw_response, type(raw_response)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
def completion( # type: ignore
|
def completion( # type: ignore
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -28,7 +28,8 @@ class HostedVLLMChatConfig(OpenAIGPTConfig):
|
||||||
_tools = _remove_additional_properties(_tools)
|
_tools = _remove_additional_properties(_tools)
|
||||||
# remove 'strict' from tools
|
# remove 'strict' from tools
|
||||||
_tools = _remove_strict_from_schema(_tools)
|
_tools = _remove_strict_from_schema(_tools)
|
||||||
non_default_params["tools"] = _tools
|
if _tools is not None:
|
||||||
|
non_default_params["tools"] = _tools
|
||||||
return super().map_openai_params(
|
return super().map_openai_params(
|
||||||
non_default_params, optional_params, model, drop_params
|
non_default_params, optional_params, model, drop_params
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,33 +1,12 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: "gpt-3.5-turbo"
|
||||||
litellm_params:
|
|
||||||
model: azure/gpt-35-turbo # 👈 EU azure model
|
|
||||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
|
||||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
|
||||||
region_name: "eu"
|
|
||||||
- model_name: gpt-4o
|
|
||||||
litellm_params:
|
|
||||||
model: azure/gpt-4o
|
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
region_name: "us"
|
|
||||||
- model_name: gpt-3.5-turbo-end-user-test
|
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
region_name: "eu"
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
model_info:
|
|
||||||
id: "1"
|
|
||||||
|
|
||||||
# guardrails:
|
|
||||||
# - guardrail_name: "gibberish-guard"
|
|
||||||
# litellm_params:
|
|
||||||
# guardrail: guardrails_ai
|
|
||||||
# guard_name: "gibberish_guard"
|
|
||||||
# mode: "post_call"
|
|
||||||
# api_base: os.environ/GUARDRAILS_AI_API_BASE
|
|
||||||
|
|
||||||
assistant_settings:
|
assistant_settings:
|
||||||
custom_llm_provider: azure
|
custom_llm_provider: azure
|
||||||
litellm_params:
|
litellm_params:
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
|
|
@ -2030,3 +2030,8 @@ class SpecialHeaders(enum.Enum):
|
||||||
openai_authorization = "Authorization"
|
openai_authorization = "Authorization"
|
||||||
azure_authorization = "API-Key"
|
azure_authorization = "API-Key"
|
||||||
anthropic_authorization = "x-api-key"
|
anthropic_authorization = "x-api-key"
|
||||||
|
|
||||||
|
|
||||||
|
class LitellmDataForBackendLLMCall(TypedDict, total=False):
|
||||||
|
headers: dict
|
||||||
|
organization: str
|
||||||
|
|
|
@ -9,6 +9,7 @@ from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
AddTeamCallback,
|
AddTeamCallback,
|
||||||
CommonProxyErrors,
|
CommonProxyErrors,
|
||||||
|
LitellmDataForBackendLLMCall,
|
||||||
LiteLLMRoutes,
|
LiteLLMRoutes,
|
||||||
SpecialHeaders,
|
SpecialHeaders,
|
||||||
TeamCallbackMetadata,
|
TeamCallbackMetadata,
|
||||||
|
@ -172,9 +173,44 @@ def get_forwardable_headers(
|
||||||
"x-stainless"
|
"x-stainless"
|
||||||
): # causes openai sdk to fail
|
): # causes openai sdk to fail
|
||||||
forwarded_headers[header] = value
|
forwarded_headers[header] = value
|
||||||
|
|
||||||
return forwarded_headers
|
return forwarded_headers
|
||||||
|
|
||||||
|
|
||||||
|
def get_openai_org_id_from_headers(
|
||||||
|
headers: dict, general_settings: Optional[Dict] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the OpenAI Org ID from the headers.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
general_settings is not None
|
||||||
|
and general_settings.get("forward_openai_org_id") is not True
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
for header, value in headers.items():
|
||||||
|
if header.lower() == "openai-organization":
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def add_litellm_data_for_backend_llm_call(
|
||||||
|
headers: dict, general_settings: Optional[Dict[str, Any]] = None
|
||||||
|
) -> LitellmDataForBackendLLMCall:
|
||||||
|
"""
|
||||||
|
- Adds forwardable headers
|
||||||
|
- Adds org id
|
||||||
|
"""
|
||||||
|
data = LitellmDataForBackendLLMCall()
|
||||||
|
_headers = get_forwardable_headers(headers)
|
||||||
|
if _headers != {}:
|
||||||
|
data["headers"] = _headers
|
||||||
|
_organization = get_openai_org_id_from_headers(headers, general_settings)
|
||||||
|
if _organization is not None:
|
||||||
|
data["organization"] = _organization
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def add_litellm_data_to_request(
|
async def add_litellm_data_to_request(
|
||||||
data: dict,
|
data: dict,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
@ -210,8 +246,8 @@ async def add_litellm_data_to_request(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_forwardable_headers(_headers) != {}:
|
data.update(add_litellm_data_for_backend_llm_call(_headers, general_settings))
|
||||||
data["headers"] = get_forwardable_headers(_headers)
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
data["proxy_server_request"] = {
|
data["proxy_server_request"] = {
|
||||||
"url": str(request.url),
|
"url": str(request.url),
|
||||||
|
|
|
@ -19,6 +19,7 @@ from typing import (
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
cast,
|
||||||
get_args,
|
get_args,
|
||||||
get_origin,
|
get_origin,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
|
@ -7313,18 +7314,40 @@ async def model_info_v1(
|
||||||
|
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
|
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router, user_model
|
||||||
|
|
||||||
|
if user_model is not None:
|
||||||
|
# user is trying to get specific model from litellm router
|
||||||
|
try:
|
||||||
|
model_info: Dict = cast(Dict, litellm.get_model_info(model=user_model))
|
||||||
|
except Exception:
|
||||||
|
model_info = {}
|
||||||
|
_deployment_info = Deployment(
|
||||||
|
model_name="*",
|
||||||
|
litellm_params=LiteLLM_Params(
|
||||||
|
model=user_model,
|
||||||
|
),
|
||||||
|
model_info=model_info,
|
||||||
|
)
|
||||||
|
_deployment_info_dict = _deployment_info.model_dump()
|
||||||
|
_deployment_info_dict = remove_sensitive_info_from_deployment(
|
||||||
|
deployment_dict=_deployment_info_dict
|
||||||
|
)
|
||||||
|
return {"data": _deployment_info_dict}
|
||||||
|
|
||||||
if llm_model_list is None:
|
if llm_model_list is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "LLM Model List not loaded in. Make sure you passed models in your config.yaml or on the LiteLLM Admin UI. - https://docs.litellm.ai/docs/proxy/configs"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if llm_router is None:
|
if llm_router is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail={
|
detail={
|
||||||
"error": "LLM Router is not loaded in. Make sure you passed models in your config.yaml or on the LiteLLM Admin UI."
|
"error": "LLM Router is not loaded in. Make sure you passed models in your config.yaml or on the LiteLLM Admin UI. - https://docs.litellm.ai/docs/proxy/configs"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -927,7 +927,7 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
|
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
_llm_caching_handler._sync_set_cache(
|
_llm_caching_handler.sync_set_cache(
|
||||||
result=result,
|
result=result,
|
||||||
args=args,
|
args=args,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
@ -1126,7 +1126,7 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
|
|
||||||
## Add response to cache
|
## Add response to cache
|
||||||
await _llm_caching_handler._async_set_cache(
|
await _llm_caching_handler.async_set_cache(
|
||||||
result=result,
|
result=result,
|
||||||
original_function=original_function,
|
original_function=original_function,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
|
|
@ -732,3 +732,18 @@ def test_drop_nested_params_add_prop_and_strict(provider, model):
|
||||||
)
|
)
|
||||||
|
|
||||||
_check_additional_properties(optional_params["tools"])
|
_check_additional_properties(optional_params["tools"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_hosted_vllm_tool_param():
|
||||||
|
"""
|
||||||
|
Relevant issue - https://github.com/BerriAI/litellm/issues/6228
|
||||||
|
"""
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
model="my-vllm-model",
|
||||||
|
custom_llm_provider="hosted_vllm",
|
||||||
|
temperature=0.2,
|
||||||
|
tools=None,
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
assert "tools" not in optional_params
|
||||||
|
assert "tool_choice" not in optional_params
|
||||||
|
|
|
@ -2298,3 +2298,70 @@ def test_basic_caching_import():
|
||||||
|
|
||||||
assert Cache is not None
|
assert Cache is not None
|
||||||
print("Cache imported successfully")
|
print("Cache imported successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_caching_kwargs_input(sync_mode):
|
||||||
|
from litellm import acompletion
|
||||||
|
from litellm.caching.caching_handler import LLMCachingHandler
|
||||||
|
from litellm.types.utils import (
|
||||||
|
Choices,
|
||||||
|
EmbeddingResponse,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
CompletionTokensDetails,
|
||||||
|
PromptTokensDetails,
|
||||||
|
)
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
llm_caching_handler = LLMCachingHandler(
|
||||||
|
original_function=acompletion, request_kwargs={}, start_time=datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
input = {
|
||||||
|
"result": ModelResponse(
|
||||||
|
id="chatcmpl-AJ119H5XsDnYiZPp5axJ5d7niwqeR",
|
||||||
|
choices=[
|
||||||
|
Choices(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=Message(
|
||||||
|
content="Hello! I'm just a computer program, so I don't have feelings, but I'm here to assist you. How can I help you today?",
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=None,
|
||||||
|
function_call=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1729095507,
|
||||||
|
model="gpt-3.5-turbo-0125",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=Usage(
|
||||||
|
completion_tokens=31,
|
||||||
|
prompt_tokens=16,
|
||||||
|
total_tokens=47,
|
||||||
|
completion_tokens_details=CompletionTokensDetails(
|
||||||
|
audio_tokens=None, reasoning_tokens=0
|
||||||
|
),
|
||||||
|
prompt_tokens_details=PromptTokensDetails(
|
||||||
|
audio_tokens=None, cached_tokens=0
|
||||||
|
),
|
||||||
|
),
|
||||||
|
service_tier=None,
|
||||||
|
),
|
||||||
|
"kwargs": {
|
||||||
|
"messages": [{"role": "user", "content": "42HHey, how's it going?"}],
|
||||||
|
"caching": True,
|
||||||
|
"litellm_call_id": "fae2aa4f-9f75-4f11-8c9c-63ab8d9fae26",
|
||||||
|
"preset_cache_key": "2f69f5640d5e0f25315d0e132f1278bb643554d14565d2c61d61564b10ade90f",
|
||||||
|
},
|
||||||
|
"args": ("gpt-3.5-turbo",),
|
||||||
|
}
|
||||||
|
if sync_mode is True:
|
||||||
|
llm_caching_handler.sync_set_cache(**input)
|
||||||
|
else:
|
||||||
|
input["original_function"] = acompletion
|
||||||
|
await llm_caching_handler.async_set_cache(**input)
|
||||||
|
|
|
@ -1796,3 +1796,81 @@ async def test_proxy_model_group_info_rerank(prisma_client):
|
||||||
print(resp)
|
print(resp)
|
||||||
models = resp["data"]
|
models = resp["data"]
|
||||||
assert models[0].mode == "rerank"
|
assert models[0].mode == "rerank"
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_proxy_team_member_add(prisma_client):
|
||||||
|
# """
|
||||||
|
# Add 10 people to a team. Confirm all 10 are added.
|
||||||
|
# """
|
||||||
|
# from litellm.proxy.management_endpoints.team_endpoints import (
|
||||||
|
# team_member_add,
|
||||||
|
# new_team,
|
||||||
|
# )
|
||||||
|
# from litellm.proxy._types import TeamMemberAddRequest, Member, NewTeamRequest
|
||||||
|
|
||||||
|
# setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
# setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
# try:
|
||||||
|
|
||||||
|
# async def test():
|
||||||
|
# await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
# from litellm.proxy.proxy_server import user_api_key_cache
|
||||||
|
|
||||||
|
# user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
# user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
# api_key="sk-1234",
|
||||||
|
# user_id="1234",
|
||||||
|
# )
|
||||||
|
|
||||||
|
# new_team()
|
||||||
|
# for _ in range(10):
|
||||||
|
# request = TeamMemberAddRequest(
|
||||||
|
# team_id="1234",
|
||||||
|
# member=Member(
|
||||||
|
# user_id="1234",
|
||||||
|
# user_role=LitellmUserRoles.INTERNAL_USER,
|
||||||
|
# ),
|
||||||
|
# )
|
||||||
|
# key = await team_member_add(
|
||||||
|
# request, user_api_key_dict=user_api_key_dict
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print(key)
|
||||||
|
# user_id = key.user_id
|
||||||
|
|
||||||
|
# # check /user/info to verify user_role was set correctly
|
||||||
|
# new_user_info = await user_info(
|
||||||
|
# user_id=user_id, user_api_key_dict=user_api_key_dict
|
||||||
|
# )
|
||||||
|
# new_user_info = new_user_info.user_info
|
||||||
|
# print("new_user_info=", new_user_info)
|
||||||
|
# assert new_user_info["user_role"] == LitellmUserRoles.INTERNAL_USER
|
||||||
|
# assert new_user_info["user_id"] == user_id
|
||||||
|
|
||||||
|
# generated_key = key.key
|
||||||
|
# bearer_token = "Bearer " + generated_key
|
||||||
|
|
||||||
|
# assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict
|
||||||
|
|
||||||
|
# value_from_prisma = await prisma_client.get_data(
|
||||||
|
# token=generated_key,
|
||||||
|
# )
|
||||||
|
# print("token from prisma", value_from_prisma)
|
||||||
|
|
||||||
|
# request = Request(
|
||||||
|
# {
|
||||||
|
# "type": "http",
|
||||||
|
# "route": api_route,
|
||||||
|
# "path": api_route.path,
|
||||||
|
# "headers": [("Authorization", bearer_token)],
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # use generated key to auth in
|
||||||
|
# result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
# print("result from user auth with new key", result)
|
||||||
|
|
||||||
|
# asyncio.run(test())
|
||||||
|
# except Exception as e:
|
||||||
|
# pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
|
@ -368,3 +368,41 @@ def test_is_request_body_safe_model_enabled(
|
||||||
error_raised = True
|
error_raised = True
|
||||||
|
|
||||||
assert expect_error == error_raised
|
assert expect_error == error_raised
|
||||||
|
|
||||||
|
|
||||||
|
def test_reading_openai_org_id_from_headers():
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import get_openai_org_id_from_headers
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"OpenAI-Organization": "test_org_id",
|
||||||
|
}
|
||||||
|
org_id = get_openai_org_id_from_headers(headers)
|
||||||
|
assert org_id == "test_org_id"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"headers, expected_data",
|
||||||
|
[
|
||||||
|
({"OpenAI-Organization": "test_org_id"}, {"organization": "test_org_id"}),
|
||||||
|
({"openai-organization": "test_org_id"}, {"organization": "test_org_id"}),
|
||||||
|
({}, {}),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"OpenAI-Organization": "test_org_id",
|
||||||
|
"Authorization": "Bearer test_token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"organization": "test_org_id",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_add_litellm_data_for_backend_llm_call(headers, expected_data):
|
||||||
|
import json
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import (
|
||||||
|
add_litellm_data_for_backend_llm_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = add_litellm_data_for_backend_llm_call(headers)
|
||||||
|
|
||||||
|
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue