forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (10/15/2024) (#6242)
* feat(litellm_pre_call_utils.py): support forwarding request headers to backend llm api * fix(litellm_pre_call_utils.py): handle custom litellm key header * test(router_code_coverage.py): check if all router functions are dire… (#6186) * test(router_code_coverage.py): check if all router functions are directly tested prevent regressions * docs(configs.md): document all environment variables (#6185) * docs: make it easier to find anthropic/openai prompt caching doc * aded codecov yml (#6207) * fix codecov.yaml * run ci/cd again * (refactor) caching use LLMCachingHandler for async_get_cache and set_cache (#6208) * use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * fix test_embedding_caching_azure_individual_items_reordered * (feat) prometheus have well defined latency buckets (#6211) * fix prometheus have well defined latency buckets * use a well define latency bucket * use types file for prometheus logging * add test for LATENCY_BUCKETS * fix prom testing * fix config.yml * (refactor caching) use LLMCachingHandler for caching streaming responses (#6210) * use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting * bump (#6187) * update code cov yaml * fix config.yml * add caching component to code cov * fix config.yml ci/cd * add coverage for proxy auth * (refactor caching) use common `_retrieve_from_cache` helper (#6212) * use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting * refactor - use _retrieve_from_cache * refactor use _convert_cached_result_to_model_response * fix linting errors * bump: version 1.49.2 → 1.49.3 * fix code cov components * test(test_router_helpers.py): add router component unit tests * test: add additional router tests * test: add more router testing * test: add more router testing + more mock functions * ci(router_code_coverage.py): fix check --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: yujonglee <yujonglee.dev@gmail.com> * bump: version 1.49.3 → 1.49.4 * (refactor) use helper function `_assemble_complete_response_from_streaming_chunks` to assemble complete responses in caching and logging callbacks (#6220) * (refactor) use _assemble_complete_response_from_streaming_chunks * add unit test for test_assemble_complete_response_from_streaming_chunks_1 * fix assemble complete_streaming_response * config add logging_testing * add logging_coverage in codecov * test test_assemble_complete_response_from_streaming_chunks_3 * add unit tests for _assemble_complete_response_from_streaming_chunks * fix remove unused / junk function * add test for streaming_chunks when error assembling * (refactor) OTEL - use safe_set_attribute for setting attributes (#6226) * otel - use safe_set_attribute for setting attributes * fix OTEL only use safe_set_attribute * (fix) prompt caching cost calculation OpenAI, Azure OpenAI (#6231) * fix prompt caching cost calculation * fix testing for prompt cache cost calc * fix(allowed_model_region): allow us as allowed region (#6234) * test(router_code_coverage.py): check if all router functions are dire… (#6186) * test(router_code_coverage.py): check if all router functions are directly tested prevent regressions * docs(configs.md): document all environment variables (#6185) * docs: make it easier to find anthropic/openai prompt caching doc * aded codecov yml (#6207) * fix codecov.yaml * run ci/cd again * (refactor) caching use LLMCachingHandler for async_get_cache and set_cache (#6208) * use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * fix test_embedding_caching_azure_individual_items_reordered * (feat) prometheus have well defined latency buckets (#6211) * fix prometheus have well defined latency buckets * use a well define latency bucket * use types file for prometheus logging * add test for LATENCY_BUCKETS * fix prom testing * fix config.yml * (refactor caching) use LLMCachingHandler for caching streaming responses (#6210) * use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting * bump (#6187) * update code cov yaml * fix config.yml * add caching component to code cov * fix config.yml ci/cd * add coverage for proxy auth * (refactor caching) use common `_retrieve_from_cache` helper (#6212) * use folder for caching * fix importing caching * fix clickhouse pyright * fix linting * fix correctly pass kwargs and args * fix test case for embedding * fix linting * fix embedding caching logic * fix refactor handle utils.py * refactor async set stream cache * fix linting * refactor - use _retrieve_from_cache * refactor use _convert_cached_result_to_model_response * fix linting errors * bump: version 1.49.2 → 1.49.3 * fix code cov components * test(test_router_helpers.py): add router component unit tests * test: add additional router tests * test: add more router testing * test: add more router testing + more mock functions * ci(router_code_coverage.py): fix check --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: yujonglee <yujonglee.dev@gmail.com> * bump: version 1.49.3 → 1.49.4 * (refactor) use helper function `_assemble_complete_response_from_streaming_chunks` to assemble complete responses in caching and logging callbacks (#6220) * (refactor) use _assemble_complete_response_from_streaming_chunks * add unit test for test_assemble_complete_response_from_streaming_chunks_1 * fix assemble complete_streaming_response * config add logging_testing * add logging_coverage in codecov * test test_assemble_complete_response_from_streaming_chunks_3 * add unit tests for _assemble_complete_response_from_streaming_chunks * fix remove unused / junk function * add test for streaming_chunks when error assembling * (refactor) OTEL - use safe_set_attribute for setting attributes (#6226) * otel - use safe_set_attribute for setting attributes * fix OTEL only use safe_set_attribute * fix(allowed_model_region): allow us as allowed region --------- Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: yujonglee <yujonglee.dev@gmail.com> * fix(litellm_pre_call_utils.py): support 'us' region routing + fix header forwarding to filter on `x-` headers * docs(customer_routing.md): fix region-based routing example * feat(azure.py): handle empty arguments function call - azure Closes https://github.com/BerriAI/litellm/issues/6241 * feat(guardrails_ai.py): support guardrails ai integration Adds support for on-prem guardrails via guardrails ai * fix(proxy/utils.py): prevent sql injection attack Fixes https://huntr.com/bounties/a4f6d357-5b44-4e00-9cac-f1cc351211d2 * fix: fix linting errors * fix(litellm_pre_call_utils.py): don't log litellm api key in proxy server request headers * fix(litellm_pre_call_utils.py): don't forward stainless headers * docs(guardrails_ai.md): add guardrails ai quick start to docs * test: handle flaky test --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: yujonglee <yujonglee.dev@gmail.com> Co-authored-by: Marcus Elwin <marcus@elwin.com>
This commit is contained in:
parent
fc5b75d171
commit
54ebdbf7ce
32 changed files with 982 additions and 314 deletions
|
@ -24,21 +24,25 @@ curl -X POST --location 'http://0.0.0.0:4000/end_user/new' \
|
|||
|
||||
### 2. Add eu models to model-group
|
||||
|
||||
Add eu models to a model group. For azure models, litellm can automatically infer the region (no need to set it).
|
||||
Add eu models to a model group. Use the 'region_name' param to specify the region for each model.
|
||||
|
||||
Supported regions are 'eu' and 'us'.
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-35-turbo-eu # 👈 EU azure model
|
||||
model: azure/gpt-35-turbo # 👈 EU azure model
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||
region_name: "eu"
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
region_name: "us"
|
||||
|
||||
router_settings:
|
||||
enable_pre_call_checks: true # 👈 IMPORTANT
|
||||
|
@ -74,6 +78,7 @@ Expected API Base in response headers
|
|||
|
||||
```
|
||||
x-litellm-api-base: "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||
x-litellm-model-region: "eu" # 👈 CONFIRMS REGION-BASED ROUTING WORKED
|
||||
```
|
||||
|
||||
### FAQ
|
||||
|
|
118
docs/my-website/docs/proxy/guardrails/guardrails_ai.md
Normal file
118
docs/my-website/docs/proxy/guardrails/guardrails_ai.md
Normal file
|
@ -0,0 +1,118 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Guardrails.ai
|
||||
|
||||
Use [Guardrails.ai](https://www.guardrailsai.com/) to add checks to LLM output.
|
||||
|
||||
## Pre-requisites
|
||||
|
||||
- Setup Guardrails AI Server. [quick start](https://www.guardrailsai.com/docs/getting_started/guardrails_server)
|
||||
|
||||
## Usage
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "guardrails_ai-guard"
|
||||
litellm_params:
|
||||
guardrail: guardrails_ai
|
||||
guard_name: "gibberish_guard" # 👈 Guardrail AI guard name
|
||||
mode: "post_call"
|
||||
api_base: os.environ/GUARDRAILS_AI_API_BASE # 👈 Guardrails AI API Base. Defaults to "http://0.0.0.0:8000"
|
||||
```
|
||||
|
||||
2. Start LiteLLM Gateway
|
||||
|
||||
```shell
|
||||
litellm --config config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
3. Test request
|
||||
|
||||
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
|
||||
],
|
||||
"guardrails": ["guardrails_ai-guard"]
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
## ✨ Control Guardrails per Project (API Key)
|
||||
|
||||
:::info
|
||||
|
||||
✨ This is an Enterprise only feature [Contact us to get a free trial](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
:::
|
||||
|
||||
Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project (API Key)
|
||||
- `guardrails`: ["aporia-pre-guard", "aporia-post-guard"]
|
||||
|
||||
**Step 1** Create Key with guardrail settings
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="/key/generate" label="/key/generate">
|
||||
|
||||
```shell
|
||||
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-D '{
|
||||
"guardrails": ["guardrails_ai-guard"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="/key/update" label="/key/update">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/key/update' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"key": "sk-jNm1Zar7XfNdZXp49Z1kSQ",
|
||||
"guardrails": ["guardrails_ai-guard"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**Step 2** Test it with new key
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-jNm1Zar7XfNdZXp49Z1kSQ' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "my email is ishaan@berri.ai"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -95,6 +95,7 @@ const sidebars = {
|
|||
items: [
|
||||
"proxy/guardrails/quick_start",
|
||||
"proxy/guardrails/aporia_api",
|
||||
"proxy/guardrails/guardrails_ai",
|
||||
"proxy/guardrails/lakera_ai",
|
||||
"proxy/guardrails/bedrock",
|
||||
"proxy/guardrails/pii_masking_v2",
|
||||
|
|
|
@ -1005,10 +1005,11 @@ from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
|||
from .llms.volcengine import VolcEngineConfig
|
||||
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||
from .llms.AzureOpenAI.azure import (
|
||||
AzureOpenAIConfig,
|
||||
AzureOpenAIError,
|
||||
AzureOpenAIAssistantsAPIConfig,
|
||||
)
|
||||
|
||||
from .llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
|
||||
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
|
||||
from .llms.watsonx import IBMWatsonXAIConfig
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Literal, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -10,11 +10,20 @@ class CustomGuardrail(CustomLogger):
|
|||
def __init__(
|
||||
self,
|
||||
guardrail_name: Optional[str] = None,
|
||||
supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
|
||||
event_hook: Optional[GuardrailEventHooks] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.guardrail_name = guardrail_name
|
||||
self.supported_event_hooks = supported_event_hooks
|
||||
self.event_hook: Optional[GuardrailEventHooks] = event_hook
|
||||
|
||||
if supported_event_hooks:
|
||||
## validate event_hook is in supported_event_hooks
|
||||
if event_hook and event_hook not in supported_event_hooks:
|
||||
raise ValueError(
|
||||
f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||
|
|
|
@ -26,10 +26,6 @@ from litellm.utils import (
|
|||
from ...types.llms.openai import (
|
||||
Batch,
|
||||
CancelBatchRequest,
|
||||
ChatCompletionToolChoiceFunctionParam,
|
||||
ChatCompletionToolChoiceObjectParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
CreateBatchRequest,
|
||||
HttpxBinaryResponseContent,
|
||||
RetrieveBatchRequest,
|
||||
|
@ -67,214 +63,6 @@ class AzureOpenAIError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class AzureOpenAIConfig:
|
||||
"""
|
||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
|
||||
|
||||
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"temperature",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"top_p",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
api_version: str, # Y-M-D-{optional}
|
||||
drop_params,
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params()
|
||||
|
||||
api_version_times = api_version.split("-")
|
||||
api_version_year = api_version_times[0]
|
||||
api_version_month = api_version_times[1]
|
||||
api_version_day = api_version_times[2]
|
||||
for param, value in non_default_params.items():
|
||||
if param == "tool_choice":
|
||||
"""
|
||||
This parameter requires API version 2023-12-01-preview or later
|
||||
|
||||
tool_choice='required' is not supported as of 2024-05-01-preview
|
||||
"""
|
||||
## check if api version supports this param ##
|
||||
if (
|
||||
api_version_year < "2023"
|
||||
or (api_version_year == "2023" and api_version_month < "12")
|
||||
or (
|
||||
api_version_year == "2023"
|
||||
and api_version_month == "12"
|
||||
and api_version_day < "01"
|
||||
)
|
||||
):
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=400,
|
||||
message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
|
||||
)
|
||||
elif value == "required" and (
|
||||
api_version_year == "2024" and api_version_month <= "05"
|
||||
): ## check if tool_choice value is supported ##
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=400,
|
||||
message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
|
||||
)
|
||||
else:
|
||||
optional_params["tool_choice"] = value
|
||||
elif param == "response_format" and isinstance(value, dict):
|
||||
json_schema: Optional[dict] = None
|
||||
schema_name: str = ""
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
schema_name = "json_tool_call"
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
schema_name = value["json_schema"]["name"]
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
if json_schema is not None and (
|
||||
(api_version_year <= "2024" and api_version_month < "08")
|
||||
or "gpt-4o" not in model
|
||||
): # azure api version "2024-08-01-preview" onwards supports 'json_schema' only for gpt-4o
|
||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolChoiceFunctionParam(
|
||||
name=schema_name
|
||||
),
|
||||
)
|
||||
|
||||
_tool = ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name=schema_name, parameters=json_schema
|
||||
),
|
||||
)
|
||||
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
else:
|
||||
optional_params["response_format"] = value
|
||||
elif param == "max_completion_tokens":
|
||||
# TODO - Azure OpenAI will probably add support for this, we should pass it through when Azure adds support
|
||||
optional_params["max_tokens"] = value
|
||||
elif param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
return {"token": "azure_ad_token"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "token":
|
||||
optional_params["azure_ad_token"] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||
"""
|
||||
return ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
|
||||
|
||||
class AzureOpenAIAssistantsAPIConfig:
|
||||
"""
|
||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
|
||||
|
@ -620,11 +408,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
|
||||
data = {"model": None, "messages": messages, **optional_params}
|
||||
else:
|
||||
data = {
|
||||
"model": model, # type: ignore
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
data = litellm.AzureOpenAIConfig.transform_request(
|
||||
model=model, messages=messages, optional_params=optional_params
|
||||
)
|
||||
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
|
|
250
litellm/llms/AzureOpenAI/chat/gpt_transformation.py
Normal file
250
litellm/llms/AzureOpenAI/chat/gpt_transformation.py
Normal file
|
@ -0,0 +1,250 @@
|
|||
import types
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
import litellm
|
||||
|
||||
from ....exceptions import UnsupportedParamsError
|
||||
from ....types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionToolChoiceFunctionParam,
|
||||
ChatCompletionToolChoiceObjectParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from ...prompt_templates.factory import convert_to_azure_openai_messages
|
||||
|
||||
|
||||
class AzureOpenAIConfig:
|
||||
"""
|
||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
|
||||
|
||||
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"temperature",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"top_p",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
api_version: str, # Y-M-D-{optional}
|
||||
drop_params,
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params()
|
||||
|
||||
api_version_times = api_version.split("-")
|
||||
api_version_year = api_version_times[0]
|
||||
api_version_month = api_version_times[1]
|
||||
api_version_day = api_version_times[2]
|
||||
for param, value in non_default_params.items():
|
||||
if param == "tool_choice":
|
||||
"""
|
||||
This parameter requires API version 2023-12-01-preview or later
|
||||
|
||||
tool_choice='required' is not supported as of 2024-05-01-preview
|
||||
"""
|
||||
## check if api version supports this param ##
|
||||
if (
|
||||
api_version_year < "2023"
|
||||
or (api_version_year == "2023" and api_version_month < "12")
|
||||
or (
|
||||
api_version_year == "2023"
|
||||
and api_version_month == "12"
|
||||
and api_version_day < "01"
|
||||
)
|
||||
):
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=400,
|
||||
message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
|
||||
)
|
||||
elif value == "required" and (
|
||||
api_version_year == "2024" and api_version_month <= "05"
|
||||
): ## check if tool_choice value is supported ##
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=400,
|
||||
message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
|
||||
)
|
||||
else:
|
||||
optional_params["tool_choice"] = value
|
||||
elif param == "response_format" and isinstance(value, dict):
|
||||
json_schema: Optional[dict] = None
|
||||
schema_name: str = ""
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
schema_name = "json_tool_call"
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
schema_name = value["json_schema"]["name"]
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
if json_schema is not None and (
|
||||
(api_version_year <= "2024" and api_version_month < "08")
|
||||
or "gpt-4o" not in model
|
||||
): # azure api version "2024-08-01-preview" onwards supports 'json_schema' only for gpt-4o
|
||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolChoiceFunctionParam(
|
||||
name=schema_name
|
||||
),
|
||||
)
|
||||
|
||||
_tool = ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name=schema_name, parameters=json_schema
|
||||
),
|
||||
)
|
||||
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
else:
|
||||
optional_params["response_format"] = value
|
||||
elif param == "max_completion_tokens":
|
||||
# TODO - Azure OpenAI will probably add support for this, we should pass it through when Azure adds support
|
||||
optional_params["max_tokens"] = value
|
||||
elif param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_request(
|
||||
cls, model: str, messages: List[AllMessageValues], optional_params: dict
|
||||
) -> dict:
|
||||
messages = convert_to_azure_openai_messages(messages)
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
return {"token": "azure_ad_token"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "token":
|
||||
optional_params["azure_ad_token"] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||
"""
|
||||
return ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||
"""
|
||||
return [
|
||||
"us",
|
||||
"eastus",
|
||||
"eastus2",
|
||||
"eastus2euap",
|
||||
"eastus3",
|
||||
"southcentralus",
|
||||
"westus",
|
||||
"westus2",
|
||||
"westus3",
|
||||
"westus4",
|
||||
]
|
|
@ -53,6 +53,17 @@ class AmazonBedrockGlobalConfig:
|
|||
"eu-central-1",
|
||||
]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
"""
|
||||
return [
|
||||
"us-east-2",
|
||||
"us-east-1",
|
||||
"us-west-2",
|
||||
"us-gov-west-1",
|
||||
]
|
||||
|
||||
|
||||
class AmazonTitanConfig:
|
||||
"""
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
Common utility functions used for translating messages across providers
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import json
|
||||
from typing import Dict, List
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
||||
|
||||
|
||||
def convert_content_list_to_str(message: AllMessageValues) -> str:
|
||||
|
@ -27,3 +29,43 @@ def convert_content_list_to_str(message: AllMessageValues) -> str:
|
|||
texts = message_content
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
def convert_openai_message_to_only_content_messages(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Converts OpenAI messages to only content messages
|
||||
|
||||
Used for calling guardrails integrations which expect string content
|
||||
"""
|
||||
converted_messages = []
|
||||
user_roles = ["user", "tool", "function"]
|
||||
for message in messages:
|
||||
if message.get("role") in user_roles:
|
||||
converted_messages.append(
|
||||
{"role": "user", "content": convert_content_list_to_str(message)}
|
||||
)
|
||||
elif message.get("role") == "assistant":
|
||||
converted_messages.append(
|
||||
{"role": "assistant", "content": convert_content_list_to_str(message)}
|
||||
)
|
||||
return converted_messages
|
||||
|
||||
|
||||
def get_content_from_model_response(response: ModelResponse) -> str:
|
||||
"""
|
||||
Gets content from model response
|
||||
"""
|
||||
content = ""
|
||||
for choice in response.choices:
|
||||
if isinstance(choice, Choices):
|
||||
content += choice.message.content if choice.message.content else ""
|
||||
if choice.message.function_call:
|
||||
content += choice.message.function_call.model_dump_json()
|
||||
if choice.message.tool_calls:
|
||||
for tc in choice.message.tool_calls:
|
||||
content += tc.model_dump_json()
|
||||
elif isinstance(choice, StreamingChoices):
|
||||
content += getattr(choice, "delta", {}).get("content", "") or ""
|
||||
return content
|
||||
|
|
|
@ -920,6 +920,31 @@ def anthropic_messages_pt_xml(messages: list):
|
|||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _azure_tool_call_invoke_helper(
|
||||
function_call_params: ChatCompletionToolCallFunctionChunk,
|
||||
) -> Optional[ChatCompletionToolCallFunctionChunk]:
|
||||
"""
|
||||
Azure requires 'arguments' to be a string.
|
||||
"""
|
||||
if function_call_params.get("arguments") is None:
|
||||
function_call_params["arguments"] = ""
|
||||
return function_call_params
|
||||
|
||||
|
||||
def convert_to_azure_openai_messages(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
for m in messages:
|
||||
if m["role"] == "assistant":
|
||||
function_call = m.get("function_call", None)
|
||||
if function_call is not None:
|
||||
m["function_call"] = _azure_tool_call_invoke_helper(function_call)
|
||||
return messages
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_protocol_value(
|
||||
value: Any,
|
||||
) -> Literal[
|
||||
|
|
|
@ -254,6 +254,21 @@ class VertexAIConfig:
|
|||
"europe-west9",
|
||||
]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
||||
"""
|
||||
return [
|
||||
"us-central1",
|
||||
"us-east1",
|
||||
"us-east4",
|
||||
"us-east5",
|
||||
"us-south1",
|
||||
"us-west1",
|
||||
"us-west4",
|
||||
"us-west5",
|
||||
]
|
||||
|
||||
|
||||
class VertexGeminiConfig:
|
||||
"""
|
||||
|
|
|
@ -178,6 +178,14 @@ class IBMWatsonXAIConfig:
|
|||
"eu-gb",
|
||||
]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||
"""
|
||||
return [
|
||||
"us-south",
|
||||
]
|
||||
|
||||
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
|
|
|
@ -4059,12 +4059,12 @@ def text_completion(
|
|||
return response["choices"][0]
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
completed_futures = [
|
||||
executor.submit(process_prompt, i, individual_prompt)
|
||||
for i, individual_prompt in enumerate(prompt)
|
||||
]
|
||||
for i, future in enumerate(
|
||||
concurrent.futures.as_completed(futures)
|
||||
concurrent.futures.as_completed(completed_futures)
|
||||
):
|
||||
responses[i] = future.result()
|
||||
text_completion_response.choices = responses # type: ignore
|
||||
|
|
|
@ -1,15 +1,27 @@
|
|||
model_list:
|
||||
# - model_name: openai-gpt-4o-realtime-audio
|
||||
# litellm_params:
|
||||
# model: azure/gpt-4o-realtime-preview
|
||||
# api_key: os.environ/AZURE_SWEDEN_API_KEY
|
||||
# api_base: os.environ/AZURE_SWEDEN_API_BASE
|
||||
|
||||
- model_name: openai-gpt-4o-realtime-audio
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
api_base: http://localhost:8080
|
||||
model: azure/gpt-35-turbo # 👈 EU azure model
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||
region_name: "eu"
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: azure/gpt-4o
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
region_name: "us"
|
||||
- model_name: gpt-3.5-turbo-end-user-test
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
region_name: "eu"
|
||||
model_info:
|
||||
id: "1"
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
# guardrails:
|
||||
# - guardrail_name: "gibberish-guard"
|
||||
# litellm_params:
|
||||
# guardrail: guardrails_ai
|
||||
# guard_name: "gibberish_guard"
|
||||
# mode: "post_call"
|
||||
# api_base: os.environ/GUARDRAILS_AI_API_BASE
|
||||
|
|
|
@ -767,6 +767,9 @@ class DeleteUserRequest(LiteLLMBase):
|
|||
user_ids: List[str] # required
|
||||
|
||||
|
||||
AllowedModelRegion = Literal["eu", "us"]
|
||||
|
||||
|
||||
class NewCustomerRequest(LiteLLMBase):
|
||||
"""
|
||||
Create a new customer, allocate a budget to them
|
||||
|
@ -777,7 +780,7 @@ class NewCustomerRequest(LiteLLMBase):
|
|||
blocked: bool = False # allow/disallow requests for this end-user
|
||||
max_budget: Optional[float] = None
|
||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
allowed_model_region: Optional[Literal["eu"]] = (
|
||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
default_model: Optional[str] = (
|
||||
|
@ -804,7 +807,7 @@ class UpdateCustomerRequest(LiteLLMBase):
|
|||
blocked: bool = False # allow/disallow requests for this end-user
|
||||
max_budget: Optional[float] = None
|
||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
allowed_model_region: Optional[Literal["eu"]] = (
|
||||
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
default_model: Optional[str] = (
|
||||
|
@ -1384,7 +1387,7 @@ class UserAPIKeyAuth(
|
|||
|
||||
api_key: Optional[str] = None
|
||||
user_role: Optional[LitellmUserRoles] = None
|
||||
allowed_model_region: Optional[Literal["eu"]] = None
|
||||
allowed_model_region: Optional[AllowedModelRegion] = None
|
||||
parent_otel_span: Optional[Span] = None
|
||||
rpm_limit_per_model: Optional[Dict[str, int]] = None
|
||||
tpm_limit_per_model: Optional[Dict[str, int]] = None
|
||||
|
@ -1466,7 +1469,7 @@ class LiteLLM_EndUserTable(LiteLLMBase):
|
|||
blocked: bool
|
||||
alias: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
allowed_model_region: Optional[Literal["eu"]] = None
|
||||
allowed_model_region: Optional[AllowedModelRegion] = None
|
||||
default_model: Optional[str] = None
|
||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||
|
||||
|
@ -2019,3 +2022,11 @@ class LoggingCallbackStatus(TypedDict, total=False):
|
|||
class KeyHealthResponse(TypedDict, total=False):
|
||||
key: Literal["healthy", "unhealthy"]
|
||||
logging_callbacks: Optional[LoggingCallbackStatus]
|
||||
|
||||
|
||||
class SpecialHeaders(enum.Enum):
|
||||
"""Used by user_api_key_auth.py to get litellm key"""
|
||||
|
||||
openai_authorization = "Authorization"
|
||||
azure_authorization = "API-Key"
|
||||
anthropic_authorization = "x-api-key"
|
||||
|
|
|
@ -75,15 +75,17 @@ from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
|||
from litellm.proxy.utils import _to_ns
|
||||
|
||||
api_key_header = APIKeyHeader(
|
||||
name="Authorization", auto_error=False, description="Bearer token"
|
||||
name=SpecialHeaders.openai_authorization.value,
|
||||
auto_error=False,
|
||||
description="Bearer token",
|
||||
)
|
||||
azure_api_key_header = APIKeyHeader(
|
||||
name="API-Key",
|
||||
name=SpecialHeaders.azure_authorization.value,
|
||||
auto_error=False,
|
||||
description="Some older versions of the openai Python package will send an API-Key header with just the API key ",
|
||||
)
|
||||
anthropic_api_key_header = APIKeyHeader(
|
||||
name="x-api-key",
|
||||
name=SpecialHeaders.anthropic_authorization.value,
|
||||
auto_error=False,
|
||||
description="If anthropic client used.",
|
||||
)
|
||||
|
|
109
litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py
Normal file
109
litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Use GuardrailsAI for your LLM calls
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
# Thank you for using Litellm! - Krrish & Ishaan
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.llms.prompt_templates.common_utils import (
|
||||
convert_openai_message_to_only_content_messages,
|
||||
get_content_from_model_response,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class GuardrailsAIResponse(TypedDict):
|
||||
callId: str
|
||||
rawLlmOutput: str
|
||||
validatedOutput: str
|
||||
validationPassed: bool
|
||||
|
||||
|
||||
class GuardrailsAI(CustomGuardrail):
|
||||
def __init__(
|
||||
self,
|
||||
guard_name: str,
|
||||
api_base: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if guard_name is None:
|
||||
raise Exception(
|
||||
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
|
||||
)
|
||||
# store kwargs as optional_params
|
||||
self.guardrails_ai_api_base = api_base or "http://0.0.0.0:8000"
|
||||
self.guardrails_ai_guard_name = guard_name
|
||||
self.optional_params = kwargs
|
||||
supported_event_hooks = [GuardrailEventHooks.post_call]
|
||||
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
|
||||
|
||||
async def make_guardrails_ai_api_request(self, llm_output: str):
|
||||
from httpx import URL
|
||||
|
||||
data = {"llmOutput": llm_output}
|
||||
_json_data = json.dumps(data)
|
||||
response = await litellm.module_level_aclient.post(
|
||||
url=str(
|
||||
URL(self.guardrails_ai_api_base).join(
|
||||
f"guards/{self.guardrails_ai_guard_name}/validate"
|
||||
)
|
||||
),
|
||||
data=_json_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
verbose_proxy_logger.debug("guardrails_ai response: %s", response)
|
||||
_json_response = GuardrailsAIResponse(**response.json()) # type: ignore
|
||||
if _json_response.get("validationPassed") is False:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated guardrail policy",
|
||||
"guardrails_ai_response": _json_response,
|
||||
},
|
||||
)
|
||||
return _json_response
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
"""
|
||||
Runs on response from LLM API call
|
||||
|
||||
It can be used to reject a response
|
||||
"""
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return
|
||||
|
||||
if not isinstance(response, litellm.ModelResponse):
|
||||
return
|
||||
|
||||
response_str: str = get_content_from_model_response(response)
|
||||
if response_str is not None and len(response_str) > 0:
|
||||
await self.make_guardrails_ai_api_request(llm_output=response_str)
|
||||
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
|
||||
return
|
|
@ -17,6 +17,7 @@ from litellm.types.guardrails import (
|
|||
GuardrailItemSpec,
|
||||
LakeraCategoryThresholds,
|
||||
LitellmParams,
|
||||
SupportedGuardrailIntegrations,
|
||||
)
|
||||
|
||||
all_guardrails: List[GuardrailItem] = []
|
||||
|
@ -86,8 +87,8 @@ Map guardrail_name: <pre_call>, <post_call>, during_call
|
|||
|
||||
|
||||
def init_guardrails_v2(
|
||||
all_guardrails: dict,
|
||||
config_file_path: str,
|
||||
all_guardrails: List[Dict],
|
||||
config_file_path: Optional[str] = None,
|
||||
):
|
||||
# Convert the loaded data to the TypedDict structure
|
||||
guardrail_list = []
|
||||
|
@ -124,7 +125,7 @@ def init_guardrails_v2(
|
|||
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
|
||||
|
||||
# Init guardrail CustomLoggerClass
|
||||
if litellm_params["guardrail"] == "aporia":
|
||||
if litellm_params["guardrail"] == SupportedGuardrailIntegrations.APORIA.value:
|
||||
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
|
||||
AporiaGuardrail,
|
||||
)
|
||||
|
@ -136,7 +137,9 @@ def init_guardrails_v2(
|
|||
event_hook=litellm_params["mode"],
|
||||
)
|
||||
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||
if litellm_params["guardrail"] == "bedrock":
|
||||
elif (
|
||||
litellm_params["guardrail"] == SupportedGuardrailIntegrations.BEDROCK.value
|
||||
):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
|
||||
BedrockGuardrail,
|
||||
)
|
||||
|
@ -148,7 +151,7 @@ def init_guardrails_v2(
|
|||
guardrailVersion=litellm_params["guardrailVersion"],
|
||||
)
|
||||
litellm.callbacks.append(_bedrock_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == "lakera":
|
||||
elif litellm_params["guardrail"] == SupportedGuardrailIntegrations.LAKERA.value:
|
||||
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
|
||||
lakeraAI_Moderation,
|
||||
)
|
||||
|
@ -161,7 +164,9 @@ def init_guardrails_v2(
|
|||
category_thresholds=litellm_params.get("category_thresholds"),
|
||||
)
|
||||
litellm.callbacks.append(_lakera_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == "presidio":
|
||||
elif (
|
||||
litellm_params["guardrail"] == SupportedGuardrailIntegrations.PRESIDIO.value
|
||||
):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
|
||||
_OPTIONAL_PresidioPIIMasking,
|
||||
)
|
||||
|
@ -189,7 +194,10 @@ def init_guardrails_v2(
|
|||
litellm.callbacks.append(_success_callback) # type: ignore
|
||||
|
||||
litellm.callbacks.append(_presidio_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == "hide-secrets":
|
||||
elif (
|
||||
litellm_params["guardrail"]
|
||||
== SupportedGuardrailIntegrations.HIDE_SECRETS.value
|
||||
):
|
||||
from enterprise.enterprise_hooks.secret_detection import (
|
||||
_ENTERPRISE_SecretDetection,
|
||||
)
|
||||
|
@ -201,10 +209,34 @@ def init_guardrails_v2(
|
|||
)
|
||||
|
||||
litellm.callbacks.append(_secret_detection_object) # type: ignore
|
||||
elif (
|
||||
litellm_params["guardrail"]
|
||||
== SupportedGuardrailIntegrations.GURDRAILS_AI.value
|
||||
):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.guardrails_ai import (
|
||||
GuardrailsAI,
|
||||
)
|
||||
|
||||
_guard_name = litellm_params.get("guard_name")
|
||||
if _guard_name is None:
|
||||
raise Exception(
|
||||
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
|
||||
)
|
||||
_guardrails_ai_callback = GuardrailsAI(
|
||||
api_base=litellm_params.get("api_base"),
|
||||
guard_name=_guard_name,
|
||||
guardrail_name=SupportedGuardrailIntegrations.GURDRAILS_AI.value,
|
||||
)
|
||||
|
||||
litellm.callbacks.append(_guardrails_ai_callback) # type: ignore
|
||||
elif (
|
||||
isinstance(litellm_params["guardrail"], str)
|
||||
and "." in litellm_params["guardrail"]
|
||||
):
|
||||
if config_file_path is None:
|
||||
raise Exception(
|
||||
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2"
|
||||
)
|
||||
import os
|
||||
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
|
@ -238,6 +270,8 @@ def init_guardrails_v2(
|
|||
event_hook=litellm_params["mode"],
|
||||
)
|
||||
litellm.callbacks.append(_guardrail_callback) # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported guardrail: {litellm_params['guardrail']}")
|
||||
|
||||
parsed_guardrail = Guardrail(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
AddTeamCallback,
|
||||
CommonProxyErrors,
|
||||
LiteLLMRoutes,
|
||||
SpecialHeaders,
|
||||
TeamCallbackMetadata,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import get_request_route
|
||||
from litellm.types.utils import SupportedCacheControls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -137,6 +141,40 @@ def _get_dynamic_logging_metadata(
|
|||
return callback_settings_obj
|
||||
|
||||
|
||||
def clean_headers(
|
||||
headers: Headers, litellm_key_header_name: Optional[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Removes litellm api key from headers
|
||||
"""
|
||||
special_headers = [v.value.lower() for v in SpecialHeaders._member_map_.values()]
|
||||
special_headers = special_headers
|
||||
if litellm_key_header_name is not None:
|
||||
special_headers.append(litellm_key_header_name.lower())
|
||||
clean_headers = {}
|
||||
for header, value in headers.items():
|
||||
if header.lower() not in special_headers:
|
||||
clean_headers[header] = value
|
||||
return clean_headers
|
||||
|
||||
|
||||
def get_forwardable_headers(
|
||||
headers: Union[Headers, dict],
|
||||
):
|
||||
"""
|
||||
Get the headers that should be forwarded to the LLM Provider.
|
||||
|
||||
Looks for any `x-` headers and sends them to the LLM Provider.
|
||||
"""
|
||||
forwarded_headers = {}
|
||||
for header, value in headers.items():
|
||||
if header.lower().startswith("x-") and not header.lower().startswith(
|
||||
"x-stainless"
|
||||
): # causes openai sdk to fail
|
||||
forwarded_headers[header] = value
|
||||
return forwarded_headers
|
||||
|
||||
|
||||
async def add_litellm_data_to_request(
|
||||
data: dict,
|
||||
request: Request,
|
||||
|
@ -163,8 +201,17 @@ async def add_litellm_data_to_request(
|
|||
|
||||
safe_add_api_version_from_query_params(data, request)
|
||||
|
||||
_headers = dict(request.headers)
|
||||
_headers = clean_headers(
|
||||
request.headers,
|
||||
litellm_key_header_name=(
|
||||
general_settings.get("litellm_key_header_name")
|
||||
if general_settings is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if get_forwardable_headers(_headers) != {}:
|
||||
data["headers"] = get_forwardable_headers(_headers)
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
|
|
|
@ -1992,7 +1992,7 @@ class ProxyConfig:
|
|||
) # type:ignore
|
||||
|
||||
# Guardrail settings
|
||||
guardrails_v2: Optional[dict] = None
|
||||
guardrails_v2: Optional[List[Dict]] = None
|
||||
|
||||
if config is not None:
|
||||
guardrails_v2 = config.get("guardrails", None)
|
||||
|
@ -5894,7 +5894,7 @@ async def new_end_user(
|
|||
- blocked: bool - Flag to allow or disallow requests for this end-user. Default is False.
|
||||
- max_budget: Optional[float] - The maximum budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both.
|
||||
- budget_id: Optional[str] - The identifier for an existing budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both.
|
||||
- allowed_model_region: Optional[Literal["eu"]] - Require all user requests to use models in this specific region.
|
||||
- allowed_model_region: Optional[Union[Literal["eu"], Literal["us"]]] - Require all user requests to use models in this specific region.
|
||||
- default_model: Optional[str] - If no equivalent model in the allowed region, default all requests to this model.
|
||||
- metadata: Optional[dict] = Metadata for customer, store information for customer. Example metadata = {"data_training_opt_out": True}
|
||||
|
||||
|
|
|
@ -1580,15 +1580,9 @@ class PrismaClient:
|
|||
}
|
||||
)
|
||||
elif query_type == "find_all" and user_id_list is not None:
|
||||
user_id_values = ", ".join(f"'{item}'" for item in user_id_list)
|
||||
sql_query = f"""
|
||||
SELECT *
|
||||
FROM "LiteLLM_UserTable"
|
||||
WHERE "user_id" IN ({user_id_values})
|
||||
"""
|
||||
# Execute the raw query
|
||||
# The asterisk before `user_id_list` unpacks the list into separate arguments
|
||||
response = await self.db.query_raw(sql_query)
|
||||
response = await self.db.litellm_usertable.find_many(
|
||||
where={"user_id": {"in": user_id_list}}
|
||||
)
|
||||
elif query_type == "find_all":
|
||||
if expires is not None:
|
||||
response = await self.db.litellm_usertable.find_many( # type: ignore
|
||||
|
|
|
@ -126,6 +126,7 @@ from litellm.utils import (
|
|||
get_llm_provider,
|
||||
get_secret,
|
||||
get_utc_datetime,
|
||||
is_region_allowed,
|
||||
)
|
||||
|
||||
from .router_utils.pattern_match_deployments import PatternMatchRouter
|
||||
|
@ -4888,25 +4889,14 @@ class Router:
|
|||
if (
|
||||
request_kwargs is not None
|
||||
and request_kwargs.get("allowed_model_region") is not None
|
||||
and request_kwargs["allowed_model_region"] == "eu"
|
||||
):
|
||||
if _litellm_params.get("region_name") is not None and isinstance(
|
||||
_litellm_params["region_name"], str
|
||||
allowed_model_region = request_kwargs.get("allowed_model_region")
|
||||
|
||||
if allowed_model_region is not None:
|
||||
if not is_region_allowed(
|
||||
litellm_params=LiteLLM_Params(**_litellm_params),
|
||||
allowed_model_region=allowed_model_region,
|
||||
):
|
||||
# check if in allowed_model_region
|
||||
if (
|
||||
_is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
|
||||
is False
|
||||
):
|
||||
invalid_model_indices.append(idx)
|
||||
continue
|
||||
else:
|
||||
verbose_router_logger.debug(
|
||||
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
|
||||
model_id, request_kwargs.get("allowed_model_region")
|
||||
)
|
||||
)
|
||||
# filter out since region unknown, and user wants to filter for specific region
|
||||
invalid_model_indices.append(idx)
|
||||
continue
|
||||
|
||||
|
|
|
@ -19,6 +19,15 @@ litellm_settings:
|
|||
"""
|
||||
|
||||
|
||||
class SupportedGuardrailIntegrations(Enum):
|
||||
APORIA = "aporia"
|
||||
BEDROCK = "bedrock"
|
||||
GURDRAILS_AI = "guardrails_ai"
|
||||
LAKERA = "lakera"
|
||||
PRESIDIO = "presidio"
|
||||
HIDE_SECRETS = "hide-secrets"
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
SYSTEM = "system"
|
||||
ASSISTANT = "assistant"
|
||||
|
@ -92,6 +101,9 @@ class LitellmParams(TypedDict):
|
|||
# hide secrets params
|
||||
detect_secrets_config: Optional[dict]
|
||||
|
||||
# guardrails ai params
|
||||
guard_name: Optional[str]
|
||||
|
||||
|
||||
class Guardrail(TypedDict):
|
||||
guardrail_name: str
|
||||
|
|
|
@ -175,7 +175,7 @@ from .exceptions import (
|
|||
UnprocessableEntityError,
|
||||
UnsupportedParamsError,
|
||||
)
|
||||
from .proxy._types import KeyManagementSystem
|
||||
from .proxy._types import AllowedModelRegion, KeyManagementSystem
|
||||
from .types.llms.openai import (
|
||||
ChatCompletionDeltaToolCallChunk,
|
||||
ChatCompletionToolCallChunk,
|
||||
|
@ -3839,18 +3839,13 @@ def _get_model_region(
|
|||
return litellm_params.region_name
|
||||
|
||||
|
||||
def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
|
||||
def _infer_model_region(litellm_params: LiteLLM_Params) -> Optional[AllowedModelRegion]:
|
||||
"""
|
||||
Return true/false if a deployment is in the EU
|
||||
"""
|
||||
if litellm_params.region_name == "eu":
|
||||
return True
|
||||
Infer if a model is in the EU or US region
|
||||
|
||||
## ELSE ##
|
||||
"""
|
||||
- get provider
|
||||
- get provider regions
|
||||
- return true if given region (get_provider_region) in eu region (config.get_eu_regions())
|
||||
Returns:
|
||||
- str (region) - "eu" or "us"
|
||||
- None (if region not found)
|
||||
"""
|
||||
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=litellm_params.model, litellm_params=litellm_params
|
||||
|
@ -3861,21 +3856,71 @@ def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
|
|||
)
|
||||
|
||||
if model_region is None:
|
||||
return False
|
||||
verbose_logger.debug(
|
||||
"Cannot infer model region for model: {}".format(litellm_params.model)
|
||||
)
|
||||
return None
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
eu_regions = litellm.AzureOpenAIConfig().get_eu_regions()
|
||||
us_regions = litellm.AzureOpenAIConfig().get_us_regions()
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
eu_regions = litellm.VertexAIConfig().get_eu_regions()
|
||||
us_regions = litellm.VertexAIConfig().get_us_regions()
|
||||
elif custom_llm_provider == "bedrock":
|
||||
eu_regions = litellm.AmazonBedrockGlobalConfig().get_eu_regions()
|
||||
us_regions = litellm.AmazonBedrockGlobalConfig().get_us_regions()
|
||||
elif custom_llm_provider == "watsonx":
|
||||
eu_regions = litellm.IBMWatsonXAIConfig().get_eu_regions()
|
||||
us_regions = litellm.IBMWatsonXAIConfig().get_us_regions()
|
||||
else:
|
||||
return False
|
||||
eu_regions = []
|
||||
us_regions = []
|
||||
|
||||
for region in eu_regions:
|
||||
if region in model_region.lower():
|
||||
return "eu"
|
||||
for region in us_regions:
|
||||
if region in model_region.lower():
|
||||
return "us"
|
||||
return None
|
||||
|
||||
|
||||
def _is_region_eu(litellm_params: LiteLLM_Params) -> bool:
|
||||
"""
|
||||
Return true/false if a deployment is in the EU
|
||||
"""
|
||||
if litellm_params.region_name == "eu":
|
||||
return True
|
||||
|
||||
## Else - try and infer from model region
|
||||
model_region = _infer_model_region(litellm_params=litellm_params)
|
||||
if model_region is not None and model_region == "eu":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_region_us(litellm_params: LiteLLM_Params) -> bool:
|
||||
"""
|
||||
Return true/false if a deployment is in the US
|
||||
"""
|
||||
if litellm_params.region_name == "us":
|
||||
return True
|
||||
|
||||
## Else - try and infer from model region
|
||||
model_region = _infer_model_region(litellm_params=litellm_params)
|
||||
if model_region is not None and model_region == "us":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_region_allowed(
|
||||
litellm_params: LiteLLM_Params, allowed_model_region: str
|
||||
) -> bool:
|
||||
"""
|
||||
Return true/false if a deployment is in the EU
|
||||
"""
|
||||
if litellm_params.region_name == allowed_model_region:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
|
@ -226,7 +226,7 @@ def test_all_model_configs():
|
|||
optional_params={},
|
||||
) == {"max_tokens": 10}
|
||||
|
||||
from litellm.llms.AzureOpenAI.azure import AzureOpenAIConfig
|
||||
from litellm.llms.AzureOpenAI.chat.gpt_transformation import AzureOpenAIConfig
|
||||
|
||||
assert "max_completion_tokens" in AzureOpenAIConfig().get_supported_openai_params()
|
||||
assert AzureOpenAIConfig().map_openai_params(
|
||||
|
|
|
@ -82,6 +82,7 @@ def user_api_key_auth() -> UserAPIKeyAuth:
|
|||
|
||||
@pytest.mark.parametrize("num_projects", [1, 2, 100])
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_available_tpm(num_projects, dynamic_rate_limit_handler):
|
||||
model = "my-fake-model"
|
||||
## SET CACHE W/ ACTIVE PROJECTS
|
||||
|
|
28
tests/local_testing/test_guardrails_ai.py
Normal file
28
tests/local_testing/test_guardrails_ai.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.proxy.guardrails.init_guardrails import init_guardrails_v2
|
||||
|
||||
|
||||
def test_guardrails_ai():
|
||||
litellm.set_verbose = True
|
||||
litellm.guardrail_name_config_map = {}
|
||||
|
||||
init_guardrails_v2(
|
||||
all_guardrails=[
|
||||
{
|
||||
"guardrail_name": "gibberish-guard",
|
||||
"litellm_params": {
|
||||
"guardrail": "guardrails_ai",
|
||||
"guard_name": "gibberish_guard",
|
||||
"mode": "post_call",
|
||||
},
|
||||
}
|
||||
],
|
||||
config_file_path="",
|
||||
)
|
|
@ -436,3 +436,24 @@ def test_vertex_only_image_user_message():
|
|||
|
||||
def test_convert_url():
|
||||
convert_url_to_base64("https://picsum.photos/id/237/200/300")
|
||||
|
||||
|
||||
def test_azure_tool_call_invoke_helper():
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather in Copenhagen?"},
|
||||
{"role": "assistant", "function_call": {"name": "get_weather"}},
|
||||
]
|
||||
|
||||
transformed_messages = litellm.AzureOpenAIConfig.transform_request(
|
||||
model="gpt-4o", messages=messages, optional_params={}
|
||||
)
|
||||
|
||||
assert transformed_messages["messages"] == [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather in Copenhagen?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"function_call": {"name": "get_weather", "arguments": ""},
|
||||
},
|
||||
]
|
||||
|
|
|
@ -72,7 +72,7 @@ def test_embedding(client):
|
|||
|
||||
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||
print("my_custom_logger", my_custom_logger)
|
||||
assert my_custom_logger.async_success_embedding == False
|
||||
assert my_custom_logger.async_success_embedding is False
|
||||
|
||||
test_data = {"model": "azure-embedding-model", "input": ["hello"]}
|
||||
response = client.post("/embeddings", json=test_data, headers=headers)
|
||||
|
@ -84,7 +84,7 @@ def test_embedding(client):
|
|||
id(my_custom_logger),
|
||||
)
|
||||
assert (
|
||||
my_custom_logger.async_success_embedding == True
|
||||
my_custom_logger.async_success_embedding is True
|
||||
) # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
||||
assert (
|
||||
my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model"
|
||||
|
@ -107,7 +107,6 @@ def test_embedding(client):
|
|||
"accept-encoding": "gzip, deflate",
|
||||
"connection": "keep-alive",
|
||||
"user-agent": "testclient",
|
||||
"authorization": "Bearer sk-1234",
|
||||
"content-length": "54",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
|
@ -194,6 +193,8 @@ def test_chat_completion(client):
|
|||
"mode": "chat",
|
||||
"db_model": False,
|
||||
}
|
||||
|
||||
assert "authorization" not in proxy_server_request_object["headers"]
|
||||
assert proxy_server_request_object == {
|
||||
"url": "http://testserver/chat/completions",
|
||||
"method": "POST",
|
||||
|
@ -203,7 +204,6 @@ def test_chat_completion(client):
|
|||
"accept-encoding": "gzip, deflate",
|
||||
"connection": "keep-alive",
|
||||
"user-agent": "testclient",
|
||||
"authorization": "Bearer sk-1234",
|
||||
"content-length": "123",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
|
|
|
@ -173,6 +173,96 @@ def test_chat_completion(mock_acompletion, client_no_auth):
|
|||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"litellm_key_header_name",
|
||||
["x-litellm-key", None],
|
||||
)
|
||||
def test_add_headers_to_request(litellm_key_header_name):
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import URL
|
||||
import json
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
clean_headers,
|
||||
get_forwardable_headers,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": "Bearer 1234",
|
||||
"X-Custom-Header": "Custom-Value",
|
||||
"X-Stainless-Header": "Stainless-Value",
|
||||
}
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
request._body = json.dumps({"model": "gpt-3.5-turbo"}).encode("utf-8")
|
||||
request_headers = clean_headers(headers, litellm_key_header_name)
|
||||
forwarded_headers = get_forwardable_headers(request_headers)
|
||||
assert forwarded_headers == {"X-Custom-Header": "Custom-Value"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"litellm_key_header_name",
|
||||
["x-litellm-key", None],
|
||||
)
|
||||
@mock_patch_acompletion()
|
||||
def test_chat_completion_forward_headers(
|
||||
mock_acompletion, client_no_auth, litellm_key_header_name
|
||||
):
|
||||
global headers
|
||||
try:
|
||||
if litellm_key_header_name is not None:
|
||||
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
||||
gs["litellm_key_header_name"] = litellm_key_header_name
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
headers_to_forward = {
|
||||
"X-Custom-Header": "Custom-Value",
|
||||
"X-Another-Header": "Another-Value",
|
||||
}
|
||||
|
||||
if litellm_key_header_name is not None:
|
||||
headers_to_not_forward = {litellm_key_header_name: "Bearer 1234"}
|
||||
else:
|
||||
headers_to_not_forward = {"Authorization": "Bearer 1234"}
|
||||
|
||||
received_headers = {**headers_to_forward, **headers_to_not_forward}
|
||||
|
||||
print("testing proxy server with chat completions")
|
||||
response = client_no_auth.post(
|
||||
"/v1/chat/completions", json=test_data, headers=received_headers
|
||||
)
|
||||
mock_acompletion.assert_called_once_with(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
max_tokens=10,
|
||||
litellm_call_id=mock.ANY,
|
||||
litellm_logging_obj=mock.ANY,
|
||||
request_timeout=mock.ANY,
|
||||
specific_deployment=True,
|
||||
metadata=mock.ANY,
|
||||
proxy_server_request=mock.ANY,
|
||||
headers={
|
||||
"x-custom-header": "Custom-Value",
|
||||
"x-another-header": "Another-Value",
|
||||
},
|
||||
)
|
||||
print(f"response - {response.text}")
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||
|
||||
|
||||
@mock_patch_acompletion()
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_disable_guardrails(mock_acompletion, client_no_auth):
|
||||
|
|
|
@ -1050,7 +1050,7 @@ def test_filter_invalid_params_pre_call_check():
|
|||
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("allowed_model_region", ["eu", None])
|
||||
@pytest.mark.parametrize("allowed_model_region", ["eu", None, "us"])
|
||||
def test_router_region_pre_call_check(allowed_model_region):
|
||||
"""
|
||||
If region based routing set
|
||||
|
@ -1065,7 +1065,7 @@ def test_router_region_pre_call_check(allowed_model_region):
|
|||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"base_model": "azure/gpt-35-turbo",
|
||||
"region_name": "eu",
|
||||
"region_name": allowed_model_region,
|
||||
},
|
||||
"model_info": {"id": "1"},
|
||||
},
|
||||
|
@ -1091,7 +1091,9 @@ def test_router_region_pre_call_check(allowed_model_region):
|
|||
if allowed_model_region is None:
|
||||
assert len(_healthy_deployments) == 2
|
||||
else:
|
||||
assert len(_healthy_deployments) == 1, "No models selected as healthy"
|
||||
assert len(_healthy_deployments) == 1, "{} models selected as healthy".format(
|
||||
len(_healthy_deployments)
|
||||
)
|
||||
assert (
|
||||
_healthy_deployments[0]["model_info"]["id"] == "1"
|
||||
), "Incorrect model id picked. Got id={}, expected id=1".format(
|
||||
|
|
|
@ -102,7 +102,6 @@ def test_spend_logs_payload():
|
|||
"method": "POST",
|
||||
"headers": {
|
||||
"content-type": "application/json",
|
||||
"authorization": "Bearer sk-1234",
|
||||
"user-agent": "PostmanRuntime/7.32.3",
|
||||
"accept": "*/*",
|
||||
"postman-token": "92300061-eeaa-423b-a420-0b44896ecdc4",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue