mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into litellm_fix_azure_api_version
This commit is contained in:
commit
409306b266
23 changed files with 490 additions and 103 deletions
14
.github/workflows/ghcr_deploy.yml
vendored
14
.github/workflows/ghcr_deploy.yml
vendored
|
@ -194,6 +194,8 @@ jobs:
|
||||||
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
|
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
|
||||||
|
|
||||||
build-and-push-helm-chart:
|
build-and-push-helm-chart:
|
||||||
|
if: github.event.inputs.release_type != 'dev'
|
||||||
|
needs: [docker-hub-deploy, build-and-push-image, build-and-push-image-database]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
@ -211,9 +213,17 @@ jobs:
|
||||||
- name: lowercase github.repository_owner
|
- name: lowercase github.repository_owner
|
||||||
run: |
|
run: |
|
||||||
echo "REPO_OWNER=`echo ${{github.repository_owner}} | tr '[:upper:]' '[:lower:]'`" >>${GITHUB_ENV}
|
echo "REPO_OWNER=`echo ${{github.repository_owner}} | tr '[:upper:]' '[:lower:]'`" >>${GITHUB_ENV}
|
||||||
|
|
||||||
- name: Get LiteLLM Latest Tag
|
- name: Get LiteLLM Latest Tag
|
||||||
id: current_app_tag
|
id: current_app_tag
|
||||||
uses: WyriHaximus/github-action-get-previous-tag@v1.3.0
|
shell: bash
|
||||||
|
run: |
|
||||||
|
LATEST_TAG=$(git describe --tags --exclude "*dev*" --abbrev=0)
|
||||||
|
if [ -z "${LATEST_TAG}" ]; then
|
||||||
|
echo "latest_tag=latest" | tee -a $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "latest_tag=${LATEST_TAG}" | tee -a $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Get last published chart version
|
- name: Get last published chart version
|
||||||
id: current_version
|
id: current_version
|
||||||
|
@ -241,7 +251,7 @@ jobs:
|
||||||
name: ${{ env.CHART_NAME }}
|
name: ${{ env.CHART_NAME }}
|
||||||
repository: ${{ env.REPO_OWNER }}
|
repository: ${{ env.REPO_OWNER }}
|
||||||
tag: ${{ github.event.inputs.chartVersion || steps.bump_version.outputs.next-version || '0.1.0' }}
|
tag: ${{ github.event.inputs.chartVersion || steps.bump_version.outputs.next-version || '0.1.0' }}
|
||||||
app_version: ${{ steps.current_app_tag.outputs.tag || 'latest' }}
|
app_version: ${{ steps.current_app_tag.outputs.latest_tag }}
|
||||||
path: deploy/charts/${{ env.CHART_NAME }}
|
path: deploy/charts/${{ env.CHART_NAME }}
|
||||||
registry: ${{ env.REGISTRY }}
|
registry: ${{ env.REGISTRY }}
|
||||||
registry_username: ${{ github.actor }}
|
registry_username: ${{ github.actor }}
|
||||||
|
|
132
docs/my-website/docs/pass_through/langfuse.md
Normal file
132
docs/my-website/docs/pass_through/langfuse.md
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
# Langfuse Endpoints (Pass-Through)
|
||||||
|
|
||||||
|
Pass-through endpoints for Langfuse - call langfuse endpoints with LiteLLM Virtual Key.
|
||||||
|
|
||||||
|
Just replace `https://us.cloud.langfuse.com` with `LITELLM_PROXY_BASE_URL/langfuse` 🚀
|
||||||
|
|
||||||
|
#### **Example Usage**
|
||||||
|
```python
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
langfuse = Langfuse(
|
||||||
|
host="http://localhost:4000/langfuse", # your litellm proxy endpoint
|
||||||
|
public_key="anything", # no key required since this is a pass through
|
||||||
|
secret_key="LITELLM_VIRTUAL_KEY", # no key required since this is a pass through
|
||||||
|
)
|
||||||
|
|
||||||
|
print("sending langfuse trace request")
|
||||||
|
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
|
||||||
|
print("flushing langfuse request")
|
||||||
|
langfuse.flush()
|
||||||
|
|
||||||
|
print("flushed langfuse request")
|
||||||
|
```
|
||||||
|
|
||||||
|
Supports **ALL** Langfuse Endpoints.
|
||||||
|
|
||||||
|
[**See All Langfuse Endpoints**](https://api.reference.langfuse.com/)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Let's log a trace to Langfuse.
|
||||||
|
|
||||||
|
1. Add Langfuse Public/Private keys to environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LANGFUSE_PUBLIC_KEY=""
|
||||||
|
export LANGFUSE_PRIVATE_KEY=""
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start LiteLLM Proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
Let's log a trace to Langfuse!
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
langfuse = Langfuse(
|
||||||
|
host="http://localhost:4000/langfuse", # your litellm proxy endpoint
|
||||||
|
public_key="anything", # no key required since this is a pass through
|
||||||
|
secret_key="anything", # no key required since this is a pass through
|
||||||
|
)
|
||||||
|
|
||||||
|
print("sending langfuse trace request")
|
||||||
|
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
|
||||||
|
print("flushing langfuse request")
|
||||||
|
langfuse.flush()
|
||||||
|
|
||||||
|
print("flushed langfuse request")
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced - Use with Virtual Keys
|
||||||
|
|
||||||
|
Pre-requisites
|
||||||
|
- [Setup proxy with DB](../proxy/virtual_keys.md#setup)
|
||||||
|
|
||||||
|
Use this, to avoid giving developers the raw Google AI Studio key, but still letting them use Google AI Studio endpoints.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
1. Setup environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export DATABASE_URL=""
|
||||||
|
export LITELLM_MASTER_KEY=""
|
||||||
|
export LANGFUSE_PUBLIC_KEY=""
|
||||||
|
export LANGFUSE_PRIVATE_KEY=""
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Generate virtual key
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Response
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
...
|
||||||
|
"key": "sk-1234ewknldferwedojwojw"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
langfuse = Langfuse(
|
||||||
|
host="http://localhost:4000/langfuse", # your litellm proxy endpoint
|
||||||
|
public_key="anything", # no key required since this is a pass through
|
||||||
|
secret_key="sk-1234ewknldferwedojwojw", # no key required since this is a pass through
|
||||||
|
)
|
||||||
|
|
||||||
|
print("sending langfuse trace request")
|
||||||
|
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
|
||||||
|
print("flushing langfuse request")
|
||||||
|
langfuse.flush()
|
||||||
|
|
||||||
|
print("flushed langfuse request")
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Advanced - Log to separate langfuse projects (by key/team)](../proxy/team_logging.md)
|
|
@ -207,7 +207,7 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
-H 'Content-Type: application/json' \
|
-H 'Content-Type: application/json' \
|
||||||
-d '{
|
-d '{
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"logging": {
|
"logging": [{
|
||||||
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
|
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
|
||||||
"callback_type": "success" # set, if required by integration - future improvement, have logging tools work for success + failure by default
|
"callback_type": "success" # set, if required by integration - future improvement, have logging tools work for success + failure by default
|
||||||
"callback_vars": {
|
"callback_vars": {
|
||||||
|
@ -215,7 +215,7 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", # [RECOMMENDED] reference key in proxy environment
|
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", # [RECOMMENDED] reference key in proxy environment
|
||||||
"langfuse_host": "https://cloud.langfuse.com"
|
"langfuse_host": "https://cloud.langfuse.com"
|
||||||
}
|
}
|
||||||
}
|
}]
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ guardrails:
|
||||||
|
|
||||||
- `pre_call` Run **before** LLM call, on **input**
|
- `pre_call` Run **before** LLM call, on **input**
|
||||||
- `post_call` Run **after** LLM call, on **input & output**
|
- `post_call` Run **after** LLM call, on **input & output**
|
||||||
- `during_call` Run **during** LLM call, on **input**
|
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
|
||||||
|
|
||||||
## 3. Start LiteLLM Gateway
|
## 3. Start LiteLLM Gateway
|
||||||
|
|
||||||
|
@ -72,6 +72,8 @@ litellm --config config.yaml --detailed_debug
|
||||||
|
|
||||||
## 4. Test request
|
## 4. Test request
|
||||||
|
|
||||||
|
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)**
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem label="Unsuccessful call" value = "not-allowed">
|
<TabItem label="Unsuccessful call" value = "not-allowed">
|
||||||
|
|
||||||
|
@ -134,12 +136,10 @@ curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
## Advanced
|
## 5. Control Guardrails per Project (API Key)
|
||||||
### Control Guardrails per Project (API Key)
|
|
||||||
|
|
||||||
Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project
|
Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project (API Key)
|
||||||
- `pre_call_guardrails`: ["aporia-pre-guard"]
|
- `guardrails`: ["aporia-pre-guard", "aporia-post-guard"]
|
||||||
- `post_call_guardrails`: ["aporia-post-guard"]
|
|
||||||
|
|
||||||
**Step 1** Create Key with guardrail settings
|
**Step 1** Create Key with guardrail settings
|
||||||
|
|
||||||
|
@ -151,8 +151,7 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
-H 'Authorization: Bearer sk-1234' \
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
-H 'Content-Type: application/json' \
|
-H 'Content-Type: application/json' \
|
||||||
-D '{
|
-D '{
|
||||||
"pre_call_guardrails": ["aporia-pre-guard"],
|
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
|
||||||
"post_call_guardrails": ["aporia"]
|
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
@ -166,8 +165,7 @@ curl --location 'http://0.0.0.0:4000/key/update' \
|
||||||
--header 'Content-Type: application/json' \
|
--header 'Content-Type: application/json' \
|
||||||
--data '{
|
--data '{
|
||||||
"key": "sk-jNm1Zar7XfNdZXp49Z1kSQ",
|
"key": "sk-jNm1Zar7XfNdZXp49Z1kSQ",
|
||||||
"pre_call_guardrails": ["aporia"],
|
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
|
||||||
"post_call_guardrails": ["aporia"]
|
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
|
@ -195,7 +195,8 @@ const sidebars = {
|
||||||
"pass_through/vertex_ai",
|
"pass_through/vertex_ai",
|
||||||
"pass_through/google_ai_studio",
|
"pass_through/google_ai_studio",
|
||||||
"pass_through/cohere",
|
"pass_through/cohere",
|
||||||
"pass_through/bedrock"
|
"pass_through/bedrock",
|
||||||
|
"pass_through/langfuse"
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"scheduler",
|
"scheduler",
|
||||||
|
|
|
@ -509,16 +509,16 @@ async def ollama_acompletion(
|
||||||
async def ollama_aembeddings(
|
async def ollama_aembeddings(
|
||||||
api_base: str,
|
api_base: str,
|
||||||
model: str,
|
model: str,
|
||||||
prompts: list,
|
prompts: List[str],
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
):
|
):
|
||||||
if api_base.endswith("/api/embeddings"):
|
if api_base.endswith("/api/embed"):
|
||||||
url = api_base
|
url = api_base
|
||||||
else:
|
else:
|
||||||
url = f"{api_base}/api/embeddings"
|
url = f"{api_base}/api/embed"
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.OllamaConfig.get_config()
|
config = litellm.OllamaConfig.get_config()
|
||||||
|
@ -528,64 +528,53 @@ async def ollama_aembeddings(
|
||||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
input_data: Dict[str, Any] = {"model": model}
|
data: Dict[str, Any] = {"model": model, "input": prompts}
|
||||||
special_optional_params = ["truncate", "options", "keep_alive"]
|
special_optional_params = ["truncate", "options", "keep_alive"]
|
||||||
|
|
||||||
for k, v in optional_params.items():
|
for k, v in optional_params.items():
|
||||||
if k in special_optional_params:
|
if k in special_optional_params:
|
||||||
input_data[k] = v
|
data[k] = v
|
||||||
else:
|
else:
|
||||||
# Ensure "options" is a dictionary before updating it
|
# Ensure "options" is a dictionary before updating it
|
||||||
input_data.setdefault("options", {})
|
data.setdefault("options", {})
|
||||||
if isinstance(input_data["options"], dict):
|
if isinstance(data["options"], dict):
|
||||||
input_data["options"].update({k: v})
|
data["options"].update({k: v})
|
||||||
total_input_tokens = 0
|
total_input_tokens = 0
|
||||||
output_data = []
|
output_data = []
|
||||||
|
|
||||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
for idx, prompt in enumerate(prompts):
|
## LOGGING
|
||||||
data = deepcopy(input_data)
|
logging_obj.pre_call(
|
||||||
data["prompt"] = prompt
|
input=None,
|
||||||
## LOGGING
|
api_key=None,
|
||||||
logging_obj.pre_call(
|
additional_args={
|
||||||
input=None,
|
"api_base": url,
|
||||||
api_key=None,
|
"complete_input_dict": data,
|
||||||
additional_args={
|
"headers": {},
|
||||||
"api_base": url,
|
},
|
||||||
"complete_input_dict": data,
|
)
|
||||||
"headers": {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await session.post(url, json=data)
|
response = await session.post(url, json=data)
|
||||||
if response.status != 200:
|
|
||||||
text = await response.text()
|
|
||||||
raise OllamaError(status_code=response.status, message=text)
|
|
||||||
|
|
||||||
## LOGGING
|
if response.status != 200:
|
||||||
logging_obj.post_call(
|
text = await response.text()
|
||||||
input=prompt,
|
raise OllamaError(status_code=response.status, message=text)
|
||||||
api_key="",
|
|
||||||
original_response=response.text,
|
|
||||||
additional_args={
|
|
||||||
"headers": None,
|
|
||||||
"api_base": api_base,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
embeddings: list[float] = response_json["embedding"]
|
|
||||||
output_data.append(
|
|
||||||
{"object": "embedding", "index": idx, "embedding": embeddings}
|
|
||||||
)
|
|
||||||
|
|
||||||
input_tokens = len(encoding.encode(prompt))
|
embeddings: List[List[float]] = response_json["embeddings"]
|
||||||
total_input_tokens += input_tokens
|
for idx, emb in enumerate(embeddings):
|
||||||
|
output_data.append({"object": "embedding", "index": idx, "embedding": emb})
|
||||||
|
|
||||||
|
input_tokens = response_json.get("prompt_eval_count") or len(
|
||||||
|
encoding.encode("".join(prompt for prompt in prompts))
|
||||||
|
)
|
||||||
|
total_input_tokens += input_tokens
|
||||||
|
|
||||||
model_response.object = "list"
|
model_response.object = "list"
|
||||||
model_response.data = output_data
|
model_response.data = output_data
|
||||||
model_response.model = model
|
model_response.model = "ollama/" + model
|
||||||
setattr(
|
setattr(
|
||||||
model_response,
|
model_response,
|
||||||
"usage",
|
"usage",
|
||||||
|
|
|
@ -2195,7 +2195,7 @@ def _convert_to_bedrock_tool_call_invoke(
|
||||||
|
|
||||||
def _convert_to_bedrock_tool_call_result(
|
def _convert_to_bedrock_tool_call_result(
|
||||||
message: dict,
|
message: dict,
|
||||||
) -> BedrockMessageBlock:
|
) -> BedrockContentBlock:
|
||||||
"""
|
"""
|
||||||
OpenAI message with a tool result looks like:
|
OpenAI message with a tool result looks like:
|
||||||
{
|
{
|
||||||
|
@ -2247,7 +2247,7 @@ def _convert_to_bedrock_tool_call_result(
|
||||||
)
|
)
|
||||||
content_block = BedrockContentBlock(toolResult=tool_result)
|
content_block = BedrockContentBlock(toolResult=tool_result)
|
||||||
|
|
||||||
return BedrockMessageBlock(role="user", content=[content_block])
|
return content_block
|
||||||
|
|
||||||
|
|
||||||
def _bedrock_converse_messages_pt(
|
def _bedrock_converse_messages_pt(
|
||||||
|
@ -2289,6 +2289,12 @@ def _bedrock_converse_messages_pt(
|
||||||
|
|
||||||
msg_i += 1
|
msg_i += 1
|
||||||
|
|
||||||
|
## MERGE CONSECUTIVE TOOL CALL MESSAGES ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||||
|
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
|
||||||
|
|
||||||
|
user_content.append(tool_call_result)
|
||||||
|
msg_i += 1
|
||||||
if user_content:
|
if user_content:
|
||||||
contents.append(BedrockMessageBlock(role="user", content=user_content))
|
contents.append(BedrockMessageBlock(role="user", content=user_content))
|
||||||
assistant_content: List[BedrockContentBlock] = []
|
assistant_content: List[BedrockContentBlock] = []
|
||||||
|
@ -2332,11 +2338,6 @@ def _bedrock_converse_messages_pt(
|
||||||
BedrockMessageBlock(role="assistant", content=assistant_content)
|
BedrockMessageBlock(role="assistant", content=assistant_content)
|
||||||
)
|
)
|
||||||
|
|
||||||
## APPEND TOOL CALL MESSAGES ##
|
|
||||||
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
|
||||||
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
|
|
||||||
contents.append(tool_call_result)
|
|
||||||
msg_i += 1
|
|
||||||
if msg_i == init_msg_i: # prevent infinite loops
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
raise litellm.BadRequestError(
|
raise litellm.BadRequestError(
|
||||||
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
||||||
|
|
|
@ -365,6 +365,7 @@ class CodestralTextCompletion(BaseLLM):
|
||||||
stream = optional_params.pop("stream", False)
|
stream = optional_params.pop("stream", False)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
}
|
}
|
||||||
|
|
|
@ -253,7 +253,7 @@ async def acompletion(
|
||||||
logit_bias: Optional[dict] = None,
|
logit_bias: Optional[dict] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
# openai v1.0+ new params
|
# openai v1.0+ new params
|
||||||
response_format: Optional[dict] = None,
|
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
tools: Optional[List] = None,
|
tools: Optional[List] = None,
|
||||||
tool_choice: Optional[str] = None,
|
tool_choice: Optional[str] = None,
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: ollama/mistral
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: ollama/mistral
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
api_base: os.environ/AZURE_API_BASE
|
|
||||||
|
|
|
@ -587,6 +587,7 @@ class GenerateKeyRequest(GenerateRequestBase):
|
||||||
send_invite_email: Optional[bool] = None
|
send_invite_email: Optional[bool] = None
|
||||||
model_rpm_limit: Optional[dict] = None
|
model_rpm_limit: Optional[dict] = None
|
||||||
model_tpm_limit: Optional[dict] = None
|
model_tpm_limit: Optional[dict] = None
|
||||||
|
guardrails: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class GenerateKeyResponse(GenerateKeyRequest):
|
class GenerateKeyResponse(GenerateKeyRequest):
|
||||||
|
|
|
@ -1269,8 +1269,9 @@ def _get_user_role(
|
||||||
|
|
||||||
def _get_request_ip_address(
|
def _get_request_ip_address(
|
||||||
request: Request, use_x_forwarded_for: Optional[bool] = False
|
request: Request, use_x_forwarded_for: Optional[bool] = False
|
||||||
) -> str:
|
) -> Optional[str]:
|
||||||
|
|
||||||
|
client_ip = None
|
||||||
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
|
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
|
||||||
client_ip = request.headers["x-forwarded-for"]
|
client_ip = request.headers["x-forwarded-for"]
|
||||||
elif request.client is not None:
|
elif request.client is not None:
|
||||||
|
|
|
@ -331,13 +331,33 @@ async def add_litellm_data_to_request(
|
||||||
|
|
||||||
# Guardrails
|
# Guardrails
|
||||||
move_guardrails_to_metadata(
|
move_guardrails_to_metadata(
|
||||||
data=data, _metadata_variable_name=_metadata_variable_name
|
data=data,
|
||||||
|
_metadata_variable_name=_metadata_variable_name,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str):
|
def move_guardrails_to_metadata(
|
||||||
|
data: dict,
|
||||||
|
_metadata_variable_name: str,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Heper to add guardrails from request to metadata
|
||||||
|
|
||||||
|
- If guardrails set on API Key metadata then sets guardrails on request metadata
|
||||||
|
- If guardrails not set on API key, then checks request metadata
|
||||||
|
|
||||||
|
"""
|
||||||
|
if user_api_key_dict.metadata:
|
||||||
|
if "guardrails" in user_api_key_dict.metadata:
|
||||||
|
data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[
|
||||||
|
"guardrails"
|
||||||
|
]
|
||||||
|
return
|
||||||
|
|
||||||
if "guardrails" in data:
|
if "guardrails" in data:
|
||||||
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
||||||
del data["guardrails"]
|
del data["guardrails"]
|
||||||
|
|
|
@ -66,6 +66,7 @@ async def generate_key_fn(
|
||||||
- budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
- budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||||
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
|
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
|
||||||
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
||||||
|
- guardrails: Optional[List[str]] - List of active guardrails for the key
|
||||||
- permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false}
|
- permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false}
|
||||||
- model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget.
|
- model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget.
|
||||||
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
|
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
|
||||||
|
@ -321,11 +322,12 @@ async def update_key_fn(
|
||||||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
|
||||||
# get non default values for key
|
# get non default values for key
|
||||||
non_default_values = {}
|
non_default_values = {}
|
||||||
for k, v in data_json.items():
|
for k, v in data_json.items():
|
||||||
# this field gets stored in metadata
|
# this field gets stored in metadata
|
||||||
if key == "model_rpm_limit" or key == "model_tpm_limit":
|
if key in _metadata_fields:
|
||||||
continue
|
continue
|
||||||
if v is not None and v not in (
|
if v is not None and v not in (
|
||||||
[],
|
[],
|
||||||
|
@ -366,6 +368,14 @@ async def update_key_fn(
|
||||||
non_default_values["metadata"] = _metadata
|
non_default_values["metadata"] = _metadata
|
||||||
non_default_values.pop("model_rpm_limit", None)
|
non_default_values.pop("model_rpm_limit", None)
|
||||||
|
|
||||||
|
if data.guardrails:
|
||||||
|
_metadata = existing_key_row.metadata or {}
|
||||||
|
_metadata["guardrails"] = data.guardrails
|
||||||
|
|
||||||
|
# update values that will be written to the DB
|
||||||
|
non_default_values["metadata"] = _metadata
|
||||||
|
non_default_values.pop("guardrails", None)
|
||||||
|
|
||||||
response = await prisma_client.update_data(
|
response = await prisma_client.update_data(
|
||||||
token=key, data={**non_default_values, "token": key}
|
token=key, data={**non_default_values, "token": key}
|
||||||
)
|
)
|
||||||
|
@ -734,6 +744,7 @@ async def generate_key_helper_fn(
|
||||||
model_max_budget: Optional[dict] = {},
|
model_max_budget: Optional[dict] = {},
|
||||||
model_rpm_limit: Optional[dict] = {},
|
model_rpm_limit: Optional[dict] = {},
|
||||||
model_tpm_limit: Optional[dict] = {},
|
model_tpm_limit: Optional[dict] = {},
|
||||||
|
guardrails: Optional[list] = None,
|
||||||
teams: Optional[list] = None,
|
teams: Optional[list] = None,
|
||||||
organization_id: Optional[str] = None,
|
organization_id: Optional[str] = None,
|
||||||
table_name: Optional[Literal["key", "user"]] = None,
|
table_name: Optional[Literal["key", "user"]] = None,
|
||||||
|
@ -783,6 +794,9 @@ async def generate_key_helper_fn(
|
||||||
if model_tpm_limit is not None:
|
if model_tpm_limit is not None:
|
||||||
metadata = metadata or {}
|
metadata = metadata or {}
|
||||||
metadata["model_tpm_limit"] = model_tpm_limit
|
metadata["model_tpm_limit"] = model_tpm_limit
|
||||||
|
if guardrails is not None:
|
||||||
|
metadata = metadata or {}
|
||||||
|
metadata["guardrails"] = guardrails
|
||||||
|
|
||||||
metadata_json = json.dumps(metadata)
|
metadata_json = json.dumps(metadata)
|
||||||
model_max_budget_json = json.dumps(model_max_budget)
|
model_max_budget_json = json.dumps(model_max_budget)
|
||||||
|
|
|
@ -360,24 +360,22 @@ async def pass_through_request(
|
||||||
|
|
||||||
# combine url with query params for logging
|
# combine url with query params for logging
|
||||||
|
|
||||||
# requested_query_params = query_params or request.query_params.__dict__
|
requested_query_params = query_params or request.query_params.__dict__
|
||||||
# requested_query_params_str = "&".join(
|
requested_query_params_str = "&".join(
|
||||||
# f"{k}={v}" for k, v in requested_query_params.items()
|
f"{k}={v}" for k, v in requested_query_params.items()
|
||||||
# )
|
)
|
||||||
|
|
||||||
requested_query_params = None
|
if "?" in str(url):
|
||||||
|
logging_url = str(url) + "&" + requested_query_params_str
|
||||||
# if "?" in str(url):
|
else:
|
||||||
# logging_url = str(url) + "&" + requested_query_params_str
|
logging_url = str(url) + "?" + requested_query_params_str
|
||||||
# else:
|
|
||||||
# logging_url = str(url) + "?" + requested_query_params_str
|
|
||||||
|
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": _parsed_body,
|
"complete_input_dict": _parsed_body,
|
||||||
"api_base": str(url),
|
"api_base": str(logging_url),
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -2350,7 +2350,8 @@ async def initialize(
|
||||||
config=None,
|
config=None,
|
||||||
):
|
):
|
||||||
global user_model, user_api_base, user_debug, user_detailed_debug, user_user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
|
global user_model, user_api_base, user_debug, user_detailed_debug, user_user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
|
||||||
generate_feedback_box()
|
if os.getenv("LITELLM_DONT_SHOW_FEEDBACK_BOX", "").lower() != "true":
|
||||||
|
generate_feedback_box()
|
||||||
user_model = model
|
user_model = model
|
||||||
user_debug = debug
|
user_debug = debug
|
||||||
if debug is True: # this needs to be first, so users can see Router init debugg
|
if debug is True: # this needs to be first, so users can see Router init debugg
|
||||||
|
@ -8065,14 +8066,14 @@ async def login(request: Request):
|
||||||
return redirect_response
|
return redirect_response
|
||||||
else:
|
else:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nNot valid credentials for {username}",
|
message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
|
||||||
type=ProxyErrorTypes.auth_error,
|
type=ProxyErrorTypes.auth_error,
|
||||||
param="invalid_credentials",
|
param="invalid_credentials",
|
||||||
code=status.HTTP_401_UNAUTHORIZED,
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
|
message="Invalid credentials used to access UI.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
|
||||||
type=ProxyErrorTypes.auth_error,
|
type=ProxyErrorTypes.auth_error,
|
||||||
param="invalid_credentials",
|
param="invalid_credentials",
|
||||||
code=status.HTTP_401_UNAUTHORIZED,
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
|
@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
|
||||||
# litellm.num_retries =3
|
# litellm.num_retries=3
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
user_message = "Write a short poem about the sky"
|
user_message = "Write a short poem about the sky"
|
||||||
|
|
|
@ -1,18 +1,20 @@
|
||||||
import sys, os
|
import os
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os, io
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
from litellm import RateLimitError
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
litellm.num_retries = 0
|
litellm.num_retries = 0
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
@ -41,7 +43,14 @@ def get_current_weather(location, unit="fahrenheit"):
|
||||||
|
|
||||||
# In production, this could be your backend API or an external API
|
# In production, this could be your backend API or an external API
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model", ["gpt-3.5-turbo-1106", "mistral/mistral-large-latest"]
|
"model",
|
||||||
|
[
|
||||||
|
"gpt-3.5-turbo-1106",
|
||||||
|
"mistral/mistral-large-latest",
|
||||||
|
"claude-3-haiku-20240307",
|
||||||
|
"gemini/gemini-1.5-pro",
|
||||||
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_parallel_function_call(model):
|
def test_parallel_function_call(model):
|
||||||
try:
|
try:
|
||||||
|
@ -124,7 +133,12 @@ def test_parallel_function_call(model):
|
||||||
) # extend conversation with function response
|
) # extend conversation with function response
|
||||||
print(f"messages: {messages}")
|
print(f"messages: {messages}")
|
||||||
second_response = litellm.completion(
|
second_response = litellm.completion(
|
||||||
model=model, messages=messages, temperature=0.2, seed=22
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.2,
|
||||||
|
seed=22,
|
||||||
|
tools=tools,
|
||||||
|
drop_params=True,
|
||||||
) # get a new response from the model where it can see the function response
|
) # get a new response from the model where it can see the function response
|
||||||
print("second response\n", second_response)
|
print("second response\n", second_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -2770,6 +2770,60 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_generate_key_with_guardrails(prisma_client):
|
||||||
|
print("prisma client=", prisma_client)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
request = GenerateKeyRequest(
|
||||||
|
guardrails=["aporia-pre-call"],
|
||||||
|
metadata={
|
||||||
|
"team": "litellm-team3",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
key = await generate_key_fn(
|
||||||
|
data=request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
print("generated key=", key)
|
||||||
|
|
||||||
|
generated_key = key.key
|
||||||
|
|
||||||
|
# use generated key to auth in
|
||||||
|
result = await info_key_fn(key=generated_key)
|
||||||
|
print("result from info_key_fn", result)
|
||||||
|
assert result["key"] == generated_key
|
||||||
|
print("\n info for key=", result["info"])
|
||||||
|
assert result["info"]["metadata"] == {
|
||||||
|
"team": "litellm-team3",
|
||||||
|
"guardrails": ["aporia-pre-call"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update model tpm_limit and rpm_limit
|
||||||
|
request = UpdateKeyRequest(
|
||||||
|
key=generated_key,
|
||||||
|
guardrails=["aporia-pre-call", "aporia-post-call"],
|
||||||
|
)
|
||||||
|
_request = Request(scope={"type": "http"})
|
||||||
|
_request._url = URL(url="/update/key")
|
||||||
|
|
||||||
|
await update_key_fn(data=request, request=_request)
|
||||||
|
result = await info_key_fn(key=generated_key)
|
||||||
|
print("result from info_key_fn", result)
|
||||||
|
assert result["key"] == generated_key
|
||||||
|
print("\n info for key=", result["info"])
|
||||||
|
assert result["info"]["metadata"] == {
|
||||||
|
"team": "litellm-team3",
|
||||||
|
"guardrails": ["aporia-pre-call", "aporia-post-call"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
async def test_team_access_groups(prisma_client):
|
async def test_team_access_groups(prisma_client):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -132,6 +132,7 @@ def test_ollama_aembeddings(mock_aembeddings):
|
||||||
# test_ollama_aembeddings()
|
# test_ollama_aembeddings()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="local only test")
|
||||||
def test_ollama_chat_function_calling():
|
def test_ollama_chat_function_calling():
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
|
@ -313,3 +313,78 @@ def test_anthropic_cache_controls_pt():
|
||||||
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
|
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||||
|
|
||||||
print("translated_messages: ", translated_messages)
|
print("translated_messages: ", translated_messages)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["bedrock", "anthropic"])
|
||||||
|
def test_bedrock_parallel_tool_calling_pt(provider):
|
||||||
|
"""
|
||||||
|
Make sure parallel tool call blocks are merged correctly - https://github.com/BerriAI/litellm/issues/5277
|
||||||
|
"""
|
||||||
|
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
|
||||||
|
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
|
||||||
|
},
|
||||||
|
Message(
|
||||||
|
content="Here are the current weather conditions for San Francisco, Tokyo, and Paris:",
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
index=1,
|
||||||
|
function=Function(
|
||||||
|
arguments='{"city": "New York"}',
|
||||||
|
name="get_current_weather",
|
||||||
|
),
|
||||||
|
id="tooluse_XcqEBfm8R-2YVaPhDUHsPQ",
|
||||||
|
type="function",
|
||||||
|
),
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
index=2,
|
||||||
|
function=Function(
|
||||||
|
arguments='{"city": "London"}',
|
||||||
|
name="get_current_weather",
|
||||||
|
),
|
||||||
|
id="tooluse_VB9nk7UGRniVzGcaj6xrAQ",
|
||||||
|
type="function",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
function_call=None,
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"tool_call_id": "tooluse_XcqEBfm8R-2YVaPhDUHsPQ",
|
||||||
|
"role": "tool",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "25 degrees celsius.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_call_id": "tooluse_VB9nk7UGRniVzGcaj6xrAQ",
|
||||||
|
"role": "tool",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "28 degrees celsius.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
if provider == "bedrock":
|
||||||
|
translated_messages = _bedrock_converse_messages_pt(
|
||||||
|
messages=messages,
|
||||||
|
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
llm_provider="bedrock",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
translated_messages = anthropic_messages_pt(
|
||||||
|
messages=messages,
|
||||||
|
model="claude-3-sonnet-20240229-v1:0",
|
||||||
|
llm_provider=provider,
|
||||||
|
)
|
||||||
|
print(translated_messages)
|
||||||
|
|
||||||
|
number_of_messages = len(translated_messages)
|
||||||
|
|
||||||
|
# assert last 2 messages are not the same role
|
||||||
|
assert (
|
||||||
|
translated_messages[number_of_messages - 1]["role"]
|
||||||
|
!= translated_messages[number_of_messages - 2]["role"]
|
||||||
|
)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.43.18"
|
version = "1.43.19"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.43.18"
|
version = "1.43.19"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -22,10 +22,6 @@ async def chat_completion(
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"guardrails": [
|
|
||||||
"aporia-post-guard",
|
|
||||||
"aporia-pre-guard",
|
|
||||||
], # default guardrails for all tests
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if guardrails is not None:
|
if guardrails is not None:
|
||||||
|
@ -41,7 +37,7 @@ async def chat_completion(
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if status != 200:
|
if status != 200:
|
||||||
return response_text
|
raise Exception(response_text)
|
||||||
|
|
||||||
# response headers
|
# response headers
|
||||||
response_headers = response.headers
|
response_headers = response.headers
|
||||||
|
@ -50,6 +46,29 @@ async def chat_completion(
|
||||||
return await response.json(), response_headers
|
return await response.json(), response_headers
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_key(session, guardrails):
|
||||||
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
if guardrails:
|
||||||
|
data = {
|
||||||
|
"guardrails": guardrails,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_guard_triggered_safe_request():
|
async def test_llm_guard_triggered_safe_request():
|
||||||
"""
|
"""
|
||||||
|
@ -62,6 +81,10 @@ async def test_llm_guard_triggered_safe_request():
|
||||||
"sk-1234",
|
"sk-1234",
|
||||||
model="fake-openai-endpoint",
|
model="fake-openai-endpoint",
|
||||||
messages=[{"role": "user", "content": f"Hello what's the weather"}],
|
messages=[{"role": "user", "content": f"Hello what's the weather"}],
|
||||||
|
guardrails=[
|
||||||
|
"aporia-post-guard",
|
||||||
|
"aporia-pre-guard",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
await asyncio.sleep(3)
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
@ -90,6 +113,10 @@ async def test_llm_guard_triggered():
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
|
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
|
||||||
],
|
],
|
||||||
|
guardrails=[
|
||||||
|
"aporia-post-guard",
|
||||||
|
"aporia-pre-guard",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
pytest.fail("Should have thrown an exception")
|
pytest.fail("Should have thrown an exception")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -116,3 +143,54 @@ async def test_no_llm_guard_triggered():
|
||||||
print("response=", response, "response headers", headers)
|
print("response=", response, "response headers", headers)
|
||||||
|
|
||||||
assert "x-litellm-applied-guardrails" not in headers
|
assert "x-litellm-applied-guardrails" not in headers
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guardrails_with_api_key_controls():
|
||||||
|
"""
|
||||||
|
- Make two API Keys
|
||||||
|
- Key 1 with no guardrails
|
||||||
|
- Key 2 with guardrails
|
||||||
|
- Request to Key 1 -> should be success with no guardrails
|
||||||
|
- Request to Key 2 -> should be error since guardrails are triggered
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
key_with_guardrails = await generate_key(
|
||||||
|
session=session,
|
||||||
|
guardrails=[
|
||||||
|
"aporia-post-guard",
|
||||||
|
"aporia-pre-guard",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
key_with_guardrails = key_with_guardrails["key"]
|
||||||
|
|
||||||
|
key_without_guardrails = await generate_key(session=session, guardrails=None)
|
||||||
|
|
||||||
|
key_without_guardrails = key_without_guardrails["key"]
|
||||||
|
|
||||||
|
# test no guardrails triggered for key without guardrails
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
key_without_guardrails,
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[{"role": "user", "content": f"Hello what's the weather"}],
|
||||||
|
)
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
print("response=", response, "response headers", headers)
|
||||||
|
assert "x-litellm-applied-guardrails" not in headers
|
||||||
|
|
||||||
|
# test guardrails triggered for key with guardrails
|
||||||
|
try:
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
key_with_guardrails,
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pytest.fail("Should have thrown an exception")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
assert "Aporia detected and blocked PII" in str(e)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue