Merge branch 'main' into litellm_redis_cluster

This commit is contained in:
Krish Dholakia 2024-08-22 11:06:14 -07:00 committed by GitHub
commit 68cb5cae58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 2079 additions and 411 deletions

View file

@ -282,7 +282,7 @@ jobs:
pip install "pytest==7.3.1" pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1" pip install "pytest-asyncio==0.21.1"
pip install aiohttp pip install aiohttp
pip install openai pip install "openai==1.40.0"
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install "pydantic==2.7.1" pip install "pydantic==2.7.1"
pip install "pytest==7.3.1" pip install "pytest==7.3.1"

View file

@ -13,10 +13,11 @@ spec:
{{- include "litellm.selectorLabels" . | nindent 6 }} {{- include "litellm.selectorLabels" . | nindent 6 }}
template: template:
metadata: metadata:
{{- with .Values.podAnnotations }}
annotations: annotations:
checksum/config: {{ include (print $.Template.BasePath "/configmap-litellm.yaml") . | sha256sum }}
{{- with .Values.podAnnotations }}
{{- toYaml . | nindent 8 }} {{- toYaml . | nindent 8 }}
{{- end }} {{- end }}
labels: labels:
{{- include "litellm.labels" . | nindent 8 }} {{- include "litellm.labels" . | nindent 8 }}
{{- with .Values.podLabels }} {{- with .Values.podLabels }}

View file

@ -81,6 +81,7 @@ Works for:
```python ```python
import os import os
from litellm import completion from litellm import completion
from pydantic import BaseModel
# add to env var # add to env var
os.environ["OPENAI_API_KEY"] = "" os.environ["OPENAI_API_KEY"] = ""

View file

@ -8,6 +8,7 @@ liteLLM supports:
- [Custom Callback Functions](https://docs.litellm.ai/docs/observability/custom_callback) - [Custom Callback Functions](https://docs.litellm.ai/docs/observability/custom_callback)
- [Langfuse](https://langfuse.com/docs) - [Langfuse](https://langfuse.com/docs)
- [LangSmith](https://www.langchain.com/langsmith)
- [Helicone](https://docs.helicone.ai/introduction) - [Helicone](https://docs.helicone.ai/introduction)
- [Traceloop](https://traceloop.com/docs) - [Traceloop](https://traceloop.com/docs)
- [Lunary](https://lunary.ai/docs) - [Lunary](https://lunary.ai/docs)

View file

@ -56,7 +56,7 @@ response = litellm.completion(
``` ```
## Advanced ## Advanced
### Set Langsmith fields - Custom Projec, Run names, tags ### Set Langsmith fields
```python ```python
import litellm import litellm
@ -75,9 +75,17 @@ response = litellm.completion(
{"role": "user", "content": "Hi 👋 - i'm openai"} {"role": "user", "content": "Hi 👋 - i'm openai"}
], ],
metadata={ metadata={
"run_name": "litellmRUN", # langsmith run name "run_name": "litellmRUN", # langsmith run name
"project_name": "litellm-completion", # langsmith project name "project_name": "litellm-completion", # langsmith project name
"tags": ["model1", "prod-2"] # tags to log on langsmith "run_id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", # langsmith run id
"parent_run_id": "f8faf8c1-9778-49a4-9004-628cdb0047e5", # langsmith run parent run id
"trace_id": "df570c03-5a03-4cea-8df0-c162d05127ac", # langsmith run trace id
"session_id": "1ffd059c-17ea-40a8-8aef-70fd0307db82", # langsmith run session id
"tags": ["model1", "prod-2"], # langsmith run tags
"metadata": { # langsmith run metadata
"key1": "value1"
},
"dotted_order": "20240429T004912090000Z497f6eca-6276-4993-bfeb-53cbbbba6f08"
} }
) )
print(response) print(response)

View file

@ -1,6 +1,10 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# [BETA] Vertex AI Endpoints (Pass-Through) # [BETA] Vertex AI Endpoints (Pass-Through)
Pass-through endpoints for Vertex AI - call provider-specific endpoint, in native format (no translation). Use VertexAI SDK to call endpoints on LiteLLM Gateway (native provider format)
:::tip :::tip
@ -40,16 +44,119 @@ litellm --config /path/to/config.yaml
#### 3. Test it #### 3. Test it
```shell ```python
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:countTokens \ import vertexai
-H "Content-Type: application/json" \ from google.auth.credentials import Credentials
-H "Authorization: Bearer sk-1234" \ from vertexai.generative_models import GenerativeModel
-d '{"instances":[{"content": "gm"}]}'
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = GenerativeModel("gemini-1.5-flash-001")
response = model.generate_content(
"What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
)
print(response.text)
``` ```
## Usage Examples ## Usage Examples
### Gemini API (Generate Content) ### Gemini API (Generate Content)
<Tabs>
<TabItem value="py" label="Vertex Python SDK">
```python
import vertexai
from google.auth.credentials import Credentials
from vertexai.generative_models import GenerativeModel
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = GenerativeModel("gemini-1.5-flash-001")
response = model.generate_content(
"What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
)
print(response.text)
```
</TabItem>
<TabItem value="Curl" label="Curl">
```shell ```shell
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -57,8 +164,77 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-0
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
``` ```
</TabItem>
</Tabs>
### Embeddings API ### Embeddings API
<Tabs>
<TabItem value="py" label="Vertex Python SDK">
```python
from typing import List, Optional
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
import vertexai
from google.auth.credentials import Credentials
from vertexai.generative_models import GenerativeModel
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
def embed_text(
texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"],
task: str = "RETRIEVAL_DOCUMENT",
model_name: str = "text-embedding-004",
dimensionality: Optional[int] = 256,
) -> List[List[float]]:
"""Embeds texts with a pre-trained, foundational model."""
model = TextEmbeddingModel.from_pretrained(model_name)
inputs = [TextEmbeddingInput(text, task) for text in texts]
kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
embeddings = model.get_embeddings(inputs, **kwargs)
return [embedding.values for embedding in embeddings]
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell ```shell
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \ curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -66,8 +242,86 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-geck
-d '{"instances":[{"content": "gm"}]}' -d '{"instances":[{"content": "gm"}]}'
``` ```
</TabItem>
</Tabs>
### Imagen API ### Imagen API
<Tabs>
<TabItem value="py" label="Vertex Python SDK">
```python
from typing import List, Optional
from vertexai.preview.vision_models import ImageGenerationModel
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001")
images = model.generate_images(
prompt=prompt,
# Optional parameters
number_of_images=1,
language="en",
# You can't use a seed value and watermark at the same time.
# add_watermark=False,
# seed=100,
aspect_ratio="1:1",
safety_filter_level="block_some",
person_generation="allow_adult",
)
images[0].save(location=output_file, include_generation_parameters=False)
# Optional. View the generated image in a notebook.
# images[0].show()
print(f"Created output image using {len(images[0]._image_bytes)} bytes")
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell ```shell
curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \ curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -75,8 +329,86 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generat
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}' -d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
``` ```
</TabItem>
</Tabs>
### Count Tokens API ### Count Tokens API
<Tabs>
<TabItem value="py" label="Vertex Python SDK">
```python
from typing import List, Optional
from vertexai.generative_models import GenerativeModel
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = GenerativeModel("gemini-1.5-flash-001")
prompt = "Why is the sky blue?"
# Prompt tokens count
response = model.count_tokens(prompt)
print(f"Prompt Token Count: {response.total_tokens}")
print(f"Prompt Character Count: {response.total_billable_characters}")
# Send text to Gemini
response = model.generate_content(prompt)
# Response tokens count
usage_metadata = response.usage_metadata
print(f"Prompt Token Count: {usage_metadata.prompt_token_count}")
print(f"Candidates Token Count: {usage_metadata.candidates_token_count}")
print(f"Total Token Count: {usage_metadata.total_token_count}")
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell ```shell
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -84,10 +416,83 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-0
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
``` ```
</TabItem>
</Tabs>
### Tuning API ### Tuning API
Create Fine Tuning Job Create Fine Tuning Job
<Tabs>
<TabItem value="py" label="Vertex Python SDK">
```python
from typing import List, Optional
from vertexai.preview.tuning import sft
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
# TODO(developer): Update project
vertexai.init(project=PROJECT_ID, location="us-central1")
sft_tuning_job = sft.train(
source_model="gemini-1.0-pro-002",
train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
)
# Polling for job completion
while not sft_tuning_job.has_ended:
time.sleep(60)
sft_tuning_job.refresh()
print(sft_tuning_job.tuned_model_name)
print(sft_tuning_job.tuned_model_endpoint_name)
print(sft_tuning_job.experiment)
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell ```shell
curl http://localhost:4000/vertex-ai/tuningJobs \ curl http://localhost:4000/vertex-ai/tuningJobs \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -98,4 +503,8 @@ curl http://localhost:4000/vertex-ai/tuningJobs \
"training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" "training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
} }
}' }'
``` ```
</TabItem>
</Tabs>

View file

