Merge branch 'main' into litellm_fix_azure_api_version

This commit is contained in:
Krish Dholakia 2024-08-20 11:40:53 -07:00 committed by GitHub
commit 409306b266
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 490 additions and 103 deletions

View file

@ -194,6 +194,8 @@ jobs:
platforms: local,linux/amd64,linux/arm64,linux/arm64/v8 platforms: local,linux/amd64,linux/arm64,linux/arm64/v8
build-and-push-helm-chart: build-and-push-helm-chart:
if: github.event.inputs.release_type != 'dev'
needs: [docker-hub-deploy, build-and-push-image, build-and-push-image-database]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -211,9 +213,17 @@ jobs:
- name: lowercase github.repository_owner - name: lowercase github.repository_owner
run: | run: |
echo "REPO_OWNER=`echo ${{github.repository_owner}} | tr '[:upper:]' '[:lower:]'`" >>${GITHUB_ENV} echo "REPO_OWNER=`echo ${{github.repository_owner}} | tr '[:upper:]' '[:lower:]'`" >>${GITHUB_ENV}
- name: Get LiteLLM Latest Tag - name: Get LiteLLM Latest Tag
id: current_app_tag id: current_app_tag
uses: WyriHaximus/github-action-get-previous-tag@v1.3.0 shell: bash
run: |
LATEST_TAG=$(git describe --tags --exclude "*dev*" --abbrev=0)
if [ -z "${LATEST_TAG}" ]; then
echo "latest_tag=latest" | tee -a $GITHUB_OUTPUT
else
echo "latest_tag=${LATEST_TAG}" | tee -a $GITHUB_OUTPUT
fi
- name: Get last published chart version - name: Get last published chart version
id: current_version id: current_version
@ -241,7 +251,7 @@ jobs:
name: ${{ env.CHART_NAME }} name: ${{ env.CHART_NAME }}
repository: ${{ env.REPO_OWNER }} repository: ${{ env.REPO_OWNER }}
tag: ${{ github.event.inputs.chartVersion || steps.bump_version.outputs.next-version || '0.1.0' }} tag: ${{ github.event.inputs.chartVersion || steps.bump_version.outputs.next-version || '0.1.0' }}
app_version: ${{ steps.current_app_tag.outputs.tag || 'latest' }} app_version: ${{ steps.current_app_tag.outputs.latest_tag }}
path: deploy/charts/${{ env.CHART_NAME }} path: deploy/charts/${{ env.CHART_NAME }}
registry: ${{ env.REGISTRY }} registry: ${{ env.REGISTRY }}
registry_username: ${{ github.actor }} registry_username: ${{ github.actor }}

View file