@ -131,6 +131,56 @@ Expected Response
} }
``` ```
## Add Streaming Support
Here's a simple example of returning unix epoch seconds for both completion + streaming use-cases.
s/o [@Eloy Lafuente](https://github.com/stronk7) for this code example.
```python
import time
from typing import Iterator, AsyncIterator
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from litellm import CustomLLM, completion, acompletion
class UnixTimeLLM(CustomLLM):
def completion(self, *args, **kwargs) -> ModelResponse:
return completion(
model="test/unixtime",
mock_response=str(int(time.time())),
) # type: ignore
async def acompletion(self, *args, **kwargs) -> ModelResponse:
return await acompletion(
model="test/unixtime",
mock_response=str(int(time.time())),
) # type: ignore
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": str(int(time.time())),
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
return generic_streaming_chunk # type: ignore
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": str(int(time.time())),
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk # type: ignore
unixtime = UnixTimeLLM()
```
## Custom Handler Spec ## Custom Handler Spec
```python ```python

View file

@ -661,6 +661,7 @@ Here's how to use Vertex AI with the LiteLLM Proxy Server
## Specifying Safety Settings ## Specifying Safety Settings
In certain use-cases you may need to make calls to the models and pass [safety settigns](https://ai.google.dev/docs/safety_setting_gemini) different from the defaults. To do so, simple pass the `safety_settings` argument to `completion` or `acompletion`. For example: In certain use-cases you may need to make calls to the models and pass [safety settigns](https://ai.google.dev/docs/safety_setting_gemini) different from the defaults. To do so, simple pass the `safety_settings` argument to `completion` or `acompletion`. For example:
### Set per model/request
<Tabs> <Tabs>
@ -752,6 +753,65 @@ response = client.chat.completions.create(
</TabItem> </TabItem>
</Tabs> </Tabs>
### Set Globally
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import litellm
litellm.set_verbose = True 👈 See RAW REQUEST/RESPONSE
litellm.vertex_ai_safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
]
response = completion(
model="vertex_ai/gemini-pro",
messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]
)
```
</TabItem>
<TabItem value="proxy" label="Proxy">
```yaml
model_list:
- model_name: gemini-experimental
litellm_params:
model: vertex_ai/gemini-experimental
vertex_project: litellm-epic
vertex_location: us-central1
litellm_settings:
vertex_ai_safety_settings:
- category: HARM_CATEGORY_HARASSMENT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_HATE_SPEECH
threshold: BLOCK_NONE
- category: HARM_CATEGORY_SEXUALLY_EXPLICIT
threshold: BLOCK_NONE
- category: HARM_CATEGORY_DANGEROUS_CONTENT
threshold: BLOCK_NONE
```
</TabItem>
</Tabs>
## Set Vertex Project & Vertex Location ## Set Vertex Project & Vertex Location
All calls using Vertex AI require the following parameters: All calls using Vertex AI require the following parameters:
* Your Project ID * Your Project ID
@ -1450,7 +1510,7 @@ curl http://0.0.0.0:4000/v1/chat/completions \
| code-gecko@latest| `completion('code-gecko@latest', messages)` | | code-gecko@latest| `completion('code-gecko@latest', messages)` |
## Embedding Models ## **Embedding Models**
#### Usage - Embedding #### Usage - Embedding
```python ```python
@ -1504,7 +1564,158 @@ response = litellm.embedding(
) )
``` ```
## Image Generation Models ## **Multi-Modal Embeddings**
Usage
<Tabs>
<TabItem value="sdk" label="SDK">
```python
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=[
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
},
],
)
```
</TabItem>
<TabItem value="proxy" label="LiteLLM PROXY (Unified Endpoint)">
1. Add model to config.yaml
```yaml
model_list:
- model_name: multimodalembedding@001
litellm_params:
model: vertex_ai/multimodalembedding@001
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
litellm_settings:
drop_params: True
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request use OpenAI Python SDK
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
# # request sent to model set on litellm proxy, `litellm --model`
response = client.embeddings.create(
model="multimodalembedding@001",
input = None,
extra_body = {
"instances": [
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
},
],
}
)
print(response)
```
</TabItem>
<TabItem value="proxy-vtx" label="LiteLLM PROXY (Vertex SDK)">
1. Add model to config.yaml
```yaml
default_vertex_config:
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request use OpenAI Python SDK
```python
import vertexai
from vertexai.vision_models import Image, MultiModalEmbeddingModel, Video
from vertexai.vision_models import VideoSegmentConfig
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers['Authorization'] = f'Bearer {self.token}'
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials = credentials,
api_transport="rest",
)
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
image = Image.load_from_file(
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
)
embeddings = model.get_embeddings(
image=image,
contextual_text="Colosseum",
dimension=1408,
)
print(f"Image Embedding: {embeddings.image_embedding}")
print(f"Text Embedding: {embeddings.text_embedding}")
```
</TabItem>
</Tabs>
## **Image Generation Models**
Usage Usage

View file

@ -728,6 +728,7 @@ general_settings:
"disable_spend_logs": "boolean", # turn off writing each transaction to the db "disable_spend_logs": "boolean", # turn off writing each transaction to the db
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint) "disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
"disable_reset_budget": "boolean", # turn off reset budget scheduled task "disable_reset_budget": "boolean", # turn off reset budget scheduled task
"disable_adding_master_key_hash_to_db": "boolean", # turn off storing master key hash in db, for spend tracking
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims "enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
"enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param "enforce_user_param": "boolean", # requires all openai endpoint requests to have a 'user' param
"allowed_routes": "list", # list of allowed proxy API routes - a user can access. (currently JWT-Auth only) "allowed_routes": "list", # list of allowed proxy API routes - a user can access. (currently JWT-Auth only)

View file

@ -101,8 +101,38 @@ Requirements:
<Tabs> <Tabs>
<TabItem value="key" label="Set on Key">
```bash
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"tags": ["tag1", "tag2", "tag3"]
}
}
'
```
</TabItem>
<TabItem value="team" label="Set on Team">
```bash
curl -L -X POST 'http://0.0.0.0:4000/team/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"tags": ["tag1", "tag2", "tag3"]
}
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI Python v1.0.0+"> <TabItem value="openai" label="OpenAI Python v1.0.0+">
Set `extra_body={"metadata": { }}` to `metadata` you want to pass Set `extra_body={"metadata": { }}` to `metadata` you want to pass
@ -270,7 +300,42 @@ Requirements:
<Tabs> <Tabs>
<TabItem value="key" label="Set on Key">
```bash
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"spend_logs_metadata": {
"hello": "world"
}
}
}
'
```
</TabItem>
<TabItem value="team" label="Set on Team">
```bash
curl -L -X POST 'http://0.0.0.0:4000/team/new' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"spend_logs_metadata": {
"hello": "world"
}
}
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI Python v1.0.0+"> <TabItem value="openai" label="OpenAI Python v1.0.0+">

View file

@ -61,6 +61,51 @@ litellm_settings:
Removes any field with `user_api_key_*` from metadata. Removes any field with `user_api_key_*` from metadata.
## What gets logged?
Found under `kwargs["standard_logging_payload"]`. This is a standard payload, logged for every response.
```python
class StandardLoggingPayload(TypedDict):
id: str
call_type: str
response_cost: float
total_tokens: int
prompt_tokens: int
completion_tokens: int
startTime: float
endTime: float
completionStartTime: float
model_map_information: StandardLoggingModelInformation
model: str
model_id: Optional[str]
model_group: Optional[str]
api_base: str
metadata: StandardLoggingMetadata
cache_hit: Optional[bool]
cache_key: Optional[str]
saved_cache_cost: Optional[float]
request_tags: list
end_user: Optional[str]
requester_ip_address: Optional[str]
messages: Optional[Union[str, list, dict]]
response: Optional[Union[str, list, dict]]
model_parameters: dict
hidden_params: StandardLoggingHiddenParams
class StandardLoggingHiddenParams(TypedDict):
model_id: Optional[str]
cache_key: Optional[str]
api_base: Optional[str]
response_cost: Optional[str]
additional_headers: Optional[dict]
class StandardLoggingModelInformation(TypedDict):
model_map_key: str
model_map_value: Optional[ModelInfo]
```
## Logging Proxy Input/Output - Langfuse ## Logging Proxy Input/Output - Langfuse
We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your environment We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this will log all successfull LLM calls to langfuse. Make sure to set `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` in your environment

View file

@ -333,4 +333,5 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
``` ```
Key=... over available RPM=0. Model RPM=100, Active keys=None Key=... over available RPM=0. Model RPM=100, Active keys=None
``` ```

View file

@ -488,9 +488,34 @@ You can set:
<Tabs> <Tabs>
<TabItem value="per-team" label="Per Team">
Use `/team/new` or `/team/update`, to persist rate limits across multiple keys for a team.
```shell
curl --location 'http://0.0.0.0:4000/team/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{"team_id": "my-prod-team", "max_parallel_requests": 10, "tpm_limit": 20, "rpm_limit": 4}'
```
[**See Swagger**](https://litellm-api.up.railway.app/#/team%20management/new_team_team_new_post)
**Expected Response**
```json
{
"key": "sk-sA7VDkyhlQ7m8Gt77Mbt3Q",
"expires": "2024-01-19T01:21:12.816168",
"team_id": "my-prod-team",
}
```
</TabItem>
<TabItem value="per-user" label="Per Internal User"> <TabItem value="per-user" label="Per Internal User">
Use `/user/new`, to persist rate limits across multiple keys. Use `/user/new` or `/user/update`, to persist rate limits across multiple keys for internal users.
```shell ```shell
@ -653,6 +678,70 @@ curl --location 'http://localhost:4000/chat/completions' \
</TabItem> </TabItem>
</Tabs> </Tabs>
## Set default budget for ALL internal users
Use this to set a default budget for users who you give keys to.
This will apply when a user has [`user_role="internal_user"`](./self_serve.md#available-roles) (set this via `/user/new` or `/user/update`).
This will NOT apply if a key has a team_id (team budgets will apply then). [Tell us how we can improve this!](https://github.com/BerriAI/litellm/issues)
1. Define max budget in your config.yaml
```yaml
model_list:
- model_name: "gpt-3.5-turbo"
litellm_params:
model: gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
max_internal_user_budget: 0 # amount in USD
internal_user_budget_duration: "1mo" # reset every month
```
2. Create key for user
```bash
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{}'
```
Expected Response:
```bash
{
...
"key": "sk-X53RdxnDhzamRwjKXR4IHg"
}
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-X53RdxnDhzamRwjKXR4IHg' \
-d '{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
}'
```
Expected Response:
```bash
{
"error": {
"message": "ExceededBudget: User=<user_id> over budget. Spend=3.7e-05, Budget=0.0",
"type": "budget_exceeded",
"param": null,
"code": "400"
}
}
```
## Grant Access to new model ## Grant Access to new model
Use model access groups to give users access to select models, and add new ones to it over time (e.g. mistral, llama-2, etc.). Use model access groups to give users access to select models, and add new ones to it over time (e.g. mistral, llama-2, etc.).

View file

@ -6,7 +6,7 @@
"": { "": {
"dependencies": { "dependencies": {
"@hono/node-server": "^1.10.1", "@hono/node-server": "^1.10.1",
"hono": "^4.2.7" "hono": "^4.5.8"
}, },
"devDependencies": { "devDependencies": {
"@types/node": "^20.11.17", "@types/node": "^20.11.17",
@ -463,9 +463,9 @@
} }
}, },
"node_modules/hono": { "node_modules/hono": {
"version": "4.2.7", "version": "4.5.8",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz", "resolved": "https://registry.npmjs.org/hono/-/hono-4.5.8.tgz",
"integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==", "integrity": "sha512-pqpSlcdqGkpTTRpLYU1PnCz52gVr0zVR9H5GzMyJWuKQLLEBQxh96q45QizJ2PPX8NATtz2mu31/PKW/Jt+90Q==",
"engines": { "engines": {
"node": ">=16.0.0" "node": ">=16.0.0"
} }

View file

@ -4,7 +4,7 @@
}, },
"dependencies": { "dependencies": {
"@hono/node-server": "^1.10.1", "@hono/node-server": "^1.10.1",
"hono": "^4.2.7" "hono": "^4.5.8"
}, },
"devDependencies": { "devDependencies": {
"@types/node": "^20.11.17", "@types/node": "^20.11.17",

View file

@ -339,6 +339,7 @@ api_version = None
organization = None organization = None
project = None project = None
config_path = None config_path = None
vertex_ai_safety_settings: Optional[dict] = None
####### COMPLETION MODELS ################### ####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = [] open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = [] open_ai_text_completion_models: List = []

View file

@ -98,6 +98,10 @@ class LangsmithLogger(CustomLogger):
project_name = metadata.get("project_name", self.langsmith_project) project_name = metadata.get("project_name", self.langsmith_project)
run_name = metadata.get("run_name", self.langsmith_default_run_name) run_name = metadata.get("run_name", self.langsmith_default_run_name)
run_id = metadata.get("id", None) run_id = metadata.get("id", None)
parent_run_id = metadata.get("parent_run_id", None)
trace_id = metadata.get("trace_id", None)
session_id = metadata.get("session_id", None)
dotted_order = metadata.get("dotted_order", None)
tags = metadata.get("tags", []) or [] tags = metadata.get("tags", []) or []
verbose_logger.debug( verbose_logger.debug(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}" f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
@ -149,6 +153,18 @@ class LangsmithLogger(CustomLogger):
if run_id: if run_id:
data["id"] = run_id data["id"] = run_id
if parent_run_id:
data["parent_run_id"] = parent_run_id
if trace_id:
data["trace_id"] = trace_id
if session_id:
data["session_id"] = session_id
if dotted_order:
data["dotted_order"] = dotted_order
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data) verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
return data return data

View file

@ -524,6 +524,7 @@ class Logging:
TextCompletionResponse, TextCompletionResponse,
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
], ],
cache_hit: Optional[bool] = None,
): ):
""" """
Calculate response cost using result + logging object variables. Calculate response cost using result + logging object variables.
@ -535,10 +536,13 @@ class Logging:
litellm_params=self.litellm_params litellm_params=self.litellm_params
) )
if cache_hit is None:
cache_hit = self.model_call_details.get("cache_hit", False)
response_cost = litellm.response_cost_calculator( response_cost = litellm.response_cost_calculator(
response_object=result, response_object=result,
model=self.model, model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False), cache_hit=cache_hit,
custom_llm_provider=self.model_call_details.get( custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None "custom_llm_provider", None
), ),
@ -630,6 +634,7 @@ class Logging:
init_response_obj=result, init_response_obj=result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
logging_obj=self,
) )
) )
return start_time, end_time, result return start_time, end_time, result
@ -2181,6 +2186,7 @@ def get_standard_logging_object_payload(
init_response_obj: Any, init_response_obj: Any,
start_time: dt_object, start_time: dt_object,
end_time: dt_object, end_time: dt_object,
logging_obj: Logging,
) -> Optional[StandardLoggingPayload]: ) -> Optional[StandardLoggingPayload]:
try: try:
if kwargs is None: if kwargs is None:
@ -2277,11 +2283,17 @@ def get_standard_logging_object_payload(
cache_key = litellm.cache.get_cache_key(**kwargs) cache_key = litellm.cache.get_cache_key(**kwargs)
else: else:
cache_key = None cache_key = None
saved_cache_cost: Optional[float] = None
if cache_hit is True: if cache_hit is True:
import time import time
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False
)
## Get model cost information ## ## Get model cost information ##
base_model = _get_base_model_from_metadata(model_call_details=kwargs) base_model = _get_base_model_from_metadata(model_call_details=kwargs)
custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params) custom_pricing = use_custom_pricing_for_model(litellm_params=litellm_params)
@ -2318,6 +2330,7 @@ def get_standard_logging_object_payload(
id=str(id), id=str(id),
call_type=call_type or "", call_type=call_type or "",
cache_hit=cache_hit, cache_hit=cache_hit,
saved_cache_cost=saved_cache_cost,
startTime=start_time_float, startTime=start_time_float,
endTime=end_time_float, endTime=end_time_float,
completionStartTime=completion_start_time_float, completionStartTime=completion_start_time_float,

View file

@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
"meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0", "meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0", "meta.llama3-1-405b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"mistral.mistral-large-2407-v1:0", "mistral.mistral-large-2407-v1:0",
] ]
@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
optional_params: dict, optional_params: dict,
acompletion: bool, acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None, litellm_params: dict,
logger_fn=None, logger_fn=None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
supported_tool_call_params = ["tools", "tool_choice"] supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"] supported_guardrail_params = ["guardrailConfig"]
## TRANSFORMATION ## ## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages,
model=model,
llm_provider="bedrock_converse",
user_continue_message=litellm_params.pop("user_continue_message", None),
)
# send all model-specific params in 'additional_request_params' # send all model-specific params in 'additional_request_params'
for k, v in inference_params.items(): for k, v in inference_params.items():
if ( if (
@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
for key in additional_request_keys: for key in additional_request_keys:
inference_params.pop(key, None) inference_params.pop(key, None)
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages,
model=model,
llm_provider="bedrock_converse",
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", []) inference_params.pop("tools", [])
) )

View file

@ -124,12 +124,14 @@ class CohereConfig:
} }
def validate_environment(api_key): def validate_environment(api_key, headers: dict):
headers = { headers.update(
"Request-Source": "unspecified:litellm", {
"accept": "application/json", "Request-Source": "unspecified:litellm",
"content-type": "application/json", "accept": "application/json",
} "content-type": "application/json",
}
)
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
@ -144,11 +146,12 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
headers: dict,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key, headers=headers)
completion_url = api_base completion_url = api_base
model = model model = model
prompt = " ".join(message["content"] for message in messages) prompt = " ".join(message["content"] for message in messages)
@ -338,13 +341,14 @@ def embedding(
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
optional_params: dict, optional_params: dict,
headers: dict,
encoding: Any, encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
aembedding: Optional[bool] = None, aembedding: Optional[bool] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key, headers=headers)
embed_url = "https://api.cohere.ai/v1/embed" embed_url = "https://api.cohere.ai/v1/embed"
model = model model = model
data = {"model": model, "texts": input, **optional_params} data = {"model": model, "texts": input, **optional_params}

View file

@ -116,12 +116,14 @@ class CohereChatConfig:
} }
def validate_environment(api_key): def validate_environment(api_key, headers: dict):
headers = { headers.update(
"Request-Source": "unspecified:litellm", {
"accept": "application/json", "Request-Source": "unspecified:litellm",
"content-type": "application/json", "accept": "application/json",
} "content-type": "application/json",
}
)
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
@ -203,13 +205,14 @@ def completion(
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
optional_params: dict, optional_params: dict,
headers: dict,
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key, headers=headers)
completion_url = api_base completion_url = api_base
model = model model = model
most_recent_message, chat_history = cohere_messages_pt_v2( most_recent_message, chat_history = cohere_messages_pt_v2(

View file

@ -4,14 +4,17 @@ import traceback
import types import types
import uuid import uuid
from itertools import chain from itertools import chain
from typing import Optional from typing import List, Optional
import aiohttp import aiohttp
import httpx import httpx
import requests import requests
from pydantic import BaseModel
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
class OllamaError(Exception): class OllamaError(Exception):
@ -175,7 +178,7 @@ class OllamaChatConfig:
## CHECK IF MODEL SUPPORTS TOOL CALLING ## ## CHECK IF MODEL SUPPORTS TOOL CALLING ##
try: try:
model_info = litellm.get_model_info( model_info = litellm.get_model_info(
model=model, custom_llm_provider="ollama_chat" model=model, custom_llm_provider="ollama"
) )
if model_info.get("supports_function_calling") is True: if model_info.get("supports_function_calling") is True:
optional_params["tools"] = value optional_params["tools"] = value
@ -237,13 +240,30 @@ def get_ollama_response(
function_name = optional_params.pop("function_name", None) function_name = optional_params.pop("function_name", None)
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
new_messages = []
for m in messages: for m in messages:
if "role" in m and m["role"] == "tool": if isinstance(
m["role"] = "assistant" m, BaseModel
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
m = m.model_dump(exclude_none=True)
if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list):
new_tools: List[OllamaToolCall] = []
for tool in m["tool_calls"]:
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
if typed_tool["type"] == "function":
ollama_tool_call = OllamaToolCall(
function=OllamaToolCallFunction(
name=typed_tool["function"]["name"],
arguments=json.loads(typed_tool["function"]["arguments"]),
)
)
new_tools.append(ollama_tool_call)
m["tool_calls"] = new_tools
new_messages.append(m)
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": new_messages,
"options": optional_params, "options": optional_params,
"stream": stream, "stream": stream,
} }
@ -263,7 +283,7 @@ def get_ollama_response(
}, },
) )
if acompletion is True: if acompletion is True:
if stream == True: if stream is True:
response = ollama_async_streaming( response = ollama_async_streaming(
url=url, url=url,
api_key=api_key, api_key=api_key,
@ -283,7 +303,7 @@ def get_ollama_response(
function_name=function_name, function_name=function_name,
) )
return response return response
elif stream == True: elif stream is True:
return ollama_completion_stream( return ollama_completion_stream(
url=url, api_key=api_key, data=data, logging_obj=logging_obj url=url, api_key=api_key, data=data, logging_obj=logging_obj
) )

View file

@ -84,6 +84,8 @@ class MistralConfig:
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'. - `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
- `stop` (string or array of strings): Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results. - `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'. - `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
@ -99,6 +101,7 @@ class MistralConfig:
random_seed: Optional[int] = None random_seed: Optional[int] = None
safe_prompt: Optional[bool] = None safe_prompt: Optional[bool] = None
response_format: Optional[dict] = None response_format: Optional[dict] = None
stop: Optional[Union[str, list]] = None
def __init__( def __init__(
self, self,
@ -110,6 +113,7 @@ class MistralConfig:
random_seed: Optional[int] = None, random_seed: Optional[int] = None,
safe_prompt: Optional[bool] = None, safe_prompt: Optional[bool] = None,
response_format: Optional[dict] = None, response_format: Optional[dict] = None,
stop: Optional[Union[str, list]] = None
) -> None: ) -> None:
locals_ = locals().copy() locals_ = locals().copy()
for key, value in locals_.items(): for key, value in locals_.items():
@ -143,6 +147,7 @@ class MistralConfig:
"tools", "tools",
"tool_choice", "tool_choice",
"seed", "seed",
"stop",
"response_format", "response_format",
] ]
@ -166,6 +171,8 @@ class MistralConfig:
optional_params["temperature"] = value optional_params["temperature"] = value
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "stop":
optional_params["stop"] = value
if param == "tool_choice" and isinstance(value, str): if param == "tool_choice" and isinstance(value, str):
optional_params["tool_choice"] = self._map_tool_choice( optional_params["tool_choice"] = self._map_tool_choice(
tool_choice=value tool_choice=value

View file

@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
BAD_MESSAGE_ERROR_STR = "Invalid Message " BAD_MESSAGE_ERROR_STR = "Invalid Message "
# used to interweave user messages, to ensure user/assistant alternating
DEFAULT_USER_CONTINUE_MESSAGE = {
"role": "user",
"content": "Please continue.",
} # similar to autogen. Only used if `litellm.modify_params=True`.
# used to interweave assistant messages, to ensure user/assistant alternating
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
"role": "assistant",
"content": "Please continue.",
} # similar to autogen. Only used if `litellm.modify_params=True`.
def map_system_message_pt(messages: list) -> list: def map_system_message_pt(messages: list) -> list:
""" """
@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
messages: List, messages: List,
model: str, model: str,
llm_provider: str, llm_provider: str,
user_continue_message: Optional[dict] = None,
) -> List[BedrockMessageBlock]: ) -> List[BedrockMessageBlock]:
""" """
Converts given messages from OpenAI format to Bedrock format Converts given messages from OpenAI format to Bedrock format
@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
contents: List[BedrockMessageBlock] = [] contents: List[BedrockMessageBlock] = []
msg_i = 0 msg_i = 0
# if initial message is assistant message
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
if user_continue_message is not None:
messages.insert(0, user_continue_message)
elif litellm.modify_params:
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
# if final message is assistant message
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
if user_continue_message is not None:
messages.append(user_continue_message)
elif litellm.modify_params:
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
while msg_i < len(messages): while msg_i < len(messages):
user_content: List[BedrockContentBlock] = [] user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i init_msg_i = msg_i
@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
model=model, model=model,
llm_provider=llm_provider, llm_provider=llm_provider,
) )
return contents return contents