@ -0,0 +1,132 @@
# Langfuse Endpoints (Pass-Through)
Pass-through endpoints for Langfuse - call langfuse endpoints with LiteLLM Virtual Key.
Just replace `https://us.cloud.langfuse.com` with `LITELLM_PROXY_BASE_URL/langfuse` 🚀
#### **Example Usage**
```python
from langfuse import Langfuse
langfuse = Langfuse(
host="http://localhost:4000/langfuse", # your litellm proxy endpoint
public_key="anything", # no key required since this is a pass through
secret_key="LITELLM_VIRTUAL_KEY", # no key required since this is a pass through
)
print("sending langfuse trace request")
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
print("flushing langfuse request")
langfuse.flush()
print("flushed langfuse request")
```
Supports **ALL** Langfuse Endpoints.
[**See All Langfuse Endpoints**](https://api.reference.langfuse.com/)
## Quick Start
Let's log a trace to Langfuse.
1. Add Langfuse Public/Private keys to environment
```bash
export LANGFUSE_PUBLIC_KEY=""
export LANGFUSE_PRIVATE_KEY=""
```
2. Start LiteLLM Proxy
```bash
litellm
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
Let's log a trace to Langfuse!
```python
from langfuse import Langfuse
langfuse = Langfuse(
host="http://localhost:4000/langfuse", # your litellm proxy endpoint
public_key="anything", # no key required since this is a pass through
secret_key="anything", # no key required since this is a pass through
)
print("sending langfuse trace request")
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
print("flushing langfuse request")
langfuse.flush()
print("flushed langfuse request")
```
## Advanced - Use with Virtual Keys
Pre-requisites
- [Setup proxy with DB](../proxy/virtual_keys.md#setup)
Use this, to avoid giving developers the raw Google AI Studio key, but still letting them use Google AI Studio endpoints.
### Usage
1. Setup environment
```bash
export DATABASE_URL=""
export LITELLM_MASTER_KEY=""
export LANGFUSE_PUBLIC_KEY=""
export LANGFUSE_PRIVATE_KEY=""
```
```bash
litellm
# RUNNING on http://0.0.0.0:4000
```
2. Generate virtual key
```bash
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{}'
```
Expected Response
```bash
{
...
"key": "sk-1234ewknldferwedojwojw"
}
```
3. Test it!
```python
from langfuse import Langfuse
langfuse = Langfuse(
host="http://localhost:4000/langfuse", # your litellm proxy endpoint
public_key="anything", # no key required since this is a pass through
secret_key="sk-1234ewknldferwedojwojw", # no key required since this is a pass through
)
print("sending langfuse trace request")
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
print("flushing langfuse request")
langfuse.flush()
print("flushed langfuse request")
```
## [Advanced - Log to separate langfuse projects (by key/team)](../proxy/team_logging.md)

View file

@ -207,7 +207,7 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{ -d '{
"metadata": { "metadata": {
"logging": { "logging": [{
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary' "callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
"callback_type": "success" # set, if required by integration - future improvement, have logging tools work for success + failure by default "callback_type": "success" # set, if required by integration - future improvement, have logging tools work for success + failure by default
"callback_vars": { "callback_vars": {
@ -215,7 +215,7 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", # [RECOMMENDED] reference key in proxy environment "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", # [RECOMMENDED] reference key in proxy environment
"langfuse_host": "https://cloud.langfuse.com" "langfuse_host": "https://cloud.langfuse.com"
} }
} }]
} }
}' }'

View file

@ -61,7 +61,7 @@ guardrails:
- `pre_call` Run **before** LLM call, on **input** - `pre_call` Run **before** LLM call, on **input**
- `post_call` Run **after** LLM call, on **input & output** - `post_call` Run **after** LLM call, on **input & output**
- `during_call` Run **during** LLM call, on **input** - `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
## 3. Start LiteLLM Gateway ## 3. Start LiteLLM Gateway
@ -72,6 +72,8 @@ litellm --config config.yaml --detailed_debug
## 4. Test request ## 4. Test request
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)**
<Tabs> <Tabs>
<TabItem label="Unsuccessful call" value = "not-allowed"> <TabItem label="Unsuccessful call" value = "not-allowed">
@ -134,12 +136,10 @@ curl -i http://localhost:4000/v1/chat/completions \
</Tabs> </Tabs>
## Advanced ## 5. Control Guardrails per Project (API Key)
### Control Guardrails per Project (API Key)
Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project Use this to control what guardrails run per project. In this tutorial we only want the following guardrails to run for 1 project (API Key)
- `pre_call_guardrails`: ["aporia-pre-guard"] - `guardrails`: ["aporia-pre-guard", "aporia-post-guard"]
- `post_call_guardrails`: ["aporia-post-guard"]
**Step 1** Create Key with guardrail settings **Step 1** Create Key with guardrail settings
@ -151,8 +151,7 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \ -H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-D '{ -D '{
"pre_call_guardrails": ["aporia-pre-guard"], "guardrails": ["aporia-pre-guard", "aporia-post-guard"]
"post_call_guardrails": ["aporia"]
} }
}' }'
``` ```
@ -166,8 +165,7 @@ curl --location 'http://0.0.0.0:4000/key/update' \
--header 'Content-Type: application/json' \ --header 'Content-Type: application/json' \
--data '{ --data '{
"key": "sk-jNm1Zar7XfNdZXp49Z1kSQ", "key": "sk-jNm1Zar7XfNdZXp49Z1kSQ",
"pre_call_guardrails": ["aporia"], "guardrails": ["aporia-pre-guard", "aporia-post-guard"]
"post_call_guardrails": ["aporia"]
} }
}' }'
``` ```

View file

@ -195,7 +195,8 @@ const sidebars = {
"pass_through/vertex_ai", "pass_through/vertex_ai",
"pass_through/google_ai_studio", "pass_through/google_ai_studio",
"pass_through/cohere", "pass_through/cohere",
"pass_through/bedrock" "pass_through/bedrock",
"pass_through/langfuse"
], ],
}, },
"scheduler", "scheduler",