View file

@ -9,7 +9,7 @@ import types
import uuid import uuid
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -38,12 +38,15 @@ from litellm.types.llms.vertex_ai import (
FunctionDeclaration, FunctionDeclaration,
GenerateContentResponseBody, GenerateContentResponseBody,
GenerationConfig, GenerationConfig,
Instance,
InstanceVideo,
PartType, PartType,
RequestBody, RequestBody,
SafetSettingsConfig, SafetSettingsConfig,
SystemInstructions, SystemInstructions,
ToolConfig, ToolConfig,
Tools, Tools,
VertexMultimodalEmbeddingRequest,
) )
from litellm.types.utils import GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -188,9 +191,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
elif value["type"] == "text": # type: ignore elif value["type"] == "text": # type: ignore
optional_params["response_mime_type"] = "text/plain" optional_params["response_mime_type"] = "text/plain"
if "response_schema" in value: # type: ignore if "response_schema" in value: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["response_schema"] # type: ignore optional_params["response_schema"] = value["response_schema"] # type: ignore
elif value["type"] == "json_schema": # type: ignore elif value["type"] == "json_schema": # type: ignore
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
if param == "tools" and isinstance(value, list): if param == "tools" and isinstance(value, list):
gtool_func_declarations = [] gtool_func_declarations = []
@ -400,9 +405,11 @@ class VertexGeminiConfig:
elif value["type"] == "text": elif value["type"] == "text":
optional_params["response_mime_type"] = "text/plain" optional_params["response_mime_type"] = "text/plain"
if "response_schema" in value: if "response_schema" in value:
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["response_schema"] optional_params["response_schema"] = value["response_schema"]
elif value["type"] == "json_schema": # type: ignore elif value["type"] == "json_schema": # type: ignore
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
if param == "frequency_penalty": if param == "frequency_penalty":
optional_params["frequency_penalty"] = value optional_params["frequency_penalty"] = value
@ -594,6 +601,10 @@ class VertexLLM(BaseLLM):
self._credentials: Optional[Any] = None self._credentials: Optional[Any] = None
self.project_id: Optional[str] = None self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None self.async_handler: Optional[AsyncHTTPHandler] = None
self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
"multimodalembedding",
"multimodalembedding@001",
]
def _process_response( def _process_response(
self, self,
@ -1537,6 +1548,160 @@ class VertexLLM(BaseLLM):
return model_response return model_response
def multimodal_embedding(
self,
model: str,
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
optional_params: dict,
api_key: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
):
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest()
if "instances" in optional_params:
request_data["instances"] = optional_params["instances"]
elif isinstance(input, list):
request_data["instances"] = input
else:
# construct instances
vertex_request_instance = Instance(**optional_params)
if isinstance(input, str):
vertex_request_instance["text"] = input
request_data["instances"] = [vertex_request_instance]
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
logging_obj.pre_call(
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
}
if aembedding is True:
return self.async_multimodal_embedding(
model=model,
api_base=url,
data=request_data,
timeout=timeout,
headers=headers,
client=client,
model_response=model_response,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
model_response.model = model
return model_response
async def async_multimodal_embedding(
self,
model: str,
api_base: str,
data: VertexMultimodalEmbeddingRequest,
model_response: litellm.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> litellm.EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
_json_response = response.json()
if "predictions" not in _json_response:
raise litellm.InternalServerError(
message=f"embedding response does not contain 'predictions', got {_json_response}",
llm_provider="vertex_ai",
model=model,
)
_predictions = _json_response["predictions"]
model_response.data = _predictions
model_response.model = model
return model_response
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool): def __init__(self, streaming_response, sync_stream: bool):

View file

@ -943,6 +943,7 @@ def completion(
output_cost_per_token=output_cost_per_token, output_cost_per_token=output_cost_per_token,
cooldown_time=cooldown_time, cooldown_time=cooldown_time,
text_completion=kwargs.get("text_completion"), text_completion=kwargs.get("text_completion"),
user_continue_message=kwargs.get("user_continue_message"),
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -1634,6 +1635,13 @@ def completion(
or "https://api.cohere.ai/v1/generate" or "https://api.cohere.ai/v1/generate"
) )
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere.completion( model_response = cohere.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -1644,6 +1652,7 @@ def completion(
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
headers=headers,
api_key=cohere_key, api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
@ -1674,6 +1683,13 @@ def completion(
or "https://api.cohere.ai/v1/chat" or "https://api.cohere.ai/v1/chat"
) )
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere_chat.completion( model_response = cohere_chat.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -1682,6 +1698,7 @@ def completion(
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
headers=headers,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
api_key=cohere_key, api_key=cohere_key,
@ -2288,7 +2305,7 @@ def completion(
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
@ -2464,7 +2481,7 @@ def completion(
model_response=model_response, model_response=model_response,
encoding=encoding, encoding=encoding,
) )
if acompletion is True or optional_params.get("stream", False) == True: if acompletion is True or optional_params.get("stream", False) is True:
return generator return generator
response = generator response = generator
@ -3158,6 +3175,7 @@ def embedding(
encoding_format = kwargs.get("encoding_format", None) encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None) aembedding = kwargs.get("aembedding", None)
extra_headers = kwargs.get("extra_headers", None)
### CUSTOM MODEL COST ### ### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None) input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -3229,6 +3247,7 @@ def embedding(
"model_config", "model_config",
"cooldown_time", "cooldown_time",
"tags", "tags",
"extra_headers",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
@ -3292,7 +3311,7 @@ def embedding(
"cooldown_time": cooldown_time, "cooldown_time": cooldown_time,
}, },
) )
if azure == True or custom_llm_provider == "azure": if azure is True or custom_llm_provider == "azure":
# azure configs # azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure" api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -3398,12 +3417,18 @@ def embedding(
or get_secret("CO_API_KEY") or get_secret("CO_API_KEY")
or litellm.api_key or litellm.api_key
) )
if extra_headers is not None and isinstance(extra_headers, dict):
headers = extra_headers
else:
headers = {}
response = cohere.embedding( response = cohere.embedding(
model=model, model=model,
input=input, input=input,
optional_params=optional_params, optional_params=optional_params,
encoding=encoding, encoding=encoding,
api_key=cohere_key, # type: ignore api_key=cohere_key, # type: ignore
headers=headers,
logging_obj=logging, logging_obj=logging,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
aembedding=aembedding, aembedding=aembedding,
@ -3477,19 +3502,39 @@ def embedding(
or get_secret("VERTEX_CREDENTIALS") or get_secret("VERTEX_CREDENTIALS")
) )
response = vertex_ai.embedding( if (
model=model, "image" in optional_params
input=input, or "video" in optional_params
encoding=encoding, or model in vertex_chat_completion.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
logging_obj=logging, ):
optional_params=optional_params, # multimodal embedding is supported on vertex httpx
model_response=EmbeddingResponse(), response = vertex_chat_completion.multimodal_embedding(
vertex_project=vertex_ai_project, model=model,
vertex_location=vertex_ai_location, input=input,
vertex_credentials=vertex_credentials, encoding=encoding,
aembedding=aembedding, logging_obj=logging,
print_verbose=print_verbose, optional_params=optional_params,
) model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
)
else:
response = vertex_ai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
)
elif custom_llm_provider == "oobabooga": elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding( response = oobabooga.embedding(
model=model, model=model,

View file

@ -3,9 +3,11 @@ model_list:
litellm_params: litellm_params:
model: "*" model: "*"
litellm_settings: litellm_settings:
cache: True success_callback: ["s3"]
cache_params: cache: true
type: redis s3_callback_params:
redis_startup_nodes: [{"host": "127.0.0.1", "port": "7001"}] s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3

View file

@ -21,6 +21,13 @@ else:
Span = Any Span = Any
class LiteLLMTeamRoles(enum.Enum):
# team admin
TEAM_ADMIN = "admin"
# team member
TEAM_MEMBER = "user"
class LitellmUserRoles(str, enum.Enum): class LitellmUserRoles(str, enum.Enum):
""" """
Admin Roles: Admin Roles:
@ -335,6 +342,11 @@ class LiteLLMRoutes(enum.Enum):
+ sso_only_routes + sso_only_routes
) )
self_managed_routes: List = [
"/team/member_add",
"/team/member_delete",
] # routes that manage their own allowed/disallowed logic
# class LiteLLMAllowedRoutes(LiteLLMBase): # class LiteLLMAllowedRoutes(LiteLLMBase):
# """ # """
@ -1308,6 +1320,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
soft_budget: Optional[float] = None soft_budget: Optional[float] = None
team_model_aliases: Optional[Dict] = None team_model_aliases: Optional[Dict] = None
team_member_spend: Optional[float] = None team_member_spend: Optional[float] = None
team_member: Optional[Member] = None
team_metadata: Optional[Dict] = None team_metadata: Optional[Dict] = None
# End User Params # End User Params

View file

@ -975,8 +975,6 @@ async def user_api_key_auth(
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
if is_llm_api_route(route=route): if is_llm_api_route(route=route):
pass pass
elif is_llm_api_route(route=request["route"].name):
pass
elif ( elif (
route in LiteLLMRoutes.info_routes.value route in LiteLLMRoutes.info_routes.value
): # check if user allowed to call an info route ): # check if user allowed to call an info route
@ -1046,11 +1044,16 @@ async def user_api_key_auth(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}", detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
) )
elif ( elif (
_user_role == LitellmUserRoles.INTERNAL_USER.value _user_role == LitellmUserRoles.INTERNAL_USER.value
and route in LiteLLMRoutes.internal_user_routes.value and route in LiteLLMRoutes.internal_user_routes.value
): ):
pass pass
elif (
route in LiteLLMRoutes.self_managed_routes.value
): # routes that manage their own allowed/disallowed logic
pass
else: else:
user_role = "unknown" user_role = "unknown"
user_id = "unknown" user_id = "unknown"

View file

@ -120,6 +120,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
max_parallel_requests = user_api_key_dict.max_parallel_requests max_parallel_requests = user_api_key_dict.max_parallel_requests
if max_parallel_requests is None: if max_parallel_requests is None:
max_parallel_requests = sys.maxsize max_parallel_requests = sys.maxsize
if data is None:
data = {}
global_max_parallel_requests = data.get("metadata", {}).get( global_max_parallel_requests = data.get("metadata", {}).get(
"global_max_parallel_requests", None "global_max_parallel_requests", None
) )

View file

@ -119,6 +119,7 @@ async def new_user(
http_request=Request( http_request=Request(
scope={"type": "http", "path": "/user/new"}, scope={"type": "http", "path": "/user/new"},
), ),
user_api_key_dict=user_api_key_dict,
) )
if data.send_invite_email is True: if data.send_invite_email is True:
@ -732,7 +733,7 @@ async def delete_user(
delete user and associated user keys delete user and associated user keys
``` ```
curl --location 'http://0.0.0.0:8000/team/delete' \ curl --location 'http://0.0.0.0:8000/user/delete' \
--header 'Authorization: Bearer sk-1234' \ --header 'Authorization: Bearer sk-1234' \

View file

@ -849,7 +849,7 @@ async def generate_key_helper_fn(
} }
if ( if (
litellm.get_secret("DISABLE_KEY_NAME", False) == True litellm.get_secret("DISABLE_KEY_NAME", False) is True
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much) ): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
pass pass
else: else:

View file

@ -30,7 +30,7 @@ from litellm.proxy._types import (
UpdateTeamRequest, UpdateTeamRequest,
UserAPIKeyAuth, UserAPIKeyAuth,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
from litellm.proxy.management_helpers.utils import ( from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
management_endpoint_wrapper, management_endpoint_wrapper,
@ -39,6 +39,16 @@ from litellm.proxy.management_helpers.utils import (
router = APIRouter() router = APIRouter()
def _is_user_team_admin(
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
) -> bool:
for member in team_obj.members_with_roles:
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
return True
return False
#### TEAM MANAGEMENT #### #### TEAM MANAGEMENT ####
@router.post( @router.post(
"/team/new", "/team/new",
@ -417,6 +427,7 @@ async def team_member_add(
If user doesn't exist, new user row will also be added to User Table If user doesn't exist, new user row will also be added to User Table
Only proxy_admin or admin of team, allowed to access this endpoint.
``` ```
curl -X POST 'http://0.0.0.0:4000/team/member_add' \ curl -X POST 'http://0.0.0.0:4000/team/member_add' \
@ -465,6 +476,25 @@ async def team_member_add(
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
if (
hasattr(user_api_key_dict, "user_role")
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
and not _is_user_team_admin(
user_api_key_dict=user_api_key_dict, team_obj=complete_team_data
)
):
raise HTTPException(
status_code=403,
detail={
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
"/team/member_add",
complete_team_data.team_id,
)
},
)
if isinstance(data.member, Member): if isinstance(data.member, Member):
# add to team db # add to team db
new_member = data.member new_member = data.member
@ -569,6 +599,23 @@ async def team_member_delete(
) )
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
and not _is_user_team_admin(
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
)
):
raise HTTPException(
status_code=403,
detail={
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
"/team/member_delete", existing_team_row.team_id
)
},
)
## DELETE MEMBER FROM TEAM ## DELETE MEMBER FROM TEAM
new_team_members: List[Member] = [] new_team_members: List[Member] = []
for m in existing_team_row.members_with_roles: for m in existing_team_row.members_with_roles:

View file

@ -266,7 +266,7 @@ def management_endpoint_wrapper(func):
) )
_http_request: Request = kwargs.get("http_request") _http_request: Request = kwargs.get("http_request")
parent_otel_span = user_api_key_dict.parent_otel_span parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
if parent_otel_span is not None: if parent_otel_span is not None:
from litellm.proxy.proxy_server import open_telemetry_logger from litellm.proxy.proxy_server import open_telemetry_logger
@ -310,7 +310,7 @@ def management_endpoint_wrapper(func):
user_api_key_dict: UserAPIKeyAuth = ( user_api_key_dict: UserAPIKeyAuth = (
kwargs.get("user_api_key_dict") or UserAPIKeyAuth() kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
) )
parent_otel_span = user_api_key_dict.parent_otel_span parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
if parent_otel_span is not None: if parent_otel_span is not None:
from litellm.proxy.proxy_server import open_telemetry_logger from litellm.proxy.proxy_server import open_telemetry_logger

View file

@ -301,16 +301,19 @@ async def pass_through_request(
request=request, headers=headers, forward_headers=forward_headers request=request, headers=headers, forward_headers=forward_headers
) )
_parsed_body = None
if custom_body: if custom_body:
_parsed_body = custom_body _parsed_body = custom_body
else: else:
request_body = await request.body() request_body = await request.body()
body_str = request_body.decode() if request_body == b"" or request_body is None:
try: _parsed_body = None
_parsed_body = ast.literal_eval(body_str) else:
except Exception: body_str = request_body.decode()
_parsed_body = json.loads(body_str) try:
_parsed_body = ast.literal_eval(body_str)
except Exception:
_parsed_body = json.loads(body_str)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format( "Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
url, headers, _parsed_body url, headers, _parsed_body
@ -320,7 +323,7 @@ async def pass_through_request(
### CALL HOOKS ### - modify incoming data / reject request before calling the model ### CALL HOOKS ### - modify incoming data / reject request before calling the model
_parsed_body = await proxy_logging_obj.pre_call_hook( _parsed_body = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
data=_parsed_body, data=_parsed_body or {},
call_type="pass_through_endpoint", call_type="pass_through_endpoint",
) )
@ -360,15 +363,24 @@ async def pass_through_request(
# combine url with query params for logging # combine url with query params for logging
requested_query_params = query_params or request.query_params.__dict__ requested_query_params: Optional[dict] = (
requested_query_params_str = "&".join( query_params or request.query_params.__dict__
f"{k}={v}" for k, v in requested_query_params.items()
) )
if requested_query_params == request.query_params.__dict__:
requested_query_params = None
if "?" in str(url): requested_query_params_str = None
logging_url = str(url) + "&" + requested_query_params_str if requested_query_params:
else: requested_query_params_str = "&".join(
logging_url = str(url) + "?" + requested_query_params_str f"{k}={v}" for k, v in requested_query_params.items()
)
logging_url = str(url)
if requested_query_params_str:
if "?" in str(url):
logging_url = str(url) + "&" + requested_query_params_str
else:
logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call( logging_obj.pre_call(
input=[{"role": "user", "content": "no-message-pass-through-endpoint"}], input=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
@ -409,6 +421,14 @@ async def pass_through_request(
status_code=response.status_code, status_code=response.status_code,
) )
verbose_proxy_logger.debug("request method: {}".format(request.method))
verbose_proxy_logger.debug("request url: {}".format(url))
verbose_proxy_logger.debug("request headers: {}".format(headers))
verbose_proxy_logger.debug(
"requested_query_params={}".format(requested_query_params)
)
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
response = await async_client.request( response = await async_client.request(
method=request.method, method=request.method,
url=url, url=url,

View file

@ -1,20 +1,18 @@
model_list: model_list:
- model_name: fake-openai-endpoint - model_name: gpt-4
litellm_params: litellm_params:
model: openai/fake model: openai/fake
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: openai-embedding
litellm_params:
model: openai/text-embedding-3-small
api_key: os.environ/OPENAI_API_KEY
litellm_settings: guardrails:
set_verbose: True - guardrail_name: "lakera-pre-guard"
cache: True # set cache responses to True, litellm defaults to using a redis cache litellm_params:
cache_params: guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
type: qdrant-semantic mode: "during_call"
qdrant_semantic_cache_embedding_model: openai-embedding api_key: os.environ/LAKERA_API_KEY
qdrant_collection_name: test_collection api_base: os.environ/LAKERA_API_BASE
qdrant_quantization_config: binary category_thresholds:
similarity_threshold: 0.8 # similarity threshold for semantic cache prompt_injection: 0.1
jailbreak: 0.1

View file

@ -1498,6 +1498,11 @@ class ProxyConfig:
litellm.get_secret(secret_name=key, default_value=value) litellm.get_secret(secret_name=key, default_value=value)
) )
# check if litellm_license in general_settings
if "LITELLM_LICENSE" in environment_variables:
_license_check.license_str = os.getenv("LITELLM_LICENSE", None)
premium_user = _license_check.is_premium()
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get("litellm_settings", None) litellm_settings = config.get("litellm_settings", None)
if litellm_settings is None: if litellm_settings is None:
@ -1878,6 +1883,11 @@ class ProxyConfig:
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
) )
# check if litellm_license in general_settings
if "litellm_license" in general_settings:
_license_check.license_str = general_settings["litellm_license"]
premium_user = _license_check.is_premium()
router_params: dict = { router_params: dict = {
"cache_responses": litellm.cache "cache_responses": litellm.cache
!= None, # cache if user passed in cache values != None, # cache if user passed in cache values
@ -2784,26 +2794,29 @@ async def startup_event():
await custom_db_client.connect() await custom_db_client.connect()
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db
if os.getenv("PROXY_ADMIN_ID", None) is not None: if os.getenv("PROXY_ADMIN_ID", None) is not None:
litellm_proxy_admin_name = os.getenv( litellm_proxy_admin_name = os.getenv(
"PROXY_ADMIN_ID", litellm_proxy_admin_name "PROXY_ADMIN_ID", litellm_proxy_admin_name
) )
asyncio.create_task( if general_settings.get("disable_adding_master_key_hash_to_db") is True:
generate_key_helper_fn( verbose_proxy_logger.info("Skipping writing master key hash to db")
request_type="user", else:
duration=None, # add master key to db
models=[], asyncio.create_task(
aliases={}, generate_key_helper_fn(
config={}, request_type="user",
spend=0, duration=None,
token=master_key, models=[],
user_id=litellm_proxy_admin_name, aliases={},
user_role=LitellmUserRoles.PROXY_ADMIN, config={},
query_type="update_data", spend=0,
update_key_values={"user_role": LitellmUserRoles.PROXY_ADMIN}, token=master_key,
user_id=litellm_proxy_admin_name,
user_role=LitellmUserRoles.PROXY_ADMIN,
query_type="update_data",
update_key_values={"user_role": LitellmUserRoles.PROXY_ADMIN},
)
) )
)
if prisma_client is not None and litellm.max_budget > 0: if prisma_client is not None and litellm.max_budget > 0:
if litellm.budget_duration is None: if litellm.budget_duration is None:
@ -3011,6 +3024,29 @@ async def chat_completion(
model: Optional[str] = None, model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
"""
Follows the exact same API spec as `OpenAI's Chat API https://platform.openai.com/docs/api-reference/chat`
```bash
curl -X POST http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
"""
global general_settings, user_debug, proxy_logging_obj, llm_model_list global general_settings, user_debug, proxy_logging_obj, llm_model_list
data = {} data = {}
@ -3268,6 +3304,24 @@ async def completion(
model: Optional[str] = None, model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
"""
Follows the exact same API spec as `OpenAI's Completions API https://platform.openai.com/docs/api-reference/completions`
```bash
curl -X POST http://localhost:4000/v1/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-3.5-turbo-instruct",
"prompt": "Once upon a time",
"max_tokens": 50,
"temperature": 0.7
}'
```
"""
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
data = {} data = {}
try: try:
@ -3474,6 +3528,23 @@ async def embeddings(
model: Optional[str] = None, model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
"""
Follows the exact same API spec as `OpenAI's Embeddings API https://platform.openai.com/docs/api-reference/embeddings`
```bash
curl -X POST http://localhost:4000/v1/embeddings \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "text-embedding-ada-002",
"input": "The quick brown fox jumps over the lazy dog"
}'
```
"""
global proxy_logging_obj global proxy_logging_obj
data: Any = {} data: Any = {}
try: try:
@ -3481,6 +3552,11 @@ async def embeddings(
body = await request.body() body = await request.body()
data = orjson.loads(body) data = orjson.loads(body)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n%s",
json.dumps(data, indent=4),
)
# Include original request and headers in the data # Include original request and headers in the data
data = await add_litellm_data_to_request( data = await add_litellm_data_to_request(
data=data, data=data,

View file

@ -1,4 +1,6 @@
import json import json
import os
import secrets
import traceback import traceback
from typing import Optional from typing import Optional
@ -8,12 +10,30 @@ from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import hash_token from litellm.proxy.utils import hash_token
def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
if _master_key is None:
return False
## string comparison
is_master_key = secrets.compare_digest(api_key, _master_key)
if is_master_key:
return True
## hash comparison
is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
if is_master_key:
return True
return False
def get_logging_payload( def get_logging_payload(
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str] kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
) -> SpendLogsPayload: ) -> SpendLogsPayload:
from pydantic import Json from pydantic import Json
from litellm.proxy._types import LiteLLM_SpendLogs from litellm.proxy._types import LiteLLM_SpendLogs
from litellm.proxy.proxy_server import general_settings, master_key
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n" f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
@ -36,9 +56,15 @@ def get_logging_payload(
usage = dict(usage) usage = dict(usage)
id = response_obj.get("id", kwargs.get("litellm_call_id")) id = response_obj.get("id", kwargs.get("litellm_call_id"))
api_key = metadata.get("user_api_key", "") api_key = metadata.get("user_api_key", "")
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"): if api_key is not None and isinstance(api_key, str):
# hash the api_key if api_key.startswith("sk-"):
api_key = hash_token(api_key) # hash the api_key
api_key = hash_token(api_key)
if (
_is_master_key(api_key=api_key, _master_key=master_key)
and general_settings.get("disable_adding_master_key_hash_to_db") is True
):
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
_model_id = metadata.get("model_info", {}).get("id", "") _model_id = metadata.get("model_info", {}).get("id", "")
_model_group = metadata.get("model_group", "") _model_group = metadata.get("model_group", "")

View file

@ -0,0 +1,21 @@
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
# # request sent to model set on litellm proxy, `litellm --model`
response = client.embeddings.create(
model="multimodalembedding@001",
input=[],
extra_body={
"instances": [
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
},
],
},
)
print(response)

View file

@ -0,0 +1,58 @@
import vertexai
from google.auth.credentials import Credentials
from vertexai.vision_models import (
Image,
MultiModalEmbeddingModel,
Video,
VideoSegmentConfig,
)
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
image = Image.load_from_file(
"gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
)
embeddings = model.get_embeddings(
image=image,
contextual_text="Colosseum",
dimension=1408,
)
print(f"Image Embedding: {embeddings.image_embedding}")
print(f"Text Embedding: {embeddings.text_embedding}")

View file

@ -44,6 +44,7 @@ from litellm.proxy._types import (
DynamoDBArgs, DynamoDBArgs,
LiteLLM_VerificationTokenView, LiteLLM_VerificationTokenView,
LitellmUserRoles, LitellmUserRoles,
Member,
ResetTeamBudgetRequest, ResetTeamBudgetRequest,
SpendLogsMetadata, SpendLogsMetadata,
SpendLogsPayload, SpendLogsPayload,
@ -1395,6 +1396,7 @@ class PrismaClient:
t.blocked AS team_blocked, t.blocked AS team_blocked,
t.team_alias AS team_alias, t.team_alias AS team_alias,
t.metadata AS team_metadata, t.metadata AS team_metadata,
t.members_with_roles AS team_members_with_roles,
tm.spend AS team_member_spend, tm.spend AS team_member_spend,
m.aliases as team_model_aliases m.aliases as team_model_aliases
FROM "LiteLLM_VerificationToken" AS v FROM "LiteLLM_VerificationToken" AS v
@ -1412,6 +1414,33 @@ class PrismaClient:
response["team_models"] = [] response["team_models"] = []
if response["team_blocked"] is None: if response["team_blocked"] is None:
response["team_blocked"] = False response["team_blocked"] = False
team_member: Optional[Member] = None
if (
response["team_members_with_roles"] is not None
and response["user_id"] is not None
):
## find the team member corresponding to user id
"""
[
{
"role": "admin",
"user_id": "default_user_id",
"user_email": null
},
{
"role": "user",
"user_id": null,
"user_email": "test@email.com"
}
]
"""
for tm in response["team_members_with_roles"]:
if tm.get("user_id") is not None and response[
"user_id"
] == tm.get("user_id"):
team_member = Member(**tm)
response["team_member"] = team_member
response = LiteLLM_VerificationTokenView( response = LiteLLM_VerificationTokenView(
**response, last_refreshed_at=time.time() **response, last_refreshed_at=time.time()
) )

View file

@ -25,6 +25,9 @@ from litellm.batches.main import FileObject
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)
router = APIRouter() router = APIRouter()
default_vertex_config = None default_vertex_config = None
@ -70,10 +73,17 @@ def exception_handler(e: Exception):
) )
async def execute_post_vertex_ai_request( @router.api_route(
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
)
async def vertex_proxy_route(
endpoint: str,
request: Request, request: Request,
route: str, fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
encoded_endpoint = httpx.URL(endpoint).path
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
if default_vertex_config is None: if default_vertex_config is None:
@ -83,250 +93,52 @@ async def execute_post_vertex_ai_request(
vertex_project = default_vertex_config.get("vertex_project", None) vertex_project = default_vertex_config.get("vertex_project", None)
vertex_location = default_vertex_config.get("vertex_location", None) vertex_location = default_vertex_config.get("vertex_location", None)
vertex_credentials = default_vertex_config.get("vertex_credentials", None) vertex_credentials = default_vertex_config.get("vertex_credentials", None)
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
request_data_json = {} auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
body = await request.body() model="",
body_str = body.decode() gemini_api_key=None,
if len(body_str) > 0: vertex_credentials=vertex_credentials,
try: vertex_project=vertex_project,
request_data_json = ast.literal_eval(body_str) vertex_location=vertex_location,
except: stream=False,
request_data_json = json.loads(body_str) custom_llm_provider="vertex_ai_beta",
api_base="",
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(
json.dumps(request_data_json, indent=4)
),
) )
response = ( headers = {
await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request( "Authorization": f"Bearer {auth_header}",
request_data=request_data_json, }
vertex_project=vertex_project,
vertex_location=vertex_location, request_route = encoded_endpoint
vertex_credentials=vertex_credentials, verbose_proxy_logger.debug("request_route %s", request_route)
request_route=route,
) # Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request,
) )
return response return received_value
@router.post(
"/vertex-ai/publishers/google/models/{model_id:path}:generateContent",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_generate_content(
request: Request,
fastapi_response: Response,
model_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /generateContent endpoint
Example Curl:
```
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
```
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#rest
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/publishers/google/models/{model_id}:generateContent",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/publishers/google/models/{model_id:path}:predict",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_predict_endpoint(
request: Request,
fastapi_response: Response,
model_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /predict endpoint
Use this for:
- Embeddings API - Text Embedding, Multi Modal Embedding
- Imagen API
- Code Completion API
Example Curl:
```
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{"instances":[{"content": "gm"}]}'
```
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#generative-ai-get-text-embedding-drest
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/publishers/google/models/{model_id}:predict",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/publishers/google/models/{model_id:path}:countTokens",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_countTokens_endpoint(
request: Request,
fastapi_response: Response,
model_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /countTokens endpoint
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/count-tokens#curl
Example Curl:
```
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
```
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/publishers/google/models/{model_id}:countTokens",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/batchPredictionJobs",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_create_batch_prediction_job(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /batchPredictionJobs endpoint
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/batch-prediction-api#syntax
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route="/batchPredictionJobs",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/tuningJobs",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_create_fine_tuning_job(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route="/tuningJobs",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/tuningJobs/{job_id:path}:cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_cancel_fine_tuning_job(
request: Request,
job_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. tuningJobs/{job_id:path}:cancel
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#cancel_a_tuning_job
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/tuningJobs/{job_id}:cancel",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/cachedContents",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_create_add_cached_content(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /cachedContents endpoint
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route="/cachedContents",
)
return response
except Exception as e:
raise exception_handler(e) from e

View file

@ -15,7 +15,7 @@ import asyncio
import json import json
import os import os
import tempfile import tempfile
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -501,6 +501,8 @@ async def test_async_vertexai_streaming_response():
assert len(complete_response) > 0 assert len(complete_response) > 0
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except litellm.APIConnectionError:
pass
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except litellm.InternalServerError as e: except litellm.InternalServerError as e:
@ -955,6 +957,8 @@ async def test_partner_models_httpx(model, sync_mode):
assert isinstance(response._hidden_params["response_cost"], float) assert isinstance(response._hidden_params["response_cost"], float)
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except litellm.InternalServerError as e:
pass
except Exception as e: except Exception as e:
if "429 Quota exceeded" in str(e): if "429 Quota exceeded" in str(e):
pass pass
@ -1004,7 +1008,9 @@ async def test_partner_models_httpx_streaming(model, sync_mode):
idx += 1 idx += 1
print(f"response: {response}") print(f"response: {response}")
except litellm.RateLimitError: except litellm.RateLimitError as e:
pass
except litellm.InternalServerError as e:
pass pass
except Exception as e: except Exception as e:
if "429 Quota exceeded" in str(e): if "429 Quota exceeded" in str(e):
@ -1558,6 +1564,16 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
"response_schema" "response_schema"
in mock_call.call_args.kwargs["json"]["generationConfig"] in mock_call.call_args.kwargs["json"]["generationConfig"]
) )
assert (
"response_mime_type"
in mock_call.call_args.kwargs["json"]["generationConfig"]
)
assert (
mock_call.call_args.kwargs["json"]["generationConfig"][
"response_mime_type"
]
== "application/json"
)
else: else:
assert ( assert (
"response_schema" "response_schema"
@ -1826,6 +1842,71 @@ def test_vertexai_embedding():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_vertexai_multimodal_embedding():
load_vertex_ai_credentials()
mock_response = AsyncMock()
def return_val():
return {
"predictions": [
{
"imageEmbedding": [0.1, 0.2, 0.3], # Simplified example
"textEmbedding": [0.4, 0.5, 0.6], # Simplified example
}
]
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"instances": [
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
}
]
}
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.aembedding function
response = await litellm.aembedding(
model="vertex_ai/multimodalembedding@001",
input=[
{
"image": {
"gcsUri": "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png"
},
"text": "this is a unicorn",
},
],
)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_vertexai = kwargs["json"]
print("args to vertex ai call:", args_to_vertexai)
assert args_to_vertexai == expected_payload
assert response.model == "multimodalembedding@001"
assert len(response.data) == 1
response_data = response.data[0]
assert "imageEmbedding" in response_data
assert "textEmbedding" in response_data
# Optional: Print for debugging
print("Arguments passed to Vertex AI:", args_to_vertexai)
print("Response:", response)
@pytest.mark.skip( @pytest.mark.skip(
reason="new test - works locally running into vertex version issues on ci/cd" reason="new test - works locally running into vertex version issues on ci/cd"
) )

View file

@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
"temperature": 0.3, "temperature": 0.3,
"messages": [ "messages": [
{"role": "system", "content": system}, {"role": "system", "content": system},
{"role": "user", "content": "hey, how's it going?"}, {"role": "assistant", "content": "hey, how's it going?"},
], ],
"user_continue_message": {"role": "user", "content": "Be a good bot!"},
} }
response: ModelResponse = completion( response: ModelResponse = completion(
model="bedrock/{}".format(model), model="bedrock/{}".format(model),

View file

@ -3653,6 +3653,7 @@ def test_completion_cohere():
response = completion( response = completion(
model="command-r", model="command-r",
messages=messages, messages=messages,
extra_headers={"Helicone-Property-Locale": "ko"},
) )
print(response) print(response)
except Exception as e: except Exception as e:

View file

@ -1252,3 +1252,48 @@ def test_standard_logging_payload(model, turn_off_message_logging):
] ]
if turn_off_message_logging: if turn_off_message_logging:
assert "redacted-by-litellm" == slobject["messages"][0]["content"] assert "redacted-by-litellm" == slobject["messages"][0]["content"]
@pytest.mark.skip(reason="Works locally. Flaky on ci/cd")
def test_aaastandard_logging_payload_cache_hit():
from litellm.types.utils import StandardLoggingPayload
# sync completion
litellm.cache = Cache()
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
)
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
)
time.sleep(2)
mock_client.assert_called_once()
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
assert standard_logging_object["cache_hit"] is True
assert standard_logging_object["response_cost"] == 0
assert standard_logging_object["saved_cache_cost"] > 0

View file

@ -54,6 +54,7 @@ def get_current_weather(location, unit="fahrenheit"):
) )
def test_parallel_function_call(model): def test_parallel_function_call(model):
try: try:
litellm.set_verbose = True
# Step 1: send the conversation and available functions to the model # Step 1: send the conversation and available functions to the model
messages = [ messages = [
{ {
@ -141,6 +142,8 @@ def test_parallel_function_call(model):
drop_params=True, drop_params=True,
) # get a new response from the model where it can see the function response ) # get a new response from the model where it can see the function response
print("second response\n", second_response) print("second response\n", second_response)
except litellm.RateLimitError:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -322,6 +325,7 @@ def test_groq_parallel_function_call():
location=function_args.get("location"), location=function_args.get("location"),
unit=function_args.get("unit"), unit=function_args.get("unit"),
) )
messages.append( messages.append(
{ {
"tool_call_id": tool_call.id, "tool_call_id": tool_call.id,
@ -337,27 +341,3 @@ def test_groq_parallel_function_call():
print("second response\n", second_response) print("second response\n", second_response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("model", ["gemini/gemini-1.5-pro"])
def test_simple_function_call_function_param(model):
try:
litellm.set_verbose = True
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
response = completion(
model=model,
messages=messages,
tools=[
{
"type": "function",
"function": {
"name": "plot",
"description": "Generate plots",
},
}
],
tool_choice="auto",
)
print(f"response: {response}")
except Exception as e:
raise e

View file

@ -116,6 +116,8 @@ async def test_async_image_generation_openai():
) )
print(f"response: {response}") print(f"response: {response}")
assert len(response.data) > 0 assert len(response.data) > 0
except litellm.APIError:
pass
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except litellm.ContentPolicyViolationError: except litellm.ContentPolicyViolationError:

View file

@ -2328,6 +2328,11 @@ async def test_master_key_hashing(prisma_client):
from litellm.proxy.proxy_server import user_api_key_cache from litellm.proxy.proxy_server import user_api_key_cache
_team_id = "ishaans-special-team_{}".format(uuid.uuid4()) _team_id = "ishaans-special-team_{}".format(uuid.uuid4())
user_api_key_dict = UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
)
await new_team( await new_team(
NewTeamRequest(team_id=_team_id), NewTeamRequest(team_id=_team_id),
user_api_key_dict=UserAPIKeyAuth( user_api_key_dict=UserAPIKeyAuth(
@ -2343,7 +2348,8 @@ async def test_master_key_hashing(prisma_client):
models=["azure-gpt-3.5"], models=["azure-gpt-3.5"],
team_id=_team_id, team_id=_team_id,
tpm_limit=20, tpm_limit=20,
) ),
user_api_key_dict=user_api_key_dict,
) )
print(_response) print(_response)
assert _response.models == ["azure-gpt-3.5"] assert _response.models == ["azure-gpt-3.5"]

View file

@ -19,7 +19,11 @@ from litellm.types.completion import (
ChatCompletionSystemMessageParam, ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam, ChatCompletionUserMessageParam,
) )
from litellm.utils import get_optional_params, get_optional_params_embeddings from litellm.utils import (
get_optional_params,
get_optional_params_embeddings,
get_optional_params_image_gen,
)
## get_optional_params_embeddings ## get_optional_params_embeddings
### Models: OpenAI, Azure, Bedrock ### Models: OpenAI, Azure, Bedrock
@ -430,7 +434,6 @@ def test_get_optional_params_image_gen():
print(response) print(response)
assert "aws_region_name" not in response assert "aws_region_name" not in response
response = litellm.utils.get_optional_params_image_gen( response = litellm.utils.get_optional_params_image_gen(
aws_region_name="us-east-1", custom_llm_provider="bedrock" aws_region_name="us-east-1", custom_llm_provider="bedrock"
) )
@ -463,3 +466,36 @@ def test_get_optional_params_num_retries():
print(f"mock_client.call_args: {mock_client.call_args}") print(f"mock_client.call_args: {mock_client.call_args}")
assert mock_client.call_args.kwargs["max_retries"] == 10 assert mock_client.call_args.kwargs["max_retries"] == 10
@pytest.mark.parametrize(
"provider",
[
"vertex_ai",
"vertex_ai_beta",
],
)
def test_vertex_safety_settings(provider):
litellm.vertex_ai_safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
]
optional_params = get_optional_params(
model="gemini-1.5-pro", custom_llm_provider=provider
)
assert len(optional_params) == 1

View file

@ -909,7 +909,7 @@ async def test_create_team_member_add(prisma_client, new_member_method):
await team_member_add( await team_member_add(
data=team_member_add_request, data=team_member_add_request,
user_api_key_dict=UserAPIKeyAuth(), user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
http_request=Request( http_request=Request(
scope={"type": "http", "path": "/user/new"}, scope={"type": "http", "path": "/user/new"},
), ),
@ -930,6 +930,172 @@ async def test_create_team_member_add(prisma_client, new_member_method):
) )
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
@pytest.mark.asyncio
async def test_create_team_member_add_team_admin_user_api_key_auth(
prisma_client, team_member_role, team_route
):
import time
from fastapi import Request
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
from litellm.proxy.proxy_server import (
ProxyException,
hash_token,
user_api_key_auth,
user_api_key_cache,
)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_internal_user_budget", 10)
setattr(litellm, "internal_user_budget_duration", "5m")
await litellm.proxy.proxy_server.prisma_client.connect()
user = f"ishaan {uuid.uuid4().hex}"
_team_id = "litellm-test-client-id-new"
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id,
token=hash_token(user_key),
team_member=Member(role=team_member_role, user_id=user),
last_refreshed_at=time.time(),
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id,
blocked=False,
last_refreshed_at=time.time(),
metadata={"guardrails": {"modify_guardrails": False}},
)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
import json
from starlette.datastructures import URL
request = Request(scope={"type": "http"})
request._url = URL(url=team_route)
body = {}
json_bytes = json.dumps(body).encode("utf-8")
request._body = json_bytes
## ALLOWED BY USER_API_KEY_AUTH
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
@pytest.mark.parametrize("user_role", ["admin", "user"])
@pytest.mark.asyncio
async def test_create_team_member_add_team_admin(
prisma_client, new_member_method, user_role
):
"""
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
Allow team admins to:
- Add and remove team members
- raise error if team member not an existing 'internal_user'
"""
import time
from fastapi import Request
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
from litellm.proxy.proxy_server import (
HTTPException,
ProxyException,
hash_token,
user_api_key_auth,
user_api_key_cache,
)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_internal_user_budget", 10)
setattr(litellm, "internal_user_budget_duration", "5m")
await litellm.proxy.proxy_server.prisma_client.connect()
user = f"ishaan {uuid.uuid4().hex}"
_team_id = "litellm-test-client-id-new"
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id,
user_id=user,
token=hash_token(user_key),
last_refreshed_at=time.time(),
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id,
blocked=False,
last_refreshed_at=time.time(),
members_with_roles=[Member(role=user_role, user_id=user)],
metadata={"guardrails": {"modify_guardrails": False}},
)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
if new_member_method == "user_id":
data = {
"team_id": _team_id,
"member": [{"role": "user", "user_id": user}],
}
elif new_member_method == "user_email":
data = {
"team_id": _team_id,
"member": [{"role": "user", "user_email": user}],
}
team_member_add_request = TeamMemberAddRequest(**data)
with patch(
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
new_callable=AsyncMock,
) as mock_litellm_usertable:
mock_client = AsyncMock()
mock_litellm_usertable.upsert = mock_client
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
try:
await team_member_add(
data=team_member_add_request,
user_api_key_dict=valid_token,
http_request=Request(
scope={"type": "http", "path": "/user/new"},
),
)
except HTTPException as e:
if user_role == "user":
assert e.status_code == 403
else:
raise e
mock_client.assert_called()
print(f"mock_client.call_args: {mock_client.call_args}")
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
assert (
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
== litellm.max_internal_user_budget
)
assert (
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
== litellm.internal_user_budget_duration
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_info_team_list(prisma_client): async def test_user_info_team_list(prisma_client):
"""Assert user_info for admin calls team_list function""" """Assert user_info for admin calls team_list function"""

View file

@ -0,0 +1,24 @@
import json
from typing import Any, Optional, TypedDict, Union
from pydantic import BaseModel
from typing_extensions import (
Protocol,
Required,
Self,
TypeGuard,
get_origin,
override,
runtime_checkable,
)
class OllamaToolCallFunction(
TypedDict
): # follows - https://github.com/ollama/ollama/blob/6bd8a4b0a1ac15d5718f52bbe1cd56f827beb694/api/types.go#L148
name: str
arguments: dict
class OllamaToolCall(TypedDict):
function: OllamaToolCallFunction

View file

@ -1,6 +1,6 @@
import json import json
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
from typing_extensions import ( from typing_extensions import (
Protocol, Protocol,
@ -305,3 +305,18 @@ class ResponseTuningJob(TypedDict):
] ]
createTime: Optional[str] createTime: Optional[str]
updateTime: Optional[str] updateTime: Optional[str]
class InstanceVideo(TypedDict, total=False):
gcsUri: str
videoSegmentConfig: Tuple[float, float, float]
class Instance(TypedDict, total=False):
text: str
image: Dict[str, str]
video: InstanceVideo
class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
instances: List[Instance]

View file

@ -1116,6 +1116,7 @@ all_litellm_params = [
"cooldown_time", "cooldown_time",
"cache_key", "cache_key",
"max_retries", "max_retries",
"user_continue_message",
] ]
@ -1218,6 +1219,7 @@ class StandardLoggingPayload(TypedDict):
metadata: StandardLoggingMetadata metadata: StandardLoggingMetadata
cache_hit: Optional[bool] cache_hit: Optional[bool]
cache_key: Optional[str] cache_key: Optional[str]
saved_cache_cost: Optional[float]
request_tags: list request_tags: list
end_user: Optional[str] end_user: Optional[str]
requester_ip_address: Optional[str] requester_ip_address: Optional[str]

View file

@ -541,7 +541,7 @@ def function_setup(
call_type == CallTypes.embedding.value call_type == CallTypes.embedding.value
or call_type == CallTypes.aembedding.value or call_type == CallTypes.aembedding.value
): ):
messages = args[1] if len(args) > 1 else kwargs["input"] messages = args[1] if len(args) > 1 else kwargs.get("input", None)
elif ( elif (
call_type == CallTypes.image_generation.value call_type == CallTypes.image_generation.value
or call_type == CallTypes.aimage_generation.value or call_type == CallTypes.aimage_generation.value
@ -2323,6 +2323,7 @@ def get_litellm_params(
output_cost_per_second=None, output_cost_per_second=None,
cooldown_time=None, cooldown_time=None,
text_completion=None, text_completion=None,
user_continue_message=None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2347,6 +2348,7 @@ def get_litellm_params(
"output_cost_per_second": output_cost_per_second, "output_cost_per_second": output_cost_per_second,
"cooldown_time": cooldown_time, "cooldown_time": cooldown_time,
"text_completion": text_completion, "text_completion": text_completion,
"user_continue_message": user_continue_message,
} }
return litellm_params return litellm_params
@ -3145,7 +3147,6 @@ def get_optional_params(
or model in litellm.vertex_embedding_models or model in litellm.vertex_embedding_models
or model in litellm.vertex_vision_models or model in litellm.vertex_vision_models
): ):
print_verbose(f"(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK")
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -3157,9 +3158,8 @@ def get_optional_params(
optional_params=optional_params, optional_params=optional_params,
) )
print_verbose( if litellm.vertex_ai_safety_settings is not None:
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}" optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
)
elif custom_llm_provider == "gemini": elif custom_llm_provider == "gemini":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -3170,7 +3170,7 @@ def get_optional_params(
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
) )
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini": elif custom_llm_provider == "vertex_ai_beta":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
@ -3185,6 +3185,8 @@ def get_optional_params(
else False else False
), ),
) )
if litellm.vertex_ai_safety_settings is not None:
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
elif ( elif (
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
): ):
@ -4219,6 +4221,7 @@ def get_supported_openai_params(
"presence_penalty", "presence_penalty",
"stop", "stop",
"n", "n",
"extra_headers",
] ]
elif custom_llm_provider == "cohere_chat": elif custom_llm_provider == "cohere_chat":
return [ return [
@ -4233,6 +4236,7 @@ def get_supported_openai_params(
"tools", "tools",
"tool_choice", "tool_choice",
"seed", "seed",
"extra_headers",
] ]
elif custom_llm_provider == "maritalk": elif custom_llm_provider == "maritalk":
return [ return [
@ -7121,6 +7125,14 @@ def exception_type(
llm_provider="bedrock", llm_provider="bedrock",
response=original_exception.response, response=original_exception.response,
) )
elif "A conversation must start with a user message." in error_str:
exception_mapping_worked = True
raise BadRequestError(
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
model=model,
llm_provider="bedrock",
response=original_exception.response,
)
elif ( elif (
"Unable to locate credentials" in error_str "Unable to locate credentials" in error_str
or "The security token included in the request is invalid" or "The security token included in the request is invalid"

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.44.1" version = "1.44.2"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.44.1" version = "1.44.2"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]