View file

@ -509,16 +509,16 @@ async def ollama_acompletion(
async def ollama_aembeddings( async def ollama_aembeddings(
api_base: str, api_base: str,
model: str, model: str,
prompts: list, prompts: List[str],
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
optional_params: dict, optional_params: dict,
logging_obj=None, logging_obj=None,
encoding=None, encoding=None,
): ):
if api_base.endswith("/api/embeddings"): if api_base.endswith("/api/embed"):
url = api_base url = api_base
else: else:
url = f"{api_base}/api/embeddings" url = f"{api_base}/api/embed"
## Load Config ## Load Config
config = litellm.OllamaConfig.get_config() config = litellm.OllamaConfig.get_config()
@ -528,25 +528,22 @@ async def ollama_aembeddings(
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
input_data: Dict[str, Any] = {"model": model} data: Dict[str, Any] = {"model": model, "input": prompts}
special_optional_params = ["truncate", "options", "keep_alive"] special_optional_params = ["truncate", "options", "keep_alive"]
for k, v in optional_params.items(): for k, v in optional_params.items():
if k in special_optional_params: if k in special_optional_params:
input_data[k] = v data[k] = v
else: else:
# Ensure "options" is a dictionary before updating it # Ensure "options" is a dictionary before updating it
input_data.setdefault("options", {}) data.setdefault("options", {})
if isinstance(input_data["options"], dict): if isinstance(data["options"], dict):
input_data["options"].update({k: v}) data["options"].update({k: v})
total_input_tokens = 0 total_input_tokens = 0
output_data = [] output_data = []
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
for idx, prompt in enumerate(prompts):
data = deepcopy(input_data)
data["prompt"] = prompt
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=None, input=None,
@ -559,33 +556,25 @@ async def ollama_aembeddings(
) )
response = await session.post(url, json=data) response = await session.post(url, json=data)
if response.status != 200: if response.status != 200:
text = await response.text() text = await response.text()
raise OllamaError(status_code=response.status, message=text) raise OllamaError(status_code=response.status, message=text)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={
"headers": None,
"api_base": api_base,
},
)
response_json = await response.json() response_json = await response.json()
embeddings: list[float] = response_json["embedding"]
output_data.append(
{"object": "embedding", "index": idx, "embedding": embeddings}
)
input_tokens = len(encoding.encode(prompt)) embeddings: List[List[float]] = response_json["embeddings"]
for idx, emb in enumerate(embeddings):
output_data.append({"object": "embedding", "index": idx, "embedding": emb})
input_tokens = response_json.get("prompt_eval_count") or len(
encoding.encode("".join(prompt for prompt in prompts))
)
total_input_tokens += input_tokens total_input_tokens += input_tokens
model_response.object = "list" model_response.object = "list"
model_response.data = output_data model_response.data = output_data
model_response.model = model model_response.model = "ollama/" + model
setattr( setattr(
model_response, model_response,
"usage", "usage",

View file

@ -2195,7 +2195,7 @@ def _convert_to_bedrock_tool_call_invoke(
def _convert_to_bedrock_tool_call_result( def _convert_to_bedrock_tool_call_result(
message: dict, message: dict,
) -> BedrockMessageBlock: ) -> BedrockContentBlock:
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:
{ {
@ -2247,7 +2247,7 @@ def _convert_to_bedrock_tool_call_result(
) )
content_block = BedrockContentBlock(toolResult=tool_result) content_block = BedrockContentBlock(toolResult=tool_result)
return BedrockMessageBlock(role="user", content=[content_block]) return content_block
def _bedrock_converse_messages_pt( def _bedrock_converse_messages_pt(
@ -2289,6 +2289,12 @@ def _bedrock_converse_messages_pt(
msg_i += 1 msg_i += 1
## MERGE CONSECUTIVE TOOL CALL MESSAGES ##
while msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
user_content.append(tool_call_result)
msg_i += 1
if user_content: if user_content:
contents.append(BedrockMessageBlock(role="user", content=user_content)) contents.append(BedrockMessageBlock(role="user", content=user_content))
assistant_content: List[BedrockContentBlock] = [] assistant_content: List[BedrockContentBlock] = []
@ -2332,11 +2338,6 @@ def _bedrock_converse_messages_pt(
BedrockMessageBlock(role="assistant", content=assistant_content) BedrockMessageBlock(role="assistant", content=assistant_content)
) )
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
contents.append(tool_call_result)
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops if msg_i == init_msg_i: # prevent infinite loops
raise litellm.BadRequestError( raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}", message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",

View file

@ -365,6 +365,7 @@ class CodestralTextCompletion(BaseLLM):
stream = optional_params.pop("stream", False) stream = optional_params.pop("stream", False)
data = { data = {
"model": model,
"prompt": prompt, "prompt": prompt,
**optional_params, **optional_params,
} }

View file

@ -253,7 +253,7 @@ async def acompletion(
logit_bias: Optional[dict] = None, logit_bias: Optional[dict] = None,
user: Optional[str] = None, user: Optional[str] = None,
# openai v1.0+ new params # openai v1.0+ new params
response_format: Optional[dict] = None, response_format: Optional[Union[dict, Type[BaseModel]]] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
tools: Optional[List] = None, tools: Optional[List] = None,
tool_choice: Optional[str] = None, tool_choice: Optional[str] = None,

View file

@ -1,6 +1,4 @@
model_list: model_list:
- model_name: gpt-3.5-turbo - model_name: ollama/mistral
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: ollama/mistral
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE

View file

@ -587,6 +587,7 @@ class GenerateKeyRequest(GenerateRequestBase):
send_invite_email: Optional[bool] = None send_invite_email: Optional[bool] = None
model_rpm_limit: Optional[dict] = None model_rpm_limit: Optional[dict] = None
model_tpm_limit: Optional[dict] = None model_tpm_limit: Optional[dict] = None
guardrails: Optional[List[str]] = None
class GenerateKeyResponse(GenerateKeyRequest): class GenerateKeyResponse(GenerateKeyRequest):

View file

@ -1269,8 +1269,9 @@ def _get_user_role(
def _get_request_ip_address( def _get_request_ip_address(
request: Request, use_x_forwarded_for: Optional[bool] = False request: Request, use_x_forwarded_for: Optional[bool] = False
) -> str: ) -> Optional[str]:
client_ip = None
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers: if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
client_ip = request.headers["x-forwarded-for"] client_ip = request.headers["x-forwarded-for"]
elif request.client is not None: elif request.client is not None:

View file

@ -331,13 +331,33 @@ async def add_litellm_data_to_request(
# Guardrails # Guardrails
move_guardrails_to_metadata( move_guardrails_to_metadata(
data=data, _metadata_variable_name=_metadata_variable_name data=data,
_metadata_variable_name=_metadata_variable_name,
user_api_key_dict=user_api_key_dict,
) )
return data return data
def move_guardrails_to_metadata(data: dict, _metadata_variable_name: str): def move_guardrails_to_metadata(
data: dict,
_metadata_variable_name: str,
user_api_key_dict: UserAPIKeyAuth,
):
"""
Heper to add guardrails from request to metadata
- If guardrails set on API Key metadata then sets guardrails on request metadata
- If guardrails not set on API key, then checks request metadata
"""
if user_api_key_dict.metadata:
if "guardrails" in user_api_key_dict.metadata:
data[_metadata_variable_name]["guardrails"] = user_api_key_dict.metadata[
"guardrails"
]
return
if "guardrails" in data: if "guardrails" in data:
data[_metadata_variable_name]["guardrails"] = data["guardrails"] data[_metadata_variable_name]["guardrails"] = data["guardrails"]
del data["guardrails"] del data["guardrails"]

View file

@ -66,6 +66,7 @@ async def generate_key_fn(
- budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). - budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
- guardrails: Optional[List[str]] - List of active guardrails for the key
- permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false} - permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false}
- model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget. - model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget.
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit. - model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
@ -321,11 +322,12 @@ async def update_key_fn(
detail={"error": f"Team not found, passed team_id={data.team_id}"}, detail={"error": f"Team not found, passed team_id={data.team_id}"},
) )
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
# get non default values for key # get non default values for key
non_default_values = {} non_default_values = {}
for k, v in data_json.items(): for k, v in data_json.items():
# this field gets stored in metadata # this field gets stored in metadata
if key == "model_rpm_limit" or key == "model_tpm_limit": if key in _metadata_fields:
continue continue
if v is not None and v not in ( if v is not None and v not in (
[], [],
@ -366,6 +368,14 @@ async def update_key_fn(
non_default_values["metadata"] = _metadata non_default_values["metadata"] = _metadata
non_default_values.pop("model_rpm_limit", None) non_default_values.pop("model_rpm_limit", None)
if data.guardrails:
_metadata = existing_key_row.metadata or {}
_metadata["guardrails"] = data.guardrails
# update values that will be written to the DB
non_default_values["metadata"] = _metadata
non_default_values.pop("guardrails", None)
response = await prisma_client.update_data( response = await prisma_client.update_data(
token=key, data={**non_default_values, "token": key} token=key, data={**non_default_values, "token": key}
) )
@ -734,6 +744,7 @@ async def generate_key_helper_fn(
model_max_budget: Optional[dict] = {}, model_max_budget: Optional[dict] = {},
model_rpm_limit: Optional[dict] = {}, model_rpm_limit: Optional[dict] = {},
model_tpm_limit: Optional[dict] = {}, model_tpm_limit: Optional[dict] = {},
guardrails: Optional[list] = None,
teams: Optional[list] = None, teams: Optional[list] = None,
organization_id: Optional[str] = None, organization_id: Optional[str] = None,
table_name: Optional[Literal["key", "user"]] = None, table_name: Optional[Literal["key", "user"]] = None,
@ -783,6 +794,9 @@ async def generate_key_helper_fn(
if model_tpm_limit is not None: if model_tpm_limit is not None:
metadata = metadata or {} metadata = metadata or {}
metadata["model_tpm_limit"] = model_tpm_limit metadata["model_tpm_limit"] = model_tpm_limit
if guardrails is not None:
metadata = metadata or {}
metadata["guardrails"] = guardrails
metadata_json = json.dumps(metadata) metadata_json = json.dumps(metadata)
model_max_budget_json = json.dumps(model_max_budget) model_max_budget_json = json.dumps(model_max_budget)

View file

@ -360,24 +360,22 @@ async def pass_through_request(
# combine url with query params for logging # combine url with query params for logging
# requested_query_params = query_params or request.query_params.__dict__ requested_query_params = query_params or request.query_params.__dict__
# requested_query_params_str = "&".join( requested_query_params_str = "&".join(
# f"{k}={v}" for k, v in requested_query_params.items() f"{k}={v}" for k, v in requested_query_params.items()
# ) )
requested_query_params = None if "?" in str(url):
logging_url = str(url) + "&" + requested_query_params_str
# if "?" in str(url): else:
# logging_url = str(url) + "&" + requested_query_params_str logging_url = str(url) + "?" + requested_query_params_str
# else:
# logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call( logging_obj.pre_call(
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}], input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
api_key="", api_key="",
additional_args={ additional_args={
"complete_input_dict": _parsed_body, "complete_input_dict": _parsed_body,
"api_base": str(url), "api_base": str(logging_url),
"headers": headers, "headers": headers,
}, },
) )

View file

@ -2350,6 +2350,7 @@ async def initialize(
config=None, config=None,
): ):
global user_model, user_api_base, user_debug, user_detailed_debug, user_user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client global user_model, user_api_base, user_debug, user_detailed_debug, user_user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
if os.getenv("LITELLM_DONT_SHOW_FEEDBACK_BOX", "").lower() != "true":
generate_feedback_box() generate_feedback_box()
user_model = model user_model = model
user_debug = debug user_debug = debug
@ -8065,14 +8066,14 @@ async def login(request: Request):
return redirect_response return redirect_response
else: else:
raise ProxyException( raise ProxyException(
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nNot valid credentials for {username}", message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
type=ProxyErrorTypes.auth_error, type=ProxyErrorTypes.auth_error,
param="invalid_credentials", param="invalid_credentials",
code=status.HTTP_401_UNAUTHORIZED, code=status.HTTP_401_UNAUTHORIZED,
) )
else: else:
raise ProxyException( raise ProxyException(
message=f"Invalid credentials used to access UI. Passed in username: {username}, passed in password: {password}.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file", message="Invalid credentials used to access UI.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
type=ProxyErrorTypes.auth_error, type=ProxyErrorTypes.auth_error,
param="invalid_credentials", param="invalid_credentials",
code=status.HTTP_401_UNAUTHORIZED, code=status.HTTP_401_UNAUTHORIZED,

View file

@ -1,18 +1,20 @@
import sys, os import os
import sys
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import os, io import io
import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
from litellm import RateLimitError
import pytest
litellm.num_retries = 0 litellm.num_retries = 0
litellm.cache = None litellm.cache = None
@ -41,7 +43,14 @@ def get_current_weather(location, unit="fahrenheit"):
# In production, this could be your backend API or an external API # In production, this could be your backend API or an external API
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", ["gpt-3.5-turbo-1106", "mistral/mistral-large-latest"] "model",
[
"gpt-3.5-turbo-1106",
"mistral/mistral-large-latest",
"claude-3-haiku-20240307",
"gemini/gemini-1.5-pro",
"anthropic.claude-3-sonnet-20240229-v1:0",
],
) )
def test_parallel_function_call(model): def test_parallel_function_call(model):
try: try:
@ -124,7 +133,12 @@ def test_parallel_function_call(model):
) # extend conversation with function response ) # extend conversation with function response
print(f"messages: {messages}") print(f"messages: {messages}")
second_response = litellm.completion( second_response = litellm.completion(
model=model, messages=messages, temperature=0.2, seed=22 model=model,
messages=messages,
temperature=0.2,
seed=22,
tools=tools,
drop_params=True,
) # get a new response from the model where it can see the function response ) # get a new response from the model where it can see the function response
print("second response\n", second_response) print("second response\n", second_response)
except Exception as e: except Exception as e:

View file

@ -2770,6 +2770,60 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
} }
@pytest.mark.asyncio()
async def test_generate_key_with_guardrails(prisma_client):
print("prisma client=", prisma_client)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(
guardrails=["aporia-pre-call"],
metadata={
"team": "litellm-team3",
},
)
key = await generate_key_fn(
data=request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print("generated key=", key)
generated_key = key.key
# use generated key to auth in
result = await info_key_fn(key=generated_key)
print("result from info_key_fn", result)
assert result["key"] == generated_key
print("\n info for key=", result["info"])
assert result["info"]["metadata"] == {
"team": "litellm-team3",
"guardrails": ["aporia-pre-call"],
}
# Update model tpm_limit and rpm_limit
request = UpdateKeyRequest(
key=generated_key,
guardrails=["aporia-pre-call", "aporia-post-call"],
)
_request = Request(scope={"type": "http"})
_request._url = URL(url="/update/key")
await update_key_fn(data=request, request=_request)
result = await info_key_fn(key=generated_key)
print("result from info_key_fn", result)
assert result["key"] == generated_key
print("\n info for key=", result["info"])
assert result["info"]["metadata"] == {
"team": "litellm-team3",
"guardrails": ["aporia-pre-call", "aporia-post-call"],
}
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_team_access_groups(prisma_client): async def test_team_access_groups(prisma_client):
""" """

View file

@ -132,6 +132,7 @@ def test_ollama_aembeddings(mock_aembeddings):
# test_ollama_aembeddings() # test_ollama_aembeddings()
@pytest.mark.skip(reason="local only test")
def test_ollama_chat_function_calling(): def test_ollama_chat_function_calling():
import json import json

View file

@ -313,3 +313,78 @@ def test_anthropic_cache_controls_pt():
assert msg["content"][0]["cache_control"] == {"type": "ephemeral"} assert msg["content"][0]["cache_control"] == {"type": "ephemeral"}
print("translated_messages: ", translated_messages) print("translated_messages: ", translated_messages)
@pytest.mark.parametrize("provider", ["bedrock", "anthropic"])
def test_bedrock_parallel_tool_calling_pt(provider):
"""
Make sure parallel tool call blocks are merged correctly - https://github.com/BerriAI/litellm/issues/5277
"""
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
messages = [
{
"role": "user",
"content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
},
Message(
content="Here are the current weather conditions for San Francisco, Tokyo, and Paris:",
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
index=1,
function=Function(
arguments='{"city": "New York"}',
name="get_current_weather",
),
id="tooluse_XcqEBfm8R-2YVaPhDUHsPQ",
type="function",
),
ChatCompletionMessageToolCall(
index=2,
function=Function(
arguments='{"city": "London"}',
name="get_current_weather",
),
id="tooluse_VB9nk7UGRniVzGcaj6xrAQ",
type="function",
),
],
function_call=None,
),
{
"tool_call_id": "tooluse_XcqEBfm8R-2YVaPhDUHsPQ",
"role": "tool",
"name": "get_current_weather",
"content": "25 degrees celsius.",
},
{
"tool_call_id": "tooluse_VB9nk7UGRniVzGcaj6xrAQ",
"role": "tool",
"name": "get_current_weather",
"content": "28 degrees celsius.",
},
]
if provider == "bedrock":
translated_messages = _bedrock_converse_messages_pt(
messages=messages,
model="anthropic.claude-3-sonnet-20240229-v1:0",
llm_provider="bedrock",
)
else:
translated_messages = anthropic_messages_pt(
messages=messages,
model="claude-3-sonnet-20240229-v1:0",
llm_provider=provider,
)
print(translated_messages)
number_of_messages = len(translated_messages)
# assert last 2 messages are not the same role
assert (
translated_messages[number_of_messages - 1]["role"]
!= translated_messages[number_of_messages - 2]["role"]
)

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.43.18" version = "1.43.19"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.43.18" version = "1.43.19"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -22,10 +22,6 @@ async def chat_completion(
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": messages,
"guardrails": [
"aporia-post-guard",
"aporia-pre-guard",
], # default guardrails for all tests
} }
if guardrails is not None: if guardrails is not None:
@ -41,7 +37,7 @@ async def chat_completion(
print() print()
if status != 200: if status != 200:
return response_text raise Exception(response_text)
# response headers # response headers
response_headers = response.headers response_headers = response.headers
@ -50,6 +46,29 @@ async def chat_completion(
return await response.json(), response_headers return await response.json(), response_headers
async def generate_key(session, guardrails):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
if guardrails:
data = {
"guardrails": guardrails,
}
else:
data = {}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_llm_guard_triggered_safe_request(): async def test_llm_guard_triggered_safe_request():
""" """
@ -62,6 +81,10 @@ async def test_llm_guard_triggered_safe_request():
"sk-1234", "sk-1234",
model="fake-openai-endpoint", model="fake-openai-endpoint",
messages=[{"role": "user", "content": f"Hello what's the weather"}], messages=[{"role": "user", "content": f"Hello what's the weather"}],
guardrails=[
"aporia-post-guard",
"aporia-pre-guard",
],
) )
await asyncio.sleep(3) await asyncio.sleep(3)
@ -90,6 +113,10 @@ async def test_llm_guard_triggered():
messages=[ messages=[
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"} {"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
], ],
guardrails=[
"aporia-post-guard",
"aporia-pre-guard",
],
) )
pytest.fail("Should have thrown an exception") pytest.fail("Should have thrown an exception")
except Exception as e: except Exception as e:
@ -116,3 +143,54 @@ async def test_no_llm_guard_triggered():
print("response=", response, "response headers", headers) print("response=", response, "response headers", headers)
assert "x-litellm-applied-guardrails" not in headers assert "x-litellm-applied-guardrails" not in headers
@pytest.mark.asyncio
async def test_guardrails_with_api_key_controls():
"""
- Make two API Keys
- Key 1 with no guardrails
- Key 2 with guardrails
- Request to Key 1 -> should be success with no guardrails
- Request to Key 2 -> should be error since guardrails are triggered
"""
async with aiohttp.ClientSession() as session:
key_with_guardrails = await generate_key(
session=session,
guardrails=[
"aporia-post-guard",
"aporia-pre-guard",
],
)
key_with_guardrails = key_with_guardrails["key"]
key_without_guardrails = await generate_key(session=session, guardrails=None)
key_without_guardrails = key_without_guardrails["key"]
# test no guardrails triggered for key without guardrails
response, headers = await chat_completion(
session,
key_without_guardrails,
model="fake-openai-endpoint",
messages=[{"role": "user", "content": f"Hello what's the weather"}],
)
await asyncio.sleep(3)
print("response=", response, "response headers", headers)
assert "x-litellm-applied-guardrails" not in headers
# test guardrails triggered for key with guardrails
try:
response, headers = await chat_completion(
session,
key_with_guardrails,
model="fake-openai-endpoint",
messages=[
{"role": "user", "content": f"Hello my name is ishaan@berri.ai"}
],
)
pytest.fail("Should have thrown an exception")
except Exception as e:
print(e)
assert "Aporia detected and blocked PII" in str(e)